more refactors

This commit is contained in:
Frodo Baggins 2024-12-06 19:56:01 -08:00
parent 3af854539a
commit 8058377e53

View File

@ -380,208 +380,10 @@ static AttributeRegion FindRegionKind(IR::Inst* ring_access, const Shader::Info&
}
}
static IR::Program* g_program; // TODO delete
struct RingAddressInfo {
AttributeRegion region{};
u32 attribute_byte_offset{};
// For InputCP and OutputCP, offset from the start of the patch's memory (including
// attribute_byte_offset) For PatchConst, not relevant
IR::U32 offset_in_patch{IR::Value(0u)};
};
class Pass {
public:
Pass(Info& info_, RuntimeInfo& runtime_info_) : info(info_), runtime_info(runtime_info_) {
InitTessConstants(info.tess_consts_ptr_base, info.tess_consts_dword_offset, info,
runtime_info, tess_constants);
}
RingAddressInfo WalkRingAccess(IR::Inst* access, IR::IREmitter& insert_point) {
Reset();
RingAddressInfo address_info{};
IR::Value addr;
switch (access->GetOpcode()) {
case IR::Opcode::LoadSharedU32:
case IR::Opcode::LoadSharedU64:
case IR::Opcode::LoadSharedU128:
case IR::Opcode::WriteSharedU32:
case IR::Opcode::WriteSharedU64:
case IR::Opcode::WriteSharedU128:
addr = access->Arg(0);
break;
case IR::Opcode::StoreBufferU32:
case IR::Opcode::StoreBufferU32x2:
case IR::Opcode::StoreBufferU32x3:
case IR::Opcode::StoreBufferU32x4:
addr = access->Arg(1);
break;
default:
UNREACHABLE();
}
products.emplace_back(addr);
Visit(addr);
FindIndexInfo(address_info, insert_point);
return address_info;
}
private:
void Reset() {
within_mul = false;
products.clear();
}
void Visit(IR::Value node) {
IR::Value a, b, c;
if (MakeInstPattern<IR::Opcode::IMul32>(MatchValue(a), MatchValue(b)).DoMatch(node)) {
bool saved_within_mul = within_mul;
within_mul = true;
Visit(a);
Visit(b);
within_mul = saved_within_mul;
} else if (MakeInstPattern<IR::Opcode::IAdd32>(MatchValue(a), MatchValue(b))
.DoMatch(node)) {
if (within_mul) {
UNREACHABLE_MSG("Test");
products.back().as_factors.emplace_back(IR::U32{node});
} else {
products.back().as_nested_value = IR::U32{a};
Visit(a);
products.emplace_back(b);
Visit(b);
}
} else if (MakeInstPattern<IR::Opcode::ShiftLeftLogical32>(MatchValue(a), MatchImm(b))
.DoMatch(node)) {
products.back().as_factors.emplace_back(IR::Value(u32(2 << (b.U32() - 1))));
Visit(a);
} else if (MakeInstPattern<IR::Opcode::GetAttributeU32>(MatchValue(a), MatchU32(0))
.DoMatch(node)) {
products.back().as_factors.emplace_back(a.Attribute());
} else if (MakeInstPattern<IR::Opcode::GetAttributeU32>(MatchValue(a), MatchU32(0))
.DoMatch(node)) {
products.back().as_factors.emplace_back(a);
} else if (MakeInstPattern<IR::Opcode::BitFieldSExtract>(MatchValue(a), MatchIgnore(),
MatchIgnore())
.DoMatch(node)) {
Visit(a);
} else if (MakeInstPattern<IR::Opcode::BitFieldUExtract>(MatchValue(a), MatchIgnore(),
MatchIgnore())
.DoMatch(node)) {
Visit(a);
} else if (MakeInstPattern<IR::Opcode::BitCastF32U32>(MatchValue(a)).DoMatch(node)) {
return Visit(a);
} else if (MakeInstPattern<IR::Opcode::BitCastU32F32>(MatchValue(a)).DoMatch(node)) {
return Visit(a);
} else if (node.TryInstRecursive() &&
node.InstRecursive()->GetOpcode() == IR::Opcode::Phi) {
DEBUG_ASSERT(false && "Phi test");
products.back().as_factors.emplace_back(node);
} else {
products.back().as_factors.emplace_back(node);
}
}
void FindIndexInfo(RingAddressInfo& address_info, IR::IREmitter& ir) {
// infer which attribute base the address is indexing
// by how many addends are multiplied by TessellationDataConstantBuffer::m_hsNumPatch.
// Also handle m_hsOutputBase or m_patchConstBase
u32 region_count = 0;
// Remove addends except for the attribute offset and possibly the
// control point index calc
std::erase_if(products, [&](Product& p) {
for (IR::Value& value : p.as_factors) {
if (value.Type() == IR::Type::Attribute) {
if (value.Attribute() == IR::Attribute::TcsNumPatches ||
value.Attribute() == IR::Attribute::TcsOutputBase) {
++region_count;
return true;
} else if (value.Attribute() == IR::Attribute::TcsPatchConstBase) {
region_count += 2;
return true;
} else if (value.Attribute() == IR::Attribute::TessPatchIdInVgt) {
return true;
}
}
}
return false;
});
// DumpIR(*g_program, "before_crash");
// Look for some term with a dynamic index (should be the control point index)
for (auto i = 0; i < products.size(); i++) {
auto& factors = products[i].as_factors;
// Remember this as the index term
if (std::any_of(factors.begin(), factors.end(), [&](const IR::Value& v) {
return !v.IsImmediate() || v.Type() == IR::Type::Attribute;
})) {
address_info.offset_in_patch =
ir.IAdd(address_info.offset_in_patch, products[i].as_nested_value);
} else {
ASSERT_MSG(factors.size() == 1, "factors all const but not const folded");
// Otherwise assume it contributes to the attribute
address_info.offset_in_patch =
ir.IAdd(address_info.offset_in_patch, IR::U32{factors[0]});
address_info.attribute_byte_offset += factors[0].U32();
}
}
if (region_count == 0) {
address_info.region = AttributeRegion::InputCP;
} else if (info.l_stage == LogicalStage::TessellationControl &&
runtime_info.hs_info.IsPassthrough()) {
ASSERT(region_count <= 1);
address_info.region = AttributeRegion::PatchConst;
} else {
ASSERT(region_count <= 2);
address_info.region = AttributeRegion(region_count);
}
}
Info& info;
RuntimeInfo& runtime_info;
TessellationDataConstantBuffer tess_constants;
bool within_mul{};
// One product in the sum of products making up an address
struct Product {
Product(IR::Value val_) : as_nested_value(val_), as_factors() {}
Product(const Product& other) = default;
~Product() = default;
// IR value used as an addend in address calc
IR::U32 as_nested_value;
// all the leaves that feed the multiplication, linear
// TODO small_vector
// boost::container::small_vector<IR::Value, 4> as_factors;
std::vector<IR::Value> as_factors;
};
std::vector<Product> products;
};
} // namespace
void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
TessConstantUseWalker walker;
g_program = &program; // TODO delete
Info& info = program.info;
Pass pass(info, runtime_info);
for (IR::Block* block : program.blocks) {
for (IR::Inst& inst : block->Instructions()) {
if (inst.GetOpcode() == IR::Opcode::GetAttributeU32) {
walker.MarkTessAttributeUsers(&inst);
}
}
}
for (IR::Block* block : program.blocks) {
for (IR::Inst& inst : block->Instructions()) {
@ -646,7 +448,7 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
case IR::Opcode::WriteSharedU32:
case IR::Opcode::WriteSharedU64: {
// DumpIR(program, "before_walk");
RingAddressInfo address_info = pass.WalkRingAccess(&inst, ir);
// RingAddressInfo address_info = pass.WalkRingAccess(&inst, ir);
const u32 num_dwords = opcode == IR::Opcode::WriteSharedU32
? 1
@ -669,7 +471,6 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
}
if (output_kind == AttributeRegion::OutputCP) {
#if 0
// Invocation ID array index is implicit, handled by SPIRV backend
IR::U32 attr_index = ir.ShiftRightLogical(
ir.IMod(addr, ir.GetAttributeU32(IR::Attribute::TcsCpStride)),
@ -677,17 +478,11 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
IR::U32 comp_index =
ir.ShiftRightLogical(ir.BitwiseAnd(addr, ir.Imm32(0xFU)), ir.Imm32(2u));
ir.SetTcsGenericAttribute(data, attr_index, comp_index);
#else
u32 offset_dw = address_info.attribute_byte_offset >> 2;
const u32 param = offset_dw >> 2;
const u32 comp = offset_dw & 3;
// Invocation ID array index is implicit, handled by SPIRV backend
ir.SetTcsGenericAttribute(data, ir.Imm32(param), ir.Imm32(comp));
#endif
} else {
ASSERT(output_kind == AttributeRegion::PatchConst);
ir.SetPatch(IR::PatchGeneric(address_info.attribute_byte_offset >> 2),
data);
ASSERT_MSG(addr.IsImmediate(), "patch addr non imm, inst {}",
fmt::ptr(addr.Inst()));
ir.SetPatch(IR::PatchGeneric(addr.U32() >> 2), data);
}
};
@ -704,20 +499,21 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
case IR::Opcode::LoadSharedU32: {
// case IR::Opcode::LoadSharedU64:
// case IR::Opcode::LoadSharedU128:
RingAddressInfo address_info = pass.WalkRingAccess(&inst, ir);
const IR::U32 addr{inst.Arg(0)};
AttributeRegion region = FindRegionKind(&inst, info, runtime_info);
ASSERT(address_info.region == AttributeRegion::InputCP ||
address_info.region == AttributeRegion::OutputCP);
switch (address_info.region) {
ASSERT(region == AttributeRegion::InputCP || region == AttributeRegion::OutputCP);
switch (region) {
case AttributeRegion::InputCP: {
u32 offset_dw =
(address_info.attribute_byte_offset % runtime_info.hs_info.ls_stride) >> 2;
const u32 attr_no = offset_dw >> 2;
const u32 comp = offset_dw & 3;
IR::U32 control_point_index = ir.IDiv(IR::U32{address_info.offset_in_patch},
ir.Imm32(runtime_info.hs_info.ls_stride));
IR::Value attr_read = ir.GetTessGenericAttribute(
control_point_index, ir.Imm32(attr_no), ir.Imm32(comp));
IR::U32 control_point_index =
ir.IDiv(addr, ir.Imm32(runtime_info.hs_info.ls_stride));
IR::U32 attr_index = ir.ShiftRightLogical(
ir.IMod(addr, ir.GetAttributeU32(IR::Attribute::TcsLsStride)),
ir.Imm32(4u));
IR::U32 comp_index =
ir.ShiftRightLogical(ir.BitwiseAnd(addr, ir.Imm32(0xFU)), ir.Imm32(2u));
IR::Value attr_read =
ir.GetTessGenericAttribute(control_point_index, attr_index, comp_index);
attr_read = ir.BitCast<IR::U32>(IR::F32{attr_read});
inst.ReplaceUsesWithAndRemove(attr_read);
break;
@ -775,17 +571,7 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
// TODO refactor
void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
TessConstantUseWalker walker;
Info& info = program.info;
Pass pass(info, runtime_info);
for (IR::Block* block : program.blocks) {
for (IR::Inst& inst : block->Instructions()) {
if (inst.GetOpcode() == IR::Opcode::GetAttributeU32) {
walker.MarkTessAttributeUsers(&inst);
}
}
}
for (IR::Block* block : program.blocks) {
for (IR::Inst& inst : block->Instructions()) {
@ -795,8 +581,7 @@ void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
case IR::Opcode::LoadSharedU32: {
// case IR::Opcode::LoadSharedU64:
// case IR::Opcode::LoadSharedU128: // TODO
RingAddressInfo address_info = pass.WalkRingAccess(&inst, ir);
const IR::U32 addr = IR::U32{inst.Arg(0)};
const IR::U32 addr{inst.Arg(0)};
AttributeRegion region = FindRegionKind(&inst, info, runtime_info);
ASSERT(region == AttributeRegion::OutputCP ||
@ -817,9 +602,10 @@ void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
break;
}
case AttributeRegion::PatchConst: {
// TODO make patch consts into dynamic offset
u32 offset_dw = address_info.attribute_byte_offset >> 2;
IR::Value get_patch = ir.GetPatch(IR::PatchGeneric(offset_dw));
// TODO if assert fails then make patch consts into dynamic offset
ASSERT_MSG(addr.IsImmediate(), "patch addr non imm, inst {}",
fmt::ptr(addr.Inst()));
IR::Value get_patch = ir.GetPatch(IR::PatchGeneric(addr.U32() >> 2));
inst.ReplaceUsesWithAndRemove(get_patch);
break;
}
@ -963,6 +749,38 @@ void TessellationPreprocess(IR::Program& program, RuntimeInfo& runtime_info) {
}
}
}
TessConstantUseWalker walker;
for (IR::Block* block : program.blocks) {
for (IR::Inst& inst : block->Instructions()) {
if (inst.GetOpcode() == IR::Opcode::GetAttributeU32) {
walker.MarkTessAttributeUsers(&inst);
}
}
}
for (IR::Block* block : program.blocks) {
for (IR::Inst& inst : block->Instructions()) {
if (inst.GetOpcode() == IR::Opcode::GetAttributeU32) {
switch (inst.Arg(0).Attribute()) {
// Should verify that these are only used in address calculations for attr
// read/write
// Replace with 0 so we can dynamically index the control points within the
// region allocated for this patch (input or output). These terms should only
// contribute to the base address of that region, so replacing with 0 *should*
// be fine
case IR::Attribute::TcsNumPatches:
case IR::Attribute::TcsOutputBase:
case IR::Attribute::TcsPatchConstBase:
case IR::Attribute::TessPatchIdInVgt:
inst.ReplaceUsesWithAndRemove(IR::Value(0u));
break;
default:
break;
}
}
}
}
}
void TessellationPostprocess(IR::Program& program, RuntimeInfo& runtime_info) {
@ -982,21 +800,6 @@ void TessellationPostprocess(IR::Program& program, RuntimeInfo& runtime_info) {
case IR::Attribute::TcsCpStride:
inst.ReplaceUsesWithAndRemove(IR::Value(tess_constants.m_hsCpStride));
break;
// Should verify that these are only used in address calculations for attr
// read/write
// Replace with 0 so we can dynamically index the control points within the
// region allocated for this patch (input or output). These terms should only
// contribute to the base address of that region, so replacing with 0 *should*
// be fine
case IR::Attribute::TcsNumPatches:
case IR::Attribute::TcsOutputBase:
case IR::Attribute::TcsPatchConstBase:
case IR::Attribute::TessPatchIdInVgt:
inst.ReplaceUsesWithAndRemove(IR::Value(0u));
break;
case IR::Attribute::TcsPatchConstSize:
case IR::Attribute::TcsPatchOutputSize:
case IR::Attribute::TcsFirstEdgeTessFactorIndex:
default:
break;
}