Use correct types

This commit is contained in:
Lander Gallastegi 2025-03-11 17:37:50 +01:00 committed by Lander Gallastegi
parent 676b06db0f
commit 79f4648c77
3 changed files with 75 additions and 75 deletions

View File

@ -89,7 +89,7 @@ static void OperationFma(Inst* inst, ImmValueList& inst_values, ComputeImmValues
ComputeImmValues(inst->Arg(2), args2, cache);
const auto op = [](const ImmValue& a, const ImmValue& b, const ImmValue& c) {
return ImmValue::fma(a, b, c);
return ImmValue::fma(ImmF32F64(a), ImmF32F64(b), ImmF32F64(c));
};
CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1, args2);
@ -242,7 +242,7 @@ static void OperationLdexp(Inst* inst, ImmValueList& inst_values, ComputeImmValu
ComputeImmValues(inst->Arg(1), args1, cache);
const auto op = [](const ImmValue& a, const ImmValue& b) {
return a.ldexp(b);
return a.ldexp(ImmU32(b));
};
CartesianInvoke(op, std::inserter(inst_values, inst_values.begin()), args0, args1);
@ -346,7 +346,7 @@ static void OperationShiftLeft(Inst* inst, ImmValueList& inst_values, ComputeImm
ComputeImmValues(inst->Arg(1), args1, cache);
const auto op = [](const ImmValue& a, const ImmValue& b) {
return a << b;
return a << ImmU32(b);
};
SetSigned(args1, false);
@ -359,7 +359,7 @@ static void OperationShiftRight(Inst* inst, bool is_signed, ImmValueList& inst_v
ComputeImmValues(inst->Arg(1), args1, cache);
const auto op = [](const ImmValue& a, const ImmValue& b) {
return a >> b;
return a >> ImmU32(b);
};
SetSigned(args0, is_signed);
@ -460,7 +460,7 @@ static void OperationCompositeExtract(Inst* inst, ImmValueList& inst_values, Com
ComputeImmValues(inst->Arg(1), args1, cache);
const auto op = [](const ImmValue& a, const ImmValue& b) {
return a.Extract(b);
return a.Extract(ImmU32(b));
};
SetSigned(args1, false);

View File

@ -345,8 +345,8 @@ ImmValue ImmValue::Bitcast(IR::Type new_type, bool new_signed) const noexcept {
return result;
}
ImmValue ImmValue::Extract(const ImmValue& index) const noexcept {
ASSERT(index.type == Type::U32 && !index.is_signed && index.imm_values[0].imm_u32 < Dimensions());
ImmValue ImmValue::Extract(const ImmU32& index) const noexcept {
ASSERT(index.imm_values[0].imm_u32 < Dimensions());
ImmValue result;
result.type = BaseType();
result.is_signed = IsSigned();
@ -354,8 +354,8 @@ ImmValue ImmValue::Extract(const ImmValue& index) const noexcept {
return result;
}
ImmValue ImmValue::Insert(const ImmValue& value, const ImmValue& index) const noexcept {
ASSERT(index.type == Type::U32 && !index.is_signed && index.imm_values[0].imm_u32 < Dimensions());
ImmValue ImmValue::Insert(const ImmValue& value, const ImmU32& index) const noexcept {
ASSERT(index.imm_values[0].imm_u32 < Dimensions());
ASSERT(value.type == BaseType() && value.IsSigned() == IsSigned());
ImmValue result = *this;
result.imm_values[index.imm_values[0].imm_u32] = value.imm_values[0];
@ -879,8 +879,7 @@ ImmValue ImmValue::operator^(const ImmValue& other) const noexcept {
}
}
ImmValue ImmValue::operator<<(const ImmValue& other) const noexcept {
ASSERT(other.type == Type::U32 && other.Dimensions() == 1);
ImmValue ImmValue::operator<<(const ImmU32& other) const noexcept {
switch (type) {
case Type::U1:
return ImmValue(imm_values[0].imm_u1 << other.imm_values[0].imm_u1);
@ -905,8 +904,7 @@ ImmValue ImmValue::operator<<(const ImmValue& other) const noexcept {
}
}
ImmValue ImmValue::operator>>(const ImmValue& other) const noexcept {
ASSERT(other.type == Type::U32 && other.Dimensions() == 1);
ImmValue ImmValue::operator>>(const ImmU32& other) const noexcept {
switch (type) {
case Type::U1:
return ImmValue(imm_values[0].imm_u1 >> other.imm_values[0].imm_u1);
@ -1170,13 +1168,13 @@ ImmValue& ImmValue::operator^=(const ImmValue& other) noexcept {
return *this;
}
ImmValue& ImmValue::operator<<=(const ImmValue& other) noexcept {
ImmValue& ImmValue::operator<<=(const ImmU32& other) noexcept {
ImmValue result = *this << other;
*this = result;
return *this;
}
ImmValue& ImmValue::operator>>=(const ImmValue& other) noexcept {
ImmValue& ImmValue::operator>>=(const ImmU32& other) noexcept {
ImmValue result = *this >> other;
*this = result;
return *this;
@ -1271,8 +1269,7 @@ ImmValue ImmValue::exp2() const noexcept {
}
}
ImmValue ImmValue::ldexp(const ImmValue& exp) const noexcept {
ASSERT(type == exp.type);
ImmValue ImmValue::ldexp(const ImmU32& exp) const noexcept {
switch (type) {
case Type::F32:
return ImmValue(std::ldexp(imm_values[0].imm_f32, exp.imm_values[0].imm_s32));
@ -1414,7 +1411,7 @@ bool ImmValue::isnan() const noexcept {
}
}
ImmValue ImmValue::fma(const ImmValue& a, const ImmValue& b, const ImmValue& c) noexcept {
ImmValue ImmValue::fma(const ImmF32F64& a, const ImmF32F64& b, const ImmF32F64& c) noexcept {
ASSERT(a.type == b.type && b.type == c.type);
switch (a.type) {
case Type::F32:

View File

@ -16,6 +16,58 @@ namespace Shader::IR {
// Live IR::Value but can only hold immediate values. Additionally, can hold vectors of values.
// Has arithmetic operations defined for it. Usefull for computing a value at shader compile time.
template <IR::Type type_, bool is_signed_>
class TypedImmValue;
using ImmU1 = TypedImmValue<Type::U1, false>;
using ImmU8 = TypedImmValue<Type::U8, false>;
using ImmS8 = TypedImmValue<Type::U8, true>;
using ImmU16 = TypedImmValue<Type::U16, false>;
using ImmS16 = TypedImmValue<Type::U16, true>;
using ImmU32 = TypedImmValue<Type::U32, false>;
using ImmS32 = TypedImmValue<Type::U32, true>;
using ImmF32 = TypedImmValue<Type::F32, true>;
using ImmU64 = TypedImmValue<Type::U64, false>;
using ImmS64 = TypedImmValue<Type::U64, true>;
using ImmF64 = TypedImmValue<Type::F64, true>;
using ImmS32F32 = TypedImmValue<Type::U32 | Type::F32, true>;
using ImmS64F64 = TypedImmValue<Type::U64 | Type::F64, true>;
using ImmU32U64 = TypedImmValue<Type::U32 | Type::U64, false>;
using ImmS32S64 = TypedImmValue<Type::U32 | Type::U64, true>;
using ImmU16U32U64 = TypedImmValue<Type::U16 | Type::U32 | Type::U64, false>;
using ImmS16S32S64 = TypedImmValue<Type::U16 | Type::U32 | Type::U64, true>;
using ImmF32F64 = TypedImmValue<Type::F32 | Type::F64, true>;
using ImmUAny = TypedImmValue<Type::U1 | Type::U8 | Type::U16 | Type::U32 | Type::U64, false>;
using ImmSAny = TypedImmValue<Type::U8 | Type::U16 | Type::U32 | Type::U64, true>;
using ImmU32x2 = TypedImmValue<Type::U32x2, false>;
using ImmU32x3 = TypedImmValue<Type::U32x3, false>;
using ImmU32x4 = TypedImmValue<Type::U32x4, false>;
using ImmS32x2 = TypedImmValue<Type::U32x2, true>;
using ImmS32x3 = TypedImmValue<Type::U32x3, true>;
using ImmS32x4 = TypedImmValue<Type::U32x4, true>;
using ImmF32x2 = TypedImmValue<Type::F32x2, true>;
using ImmF32x3 = TypedImmValue<Type::F32x3, true>;
using ImmF32x4 = TypedImmValue<Type::F32x4, true>;
using ImmF64x2 = TypedImmValue<Type::F64x2, true>;
using ImmF64x3 = TypedImmValue<Type::F64x3, true>;
using ImmF64x4 = TypedImmValue<Type::F64x4, true>;
using ImmS32F32x2 = TypedImmValue<Type::U32x2 | Type::F32x2, true>;
using ImmS32F32x3 = TypedImmValue<Type::U32x3 | Type::F32x3, true>;
using ImmS32F32x4 = TypedImmValue<Type::U32x4 | Type::F32x4, true>;
using ImmF32F64x2 = TypedImmValue<Type::F32x2 | Type::F64x2, true>;
using ImmF32F64x3 = TypedImmValue<Type::F32x3 | Type::F64x3, true>;
using ImmF32F64x4 = TypedImmValue<Type::F32x4 | Type::F64x4, true>;
using ImmU32xAny = TypedImmValue<Type::U32 | Type::U32x2 | Type::U32x3 | Type::U32x4, false>;
using ImmS32xAny = TypedImmValue<Type::U32 | Type::U32x2 | Type::U32x3 | Type::U32x4, true>;
using ImmF32xAny = TypedImmValue<Type::F32 | Type::F32x2 | Type::F32x3 | Type::F32x4, true>;
using ImmF64xAny = TypedImmValue<Type::F64 | Type::F64x2 | Type::F64x3 | Type::F64x4, true>;
using ImmS32F32xAny = TypedImmValue<Type::U32 | Type::F32 | Type::U32x2 | Type::F32x2 |
Type::U32x3 | Type::F32x3 | Type::U32x4 | Type::F32x4,
true>;
using ImmF32F64xAny = TypedImmValue<Type::F32 | Type::F64 | Type::F32x2 | Type::F64x2 |
Type::F32x3 | Type::F64x3 | Type::F32x4 | Type::F64x4,
true>;
class ImmValue {
public:
ImmValue() noexcept = default;
@ -59,8 +111,8 @@ public:
[[nodiscard]] ImmValue Convert(IR::Type new_type, bool new_signed) const noexcept;
[[nodiscard]] ImmValue Bitcast(IR::Type new_type, bool new_signed) const noexcept;
[[nodiscard]] ImmValue Extract(const ImmValue& index) const noexcept;
[[nodiscard]] ImmValue Insert(const ImmValue& value, const ImmValue& index) const noexcept;
[[nodiscard]] ImmValue Extract(const ImmU32& index) const noexcept;
[[nodiscard]] ImmValue Insert(const ImmValue& value, const ImmU32& indndex) const noexcept;
[[nodiscard]] bool U1() const;
[[nodiscard]] u8 U8() const;
@ -104,8 +156,8 @@ public:
[[nodiscard]] ImmValue operator&(const ImmValue& other) const noexcept;
[[nodiscard]] ImmValue operator|(const ImmValue& other) const noexcept;
[[nodiscard]] ImmValue operator^(const ImmValue& other) const noexcept;
[[nodiscard]] ImmValue operator<<(const ImmValue& other) const noexcept;
[[nodiscard]] ImmValue operator>>(const ImmValue& other) const noexcept;
[[nodiscard]] ImmValue operator<<(const ImmU32& other) const noexcept;
[[nodiscard]] ImmValue operator>>(const ImmU32& other) const noexcept;
[[nodiscard]] ImmValue operator~() const noexcept;
[[nodiscard]] ImmValue operator++(int) noexcept;
@ -125,8 +177,8 @@ public:
ImmValue& operator&=(const ImmValue& other) noexcept;
ImmValue& operator|=(const ImmValue& other) noexcept;
ImmValue& operator^=(const ImmValue& other) noexcept;
ImmValue& operator<<=(const ImmValue& other) noexcept;
ImmValue& operator>>=(const ImmValue& other) noexcept;
ImmValue& operator<<=(const ImmU32& other) noexcept;
ImmValue& operator>>=(const ImmU32& other) noexcept;
[[nodiscard]] ImmValue abs() const noexcept;
[[nodiscard]] ImmValue recip() const noexcept;
@ -135,7 +187,7 @@ public:
[[nodiscard]] ImmValue sin() const noexcept;
[[nodiscard]] ImmValue cos() const noexcept;
[[nodiscard]] ImmValue exp2() const noexcept;
[[nodiscard]] ImmValue ldexp(const ImmValue& exp) const noexcept;
[[nodiscard]] ImmValue ldexp(const ImmU32& exp) const noexcept;
[[nodiscard]] ImmValue log2() const noexcept;
[[nodiscard]] ImmValue clamp(const ImmValue& min, const ImmValue& max) const noexcept;
[[nodiscard]] ImmValue floor() const noexcept;
@ -145,7 +197,7 @@ public:
[[nodiscard]] ImmValue fract() const noexcept;
[[nodiscard]] bool isnan() const noexcept;
[[nodiscard]] static ImmValue fma(const ImmValue& a, const ImmValue& b, const ImmValue& c) noexcept;
[[nodiscard]] static ImmValue fma(const ImmF32F64& a, const ImmF32F64& b, const ImmF32F64& c) noexcept;
static bool IsSupportedValue(const IR::Value& value) noexcept;
private:
@ -193,55 +245,6 @@ public:
}
};
using ImmU1 = TypedImmValue<Type::U1, false>;
using ImmU8 = TypedImmValue<Type::U8, false>;
using ImmS8 = TypedImmValue<Type::U8, true>;
using ImmU16 = TypedImmValue<Type::U16, false>;
using ImmS16 = TypedImmValue<Type::U16, true>;
using ImmU32 = TypedImmValue<Type::U32, false>;
using ImmS32 = TypedImmValue<Type::U32, true>;
using ImmF32 = TypedImmValue<Type::F32, true>;
using ImmU64 = TypedImmValue<Type::U64, false>;
using ImmS64 = TypedImmValue<Type::U64, true>;
using ImmF64 = TypedImmValue<Type::F64, true>;
using ImmS32F32 = TypedImmValue<Type::U32 | Type::F32, true>;
using ImmS64F64 = TypedImmValue<Type::U64 | Type::F64, true>;
using ImmU32U64 = TypedImmValue<Type::U32 | Type::U64, false>;
using ImmS32S64 = TypedImmValue<Type::U32 | Type::U64, true>;
using ImmU16U32U64 = TypedImmValue<Type::U16 | Type::U32 | Type::U64, false>;
using ImmS16S32S64 = TypedImmValue<Type::U16 | Type::U32 | Type::U64, true>;
using ImmF32F64 = TypedImmValue<Type::F32 | Type::F64, true>;
using ImmUAny = TypedImmValue<Type::U1 | Type::U8 | Type::U16 | Type::U32 | Type::U64, false>;
using ImmSAny = TypedImmValue<Type::U8 | Type::U16 | Type::U32 | Type::U64, true>;
using ImmU32x2 = TypedImmValue<Type::U32x2, false>;
using ImmU32x3 = TypedImmValue<Type::U32x3, false>;
using ImmU32x4 = TypedImmValue<Type::U32x4, false>;
using ImmS32x2 = TypedImmValue<Type::U32x2, true>;
using ImmS32x3 = TypedImmValue<Type::U32x3, true>;
using ImmS32x4 = TypedImmValue<Type::U32x4, true>;
using ImmF32x2 = TypedImmValue<Type::F32x2, true>;
using ImmF32x3 = TypedImmValue<Type::F32x3, true>;
using ImmF32x4 = TypedImmValue<Type::F32x4, true>;
using ImmF64x2 = TypedImmValue<Type::F64x2, true>;
using ImmF64x3 = TypedImmValue<Type::F64x3, true>;
using ImmF64x4 = TypedImmValue<Type::F64x4, true>;
using ImmS32F32x2 = TypedImmValue<Type::U32x2 | Type::F32x2, true>;
using ImmS32F32x3 = TypedImmValue<Type::U32x3 | Type::F32x3, true>;
using ImmS32F32x4 = TypedImmValue<Type::U32x4 | Type::F32x4, true>;
using ImmF32F64x2 = TypedImmValue<Type::F32x2 | Type::F64x2, true>;
using ImmF32F64x3 = TypedImmValue<Type::F32x3 | Type::F64x3, true>;
using ImmF32F64x4 = TypedImmValue<Type::F32x4 | Type::F64x4, true>;
using ImmU32xAny = TypedImmValue<Type::U32 | Type::U32x2 | Type::U32x3 | Type::U32x4, false>;
using ImmS32xAny = TypedImmValue<Type::U32 | Type::U32x2 | Type::U32x3 | Type::U32x4, true>;
using ImmF32xAny = TypedImmValue<Type::F32 | Type::F32x2 | Type::F32x3 | Type::F32x4, true>;
using ImmF64xAny = TypedImmValue<Type::F64 | Type::F64x2 | Type::F64x3 | Type::F64x4, true>;
using ImmS32F32xAny = TypedImmValue<ImmS32F32::static_type | ImmS32F32x2::static_type |
ImmS32F32x3::static_type | ImmS32F32x4::static_type,
true>;
using ImmF32F64xAny = TypedImmValue<ImmF32F64::static_type | ImmF32F64x2::static_type |
ImmF32F64x3::static_type | ImmF32F64x4::static_type,
true>;
inline bool ImmValue::IsEmpty() const noexcept {
return type == Type::Void;
}