From fd6611ed54749e0f1d2618a1d8a009d70417e0f0 Mon Sep 17 00:00:00 2001 From: IndecisiveTurtle <47210458+raphaelthegreat@users.noreply.github.com> Date: Fri, 6 Sep 2024 15:43:59 +0300 Subject: [PATCH] shader_recompiler: Separate thread bit scalars * We can assume guest shader never mixes them with normal sgprs. This helps avoid errors where ssa could view an sgpr write dominating a thread bit read, due to how control flow is structurized, even though its not possible in actual control flow --- src/shader_recompiler/ir/basic_block.h | 1 + .../ir/passes/ssa_rewrite_pass.cpp | 46 ++++++++++++++----- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/src/shader_recompiler/ir/basic_block.h b/src/shader_recompiler/ir/basic_block.h index 1eb11469c..11ae969bc 100644 --- a/src/shader_recompiler/ir/basic_block.h +++ b/src/shader_recompiler/ir/basic_block.h @@ -147,6 +147,7 @@ public: /// Intrusively store the value of a register in the block. std::array ssa_sreg_values; + std::array ssa_sbit_values; std::array ssa_vreg_values; bool has_multiple_predecessors{false}; diff --git a/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp b/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp index ea27c64f7..54dce0355 100644 --- a/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp +++ b/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp @@ -44,8 +44,17 @@ struct GotoVariable : FlagTag { u32 index; }; -using Variant = std::variant; +struct ThreadBitScalar : FlagTag { + ThreadBitScalar() = default; + explicit ThreadBitScalar(IR::ScalarReg sgpr_) : sgpr{sgpr_} {} + + auto operator<=>(const ThreadBitScalar&) const noexcept = default; + + IR::ScalarReg sgpr; +}; + +using Variant = std::variant; using ValueMap = std::unordered_map; struct DefTable { @@ -70,6 +79,13 @@ struct DefTable { goto_vars[variable.index].insert_or_assign(block, value); } + const IR::Value& Def(IR::Block* block, ThreadBitScalar variable) { + return block->ssa_sreg_values[RegIndex(variable.sgpr)]; + } + void SetDef(IR::Block* block, ThreadBitScalar variable, const IR::Value& value) { + block->ssa_sreg_values[RegIndex(variable.sgpr)] = value; + } + const IR::Value& Def(IR::Block* block, SccFlagTag) { return scc_flag[block]; } @@ -173,7 +189,7 @@ public: } template - IR::Value ReadVariable(Type variable, IR::Block* root_block, bool is_thread_bit = false) { + IR::Value ReadVariable(Type variable, IR::Block* root_block) { boost::container::small_vector, 64> stack{ ReadState(nullptr), ReadState(root_block), @@ -201,7 +217,7 @@ public: } else if (!block->IsSsaSealed()) { // Incomplete CFG IR::Inst* phi{&*block->PrependNewInst(block->begin(), IR::Opcode::Phi)}; - phi->SetFlags(is_thread_bit ? IR::Type::U1 : IR::TypeOf(UndefOpcode(variable))); + phi->SetFlags(IR::TypeOf(UndefOpcode(variable))); incomplete_phis[block].insert_or_assign(variable, phi); stack.back().result = IR::Value{&*phi}; @@ -214,7 +230,7 @@ public: } else { // Break potential cycles with operandless phi IR::Inst* const phi{&*block->PrependNewInst(block->begin(), IR::Opcode::Phi)}; - phi->SetFlags(is_thread_bit ? IR::Type::U1 : IR::TypeOf(UndefOpcode(variable))); + phi->SetFlags(IR::TypeOf(UndefOpcode(variable))); WriteVariable(variable, block, IR::Value{phi}); @@ -263,9 +279,7 @@ private: template IR::Value AddPhiOperands(Type variable, IR::Inst& phi, IR::Block* block) { for (IR::Block* const imm_pred : block->ImmPredecessors()) { - const bool is_thread_bit = - std::is_same_v && phi.Flags() == IR::Type::U1; - phi.AddPhiOperand(imm_pred, ReadVariable(variable, imm_pred, is_thread_bit)); + phi.AddPhiOperand(imm_pred, ReadVariable(variable, imm_pred)); } return TryRemoveTrivialPhi(phi, block, UndefOpcode(variable)); } @@ -313,7 +327,11 @@ private: void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) { const IR::Opcode opcode{inst.GetOpcode()}; switch (opcode) { - case IR::Opcode::SetThreadBitScalarReg: + case IR::Opcode::SetThreadBitScalarReg: { + const IR::ScalarReg reg{inst.Arg(0).ScalarReg()}; + pass.WriteVariable(ThreadBitScalar{reg}, block, inst.Arg(1)); + break; + } case IR::Opcode::SetScalarRegister: { const IR::ScalarReg reg{inst.Arg(0).ScalarReg()}; pass.WriteVariable(reg, block, inst.Arg(1)); @@ -345,11 +363,15 @@ void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) { case IR::Opcode::SetM0: pass.WriteVariable(M0Tag{}, block, inst.Arg(0)); break; - case IR::Opcode::GetThreadBitScalarReg: + case IR::Opcode::GetThreadBitScalarReg: { + const IR::ScalarReg reg{inst.Arg(0).ScalarReg()}; + const IR::Value value = pass.ReadVariable(ThreadBitScalar{reg}, block); + inst.ReplaceUsesWith(value); + break; + } case IR::Opcode::GetScalarRegister: { const IR::ScalarReg reg{inst.Arg(0).ScalarReg()}; - const bool thread_bit = opcode == IR::Opcode::GetThreadBitScalarReg; - const IR::Value value = pass.ReadVariable(reg, block, thread_bit); + const IR::Value value = pass.ReadVariable(reg, block); inst.ReplaceUsesWith(value); break; }