diff --git a/src/shader_recompiler/ir/compute_value/compute.cpp b/src/shader_recompiler/ir/compute_value/compute.cpp index 829de9d72..7ba40d5bb 100644 --- a/src/shader_recompiler/ir/compute_value/compute.cpp +++ b/src/shader_recompiler/ir/compute_value/compute.cpp @@ -11,8 +11,8 @@ namespace Shader::IR { template static void CartesianInvokeImpl(Func func, OutputIt out_it, - std::tuple& arglists_its, - const std::tuple& arglists_tuple) { + std::tuple& arglists_its, + const std::tuple& arglists_tuple) { if constexpr (Level == N) { auto get_tuple = [&](std::index_sequence) { return std::forward_as_tuple(*std::get(arglists_its)...); @@ -23,7 +23,8 @@ static void CartesianInvokeImpl(Func func, OutputIt out_it, const auto& arglist = std::get(arglists_tuple); for (auto it = arglist.begin(); it != arglist.end(); ++it) { std::get(arglists_its) = it; - CartesianInvokeImpl(func, out_it, arglists_its, arglists_tuple); + CartesianInvokeImpl( + func, out_it, arglists_its, arglists_tuple); } } } @@ -34,7 +35,8 @@ static void CartesianInvoke(Func func, OutputIt out_it, const ArgLists&... arg_l const std::tuple arglists_tuple = std::forward_as_tuple(arg_lists...); std::tuple arglists_it; - CartesianInvokeImpl(func, out_it, arglists_it, arglists_tuple); + CartesianInvokeImpl(func, out_it, arglists_it, + arglists_tuple); } static void SetSigned(ImmValueList& values, bool is_signed) { @@ -47,35 +49,31 @@ static void OperationAbs(Inst* inst, ImmValueList& inst_values, ComputeImmValues ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return a.abs(); - }; + const auto op = [](const ImmValue& a) { return a.abs(); }; std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } -static void OperationAdd(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationAdd(Inst* inst, bool is_signed, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args0, args1; ComputeImmValues(inst->Arg(0), args0, cache); ComputeImmValues(inst->Arg(1), args1, cache); - const auto op = [](const ImmValue& a, const ImmValue& b) { - return a + b; - }; + const auto op = [](const ImmValue& a, const ImmValue& b) { return a + b; }; SetSigned(args0, is_signed); SetSigned(args1, is_signed); CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); } -static void OperationSub(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationSub(Inst* inst, bool is_signed, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args0, args1; ComputeImmValues(inst->Arg(0), args0, cache); ComputeImmValues(inst->Arg(1), args1, cache); - const auto op = [](const ImmValue& a, const ImmValue& b) { - return a - b; - }; + const auto op = [](const ImmValue& a, const ImmValue& b) { return a - b; }; SetSigned(args0, is_signed); SetSigned(args1, is_signed); @@ -95,7 +93,8 @@ static void OperationFma(Inst* inst, ImmValueList& inst_values, ComputeImmValues CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, args2); } -static void OperationMin(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationMin(Inst* inst, bool is_signed, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args0, args1, is_legacy_args; ComputeImmValues(inst->Arg(0), args0, cache); ComputeImmValues(inst->Arg(1), args1, cache); @@ -107,18 +106,22 @@ static void OperationMin(Inst* inst, bool is_signed, ImmValueList& inst_values, const auto op = [](const ImmValue& a, const ImmValue& b, const ImmValue& is_legacy) { if (is_legacy.U1()) { - if (a.isnan()) return b; - if (b.isnan()) return a; + if (a.isnan()) + return b; + if (b.isnan()) + return a; } return std::min(a, b); }; SetSigned(args0, is_signed); SetSigned(args1, is_signed); - CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, is_legacy_args); + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, + is_legacy_args); } -static void OperationMax(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationMax(Inst* inst, bool is_signed, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args0, args1, is_legacy_args; ComputeImmValues(inst->Arg(0), args0, cache); ComputeImmValues(inst->Arg(1), args1, cache); @@ -130,39 +133,53 @@ static void OperationMax(Inst* inst, bool is_signed, ImmValueList& inst_values, const auto op = [](const ImmValue& a, const ImmValue& b, const ImmValue& is_legacy) { if (is_legacy.U1()) { - if (a.isnan()) return b; - if (b.isnan()) return a; + if (a.isnan()) + return b; + if (b.isnan()) + return a; } return std::max(a, b); }; SetSigned(args0, is_signed); SetSigned(args1, is_signed); - CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, is_legacy_args); + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, + is_legacy_args); } -static void OperationMul(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationMul(Inst* inst, bool is_signed, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args0, args1; ComputeImmValues(inst->Arg(0), args0, cache); ComputeImmValues(inst->Arg(1), args1, cache); - const auto op = [](const ImmValue& a, const ImmValue& b) { - return a * b; - }; + const auto op = [](const ImmValue& a, const ImmValue& b) { return a * b; }; SetSigned(args0, is_signed); SetSigned(args1, is_signed); CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); } -static void OperationDiv(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationDiv(Inst* inst, bool is_signed, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args0, args1; ComputeImmValues(inst->Arg(0), args0, cache); ComputeImmValues(inst->Arg(1), args1, cache); - const auto op = [](const ImmValue& a, const ImmValue& b) { - return a / b; - }; + const auto op = [](const ImmValue& a, const ImmValue& b) { return a / b; }; + + SetSigned(args0, is_signed); + SetSigned(args1, is_signed); + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); +} + +static void OperationMod(Inst* inst, bool is_signed, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { + ImmValueList args0, args1; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b) { return a % b; }; SetSigned(args0, is_signed); SetSigned(args1, is_signed); @@ -173,9 +190,7 @@ static void OperationNeg(Inst* inst, ImmValueList& inst_values, ComputeImmValues ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return -a; - }; + const auto op = [](const ImmValue& a) { return -a; }; SetSigned(args, true); std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); @@ -185,20 +200,17 @@ static void OperationRecip(Inst* inst, ImmValueList& inst_values, ComputeImmValu ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return a.recip(); - }; + const auto op = [](const ImmValue& a) { return a.recip(); }; std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } -static void OperationRecipSqrt(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationRecipSqrt(Inst* inst, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return a.rsqrt(); - }; + const auto op = [](const ImmValue& a) { return a.rsqrt(); }; std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } @@ -207,9 +219,7 @@ static void OperationSqrt(Inst* inst, ImmValueList& inst_values, ComputeImmValue ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return a.sqrt(); - }; + const auto op = [](const ImmValue& a) { return a.sqrt(); }; std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } @@ -218,9 +228,7 @@ static void OperationSin(Inst* inst, ImmValueList& inst_values, ComputeImmValues ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return a.sin(); - }; + const auto op = [](const ImmValue& a) { return a.sin(); }; std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } @@ -229,9 +237,7 @@ static void OperationExp2(Inst* inst, ImmValueList& inst_values, ComputeImmValue ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return a.exp2(); - }; + const auto op = [](const ImmValue& a) { return a.exp2(); }; std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } @@ -241,9 +247,7 @@ static void OperationLdexp(Inst* inst, ImmValueList& inst_values, ComputeImmValu ComputeImmValues(inst->Arg(0), args0, cache); ComputeImmValues(inst->Arg(1), args1, cache); - const auto op = [](const ImmValue& a, const ImmValue& b) { - return a.ldexp(ImmU32(b)); - }; + const auto op = [](const ImmValue& a, const ImmValue& b) { return a.ldexp(ImmU32(b)); }; CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); } @@ -252,9 +256,7 @@ static void OperationCos(Inst* inst, ImmValueList& inst_values, ComputeImmValues ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return a.cos(); - }; + const auto op = [](const ImmValue& a) { return a.cos(); }; std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } @@ -263,14 +265,13 @@ static void OperationLog2(Inst* inst, ImmValueList& inst_values, ComputeImmValue ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return a.log2(); - }; + const auto op = [](const ImmValue& a) { return a.log2(); }; std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } -static void OperationClamp(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationClamp(Inst* inst, bool is_signed, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args0, args1, args2; ComputeImmValues(inst->Arg(0), args0, cache); ComputeImmValues(inst->Arg(1), args1, cache); @@ -290,9 +291,7 @@ static void OperationRound(Inst* inst, ImmValueList& inst_values, ComputeImmValu ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return a.round(); - }; + const auto op = [](const ImmValue& a) { return a.round(); }; std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } @@ -301,9 +300,7 @@ static void OperationFloor(Inst* inst, ImmValueList& inst_values, ComputeImmValu ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return a.floor(); - }; + const auto op = [](const ImmValue& a) { return a.floor(); }; std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } @@ -312,9 +309,7 @@ static void OperationCeil(Inst* inst, ImmValueList& inst_values, ComputeImmValue ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return a.ceil(); - }; + const auto op = [](const ImmValue& a) { return a.ceil(); }; std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } @@ -323,9 +318,7 @@ static void OperationTrunc(Inst* inst, ImmValueList& inst_values, ComputeImmValu ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return a.trunc(); - }; + const auto op = [](const ImmValue& a) { return a.trunc(); }; std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } @@ -334,87 +327,80 @@ static void OperationFract(Inst* inst, ImmValueList& inst_values, ComputeImmValu ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return a.fract(); - }; + const auto op = [](const ImmValue& a) { return a.fract(); }; std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } -static void OperationShiftLeft(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationShiftLeft(Inst* inst, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args0, args1; ComputeImmValues(inst->Arg(0), args0, cache); ComputeImmValues(inst->Arg(1), args1, cache); - const auto op = [](const ImmValue& a, const ImmValue& b) { - return a << ImmU32(b); - }; + const auto op = [](const ImmValue& a, const ImmValue& b) { return a << ImmU32(b); }; SetSigned(args1, false); CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); } -static void OperationShiftRight(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationShiftRight(Inst* inst, bool is_signed, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args0, args1; ComputeImmValues(inst->Arg(0), args0, cache); ComputeImmValues(inst->Arg(1), args1, cache); - const auto op = [](const ImmValue& a, const ImmValue& b) { - return a >> ImmU32(b); - }; + const auto op = [](const ImmValue& a, const ImmValue& b) { return a >> ImmU32(b); }; SetSigned(args0, is_signed); SetSigned(args1, false); CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); } -static void OperationBitwiseNot(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationBitwiseNot(Inst* inst, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); - const auto op = [](const ImmValue& a) { - return ~a; - }; + const auto op = [](const ImmValue& a) { return ~a; }; std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } -static void OperationBitwiseAnd(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationBitwiseAnd(Inst* inst, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args0, args1; ComputeImmValues(inst->Arg(0), args0, cache); ComputeImmValues(inst->Arg(1), args1, cache); - const auto op = [](const ImmValue& a, const ImmValue& b) { - return a & b; - }; + const auto op = [](const ImmValue& a, const ImmValue& b) { return a & b; }; CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); } -static void OperationBitwiseOr(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationBitwiseOr(Inst* inst, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args0, args1; ComputeImmValues(inst->Arg(0), args0, cache); ComputeImmValues(inst->Arg(1), args1, cache); - const auto op = [](const ImmValue& a, const ImmValue& b) { - return a | b; - }; + const auto op = [](const ImmValue& a, const ImmValue& b) { return a | b; }; CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); } -static void OperationBitwiseXor(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationBitwiseXor(Inst* inst, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args0, args1; ComputeImmValues(inst->Arg(0), args0, cache); ComputeImmValues(inst->Arg(1), args1, cache); - const auto op = [](const ImmValue& a, const ImmValue& b) { - return a ^ b; - }; + const auto op = [](const ImmValue& a, const ImmValue& b) { return a ^ b; }; CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); } -static void OperationConvert(Inst* inst, bool is_signed, Type new_type, bool new_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationConvert(Inst* inst, bool is_signed, Type new_type, bool new_signed, + ImmValueList& inst_values, ComputeImmValuesCache& cache) { ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); @@ -426,7 +412,8 @@ static void OperationConvert(Inst* inst, bool is_signed, Type new_type, bool new std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } -static void OperationBitCast(Inst* inst, Type new_type, bool new_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationBitCast(Inst* inst, Type new_type, bool new_signed, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args; ComputeImmValues(inst->Arg(0), args, cache); @@ -437,16 +424,15 @@ static void OperationBitCast(Inst* inst, Type new_type, bool new_signed, ImmValu std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); } -template -static void OperationCompositeConstruct(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +template +static void OperationCompositeConstruct(Inst* inst, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { std::array args; for (size_t i = 0; i < N; ++i) { ComputeImmValues(inst->Arg(i), args[i], cache); } - const auto op = [](const Args&... args) { - return ImmValue(args...); - }; + const auto op = [](const Args&... args) { return ImmValue(args...); }; const auto call_cartesian = [&](std::index_sequence) { CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args[I]...); @@ -454,14 +440,13 @@ static void OperationCompositeConstruct(Inst* inst, ImmValueList& inst_values, C call_cartesian(std::make_index_sequence{}); } -static void OperationCompositeExtract(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void OperationCompositeExtract(Inst* inst, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { ImmValueList args0, args1; ComputeImmValues(inst->Arg(0), args0, cache); ComputeImmValues(inst->Arg(1), args1, cache); - const auto op = [](const ImmValue& a, const ImmValue& b) { - return a.Extract(ImmU32(b)); - }; + const auto op = [](const ImmValue& a, const ImmValue& b) { return a.Extract(ImmU32(b)); }; SetSigned(args1, false); CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); @@ -481,176 +466,274 @@ static void OperationInsert(Inst* inst, ImmValueList& inst_values, ComputeImmVal CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, args2); } -static void DoInstructionOperation(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { +static void DoInstructionOperation(Inst* inst, ImmValueList& inst_values, + ComputeImmValuesCache& cache) { switch (inst->GetOpcode()) { - case Opcode::CompositeConstructU32x2: - case Opcode::CompositeConstructU32x2x2: - case Opcode::CompositeConstructF16x2: - case Opcode::CompositeConstructF32x2: - case Opcode::CompositeConstructF32x2x2: - case Opcode::CompositeConstructF64x2: - OperationCompositeConstruct<2>(inst, inst_values, cache); - break; - case Opcode::CompositeConstructU32x3: - case Opcode::CompositeConstructF16x3: - case Opcode::CompositeConstructF32x3: - case Opcode::CompositeConstructF64x3: - OperationCompositeConstruct<3>(inst, inst_values, cache); - break; - case Opcode::CompositeConstructU32x4: - case Opcode::CompositeConstructF16x4: - case Opcode::CompositeConstructF32x4: - case Opcode::CompositeConstructF64x4: - OperationCompositeConstruct<4>(inst, inst_values, cache); - break; - case Opcode::CompositeExtractU32x2: - case Opcode::CompositeExtractU32x3: - case Opcode::CompositeExtractU32x4: - case Opcode::CompositeExtractF16x2: - case Opcode::CompositeExtractF16x3: - case Opcode::CompositeExtractF16x4: - case Opcode::CompositeExtractF32x2: - case Opcode::CompositeExtractF32x3: - case Opcode::CompositeExtractF32x4: - case Opcode::CompositeExtractF64x2: - case Opcode::CompositeExtractF64x3: - case Opcode::CompositeExtractF64x4: - OperationCompositeExtract(inst, inst_values, cache); - break; - case Opcode::CompositeInsertU32x2: - case Opcode::CompositeInsertU32x3: - case Opcode::CompositeInsertU32x4: - case Opcode::CompositeInsertF16x2: - case Opcode::CompositeInsertF16x3: - case Opcode::CompositeInsertF16x4: - case Opcode::CompositeInsertF32x2: - case Opcode::CompositeInsertF32x3: - case Opcode::CompositeInsertF32x4: - case Opcode::CompositeInsertF64x2: - case Opcode::CompositeInsertF64x3: - case Opcode::CompositeInsertF64x4: - OperationInsert(inst, inst_values, cache); - break; - case Opcode::BitCastU16F16: - OperationBitCast(inst, IR::Type::U16, false, inst_values, cache); - break; - case Opcode::BitCastU32F32: - OperationBitCast(inst, IR::Type::U32, false, inst_values, cache); - break; - case Opcode::BitCastU64F64: - OperationBitCast(inst, IR::Type::U64, false, inst_values, cache); - break; - case Opcode::BitCastF16U16: - OperationBitCast(inst, IR::Type::F16, true, inst_values, cache); - break; - case Opcode::BitCastF32U32: - OperationBitCast(inst, IR::Type::F32, true, inst_values, cache); - break; - case Opcode::BitCastF64U64: - OperationBitCast(inst, IR::Type::F64, true, inst_values, cache); - break; - case Opcode::FPAbs32: - case Opcode::FPAbs64: - OperationAbs(inst, inst_values, cache); - break; - case Opcode::FPAdd32: - case Opcode::FPAdd64: - OperationAdd(inst, false, inst_values, cache); - break; - case Opcode::FPSub32: - OperationSub(inst, false, inst_values, cache); - break; - case Opcode::FPMul32: - case Opcode::FPMul64: - OperationMul(inst, false, inst_values, cache); - break; - case Opcode::FPDiv32: - case Opcode::FPDiv64: - OperationDiv(inst, false, inst_values, cache); - break; - case Opcode::FPFma32: - case Opcode::FPFma64: - OperationFma(inst, inst_values, cache); - break; - case Opcode::FPMin32: - case Opcode::FPMin64: - OperationMin(inst, false, inst_values, cache); - break; - case Opcode::FPMax32: - case Opcode::FPMax64: - OperationMax(inst, false, inst_values, cache); - break; - case Opcode::FPNeg32: - case Opcode::FPNeg64: - OperationNeg(inst, inst_values, cache); - break; - case Opcode::FPRecip32: - case Opcode::FPRecip64: - OperationRecip(inst, inst_values, cache); - break; - case Opcode::FPRecipSqrt32: - case Opcode::FPRecipSqrt64: - OperationRecipSqrt(inst, inst_values, cache); - break; - case Opcode::FPSqrt: - OperationSqrt(inst, inst_values, cache); - break; - case Opcode::FPSin: - OperationSin(inst, inst_values, cache); - break; - case Opcode::FPCos: - OperationCos(inst, inst_values, cache); - break; - case Opcode::FPExp2: - OperationExp2(inst, inst_values, cache); - break; - case Opcode::FPLdexp: - OperationLdexp(inst, inst_values, cache); - break; - case Opcode::FPLog2: - OperationLog2(inst, inst_values, cache); - break; - case Opcode::FPClamp32: - case Opcode::FPClamp64: - OperationClamp(inst, false, inst_values, cache); - break; - case Opcode::FPRoundEven32: - case Opcode::FPRoundEven64: - OperationRound(inst, inst_values, cache); - break; - case Opcode::FPFloor32: - case Opcode::FPFloor64: - OperationFloor(inst, inst_values, cache); - break; - case Opcode::FPCeil32: - case Opcode::FPCeil64: - OperationCeil(inst, inst_values, cache); - break; - case Opcode::FPTrunc32: - case Opcode::FPTrunc64: - OperationTrunc(inst, inst_values, cache); - break; - case Opcode::FPFract32: - case Opcode::FPFract64: - OperationFract(inst, inst_values, cache); - break; - default: - break; + case Opcode::CompositeConstructU32x2: + case Opcode::CompositeConstructU32x2x2: + case Opcode::CompositeConstructF16x2: + case Opcode::CompositeConstructF32x2: + case Opcode::CompositeConstructF32x2x2: + case Opcode::CompositeConstructF64x2: + OperationCompositeConstruct<2>(inst, inst_values, cache); + break; + case Opcode::CompositeConstructU32x3: + case Opcode::CompositeConstructF16x3: + case Opcode::CompositeConstructF32x3: + case Opcode::CompositeConstructF64x3: + OperationCompositeConstruct<3>(inst, inst_values, cache); + break; + case Opcode::CompositeConstructU32x4: + case Opcode::CompositeConstructF16x4: + case Opcode::CompositeConstructF32x4: + case Opcode::CompositeConstructF64x4: + OperationCompositeConstruct<4>(inst, inst_values, cache); + break; + case Opcode::CompositeExtractU32x2: + case Opcode::CompositeExtractU32x3: + case Opcode::CompositeExtractU32x4: + case Opcode::CompositeExtractF16x2: + case Opcode::CompositeExtractF16x3: + case Opcode::CompositeExtractF16x4: + case Opcode::CompositeExtractF32x2: + case Opcode::CompositeExtractF32x3: + case Opcode::CompositeExtractF32x4: + case Opcode::CompositeExtractF64x2: + case Opcode::CompositeExtractF64x3: + case Opcode::CompositeExtractF64x4: + OperationCompositeExtract(inst, inst_values, cache); + break; + case Opcode::CompositeInsertU32x2: + case Opcode::CompositeInsertU32x3: + case Opcode::CompositeInsertU32x4: + case Opcode::CompositeInsertF16x2: + case Opcode::CompositeInsertF16x3: + case Opcode::CompositeInsertF16x4: + case Opcode::CompositeInsertF32x2: + case Opcode::CompositeInsertF32x3: + case Opcode::CompositeInsertF32x4: + case Opcode::CompositeInsertF64x2: + case Opcode::CompositeInsertF64x3: + case Opcode::CompositeInsertF64x4: + OperationInsert(inst, inst_values, cache); + break; + case Opcode::BitCastU16F16: + OperationBitCast(inst, IR::Type::U16, false, inst_values, cache); + break; + case Opcode::BitCastU32F32: + OperationBitCast(inst, IR::Type::U32, false, inst_values, cache); + break; + case Opcode::BitCastU64F64: + OperationBitCast(inst, IR::Type::U64, false, inst_values, cache); + break; + case Opcode::BitCastF16U16: + OperationBitCast(inst, IR::Type::F16, true, inst_values, cache); + break; + case Opcode::BitCastF32U32: + OperationBitCast(inst, IR::Type::F32, true, inst_values, cache); + break; + case Opcode::BitCastF64U64: + OperationBitCast(inst, IR::Type::F64, true, inst_values, cache); + break; + case Opcode::FPAbs32: + case Opcode::FPAbs64: + case Opcode::IAbs32: + OperationAbs(inst, inst_values, cache); + break; + case Opcode::FPAdd32: + case Opcode::FPAdd64: + OperationAdd(inst, true, inst_values, cache); + break; + case Opcode::IAdd32: + case Opcode::IAdd64: + OperationAdd(inst, false, inst_values, cache); + break; + case Opcode::FPSub32: + OperationSub(inst, true, inst_values, cache); + break; + case Opcode::ISub32: + case Opcode::ISub64: + OperationSub(inst, false, inst_values, cache); + break; + case Opcode::FPMul32: + case Opcode::FPMul64: + OperationMul(inst, true, inst_values, cache); + break; + case Opcode::IMul32: + case Opcode::IMul64: + OperationMul(inst, false, inst_values, cache); + break; + case Opcode::FPDiv32: + case Opcode::FPDiv64: + case Opcode::SDiv32: + OperationDiv(inst, true, inst_values, cache); + break; + case Opcode::UDiv32: + OperationDiv(inst, false, inst_values, cache); + break; + case Opcode::SMod32: + OperationMod(inst, true, inst_values, cache); + break; + case Opcode::UMod32: + OperationMod(inst, false, inst_values, cache); + break; + case Opcode::INeg32: + case Opcode::INeg64: + OperationNeg(inst, inst_values, cache); + break; + case Opcode::FPFma32: + case Opcode::FPFma64: + OperationFma(inst, inst_values, cache); + break; + case Opcode::FPMin32: + case Opcode::FPMin64: + case Opcode::SMin32: + OperationMin(inst, true, inst_values, cache); + break; + case Opcode::UMin32: + OperationMin(inst, false, inst_values, cache); + break; + case Opcode::FPMax32: + case Opcode::FPMax64: + case Opcode::SMax32: + OperationMax(inst, true, inst_values, cache); + break; + case Opcode::UMax32: + OperationMax(inst, false, inst_values, cache); + break; + case Opcode::FPNeg32: + case Opcode::FPNeg64: + OperationNeg(inst, inst_values, cache); + break; + case Opcode::FPRecip32: + case Opcode::FPRecip64: + OperationRecip(inst, inst_values, cache); + break; + case Opcode::FPRecipSqrt32: + case Opcode::FPRecipSqrt64: + OperationRecipSqrt(inst, inst_values, cache); + break; + case Opcode::FPSqrt: + OperationSqrt(inst, inst_values, cache); + break; + case Opcode::FPSin: + OperationSin(inst, inst_values, cache); + break; + case Opcode::FPCos: + OperationCos(inst, inst_values, cache); + break; + case Opcode::FPExp2: + OperationExp2(inst, inst_values, cache); + break; + case Opcode::FPLdexp: + OperationLdexp(inst, inst_values, cache); + break; + case Opcode::FPLog2: + OperationLog2(inst, inst_values, cache); + break; + case Opcode::FPClamp32: + case Opcode::FPClamp64: + case Opcode::SClamp32: + OperationClamp(inst, true, inst_values, cache); + break; + case Opcode::UClamp32: + OperationClamp(inst, false, inst_values, cache); + break; + case Opcode::FPRoundEven32: + case Opcode::FPRoundEven64: + OperationRound(inst, inst_values, cache); + break; + case Opcode::FPFloor32: + case Opcode::FPFloor64: + OperationFloor(inst, inst_values, cache); + break; + case Opcode::FPCeil32: + case Opcode::FPCeil64: + OperationCeil(inst, inst_values, cache); + break; + case Opcode::FPTrunc32: + case Opcode::FPTrunc64: + OperationTrunc(inst, inst_values, cache); + break; + case Opcode::FPFract32: + case Opcode::FPFract64: + OperationFract(inst, inst_values, cache); + break; + case Opcode::ShiftLeftLogical32: + case Opcode::ShiftLeftLogical64: + OperationShiftLeft(inst, inst_values, cache); + break; + case Opcode::ShiftRightLogical32: + case Opcode::ShiftRightLogical64: + OperationShiftRight(inst, false, inst_values, cache); + break; + case Opcode::ShiftRightArithmetic32: + case Opcode::ShiftRightArithmetic64: + OperationShiftRight(inst, true, inst_values, cache); + break; + case Opcode::BitwiseAnd32: + case Opcode::BitwiseAnd64: + case Opcode::LogicalAnd: + OperationBitwiseAnd(inst, inst_values, cache); + break; + case Opcode::BitwiseOr32: + case Opcode::BitwiseOr64: + case Opcode::LogicalOr: + OperationBitwiseOr(inst, inst_values, cache); + break; + case Opcode::BitwiseXor32: + case Opcode::LogicalXor: + OperationBitwiseXor(inst, inst_values, cache); + break; + case Opcode::BitwiseNot32: + case Opcode::LogicalNot: + OperationBitwiseNot(inst, inst_values, cache); + break; + case Opcode::ConvertU16U32: + OperationConvert(inst, false, Type::U16, false, inst_values, cache); + break; + case Opcode::ConvertS32F32: + case Opcode::ConvertS32F64: + OperationConvert(inst, true, Type::U32, true, inst_values, cache); + break; + case Opcode::ConvertU32F32: + OperationConvert(inst, true, Type::U32, false, inst_values, cache); + break; + case Opcode::ConvertU32U16: + OperationConvert(inst, false, Type::U32, false, inst_values, cache); + break; + case Opcode::ConvertF32F16: + case Opcode::ConvertF32F64: + case Opcode::ConvertF32S32: + OperationConvert(inst, true, Type::F32, true, inst_values, cache); + break; + case Opcode::ConvertF32U32: + OperationConvert(inst, false, Type::F32, true, inst_values, cache); + break; + case Opcode::ConvertF64F32: + case Opcode::ConvertF64S32: + OperationConvert(inst, true, Type::F64, true, inst_values, cache); + break; + case Opcode::ConvertF64U32: + OperationConvert(inst, false, Type::F64, true, inst_values, cache); + break; + default: + break; } } static bool IsSelectInst(Inst* inst) { switch (inst->GetOpcode()) { - case Opcode::SelectU1: - case Opcode::SelectU8: - case Opcode::SelectU16: - case Opcode::SelectU32: - case Opcode::SelectU64: - case Opcode::SelectF32: - case Opcode::SelectF64: - return true; - default: - return false; + case Opcode::SelectU1: + case Opcode::SelectU8: + case Opcode::SelectU16: + case Opcode::SelectU32: + case Opcode::SelectU64: + case Opcode::SelectF32: + case Opcode::SelectF64: + return true; + default: + return false; } } @@ -674,7 +757,8 @@ void ComputeImmValues(const Value& value, ImmValueList& values, ComputeImmValues for (size_t i = 0; i < inst->NumArgs(); ++i) { ComputeImmValues(inst->Arg(i), inst_values, cache); } - } if (IsSelectInst(inst)) { + } + if (IsSelectInst(inst)) { ComputeImmValues(inst->Arg(1), inst_values, cache); ComputeImmValues(inst->Arg(2), inst_values, cache); } else { diff --git a/src/shader_recompiler/ir/compute_value/compute.h b/src/shader_recompiler/ir/compute_value/compute.h index fbfe46575..8b6e7b86b 100644 --- a/src/shader_recompiler/ir/compute_value/compute.h +++ b/src/shader_recompiler/ir/compute_value/compute.h @@ -3,11 +3,15 @@ #pragma once -#include #include +#include #include "shader_recompiler/ir/compute_value/imm_value.h" #include "shader_recompiler/ir/value.h" +// Given a value (inmediate or not), compute all the possible inmediate values +// that can represent. If the value can't be computed statically, the list will +// be empty. + namespace Shader::IR { using ImmValueList = boost::container::flat_set; diff --git a/src/shader_recompiler/ir/compute_value/imm_value.cpp b/src/shader_recompiler/ir/compute_value/imm_value.cpp index 068069d2e..e94533e57 100644 --- a/src/shader_recompiler/ir/compute_value/imm_value.cpp +++ b/src/shader_recompiler/ir/compute_value/imm_value.cpp @@ -180,17 +180,20 @@ ImmValue::ImmValue(const ImmValue& value1, const ImmValue& value2) noexcept ImmValue::ImmValue(const ImmValue& value1, const ImmValue& value2, const ImmValue& value3) noexcept : type{value1.type}, is_signed{value1.is_signed} { - ASSERT(value1.type == value2.type && value1.type == value3.type && value1.is_signed == value2.is_signed && - value1.is_signed == value3.is_signed && value1.Dimensions() == 1); + ASSERT(value1.type == value2.type && value1.type == value3.type && + value1.is_signed == value2.is_signed && value1.is_signed == value3.is_signed && + value1.Dimensions() == 1); imm_values[0] = value1.imm_values[0]; imm_values[1] = value2.imm_values[0]; imm_values[2] = value3.imm_values[0]; } -ImmValue::ImmValue(const ImmValue& value1, const ImmValue& value2, const ImmValue& value3, const ImmValue& value4) noexcept +ImmValue::ImmValue(const ImmValue& value1, const ImmValue& value2, const ImmValue& value3, + const ImmValue& value4) noexcept : type{value1.type}, is_signed{value1.is_signed} { - ASSERT(value1.type == value2.type && value1.type == value3.type && value1.type == value4.type && value1.is_signed == value2.is_signed && - value1.is_signed == value3.is_signed && value1.is_signed == value4.is_signed && value1.Dimensions() == 1); + ASSERT(value1.type == value2.type && value1.type == value3.type && value1.type == value4.type && + value1.is_signed == value2.is_signed && value1.is_signed == value3.is_signed && + value1.is_signed == value4.is_signed && value1.Dimensions() == 1); imm_values[0] = value1.imm_values[0]; imm_values[1] = value2.imm_values[0]; imm_values[2] = value3.imm_values[0]; @@ -280,12 +283,12 @@ ImmValue ImmValue::Convert(IR::Type new_type, bool new_signed) const noexcept { case Type::U32: { if (new_signed) { switch (type) { - case Type::F32: - return ImmValue(static_cast(imm_values[0].imm_f32)); - case Type::F64: - return ImmValue(static_cast(imm_values[0].imm_f64)); - default: - break; + case Type::F32: + return ImmValue(static_cast(imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(static_cast(imm_values[0].imm_f64)); + default: + break; } } else { switch (type) { @@ -884,21 +887,17 @@ ImmValue ImmValue::operator<<(const ImmU32& other) const noexcept { case Type::U1: return ImmValue(imm_values[0].imm_u1 << other.imm_values[0].imm_u1); case Type::U8: - return is_signed - ? ImmValue(imm_values[0].imm_s8 << other.imm_values[0].imm_s8) - : ImmValue(imm_values[0].imm_u8 << other.imm_values[0].imm_u8); + return is_signed ? ImmValue(imm_values[0].imm_s8 << other.imm_values[0].imm_s8) + : ImmValue(imm_values[0].imm_u8 << other.imm_values[0].imm_u8); case Type::U16: - return is_signed - ? ImmValue(imm_values[0].imm_s16 << other.imm_values[0].imm_s16) - : ImmValue(imm_values[0].imm_u16 << other.imm_values[0].imm_u16); + return is_signed ? ImmValue(imm_values[0].imm_s16 << other.imm_values[0].imm_s16) + : ImmValue(imm_values[0].imm_u16 << other.imm_values[0].imm_u16); case Type::U32: - return is_signed - ? ImmValue(imm_values[0].imm_s32 << other.imm_values[0].imm_s32) - : ImmValue(imm_values[0].imm_u32 << other.imm_values[0].imm_u32); + return is_signed ? ImmValue(imm_values[0].imm_s32 << other.imm_values[0].imm_s32) + : ImmValue(imm_values[0].imm_u32 << other.imm_values[0].imm_u32); case Type::U64: - return is_signed - ? ImmValue(imm_values[0].imm_s64 << other.imm_values[0].imm_s64) - : ImmValue(imm_values[0].imm_u64 << other.imm_values[0].imm_u64); + return is_signed ? ImmValue(imm_values[0].imm_s64 << other.imm_values[0].imm_s64) + : ImmValue(imm_values[0].imm_u64 << other.imm_values[0].imm_u64); default: UNREACHABLE_MSG("Invalid type {}", type); } @@ -909,21 +908,17 @@ ImmValue ImmValue::operator>>(const ImmU32& other) const noexcept { case Type::U1: return ImmValue(imm_values[0].imm_u1 >> other.imm_values[0].imm_u1); case Type::U8: - return is_signed - ? ImmValue(imm_values[0].imm_s8 >> other.imm_values[0].imm_s8) - : ImmValue(imm_values[0].imm_u8 >> other.imm_values[0].imm_u8); + return is_signed ? ImmValue(imm_values[0].imm_s8 >> other.imm_values[0].imm_s8) + : ImmValue(imm_values[0].imm_u8 >> other.imm_values[0].imm_u8); case Type::U16: - return is_signed - ? ImmValue(imm_values[0].imm_s16 >> other.imm_values[0].imm_s16) - : ImmValue(imm_values[0].imm_u16 >> other.imm_values[0].imm_u16); + return is_signed ? ImmValue(imm_values[0].imm_s16 >> other.imm_values[0].imm_s16) + : ImmValue(imm_values[0].imm_u16 >> other.imm_values[0].imm_u16); case Type::U32: - return is_signed - ? ImmValue(imm_values[0].imm_s32 >> other.imm_values[0].imm_s32) - : ImmValue(imm_values[0].imm_u32 >> other.imm_values[0].imm_u32); + return is_signed ? ImmValue(imm_values[0].imm_s32 >> other.imm_values[0].imm_s32) + : ImmValue(imm_values[0].imm_u32 >> other.imm_values[0].imm_u32); case Type::U64: - return is_signed - ? ImmValue(imm_values[0].imm_s64 >> other.imm_values[0].imm_s64) - : ImmValue(imm_values[0].imm_u64 >> other.imm_values[0].imm_u64); + return is_signed ? ImmValue(imm_values[0].imm_s64 >> other.imm_values[0].imm_s64) + : ImmValue(imm_values[0].imm_u64 >> other.imm_values[0].imm_u64); default: UNREACHABLE_MSG("Invalid type {}", type); } diff --git a/src/shader_recompiler/ir/compute_value/imm_value.h b/src/shader_recompiler/ir/compute_value/imm_value.h index 9ffefc382..74b9d39b7 100644 --- a/src/shader_recompiler/ir/compute_value/imm_value.h +++ b/src/shader_recompiler/ir/compute_value/imm_value.h @@ -62,10 +62,10 @@ using ImmS32xAny = TypedImmValue; using ImmF64xAny = TypedImmValue; using ImmS32F32xAny = TypedImmValue; using ImmF32F64xAny = TypedImmValue; class ImmValue { @@ -98,7 +98,8 @@ public: ImmValue(f64 value1, f64 value2, f64 value3, f64 value4) noexcept; ImmValue(const ImmValue& value1, const ImmValue& value2) noexcept; ImmValue(const ImmValue& value1, const ImmValue& value2, const ImmValue& value3) noexcept; - ImmValue(const ImmValue& value1, const ImmValue& value2, const ImmValue& value3, const ImmValue& value4) noexcept; + ImmValue(const ImmValue& value1, const ImmValue& value2, const ImmValue& value3, + const ImmValue& value4) noexcept; [[nodiscard]] bool IsEmpty() const noexcept; [[nodiscard]] IR::Type Type() const noexcept; @@ -196,10 +197,12 @@ public: [[nodiscard]] ImmValue trunc() const noexcept; [[nodiscard]] ImmValue fract() const noexcept; [[nodiscard]] bool isnan() const noexcept; - - [[nodiscard]] static ImmValue fma(const ImmF32F64& a, const ImmF32F64& b, const ImmF32F64& c) noexcept; - + + [[nodiscard]] static ImmValue fma(const ImmF32F64& a, const ImmF32F64& b, + const ImmF32F64& c) noexcept; + static bool IsSupportedValue(const IR::Value& value) noexcept; + private: union Value { bool imm_u1;