From 06847373449ec123a8230108575427d46c04aa3a Mon Sep 17 00:00:00 2001 From: polymonster Date: Sun, 12 Jan 2025 17:08:35 +0000 Subject: [PATCH] - refactor to make raytracing shaders, a stage of lib with an array of shaders and a hit group block to define which entry points to use --- pmfx_pipeline.py | 128 ++++++++++++++++++++++++++++------------------- 1 file changed, 76 insertions(+), 52 deletions(-) diff --git a/pmfx_pipeline.py b/pmfx_pipeline.py index cb18920..339cd1e 100644 --- a/pmfx_pipeline.py +++ b/pmfx_pipeline.py @@ -42,18 +42,19 @@ def pmfx_hash(src): return zlib.adler32(bytes(str(src).encode("utf8"))) +# combine 2, 32 bit hashes +def pmfx_hash_combine(h1: int, h2: int) -> int: + combined_data = h1.to_bytes(4, 'little') + h2.to_bytes(4, 'little') + return zlib.adler32(combined_data) + + # return names of supported shader stages def get_shader_stages(): return [ "vs", "ps", "cs", - "rg", - "ch", - "ah", - "mi", - "is", - "ca" + "lib" ] @@ -773,13 +774,6 @@ def cross_compile_hlsl_metal(info, src, stage, entry_point, temp_filepath, outpu return 0, error_list, output_list -# convert satage to correct hlsl profile -def hlsl_stage(stage): - if stage in ["rg", "ch", "ah", "mi"]: - return "lib" - return stage - - # compile a hlsl version 2 def compile_shader_hlsl(info, src, stage, entry_point, temp_filepath, output_filepath): exe = os.path.join(info.tools_dir, "bin", "dxc", "dxc") @@ -792,7 +786,7 @@ def compile_shader_hlsl(info, src, stage, entry_point, temp_filepath, output_fil if info.shader_platform == "metal": error_code, error_list, output_list = cross_compile_hlsl_metal(info, src, stage, entry_point, temp_filepath, output_filepath) elif info.shader_platform == "hlsl": - cmdline = "{} -T {}_{} -E {} -Fo {} {}".format(exe, hlsl_stage(stage), info.shader_version, entry_point, output_filepath, temp_filepath) + cmdline = "{} -T {}_{} -E {} -Fo {} {}".format(exe, stage, info.shader_version, entry_point, output_filepath, temp_filepath) cmdline += " " + build_pmfx.get_info().args error_code, error_list, output_list = build_pmfx.call_wait_subprocess(cmdline) @@ -1066,42 +1060,63 @@ def generate_pipeline_permutation(pipeline_name, pipeline, output_pmfx, shaders, print(" pipeline: {} {}".format(pipeline_name, permutation_name)) resources = dict() output_pipeline = dict(pipeline) - # lookup info from compiled shaders and combine resources + + # gather entry points + entry_points = list() for stage in get_shader_stages(): if stage in pipeline: - entry_point = pipeline[stage] - if entry_point not in shaders[stage]: - output_pipeline["error_code"] = 1 - continue - # lookup shader info, and redirect to shared shaders - shader_info = shaders[stage][entry_point][pemutation_id] - if "lookup" in shader_info: - lookup = shader_info["lookup"] - shader_info = dict(shaders[stage][lookup[0]][lookup[1]]) + if type(pipeline[stage]) is list: + for entry_point in pipeline[stage]: + entry_points.append((stage, entry_point, True)) + else: + entry_points.append((stage, pipeline[stage], False)) + + # clear lib + if "lib" in output_pipeline: + output_pipeline["lib_hash"] = 0 + output_pipeline["lib"].clear() + + # lookup info from compiled shaders and combine resources + for (stage, entry_point, lib) in entry_points: + # check entry exists + if entry_point not in shaders[stage]: + output_pipeline["error_code"] = 1 + continue + # lookup shader info, and redirect to shared shaders + shader_info = shaders[stage][entry_point][pemutation_id] + if "lookup" in shader_info: + lookup = shader_info["lookup"] + shader_info = dict(shaders[stage][lookup[0]][lookup[1]]) + + if lib: + output_pipeline[stage].append(shader_info["filename"]) + output_pipeline["lib_hash"] = pmfx_hash_combine(output_pipeline["lib_hash"], pmfx_hash(shader_info["src_hash"])) + else: output_pipeline[stage] = shader_info["filename"] - output_pipeline["{}_hash:".format(stage)] = pmfx_hash(shader_info["src_hash"]) - shader = shader_info - resources = merge_dicts(resources, dict(shader["resources"]), ["visibility"]) - # generate vertex layout - if stage == "vs": - pmfx_vertex_layout = dict() - if "vertex_layout" in pipeline: - pmfx_vertex_layout = pipeline["vertex_layout"] - output_pipeline["vertex_layout"] = generate_vertex_layout(shader["vertex_elements"], pmfx_vertex_layout) - # extract numthreads - if stage == "cs": - for attrib in shader["attributes"]: - if attrib.find("numthreads") != -1: - start, end = cgu.enclose_start_end("(", ")", attrib, 0) - xyz = attrib[start:end].split(",") - numthreads = [] - for a in xyz: - numthreads.append(int(a.strip())) - output_pipeline["numthreads"] = numthreads - - # set non zero error codes to track failures - if shader_info["error_code"] != 0: - output_pipeline["error_code"] = shader_info["error_code"] + output_pipeline["{}_hash".format(stage)] = pmfx_hash(shader_info["src_hash"]) + + shader = shader_info + resources = merge_dicts(resources, dict(shader["resources"]), ["visibility"]) + # generate vertex layout + if stage == "vs": + pmfx_vertex_layout = dict() + if "vertex_layout" in pipeline: + pmfx_vertex_layout = pipeline["vertex_layout"] + output_pipeline["vertex_layout"] = generate_vertex_layout(shader["vertex_elements"], pmfx_vertex_layout) + # extract numthreads + if stage == "cs": + for attrib in shader["attributes"]: + if attrib.find("numthreads") != -1: + start, end = cgu.enclose_start_end("(", ")", attrib, 0) + xyz = attrib[start:end].split(",") + numthreads = [] + for a in xyz: + numthreads.append(int(a.strip())) + output_pipeline["numthreads"] = numthreads + + # set non zero error codes to track failures + if shader_info["error_code"] != 0: + output_pipeline["error_code"] = shader_info["error_code"] # build pipeline layout output_pipeline["pipeline_layout"] = generate_pipeline_layout(output_pmfx, pipeline, resources) @@ -1340,9 +1355,13 @@ def generate_pmfx(file, root): pipeline = pipelines[pipeline_key] for stage in get_shader_stages(): if stage in pipeline: - stage_shader = (stage, pipeline[stage]) - if stage_shader not in shader_list: - shader_list.append(stage_shader) + if type(pipeline[stage]) is list: + for shader in pipeline[stage]: + stage_shader = (stage, shader) + else: + stage_shader = (stage, pipeline[stage]) + if stage_shader not in shader_list: + shader_list.append(stage_shader) # gather permutations permutation_jobs = [] @@ -1357,8 +1376,13 @@ def generate_pmfx(file, root): pipeline_jobs.append((pipeline_key, id)) for stage in get_shader_stages(): if stage in pipeline: - permutation_jobs.append( - pool.apply_async(generate_shader_info_permutation, (pmfx, pipeline[stage], stage, permute, define_list))) + if type(pipeline[stage]) is list: + for shader in pipeline[stage]: + permutation_jobs.append( + pool.apply_async(generate_shader_info_permutation, (pmfx, shader, stage, permute, define_list))) + else: + permutation_jobs.append( + pool.apply_async(generate_shader_info_permutation, (pmfx, pipeline[stage], stage, permute, define_list))) # wait on shader permutations shaders = dict()