Skip to content

Commit

Permalink
Try #7518:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Feb 6, 2023
2 parents e1b0bbf + 47abe67 commit cf13f56
Showing 1 changed file with 159 additions and 12 deletions.
171 changes: 159 additions & 12 deletions crates/bevy_render/src/render_resource/shader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,11 @@ pub enum ProcessShaderError {
expected: String,
value: String,
},
#[error("Invalid shader def definition for '{shader_def_name}': {value}")]
InvalidShaderDefDefinitionValue {
shader_def_name: String,
value: String,
},
}

pub struct ShaderImportProcessor {
Expand Down Expand Up @@ -388,6 +393,7 @@ pub struct ShaderProcessor {
else_ifdef_regex: Regex,
else_regex: Regex,
endif_regex: Regex,
define_regex: Regex,
def_regex: Regex,
def_regex_delimited: Regex,
}
Expand All @@ -397,10 +403,11 @@ impl Default for ShaderProcessor {
Self {
ifdef_regex: Regex::new(r"^\s*#\s*ifdef\s*([\w|\d|_]+)").unwrap(),
ifndef_regex: Regex::new(r"^\s*#\s*ifndef\s*([\w|\d|_]+)").unwrap(),
ifop_regex: Regex::new(r"^\s*#\s*if\s*([\w|\d|_]+)\s*([^\s]*)\s*([\w|\d]+)").unwrap(),
ifop_regex: Regex::new(r"^\s*#\s*if\s*([\w|\d|_]+)\s*([^\s]*)\s*([-\w|\d]+)").unwrap(),
else_ifdef_regex: Regex::new(r"^\s*#\s*else\s+ifdef\s*([\w|\d|_]+)").unwrap(),
else_regex: Regex::new(r"^\s*#\s*else").unwrap(),
endif_regex: Regex::new(r"^\s*#\s*endif").unwrap(),
define_regex: Regex::new(r"^\s*#\s*define\s*([\w|\d|_]+)\s*([-\w|\d]+)?").unwrap(),
def_regex: Regex::new(r"#\s*([\w|\d|_]+)").unwrap(),
def_regex_delimited: Regex::new(r"#\s*\{([\w|\d|_]+)\}").unwrap(),
}
Expand Down Expand Up @@ -449,24 +456,34 @@ impl ShaderProcessor {
shader_defs: &[ShaderDefVal],
shaders: &HashMap<Handle<Shader>, Shader>,
import_handles: &HashMap<ShaderImport, Handle<Shader>>,
) -> Result<ProcessedShader, ProcessShaderError> {
let mut shader_defs_unique =
HashMap::<String, ShaderDefVal>::from_iter(shader_defs.iter().map(|v| match v {
ShaderDefVal::Bool(k, _) | ShaderDefVal::Int(k, _) | ShaderDefVal::UInt(k, _) => {
(k.clone(), v.clone())
}
}));
self.process_inner(shader, &mut shader_defs_unique, shaders, import_handles)
}

fn process_inner(
&self,
shader: &Shader,
shader_defs_unique: &mut HashMap<String, ShaderDefVal>,
shaders: &HashMap<Handle<Shader>, Shader>,
import_handles: &HashMap<ShaderImport, Handle<Shader>>,
) -> Result<ProcessedShader, ProcessShaderError> {
let shader_str = match &shader.source {
Source::Wgsl(source) => source.deref(),
Source::Glsl(source, _stage) => source.deref(),
Source::SpirV(source) => {
if shader_defs.is_empty() {
if shader_defs_unique.is_empty() {
return Ok(ProcessedShader::SpirV(source.clone()));
}
return Err(ProcessShaderError::ShaderFormatDoesNotSupportShaderDefs);
}
};

let shader_defs_unique =
HashMap::<String, ShaderDefVal>::from_iter(shader_defs.iter().map(|v| match v {
ShaderDefVal::Bool(k, _) | ShaderDefVal::Int(k, _) | ShaderDefVal::UInt(k, _) => {
(k.clone(), v.clone())
}
}));
let mut scopes = vec![Scope::new(true)];
let mut final_string = String::new();
for line in shader_str.lines() {
Expand Down Expand Up @@ -544,6 +561,26 @@ impl ShaderProcessor {
let current_valid = scopes.last().unwrap().is_accepting_lines();

scopes.push(Scope::new(current_valid && new_scope));
} else if let Some(cap) = self.define_regex.captures(line) {
let def = cap.get(1).unwrap();
let name = def.as_str().to_string();

if let Some(val) = cap.get(2) {
if let Ok(val) = val.as_str().parse() {
shader_defs_unique.insert(name.clone(), ShaderDefVal::UInt(name, val));
} else if let Ok(val) = val.as_str().parse() {
shader_defs_unique.insert(name.clone(), ShaderDefVal::Int(name, val));
} else if let Ok(val) = val.as_str().parse() {
shader_defs_unique.insert(name.clone(), ShaderDefVal::Bool(name, val));
} else {
return Err(ProcessShaderError::InvalidShaderDefDefinitionValue {
shader_def_name: name,
value: val.as_str().to_string(),
});
}
} else {
shader_defs_unique.insert(name.clone(), ShaderDefVal::Bool(name, true));
}
} else if let Some(cap) = self.else_ifdef_regex.captures(line) {
// When should we accept the code in an
//
Expand Down Expand Up @@ -627,7 +664,7 @@ impl ShaderProcessor {
shaders,
&import,
shader,
shader_defs,
shader_defs_unique,
&mut final_string,
)?;
} else if let Some(cap) = SHADER_IMPORT_PROCESSOR
Expand All @@ -640,7 +677,7 @@ impl ShaderProcessor {
shaders,
&import,
shader,
shader_defs,
shader_defs_unique,
&mut final_string,
)?;
} else if SHADER_IMPORT_PROCESSOR
Expand Down Expand Up @@ -695,15 +732,15 @@ impl ShaderProcessor {
shaders: &HashMap<Handle<Shader>, Shader>,
import: &ShaderImport,
shader: &Shader,
shader_defs: &[ShaderDefVal],
shader_defs_unique: &mut HashMap<String, ShaderDefVal>,
final_string: &mut String,
) -> Result<(), ProcessShaderError> {
let imported_shader = import_handles
.get(import)
.and_then(|handle| shaders.get(handle))
.ok_or_else(|| ProcessShaderError::UnresolvedImport(import.clone()))?;
let imported_processed =
self.process(imported_shader, shader_defs, shaders, import_handles)?;
self.process_inner(imported_shader, shader_defs_unique, shaders, import_handles)?;

match &shader.source {
Source::Wgsl(_) => {
Expand Down Expand Up @@ -2441,4 +2478,114 @@ fn vertex(
.unwrap();
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED_REPLACED);
}

#[test]
fn process_shader_define_in_shader() {
#[rustfmt::skip]
const WGSL: &str = r"
#ifdef NOW_DEFINED
defined at start
#endif
#define NOW_DEFINED
#ifdef NOW_DEFINED
defined at end
#endif
";

#[rustfmt::skip]
const EXPECTED: &str = r"
defined at end
";
let processor = ShaderProcessor::default();
let result = processor
.process(
&Shader::from_wgsl(WGSL),
&[],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
}

#[test]
fn process_shader_define_in_shader_with_value() {
#[rustfmt::skip]
const WGSL: &str = r"
#define DEFUINT 1
#define DEFINT -1
#define DEFBOOL false
#if DEFUINT == 1
uint: #DEFUINT
#endif
#if DEFINT == -1
int: #DEFINT
#endif
#if DEFBOOL == false
bool: #DEFBOOL
#endif
";

#[rustfmt::skip]
const EXPECTED: &str = r"
uint: 1
int: -1
bool: false
";
let processor = ShaderProcessor::default();
let result = processor
.process(
&Shader::from_wgsl(WGSL),
&[],
&HashMap::default(),
&HashMap::default(),
)
.unwrap();
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
}

#[test]
fn process_shader_define_across_imports() {
#[rustfmt::skip]
const FOO: &str = r"
#define IMPORTED
";
const BAR: &str = r"
#IMPORTED
";
#[rustfmt::skip]
const INPUT: &str = r"
#import FOO
#import BAR
";
#[rustfmt::skip]
const EXPECTED: &str = r"
true
";
let processor = ShaderProcessor::default();
let mut shaders = HashMap::default();
let mut import_handles = HashMap::default();
{
let foo_handle = HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 0).typed();
shaders.insert(foo_handle.clone_weak(), Shader::from_wgsl(FOO));
import_handles.insert(
ShaderImport::Custom("FOO".to_string()),
foo_handle.clone_weak(),
);
}
{
let bar_handle = HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 1).typed();
shaders.insert(bar_handle.clone_weak(), Shader::from_wgsl(BAR));
import_handles.insert(
ShaderImport::Custom("BAR".to_string()),
bar_handle.clone_weak(),
);
}
let result = processor
.process(&Shader::from_wgsl(INPUT), &[], &shaders, &import_handles)
.unwrap();
assert_eq!(result.get_wgsl_source().unwrap(), EXPECTED);
}
}

0 comments on commit cf13f56

Please sign in to comment.