shader recompiler: auto cast between u32 and u64 during ssa pass

This commit is contained in:
Vinicius Rangel 2024-07-24 13:05:34 -03:00
parent 21ce67e8a0
commit 09946f15a2
No known key found for this signature in database
GPG Key ID: A5B154D904B761D9
3 changed files with 46 additions and 12 deletions

View File

@ -145,7 +145,7 @@ void IREmitter::SetThreadBitScalarReg(IR::ScalarReg reg, const U1& value) {
template <> template <>
U32 IREmitter::GetScalarReg(IR::ScalarReg reg) { U32 IREmitter::GetScalarReg(IR::ScalarReg reg) {
return Inst<U32>(Opcode::GetScalarRegister, reg); return Inst<U32>(Opcode::GetScalarRegister, reg, Imm32(32));
} }
template <> template <>
@ -155,7 +155,7 @@ F32 IREmitter::GetScalarReg(IR::ScalarReg reg) {
template <> template <>
U64 IREmitter::GetScalarReg(IR::ScalarReg reg) { U64 IREmitter::GetScalarReg(IR::ScalarReg reg) {
return Inst<U64>(Opcode::GetScalarRegister, reg); return Inst<U64>(Opcode::GetScalarRegister, reg, Imm32(64));
} }
template <> template <>
@ -165,7 +165,7 @@ F64 IREmitter::GetScalarReg(IR::ScalarReg reg) {
template <> template <>
U32 IREmitter::GetVectorReg(IR::VectorReg reg) { U32 IREmitter::GetVectorReg(IR::VectorReg reg) {
return Inst<U32>(Opcode::GetVectorRegister, reg); return Inst<U32>(Opcode::GetVectorRegister, reg, Imm32(32));
} }
template <> template <>
@ -175,7 +175,7 @@ F32 IREmitter::GetVectorReg(IR::VectorReg reg) {
template <> template <>
U64 IREmitter::GetVectorReg(IR::VectorReg reg) { U64 IREmitter::GetVectorReg(IR::VectorReg reg) {
return Inst<U64>(Opcode::GetVectorRegister, reg); return Inst<U64>(Opcode::GetVectorRegister, reg, Imm32(64));
} }
template <> template <>
@ -1278,6 +1278,13 @@ U16U32U64 IREmitter::UConvert(size_t result_bitsize, const U16U32U64& value) {
default: default:
ThrowInvalidType(value.Type()); ThrowInvalidType(value.Type());
} }
case 32:
switch (value.Type()) {
case Type::U64:
return Inst<U32>(Opcode::ConvertU32U64, value);
default:
ThrowInvalidType(value.Type());
}
case 64: case 64:
switch (value.Type()) { switch (value.Type()) {
case Type::U32: case Type::U32:

View File

@ -43,9 +43,9 @@ OPCODE(WriteSharedU128, Void, U32,
OPCODE(GetUserData, U32, ScalarReg, ) OPCODE(GetUserData, U32, ScalarReg, )
OPCODE(GetThreadBitScalarReg, U1, ScalarReg, ) OPCODE(GetThreadBitScalarReg, U1, ScalarReg, )
OPCODE(SetThreadBitScalarReg, Void, ScalarReg, U1, ) OPCODE(SetThreadBitScalarReg, Void, ScalarReg, U1, )
OPCODE(GetScalarRegister, U32, ScalarReg, ) OPCODE(GetScalarRegister, U32, ScalarReg, U32, )
OPCODE(SetScalarRegister, Void, ScalarReg, U32, ) OPCODE(SetScalarRegister, Void, ScalarReg, U32, )
OPCODE(GetVectorRegister, U32, VectorReg, ) OPCODE(GetVectorRegister, U32, VectorReg, U32, )
OPCODE(SetVectorRegister, Void, VectorReg, U32, ) OPCODE(SetVectorRegister, Void, VectorReg, U32, )
OPCODE(GetGotoVariable, U1, U32, ) OPCODE(GetGotoVariable, U1, U32, )
OPCODE(SetGotoVariable, Void, U32, U1, ) OPCODE(SetGotoVariable, Void, U32, U1, )
@ -291,6 +291,7 @@ OPCODE(ConvertF64U32, F64, U32,
OPCODE(ConvertF32U16, F32, U16, ) OPCODE(ConvertF32U16, F32, U16, )
OPCODE(ConvertU16U32, U16, U32, ) OPCODE(ConvertU16U32, U16, U32, )
OPCODE(ConvertU64U32, U64, U32, ) OPCODE(ConvertU64U32, U64, U32, )
OPCODE(ConvertU32U64, U32, U64, )
OPCODE(ConvertU64F32, U64, F32, ) OPCODE(ConvertU64F32, U64, F32, )
// Image operations // Image operations

View File

@ -310,7 +310,8 @@ private:
DefTable current_def; DefTable current_def;
}; };
void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) { void VisitInst(Pass& pass, IR::Block* block, const IR::Block::iterator& iter) {
auto& inst{*iter};
const IR::Opcode opcode{inst.GetOpcode()}; const IR::Opcode opcode{inst.GetOpcode()};
switch (opcode) { switch (opcode) {
case IR::Opcode::SetThreadBitScalarReg: case IR::Opcode::SetThreadBitScalarReg:
@ -348,13 +349,37 @@ void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) {
case IR::Opcode::GetThreadBitScalarReg: case IR::Opcode::GetThreadBitScalarReg:
case IR::Opcode::GetScalarRegister: { case IR::Opcode::GetScalarRegister: {
const IR::ScalarReg reg{inst.Arg(0).ScalarReg()}; const IR::ScalarReg reg{inst.Arg(0).ScalarReg()};
inst.ReplaceUsesWith( bool thread_bit = opcode == IR::Opcode::GetThreadBitScalarReg;
pass.ReadVariable(reg, block, opcode == IR::Opcode::GetThreadBitScalarReg)); IR::Value value = pass.ReadVariable(reg, block, thread_bit);
if (!thread_bit) {
size_t bit_size{inst.Arg(1).U32()};
if (bit_size == 32 && value.Type() == IR::Type::U64) {
auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU32U64, {value})};
value = IR::U32{IR::Value{&*it}};
} else if (bit_size == 64 && value.Type() == IR::Type::U32) {
auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU64U32, {value})};
value = IR::U64{IR::Value{&*it}};
}
}
inst.ReplaceUsesWith(value);
break; break;
} }
case IR::Opcode::GetVectorRegister: { case IR::Opcode::GetVectorRegister: {
const IR::VectorReg reg{inst.Arg(0).VectorReg()}; const IR::VectorReg reg{inst.Arg(0).VectorReg()};
inst.ReplaceUsesWith(pass.ReadVariable(reg, block)); IR::Value value = pass.ReadVariable(reg, block);
size_t bit_size{inst.Arg(1).U32()};
if (bit_size == 32 && value.Type() == IR::Type::U64) {
auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU32U64, {value})};
value = IR::U32{IR::Value{&*it}};
} else if (bit_size == 64 && value.Type() == IR::Type::U32) {
auto it{block->PrependNewInst(iter, IR::Opcode::ConvertU64U32, {value})};
value = IR::U64{IR::Value{&*it}};
}
inst.ReplaceUsesWith(value);
break; break;
} }
case IR::Opcode::GetGotoVariable: case IR::Opcode::GetGotoVariable:
@ -384,8 +409,9 @@ void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) {
} }
void VisitBlock(Pass& pass, IR::Block* block) { void VisitBlock(Pass& pass, IR::Block* block) {
for (IR::Inst& inst : block->Instructions()) { const auto end{block->end()};
VisitInst(pass, block, inst); for (auto iter = block->begin(); iter != end; ++iter) {
VisitInst(pass, block, iter);
} }
pass.SealBlock(block); pass.SealBlock(block);
} }