Simplify bounds checking in generated SPIR-V

This commit is contained in:
Marcin Mikołajczyk 2025-06-07 19:38:52 +01:00
parent 382ff2fb4a
commit 72553c44d8
2 changed files with 11 additions and 9 deletions

View File

@ -20,24 +20,26 @@ auto AccessBoundsCheck(EmitContext& ctx, Id index, Id buffer_size, auto emit_fun
zero_value = ctx.u16_zero_value;
result_type = ctx.U16;
} else {
static_assert("type not supported");
static_assert(false, "type not supported");
}
if (Sirit::ValidId(buffer_size)) {
// Bounds checking enabled, wrap in a conditional branch to make sure that
// the atomic is not mistakenly executed when the index is out of bounds.
const Id in_bounds = ctx.OpULessThan(ctx.U1[1], index, buffer_size);
const Id ib_label = ctx.OpLabel();
const Id oob_label = ctx.OpLabel();
const Id end_label = ctx.OpLabel();
ctx.OpSelectionMerge(end_label, spv::SelectionControlMask::MaskNone);
ctx.OpBranchConditional(in_bounds, ib_label, oob_label);
ctx.OpBranchConditional(in_bounds, ib_label, end_label);
const auto last_label = ctx.last_label;
ctx.AddLabel(ib_label);
const auto ib_result = emit_func();
ctx.OpBranch(end_label);
ctx.AddLabel(oob_label);
ctx.OpBranch(end_label);
ctx.AddLabel(end_label);
return ctx.OpPhi(result_type, ib_result, ib_label, zero_value, oob_label);
if (Sirit::ValidId(ib_result)) {
return ctx.OpPhi(result_type, ib_result, ib_label, zero_value, last_label);
} else {
return Id{0};
}
}
// Bounds checking not enabled, just perform the atomic operation.
return emit_func();

View File

@ -53,7 +53,7 @@ void EmitWriteSharedU16(EmitContext& ctx, Id offset, Id value) {
const Id pointer =
ctx.OpAccessChain(ctx.shared_u16, ctx.shared_memory_u16, ctx.u32_zero_value, index);
ctx.OpStore(pointer, value);
return ctx.OpUndef(ctx.U16);
return Id{0};
});
}
@ -66,7 +66,7 @@ void EmitWriteSharedU32(EmitContext& ctx, Id offset, Id value) {
const Id pointer =
ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, ctx.u32_zero_value, index);
ctx.OpStore(pointer, value);
return ctx.OpUndef(ctx.U32[1]);
return Id{0};
});
}
@ -79,7 +79,7 @@ void EmitWriteSharedU64(EmitContext& ctx, Id offset, Id value) {
const Id pointer{
ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, ctx.u32_zero_value, index)};
ctx.OpStore(pointer, value);
return ctx.OpUndef(ctx.U64);
return Id{0};
});
}