Skip to content

Commit

Permalink
- refactor to make raytracing shaders, a stage of lib with an array o…
Browse files Browse the repository at this point in the history
…f shaders and a hit group block to define which entry points to use
  • Loading branch information
polymonster authored and GBDixonAlex committed Jan 12, 2025
1 parent 188270c commit 0684737
Showing 1 changed file with 76 additions and 52 deletions.
128 changes: 76 additions & 52 deletions pmfx_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,19 @@ def pmfx_hash(src):
return zlib.adler32(bytes(str(src).encode("utf8")))


# combine 2, 32 bit hashes
def pmfx_hash_combine(h1: int, h2: int) -> int:
combined_data = h1.to_bytes(4, 'little') + h2.to_bytes(4, 'little')
return zlib.adler32(combined_data)


# return names of supported shader stages
def get_shader_stages():
return [
"vs",
"ps",
"cs",
"rg",
"ch",
"ah",
"mi",
"is",
"ca"
"lib"
]


Expand Down Expand Up @@ -773,13 +774,6 @@ def cross_compile_hlsl_metal(info, src, stage, entry_point, temp_filepath, outpu
return 0, error_list, output_list


# convert satage to correct hlsl profile
def hlsl_stage(stage):
if stage in ["rg", "ch", "ah", "mi"]:
return "lib"
return stage


# compile a hlsl version 2
def compile_shader_hlsl(info, src, stage, entry_point, temp_filepath, output_filepath):
exe = os.path.join(info.tools_dir, "bin", "dxc", "dxc")
Expand All @@ -792,7 +786,7 @@ def compile_shader_hlsl(info, src, stage, entry_point, temp_filepath, output_fil
if info.shader_platform == "metal":
error_code, error_list, output_list = cross_compile_hlsl_metal(info, src, stage, entry_point, temp_filepath, output_filepath)
elif info.shader_platform == "hlsl":
cmdline = "{} -T {}_{} -E {} -Fo {} {}".format(exe, hlsl_stage(stage), info.shader_version, entry_point, output_filepath, temp_filepath)
cmdline = "{} -T {}_{} -E {} -Fo {} {}".format(exe, stage, info.shader_version, entry_point, output_filepath, temp_filepath)
cmdline += " " + build_pmfx.get_info().args
error_code, error_list, output_list = build_pmfx.call_wait_subprocess(cmdline)

Expand Down Expand Up @@ -1066,42 +1060,63 @@ def generate_pipeline_permutation(pipeline_name, pipeline, output_pmfx, shaders,
print(" pipeline: {} {}".format(pipeline_name, permutation_name))
resources = dict()
output_pipeline = dict(pipeline)
# lookup info from compiled shaders and combine resources

# gather entry points
entry_points = list()
for stage in get_shader_stages():
if stage in pipeline:
entry_point = pipeline[stage]
if entry_point not in shaders[stage]:
output_pipeline["error_code"] = 1
continue
# lookup shader info, and redirect to shared shaders
shader_info = shaders[stage][entry_point][pemutation_id]
if "lookup" in shader_info:
lookup = shader_info["lookup"]
shader_info = dict(shaders[stage][lookup[0]][lookup[1]])
if type(pipeline[stage]) is list:
for entry_point in pipeline[stage]:
entry_points.append((stage, entry_point, True))
else:
entry_points.append((stage, pipeline[stage], False))

# clear lib
if "lib" in output_pipeline:
output_pipeline["lib_hash"] = 0
output_pipeline["lib"].clear()

# lookup info from compiled shaders and combine resources
for (stage, entry_point, lib) in entry_points:
# check entry exists
if entry_point not in shaders[stage]:
output_pipeline["error_code"] = 1
continue
# lookup shader info, and redirect to shared shaders
shader_info = shaders[stage][entry_point][pemutation_id]
if "lookup" in shader_info:
lookup = shader_info["lookup"]
shader_info = dict(shaders[stage][lookup[0]][lookup[1]])

if lib:
output_pipeline[stage].append(shader_info["filename"])
output_pipeline["lib_hash"] = pmfx_hash_combine(output_pipeline["lib_hash"], pmfx_hash(shader_info["src_hash"]))
else:
output_pipeline[stage] = shader_info["filename"]
output_pipeline["{}_hash:".format(stage)] = pmfx_hash(shader_info["src_hash"])
shader = shader_info
resources = merge_dicts(resources, dict(shader["resources"]), ["visibility"])
# generate vertex layout
if stage == "vs":
pmfx_vertex_layout = dict()
if "vertex_layout" in pipeline:
pmfx_vertex_layout = pipeline["vertex_layout"]
output_pipeline["vertex_layout"] = generate_vertex_layout(shader["vertex_elements"], pmfx_vertex_layout)
# extract numthreads
if stage == "cs":
for attrib in shader["attributes"]:
if attrib.find("numthreads") != -1:
start, end = cgu.enclose_start_end("(", ")", attrib, 0)
xyz = attrib[start:end].split(",")
numthreads = []
for a in xyz:
numthreads.append(int(a.strip()))
output_pipeline["numthreads"] = numthreads

# set non zero error codes to track failures
if shader_info["error_code"] != 0:
output_pipeline["error_code"] = shader_info["error_code"]
output_pipeline["{}_hash".format(stage)] = pmfx_hash(shader_info["src_hash"])

shader = shader_info
resources = merge_dicts(resources, dict(shader["resources"]), ["visibility"])
# generate vertex layout
if stage == "vs":
pmfx_vertex_layout = dict()
if "vertex_layout" in pipeline:
pmfx_vertex_layout = pipeline["vertex_layout"]
output_pipeline["vertex_layout"] = generate_vertex_layout(shader["vertex_elements"], pmfx_vertex_layout)
# extract numthreads
if stage == "cs":
for attrib in shader["attributes"]:
if attrib.find("numthreads") != -1:
start, end = cgu.enclose_start_end("(", ")", attrib, 0)
xyz = attrib[start:end].split(",")
numthreads = []
for a in xyz:
numthreads.append(int(a.strip()))
output_pipeline["numthreads"] = numthreads

# set non zero error codes to track failures
if shader_info["error_code"] != 0:
output_pipeline["error_code"] = shader_info["error_code"]

# build pipeline layout
output_pipeline["pipeline_layout"] = generate_pipeline_layout(output_pmfx, pipeline, resources)
Expand Down Expand Up @@ -1340,9 +1355,13 @@ def generate_pmfx(file, root):
pipeline = pipelines[pipeline_key]
for stage in get_shader_stages():
if stage in pipeline:
stage_shader = (stage, pipeline[stage])
if stage_shader not in shader_list:
shader_list.append(stage_shader)
if type(pipeline[stage]) is list:
for shader in pipeline[stage]:
stage_shader = (stage, shader)
else:
stage_shader = (stage, pipeline[stage])
if stage_shader not in shader_list:
shader_list.append(stage_shader)

# gather permutations
permutation_jobs = []
Expand All @@ -1357,8 +1376,13 @@ def generate_pmfx(file, root):
pipeline_jobs.append((pipeline_key, id))
for stage in get_shader_stages():
if stage in pipeline:
permutation_jobs.append(
pool.apply_async(generate_shader_info_permutation, (pmfx, pipeline[stage], stage, permute, define_list)))
if type(pipeline[stage]) is list:
for shader in pipeline[stage]:
permutation_jobs.append(
pool.apply_async(generate_shader_info_permutation, (pmfx, shader, stage, permute, define_list)))
else:
permutation_jobs.append(
pool.apply_async(generate_shader_info_permutation, (pmfx, pipeline[stage], stage, permute, define_list)))

# wait on shader permutations
shaders = dict()
Expand Down

0 comments on commit 0684737

Please sign in to comment.