refactor pattern matching and optimize modulos (disabled)

This commit is contained in:
Frodo Baggins 2024-12-07 22:17:49 -08:00
parent 8058377e53
commit b6482f4b0d
3 changed files with 242 additions and 153 deletions

View File

@ -7,6 +7,7 @@
#include "shader_recompiler/ir/breadth_first_search.h"
#include "shader_recompiler/ir/ir_emitter.h"
#include "shader_recompiler/ir/opcodes.h"
#include "shader_recompiler/ir/pattern_matching.h"
#include "shader_recompiler/ir/program.h"
// TODO delelte
@ -90,114 +91,8 @@ static void DumpIR(IR::Program& program, std::string phase) {
* Must be placed in uniform control flow
*/
// Bad pattern matching attempt
template <typename Derived>
struct MatchObject {
inline bool DoMatch(IR::Value v) {
return static_cast<Derived*>(this)->DoMatch(v);
}
};
struct MatchValue : MatchObject<MatchValue> {
MatchValue(IR::Value& return_val_) : return_val(return_val_) {}
inline bool DoMatch(IR::Value v) {
return_val = v;
return true;
}
private:
IR::Value& return_val;
};
struct MatchIgnore : MatchObject<MatchIgnore> {
MatchIgnore() {}
inline bool DoMatch(IR::Value v) {
return true;
}
};
struct MatchImm : MatchObject<MatchImm> {
MatchImm(IR::Value& v) : return_val(v) {}
inline bool DoMatch(IR::Value v) {
if (!v.IsImmediate()) {
return false;
}
return_val = v;
return true;
}
private:
IR::Value& return_val;
};
// Specific
struct MatchAttribute : MatchObject<MatchAttribute> {
MatchAttribute(IR::Attribute attribute_) : attribute(attribute_) {}
inline bool DoMatch(IR::Value v) {
return v.Type() == IR::Type::Attribute && v.Attribute() == attribute;
}
private:
IR::Attribute attribute;
};
// Specific
struct MatchU32 : MatchObject<MatchU32> {
MatchU32(u32 imm_) : imm(imm_) {}
inline bool DoMatch(IR::Value v) {
return v.Type() == IR::Type::U32 && v.U32() == imm;
}
private:
u32 imm;
};
template <IR::Opcode opcode, typename... Args>
struct MatchInstObject : MatchObject<MatchInstObject<opcode>> {
static_assert(sizeof...(Args) == IR::NumArgsOf(opcode));
MatchInstObject(Args&&... args) : pattern(std::forward_as_tuple(args...)) {}
inline bool DoMatch(IR::Value v) {
IR::Inst* inst = v.TryInstRecursive();
if (!inst || inst->GetOpcode() != opcode) {
return false;
}
bool matched = true;
[&]<std::size_t... Is>(std::index_sequence<Is...>) {
((matched = matched && std::get<Is>(pattern).DoMatch(inst->Arg(Is))), ...);
}(std::make_index_sequence<sizeof...(Args)>{});
return matched;
}
private:
using MatchArgs = std::tuple<Args&...>;
MatchArgs pattern;
};
template <IR::Opcode opcode, typename... Args>
auto MakeInstPattern(Args&&... args) {
return MatchInstObject<opcode, Args...>(std::forward<Args>(args)...);
}
struct MatchFoldImm : MatchObject<MatchFoldImm> {
MatchFoldImm(IR::Value& v) : return_val(v) {}
inline bool DoMatch(IR::Value v);
private:
IR::Value& return_val;
};
// Represent address as sum of products
// Addr calculations look something like this, but can vary wildly due to decisions made by
// the ps4 compiler (instruction selection, etc)
// Input control point:
// PrimitiveId * input_cp_stride * #cp_per_input_patch + index * input_cp_stride + (attr# * 16 +
// component)
@ -208,10 +103,10 @@ private:
// #patches * input_cp_stride * #cp_per_input_patch + #patches * output_patch_stride +
// + PrimitiveId * per_patch_output_stride + (attr# * 16 + component)
// Sort terms left to right
namespace {
using namespace Shader::Optimiation::PatternMatching;
static void InitTessConstants(IR::ScalarReg sharp_ptr_base, s32 sharp_dword_offset,
Shader::Info& info, Shader::RuntimeInfo& runtime_info,
TessellationDataConstantBuffer& tess_constants) {
@ -239,20 +134,17 @@ std::optional<TessSharpLocation> FindTessConstantSharp(IR::Inst* read_const_buff
IR::Value rv = IR::Value{read_const_buffer};
IR::Value handle = read_const_buffer->Arg(0);
if (MakeInstPattern<IR::Opcode::CompositeConstructU32x4>(
MakeInstPattern<IR::Opcode::GetUserData>(MatchImm(sharp_dword_offset)), MatchIgnore(),
MatchIgnore(), MatchIgnore())
.DoMatch(handle)) {
if (M_COMPOSITECONSTRUCTU32X4(M_GETUSERDATA(MatchImm(sharp_dword_offset)), MatchIgnore(),
MatchIgnore(), MatchIgnore())
.Match(handle)) {
return TessSharpLocation{.ptr_base = IR::ScalarReg::Max,
.dword_off = static_cast<u32>(sharp_dword_offset.ScalarReg())};
} else if (MakeInstPattern<IR::Opcode::CompositeConstructU32x4>(
MakeInstPattern<IR::Opcode::ReadConst>(
MakeInstPattern<IR::Opcode::CompositeConstructU32x2>(
MakeInstPattern<IR::Opcode::GetUserData>(MatchImm(sharp_ptr_base)),
MatchIgnore()),
MatchImm(sharp_dword_offset)),
} else if (M_COMPOSITECONSTRUCTU32X4(
M_READCONST(M_COMPOSITECONSTRUCTU32X2(M_GETUSERDATA(MatchImm(sharp_ptr_base)),
MatchIgnore()),
MatchImm(sharp_dword_offset)),
MatchIgnore(), MatchIgnore(), MatchIgnore())
.DoMatch(handle)) {
.Match(handle)) {
return TessSharpLocation{.ptr_base = sharp_ptr_base.ScalarReg(),
.dword_off = sharp_dword_offset.U32()};
}
@ -322,32 +214,32 @@ private:
case IR::Opcode::Phi: {
struct PhiCounter {
u16 seq_num;
u8 base_edge;
u8 unique_edge;
u8 counter;
};
PhiCounter count = inst->Flags<PhiCounter>();
ASSERT_MSG(count.counter == 0 || count.base_edge == use.operand);
ASSERT_MSG(count.counter == 0 || count.unique_edge == use.operand);
// the point of seq_num is to tell us if we've already traversed this
// phi on the current walk. Alternatively we could keep a set of phi's
// seen on the current walk. This is to handle phi cycles
if (count.seq_num == 0) {
// First time we've encountered this phi
count.seq_num = seq_num;
// Mark the phi as having been traversed through this edge
count.base_edge = use.operand;
// Mark the phi as having been traversed originally through this edge
count.unique_edge = use.operand;
count.counter = inc;
} else if (count.seq_num < seq_num) {
count.seq_num = seq_num;
// Make sure the other phi edge has never been visited before.
// I think the other phi edges should be either undefs
// or self-referential edges, due to loops or something.
// TODO better explanation
ASSERT(count.base_edge == use.operand);
// For now, assume we are visiting this phi via the same edge
// as on other walks. If not, some dataflow analysis might be necessary
ASSERT(count.unique_edge == use.operand);
count.counter += inc;
} else {
// count.seq_num == seq_num
// there's a cycle, and we've already been here on this walk
return;
}
// Else: This is some back edge to a previously visited phi, like a
// loop induction variable
inst->SetFlags<PhiCounter>(count);
break;
}
@ -363,7 +255,7 @@ private:
uint seq_num{1u};
};
enum class AttributeRegion : u32 { InputCP, OutputCP, PatchConst, Unknown };
enum class AttributeRegion : u32 { InputCP, OutputCP, PatchConst };
static AttributeRegion FindRegionKind(IR::Inst* ring_access, const Shader::Info& info,
const Shader::RuntimeInfo& runtime_info) {
@ -380,6 +272,62 @@ static AttributeRegion FindRegionKind(IR::Inst* ring_access, const Shader::Info&
}
}
static bool IsDivisibleByStride(IR::Value term, u32 stride) {
IR::Value a, b;
if (MatchU32(stride).Match(term)) {
return true;
} else if (M_GETATTRIBUTEU32(MatchAttribute(IR::Attribute::TcsLsStride), MatchU32(0))
.Match(term) ||
M_GETATTRIBUTEU32(MatchAttribute(IR::Attribute::TcsCpStride), MatchU32(0))
.Match(term)) {
// TODO if we fold in constants earlier (Dont produce attributes, instead just emit
// constants) then this case isnt needed. Also should assert that this correct attribute is
// being used depending on stage and whether this is an input or output attribute
return true;
} else if (M_BITFIELDUEXTRACT(MatchValue(a), MatchU32(0), MatchU32(24)).Match(term) ||
M_BITFIELDSEXTRACT(MatchValue(a), MatchU32(0), MatchU32(24)).Match(term)) {
return IsDivisibleByStride(a, stride);
} else if (M_IMUL32(MatchValue(a), MatchValue(b)).Match(term)) {
return IsDivisibleByStride(a, stride) || IsDivisibleByStride(b, stride);
}
return false;
}
// Return true if we can eliminate any addends
static bool TryOptimizeAddendInModulo(IR::Value addend, u32 stride, std::vector<IR::U32>& addends) {
IR::Value a, b;
if (M_IADD32(MatchValue(a), MatchValue(b)).Match(addend)) {
bool ret = false;
ret = TryOptimizeAddendInModulo(a, stride, addends);
ret |= TryOptimizeAddendInModulo(b, stride, addends);
return ret;
} else if (!IsDivisibleByStride(addend, stride)) {
addends.push_back(IR::U32{addend});
return false;
} else {
return true;
}
}
// In calculation addr = (a + b + ...) % stride
// Use this fact
// (a + b) mod N = (a mod N + b mod N) mod N
// If any addend is divisible by stride, then we can replace it with 0 in the attribute
// or component index calculation
static IR::U32 TryOptimizeAddressModulo(IR::U32 addr, u32 stride, IR::IREmitter& ir) {
#if 0
std::vector<IR::U32> addends;
if (TryOptimizeAddendInModulo(addr, stride, addends)) {
addr = ir.Imm32(0);
for (auto& addend : addends) {
addr = ir.IAdd(addr, addend);
}
LOG_INFO(Render_Recompiler, "OPTIMIZED attr index");
}
#endif
return addr;
}
} // namespace
void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
@ -472,11 +420,15 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
if (output_kind == AttributeRegion::OutputCP) {
// Invocation ID array index is implicit, handled by SPIRV backend
IR::U32 addr_for_attrs = TryOptimizeAddressModulo(
addr, runtime_info.hs_info.hs_output_cp_stride, ir);
IR::U32 attr_index = ir.ShiftRightLogical(
ir.IMod(addr, ir.GetAttributeU32(IR::Attribute::TcsCpStride)),
ir.IMod(addr_for_attrs,
ir.Imm32(runtime_info.hs_info.hs_output_cp_stride)),
ir.Imm32(4u));
IR::U32 comp_index =
ir.ShiftRightLogical(ir.BitwiseAnd(addr, ir.Imm32(0xFU)), ir.Imm32(2u));
IR::U32 comp_index = ir.ShiftRightLogical(
ir.BitwiseAnd(addr_for_attrs, ir.Imm32(0xFU)), ir.Imm32(2u));
ir.SetTcsGenericAttribute(data, attr_index, comp_index);
} else {
ASSERT(output_kind == AttributeRegion::PatchConst);
@ -507,11 +459,19 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
case AttributeRegion::InputCP: {
IR::U32 control_point_index =
ir.IDiv(addr, ir.Imm32(runtime_info.hs_info.ls_stride));
if (info.pgm_hash == 0xb5fb5174) {
printf("here\n");
}
IR::U32 addr_for_attrs =
TryOptimizeAddressModulo(addr, runtime_info.hs_info.ls_stride, ir);
IR::U32 attr_index = ir.ShiftRightLogical(
ir.IMod(addr, ir.GetAttributeU32(IR::Attribute::TcsLsStride)),
ir.IMod(addr_for_attrs, ir.Imm32(runtime_info.hs_info.ls_stride)),
ir.Imm32(4u));
IR::U32 comp_index =
ir.ShiftRightLogical(ir.BitwiseAnd(addr, ir.Imm32(0xFU)), ir.Imm32(2u));
IR::U32 comp_index = ir.ShiftRightLogical(
ir.BitwiseAnd(addr_for_attrs, ir.Imm32(0xFU)), ir.Imm32(2u));
IR::Value attr_read =
ir.GetTessGenericAttribute(control_point_index, attr_index, comp_index);
attr_read = ir.BitCast<IR::U32>(IR::F32{attr_read});
@ -590,11 +550,15 @@ void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
case AttributeRegion::OutputCP: {
IR::U32 control_point_index =
ir.IDiv(addr, ir.Imm32(runtime_info.vs_info.hs_output_cp_stride));
IR::U32 addr_for_attrs = TryOptimizeAddressModulo(
addr, runtime_info.vs_info.hs_output_cp_stride, ir);
IR::U32 attr_index = ir.ShiftRightLogical(
ir.IMod(addr, ir.GetAttributeU32(IR::Attribute::TcsCpStride)),
ir.IMod(addr_for_attrs, ir.Imm32(runtime_info.vs_info.hs_output_cp_stride)),
ir.Imm32(4u));
IR::U32 comp_index =
ir.ShiftRightLogical(ir.BitwiseAnd(addr, ir.Imm32(0xFU)), ir.Imm32(2u));
IR::U32 comp_index = ir.ShiftRightLogical(
ir.BitwiseAnd(addr_for_attrs, ir.Imm32(0xFU)), ir.Imm32(2u));
IR::Value attr_read =
ir.GetTessGenericAttribute(control_point_index, attr_index, comp_index);
attr_read = ir.BitCast<IR::U32>(IR::F32{attr_read});
@ -602,7 +566,7 @@ void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
break;
}
case AttributeRegion::PatchConst: {
// TODO if assert fails then make patch consts into dynamic offset
// TODO if assert fails then make generic patch attrs into array and dyn index
ASSERT_MSG(addr.IsImmediate(), "patch addr non imm, inst {}",
fmt::ptr(addr.Inst()));
IR::Value get_patch = ir.GetPatch(IR::PatchGeneric(addr.U32() >> 2));
@ -655,8 +619,8 @@ void TessellationPreprocess(IR::Program& program, RuntimeInfo& runtime_info) {
InitTessConstants(sharp_location->ptr_base,
static_cast<s32>(sharp_location->dword_off), info,
runtime_info, tess_constants);
// break; TODO
continue;
break; // TODO
// continue;
}
UNREACHABLE_MSG("Failed to match tess constant sharp");
}
@ -718,23 +682,23 @@ void TessellationPreprocess(IR::Program& program, RuntimeInfo& runtime_info) {
for (IR::Block* block : program.blocks) {
for (auto it = block->Instructions().begin(); it != block->Instructions().end(); it++) {
IR::Inst& inst = *it;
if (MakeInstPattern<IR::Opcode::BitFieldUExtract>(
MakeInstPattern<IR::Opcode::GetAttributeU32>(
MatchAttribute(IR::Attribute::PackedHullInvocationInfo), MatchIgnore()),
if (M_BITFIELDUEXTRACT(
M_GETATTRIBUTEU32(MatchAttribute(IR::Attribute::PackedHullInvocationInfo),
MatchIgnore()),
MatchU32(0), MatchU32(8))
.DoMatch(IR::Value{&inst})) {
.Match(IR::Value{&inst})) {
IR::IREmitter emit(*block, it);
// IR::Value replacement =
// emit.GetAttributeU32(IR::Attribute::TessPatchIdInVgt);
// TODO should be fine but check this
IR::Value replacement(0u);
inst.ReplaceUsesWithAndRemove(replacement);
} else if (MakeInstPattern<IR::Opcode::BitFieldUExtract>(
MakeInstPattern<IR::Opcode::GetAttributeU32>(
} else if (M_BITFIELDUEXTRACT(
M_GETATTRIBUTEU32(
MatchAttribute(IR::Attribute::PackedHullInvocationInfo),
MatchIgnore()),
MatchU32(8), MatchU32(5))
.DoMatch(IR::Value{&inst})) {
.Match(IR::Value{&inst})) {
IR::IREmitter ir(*block, it);
IR::Value replacement;
if (runtime_info.hs_info.IsPassthrough()) {

View File

@ -0,0 +1,124 @@
#include "shader_recompiler/ir/attribute.h"
#include "shader_recompiler/ir/value.h"
namespace Shader::Optimiation::PatternMatching {
// Bad pattern matching attempt
template <typename Derived>
struct MatchObject {
inline bool Match(IR::Value v) {
return static_cast<Derived*>(this)->Match(v);
}
};
struct MatchValue : MatchObject<MatchValue> {
MatchValue(IR::Value& return_val_) : return_val(return_val_) {}
inline bool Match(IR::Value v) {
return_val = v;
return true;
}
private:
IR::Value& return_val;
};
struct MatchIgnore : MatchObject<MatchIgnore> {
MatchIgnore() {}
inline bool Match(IR::Value v) {
return true;
}
};
struct MatchImm : MatchObject<MatchImm> {
MatchImm(IR::Value& v) : return_val(v) {}
inline bool Match(IR::Value v) {
if (!v.IsImmediate()) {
return false;
}
return_val = v;
return true;
}
private:
IR::Value& return_val;
};
// Specific
struct MatchAttribute : MatchObject<MatchAttribute> {
MatchAttribute(IR::Attribute attribute_) : attribute(attribute_) {}
inline bool Match(IR::Value v) {
return v.Type() == IR::Type::Attribute && v.Attribute() == attribute;
}
private:
IR::Attribute attribute;
};
// Specific
struct MatchU32 : MatchObject<MatchU32> {
MatchU32(u32 imm_) : imm(imm_) {}
inline bool Match(IR::Value v) {
return v.IsImmediate() && v.Type() == IR::Type::U32 && v.U32() == imm;
}
private:
u32 imm;
};
template <IR::Opcode opcode, typename... Args>
struct MatchInstObject : MatchObject<MatchInstObject<opcode>> {
static_assert(sizeof...(Args) == IR::NumArgsOf(opcode));
MatchInstObject(Args&&... args) : pattern(std::forward_as_tuple(args...)) {}
inline bool Match(IR::Value v) {
IR::Inst* inst = v.TryInstRecursive();
if (!inst || inst->GetOpcode() != opcode) {
return false;
}
bool matched = true;
[&]<std::size_t... Is>(std::index_sequence<Is...>) {
((matched = matched && std::get<Is>(pattern).Match(inst->Arg(Is))), ...);
}(std::make_index_sequence<sizeof...(Args)>{});
return matched;
}
private:
using MatchArgs = std::tuple<Args&...>;
MatchArgs pattern;
};
template <IR::Opcode opcode, typename... Args>
inline auto MakeInstPattern(Args&&... args) {
return MatchInstObject<opcode, Args...>(std::forward<Args>(args)...);
}
// Conveniences. TODO maybe delete
#define M_READCONST(...) MakeInstPattern<IR::Opcode::ReadConst>(__VA_ARGS__)
#define M_GETUSERDATA(...) MakeInstPattern<IR::Opcode::GetUserData>(__VA_ARGS__)
#define M_BITFIELDUEXTRACT(...) MakeInstPattern<IR::Opcode::BitFieldUExtract>(__VA_ARGS__)
#define M_BITFIELDSEXTRACT(...) MakeInstPattern<IR::Opcode::BitFieldSExtract>(__VA_ARGS__)
#define M_GETATTRIBUTEU32(...) MakeInstPattern<IR::Opcode::GetAttributeU32>(__VA_ARGS__)
#define M_UMOD32(...) MakeInstPattern<IR::Opcode::UMod32>(__VA_ARGS__)
#define M_SHIFTRIGHTLOGICAL32(...) MakeInstPattern<IR::Opcode::ShiftRightLogical32>(__VA_ARGS__)
#define M_IADD32(...) MakeInstPattern<IR::Opcode::IAdd32>(__VA_ARGS__)
#define M_IMUL32(...) MakeInstPattern<IR::Opcode::IMul32>(__VA_ARGS__)
#define M_BITWISEAND32(...) MakeInstPattern<IR::Opcode::BitwiseAnd32>(__VA_ARGS__)
#define M_GETTESSGENERICATTRIBUTE(...) \
MakeInstPattern<IR::Opcode::GetTessGenericAttribute>(__VA_ARGS__)
#define M_SETTCSGENERICATTRIBUTE(...) \
MakeInstPattern<IR::Opcode::SetTcsGenericAttribute>(__VA_ARGS__)
#define M_COMPOSITECONSTRUCTU32X2(...) \
MakeInstPattern<IR::Opcode::CompositeConstructU32x2>(__VA_ARGS__)
#define M_COMPOSITECONSTRUCTU32X4(...) \
MakeInstPattern<IR::Opcode::CompositeConstructU32x4>(__VA_ARGS__)
} // namespace Shader::Optimiation::PatternMatching

View File

@ -85,7 +85,7 @@ IR::Program TranslateProgram(std::span<const u32> code, Pools& pools, Info& info
Shader::Optimization::SsaRewritePass(program.post_order_blocks);
Shader::Optimization::IdentityRemovalPass(program.blocks);
Shader::Optimization::ConstantPropagationPass(
program.post_order_blocks); // TODO const fold spam for now dumpMatchingIR("post_ssa");
program.post_order_blocks); // TODO const fold spam for now while testing
if (stage == Stage::Hull) {
Shader::Optimization::TessellationPreprocess(program, runtime_info);
Shader::Optimization::ConstantPropagationPass(program.post_order_blocks);
@ -93,6 +93,7 @@ IR::Program TranslateProgram(std::span<const u32> code, Pools& pools, Info& info
Shader::Optimization::HullShaderTransform(program, runtime_info);
dumpMatchingIR("post_hull");
Shader::Optimization::TessellationPostprocess(program, runtime_info);
dumpMatchingIR("post_hull_postprocess");
} else if (info.l_stage == LogicalStage::TessellationEval) {
Shader::Optimization::TessellationPreprocess(program, runtime_info);
Shader::Optimization::ConstantPropagationPass(program.post_order_blocks);