diff --git a/src/shader_recompiler/ir/compute_value/compute.cpp b/src/shader_recompiler/ir/compute_value/compute.cpp index 0f1707981..829de9d72 100644 --- a/src/shader_recompiler/ir/compute_value/compute.cpp +++ b/src/shader_recompiler/ir/compute_value/compute.cpp @@ -467,13 +467,193 @@ static void OperationCompositeExtract(Inst* inst, ImmValueList& inst_values, Com CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); } +static void OperationInsert(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1, args2; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + ComputeImmValues(inst->Arg(2), args2, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b, const ImmValue& c) { + return a.Insert(b, ImmU32(c)); + }; + + SetSigned(args2, false); + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, args2); +} + static void DoInstructionOperation(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { switch (inst->GetOpcode()) { + case Opcode::CompositeConstructU32x2: + case Opcode::CompositeConstructU32x2x2: + case Opcode::CompositeConstructF16x2: + case Opcode::CompositeConstructF32x2: + case Opcode::CompositeConstructF32x2x2: + case Opcode::CompositeConstructF64x2: + OperationCompositeConstruct<2>(inst, inst_values, cache); + break; + case Opcode::CompositeConstructU32x3: + case Opcode::CompositeConstructF16x3: + case Opcode::CompositeConstructF32x3: + case Opcode::CompositeConstructF64x3: + OperationCompositeConstruct<3>(inst, inst_values, cache); + break; + case Opcode::CompositeConstructU32x4: + case Opcode::CompositeConstructF16x4: + case Opcode::CompositeConstructF32x4: + case Opcode::CompositeConstructF64x4: + OperationCompositeConstruct<4>(inst, inst_values, cache); + break; + case Opcode::CompositeExtractU32x2: + case Opcode::CompositeExtractU32x3: + case Opcode::CompositeExtractU32x4: + case Opcode::CompositeExtractF16x2: + case Opcode::CompositeExtractF16x3: + case Opcode::CompositeExtractF16x4: + case Opcode::CompositeExtractF32x2: + case Opcode::CompositeExtractF32x3: + case Opcode::CompositeExtractF32x4: + case Opcode::CompositeExtractF64x2: + case Opcode::CompositeExtractF64x3: + case Opcode::CompositeExtractF64x4: + OperationCompositeExtract(inst, inst_values, cache); + break; + case Opcode::CompositeInsertU32x2: + case Opcode::CompositeInsertU32x3: + case Opcode::CompositeInsertU32x4: + case Opcode::CompositeInsertF16x2: + case Opcode::CompositeInsertF16x3: + case Opcode::CompositeInsertF16x4: + case Opcode::CompositeInsertF32x2: + case Opcode::CompositeInsertF32x3: + case Opcode::CompositeInsertF32x4: + case Opcode::CompositeInsertF64x2: + case Opcode::CompositeInsertF64x3: + case Opcode::CompositeInsertF64x4: + OperationInsert(inst, inst_values, cache); + break; + case Opcode::BitCastU16F16: + OperationBitCast(inst, IR::Type::U16, false, inst_values, cache); + break; + case Opcode::BitCastU32F32: + OperationBitCast(inst, IR::Type::U32, false, inst_values, cache); + break; + case Opcode::BitCastU64F64: + OperationBitCast(inst, IR::Type::U64, false, inst_values, cache); + break; + case Opcode::BitCastF16U16: + OperationBitCast(inst, IR::Type::F16, true, inst_values, cache); + break; + case Opcode::BitCastF32U32: + OperationBitCast(inst, IR::Type::F32, true, inst_values, cache); + break; + case Opcode::BitCastF64U64: + OperationBitCast(inst, IR::Type::F64, true, inst_values, cache); + break; + case Opcode::FPAbs32: + case Opcode::FPAbs64: + OperationAbs(inst, inst_values, cache); + break; + case Opcode::FPAdd32: + case Opcode::FPAdd64: + OperationAdd(inst, false, inst_values, cache); + break; + case Opcode::FPSub32: + OperationSub(inst, false, inst_values, cache); + break; + case Opcode::FPMul32: + case Opcode::FPMul64: + OperationMul(inst, false, inst_values, cache); + break; + case Opcode::FPDiv32: + case Opcode::FPDiv64: + OperationDiv(inst, false, inst_values, cache); + break; + case Opcode::FPFma32: + case Opcode::FPFma64: + OperationFma(inst, inst_values, cache); + break; + case Opcode::FPMin32: + case Opcode::FPMin64: + OperationMin(inst, false, inst_values, cache); + break; + case Opcode::FPMax32: + case Opcode::FPMax64: + OperationMax(inst, false, inst_values, cache); + break; + case Opcode::FPNeg32: + case Opcode::FPNeg64: + OperationNeg(inst, inst_values, cache); + break; + case Opcode::FPRecip32: + case Opcode::FPRecip64: + OperationRecip(inst, inst_values, cache); + break; + case Opcode::FPRecipSqrt32: + case Opcode::FPRecipSqrt64: + OperationRecipSqrt(inst, inst_values, cache); + break; + case Opcode::FPSqrt: + OperationSqrt(inst, inst_values, cache); + break; + case Opcode::FPSin: + OperationSin(inst, inst_values, cache); + break; + case Opcode::FPCos: + OperationCos(inst, inst_values, cache); + break; + case Opcode::FPExp2: + OperationExp2(inst, inst_values, cache); + break; + case Opcode::FPLdexp: + OperationLdexp(inst, inst_values, cache); + break; + case Opcode::FPLog2: + OperationLog2(inst, inst_values, cache); + break; + case Opcode::FPClamp32: + case Opcode::FPClamp64: + OperationClamp(inst, false, inst_values, cache); + break; + case Opcode::FPRoundEven32: + case Opcode::FPRoundEven64: + OperationRound(inst, inst_values, cache); + break; + case Opcode::FPFloor32: + case Opcode::FPFloor64: + OperationFloor(inst, inst_values, cache); + break; + case Opcode::FPCeil32: + case Opcode::FPCeil64: + OperationCeil(inst, inst_values, cache); + break; + case Opcode::FPTrunc32: + case Opcode::FPTrunc64: + OperationTrunc(inst, inst_values, cache); + break; + case Opcode::FPFract32: + case Opcode::FPFract64: + OperationFract(inst, inst_values, cache); + break; default: break; } } +static bool IsSelectInst(Inst* inst) { + switch (inst->GetOpcode()) { + case Opcode::SelectU1: + case Opcode::SelectU8: + case Opcode::SelectU16: + case Opcode::SelectU32: + case Opcode::SelectU64: + case Opcode::SelectF32: + case Opcode::SelectF64: + return true; + default: + return false; + } +} + void ComputeImmValues(const Value& value, ImmValueList& values, ComputeImmValuesCache& cache) { Value resolved = value.Resolve(); if (ImmValue::IsSupportedValue(resolved)) { @@ -494,8 +674,11 @@ void ComputeImmValues(const Value& value, ImmValueList& values, ComputeImmValues for (size_t i = 0; i < inst->NumArgs(); ++i) { ComputeImmValues(inst->Arg(i), inst_values, cache); } + } if (IsSelectInst(inst)) { + ComputeImmValues(inst->Arg(1), inst_values, cache); + ComputeImmValues(inst->Arg(2), inst_values, cache); } else { - + DoInstructionOperation(inst, inst_values, cache); } values.insert(inst_values.begin(), inst_values.end()); }