Skip to content

Commit

Permalink
refactor sync keying
Browse files Browse the repository at this point in the history
  • Loading branch information
martty committed Dec 26, 2024
1 parent 9cfa722 commit d14ebca
Showing 1 changed file with 39 additions and 44 deletions.
83 changes: 39 additions & 44 deletions src/runtime/vk/Backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -839,17 +839,45 @@ namespace vuk {
return nullptr;
}

void init_sync(Type* base_ty, StreamResourceUse src_use, void* value, bool enforce_unique = true) {
uint64_t value_identity(Type* base_ty, void* value) {
uint64_t key = 0;
PartialStreamResourceUse psru{ src_use };
if (base_ty->hash_value == current_module->types.builtin_image) {
auto& img_att = *reinterpret_cast<ImageAttachment*>(value);
key = reinterpret_cast<uint64_t>(img_att.image.image);
psru.subrange.image = { img_att.base_level, img_att.level_count, img_att.base_layer, img_att.layer_count };
} else if (base_ty->hash_value == current_module->types.builtin_buffer) {
auto buf = reinterpret_cast<Buffer*>(value);
key = reinterpret_cast<uint64_t>(buf->allocation);
hash_combine(key, buf->offset);
} else if (base_ty->kind == Type::ARRAY_TY) {
if (base_ty->array.count > 0) { // for an array, we key off the the first element, as the array syncs together
auto elem_ty = base_ty->array.T->get();
auto elems = reinterpret_cast<std::byte*>(value);
return value_identity(elem_ty, elems);
} else { // zero-len arrays
return 0;
}
} else if (base_ty->hash_value == current_module->types.builtin_sampled_image) { // only image syncs
auto& img_att = reinterpret_cast<SampledImage*>(value)->ia;
key = reinterpret_cast<uint64_t>(img_att.image.image);
} else if (base_ty->kind == Type::INTEGER_TY) { // TODO: generalise
return 0;
} else if (base_ty->kind == Type::POINTER_TY) {
key = reinterpret_cast<ptr_base*>(value)->device_address;
} else if (base_ty->is_bufferlike_view()) {
auto& v = *reinterpret_cast<view<BufferLike<void>>*>(value);
key = v.ptr.device_address;
hash_combine(key, v.count);
} else { // other types just key on the voidptr
key = reinterpret_cast<uint64_t>(value);
}
return key;
}

void init_sync(Type* base_ty, StreamResourceUse src_use, void* value, bool enforce_unique = true) {
PartialStreamResourceUse psru{ src_use };
if (base_ty->hash_value == current_module->types.builtin_image) {
auto& img_att = *reinterpret_cast<ImageAttachment*>(value);
psru.subrange.image = { img_att.base_level, img_att.level_count, img_att.base_layer, img_att.layer_count };
} else if (base_ty->kind == Type::ARRAY_TY) { // for an array, we init all elements
auto elem_ty = base_ty->array.T->get();
auto size = base_ty->array.count;
Expand All @@ -859,12 +887,10 @@ namespace vuk {
elems += elem_ty->size;
}
return;
} else if (base_ty->kind == Type::POINTER_TY) {
key = reinterpret_cast<ptr_base*>(value)->device_address;
} else { // other types just key on the voidptr
key = reinterpret_cast<uint64_t>(value);
}

uint64_t key = value_identity(base_ty, value);

if (enforce_unique) {
assert(last_modify.find(key) == last_modify.end());
last_modify.emplace(key, new (this->arena.ensure_space(sizeof(PartialStreamResourceUse))) PartialStreamResourceUse(psru));
Expand All @@ -879,7 +905,6 @@ namespace vuk {
}
auto& dst_use = *maybe_dst_use;

uint64_t key = 0;
if (base_ty->kind == Type::ARRAY_TY) {
auto elem_ty = base_ty->array.T->get();
auto size = base_ty->array.count;
Expand All @@ -889,20 +914,15 @@ namespace vuk {
elems += elem_ty->size;
}
return;
} else if (base_ty->hash_value == current_module->types.builtin_image) {
auto& img_att = *reinterpret_cast<ImageAttachment*>(value);
key = reinterpret_cast<uint64_t>(img_att.image.image);
} else if (base_ty->hash_value == current_module->types.builtin_buffer) {
auto buf = reinterpret_cast<Buffer*>(value);
key = reinterpret_cast<uint64_t>(buf->allocation);
hash_combine(key, buf->offset);
} else if (base_ty->hash_value == current_module->types.builtin_sampled_image) { // sync the image
auto& img_att = reinterpret_cast<SampledImage*>(value)->ia;
add_sync(current_module->types.get_builtin_image().get(), dst_use, &img_att);
return;
} else if (base_ty->kind == Type::POINTER_TY) {
key = reinterpret_cast<ptr_base*>(value)->device_address;
} else { // no other types require sync
}

uint64_t key = value_identity(base_ty, value);

if (key == 0) { // doesn't require sync
return;
}

Expand Down Expand Up @@ -976,32 +996,7 @@ namespace vuk {
}

StreamResourceUse& last_use(Type* base_ty, void* value) {
uint64_t key = 0;
if (base_ty->hash_value == current_module->types.builtin_image) {
auto& img_att = *reinterpret_cast<ImageAttachment*>(value);
key = reinterpret_cast<uint64_t>(img_att.image.image);
} else if (base_ty->hash_value == current_module->types.builtin_buffer) {
auto buf = reinterpret_cast<Buffer*>(value);
key = reinterpret_cast<uint64_t>(buf->allocation);
hash_combine(key, buf->offset);
} else if (base_ty->kind == Type::ARRAY_TY) {
if (base_ty->array.count > 0) { // for an array, we key off the the first element, as the array syncs together
auto elem_ty = base_ty->array.T->get();
auto elems = reinterpret_cast<std::byte*>(value);
return last_use(elem_ty, elems);
} else { // zero-len arrays
return *last_modify.at(0);
}
} else if (base_ty->hash_value == current_module->types.builtin_sampled_image) { // only image syncs
auto& img_att = reinterpret_cast<SampledImage*>(value)->ia;
key = reinterpret_cast<uint64_t>(img_att.image.image);
} else if (base_ty->kind == Type::INTEGER_TY) { // TODO: generalise
return *last_modify.at(0);
} else if (base_ty->kind == Type::POINTER_TY) {
key = reinterpret_cast<ptr_base*>(value)->device_address;
} else { // other types just key on the voidptr
key = reinterpret_cast<uint64_t>(value);
}
uint64_t key = value_identity(base_ty, value);

return *last_modify.at(key);
}
Expand Down

0 comments on commit d14ebca

Please sign in to comment.