diff --git a/src/shader_recompiler/ir/compute_value/compute.cpp b/src/shader_recompiler/ir/compute_value/compute.cpp index 14e7518f4..0f1707981 100644 --- a/src/shader_recompiler/ir/compute_value/compute.cpp +++ b/src/shader_recompiler/ir/compute_value/compute.cpp @@ -89,7 +89,7 @@ static void OperationFma(Inst* inst, ImmValueList& inst_values, ComputeImmValues ComputeImmValues(inst->Arg(2), args2, cache); const auto op = [](const ImmValue& a, const ImmValue& b, const ImmValue& c) { - return ImmValue::fma(a, b, c); + return ImmValue::fma(ImmF32F64(a), ImmF32F64(b), ImmF32F64(c)); }; CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, args2); @@ -242,7 +242,7 @@ static void OperationLdexp(Inst* inst, ImmValueList& inst_values, ComputeImmValu ComputeImmValues(inst->Arg(1), args1, cache); const auto op = [](const ImmValue& a, const ImmValue& b) { - return a.ldexp(b); + return a.ldexp(ImmU32(b)); }; CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); @@ -346,7 +346,7 @@ static void OperationShiftLeft(Inst* inst, ImmValueList& inst_values, ComputeImm ComputeImmValues(inst->Arg(1), args1, cache); const auto op = [](const ImmValue& a, const ImmValue& b) { - return a << b; + return a << ImmU32(b); }; SetSigned(args1, false); @@ -359,7 +359,7 @@ static void OperationShiftRight(Inst* inst, bool is_signed, ImmValueList& inst_v ComputeImmValues(inst->Arg(1), args1, cache); const auto op = [](const ImmValue& a, const ImmValue& b) { - return a >> b; + return a >> ImmU32(b); }; SetSigned(args0, is_signed); @@ -460,7 +460,7 @@ static void OperationCompositeExtract(Inst* inst, ImmValueList& inst_values, Com ComputeImmValues(inst->Arg(1), args1, cache); const auto op = [](const ImmValue& a, const ImmValue& b) { - return a.Extract(b); + return a.Extract(ImmU32(b)); }; SetSigned(args1, false); diff --git a/src/shader_recompiler/ir/compute_value/imm_value.cpp b/src/shader_recompiler/ir/compute_value/imm_value.cpp index 6c433c244..068069d2e 100644 --- a/src/shader_recompiler/ir/compute_value/imm_value.cpp +++ b/src/shader_recompiler/ir/compute_value/imm_value.cpp @@ -345,8 +345,8 @@ ImmValue ImmValue::Bitcast(IR::Type new_type, bool new_signed) const noexcept { return result; } -ImmValue ImmValue::Extract(const ImmValue& index) const noexcept { - ASSERT(index.type == Type::U32 && !index.is_signed && index.imm_values[0].imm_u32 < Dimensions()); +ImmValue ImmValue::Extract(const ImmU32& index) const noexcept { + ASSERT(index.imm_values[0].imm_u32 < Dimensions()); ImmValue result; result.type = BaseType(); result.is_signed = IsSigned(); @@ -354,8 +354,8 @@ ImmValue ImmValue::Extract(const ImmValue& index) const noexcept { return result; } -ImmValue ImmValue::Insert(const ImmValue& value, const ImmValue& index) const noexcept { - ASSERT(index.type == Type::U32 && !index.is_signed && index.imm_values[0].imm_u32 < Dimensions()); +ImmValue ImmValue::Insert(const ImmValue& value, const ImmU32& index) const noexcept { + ASSERT(index.imm_values[0].imm_u32 < Dimensions()); ASSERT(value.type == BaseType() && value.IsSigned() == IsSigned()); ImmValue result = *this; result.imm_values[index.imm_values[0].imm_u32] = value.imm_values[0]; @@ -879,8 +879,7 @@ ImmValue ImmValue::operator^(const ImmValue& other) const noexcept { } } -ImmValue ImmValue::operator<<(const ImmValue& other) const noexcept { - ASSERT(other.type == Type::U32 && other.Dimensions() == 1); +ImmValue ImmValue::operator<<(const ImmU32& other) const noexcept { switch (type) { case Type::U1: return ImmValue(imm_values[0].imm_u1 << other.imm_values[0].imm_u1); @@ -905,8 +904,7 @@ ImmValue ImmValue::operator<<(const ImmValue& other) const noexcept { } } -ImmValue ImmValue::operator>>(const ImmValue& other) const noexcept { - ASSERT(other.type == Type::U32 && other.Dimensions() == 1); +ImmValue ImmValue::operator>>(const ImmU32& other) const noexcept { switch (type) { case Type::U1: return ImmValue(imm_values[0].imm_u1 >> other.imm_values[0].imm_u1); @@ -1170,13 +1168,13 @@ ImmValue& ImmValue::operator^=(const ImmValue& other) noexcept { return *this; } -ImmValue& ImmValue::operator<<=(const ImmValue& other) noexcept { +ImmValue& ImmValue::operator<<=(const ImmU32& other) noexcept { ImmValue result = *this << other; *this = result; return *this; } -ImmValue& ImmValue::operator>>=(const ImmValue& other) noexcept { +ImmValue& ImmValue::operator>>=(const ImmU32& other) noexcept { ImmValue result = *this >> other; *this = result; return *this; @@ -1271,8 +1269,7 @@ ImmValue ImmValue::exp2() const noexcept { } } -ImmValue ImmValue::ldexp(const ImmValue& exp) const noexcept { - ASSERT(type == exp.type); +ImmValue ImmValue::ldexp(const ImmU32& exp) const noexcept { switch (type) { case Type::F32: return ImmValue(std::ldexp(imm_values[0].imm_f32, exp.imm_values[0].imm_s32)); @@ -1414,7 +1411,7 @@ bool ImmValue::isnan() const noexcept { } } -ImmValue ImmValue::fma(const ImmValue& a, const ImmValue& b, const ImmValue& c) noexcept { +ImmValue ImmValue::fma(const ImmF32F64& a, const ImmF32F64& b, const ImmF32F64& c) noexcept { ASSERT(a.type == b.type && b.type == c.type); switch (a.type) { case Type::F32: diff --git a/src/shader_recompiler/ir/compute_value/imm_value.h b/src/shader_recompiler/ir/compute_value/imm_value.h index 78696d83a..017bd339d 100644 --- a/src/shader_recompiler/ir/compute_value/imm_value.h +++ b/src/shader_recompiler/ir/compute_value/imm_value.h @@ -16,6 +16,58 @@ namespace Shader::IR { // Live IR::Value but can only hold immediate values. Additionally, can hold vectors of values. // Has arithmetic operations defined for it. Usefull for computing a value at shader compile time. +template +class TypedImmValue; + +using ImmU1 = TypedImmValue; +using ImmU8 = TypedImmValue; +using ImmS8 = TypedImmValue; +using ImmU16 = TypedImmValue; +using ImmS16 = TypedImmValue; +using ImmU32 = TypedImmValue; +using ImmS32 = TypedImmValue; +using ImmF32 = TypedImmValue; +using ImmU64 = TypedImmValue; +using ImmS64 = TypedImmValue; +using ImmF64 = TypedImmValue; +using ImmS32F32 = TypedImmValue; +using ImmS64F64 = TypedImmValue; +using ImmU32U64 = TypedImmValue; +using ImmS32S64 = TypedImmValue; +using ImmU16U32U64 = TypedImmValue; +using ImmS16S32S64 = TypedImmValue; +using ImmF32F64 = TypedImmValue; +using ImmUAny = TypedImmValue; +using ImmSAny = TypedImmValue; +using ImmU32x2 = TypedImmValue; +using ImmU32x3 = TypedImmValue; +using ImmU32x4 = TypedImmValue; +using ImmS32x2 = TypedImmValue; +using ImmS32x3 = TypedImmValue; +using ImmS32x4 = TypedImmValue; +using ImmF32x2 = TypedImmValue; +using ImmF32x3 = TypedImmValue; +using ImmF32x4 = TypedImmValue; +using ImmF64x2 = TypedImmValue; +using ImmF64x3 = TypedImmValue; +using ImmF64x4 = TypedImmValue; +using ImmS32F32x2 = TypedImmValue; +using ImmS32F32x3 = TypedImmValue; +using ImmS32F32x4 = TypedImmValue; +using ImmF32F64x2 = TypedImmValue; +using ImmF32F64x3 = TypedImmValue; +using ImmF32F64x4 = TypedImmValue; +using ImmU32xAny = TypedImmValue; +using ImmS32xAny = TypedImmValue; +using ImmF32xAny = TypedImmValue; +using ImmF64xAny = TypedImmValue; +using ImmS32F32xAny = TypedImmValue; +using ImmF32F64xAny = TypedImmValue; + class ImmValue { public: ImmValue() noexcept = default; @@ -59,8 +111,8 @@ public: [[nodiscard]] ImmValue Convert(IR::Type new_type, bool new_signed) const noexcept; [[nodiscard]] ImmValue Bitcast(IR::Type new_type, bool new_signed) const noexcept; - [[nodiscard]] ImmValue Extract(const ImmValue& index) const noexcept; - [[nodiscard]] ImmValue Insert(const ImmValue& value, const ImmValue& index) const noexcept; + [[nodiscard]] ImmValue Extract(const ImmU32& index) const noexcept; + [[nodiscard]] ImmValue Insert(const ImmValue& value, const ImmU32& indndex) const noexcept; [[nodiscard]] bool U1() const; [[nodiscard]] u8 U8() const; @@ -104,8 +156,8 @@ public: [[nodiscard]] ImmValue operator&(const ImmValue& other) const noexcept; [[nodiscard]] ImmValue operator|(const ImmValue& other) const noexcept; [[nodiscard]] ImmValue operator^(const ImmValue& other) const noexcept; - [[nodiscard]] ImmValue operator<<(const ImmValue& other) const noexcept; - [[nodiscard]] ImmValue operator>>(const ImmValue& other) const noexcept; + [[nodiscard]] ImmValue operator<<(const ImmU32& other) const noexcept; + [[nodiscard]] ImmValue operator>>(const ImmU32& other) const noexcept; [[nodiscard]] ImmValue operator~() const noexcept; [[nodiscard]] ImmValue operator++(int) noexcept; @@ -125,8 +177,8 @@ public: ImmValue& operator&=(const ImmValue& other) noexcept; ImmValue& operator|=(const ImmValue& other) noexcept; ImmValue& operator^=(const ImmValue& other) noexcept; - ImmValue& operator<<=(const ImmValue& other) noexcept; - ImmValue& operator>>=(const ImmValue& other) noexcept; + ImmValue& operator<<=(const ImmU32& other) noexcept; + ImmValue& operator>>=(const ImmU32& other) noexcept; [[nodiscard]] ImmValue abs() const noexcept; [[nodiscard]] ImmValue recip() const noexcept; @@ -135,7 +187,7 @@ public: [[nodiscard]] ImmValue sin() const noexcept; [[nodiscard]] ImmValue cos() const noexcept; [[nodiscard]] ImmValue exp2() const noexcept; - [[nodiscard]] ImmValue ldexp(const ImmValue& exp) const noexcept; + [[nodiscard]] ImmValue ldexp(const ImmU32& exp) const noexcept; [[nodiscard]] ImmValue log2() const noexcept; [[nodiscard]] ImmValue clamp(const ImmValue& min, const ImmValue& max) const noexcept; [[nodiscard]] ImmValue floor() const noexcept; @@ -145,7 +197,7 @@ public: [[nodiscard]] ImmValue fract() const noexcept; [[nodiscard]] bool isnan() const noexcept; - [[nodiscard]] static ImmValue fma(const ImmValue& a, const ImmValue& b, const ImmValue& c) noexcept; + [[nodiscard]] static ImmValue fma(const ImmF32F64& a, const ImmF32F64& b, const ImmF32F64& c) noexcept; static bool IsSupportedValue(const IR::Value& value) noexcept; private: @@ -193,55 +245,6 @@ public: } }; -using ImmU1 = TypedImmValue; -using ImmU8 = TypedImmValue; -using ImmS8 = TypedImmValue; -using ImmU16 = TypedImmValue; -using ImmS16 = TypedImmValue; -using ImmU32 = TypedImmValue; -using ImmS32 = TypedImmValue; -using ImmF32 = TypedImmValue; -using ImmU64 = TypedImmValue; -using ImmS64 = TypedImmValue; -using ImmF64 = TypedImmValue; -using ImmS32F32 = TypedImmValue; -using ImmS64F64 = TypedImmValue; -using ImmU32U64 = TypedImmValue; -using ImmS32S64 = TypedImmValue; -using ImmU16U32U64 = TypedImmValue; -using ImmS16S32S64 = TypedImmValue; -using ImmF32F64 = TypedImmValue; -using ImmUAny = TypedImmValue; -using ImmSAny = TypedImmValue; -using ImmU32x2 = TypedImmValue; -using ImmU32x3 = TypedImmValue; -using ImmU32x4 = TypedImmValue; -using ImmS32x2 = TypedImmValue; -using ImmS32x3 = TypedImmValue; -using ImmS32x4 = TypedImmValue; -using ImmF32x2 = TypedImmValue; -using ImmF32x3 = TypedImmValue; -using ImmF32x4 = TypedImmValue; -using ImmF64x2 = TypedImmValue; -using ImmF64x3 = TypedImmValue; -using ImmF64x4 = TypedImmValue; -using ImmS32F32x2 = TypedImmValue; -using ImmS32F32x3 = TypedImmValue; -using ImmS32F32x4 = TypedImmValue; -using ImmF32F64x2 = TypedImmValue; -using ImmF32F64x3 = TypedImmValue; -using ImmF32F64x4 = TypedImmValue; -using ImmU32xAny = TypedImmValue; -using ImmS32xAny = TypedImmValue; -using ImmF32xAny = TypedImmValue; -using ImmF64xAny = TypedImmValue; -using ImmS32F32xAny = TypedImmValue; -using ImmF32F64xAny = TypedImmValue; - inline bool ImmValue::IsEmpty() const noexcept { return type == Type::Void; }