save some work

This commit is contained in:
Frodo Baggins 2024-12-01 21:53:06 -08:00
parent 256ee8f7a0
commit 2078da1f6d
4 changed files with 229 additions and 42 deletions

View File

@ -216,6 +216,18 @@ void FoldAdd(IR::Block& block, IR::Inst& inst) {
}
}
template <typename T>
void FoldMul(IR::Block& block, IR::Inst& inst) {
if (!FoldCommutative<T>(inst, [](T a, T b) { return a * b; })) {
return;
}
const IR::Value rhs{inst.Arg(1)};
if (rhs.IsImmediate() && Arg<T>(rhs) == 0) {
inst.ReplaceUsesWithAndRemove(IR::Value(0u));
return;
}
}
void FoldCmpClass(IR::Block& block, IR::Inst& inst) {
ASSERT_MSG(inst.Arg(1).IsImmediate(), "Unable to resolve compare operation");
const auto class_mask = static_cast<IR::FloatClassFunc>(inst.Arg(1).U32());
@ -281,6 +293,14 @@ void FoldReadLane(IR::Block& block, IR::Inst& inst) {
}
}
void FoldTessAttrAccess(IR::Inst& inst) {
if (inst.GetOpcode() == IR::Opcode::GetTessGenericAttribute) {
// Fold the vertex index
}
// Fold the attr index
// Fold the component index
}
void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
switch (inst.GetOpcode()) {
case IR::Opcode::IAdd32:
@ -292,10 +312,19 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
FoldWhenAllImmediates(inst, [](u32 a) { return static_cast<float>(a); });
return;
case IR::Opcode::IMul32:
FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a * b; });
FoldMul<u32>(block, inst);
return;
case IR::Opcode::UDiv32:
FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a / b; });
FoldWhenAllImmediates(inst, [](u32 a, u32 b) {
ASSERT_MSG(b != 0, "Folding UDiv32 with divisor 0");
return a / b;
});
return;
case IR::Opcode::UMod32:
FoldWhenAllImmediates(inst, [](u32 a, u32 b) {
ASSERT_MSG(b != 0, "Folding UMod32 with modulo 0");
return a % b;
});
return;
case IR::Opcode::FPCmpClass32:
FoldCmpClass(block, inst);
@ -452,6 +481,9 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
return FoldConvert(inst, IR::Opcode::ConvertF16F32);
case IR::Opcode::ConvertF16F32:
return FoldConvert(inst, IR::Opcode::ConvertF32F16);
case IR::Opcode::GetTessGenericAttribute:
case IR::Opcode::SetTcsGenericAttribute:
return FoldTessAttrAccess(inst);
default:
break;
}

View File

@ -1,13 +1,18 @@
// SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later
#include <numeric>
#include "common/assert.h"
#include "shader_recompiler/info.h"
#include "shader_recompiler/ir/attribute.h"
#include "shader_recompiler/ir/breadth_first_search.h"
#include "shader_recompiler/ir/ir_emitter.h"
#include "shader_recompiler/ir/opcodes.h"
#include "shader_recompiler/ir/program.h"
// TODO delelte
#include "common/io_file.h"
#include "common/path_util.h"
#include "shader_recompiler/runtime_info.h"
namespace Shader::Optimization {
@ -251,13 +256,131 @@ std::optional<TessSharpLocation> FindTessConstantSharp(IR::Inst* read_const_buff
return TessSharpLocation{.ptr_base = sharp_ptr_base.ScalarReg(),
.dword_off = sharp_dword_offset.U32()};
}
UNREACHABLE_MSG("failed to match tess constants sharp buf");
return {};
}
static IR::Program* g_program; // TODO delete
// Walker that helps deduce what type of attribute a DS instruction is reading
// or writing, which could be an input control point, output control point,
// or per-patch constant (PatchConst).
// For certain ReadConstBuffer instructions using the tess constants V#,
// which we preprocess and transform into a named GetAttribute, we visit the users
// recursively and increment a counter on the Load/WriteShared users.
// Namely TcsNumPatches (from m_hsNumPatch), TcsOutputBase (m_hsOutputBase),
// and TcsPatchConstBase (m_patchConstBase).
// In addr calculations, the term TcsNumPatches * ls_stride * #input_cp_in_patch
// is used as an addend to skip the region for input control points, and similarly
// TcsNumPatches * hs_cp_stride * #output_cp_in_patch is used to skip the region
// for output control points.
// The Input CP, Output CP, and PatchConst regions are laid out in that order for the
// entire thread group, so seeing the TcsNumPatches attribute used in an addr calc should
// increment the "region counter" by 1 for the given Load/WriteShared
//
// TODO this will break if AMD compiler used distributive property like
// TcsNumPatches * (ls_stride * #input_cp_in_patch + hs_cp_stride * #output_cp_in_patch)
//
// TODO can we just look at address post-constant folding, pull out all the constants
// and find the interval it's inside of? (phis are still a problem here)
class TessConstantUseWalker {
public:
void MarkTessAttributeUsers(IR::Inst* get_attribute) {
uint inc;
switch (get_attribute->Arg(0).Attribute()) {
case IR::Attribute::TcsNumPatches:
case IR::Attribute::TcsOutputBase:
inc = 1;
break;
case IR::Attribute::TcsPatchConstBase:
inc = 2;
break;
default:
return;
}
enum AttributeRegion { InputCP, OutputCP, PatchConst, Unknown };
for (IR::Use use : get_attribute->Uses()) {
MarkTessAttributeUsersHelper(use, inc);
}
++seq_num;
}
private:
void MarkTessAttributeUsersHelper(IR::Use use, uint inc) {
IR::Inst* inst = use.user;
switch (use.user->GetOpcode()) {
case IR::Opcode::LoadSharedU32:
case IR::Opcode::LoadSharedU64:
case IR::Opcode::LoadSharedU128:
case IR::Opcode::WriteSharedU32:
case IR::Opcode::WriteSharedU64:
case IR::Opcode::WriteSharedU128: {
uint counter = inst->Flags<u32>();
inst->SetFlags<u32>(counter + inc);
// Stop here
return;
}
case IR::Opcode::Phi: {
struct PhiCounter {
u16 seq_num;
u8 base_edge;
u8 counter;
};
PhiCounter count = inst->Flags<PhiCounter>();
ASSERT_MSG(count.counter == 0 || count.base_edge == use.operand);
// the point of seq_num is to tell us if we've already traversed this
// phi on the current walk. Alternatively we could keep a set of phi's
// seen on the current walk. This is to handle phi cycles
if (count.seq_num == 0) {
// First time we've encountered this phi
count.seq_num = seq_num;
// Mark the phi as having been traversed through this edge
count.base_edge = use.operand;
count.counter = inc;
} else if (count.seq_num < seq_num) {
count.seq_num = seq_num;
// Make sure the other phi edge has never been visited before.
// I think the other phi edges should be either undefs
// or self-referential edges, due to loops or something.
// TODO better explanation
ASSERT(count.base_edge == use.operand);
count.counter += inc;
}
// Else: This is some back edge to a previously visited phi, like a
// loop induction variable
inst->SetFlags<PhiCounter>(count);
break;
}
default:
break;
}
for (IR::Use use : inst->Uses()) {
MarkTessAttributeUsersHelper(use, inc);
}
}
uint seq_num{1u};
};
enum class AttributeRegion : u32 { InputCP, OutputCP, PatchConst, Unknown };
static AttributeRegion FindRegionKind(IR::Inst* ring_access, const Shader::Info& info,
const Shader::RuntimeInfo& runtime_info) {
u32 count = ring_access->Flags<u32>();
if (count == 0) {
return AttributeRegion::InputCP;
} else if (info.l_stage == LogicalStage::TessellationControl &&
runtime_info.hs_info.IsPassthrough()) {
ASSERT(count <= 1);
return AttributeRegion::PatchConst;
} else {
ASSERT(count <= 2);
return AttributeRegion(count);
}
}
static IR::Program* g_program; // TODO delete
struct RingAddressInfo {
AttributeRegion region{};
@ -336,35 +459,10 @@ private:
.DoMatch(node)) {
products.back().as_factors.emplace_back(IR::Value(u32(2 << (b.U32() - 1))));
Visit(a);
} else if (MakeInstPattern<IR::Opcode::ReadConstBuffer>(MatchIgnore(), MatchValue(b))
} else if (MakeInstPattern<IR::Opcode::GetAttributeU32>(MatchValue(a), MatchU32(0))
.DoMatch(node)) {
IR::Inst* read_const_buffer = node.InstRecursive();
IR::Value index = read_const_buffer->Arg(1);
if (index.IsImmediate()) {
u32 offset = index.U32();
if (offset < static_cast<u32>(IR::Attribute::TcsFirstEdgeTessFactorIndex) -
static_cast<u32>(IR::Attribute::TcsLsStride) + 1) {
IR::Attribute tess_constant_attr = static_cast<IR::Attribute>(
static_cast<u32>(IR::Attribute::TcsLsStride) + offset);
IR::IREmitter ir{*read_const_buffer->GetParent(),
IR::Block::InstructionList::s_iterator_to(*read_const_buffer)};
ASSERT(tess_constant_attr !=
IR::Attribute::TcsOffChipTessellationFactorThreshold);
IR::U32 replacement = ir.GetAttributeU32(tess_constant_attr);
read_const_buffer->ReplaceUsesWithAndRemove(replacement);
// Unwrap the attribute from the GetAttribute Inst and push back as a factor
// (more convenient for scanning the factors later)
node = IR::Value{tess_constant_attr};
if (IR::Value{read_const_buffer} == products.back().as_nested_value) {
products.back().as_nested_value = replacement;
}
}
}
products.back().as_factors.emplace_back(node);
printf("here\n");
products.back().as_factors.emplace_back(a.Attribute());
} else if (MakeInstPattern<IR::Opcode::GetAttributeU32>(MatchValue(a), MatchU32(0))
.DoMatch(node)) {
products.back().as_factors.emplace_back(a);
@ -547,6 +645,7 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
const u32 num_dwords = opcode == IR::Opcode::WriteSharedU32
? 1
: (opcode == IR::Opcode::WriteSharedU64 ? 2 : 4);
const IR::U32 addr = IR::U32{inst.Arg(0)};
const IR::Value data = inst.Arg(1);
const auto [data_lo, data_hi] = [&] -> std::pair<IR::U32, IR::U32> {
if (num_dwords == 1) {
@ -564,7 +663,12 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
const u32 comp = offset_dw & 3;
// Invocation ID array index is implicit, handled by SPIRV backend
// ir.SetAttribute(IR::Attribute::Param0 + param, data, comp);
ir.SetTcsGenericAttribute(data, ir.Imm32(attr_no), ir.Imm32(comp));
IR::U32 attr_index = ir.ShiftRightLogical(
ir.IMod(addr, ir.GetAttributeU32(IR::Attribute::TcsCpStride)),
ir.Imm32(4u));
IR::U32 comp_index =
ir.ShiftRightLogical(ir.BitwiseAnd(addr, ir.Imm32(0xFU)), ir.Imm32(2u));
ir.SetTcsGenericAttribute(data, attr_index, comp_index);
} else {
ASSERT(output_kind == AttributeRegion::PatchConst);
ir.SetPatch(IR::PatchGeneric(offset_dw), data);
@ -668,6 +772,7 @@ void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
// case IR::Opcode::LoadSharedU64:
// case IR::Opcode::LoadSharedU128: // TODO
RingAddressInfo address_info = pass.WalkRingAccess(&inst, ir);
const IR::U32 addr = IR::U32{inst.Arg(0)};
ASSERT(address_info.region == AttributeRegion::OutputCP ||
address_info.region == AttributeRegion::PatchConst);
@ -681,8 +786,13 @@ void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
IR::U32 control_point_index =
ir.IDiv(IR::U32{address_info.offset_in_patch},
ir.Imm32(runtime_info.vs_info.hs_output_cp_stride));
IR::Value attr_read = ir.GetTessGenericAttribute(
control_point_index, ir.Imm32(attr_no), ir.Imm32(comp));
IR::U32 attr_index = ir.ShiftRightLogical(
ir.IMod(addr, ir.GetAttributeU32(IR::Attribute::TcsCpStride)),
ir.Imm32(4u));
IR::U32 comp_index =
ir.ShiftRightLogical(ir.BitwiseAnd(addr, ir.Imm32(0xFU)), ir.Imm32(2u));
IR::Value attr_read =
ir.GetTessGenericAttribute(control_point_index, attr_index, comp_index);
attr_read = ir.BitCast<IR::U32>(IR::F32{attr_read});
inst.ReplaceUsesWithAndRemove(attr_read);
break;
@ -742,6 +852,7 @@ void TessellationPreprocess(IR::Program& program, RuntimeInfo& runtime_info) {
// break; TODO
continue;
}
UNREACHABLE_MSG("Failed to match tess constant sharp");
}
continue;
}
@ -755,6 +866,41 @@ void TessellationPreprocess(IR::Program& program, RuntimeInfo& runtime_info) {
ASSERT(info.FoundTessConstantsSharp());
for (IR::Block* block : program.blocks) {
for (IR::Inst& inst : block->Instructions()) {
switch (inst.GetOpcode()) {
case IR::Opcode::ReadConstBuffer: {
auto sharp_location = FindTessConstantSharp(&inst);
if (sharp_location && sharp_location->ptr_base == info.tess_consts_ptr_base &&
sharp_location->dword_off == info.tess_consts_dword_offset) {
// Replace the load with a special attribute load (for readability and easier
// pattern matching)
IR::Value index = inst.Arg(1);
ASSERT_MSG(index.IsImmediate(),
"Tessellation constant read with dynamic index");
u32 offset = index.U32();
ASSERT(offset < static_cast<u32>(IR::Attribute::TcsFirstEdgeTessFactorIndex) -
static_cast<u32>(IR::Attribute::TcsLsStride) + 1);
IR::Attribute tess_constant_attr = static_cast<IR::Attribute>(
static_cast<u32>(IR::Attribute::TcsLsStride) + offset);
IR::IREmitter ir{*block, IR::Block::InstructionList::s_iterator_to(inst)};
ASSERT(tess_constant_attr !=
IR::Attribute::TcsOffChipTessellationFactorThreshold);
IR::U32 replacement = ir.GetAttributeU32(tess_constant_attr);
inst.ReplaceUsesWithAndRemove(replacement);
}
break;
}
default:
break;
}
}
}
if (info.l_stage == LogicalStage::TessellationControl) {
// Replace the BFEs on V1 (packed with patch id and output cp id) for easier pattern
// matching
@ -778,8 +924,8 @@ void TessellationPreprocess(IR::Program& program, RuntimeInfo& runtime_info) {
IR::IREmitter ir(*block, it);
IR::Value replacement;
if (runtime_info.hs_info.IsPassthrough()) {
// Deal with annoying pattern in BB where InvocationID use makes no sense
// (in addr calculation for patchconst write)
// Deal with annoying pattern in BB where InvocationID use makes no
// sense (in addr calculation for patchconst write)
replacement = ir.Imm32(0);
} else {
replacement = ir.GetAttributeU32(IR::Attribute::InvocationId);
@ -808,10 +954,19 @@ void TessellationPostprocess(IR::Program& program, RuntimeInfo& runtime_info) {
case IR::Attribute::TcsCpStride:
inst.ReplaceUsesWithAndRemove(IR::Value(tess_constants.m_hsCpStride));
break;
// Should verify that these are only used in address calculations for attr
// read/write
// Replace with 0 so we can dynamically index the control points within the
// region allocated for this patch (input or output). These terms should only
// contribute to the base address of that region, so replacing with 0 *should*
// be fine
case IR::Attribute::TcsNumPatches:
case IR::Attribute::TcsOutputBase:
case IR::Attribute::TcsPatchConstSize:
case IR::Attribute::TcsPatchConstBase:
case IR::Attribute::TessPatchIdInVgt:
inst.ReplaceUsesWithAndRemove(IR::Value(0u));
break;
case IR::Attribute::TcsPatchConstSize:
case IR::Attribute::TcsPatchOutputSize:
case IR::Attribute::TcsFirstEdgeTessFactorIndex:
default:

View File

@ -84,8 +84,8 @@ IR::Program TranslateProgram(std::span<const u32> code, Pools& pools, Info& info
Shader::Optimization::SsaRewritePass(program.post_order_blocks);
Shader::Optimization::IdentityRemovalPass(program.blocks);
// Shader::Optimization::ConstantPropagationPass(program.post_order_blocks);
dumpMatchingIR("post_ssa");
Shader::Optimization::ConstantPropagationPass(
program.post_order_blocks); // TODO const fold spam for now dumpMatchingIR("post_ssa");
if (stage == Stage::Hull) {
Shader::Optimization::TessellationPreprocess(program, runtime_info);
Shader::Optimization::ConstantPropagationPass(program.post_order_blocks);

View File

@ -119,7 +119,7 @@ struct HullRuntimeInfo {
auto operator<=>(const HullRuntimeInfo&) const noexcept = default;
bool IsPassthrough() {
bool IsPassthrough() const {
return hs_output_base == 0;
};