From 676b06db0ff1d6a94abd5c762a4b0268492526e5 Mon Sep 17 00:00:00 2001 From: Lander Gallastegi Date: Tue, 11 Mar 2025 15:43:06 +0100 Subject: [PATCH] Squashed commit of the following: commit 39328be45ed83d252e491b61da5cb4d767ff100d Author: Lander Gallastegi Date: Tue Mar 11 15:40:44 2025 +0100 Fix trivialy copiable commit f0633525b343dab7ea75cd7c809c936cb67e572f Author: Lander Gallastegi Date: Tue Mar 11 00:29:42 2025 +0100 Compute value commit 8c42a014ee925b61c5ea5721423da3211a635eb8 Author: Lander Gallastegi Date: Tue Mar 11 00:29:31 2025 +0100 Add missing operations --- CMakeLists.txt | 2 + .../ir/compute_value/compute.cpp | 503 ++++++++++++++++++ .../ir/compute_value/compute.h | 18 + .../ir/compute_value/imm_value.cpp | 462 +++++++++++++++- .../ir/compute_value/imm_value.h | 31 ++ 5 files changed, 997 insertions(+), 19 deletions(-) create mode 100644 src/shader_recompiler/ir/compute_value/compute.cpp create mode 100644 src/shader_recompiler/ir/compute_value/compute.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 041642840..5b723f76c 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -843,6 +843,8 @@ set(SHADER_RECOMPILER src/shader_recompiler/exception.h src/shader_recompiler/ir/passes/shared_memory_barrier_pass.cpp src/shader_recompiler/ir/passes/shared_memory_to_storage_pass.cpp src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp + src/shader_recompiler/ir/compute_value/compute.cpp + src/shader_recompiler/ir/compute_value/compute.h src/shader_recompiler/ir/compute_value/imm_value.cpp src/shader_recompiler/ir/compute_value/imm_value.h src/shader_recompiler/ir/abstract_syntax_list.cpp diff --git a/src/shader_recompiler/ir/compute_value/compute.cpp b/src/shader_recompiler/ir/compute_value/compute.cpp new file mode 100644 index 000000000..14e7518f4 --- /dev/null +++ b/src/shader_recompiler/ir/compute_value/compute.cpp @@ -0,0 +1,503 @@ +// SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#include +#include +#include +#include +#include "shader_recompiler/ir/compute_value/compute.h" + +namespace Shader::IR { + +template +static void CartesianInvokeImpl(Func func, OutputIt out_it, + std::tuple& arglists_its, + const std::tuple& arglists_tuple) { + if constexpr (Level == N) { + auto get_tuple = [&](std::index_sequence) { + return std::forward_as_tuple(*std::get(arglists_its)...); + }; + *out_it++ = std::move(std::apply(func, get_tuple(std::make_index_sequence{}))); + return; + } else { + const auto& arglist = std::get(arglists_tuple); + for (auto it = arglist.begin(); it != arglist.end(); ++it) { + std::get(arglists_its) = it; + CartesianInvokeImpl(func, out_it, arglists_its, arglists_tuple); + } + } +} + +template +static void CartesianInvoke(Func func, OutputIt out_it, const ArgLists&... arg_lists) { + constexpr size_t N = sizeof...(ArgLists); + const std::tuple arglists_tuple = std::forward_as_tuple(arg_lists...); + + std::tuple arglists_it; + CartesianInvokeImpl(func, out_it, arglists_it, arglists_tuple); +} + +static void SetSigned(ImmValueList& values, bool is_signed) { + for (auto& value : values) { + value.SetSigned(is_signed); + } +} + +static void OperationAbs(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return a.abs(); + }; + + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationAdd(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b) { + return a + b; + }; + + SetSigned(args0, is_signed); + SetSigned(args1, is_signed); + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); +} + +static void OperationSub(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b) { + return a - b; + }; + + SetSigned(args0, is_signed); + SetSigned(args1, is_signed); + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); +} + +static void OperationFma(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1, args2; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + ComputeImmValues(inst->Arg(2), args2, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b, const ImmValue& c) { + return ImmValue::fma(a, b, c); + }; + + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, args2); +} + +static void OperationMin(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1, is_legacy_args; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + if (inst->NumArgs() > 2) { + ComputeImmValues(inst->Arg(2), is_legacy_args, cache); + } else { + is_legacy_args.insert(ImmValue(false)); + } + + const auto op = [](const ImmValue& a, const ImmValue& b, const ImmValue& is_legacy) { + if (is_legacy.U1()) { + if (a.isnan()) return b; + if (b.isnan()) return a; + } + return std::min(a, b); + }; + + SetSigned(args0, is_signed); + SetSigned(args1, is_signed); + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, is_legacy_args); +} + +static void OperationMax(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1, is_legacy_args; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + if (inst->NumArgs() > 2) { + ComputeImmValues(inst->Arg(2), is_legacy_args, cache); + } else { + is_legacy_args.insert(ImmValue(false)); + } + + const auto op = [](const ImmValue& a, const ImmValue& b, const ImmValue& is_legacy) { + if (is_legacy.U1()) { + if (a.isnan()) return b; + if (b.isnan()) return a; + } + return std::max(a, b); + }; + + SetSigned(args0, is_signed); + SetSigned(args1, is_signed); + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, is_legacy_args); +} + +static void OperationMul(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b) { + return a * b; + }; + + SetSigned(args0, is_signed); + SetSigned(args1, is_signed); + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); +} + +static void OperationDiv(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b) { + return a / b; + }; + + SetSigned(args0, is_signed); + SetSigned(args1, is_signed); + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); +} + +static void OperationNeg(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return -a; + }; + + SetSigned(args, true); + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationRecip(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return a.recip(); + }; + + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationRecipSqrt(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return a.rsqrt(); + }; + + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationSqrt(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return a.sqrt(); + }; + + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationSin(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return a.sin(); + }; + + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationExp2(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return a.exp2(); + }; + + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationLdexp(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b) { + return a.ldexp(b); + }; + + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); +} + +static void OperationCos(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return a.cos(); + }; + + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationLog2(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return a.log2(); + }; + + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationClamp(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1, args2; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + ComputeImmValues(inst->Arg(2), args2, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b, const ImmValue& c) { + return a.clamp(b, c); + }; + + SetSigned(args0, is_signed); + SetSigned(args1, is_signed); + SetSigned(args2, is_signed); + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, args2); +} + +static void OperationRound(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return a.round(); + }; + + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationFloor(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return a.floor(); + }; + + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationCeil(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return a.ceil(); + }; + + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationTrunc(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return a.trunc(); + }; + + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationFract(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return a.fract(); + }; + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationShiftLeft(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b) { + return a << b; + }; + + SetSigned(args1, false); + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); +} + +static void OperationShiftRight(Inst* inst, bool is_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b) { + return a >> b; + }; + + SetSigned(args0, is_signed); + SetSigned(args1, false); + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); +} + +static void OperationBitwiseNot(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [](const ImmValue& a) { + return ~a; + }; + + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationBitwiseAnd(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b) { + return a & b; + }; + + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); +} + +static void OperationBitwiseOr(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b) { + return a | b; + }; + + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); +} + +static void OperationBitwiseXor(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b) { + return a ^ b; + }; + + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); +} + +static void OperationConvert(Inst* inst, bool is_signed, Type new_type, bool new_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [new_type, new_signed](const ImmValue& a) { + return a.Convert(new_type, new_signed); + }; + + SetSigned(args, is_signed); + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +static void OperationBitCast(Inst* inst, Type new_type, bool new_signed, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args; + ComputeImmValues(inst->Arg(0), args, cache); + + const auto op = [new_type, new_signed](const ImmValue& a) { + return a.Bitcast(new_type, new_signed); + }; + + std::transform(args.begin(), args.end(), std::inserter(inst_values, inst_values.begin()), op); +} + +template +static void OperationCompositeConstruct(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + std::array args; + for (size_t i = 0; i < N; ++i) { + ComputeImmValues(inst->Arg(i), args[i], cache); + } + + const auto op = [](const Args&... args) { + return ImmValue(args...); + }; + + const auto call_cartesian = [&](std::index_sequence) { + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args[I]...); + }; + call_cartesian(std::make_index_sequence{}); +} + +static void OperationCompositeExtract(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + ImmValueList args0, args1; + ComputeImmValues(inst->Arg(0), args0, cache); + ComputeImmValues(inst->Arg(1), args1, cache); + + const auto op = [](const ImmValue& a, const ImmValue& b) { + return a.Extract(b); + }; + + SetSigned(args1, false); + CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); +} + +static void DoInstructionOperation(Inst* inst, ImmValueList& inst_values, ComputeImmValuesCache& cache) { + switch (inst->GetOpcode()) { + default: + break; + } +} + +void ComputeImmValues(const Value& value, ImmValueList& values, ComputeImmValuesCache& cache) { + Value resolved = value.Resolve(); + if (ImmValue::IsSupportedValue(resolved)) { + values.insert(ImmValue(resolved)); + return; + } + if (resolved.Type() != Type::Opaque) { + return; + } + Inst* inst = resolved.InstRecursive(); + auto it = cache.find(inst); + if (it != cache.end()) { + values.insert(it->second.begin(), it->second.end()); + return; + } + auto& inst_values = cache.emplace(inst, ImmValueList{}).first->second; + if (inst->GetOpcode() == Opcode::Phi) { + for (size_t i = 0; i < inst->NumArgs(); ++i) { + ComputeImmValues(inst->Arg(i), inst_values, cache); + } + } else { + + } + values.insert(inst_values.begin(), inst_values.end()); +} + +} // namespace Shader::IR diff --git a/src/shader_recompiler/ir/compute_value/compute.h b/src/shader_recompiler/ir/compute_value/compute.h new file mode 100644 index 000000000..fbfe46575 --- /dev/null +++ b/src/shader_recompiler/ir/compute_value/compute.h @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#pragma once + +#include +#include +#include "shader_recompiler/ir/compute_value/imm_value.h" +#include "shader_recompiler/ir/value.h" + +namespace Shader::IR { + +using ImmValueList = boost::container::flat_set; +using ComputeImmValuesCache = boost::container::flat_map; + +void ComputeImmValues(const Value& value, ImmValueList& values, ComputeImmValuesCache& cache); + +} // namespace Shader::IR diff --git a/src/shader_recompiler/ir/compute_value/imm_value.cpp b/src/shader_recompiler/ir/compute_value/imm_value.cpp index f222ea009..6c433c244 100644 --- a/src/shader_recompiler/ir/compute_value/imm_value.cpp +++ b/src/shader_recompiler/ir/compute_value/imm_value.cpp @@ -7,29 +7,28 @@ namespace Shader::IR { ImmValue::ImmValue(const IR::Value& value) noexcept { - IR::Value resolved = value.Resolve(); - type = resolved.Type(); + type = value.Type(); switch (type) { case Type::U1: - imm_values[0].imm_u1 = resolved.U1(); + imm_values[0].imm_u1 = value.U1(); break; case Type::U8: - imm_values[0].imm_u8 = resolved.U8(); + imm_values[0].imm_u8 = value.U8(); break; case Type::U16: - imm_values[0].imm_u16 = resolved.U16(); + imm_values[0].imm_u16 = value.U16(); break; case Type::U32: - imm_values[0].imm_u32 = resolved.U32(); + imm_values[0].imm_u32 = value.U32(); break; case Type::F32: - imm_values[0].imm_f32 = resolved.F32(); + imm_values[0].imm_f32 = value.F32(); break; case Type::U64: - imm_values[0].imm_u64 = resolved.U64(); + imm_values[0].imm_u64 = value.U64(); break; case Type::F64: - imm_values[0].imm_f64 = resolved.F64(); + imm_values[0].imm_f64 = value.F64(); break; default: UNREACHABLE_MSG("Invalid type {}", type); @@ -160,6 +159,44 @@ ImmValue::ImmValue(f64 value1, f64 value2, f64 value3, f64 value4) noexcept imm_values[3].imm_f64 = value4; } +ImmValue::ImmValue(const ImmValue& value1, const ImmValue& value2) noexcept + : type{value1.type}, is_signed{value1.is_signed} { + ASSERT(value1.type == value2.type && value1.is_signed == value2.is_signed); + switch (value1.Dimensions()) { + case 1: + imm_values[0] = value1.imm_values[0]; + imm_values[1] = value2.imm_values[0]; + break; + case 2: + imm_values[0] = value1.imm_values[0]; + imm_values[1] = value1.imm_values[1]; + imm_values[2] = value2.imm_values[0]; + imm_values[3] = value2.imm_values[1]; + break; + default: + UNREACHABLE_MSG("Invalid dimensions {}", value1.Dimensions()); + } +} + +ImmValue::ImmValue(const ImmValue& value1, const ImmValue& value2, const ImmValue& value3) noexcept + : type{value1.type}, is_signed{value1.is_signed} { + ASSERT(value1.type == value2.type && value1.type == value3.type && value1.is_signed == value2.is_signed && + value1.is_signed == value3.is_signed && value1.Dimensions() == 1); + imm_values[0] = value1.imm_values[0]; + imm_values[1] = value2.imm_values[0]; + imm_values[2] = value3.imm_values[0]; +} + +ImmValue::ImmValue(const ImmValue& value1, const ImmValue& value2, const ImmValue& value3, const ImmValue& value4) noexcept + : type{value1.type}, is_signed{value1.is_signed} { + ASSERT(value1.type == value2.type && value1.type == value3.type && value1.type == value4.type && value1.is_signed == value2.is_signed && + value1.is_signed == value3.is_signed && value1.is_signed == value4.is_signed && value1.Dimensions() == 1); + imm_values[0] = value1.imm_values[0]; + imm_values[1] = value2.imm_values[0]; + imm_values[2] = value3.imm_values[0]; + imm_values[3] = value4.imm_values[0]; +} + IR::Type ImmValue::BaseType() const noexcept { switch (type) { case Type::U1: @@ -229,6 +266,102 @@ void ImmValue::SameSignAs(const ImmValue& other) noexcept { SetSigned(other.IsSigned()); } +ImmValue ImmValue::Convert(IR::Type new_type, bool new_signed) const noexcept { + switch (new_type) { + case Type::U16: { + switch (type) { + case Type::U32: + return ImmValue(static_cast(imm_values[0].imm_u32)); + default: + break; + } + break; + } + case Type::U32: { + if (new_signed) { + switch (type) { + case Type::F32: + return ImmValue(static_cast(imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(static_cast(imm_values[0].imm_f64)); + default: + break; + } + } else { + switch (type) { + case Type::U16: + return ImmValue(static_cast(imm_values[0].imm_u16)); + case Type::U32: + if (is_signed) { + return ImmValue(static_cast(imm_values[0].imm_s32)); + } + break; + case Type::F32: + return ImmValue(static_cast(imm_values[0].imm_f32)); + default: + break; + } + } + } + case Type::F32: { + switch (type) { + case Type::U16: + return ImmValue(static_cast(imm_values[0].imm_u16)); + case Type::U32: + if (is_signed) { + return ImmValue(static_cast(imm_values[0].imm_s32)); + } else { + return ImmValue(static_cast(imm_values[0].imm_u32)); + } + case Type::F64: + return ImmValue(static_cast(imm_values[0].imm_f64)); + default: + break; + } + break; + } + case Type::F64: { + switch (type) { + case Type::F32: + return ImmValue(static_cast(imm_values[0].imm_f32)); + default: + break; + } + break; + } + default: + break; + } + UNREACHABLE_MSG("Invalid conversion from {} {} to {} {}", is_signed ? "signed" : "unsigned", + type, new_signed ? "signed" : "unsigned", new_type); +} + +ImmValue ImmValue::Bitcast(IR::Type new_type, bool new_signed) const noexcept { + ImmValue result; + result.type = new_type; + result.is_signed = new_signed; + result.imm_values = imm_values; + ASSERT(Dimensions() == result.Dimensions()); + return result; +} + +ImmValue ImmValue::Extract(const ImmValue& index) const noexcept { + ASSERT(index.type == Type::U32 && !index.is_signed && index.imm_values[0].imm_u32 < Dimensions()); + ImmValue result; + result.type = BaseType(); + result.is_signed = IsSigned(); + result.imm_values[0] = imm_values[index.imm_values[0].imm_u32]; + return result; +} + +ImmValue ImmValue::Insert(const ImmValue& value, const ImmValue& index) const noexcept { + ASSERT(index.type == Type::U32 && !index.is_signed && index.imm_values[0].imm_u32 < Dimensions()); + ASSERT(value.type == BaseType() && value.IsSigned() == IsSigned()); + ImmValue result = *this; + result.imm_values[index.imm_values[0].imm_u32] = value.imm_values[0]; + return result; +} + bool ImmValue::operator==(const ImmValue& other) const noexcept { if (type != other.type) { return false; @@ -747,24 +880,24 @@ ImmValue ImmValue::operator^(const ImmValue& other) const noexcept { } ImmValue ImmValue::operator<<(const ImmValue& other) const noexcept { - ASSERT(type == other.type); + ASSERT(other.type == Type::U32 && other.Dimensions() == 1); switch (type) { case Type::U1: return ImmValue(imm_values[0].imm_u1 << other.imm_values[0].imm_u1); case Type::U8: - return is_signed && other.is_signed + return is_signed ? ImmValue(imm_values[0].imm_s8 << other.imm_values[0].imm_s8) : ImmValue(imm_values[0].imm_u8 << other.imm_values[0].imm_u8); case Type::U16: - return is_signed && other.is_signed + return is_signed ? ImmValue(imm_values[0].imm_s16 << other.imm_values[0].imm_s16) : ImmValue(imm_values[0].imm_u16 << other.imm_values[0].imm_u16); case Type::U32: - return is_signed && other.is_signed + return is_signed ? ImmValue(imm_values[0].imm_s32 << other.imm_values[0].imm_s32) : ImmValue(imm_values[0].imm_u32 << other.imm_values[0].imm_u32); case Type::U64: - return is_signed && other.is_signed + return is_signed ? ImmValue(imm_values[0].imm_s64 << other.imm_values[0].imm_s64) : ImmValue(imm_values[0].imm_u64 << other.imm_values[0].imm_u64); default: @@ -773,24 +906,24 @@ ImmValue ImmValue::operator<<(const ImmValue& other) const noexcept { } ImmValue ImmValue::operator>>(const ImmValue& other) const noexcept { - ASSERT(type == other.type); + ASSERT(other.type == Type::U32 && other.Dimensions() == 1); switch (type) { case Type::U1: return ImmValue(imm_values[0].imm_u1 >> other.imm_values[0].imm_u1); case Type::U8: - return is_signed && other.is_signed + return is_signed ? ImmValue(imm_values[0].imm_s8 >> other.imm_values[0].imm_s8) : ImmValue(imm_values[0].imm_u8 >> other.imm_values[0].imm_u8); case Type::U16: - return is_signed && other.is_signed + return is_signed ? ImmValue(imm_values[0].imm_s16 >> other.imm_values[0].imm_s16) : ImmValue(imm_values[0].imm_u16 >> other.imm_values[0].imm_u16); case Type::U32: - return is_signed && other.is_signed + return is_signed ? ImmValue(imm_values[0].imm_s32 >> other.imm_values[0].imm_s32) : ImmValue(imm_values[0].imm_u32 >> other.imm_values[0].imm_u32); case Type::U64: - return is_signed && other.is_signed + return is_signed ? ImmValue(imm_values[0].imm_s64 >> other.imm_values[0].imm_s64) : ImmValue(imm_values[0].imm_u64 >> other.imm_values[0].imm_u64); default: @@ -1049,6 +1182,297 @@ ImmValue& ImmValue::operator>>=(const ImmValue& other) noexcept { return *this; } +ImmValue ImmValue::abs() const noexcept { + switch (type) { + case Type::U8: + return is_signed ? ImmValue(std::abs(imm_values[0].imm_s8)) + : ImmValue(imm_values[0].imm_u8); + case Type::U16: + return is_signed ? ImmValue(std::abs(imm_values[0].imm_s16)) + : ImmValue(imm_values[0].imm_u16); + case Type::U32: + return is_signed ? ImmValue(std::abs(imm_values[0].imm_s32)) + : ImmValue(imm_values[0].imm_u32); + case Type::U64: + return is_signed ? ImmValue(std::abs(imm_values[0].imm_s64)) + : ImmValue(imm_values[0].imm_u64); + case Type::F32: + return ImmValue(std::abs(imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(std::abs(imm_values[0].imm_f64)); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::recip() const noexcept { + switch (type) { + case Type::F32: + return ImmValue(1.0f / imm_values[0].imm_f32); + case Type::F64: + return ImmValue(1.0 / imm_values[0].imm_f64); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::sqrt() const noexcept { + switch (type) { + case Type::F32: + return ImmValue(std::sqrt(imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(std::sqrt(imm_values[0].imm_f64)); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::rsqrt() const noexcept { + switch (type) { + case Type::F32: + return ImmValue(1.0f / std::sqrt(imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(1.0 / std::sqrt(imm_values[0].imm_f64)); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::sin() const noexcept { + switch (type) { + case Type::F32: + return ImmValue(std::sin(imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(std::sin(imm_values[0].imm_f64)); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::cos() const noexcept { + switch (type) { + case Type::F32: + return ImmValue(std::cos(imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(std::cos(imm_values[0].imm_f64)); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::exp2() const noexcept { + switch (type) { + case Type::F32: + return ImmValue(std::exp2(imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(std::exp2(imm_values[0].imm_f64)); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::ldexp(const ImmValue& exp) const noexcept { + ASSERT(type == exp.type); + switch (type) { + case Type::F32: + return ImmValue(std::ldexp(imm_values[0].imm_f32, exp.imm_values[0].imm_s32)); + case Type::F64: + return ImmValue(std::ldexp(imm_values[0].imm_f64, exp.imm_values[0].imm_s32)); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::log2() const noexcept { + switch (type) { + case Type::F32: + return ImmValue(std::log2(imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(std::log2(imm_values[0].imm_f64)); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::clamp(const ImmValue& min, const ImmValue& max) const noexcept { + ASSERT(type == min.type && min.type == max.type); + switch (type) { + case Type::U8: + return is_signed && min.is_signed && max.is_signed + ? ImmValue(std::clamp(imm_values[0].imm_s8, min.imm_values[0].imm_s8, + max.imm_values[0].imm_s8)) + : ImmValue(std::clamp(imm_values[0].imm_u8, min.imm_values[0].imm_u8, + max.imm_values[0].imm_u8)); + case Type::U16: + return is_signed && min.is_signed && max.is_signed + ? ImmValue(std::clamp(imm_values[0].imm_s16, min.imm_values[0].imm_s16, + max.imm_values[0].imm_s16)) + : ImmValue(std::clamp(imm_values[0].imm_u16, min.imm_values[0].imm_u16, + max.imm_values[0].imm_u16)); + case Type::U32: + return is_signed && min.is_signed && max.is_signed + ? ImmValue(std::clamp(imm_values[0].imm_s32, min.imm_values[0].imm_s32, + max.imm_values[0].imm_s32)) + : ImmValue(std::clamp(imm_values[0].imm_u32, min.imm_values[0].imm_u32, + max.imm_values[0].imm_u32)); + case Type::U64: + return is_signed && min.is_signed && max.is_signed + ? ImmValue(std::clamp(imm_values[0].imm_s64, min.imm_values[0].imm_s64, + max.imm_values[0].imm_s64)) + : ImmValue(std::clamp(imm_values[0].imm_u64, min.imm_values[0].imm_u64, + max.imm_values[0].imm_u64)); + case Type::F32: + return ImmValue(std::clamp(imm_values[0].imm_f32, min.imm_values[0].imm_f32, + max.imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(std::clamp(imm_values[0].imm_f64, min.imm_values[0].imm_f64, + max.imm_values[0].imm_f64)); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::floor() const noexcept { + switch (type) { + case Type::F32: + return ImmValue(std::floor(imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(std::floor(imm_values[0].imm_f64)); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::ceil() const noexcept { + switch (type) { + case Type::F32: + return ImmValue(std::ceil(imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(std::ceil(imm_values[0].imm_f64)); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::round() const noexcept { + switch (type) { + case Type::F32: + return ImmValue(std::round(imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(std::round(imm_values[0].imm_f64)); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::trunc() const noexcept { + switch (type) { + case Type::F32: + return ImmValue(std::trunc(imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(std::trunc(imm_values[0].imm_f64)); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::fract() const noexcept { + switch (type) { + case Type::F32: + return ImmValue(imm_values[0].imm_f32 - std::floor(imm_values[0].imm_f32)); + case Type::F64: + return ImmValue(imm_values[0].imm_f64 - std::floor(imm_values[0].imm_f64)); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +bool ImmValue::isnan() const noexcept { + switch (type) { + case Type::F32: + return std::isnan(imm_values[0].imm_f32); + case Type::F64: + return std::isnan(imm_values[0].imm_f64); + case Type::F32x2: + return std::isnan(imm_values[0].imm_f32) || std::isnan(imm_values[1].imm_f32); + case Type::F64x2: + return std::isnan(imm_values[0].imm_f64) || std::isnan(imm_values[1].imm_f64); + case Type::F32x3: + return std::isnan(imm_values[0].imm_f32) || std::isnan(imm_values[1].imm_f32) || + std::isnan(imm_values[2].imm_f32); + case Type::F64x3: + return std::isnan(imm_values[0].imm_f64) || std::isnan(imm_values[1].imm_f64) || + std::isnan(imm_values[2].imm_f64); + case Type::F32x4: + return std::isnan(imm_values[0].imm_f32) || std::isnan(imm_values[1].imm_f32) || + std::isnan(imm_values[2].imm_f32) || std::isnan(imm_values[3].imm_f32); + case Type::F64x4: + return std::isnan(imm_values[0].imm_f64) || std::isnan(imm_values[1].imm_f64) || + std::isnan(imm_values[2].imm_f64) || std::isnan(imm_values[3].imm_f64); + default: + UNREACHABLE_MSG("Invalid type {}", type); + } +} + +ImmValue ImmValue::fma(const ImmValue& a, const ImmValue& b, const ImmValue& c) noexcept { + ASSERT(a.type == b.type && b.type == c.type); + switch (a.type) { + case Type::F32: + return ImmValue( + std::fma(a.imm_values[0].imm_f32, b.imm_values[0].imm_f32, c.imm_values[0].imm_f32)); + case Type::F64: + return ImmValue( + std::fma(a.imm_values[0].imm_f64, b.imm_values[0].imm_f64, c.imm_values[0].imm_f64)); + case Type::F32x2: + return ImmValue( + std::fma(a.imm_values[0].imm_f32, b.imm_values[0].imm_f32, c.imm_values[0].imm_f32), + std::fma(a.imm_values[1].imm_f32, b.imm_values[1].imm_f32, c.imm_values[1].imm_f32)); + case Type::F64x2: + return ImmValue( + std::fma(a.imm_values[0].imm_f64, b.imm_values[0].imm_f64, c.imm_values[0].imm_f64), + std::fma(a.imm_values[1].imm_f64, b.imm_values[1].imm_f64, c.imm_values[1].imm_f64)); + case Type::F32x3: + return ImmValue( + std::fma(a.imm_values[0].imm_f32, b.imm_values[0].imm_f32, c.imm_values[0].imm_f32), + std::fma(a.imm_values[1].imm_f32, b.imm_values[1].imm_f32, c.imm_values[1].imm_f32), + std::fma(a.imm_values[2].imm_f32, b.imm_values[2].imm_f32, c.imm_values[2].imm_f32)); + case Type::F64x3: + return ImmValue( + std::fma(a.imm_values[0].imm_f64, b.imm_values[0].imm_f64, c.imm_values[0].imm_f64), + std::fma(a.imm_values[1].imm_f64, b.imm_values[1].imm_f64, c.imm_values[1].imm_f64), + std::fma(a.imm_values[2].imm_f64, b.imm_values[2].imm_f64, c.imm_values[2].imm_f64)); + case Type::F32x4: + return ImmValue( + std::fma(a.imm_values[0].imm_f32, b.imm_values[0].imm_f32, c.imm_values[0].imm_f32), + std::fma(a.imm_values[1].imm_f32, b.imm_values[1].imm_f32, c.imm_values[1].imm_f32), + std::fma(a.imm_values[2].imm_f32, b.imm_values[2].imm_f32, c.imm_values[2].imm_f32), + std::fma(a.imm_values[3].imm_f32, b.imm_values[3].imm_f32, c.imm_values[3].imm_f32)); + case Type::F64x4: + return ImmValue( + std::fma(a.imm_values[0].imm_f64, b.imm_values[0].imm_f64, c.imm_values[0].imm_f64), + std::fma(a.imm_values[1].imm_f64, b.imm_values[1].imm_f64, c.imm_values[1].imm_f64), + std::fma(a.imm_values[2].imm_f64, b.imm_values[2].imm_f64, c.imm_values[2].imm_f64), + std::fma(a.imm_values[3].imm_f64, b.imm_values[3].imm_f64, c.imm_values[3].imm_f64)); + default: + UNREACHABLE_MSG("Invalid type {}", a.type); + } +} + +bool ImmValue::IsSupportedValue(const IR::Value& value) noexcept { + switch (value.Type()) { + case IR::Type::U1: + case IR::Type::U8: + case IR::Type::U16: + case IR::Type::U32: + case IR::Type::U64: + case IR::Type::F32: + case IR::Type::F64: + return true; + default: + return false; + } +} + } // namespace Shader::IR namespace std { diff --git a/src/shader_recompiler/ir/compute_value/imm_value.h b/src/shader_recompiler/ir/compute_value/imm_value.h index 7ece9f48e..78696d83a 100644 --- a/src/shader_recompiler/ir/compute_value/imm_value.h +++ b/src/shader_recompiler/ir/compute_value/imm_value.h @@ -19,6 +19,7 @@ namespace Shader::IR { class ImmValue { public: ImmValue() noexcept = default; + ImmValue(const ImmValue& value) noexcept = default; explicit ImmValue(const IR::Value& value) noexcept; explicit ImmValue(bool value) noexcept; explicit ImmValue(u8 value) noexcept; @@ -43,6 +44,9 @@ public: ImmValue(f64 value1, f64 value2) noexcept; ImmValue(f64 value1, f64 value2, f64 value3) noexcept; ImmValue(f64 value1, f64 value2, f64 value3, f64 value4) noexcept; + ImmValue(const ImmValue& value1, const ImmValue& value2) noexcept; + ImmValue(const ImmValue& value1, const ImmValue& value2, const ImmValue& value3) noexcept; + ImmValue(const ImmValue& value1, const ImmValue& value2, const ImmValue& value3, const ImmValue& value4) noexcept; [[nodiscard]] bool IsEmpty() const noexcept; [[nodiscard]] IR::Type Type() const noexcept; @@ -53,6 +57,11 @@ public: void SetSigned(bool signed_) noexcept; void SameSignAs(const ImmValue& other) noexcept; + [[nodiscard]] ImmValue Convert(IR::Type new_type, bool new_signed) const noexcept; + [[nodiscard]] ImmValue Bitcast(IR::Type new_type, bool new_signed) const noexcept; + [[nodiscard]] ImmValue Extract(const ImmValue& index) const noexcept; + [[nodiscard]] ImmValue Insert(const ImmValue& value, const ImmValue& index) const noexcept; + [[nodiscard]] bool U1() const; [[nodiscard]] u8 U8() const; [[nodiscard]] s8 S8() const; @@ -78,6 +87,8 @@ public: [[nodiscard]] std::tuple F64x3() const; [[nodiscard]] std::tuple F64x4() const; + ImmValue& operator=(const ImmValue& value) noexcept = default; + [[nodiscard]] bool operator==(const ImmValue& other) const noexcept; [[nodiscard]] bool operator!=(const ImmValue& other) const noexcept; [[nodiscard]] bool operator<(const ImmValue& other) const noexcept; @@ -117,6 +128,26 @@ public: ImmValue& operator<<=(const ImmValue& other) noexcept; ImmValue& operator>>=(const ImmValue& other) noexcept; + [[nodiscard]] ImmValue abs() const noexcept; + [[nodiscard]] ImmValue recip() const noexcept; + [[nodiscard]] ImmValue sqrt() const noexcept; + [[nodiscard]] ImmValue rsqrt() const noexcept; + [[nodiscard]] ImmValue sin() const noexcept; + [[nodiscard]] ImmValue cos() const noexcept; + [[nodiscard]] ImmValue exp2() const noexcept; + [[nodiscard]] ImmValue ldexp(const ImmValue& exp) const noexcept; + [[nodiscard]] ImmValue log2() const noexcept; + [[nodiscard]] ImmValue clamp(const ImmValue& min, const ImmValue& max) const noexcept; + [[nodiscard]] ImmValue floor() const noexcept; + [[nodiscard]] ImmValue ceil() const noexcept; + [[nodiscard]] ImmValue round() const noexcept; + [[nodiscard]] ImmValue trunc() const noexcept; + [[nodiscard]] ImmValue fract() const noexcept; + [[nodiscard]] bool isnan() const noexcept; + + [[nodiscard]] static ImmValue fma(const ImmValue& a, const ImmValue& b, const ImmValue& c) noexcept; + + static bool IsSupportedValue(const IR::Value& value) noexcept; private: union Value { bool imm_u1;