diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_bitwise_conversion.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_bitwise_conversion.cpp index 02ac74e19..70e620230 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_bitwise_conversion.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_bitwise_conversion.cpp @@ -58,4 +58,8 @@ Id EmitUnpackHalf2x16(EmitContext& ctx, Id value) { return ctx.OpUnpackHalf2x16(ctx.F32[2], value); } +Id EmitQuantizeHalf2x16(EmitContext& ctx, Id value) { + return ctx.OpQuantizeToF16(ctx.F32[2], value); +} + } // namespace Shader::Backend::SPIRV diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h b/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h index f0bb9fd7e..ecb63a956 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h +++ b/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h @@ -197,6 +197,7 @@ Id EmitPackFloat2x16(EmitContext& ctx, Id value); Id EmitUnpackFloat2x16(EmitContext& ctx, Id value); Id EmitPackHalf2x16(EmitContext& ctx, Id value); Id EmitUnpackHalf2x16(EmitContext& ctx, Id value); +Id EmitQuantizeHalf2x16(EmitContext& ctx, Id value); Id EmitFPAbs16(EmitContext& ctx, Id value); Id EmitFPAbs32(EmitContext& ctx, Id value); Id EmitFPAbs64(EmitContext& ctx, Id value); diff --git a/src/shader_recompiler/ir/ir_emitter.cpp b/src/shader_recompiler/ir/ir_emitter.cpp index 5ac08e7dc..b2d354435 100644 --- a/src/shader_recompiler/ir/ir_emitter.cpp +++ b/src/shader_recompiler/ir/ir_emitter.cpp @@ -795,6 +795,10 @@ Value IREmitter::UnpackHalf2x16(const U32& value) { return Inst(Opcode::UnpackHalf2x16, value); } +Value IREmitter::QuantizeHalf2x16(const Value& value) { + return Inst(Opcode::QuantizeHalf2x16, value); +} + F32F64 IREmitter::FPMul(const F32F64& a, const F32F64& b) { if (a.Type() != b.Type()) { UNREACHABLE_MSG("Mismatching types {} and {}", a.Type(), b.Type()); diff --git a/src/shader_recompiler/ir/ir_emitter.h b/src/shader_recompiler/ir/ir_emitter.h index d1dc44d74..8ddd1a013 100644 --- a/src/shader_recompiler/ir/ir_emitter.h +++ b/src/shader_recompiler/ir/ir_emitter.h @@ -175,6 +175,7 @@ public: [[nodiscard]] U32 PackHalf2x16(const Value& vector); [[nodiscard]] Value UnpackHalf2x16(const U32& value); + [[nodiscard]] Value QuantizeHalf2x16(const Value& value); [[nodiscard]] F32F64 FPAdd(const F32F64& a, const F32F64& b); [[nodiscard]] F32F64 FPSub(const F32F64& a, const F32F64& b); diff --git a/src/shader_recompiler/ir/opcodes.inc b/src/shader_recompiler/ir/opcodes.inc index b45151dba..04521a834 100644 --- a/src/shader_recompiler/ir/opcodes.inc +++ b/src/shader_recompiler/ir/opcodes.inc @@ -187,6 +187,7 @@ OPCODE(PackFloat2x16, U32, F16x OPCODE(UnpackFloat2x16, F16x2, U32, ) OPCODE(PackHalf2x16, U32, F32x2, ) OPCODE(UnpackHalf2x16, F32x2, U32, ) +OPCODE(QuantizeHalf2x16, F32x2, F32x2, ) // Floating-point operations OPCODE(FPAbs32, F32, F32, ) diff --git a/src/shader_recompiler/ir/passes/constant_propagation_pass.cpp b/src/shader_recompiler/ir/passes/constant_propagation_pass.cpp index 26d819d8e..022c94be7 100644 --- a/src/shader_recompiler/ir/passes/constant_propagation_pass.cpp +++ b/src/shader_recompiler/ir/passes/constant_propagation_pass.cpp @@ -204,6 +204,19 @@ void FoldInverseFunc(IR::Inst& inst, IR::Opcode reverse) { } } +void FoldUnpackHalf2x16(IR::Block& block, IR::Inst& inst) { + const IR::Value value{inst.Arg(0)}; + if (value.IsImmediate()) { + return; + } + IR::Inst* const arg_inst{value.InstRecursive()}; + if (arg_inst->GetOpcode() == IR::Opcode::PackHalf2x16) { + // When reversing pack half instruction, keep the loss of precision using quantization. + IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)}; + inst.ReplaceUsesWithAndRemove(ir.QuantizeHalf2x16(arg_inst->Arg(0))); + } +} + template void FoldAdd(IR::Block& block, IR::Inst& inst) { if (!FoldCommutative(inst, [](T a, T b) { return a + b; })) { @@ -343,7 +356,7 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) { case IR::Opcode::PackHalf2x16: return FoldInverseFunc(inst, IR::Opcode::UnpackHalf2x16); case IR::Opcode::UnpackHalf2x16: - return FoldInverseFunc(inst, IR::Opcode::PackHalf2x16); + return FoldUnpackHalf2x16(block, inst); case IR::Opcode::PackFloat2x16: return FoldInverseFunc(inst, IR::Opcode::UnpackFloat2x16); case IR::Opcode::UnpackFloat2x16: