From 09946f15a270074a436655869150389dff9876fb Mon Sep 17 00:00:00 2001 From: Vinicius Rangel Date: Wed, 24 Jul 2024 13:05:34 -0300 Subject: [PATCH] shader recompiler: auto cast between u32 and u64 during ssa pass --- src/shader_recompiler/ir/ir_emitter.cpp | 15 ++++++-- src/shader_recompiler/ir/opcodes.inc | 5 ++- .../ir/passes/ssa_rewrite_pass.cpp | 38 ++++++++++++++++--- 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/src/shader_recompiler/ir/ir_emitter.cpp b/src/shader_recompiler/ir/ir_emitter.cpp index 5cc96c800..e60d111f0 100644 --- a/src/shader_recompiler/ir/ir_emitter.cpp +++ b/src/shader_recompiler/ir/ir_emitter.cpp @@ -145,7 +145,7 @@ void IREmitter::SetThreadBitScalarReg(IR::ScalarReg reg, const U1& value) { template <> U32 IREmitter::GetScalarReg(IR::ScalarReg reg) { - return Inst(Opcode::GetScalarRegister, reg); + return Inst(Opcode::GetScalarRegister, reg, Imm32(32)); } template <> @@ -155,7 +155,7 @@ F32 IREmitter::GetScalarReg(IR::ScalarReg reg) { template <> U64 IREmitter::GetScalarReg(IR::ScalarReg reg) { - return Inst(Opcode::GetScalarRegister, reg); + return Inst(Opcode::GetScalarRegister, reg, Imm32(64)); } template <> @@ -165,7 +165,7 @@ F64 IREmitter::GetScalarReg(IR::ScalarReg reg) { template <> U32 IREmitter::GetVectorReg(IR::VectorReg reg) { - return Inst(Opcode::GetVectorRegister, reg); + return Inst(Opcode::GetVectorRegister, reg, Imm32(32)); } template <> @@ -175,7 +175,7 @@ F32 IREmitter::GetVectorReg(IR::VectorReg reg) { template <> U64 IREmitter::GetVectorReg(IR::VectorReg reg) { - return Inst(Opcode::GetVectorRegister, reg); + return Inst(Opcode::GetVectorRegister, reg, Imm32(64)); } template <> @@ -1278,6 +1278,13 @@ U16U32U64 IREmitter::UConvert(size_t result_bitsize, const U16U32U64& value) { default: ThrowInvalidType(value.Type()); } + case 32: + switch (value.Type()) { + case Type::U64: + return Inst(Opcode::ConvertU32U64, value); + default: + ThrowInvalidType(value.Type()); + } case 64: switch (value.Type()) { case Type::U32: diff --git a/src/shader_recompiler/ir/opcodes.inc b/src/shader_recompiler/ir/opcodes.inc index f5b0ff362..11d146f72 100644 --- a/src/shader_recompiler/ir/opcodes.inc +++ b/src/shader_recompiler/ir/opcodes.inc @@ -43,9 +43,9 @@ OPCODE(WriteSharedU128, Void, U32, OPCODE(GetUserData, U32, ScalarReg, ) OPCODE(GetThreadBitScalarReg, U1, ScalarReg, ) OPCODE(SetThreadBitScalarReg, Void, ScalarReg, U1, ) -OPCODE(GetScalarRegister, U32, ScalarReg, ) +OPCODE(GetScalarRegister, U32, ScalarReg, U32, ) OPCODE(SetScalarRegister, Void, ScalarReg, U32, ) -OPCODE(GetVectorRegister, U32, VectorReg, ) +OPCODE(GetVectorRegister, U32, VectorReg, U32, ) OPCODE(SetVectorRegister, Void, VectorReg, U32, ) OPCODE(GetGotoVariable, U1, U32, ) OPCODE(SetGotoVariable, Void, U32, U1, ) @@ -291,6 +291,7 @@ OPCODE(ConvertF64U32, F64, U32, OPCODE(ConvertF32U16, F32, U16, ) OPCODE(ConvertU16U32, U16, U32, ) OPCODE(ConvertU64U32, U64, U32, ) +OPCODE(ConvertU32U64, U32, U64, ) OPCODE(ConvertU64F32, U64, F32, ) // Image operations diff --git a/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp b/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp index 6a43ad6be..6a686fadb 100644 --- a/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp +++ b/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp @@ -310,7 +310,8 @@ private: DefTable current_def; }; -void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) { +void VisitInst(Pass& pass, IR::Block* block, const IR::Block::iterator& iter) { + auto& inst{*iter}; const IR::Opcode opcode{inst.GetOpcode()}; switch (opcode) { case IR::Opcode::SetThreadBitScalarReg: @@ -348,13 +349,37 @@ void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) { case IR::Opcode::GetThreadBitScalarReg: case IR::Opcode::GetScalarRegister: { const IR::ScalarReg reg{inst.Arg(0).ScalarReg()}; - inst.ReplaceUsesWith( - pass.ReadVariable(reg, block, opcode == IR::Opcode::GetThreadBitScalarReg)); + bool thread_bit = opcode == IR::Opcode::GetThreadBitScalarReg; + IR::Value value = pass.ReadVariable(reg, block, thread_bit); + + if (!thread_bit) { + size_t bit_size{inst.Arg(1).U32()}; + if (bit_size == 32 && value.Type() == IR::Type::U64) { + auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU32U64, {value})}; + value = IR::U32{IR::Value{&*it}}; + } else if (bit_size == 64 && value.Type() == IR::Type::U32) { + auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU64U32, {value})}; + value = IR::U64{IR::Value{&*it}}; + } + } + + inst.ReplaceUsesWith(value); break; } case IR::Opcode::GetVectorRegister: { const IR::VectorReg reg{inst.Arg(0).VectorReg()}; - inst.ReplaceUsesWith(pass.ReadVariable(reg, block)); + IR::Value value = pass.ReadVariable(reg, block); + + size_t bit_size{inst.Arg(1).U32()}; + if (bit_size == 32 && value.Type() == IR::Type::U64) { + auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU32U64, {value})}; + value = IR::U32{IR::Value{&*it}}; + } else if (bit_size == 64 && value.Type() == IR::Type::U32) { + auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU64U32, {value})}; + value = IR::U64{IR::Value{&*it}}; + } + + inst.ReplaceUsesWith(value); break; } case IR::Opcode::GetGotoVariable: @@ -384,8 +409,9 @@ void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) { } void VisitBlock(Pass& pass, IR::Block* block) { - for (IR::Inst& inst : block->Instructions()) { - VisitInst(pass, block, inst); + const auto end{block->end()}; + for (auto iter = block->begin(); iter != end; ++iter) { + VisitInst(pass, block, iter); } pass.SealBlock(block); }