shader_recompiler: Add use tracking for Insts

This commit is contained in:
Frodo Baggins 2024-10-21 16:41:02 -07:00 committed by IndecisiveTurtle
parent 920acb8d8b
commit 38ac0024bf
5 changed files with 78 additions and 45 deletions

View File

@ -2,6 +2,7 @@
// SPDX-License-Identifier: GPL-2.0-or-later // SPDX-License-Identifier: GPL-2.0-or-later
#include <algorithm> #include <algorithm>
#include <any>
#include <memory> #include <memory>
#include "shader_recompiler/exception.h" #include "shader_recompiler/exception.h"
@ -119,10 +120,10 @@ void Inst::SetArg(size_t index, Value value) {
} }
const IR::Value arg{Arg(index)}; const IR::Value arg{Arg(index)};
if (!arg.IsImmediate()) { if (!arg.IsImmediate()) {
UndoUse(arg); UndoUse(arg.Inst(), index);
} }
if (!value.IsImmediate()) { if (!value.IsImmediate()) {
Use(value); Use(value.Inst(), index);
} }
if (op == Opcode::Phi) { if (op == Opcode::Phi) {
phi_args[index].second = value; phi_args[index].second = value;
@ -143,29 +144,32 @@ Block* Inst::PhiBlock(size_t index) const {
void Inst::AddPhiOperand(Block* predecessor, const Value& value) { void Inst::AddPhiOperand(Block* predecessor, const Value& value) {
if (!value.IsImmediate()) { if (!value.IsImmediate()) {
Use(value); Use(value.Inst(), phi_args.size());
} }
phi_args.emplace_back(predecessor, value); phi_args.emplace_back(predecessor, value);
} }
void Inst::Invalidate() { void Inst::Invalidate() {
ASSERT(uses.empty());
ClearArgs(); ClearArgs();
ReplaceOpcode(Opcode::Void); ReplaceOpcode(Opcode::Void);
} }
void Inst::ClearArgs() { void Inst::ClearArgs() {
if (op == Opcode::Phi) { if (op == Opcode::Phi) {
for (auto& pair : phi_args) { for (auto i = 0; i < phi_args.size(); i++) {
auto& pair = phi_args[i];
IR::Value& value{pair.second}; IR::Value& value{pair.second};
if (!value.IsImmediate()) { if (!value.IsImmediate()) {
UndoUse(value); UndoUse(value.Inst(), i);
} }
} }
phi_args.clear(); phi_args.clear();
} else { } else {
for (auto& value : args) { for (auto i = 0; i < args.size(); i++) {
auto& value = args[i];
if (!value.IsImmediate()) { if (!value.IsImmediate()) {
UndoUse(value); UndoUse(value.Inst(), i);
} }
} }
// Reset arguments to null // Reset arguments to null
@ -174,13 +178,21 @@ void Inst::ClearArgs() {
} }
} }
void Inst::ReplaceUsesWith(Value replacement) { void Inst::ReplaceUsesWith(Value replacement, bool preserve) {
Invalidate(); // Copy since user->SetArg will mutate this->uses
ReplaceOpcode(Opcode::Identity); // Could also do temp_uses = std::move(uses) but more readable
if (!replacement.IsImmediate()) { boost::container::list<IR::Use> temp_uses = uses;
Use(replacement); for (auto& [user, operand] : temp_uses) {
DEBUG_ASSERT(user->Arg(operand).Inst() == this);
user->SetArg(operand, replacement);
}
Invalidate();
if (preserve) {
// Still useful to have Identity for indirection.
// SSA pass would be more complicated without it
ReplaceOpcode(Opcode::Identity);
SetArg(0, replacement);
} }
args[0] = replacement;
} }
void Inst::ReplaceOpcode(IR::Opcode opcode) { void Inst::ReplaceOpcode(IR::Opcode opcode) {
@ -195,14 +207,15 @@ void Inst::ReplaceOpcode(IR::Opcode opcode) {
op = opcode; op = opcode;
} }
void Inst::Use(const Value& value) { void Inst::Use(Inst* used, u32 operand) {
Inst* const inst{value.Inst()}; DEBUG_ASSERT(0 == std::count(used->uses.begin(), used->uses.end(), IR::Use(this, operand)));
++inst->use_count; used->uses.emplace_front(this, operand);
} }
void Inst::UndoUse(const Value& value) { void Inst::UndoUse(Inst* used, u32 operand) {
Inst* const inst{value.Inst()}; IR::Use use(this, operand);
--inst->use_count; DEBUG_ASSERT(1 == std::count(used->uses.begin(), used->uses.end(), use));
used->uses.remove(use);
} }
} // namespace Shader::IR } // namespace Shader::IR

View File

@ -43,7 +43,7 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) {
if (is_lhs_immediate && is_rhs_immediate) { if (is_lhs_immediate && is_rhs_immediate) {
const auto result{imm_fn(Arg<T>(lhs), Arg<T>(rhs))}; const auto result{imm_fn(Arg<T>(lhs), Arg<T>(rhs))};
inst.ReplaceUsesWith(IR::Value{result}); inst.ReplaceUsesWithAndRemove(IR::Value{result});
return false; return false;
} }
if (is_lhs_immediate && !is_rhs_immediate) { if (is_lhs_immediate && !is_rhs_immediate) {
@ -75,7 +75,7 @@ bool FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
return false; return false;
} }
using Indices = std::make_index_sequence<Common::LambdaTraits<decltype(func)>::NUM_ARGS>; using Indices = std::make_index_sequence<Common::LambdaTraits<decltype(func)>::NUM_ARGS>;
inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{})); inst.ReplaceUsesWithAndRemove(EvalImmediates(inst, func, Indices{}));
return true; return true;
} }
@ -83,12 +83,12 @@ template <IR::Opcode op, typename Dest, typename Source>
void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) { void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) {
const IR::Value value{inst.Arg(0)}; const IR::Value value{inst.Arg(0)};
if (value.IsImmediate()) { if (value.IsImmediate()) {
inst.ReplaceUsesWith(IR::Value{std::bit_cast<Dest>(Arg<Source>(value))}); inst.ReplaceUsesWithAndRemove(IR::Value{std::bit_cast<Dest>(Arg<Source>(value))});
return; return;
} }
IR::Inst* const arg_inst{value.InstRecursive()}; IR::Inst* const arg_inst{value.InstRecursive()};
if (arg_inst->GetOpcode() == reverse) { if (arg_inst->GetOpcode() == reverse) {
inst.ReplaceUsesWith(arg_inst->Arg(0)); inst.ReplaceUsesWithAndRemove(arg_inst->Arg(0));
return; return;
} }
} }
@ -131,7 +131,7 @@ void FoldCompositeExtract(IR::Inst& inst, IR::Opcode construct, IR::Opcode inser
if (!result) { if (!result) {
return; return;
} }
inst.ReplaceUsesWith(*result); inst.ReplaceUsesWithAndRemove(*result);
} }
void FoldConvert(IR::Inst& inst, IR::Opcode opposite) { void FoldConvert(IR::Inst& inst, IR::Opcode opposite) {
@ -141,7 +141,7 @@ void FoldConvert(IR::Inst& inst, IR::Opcode opposite) {
} }
IR::Inst* const producer{value.InstRecursive()}; IR::Inst* const producer{value.InstRecursive()};
if (producer->GetOpcode() == opposite) { if (producer->GetOpcode() == opposite) {
inst.ReplaceUsesWith(producer->Arg(0)); inst.ReplaceUsesWithAndRemove(producer->Arg(0));
} }
} }
@ -152,9 +152,9 @@ void FoldLogicalAnd(IR::Inst& inst) {
const IR::Value rhs{inst.Arg(1)}; const IR::Value rhs{inst.Arg(1)};
if (rhs.IsImmediate()) { if (rhs.IsImmediate()) {
if (rhs.U1()) { if (rhs.U1()) {
inst.ReplaceUsesWith(inst.Arg(0)); inst.ReplaceUsesWithAndRemove(inst.Arg(0));
} else { } else {
inst.ReplaceUsesWith(IR::Value{false}); inst.ReplaceUsesWithAndRemove(IR::Value{false});
} }
} }
} }
@ -162,7 +162,7 @@ void FoldLogicalAnd(IR::Inst& inst) {
void FoldSelect(IR::Inst& inst) { void FoldSelect(IR::Inst& inst) {
const IR::Value cond{inst.Arg(0)}; const IR::Value cond{inst.Arg(0)};
if (cond.IsImmediate()) { if (cond.IsImmediate()) {
inst.ReplaceUsesWith(cond.U1() ? inst.Arg(1) : inst.Arg(2)); inst.ReplaceUsesWithAndRemove(cond.U1() ? inst.Arg(1) : inst.Arg(2));
} }
} }
@ -173,9 +173,9 @@ void FoldLogicalOr(IR::Inst& inst) {
const IR::Value rhs{inst.Arg(1)}; const IR::Value rhs{inst.Arg(1)};
if (rhs.IsImmediate()) { if (rhs.IsImmediate()) {
if (rhs.U1()) { if (rhs.U1()) {
inst.ReplaceUsesWith(IR::Value{true}); inst.ReplaceUsesWithAndRemove(IR::Value{true});
} else { } else {
inst.ReplaceUsesWith(inst.Arg(0)); inst.ReplaceUsesWithAndRemove(inst.Arg(0));
} }
} }
} }
@ -183,12 +183,12 @@ void FoldLogicalOr(IR::Inst& inst) {
void FoldLogicalNot(IR::Inst& inst) { void FoldLogicalNot(IR::Inst& inst) {
const IR::U1 value{inst.Arg(0)}; const IR::U1 value{inst.Arg(0)};
if (value.IsImmediate()) { if (value.IsImmediate()) {
inst.ReplaceUsesWith(IR::Value{!value.U1()}); inst.ReplaceUsesWithAndRemove(IR::Value{!value.U1()});
return; return;
} }
IR::Inst* const arg{value.InstRecursive()}; IR::Inst* const arg{value.InstRecursive()};
if (arg->GetOpcode() == IR::Opcode::LogicalNot) { if (arg->GetOpcode() == IR::Opcode::LogicalNot) {
inst.ReplaceUsesWith(arg->Arg(0)); inst.ReplaceUsesWithAndRemove(arg->Arg(0));
} }
} }
@ -199,7 +199,7 @@ void FoldInverseFunc(IR::Inst& inst, IR::Opcode reverse) {
} }
IR::Inst* const arg_inst{value.InstRecursive()}; IR::Inst* const arg_inst{value.InstRecursive()};
if (arg_inst->GetOpcode() == reverse) { if (arg_inst->GetOpcode() == reverse) {
inst.ReplaceUsesWith(arg_inst->Arg(0)); inst.ReplaceUsesWithAndRemove(arg_inst->Arg(0));
return; return;
} }
} }
@ -211,7 +211,7 @@ void FoldAdd(IR::Block& block, IR::Inst& inst) {
} }
const IR::Value rhs{inst.Arg(1)}; const IR::Value rhs{inst.Arg(1)};
if (rhs.IsImmediate() && Arg<T>(rhs) == 0) { if (rhs.IsImmediate() && Arg<T>(rhs) == 0) {
inst.ReplaceUsesWith(inst.Arg(0)); inst.ReplaceUsesWithAndRemove(inst.Arg(0));
return; return;
} }
} }
@ -226,7 +226,8 @@ void FoldCmpClass(IR::Block& block, IR::Inst& inst) {
} else if ((class_mask & IR::FloatClassFunc::Finite) == IR::FloatClassFunc::Finite) { } else if ((class_mask & IR::FloatClassFunc::Finite) == IR::FloatClassFunc::Finite) {
IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)}; IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
const IR::F32 value = IR::F32{inst.Arg(0)}; const IR::F32 value = IR::F32{inst.Arg(0)};
inst.ReplaceUsesWith(ir.LogicalNot(ir.LogicalOr(ir.FPIsInf(value), ir.FPIsInf(value)))); inst.ReplaceUsesWithAndRemove(
ir.LogicalNot(ir.LogicalOr(ir.FPIsInf(value), ir.FPIsInf(value))));
} else { } else {
UNREACHABLE(); UNREACHABLE();
} }
@ -237,7 +238,7 @@ void FoldReadLane(IR::Inst& inst) {
IR::Inst* prod = inst.Arg(0).InstRecursive(); IR::Inst* prod = inst.Arg(0).InstRecursive();
while (prod->GetOpcode() == IR::Opcode::WriteLane) { while (prod->GetOpcode() == IR::Opcode::WriteLane) {
if (prod->Arg(2).U32() == lane) { if (prod->Arg(2).U32() == lane) {
inst.ReplaceUsesWith(prod->Arg(1)); inst.ReplaceUsesWithAndRemove(prod->Arg(1));
return; return;
} }
prod = prod->Arg(0).InstRecursive(); prod = prod->Arg(0).InstRecursive();

View File

@ -25,7 +25,7 @@ void LowerSharedMemToRegisters(IR::Program& program) {
}); });
ASSERT(it != ds_writes.end()); ASSERT(it != ds_writes.end());
// Replace data read with value written. // Replace data read with value written.
inst.ReplaceUsesWith((*it)->Arg(1)); inst.ReplaceUsesWithAndRemove((*it)->Arg(1));
} }
} }
} }

View File

@ -596,7 +596,7 @@ void PatchImageSampleInstruction(IR::Block& block, IR::Inst& inst, Info& info,
} }
return ir.ImageSampleImplicitLod(handle, coords, bias, offset, inst_info); return ir.ImageSampleImplicitLod(handle, coords, bias, offset, inst_info);
}(); }();
inst.ReplaceUsesWith(new_inst); inst.ReplaceUsesWithAndRemove(new_inst);
} }
void PatchImageInstruction(IR::Block& block, IR::Inst& inst, Info& info, Descriptors& descriptors) { void PatchImageInstruction(IR::Block& block, IR::Inst& inst, Info& info, Descriptors& descriptors) {

View File

@ -8,6 +8,7 @@
#include <cstring> #include <cstring>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <boost/container/list.hpp>
#include <boost/container/small_vector.hpp> #include <boost/container/small_vector.hpp>
#include <boost/intrusive/list.hpp> #include <boost/intrusive/list.hpp>
@ -107,6 +108,16 @@ public:
explicit TypedValue(IR::Inst* inst_) : TypedValue(Value(inst_)) {} explicit TypedValue(IR::Inst* inst_) : TypedValue(Value(inst_)) {}
}; };
struct Use {
Inst* user;
u32 operand;
Use() = default;
Use(Inst* user_, u32 operand_) : user(user_), operand(operand_) {}
Use(const Use&) = default;
bool operator==(const Use&) const noexcept = default;
};
class Inst : public boost::intrusive::list_base_hook<> { class Inst : public boost::intrusive::list_base_hook<> {
public: public:
explicit Inst(IR::Opcode op_, u32 flags_) noexcept; explicit Inst(IR::Opcode op_, u32 flags_) noexcept;
@ -120,12 +131,12 @@ public:
/// Get the number of uses this instruction has. /// Get the number of uses this instruction has.
[[nodiscard]] int UseCount() const noexcept { [[nodiscard]] int UseCount() const noexcept {
return use_count; return uses.size();
} }
/// Determines whether this instruction has uses or not. /// Determines whether this instruction has uses or not.
[[nodiscard]] bool HasUses() const noexcept { [[nodiscard]] bool HasUses() const noexcept {
return use_count > 0; return uses.size() > 0;
} }
/// Get the opcode this microinstruction represents. /// Get the opcode this microinstruction represents.
@ -167,7 +178,13 @@ public:
void Invalidate(); void Invalidate();
void ClearArgs(); void ClearArgs();
void ReplaceUsesWith(Value replacement); void ReplaceUsesWithAndRemove(Value replacement) {
ReplaceUsesWith(replacement, false);
}
void ReplaceUsesWith(Value replacement) {
ReplaceUsesWith(replacement, true);
}
void ReplaceOpcode(IR::Opcode opcode); void ReplaceOpcode(IR::Opcode opcode);
@ -202,11 +219,11 @@ private:
NonTriviallyDummy() noexcept {} NonTriviallyDummy() noexcept {}
}; };
void Use(const Value& value); void Use(Inst* used, u32 operand);
void UndoUse(const Value& value); void UndoUse(Inst* used, u32 operand);
void ReplaceUsesWith(Value replacement, bool preserve);
IR::Opcode op{}; IR::Opcode op{};
int use_count{};
u32 flags{}; u32 flags{};
u32 definition{}; u32 definition{};
union { union {
@ -214,8 +231,10 @@ private:
boost::container::small_vector<std::pair<Block*, Value>, 2> phi_args; boost::container::small_vector<std::pair<Block*, Value>, 2> phi_args;
std::array<Value, 6> args; std::array<Value, 6> args;
}; };
boost::container::list<IR::Use> uses;
}; };
static_assert(sizeof(Inst) <= 128, "Inst size unintentionally increased"); static_assert(sizeof(Inst) <= 152, "Inst size unintentionally increased");
using U1 = TypedValue<Type::U1>; using U1 = TypedValue<Type::U1>;
using U8 = TypedValue<Type::U8>; using U8 = TypedValue<Type::U8>;