Use correct types

This commit is contained in:
Lander Gallastegi 2025-03-11 17:37:50 +01:00 committed by Lander Gallastegi
parent 676b06db0f
commit 79f4648c77
3 changed files with 75 additions and 75 deletions

View File

@ -89,7 +89,7 @@ static void OperationFma(Inst* inst, ImmValueList& inst_values, ComputeImmValues
ComputeImmValues(inst->Arg(2), args2, cache); ComputeImmValues(inst->Arg(2), args2, cache);
const auto op = [](const ImmValue& a, const ImmValue& b, const ImmValue& c) { const auto op = [](const ImmValue& a, const ImmValue& b, const ImmValue& c) {
return ImmValue::fma(a, b, c); return ImmValue::fma(ImmF32F64(a), ImmF32F64(b), ImmF32F64(c));
}; };
CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, args2); CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, args2);
@ -242,7 +242,7 @@ static void OperationLdexp(Inst* inst, ImmValueList& inst_values, ComputeImmValu
ComputeImmValues(inst->Arg(1), args1, cache); ComputeImmValues(inst->Arg(1), args1, cache);
const auto op = [](const ImmValue& a, const ImmValue& b) { const auto op = [](const ImmValue& a, const ImmValue& b) {
return a.ldexp(b); return a.ldexp(ImmU32(b));
}; };
CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1); CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1);
@ -346,7 +346,7 @@ static void OperationShiftLeft(Inst* inst, ImmValueList& inst_values, ComputeImm
ComputeImmValues(inst->Arg(1), args1, cache); ComputeImmValues(inst->Arg(1), args1, cache);
const auto op = [](const ImmValue& a, const ImmValue& b) { const auto op = [](const ImmValue& a, const ImmValue& b) {
return a << b; return a << ImmU32(b);
}; };
SetSigned(args1, false); SetSigned(args1, false);
@ -359,7 +359,7 @@ static void OperationShiftRight(Inst* inst, bool is_signed, ImmValueList& inst_v
ComputeImmValues(inst->Arg(1), args1, cache); ComputeImmValues(inst->Arg(1), args1, cache);
const auto op = [](const ImmValue& a, const ImmValue& b) { const auto op = [](const ImmValue& a, const ImmValue& b) {
return a >> b; return a >> ImmU32(b);
}; };
SetSigned(args0, is_signed); SetSigned(args0, is_signed);
@ -460,7 +460,7 @@ static void OperationCompositeExtract(Inst* inst, ImmValueList& inst_values, Com
ComputeImmValues(inst->Arg(1), args1, cache); ComputeImmValues(inst->Arg(1), args1, cache);
const auto op = [](const ImmValue& a, const ImmValue& b) { const auto op = [](const ImmValue& a, const ImmValue& b) {
return a.Extract(b); return a.Extract(ImmU32(b));
}; };
SetSigned(args1, false); SetSigned(args1, false);

View File

@ -345,8 +345,8 @@ ImmValue ImmValue::Bitcast(IR::Type new_type, bool new_signed) const noexcept {
return result; return result;
} }
ImmValue ImmValue::Extract(const ImmValue& index) const noexcept { ImmValue ImmValue::Extract(const ImmU32& index) const noexcept {
ASSERT(index.type == Type::U32 && !index.is_signed && index.imm_values[0].imm_u32 < Dimensions()); ASSERT(index.imm_values[0].imm_u32 < Dimensions());
ImmValue result; ImmValue result;
result.type = BaseType(); result.type = BaseType();
result.is_signed = IsSigned(); result.is_signed = IsSigned();
@ -354,8 +354,8 @@ ImmValue ImmValue::Extract(const ImmValue& index) const noexcept {
return result; return result;
} }
ImmValue ImmValue::Insert(const ImmValue& value, const ImmValue& index) const noexcept { ImmValue ImmValue::Insert(const ImmValue& value, const ImmU32& index) const noexcept {
ASSERT(index.type == Type::U32 && !index.is_signed && index.imm_values[0].imm_u32 < Dimensions()); ASSERT(index.imm_values[0].imm_u32 < Dimensions());
ASSERT(value.type == BaseType() && value.IsSigned() == IsSigned()); ASSERT(value.type == BaseType() && value.IsSigned() == IsSigned());
ImmValue result = *this; ImmValue result = *this;
result.imm_values[index.imm_values[0].imm_u32] = value.imm_values[0]; result.imm_values[index.imm_values[0].imm_u32] = value.imm_values[0];
@ -879,8 +879,7 @@ ImmValue ImmValue::operator^(const ImmValue& other) const noexcept {
} }
} }
ImmValue ImmValue::operator<<(const ImmValue& other) const noexcept { ImmValue ImmValue::operator<<(const ImmU32& other) const noexcept {
ASSERT(other.type == Type::U32 && other.Dimensions() == 1);
switch (type) { switch (type) {
case Type::U1: case Type::U1:
return ImmValue(imm_values[0].imm_u1 << other.imm_values[0].imm_u1); return ImmValue(imm_values[0].imm_u1 << other.imm_values[0].imm_u1);
@ -905,8 +904,7 @@ ImmValue ImmValue::operator<<(const ImmValue& other) const noexcept {
} }
} }
ImmValue ImmValue::operator>>(const ImmValue& other) const noexcept { ImmValue ImmValue::operator>>(const ImmU32& other) const noexcept {
ASSERT(other.type == Type::U32 && other.Dimensions() == 1);
switch (type) { switch (type) {
case Type::U1: case Type::U1:
return ImmValue(imm_values[0].imm_u1 >> other.imm_values[0].imm_u1); return ImmValue(imm_values[0].imm_u1 >> other.imm_values[0].imm_u1);
@ -1170,13 +1168,13 @@ ImmValue& ImmValue::operator^=(const ImmValue& other) noexcept {
return *this; return *this;
} }
ImmValue& ImmValue::operator<<=(const ImmValue& other) noexcept { ImmValue& ImmValue::operator<<=(const ImmU32& other) noexcept {
ImmValue result = *this << other; ImmValue result = *this << other;
*this = result; *this = result;
return *this; return *this;
} }
ImmValue& ImmValue::operator>>=(const ImmValue& other) noexcept { ImmValue& ImmValue::operator>>=(const ImmU32& other) noexcept {
ImmValue result = *this >> other; ImmValue result = *this >> other;
*this = result; *this = result;
return *this; return *this;
@ -1271,8 +1269,7 @@ ImmValue ImmValue::exp2() const noexcept {
} }
} }
ImmValue ImmValue::ldexp(const ImmValue& exp) const noexcept { ImmValue ImmValue::ldexp(const ImmU32& exp) const noexcept {
ASSERT(type == exp.type);
switch (type) { switch (type) {
case Type::F32: case Type::F32:
return ImmValue(std::ldexp(imm_values[0].imm_f32, exp.imm_values[0].imm_s32)); return ImmValue(std::ldexp(imm_values[0].imm_f32, exp.imm_values[0].imm_s32));
@ -1414,7 +1411,7 @@ bool ImmValue::isnan() const noexcept {
} }
} }
ImmValue ImmValue::fma(const ImmValue& a, const ImmValue& b, const ImmValue& c) noexcept { ImmValue ImmValue::fma(const ImmF32F64& a, const ImmF32F64& b, const ImmF32F64& c) noexcept {
ASSERT(a.type == b.type && b.type == c.type); ASSERT(a.type == b.type && b.type == c.type);
switch (a.type) { switch (a.type) {
case Type::F32: case Type::F32:

View File

@ -16,6 +16,58 @@ namespace Shader::IR {
// Live IR::Value but can only hold immediate values. Additionally, can hold vectors of values. // Live IR::Value but can only hold immediate values. Additionally, can hold vectors of values.
// Has arithmetic operations defined for it. Usefull for computing a value at shader compile time. // Has arithmetic operations defined for it. Usefull for computing a value at shader compile time.
template <IR::Type type_, bool is_signed_>
class TypedImmValue;
using ImmU1 = TypedImmValue<Type::U1, false>;
using ImmU8 = TypedImmValue<Type::U8, false>;
using ImmS8 = TypedImmValue<Type::U8, true>;
using ImmU16 = TypedImmValue<Type::U16, false>;
using ImmS16 = TypedImmValue<Type::U16, true>;
using ImmU32 = TypedImmValue<Type::U32, false>;
using ImmS32 = TypedImmValue<Type::U32, true>;
using ImmF32 = TypedImmValue<Type::F32, true>;
using ImmU64 = TypedImmValue<Type::U64, false>;
using ImmS64 = TypedImmValue<Type::U64, true>;
using ImmF64 = TypedImmValue<Type::F64, true>;
using ImmS32F32 = TypedImmValue<Type::U32 | Type::F32, true>;
using ImmS64F64 = TypedImmValue<Type::U64 | Type::F64, true>;
using ImmU32U64 = TypedImmValue<Type::U32 | Type::U64, false>;
using ImmS32S64 = TypedImmValue<Type::U32 | Type::U64, true>;
using ImmU16U32U64 = TypedImmValue<Type::U16 | Type::U32 | Type::U64, false>;
using ImmS16S32S64 = TypedImmValue<Type::U16 | Type::U32 | Type::U64, true>;
using ImmF32F64 = TypedImmValue<Type::F32 | Type::F64, true>;
using ImmUAny = TypedImmValue<Type::U1 | Type::U8 | Type::U16 | Type::U32 | Type::U64, false>;
using ImmSAny = TypedImmValue<Type::U8 | Type::U16 | Type::U32 | Type::U64, true>;
using ImmU32x2 = TypedImmValue<Type::U32x2, false>;
using ImmU32x3 = TypedImmValue<Type::U32x3, false>;
using ImmU32x4 = TypedImmValue<Type::U32x4, false>;
using ImmS32x2 = TypedImmValue<Type::U32x2, true>;
using ImmS32x3 = TypedImmValue<Type::U32x3, true>;
using ImmS32x4 = TypedImmValue<Type::U32x4, true>;
using ImmF32x2 = TypedImmValue<Type::F32x2, true>;
using ImmF32x3 = TypedImmValue<Type::F32x3, true>;
using ImmF32x4 = TypedImmValue<Type::F32x4, true>;
using ImmF64x2 = TypedImmValue<Type::F64x2, true>;
using ImmF64x3 = TypedImmValue<Type::F64x3, true>;
using ImmF64x4 = TypedImmValue<Type::F64x4, true>;
using ImmS32F32x2 = TypedImmValue<Type::U32x2 | Type::F32x2, true>;
using ImmS32F32x3 = TypedImmValue<Type::U32x3 | Type::F32x3, true>;
using ImmS32F32x4 = TypedImmValue<Type::U32x4 | Type::F32x4, true>;
using ImmF32F64x2 = TypedImmValue<Type::F32x2 | Type::F64x2, true>;
using ImmF32F64x3 = TypedImmValue<Type::F32x3 | Type::F64x3, true>;
using ImmF32F64x4 = TypedImmValue<Type::F32x4 | Type::F64x4, true>;
using ImmU32xAny = TypedImmValue<Type::U32 | Type::U32x2 | Type::U32x3 | Type::U32x4, false>;
using ImmS32xAny = TypedImmValue<Type::U32 | Type::U32x2 | Type::U32x3 | Type::U32x4, true>;
using ImmF32xAny = TypedImmValue<Type::F32 | Type::F32x2 | Type::F32x3 | Type::F32x4, true>;
using ImmF64xAny = TypedImmValue<Type::F64 | Type::F64x2 | Type::F64x3 | Type::F64x4, true>;
using ImmS32F32xAny = TypedImmValue<Type::U32 | Type::F32 | Type::U32x2 | Type::F32x2 |
Type::U32x3 | Type::F32x3 | Type::U32x4 | Type::F32x4,
true>;
using ImmF32F64xAny = TypedImmValue<Type::F32 | Type::F64 | Type::F32x2 | Type::F64x2 |
Type::F32x3 | Type::F64x3 | Type::F32x4 | Type::F64x4,
true>;
class ImmValue { class ImmValue {
public: public:
ImmValue() noexcept = default; ImmValue() noexcept = default;
@ -59,8 +111,8 @@ public:
[[nodiscard]] ImmValue Convert(IR::Type new_type, bool new_signed) const noexcept; [[nodiscard]] ImmValue Convert(IR::Type new_type, bool new_signed) const noexcept;
[[nodiscard]] ImmValue Bitcast(IR::Type new_type, bool new_signed) const noexcept; [[nodiscard]] ImmValue Bitcast(IR::Type new_type, bool new_signed) const noexcept;
[[nodiscard]] ImmValue Extract(const ImmValue& index) const noexcept; [[nodiscard]] ImmValue Extract(const ImmU32& index) const noexcept;
[[nodiscard]] ImmValue Insert(const ImmValue& value, const ImmValue& index) const noexcept; [[nodiscard]] ImmValue Insert(const ImmValue& value, const ImmU32& indndex) const noexcept;
[[nodiscard]] bool U1() const; [[nodiscard]] bool U1() const;
[[nodiscard]] u8 U8() const; [[nodiscard]] u8 U8() const;
@ -104,8 +156,8 @@ public:
[[nodiscard]] ImmValue operator&(const ImmValue& other) const noexcept; [[nodiscard]] ImmValue operator&(const ImmValue& other) const noexcept;
[[nodiscard]] ImmValue operator|(const ImmValue& other) const noexcept; [[nodiscard]] ImmValue operator|(const ImmValue& other) const noexcept;
[[nodiscard]] ImmValue operator^(const ImmValue& other) const noexcept; [[nodiscard]] ImmValue operator^(const ImmValue& other) const noexcept;
[[nodiscard]] ImmValue operator<<(const ImmValue& other) const noexcept; [[nodiscard]] ImmValue operator<<(const ImmU32& other) const noexcept;
[[nodiscard]] ImmValue operator>>(const ImmValue& other) const noexcept; [[nodiscard]] ImmValue operator>>(const ImmU32& other) const noexcept;
[[nodiscard]] ImmValue operator~() const noexcept; [[nodiscard]] ImmValue operator~() const noexcept;
[[nodiscard]] ImmValue operator++(int) noexcept; [[nodiscard]] ImmValue operator++(int) noexcept;
@ -125,8 +177,8 @@ public:
ImmValue& operator&=(const ImmValue& other) noexcept; ImmValue& operator&=(const ImmValue& other) noexcept;
ImmValue& operator|=(const ImmValue& other) noexcept; ImmValue& operator|=(const ImmValue& other) noexcept;
ImmValue& operator^=(const ImmValue& other) noexcept; ImmValue& operator^=(const ImmValue& other) noexcept;
ImmValue& operator<<=(const ImmValue& other) noexcept; ImmValue& operator<<=(const ImmU32& other) noexcept;
ImmValue& operator>>=(const ImmValue& other) noexcept; ImmValue& operator>>=(const ImmU32& other) noexcept;
[[nodiscard]] ImmValue abs() const noexcept; [[nodiscard]] ImmValue abs() const noexcept;
[[nodiscard]] ImmValue recip() const noexcept; [[nodiscard]] ImmValue recip() const noexcept;
@ -135,7 +187,7 @@ public:
[[nodiscard]] ImmValue sin() const noexcept; [[nodiscard]] ImmValue sin() const noexcept;
[[nodiscard]] ImmValue cos() const noexcept; [[nodiscard]] ImmValue cos() const noexcept;
[[nodiscard]] ImmValue exp2() const noexcept; [[nodiscard]] ImmValue exp2() const noexcept;
[[nodiscard]] ImmValue ldexp(const ImmValue& exp) const noexcept; [[nodiscard]] ImmValue ldexp(const ImmU32& exp) const noexcept;
[[nodiscard]] ImmValue log2() const noexcept; [[nodiscard]] ImmValue log2() const noexcept;
[[nodiscard]] ImmValue clamp(const ImmValue& min, const ImmValue& max) const noexcept; [[nodiscard]] ImmValue clamp(const ImmValue& min, const ImmValue& max) const noexcept;
[[nodiscard]] ImmValue floor() const noexcept; [[nodiscard]] ImmValue floor() const noexcept;
@ -145,7 +197,7 @@ public:
[[nodiscard]] ImmValue fract() const noexcept; [[nodiscard]] ImmValue fract() const noexcept;
[[nodiscard]] bool isnan() const noexcept; [[nodiscard]] bool isnan() const noexcept;
[[nodiscard]] static ImmValue fma(const ImmValue& a, const ImmValue& b, const ImmValue& c) noexcept; [[nodiscard]] static ImmValue fma(const ImmF32F64& a, const ImmF32F64& b, const ImmF32F64& c) noexcept;
static bool IsSupportedValue(const IR::Value& value) noexcept; static bool IsSupportedValue(const IR::Value& value) noexcept;
private: private:
@ -193,55 +245,6 @@ public:
} }
}; };
using ImmU1 = TypedImmValue<Type::U1, false>;
using ImmU8 = TypedImmValue<Type::U8, false>;
using ImmS8 = TypedImmValue<Type::U8, true>;
using ImmU16 = TypedImmValue<Type::U16, false>;
using ImmS16 = TypedImmValue<Type::U16, true>;
using ImmU32 = TypedImmValue<Type::U32, false>;
using ImmS32 = TypedImmValue<Type::U32, true>;
using ImmF32 = TypedImmValue<Type::F32, true>;
using ImmU64 = TypedImmValue<Type::U64, false>;
using ImmS64 = TypedImmValue<Type::U64, true>;
using ImmF64 = TypedImmValue<Type::F64, true>;
using ImmS32F32 = TypedImmValue<Type::U32 | Type::F32, true>;
using ImmS64F64 = TypedImmValue<Type::U64 | Type::F64, true>;
using ImmU32U64 = TypedImmValue<Type::U32 | Type::U64, false>;
using ImmS32S64 = TypedImmValue<Type::U32 | Type::U64, true>;
using ImmU16U32U64 = TypedImmValue<Type::U16 | Type::U32 | Type::U64, false>;
using ImmS16S32S64 = TypedImmValue<Type::U16 | Type::U32 | Type::U64, true>;
using ImmF32F64 = TypedImmValue<Type::F32 | Type::F64, true>;
using ImmUAny = TypedImmValue<Type::U1 | Type::U8 | Type::U16 | Type::U32 | Type::U64, false>;
using ImmSAny = TypedImmValue<Type::U8 | Type::U16 | Type::U32 | Type::U64, true>;
using ImmU32x2 = TypedImmValue<Type::U32x2, false>;
using ImmU32x3 = TypedImmValue<Type::U32x3, false>;
using ImmU32x4 = TypedImmValue<Type::U32x4, false>;
using ImmS32x2 = TypedImmValue<Type::U32x2, true>;
using ImmS32x3 = TypedImmValue<Type::U32x3, true>;
using ImmS32x4 = TypedImmValue<Type::U32x4, true>;
using ImmF32x2 = TypedImmValue<Type::F32x2, true>;
using ImmF32x3 = TypedImmValue<Type::F32x3, true>;
using ImmF32x4 = TypedImmValue<Type::F32x4, true>;
using ImmF64x2 = TypedImmValue<Type::F64x2, true>;
using ImmF64x3 = TypedImmValue<Type::F64x3, true>;
using ImmF64x4 = TypedImmValue<Type::F64x4, true>;
using ImmS32F32x2 = TypedImmValue<Type::U32x2 | Type::F32x2, true>;
using ImmS32F32x3 = TypedImmValue<Type::U32x3 | Type::F32x3, true>;
using ImmS32F32x4 = TypedImmValue<Type::U32x4 | Type::F32x4, true>;
using ImmF32F64x2 = TypedImmValue<Type::F32x2 | Type::F64x2, true>;
using ImmF32F64x3 = TypedImmValue<Type::F32x3 | Type::F64x3, true>;
using ImmF32F64x4 = TypedImmValue<Type::F32x4 | Type::F64x4, true>;
using ImmU32xAny = TypedImmValue<Type::U32 | Type::U32x2 | Type::U32x3 | Type::U32x4, false>;
using ImmS32xAny = TypedImmValue<Type::U32 | Type::U32x2 | Type::U32x3 | Type::U32x4, true>;
using ImmF32xAny = TypedImmValue<Type::F32 | Type::F32x2 | Type::F32x3 | Type::F32x4, true>;
using ImmF64xAny = TypedImmValue<Type::F64 | Type::F64x2 | Type::F64x3 | Type::F64x4, true>;
using ImmS32F32xAny = TypedImmValue<ImmS32F32::static_type | ImmS32F32x2::static_type |
ImmS32F32x3::static_type | ImmS32F32x4::static_type,
true>;
using ImmF32F64xAny = TypedImmValue<ImmF32F64::static_type | ImmF32F64x2::static_type |
ImmF32F64x3::static_type | ImmF32F64x4::static_type,
true>;
inline bool ImmValue::IsEmpty() const noexcept { inline bool ImmValue::IsEmpty() const noexcept {
return type == Type::Void; return type == Type::Void;
} }