Store users as list. Use iterator facade for uses

This commit is contained in:
Frodo Baggins 2024-10-23 21:41:23 -07:00
parent 3c6d335be4
commit ed3e12176f
2 changed files with 140 additions and 29 deletions

View File

@ -147,7 +147,7 @@ void Inst::AddPhiOperand(Block* predecessor, const Value& value) {
}
void Inst::Invalidate() {
ASSERT(uses.empty());
ASSERT(users.list.empty());
ClearArgs();
ReplaceOpcode(Opcode::Void);
}
@ -175,13 +175,52 @@ void Inst::ClearArgs() {
}
}
UseIterator Inst::UserList::UseBegin() {
return UseIterator(list.begin(), list.end());
}
UseIterator Inst::UserList::UseEnd() {
return UseIterator(list.end(), list.end());
}
boost::iterator_range<UseIterator> Inst::UserList::Uses() {
return boost::make_iterator_range(UseBegin(), UseEnd());
}
void Inst::UserList::AddUse(IR::Inst* user, u32 operand) {
DEBUG_ASSERT(operand < 31);
auto it = std::find_if(list.begin(), list.end(),
[&](const UserNode& user_node) { return user_node.user == user; });
u32 operand_pos = 1 << operand;
if (it == list.end()) {
list.emplace_front(user, operand_pos);
} else {
DEBUG_ASSERT((it->operand_mask & operand_pos) == 0);
it->operand_mask |= operand_pos;
}
++num_uses;
}
void Inst::UserList::RemoveUse(IR::Inst* user, u32 operand) {
auto it = std::find_if(list.begin(), list.end(),
[&](const UserNode& user_node) { return user_node.user == user; });
DEBUG_ASSERT(it != list.end());
u32 operand_pos = 1 << operand;
DEBUG_ASSERT((it->operand_mask & operand_pos) != 0);
it->operand_mask &= ~operand_pos;
if (it->operand_mask == 0) {
list.erase(it);
}
--num_uses;
}
void Inst::ReplaceUsesWith(Value replacement, bool preserve) {
// Copy since user->SetArg will mutate this->uses
// Could also do temp_uses = std::move(uses) but more readable
boost::container::list<IR::Use> temp_uses = uses;
for (auto& [user, operand] : temp_uses) {
DEBUG_ASSERT(user->Arg(operand).Inst() == this);
user->SetArg(operand, replacement);
UserList temp_users = users;
for (IR::Use use : temp_users.Uses()) {
DEBUG_ASSERT(use.user->Arg(use.operand).Inst() == this);
use.user->SetArg(use.operand, replacement);
}
Invalidate();
if (preserve) {
@ -205,14 +244,23 @@ void Inst::ReplaceOpcode(IR::Opcode opcode) {
}
void Inst::Use(Inst* used, u32 operand) {
DEBUG_ASSERT(0 == std::count(used->uses.begin(), used->uses.end(), IR::Use(this, operand)));
used->uses.emplace_front(this, operand);
used->users.AddUse(this, operand);
}
void Inst::UndoUse(Inst* used, u32 operand) {
IR::Use use(this, operand);
DEBUG_ASSERT(1 == std::count(used->uses.begin(), used->uses.end(), use));
used->uses.remove(use);
used->users.RemoveUse(this, operand);
}
UseIterator Inst::UseBegin() {
return users.UseBegin();
}
UseIterator Inst::UseEnd() {
return users.UseEnd();
}
boost::iterator_range<UseIterator> Inst::Uses() {
return boost::make_iterator_range(UseBegin(), UseEnd());
}
} // namespace Shader::IR

View File

@ -110,13 +110,10 @@ public:
struct Use {
Inst* user;
u32 operand;
Use() = default;
Use(Inst* user_, u32 operand_) : user(user_), operand(operand_) {}
Use(const Use&) = default;
bool operator==(const Use&) const noexcept = default;
};
class UseIterator;
class Inst : public boost::intrusive::list_base_hook<> {
public:
explicit Inst(IR::Opcode op_, u32 flags_) noexcept;
@ -130,12 +127,12 @@ public:
/// Get the number of uses this instruction has.
[[nodiscard]] int UseCount() const noexcept {
return uses.size();
return users.num_uses;
}
/// Determines whether this instruction has uses or not.
[[nodiscard]] bool HasUses() const noexcept {
return uses.size() > 0;
return users.num_uses > 0;
}
/// Get the opcode this microinstruction represents.
@ -213,20 +210,22 @@ public:
return std::bit_cast<DefinitionType>(definition);
}
auto users() {
auto Users() {
return boost::adaptors::transform(users.list,
[](UserList::UserNode& user) { return user.user; });
}
const auto Users() const {
return boost::adaptors::transform(
boost::unique(uses,
[](const IR::Use& a, const IR::Use& b) { return a.user == b.user; }),
[](const IR::Use& use) { return use.user; });
users.list,
[](const UserList::UserNode& user) -> const IR::Inst* { return user.user; });
}
auto user_begin() {
return users().begin();
}
UseIterator UseBegin();
UseIterator UseEnd();
auto user_end() {
return users().end();
}
boost::iterator_range<UseIterator> Uses();
const boost::iterator_range<const UseIterator> Uses() const;
private:
struct NonTriviallyDummy {
@ -246,9 +245,73 @@ private:
std::array<Value, 6> args;
};
boost::container::list<IR::Use> uses;
struct UserList {
struct UserNode {
IR::Inst* user;
u32 operand_mask;
};
boost::container::list<UserNode> list;
using UserIterator = boost::container::list<UserNode>::iterator;
u32 num_uses{};
void AddUse(IR::Inst* user, u32 operand);
void RemoveUse(IR::Inst* user, u32 operand);
UseIterator UseBegin();
UseIterator UseEnd();
boost::iterator_range<UseIterator> Uses();
} users;
friend class UseIterator;
};
static_assert(sizeof(Inst) <= 152, "Inst size unintentionally increased");
static_assert(sizeof(Inst) <= 160, "Inst size unintentionally increased");
class UseIterator
: public boost::iterator_facade<UseIterator, IR::Use, boost::forward_traversal_tag, IR::Use> {
public:
UseIterator() : user_it(), user_end(), bitmask_pos(0) {};
explicit UseIterator(IR::Inst::UserList::UserIterator user_begin_,
IR::Inst::UserList::UserIterator user_end_)
: user_it(user_begin_), user_end(user_end_), bitmask_pos(0) {
if (user_it != user_end) {
bitmask_pos = std::countr_zero(user_it->operand_mask);
DEBUG_ASSERT(user_it->operand_mask != 0);
}
};
private:
friend class boost::iterator_core_access;
void increment() {
// Assumes inst has less than 31 operands
u32 mask = 1 << (bitmask_pos + 1);
u32 use_mask = user_it->operand_mask & ~(mask - 1);
if (use_mask == 0) {
++user_it;
if (user_it == user_end) {
bitmask_pos = 0;
return;
}
use_mask = user_it->operand_mask;
ASSERT(use_mask != 0);
}
bitmask_pos = std::countr_zero(use_mask);
};
bool equal(UseIterator const& other) const {
return user_it == other.user_it && bitmask_pos == other.bitmask_pos;
};
IR::Use dereference() const {
return IR::Use(user_it->user, bitmask_pos);
};
IR::Inst::UserList::UserIterator user_it;
IR::Inst::UserList::UserIterator user_end;
u32 bitmask_pos;
}; // class UseIterator
using U1 = TypedValue<Type::U1>;
using U8 = TypedValue<Type::U8>;