ir_passes: Adjust shared memory barrier pass to cover more cases

This commit is contained in:
IndecisiveTurtle 2025-06-10 23:18:48 +03:00
parent 2b3eef0114
commit 2554ea55aa

View File

@ -1,6 +1,7 @@
// SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project // SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project
// SPDX-License-Identifier: GPL-2.0-or-later // SPDX-License-Identifier: GPL-2.0-or-later
#include <unordered_set>
#include "shader_recompiler/ir/breadth_first_search.h" #include "shader_recompiler/ir/breadth_first_search.h"
#include "shader_recompiler/ir/ir_emitter.h" #include "shader_recompiler/ir/ir_emitter.h"
#include "shader_recompiler/ir/program.h" #include "shader_recompiler/ir/program.h"
@ -49,11 +50,14 @@ static void EmitBarrierInBlock(IR::Block* block) {
} }
} }
using NodeSet = std::unordered_set<const IR::Block*>;
// Inserts a barrier after divergent conditional blocks to avoid undefined // Inserts a barrier after divergent conditional blocks to avoid undefined
// behavior when some threads write and others read from shared memory. // behavior when some threads write and others read from shared memory.
static void EmitBarrierInMergeBlock(const IR::AbstractSyntaxNode::Data& data) { static void EmitBarrierInMergeBlock(const IR::AbstractSyntaxNode::Data& data,
NodeSet& divergence_end, u32& divergence_depth) {
const IR::U1 cond = data.if_node.cond; const IR::U1 cond = data.if_node.cond;
const auto insert_barrier = const auto is_divergent_cond =
IR::BreadthFirstSearch(cond, [](IR::Inst* inst) -> std::optional<bool> { IR::BreadthFirstSearch(cond, [](IR::Inst* inst) -> std::optional<bool> {
if (inst->GetOpcode() == IR::Opcode::GetAttributeU32 && if (inst->GetOpcode() == IR::Opcode::GetAttributeU32 &&
inst->Arg(0).Attribute() == IR::Attribute::LocalInvocationId) { inst->Arg(0).Attribute() == IR::Attribute::LocalInvocationId) {
@ -61,11 +65,15 @@ static void EmitBarrierInMergeBlock(const IR::AbstractSyntaxNode::Data& data) {
} }
return std::nullopt; return std::nullopt;
}); });
if (insert_barrier) { if (is_divergent_cond) {
IR::Block* const merge = data.if_node.merge; if (divergence_depth == 0) {
auto insert_point = std::ranges::find_if_not(merge->Instructions(), IR::IsPhi); IR::Block* const merge = data.if_node.merge;
IR::IREmitter ir{*merge, insert_point}; auto insert_point = std::ranges::find_if_not(merge->Instructions(), IR::IsPhi);
ir.Barrier(); IR::IREmitter ir{*merge, insert_point};
ir.Barrier();
}
++divergence_depth;
divergence_end.emplace(data.if_node.merge);
} }
} }
@ -87,19 +95,22 @@ void SharedMemoryBarrierPass(IR::Program& program, const RuntimeInfo& runtime_in
return; return;
} }
using Type = IR::AbstractSyntaxNode::Type; using Type = IR::AbstractSyntaxNode::Type;
u32 branch_depth{}; u32 divergence_depth{};
NodeSet divergence_end;
for (const IR::AbstractSyntaxNode& node : program.syntax_list) { for (const IR::AbstractSyntaxNode& node : program.syntax_list) {
if (node.type == Type::EndIf) { if (node.type == Type::EndIf) {
--branch_depth; if (divergence_end.contains(node.data.end_if.merge)) {
--divergence_depth;
}
continue; continue;
} }
// Check if branch depth is zero, we don't want to insert barrier in potentially divergent // Check if branch depth is zero, we don't want to insert barrier in potentially divergent
// code. // code.
if (node.type == Type::If && branch_depth++ == 0) { if (node.type == Type::If) {
EmitBarrierInMergeBlock(node.data); EmitBarrierInMergeBlock(node.data, divergence_end, divergence_depth);
continue; continue;
} }
if (node.type == Type::Block && branch_depth == 0) { if (node.type == Type::Block && divergence_depth == 0) {
EmitBarrierInBlock(node.data.block); EmitBarrierInBlock(node.data.block);
} }
} }