diff --git a/include/vuk/runtime/vk/Program.hpp b/include/vuk/runtime/vk/Program.hpp index 417cfe29..283f57eb 100644 --- a/include/vuk/runtime/vk/Program.hpp +++ b/include/vuk/runtime/vk/Program.hpp @@ -114,6 +114,7 @@ namespace vuk { private: void flatten_bindings(); + Descriptors& ensure_set(size_t set_index); }; struct ShaderModule { diff --git a/src/runtime/vk/Program.cpp b/src/runtime/vk/Program.cpp index 42441bd4..c88d1dd7 100644 --- a/src/runtime/vk/Program.cpp +++ b/src/runtime/vk/Program.cpp @@ -178,6 +178,16 @@ namespace vuk { } } + Program::Descriptors& Program::ensure_set(size_t set) { + if (set >= sets.size()) { + sets.resize(set + 1, std::nullopt); + } + if (!sets[set]) { + sets[set] = Descriptors{}; + } + return *sets[set]; + } + VkShaderStageFlagBits Program::introspect(const uint32_t* ir, size_t word_count) { spirv_cross::Compiler refl(ir, word_count); auto resources = refl.get_shader_resources(); @@ -249,11 +259,8 @@ namespace vuk { reflect_members(refl, refl.get_type(ub.type_id), un.members); } un.size = refl.get_declared_struct_size(type); - if (set >= sets.size()) { - sets.resize(set + 1, std::nullopt); - sets[set] = Descriptors{}; - } - sets[set]->bindings.push_back(un); + + ensure_set(set).bindings.push_back(un); } for (auto& sb : resources.storage_buffers) { @@ -275,11 +282,8 @@ namespace vuk { un.is_hlsl_counter_buffer = refl.buffer_is_hlsl_counter_buffer(sb.id); un.non_writable = refl.get_decoration(sb.id, spv::DecorationNonWritable); un.non_readable = refl.get_decoration(sb.id, spv::DecorationNonReadable); - if (set >= sets.size()) { - sets.resize(set + 1, std::nullopt); - sets[set] = Descriptors{}; - } - sets[set]->bindings.push_back(un); + + ensure_set(set).bindings.push_back(un); } for (auto& si : resources.sampled_images) { @@ -293,11 +297,8 @@ namespace vuk { // maybe spirv cross bug? t.array_size = type.array.size() == 1 ? (type.array[0] == 1 ? 0 : type.array[0]) : -1; t.shadow = type.image.depth; - if (set >= sets.size()) { - sets.resize(set + 1, std::nullopt); - sets[set] = Descriptors{}; - } - sets[set]->bindings.push_back(t); + + ensure_set(set).bindings.push_back(t); } for (auto& sa : resources.separate_samplers) { @@ -311,11 +312,8 @@ namespace vuk { // maybe spirv cross bug? t.array_size = type.array.size() == 1 ? (type.array[0] == 1 ? 0 : type.array[0]) : -1; t.shadow = type.image.depth; - if (set >= sets.size()) { - sets.resize(set + 1, std::nullopt); - sets[set] = Descriptors{}; - } - sets[set]->bindings.push_back(t); + + ensure_set(set).bindings.push_back(t); } for (auto& si : resources.separate_images) { @@ -328,11 +326,8 @@ namespace vuk { t.stage = stage; // maybe spirv cross bug? t.array_size = type.array.size() == 1 ? (type.array[0] == 1 ? 0 : type.array[0]) : -1; - if (set >= sets.size()) { - sets.resize(set + 1, std::nullopt); - sets[set] = Descriptors{}; - } - sets[set]->bindings.push_back(t); + + ensure_set(set).bindings.push_back(t); } for (auto& sb : resources.storage_images) { @@ -347,11 +342,8 @@ namespace vuk { un.non_readable = refl.get_decoration(sb.id, spv::DecorationNonReadable); // maybe spirv cross bug? un.array_size = type.array.size() == 1 ? (type.array[0] == 1 ? 0 : type.array[0]) : -1; - if (set >= sets.size()) { - sets.resize(set + 1, std::nullopt); - sets[set] = Descriptors{}; - } - sets[set]->bindings.push_back(un); + + ensure_set(set).bindings.push_back(un); } // subpass inputs @@ -363,11 +355,8 @@ namespace vuk { s.name = std::string(si.name.c_str()); s.binding = binding; s.stage = stage; - if (set >= sets.size()) { - sets.resize(set + 1, std::nullopt); - sets[set] = Descriptors{}; - } - sets[set]->bindings.push_back(s); + + ensure_set(set).bindings.push_back(s); } // ASs @@ -380,11 +369,8 @@ namespace vuk { s.binding = binding; s.stage = stage; s.array_size = type.array.size() == 1 ? (type.array[0] == 1 ? 0 : type.array[0]) : -1; - if (set >= sets.size()) { - sets.resize(set + 1, std::nullopt); - sets[set] = Descriptors{}; - } - sets[set]->bindings.push_back(s); + + ensure_set(set).bindings.push_back(s); } for (auto& sc : refl.get_specialization_constants()) {