From 2078da1f6d0fef8022e3d4a93b4eba63c6d6884b Mon Sep 17 00:00:00 2001 From: Frodo Baggins Date: Sun, 1 Dec 2024 21:53:06 -0800 Subject: [PATCH] save some work --- .../ir/passes/constant_propagation_pass.cpp | 36 ++- .../ir/passes/hull_shader_transform.cpp | 229 +++++++++++++++--- src/shader_recompiler/recompiler.cpp | 4 +- src/shader_recompiler/runtime_info.h | 2 +- 4 files changed, 229 insertions(+), 42 deletions(-) diff --git a/src/shader_recompiler/ir/passes/constant_propagation_pass.cpp b/src/shader_recompiler/ir/passes/constant_propagation_pass.cpp index 6a27cba04..71f4a2c42 100644 --- a/src/shader_recompiler/ir/passes/constant_propagation_pass.cpp +++ b/src/shader_recompiler/ir/passes/constant_propagation_pass.cpp @@ -216,6 +216,18 @@ void FoldAdd(IR::Block& block, IR::Inst& inst) { } } +template +void FoldMul(IR::Block& block, IR::Inst& inst) { + if (!FoldCommutative(inst, [](T a, T b) { return a * b; })) { + return; + } + const IR::Value rhs{inst.Arg(1)}; + if (rhs.IsImmediate() && Arg(rhs) == 0) { + inst.ReplaceUsesWithAndRemove(IR::Value(0u)); + return; + } +} + void FoldCmpClass(IR::Block& block, IR::Inst& inst) { ASSERT_MSG(inst.Arg(1).IsImmediate(), "Unable to resolve compare operation"); const auto class_mask = static_cast(inst.Arg(1).U32()); @@ -281,6 +293,14 @@ void FoldReadLane(IR::Block& block, IR::Inst& inst) { } } +void FoldTessAttrAccess(IR::Inst& inst) { + if (inst.GetOpcode() == IR::Opcode::GetTessGenericAttribute) { + // Fold the vertex index + } + // Fold the attr index + // Fold the component index +} + void ConstantPropagation(IR::Block& block, IR::Inst& inst) { switch (inst.GetOpcode()) { case IR::Opcode::IAdd32: @@ -292,10 +312,19 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) { FoldWhenAllImmediates(inst, [](u32 a) { return static_cast(a); }); return; case IR::Opcode::IMul32: - FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a * b; }); + FoldMul(block, inst); return; case IR::Opcode::UDiv32: - FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a / b; }); + FoldWhenAllImmediates(inst, [](u32 a, u32 b) { + ASSERT_MSG(b != 0, "Folding UDiv32 with divisor 0"); + return a / b; + }); + return; + case IR::Opcode::UMod32: + FoldWhenAllImmediates(inst, [](u32 a, u32 b) { + ASSERT_MSG(b != 0, "Folding UMod32 with modulo 0"); + return a % b; + }); return; case IR::Opcode::FPCmpClass32: FoldCmpClass(block, inst); @@ -452,6 +481,9 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) { return FoldConvert(inst, IR::Opcode::ConvertF16F32); case IR::Opcode::ConvertF16F32: return FoldConvert(inst, IR::Opcode::ConvertF32F16); + case IR::Opcode::GetTessGenericAttribute: + case IR::Opcode::SetTcsGenericAttribute: + return FoldTessAttrAccess(inst); default: break; } diff --git a/src/shader_recompiler/ir/passes/hull_shader_transform.cpp b/src/shader_recompiler/ir/passes/hull_shader_transform.cpp index 81598ee84..a67ce640d 100644 --- a/src/shader_recompiler/ir/passes/hull_shader_transform.cpp +++ b/src/shader_recompiler/ir/passes/hull_shader_transform.cpp @@ -1,13 +1,18 @@ // SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project // SPDX-License-Identifier: GPL-2.0-or-later #include +#include "common/assert.h" +#include "shader_recompiler/info.h" +#include "shader_recompiler/ir/attribute.h" #include "shader_recompiler/ir/breadth_first_search.h" #include "shader_recompiler/ir/ir_emitter.h" +#include "shader_recompiler/ir/opcodes.h" #include "shader_recompiler/ir/program.h" // TODO delelte #include "common/io_file.h" #include "common/path_util.h" +#include "shader_recompiler/runtime_info.h" namespace Shader::Optimization { @@ -251,13 +256,131 @@ std::optional FindTessConstantSharp(IR::Inst* read_const_buff return TessSharpLocation{.ptr_base = sharp_ptr_base.ScalarReg(), .dword_off = sharp_dword_offset.U32()}; } - UNREACHABLE_MSG("failed to match tess constants sharp buf"); return {}; } -static IR::Program* g_program; // TODO delete +// Walker that helps deduce what type of attribute a DS instruction is reading +// or writing, which could be an input control point, output control point, +// or per-patch constant (PatchConst). +// For certain ReadConstBuffer instructions using the tess constants V#, +// which we preprocess and transform into a named GetAttribute, we visit the users +// recursively and increment a counter on the Load/WriteShared users. +// Namely TcsNumPatches (from m_hsNumPatch), TcsOutputBase (m_hsOutputBase), +// and TcsPatchConstBase (m_patchConstBase). +// In addr calculations, the term TcsNumPatches * ls_stride * #input_cp_in_patch +// is used as an addend to skip the region for input control points, and similarly +// TcsNumPatches * hs_cp_stride * #output_cp_in_patch is used to skip the region +// for output control points. +// The Input CP, Output CP, and PatchConst regions are laid out in that order for the +// entire thread group, so seeing the TcsNumPatches attribute used in an addr calc should +// increment the "region counter" by 1 for the given Load/WriteShared +// +// TODO this will break if AMD compiler used distributive property like +// TcsNumPatches * (ls_stride * #input_cp_in_patch + hs_cp_stride * #output_cp_in_patch) +// +// TODO can we just look at address post-constant folding, pull out all the constants +// and find the interval it's inside of? (phis are still a problem here) +class TessConstantUseWalker { +public: + void MarkTessAttributeUsers(IR::Inst* get_attribute) { + uint inc; + switch (get_attribute->Arg(0).Attribute()) { + case IR::Attribute::TcsNumPatches: + case IR::Attribute::TcsOutputBase: + inc = 1; + break; + case IR::Attribute::TcsPatchConstBase: + inc = 2; + break; + default: + return; + } -enum AttributeRegion { InputCP, OutputCP, PatchConst, Unknown }; + for (IR::Use use : get_attribute->Uses()) { + MarkTessAttributeUsersHelper(use, inc); + } + + ++seq_num; + } + +private: + void MarkTessAttributeUsersHelper(IR::Use use, uint inc) { + IR::Inst* inst = use.user; + + switch (use.user->GetOpcode()) { + case IR::Opcode::LoadSharedU32: + case IR::Opcode::LoadSharedU64: + case IR::Opcode::LoadSharedU128: + case IR::Opcode::WriteSharedU32: + case IR::Opcode::WriteSharedU64: + case IR::Opcode::WriteSharedU128: { + uint counter = inst->Flags(); + inst->SetFlags(counter + inc); + // Stop here + return; + } + case IR::Opcode::Phi: { + struct PhiCounter { + u16 seq_num; + u8 base_edge; + u8 counter; + }; + + PhiCounter count = inst->Flags(); + ASSERT_MSG(count.counter == 0 || count.base_edge == use.operand); + // the point of seq_num is to tell us if we've already traversed this + // phi on the current walk. Alternatively we could keep a set of phi's + // seen on the current walk. This is to handle phi cycles + if (count.seq_num == 0) { + // First time we've encountered this phi + count.seq_num = seq_num; + // Mark the phi as having been traversed through this edge + count.base_edge = use.operand; + count.counter = inc; + } else if (count.seq_num < seq_num) { + count.seq_num = seq_num; + // Make sure the other phi edge has never been visited before. + // I think the other phi edges should be either undefs + // or self-referential edges, due to loops or something. + // TODO better explanation + ASSERT(count.base_edge == use.operand); + count.counter += inc; + } + // Else: This is some back edge to a previously visited phi, like a + // loop induction variable + inst->SetFlags(count); + break; + } + default: + break; + } + + for (IR::Use use : inst->Uses()) { + MarkTessAttributeUsersHelper(use, inc); + } + } + + uint seq_num{1u}; +}; + +enum class AttributeRegion : u32 { InputCP, OutputCP, PatchConst, Unknown }; + +static AttributeRegion FindRegionKind(IR::Inst* ring_access, const Shader::Info& info, + const Shader::RuntimeInfo& runtime_info) { + u32 count = ring_access->Flags(); + if (count == 0) { + return AttributeRegion::InputCP; + } else if (info.l_stage == LogicalStage::TessellationControl && + runtime_info.hs_info.IsPassthrough()) { + ASSERT(count <= 1); + return AttributeRegion::PatchConst; + } else { + ASSERT(count <= 2); + return AttributeRegion(count); + } +} + +static IR::Program* g_program; // TODO delete struct RingAddressInfo { AttributeRegion region{}; @@ -336,35 +459,10 @@ private: .DoMatch(node)) { products.back().as_factors.emplace_back(IR::Value(u32(2 << (b.U32() - 1)))); Visit(a); - } else if (MakeInstPattern(MatchIgnore(), MatchValue(b)) + } else if (MakeInstPattern(MatchValue(a), MatchU32(0)) .DoMatch(node)) { - IR::Inst* read_const_buffer = node.InstRecursive(); - IR::Value index = read_const_buffer->Arg(1); - - if (index.IsImmediate()) { - u32 offset = index.U32(); - if (offset < static_cast(IR::Attribute::TcsFirstEdgeTessFactorIndex) - - static_cast(IR::Attribute::TcsLsStride) + 1) { - IR::Attribute tess_constant_attr = static_cast( - static_cast(IR::Attribute::TcsLsStride) + offset); - IR::IREmitter ir{*read_const_buffer->GetParent(), - IR::Block::InstructionList::s_iterator_to(*read_const_buffer)}; - - ASSERT(tess_constant_attr != - IR::Attribute::TcsOffChipTessellationFactorThreshold); - IR::U32 replacement = ir.GetAttributeU32(tess_constant_attr); - - read_const_buffer->ReplaceUsesWithAndRemove(replacement); - // Unwrap the attribute from the GetAttribute Inst and push back as a factor - // (more convenient for scanning the factors later) - node = IR::Value{tess_constant_attr}; - - if (IR::Value{read_const_buffer} == products.back().as_nested_value) { - products.back().as_nested_value = replacement; - } - } - } - products.back().as_factors.emplace_back(node); + printf("here\n"); + products.back().as_factors.emplace_back(a.Attribute()); } else if (MakeInstPattern(MatchValue(a), MatchU32(0)) .DoMatch(node)) { products.back().as_factors.emplace_back(a); @@ -547,6 +645,7 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { const u32 num_dwords = opcode == IR::Opcode::WriteSharedU32 ? 1 : (opcode == IR::Opcode::WriteSharedU64 ? 2 : 4); + const IR::U32 addr = IR::U32{inst.Arg(0)}; const IR::Value data = inst.Arg(1); const auto [data_lo, data_hi] = [&] -> std::pair { if (num_dwords == 1) { @@ -564,7 +663,12 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { const u32 comp = offset_dw & 3; // Invocation ID array index is implicit, handled by SPIRV backend // ir.SetAttribute(IR::Attribute::Param0 + param, data, comp); - ir.SetTcsGenericAttribute(data, ir.Imm32(attr_no), ir.Imm32(comp)); + IR::U32 attr_index = ir.ShiftRightLogical( + ir.IMod(addr, ir.GetAttributeU32(IR::Attribute::TcsCpStride)), + ir.Imm32(4u)); + IR::U32 comp_index = + ir.ShiftRightLogical(ir.BitwiseAnd(addr, ir.Imm32(0xFU)), ir.Imm32(2u)); + ir.SetTcsGenericAttribute(data, attr_index, comp_index); } else { ASSERT(output_kind == AttributeRegion::PatchConst); ir.SetPatch(IR::PatchGeneric(offset_dw), data); @@ -668,6 +772,7 @@ void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { // case IR::Opcode::LoadSharedU64: // case IR::Opcode::LoadSharedU128: // TODO RingAddressInfo address_info = pass.WalkRingAccess(&inst, ir); + const IR::U32 addr = IR::U32{inst.Arg(0)}; ASSERT(address_info.region == AttributeRegion::OutputCP || address_info.region == AttributeRegion::PatchConst); @@ -681,8 +786,13 @@ void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { IR::U32 control_point_index = ir.IDiv(IR::U32{address_info.offset_in_patch}, ir.Imm32(runtime_info.vs_info.hs_output_cp_stride)); - IR::Value attr_read = ir.GetTessGenericAttribute( - control_point_index, ir.Imm32(attr_no), ir.Imm32(comp)); + IR::U32 attr_index = ir.ShiftRightLogical( + ir.IMod(addr, ir.GetAttributeU32(IR::Attribute::TcsCpStride)), + ir.Imm32(4u)); + IR::U32 comp_index = + ir.ShiftRightLogical(ir.BitwiseAnd(addr, ir.Imm32(0xFU)), ir.Imm32(2u)); + IR::Value attr_read = + ir.GetTessGenericAttribute(control_point_index, attr_index, comp_index); attr_read = ir.BitCast(IR::F32{attr_read}); inst.ReplaceUsesWithAndRemove(attr_read); break; @@ -742,6 +852,7 @@ void TessellationPreprocess(IR::Program& program, RuntimeInfo& runtime_info) { // break; TODO continue; } + UNREACHABLE_MSG("Failed to match tess constant sharp"); } continue; } @@ -755,6 +866,41 @@ void TessellationPreprocess(IR::Program& program, RuntimeInfo& runtime_info) { ASSERT(info.FoundTessConstantsSharp()); + for (IR::Block* block : program.blocks) { + for (IR::Inst& inst : block->Instructions()) { + switch (inst.GetOpcode()) { + case IR::Opcode::ReadConstBuffer: { + auto sharp_location = FindTessConstantSharp(&inst); + if (sharp_location && sharp_location->ptr_base == info.tess_consts_ptr_base && + sharp_location->dword_off == info.tess_consts_dword_offset) { + // Replace the load with a special attribute load (for readability and easier + // pattern matching) + IR::Value index = inst.Arg(1); + + ASSERT_MSG(index.IsImmediate(), + "Tessellation constant read with dynamic index"); + u32 offset = index.U32(); + ASSERT(offset < static_cast(IR::Attribute::TcsFirstEdgeTessFactorIndex) - + static_cast(IR::Attribute::TcsLsStride) + 1); + IR::Attribute tess_constant_attr = static_cast( + static_cast(IR::Attribute::TcsLsStride) + offset); + IR::IREmitter ir{*block, IR::Block::InstructionList::s_iterator_to(inst)}; + + ASSERT(tess_constant_attr != + IR::Attribute::TcsOffChipTessellationFactorThreshold); + IR::U32 replacement = ir.GetAttributeU32(tess_constant_attr); + + inst.ReplaceUsesWithAndRemove(replacement); + } + break; + } + + default: + break; + } + } + } + if (info.l_stage == LogicalStage::TessellationControl) { // Replace the BFEs on V1 (packed with patch id and output cp id) for easier pattern // matching @@ -778,8 +924,8 @@ void TessellationPreprocess(IR::Program& program, RuntimeInfo& runtime_info) { IR::IREmitter ir(*block, it); IR::Value replacement; if (runtime_info.hs_info.IsPassthrough()) { - // Deal with annoying pattern in BB where InvocationID use makes no sense - // (in addr calculation for patchconst write) + // Deal with annoying pattern in BB where InvocationID use makes no + // sense (in addr calculation for patchconst write) replacement = ir.Imm32(0); } else { replacement = ir.GetAttributeU32(IR::Attribute::InvocationId); @@ -808,10 +954,19 @@ void TessellationPostprocess(IR::Program& program, RuntimeInfo& runtime_info) { case IR::Attribute::TcsCpStride: inst.ReplaceUsesWithAndRemove(IR::Value(tess_constants.m_hsCpStride)); break; + // Should verify that these are only used in address calculations for attr + // read/write + // Replace with 0 so we can dynamically index the control points within the + // region allocated for this patch (input or output). These terms should only + // contribute to the base address of that region, so replacing with 0 *should* + // be fine case IR::Attribute::TcsNumPatches: case IR::Attribute::TcsOutputBase: - case IR::Attribute::TcsPatchConstSize: case IR::Attribute::TcsPatchConstBase: + case IR::Attribute::TessPatchIdInVgt: + inst.ReplaceUsesWithAndRemove(IR::Value(0u)); + break; + case IR::Attribute::TcsPatchConstSize: case IR::Attribute::TcsPatchOutputSize: case IR::Attribute::TcsFirstEdgeTessFactorIndex: default: diff --git a/src/shader_recompiler/recompiler.cpp b/src/shader_recompiler/recompiler.cpp index 584211602..0859a9f5a 100644 --- a/src/shader_recompiler/recompiler.cpp +++ b/src/shader_recompiler/recompiler.cpp @@ -84,8 +84,8 @@ IR::Program TranslateProgram(std::span code, Pools& pools, Info& info Shader::Optimization::SsaRewritePass(program.post_order_blocks); Shader::Optimization::IdentityRemovalPass(program.blocks); - // Shader::Optimization::ConstantPropagationPass(program.post_order_blocks); - dumpMatchingIR("post_ssa"); + Shader::Optimization::ConstantPropagationPass( + program.post_order_blocks); // TODO const fold spam for now dumpMatchingIR("post_ssa"); if (stage == Stage::Hull) { Shader::Optimization::TessellationPreprocess(program, runtime_info); Shader::Optimization::ConstantPropagationPass(program.post_order_blocks); diff --git a/src/shader_recompiler/runtime_info.h b/src/shader_recompiler/runtime_info.h index 228ddc60a..9e3788ce2 100644 --- a/src/shader_recompiler/runtime_info.h +++ b/src/shader_recompiler/runtime_info.h @@ -119,7 +119,7 @@ struct HullRuntimeInfo { auto operator<=>(const HullRuntimeInfo&) const noexcept = default; - bool IsPassthrough() { + bool IsPassthrough() const { return hs_output_base == 0; };