From 634b04c517400546c53beb04091b64e923da6c23 Mon Sep 17 00:00:00 2001 From: Frodo Baggins Date: Wed, 11 Dec 2024 00:26:03 -0800 Subject: [PATCH] refactor and handle wider DS instructions --- src/shader_recompiler/ir/ir_emitter.cpp | 13 ++ src/shader_recompiler/ir/ir_emitter.h | 2 + .../ir/passes/hull_shader_transform.cpp | 193 ++++++++---------- 3 files changed, 105 insertions(+), 103 deletions(-) diff --git a/src/shader_recompiler/ir/ir_emitter.cpp b/src/shader_recompiler/ir/ir_emitter.cpp index 0a4dd039d..b216d325d 100644 --- a/src/shader_recompiler/ir/ir_emitter.cpp +++ b/src/shader_recompiler/ir/ir_emitter.cpp @@ -574,6 +574,19 @@ Value IREmitter::CompositeConstruct(const Value& e1, const Value& e2, const Valu } } +Value IREmitter::CompositeConstruct(std::span elements) { + switch (elements.size()) { + case 2: + return CompositeConstruct(elements[0], elements[1]); + case 3: + return CompositeConstruct(elements[0], elements[1], elements[2]); + case 4: + return CompositeConstruct(elements[0], elements[1], elements[2], elements[3]); + default: + UNREACHABLE_MSG("Composite construct with greater than 4 elements"); + } +} + Value IREmitter::CompositeExtract(const Value& vector, size_t element) { const auto read{[&](Opcode opcode, size_t limit) -> Value { if (element >= limit) { diff --git a/src/shader_recompiler/ir/ir_emitter.h b/src/shader_recompiler/ir/ir_emitter.h index b6cafa0e0..3bd6ef1ec 100644 --- a/src/shader_recompiler/ir/ir_emitter.h +++ b/src/shader_recompiler/ir/ir_emitter.h @@ -148,6 +148,8 @@ public: [[nodiscard]] Value CompositeConstruct(const Value& e1, const Value& e2, const Value& e3); [[nodiscard]] Value CompositeConstruct(const Value& e1, const Value& e2, const Value& e3, const Value& e4); + [[nodiscard]] Value CompositeConstruct(std::span values); + [[nodiscard]] Value CompositeExtract(const Value& vector, size_t element); [[nodiscard]] Value CompositeInsert(const Value& vector, const Value& object, size_t element); diff --git a/src/shader_recompiler/ir/passes/hull_shader_transform.cpp b/src/shader_recompiler/ir/passes/hull_shader_transform.cpp index 11c58025b..b45ee6db6 100644 --- a/src/shader_recompiler/ir/passes/hull_shader_transform.cpp +++ b/src/shader_recompiler/ir/passes/hull_shader_transform.cpp @@ -257,8 +257,8 @@ private: enum class AttributeRegion : u32 { InputCP, OutputCP, PatchConst }; -static AttributeRegion FindRegionKind(IR::Inst* ring_access, const Shader::Info& info, - const Shader::RuntimeInfo& runtime_info) { +static AttributeRegion GetAttributeRegionKind(IR::Inst* ring_access, const Shader::Info& info, + const Shader::RuntimeInfo& runtime_info) { u32 count = ring_access->Flags(); if (count == 0) { return AttributeRegion::InputCP; @@ -327,6 +327,21 @@ static IR::U32 TryOptimizeAddressModulo(IR::U32 addr, u32 stride, IR::IREmitter& return addr; } +// Read a TCS input (InputCP region) or TES input (OutputCP region) +static IR::F32 ReadTessInputComponent(IR::U32 addr, const u32 stride, IR::IREmitter& ir, + u32 off_dw) { + if (off_dw > 0) { + addr = ir.IAdd(addr, ir.Imm32(off_dw)); + } + const IR::U32 control_point_index = ir.IDiv(addr, ir.Imm32(stride)); + const IR::U32 addr_for_attrs = TryOptimizeAddressModulo(addr, stride, ir); + const IR::U32 attr_index = + ir.ShiftRightLogical(ir.IMod(addr_for_attrs, ir.Imm32(stride)), ir.Imm32(4u)); + const IR::U32 comp_index = + ir.ShiftRightLogical(ir.BitwiseAnd(addr_for_attrs, ir.Imm32(0xFU)), ir.Imm32(2u)); + return ir.GetTessGenericAttribute(control_point_index, attr_index, comp_index); +} + } // namespace void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { @@ -391,95 +406,76 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { break; } - // case IR::Opcode::WriteSharedU128: // TODO case IR::Opcode::WriteSharedU32: - case IR::Opcode::WriteSharedU64: { - // DumpIR(program, "before_walk"); - // RingAddressInfo address_info = pass.WalkRingAccess(&inst, ir); - + case IR::Opcode::WriteSharedU64: + case IR::Opcode::WriteSharedU128: { const u32 num_dwords = opcode == IR::Opcode::WriteSharedU32 ? 1 : (opcode == IR::Opcode::WriteSharedU64 ? 2 : 4); const IR::U32 addr{inst.Arg(0)}; - const IR::U32 data{inst.Arg(1)}; - const auto [data_lo, data_hi] = [&] -> std::pair { - if (num_dwords == 1) { - return {IR::U32{data}, IR::U32{}}; - } - const auto* prod = data.InstRecursive(); - return {IR::U32{prod->Arg(0)}, IR::U32{prod->Arg(1)}}; - }(); + const IR::U32 data{inst.Arg(1).Resolve()}; const auto SetOutput = [&](IR::U32 addr, IR::U32 value, AttributeRegion output_kind, - u32 off_dw = 0) { - const IR::F32 data = ir.BitCast(value); - if (off_dw > 0) { - addr = ir.IAdd(addr, ir.Imm32(off_dw)); - } + u32 off_dw) { + const IR::F32 data_component = ir.BitCast(value); if (output_kind == AttributeRegion::OutputCP) { + if (off_dw > 0) { + addr = ir.IAdd(addr, ir.Imm32(off_dw)); + } + u32 stride = runtime_info.hs_info.hs_output_cp_stride; // Invocation ID array index is implicit, handled by SPIRV backend - IR::U32 addr_for_attrs = TryOptimizeAddressModulo( - addr, runtime_info.hs_info.hs_output_cp_stride, ir); - - IR::U32 attr_index = ir.ShiftRightLogical( - ir.IMod(addr_for_attrs, - ir.Imm32(runtime_info.hs_info.hs_output_cp_stride)), - ir.Imm32(4u)); - IR::U32 comp_index = ir.ShiftRightLogical( + const IR::U32 addr_for_attrs = TryOptimizeAddressModulo(addr, stride, ir); + const IR::U32 attr_index = ir.ShiftRightLogical( + ir.IMod(addr_for_attrs, ir.Imm32(stride)), ir.Imm32(4u)); + const IR::U32 comp_index = ir.ShiftRightLogical( ir.BitwiseAnd(addr_for_attrs, ir.Imm32(0xFU)), ir.Imm32(2u)); - ir.SetTcsGenericAttribute(data, attr_index, comp_index); + ir.SetTcsGenericAttribute(data_component, attr_index, comp_index); } else { ASSERT(output_kind == AttributeRegion::PatchConst); ASSERT_MSG(addr.IsImmediate(), "patch addr non imm, inst {}", fmt::ptr(addr.Inst())); - ir.SetPatch(IR::PatchGeneric(addr.U32() >> 2), data); + ir.SetPatch(IR::PatchGeneric((addr.U32() >> 2) + off_dw), data_component); } }; - AttributeRegion region = FindRegionKind(&inst, info, runtime_info); - SetOutput(addr, data, region); - if (num_dwords > 1) { - // TODO handle WriteSharedU128 - SetOutput(addr, data_hi, region, 1); + AttributeRegion region = GetAttributeRegionKind(&inst, info, runtime_info); + if (num_dwords == 1) { + SetOutput(addr, data, region, 0); + } else { + for (auto i = 0; i < num_dwords; i++) { + SetOutput(addr, IR::U32{data.Inst()->Arg(i)}, region, i); + } } inst.Invalidate(); break; } case IR::Opcode::LoadSharedU32: { - // case IR::Opcode::LoadSharedU64: - // case IR::Opcode::LoadSharedU128: + case IR::Opcode::LoadSharedU64: + case IR::Opcode::LoadSharedU128: const IR::U32 addr{inst.Arg(0)}; - AttributeRegion region = FindRegionKind(&inst, info, runtime_info); - - ASSERT(region == AttributeRegion::InputCP || region == AttributeRegion::OutputCP); - switch (region) { - case AttributeRegion::InputCP: { - IR::U32 control_point_index = - ir.IDiv(addr, ir.Imm32(runtime_info.hs_info.ls_stride)); - - IR::U32 addr_for_attrs = - TryOptimizeAddressModulo(addr, runtime_info.hs_info.ls_stride, ir); - - IR::U32 attr_index = ir.ShiftRightLogical( - ir.IMod(addr_for_attrs, ir.Imm32(runtime_info.hs_info.ls_stride)), - ir.Imm32(4u)); - IR::U32 comp_index = ir.ShiftRightLogical( - ir.BitwiseAnd(addr_for_attrs, ir.Imm32(0xFU)), ir.Imm32(2u)); - IR::Value attr_read = - ir.GetTessGenericAttribute(control_point_index, attr_index, comp_index); - attr_read = ir.BitCast(IR::F32{attr_read}); - inst.ReplaceUsesWithAndRemove(attr_read); - break; - } - case AttributeRegion::OutputCP: { - UNREACHABLE_MSG("Unhandled output control point read"); - break; - } - default: - break; + AttributeRegion region = GetAttributeRegionKind(&inst, info, runtime_info); + const u32 num_dwords = opcode == IR::Opcode::LoadSharedU32 + ? 1 + : (opcode == IR::Opcode::LoadSharedU64 ? 2 : 4); + ASSERT_MSG(region == AttributeRegion::InputCP, + "Unhandled read of output or patchconst attribute in hull shader"); + IR::Value attr_read; + if (num_dwords == 1) { + attr_read = ir.BitCast( + ReadTessInputComponent(addr, runtime_info.hs_info.ls_stride, ir, 0)); + } else { + boost::container::static_vector read_components; + for (auto i = 0; i < num_dwords; i++) { + const IR::F32 component = + ReadTessInputComponent(addr, runtime_info.hs_info.ls_stride, ir, i); + read_components.push_back(ir.BitCast(component)); + } + attr_read = ir.CompositeConstruct(read_components); } + inst.ReplaceUsesWithAndRemove(attr_read); + break; } default: @@ -534,44 +530,34 @@ void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { const auto opcode = inst.GetOpcode(); switch (inst.GetOpcode()) { case IR::Opcode::LoadSharedU32: { - // case IR::Opcode::LoadSharedU64: - // case IR::Opcode::LoadSharedU128: // TODO + case IR::Opcode::LoadSharedU64: + case IR::Opcode::LoadSharedU128: const IR::U32 addr{inst.Arg(0)}; - AttributeRegion region = FindRegionKind(&inst, info, runtime_info); - - ASSERT(region == AttributeRegion::OutputCP || - region == AttributeRegion::PatchConst); - switch (region) { - case AttributeRegion::OutputCP: { - IR::U32 control_point_index = - ir.IDiv(addr, ir.Imm32(runtime_info.vs_info.hs_output_cp_stride)); - - IR::U32 addr_for_attrs = TryOptimizeAddressModulo( - addr, runtime_info.vs_info.hs_output_cp_stride, ir); - - IR::U32 attr_index = ir.ShiftRightLogical( - ir.IMod(addr_for_attrs, ir.Imm32(runtime_info.vs_info.hs_output_cp_stride)), - ir.Imm32(4u)); - IR::U32 comp_index = ir.ShiftRightLogical( - ir.BitwiseAnd(addr_for_attrs, ir.Imm32(0xFU)), ir.Imm32(2u)); - IR::Value attr_read = - ir.GetTessGenericAttribute(control_point_index, attr_index, comp_index); - attr_read = ir.BitCast(IR::F32{attr_read}); - inst.ReplaceUsesWithAndRemove(attr_read); - break; + AttributeRegion region = GetAttributeRegionKind(&inst, info, runtime_info); + const u32 num_dwords = opcode == IR::Opcode::LoadSharedU32 + ? 1 + : (opcode == IR::Opcode::LoadSharedU64 ? 2 : 4); + const auto GetInput = [&](IR::U32 addr, u32 off_dw) -> IR::F32 { + if (region == AttributeRegion::OutputCP) { + return ReadTessInputComponent( + addr, runtime_info.vs_info.hs_output_cp_stride, ir, off_dw); + } else { + ASSERT(region == AttributeRegion::PatchConst); + return ir.GetPatch(IR::PatchGeneric((addr.U32() >> 2) + off_dw)); + } + }; + IR::Value attr_read; + if (num_dwords == 1) { + attr_read = ir.BitCast(GetInput(addr, 0)); + } else { + boost::container::static_vector read_components; + for (auto i = 0; i < num_dwords; i++) { + const IR::F32 component = GetInput(addr, i); + read_components.push_back(ir.BitCast(component)); + } + attr_read = ir.CompositeConstruct(read_components); } - case AttributeRegion::PatchConst: { - // TODO if assert fails then make generic patch attrs into array and dyn index - ASSERT_MSG(addr.IsImmediate(), "patch addr non imm, inst {}", - fmt::ptr(addr.Inst())); - IR::Value get_patch = ir.GetPatch(IR::PatchGeneric(addr.U32() >> 2)); - inst.ReplaceUsesWithAndRemove(get_patch); - break; - } - default: - break; - } - + inst.ReplaceUsesWithAndRemove(attr_read); break; } default: @@ -638,8 +624,8 @@ void TessellationPreprocess(IR::Program& program, RuntimeInfo& runtime_info) { auto sharp_location = FindTessConstantSharp(&inst); if (sharp_location && sharp_location->ptr_base == info.tess_consts_ptr_base && sharp_location->dword_off == info.tess_consts_dword_offset) { - // Replace the load with a special attribute load (for readability and easier - // pattern matching) + // Replace the load with a special attribute load (for readability and + // easier pattern matching) IR::Value index = inst.Arg(1); ASSERT_MSG(index.IsImmediate(), @@ -766,6 +752,7 @@ void TessellationPostprocess(IR::Program& program, RuntimeInfo& runtime_info) { } } + // TODO delete for (IR::Block* block : program.blocks) { for (IR::Inst& inst : block->Instructions()) { switch (inst.GetOpcode()) {