From b0dd81a2b9a804a4545ba7aab5c98de1eff81860 Mon Sep 17 00:00:00 2001 From: IndecisiveTurtle <47210458+raphaelthegreat@users.noreply.github.com> Date: Fri, 14 Feb 2025 11:39:15 +0200 Subject: [PATCH] shader_recompiler: Complete buffer aliasing support * Add a bunch more types into buffers, such as F32 for float reads/writes and 8/16 bit integer types for formatted buffers --- .../backend/spirv/emit_spirv.cpp | 7 +- .../spirv/emit_spirv_context_get_set.cpp | 116 ++++++++++-------- .../backend/spirv/spirv_emit_context.cpp | 17 ++- .../ir/passes/resource_tracking_pass.cpp | 15 ++- .../renderer_vulkan/vk_instance.cpp | 14 ++- 5 files changed, 111 insertions(+), 58 deletions(-) diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.cpp b/src/shader_recompiler/backend/spirv/emit_spirv.cpp index 3712380f5..2a5b9335e 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv.cpp @@ -242,14 +242,17 @@ void SetupCapabilities(const Info& info, const Profile& profile, EmitContext& ct ctx.AddCapability(spv::Capability::Image1D); ctx.AddCapability(spv::Capability::Sampled1D); ctx.AddCapability(spv::Capability::ImageQuery); + ctx.AddCapability(spv::Capability::Int8); + ctx.AddCapability(spv::Capability::Int16); + ctx.AddCapability(spv::Capability::Int64); + ctx.AddCapability(spv::Capability::UniformAndStorageBuffer8BitAccess); + ctx.AddCapability(spv::Capability::UniformAndStorageBuffer16BitAccess); if (info.uses_fp16) { ctx.AddCapability(spv::Capability::Float16); - ctx.AddCapability(spv::Capability::Int16); } if (info.uses_fp64) { ctx.AddCapability(spv::Capability::Float64); } - ctx.AddCapability(spv::Capability::Int64); if (info.has_storage_images) { ctx.AddCapability(spv::Capability::StorageImageExtendedFormats); ctx.AddCapability(spv::Capability::StorageImageReadWithoutFormat); diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp index 5dffd1be4..798dfe8c8 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp @@ -398,83 +398,96 @@ void EmitSetPatch(EmitContext& ctx, IR::Patch patch, Id value) { ctx.OpStore(pointer, value); } -template -static Id EmitLoadBufferU32xN(EmitContext& ctx, u32 handle, Id address) { - const auto& buffer = ctx.buffers[handle]; - address = ctx.OpIAdd(ctx.U32[1], address, buffer.offset); - const auto [id, pointer_type] = buffer[BufferAlias::U32]; +template +static Id EmitLoadBufferB32xN(EmitContext& ctx, u32 handle, Id address) { + const auto& spv_buffer = ctx.buffers[handle]; + if (Sirit::ValidId(spv_buffer.offset)) { + address = ctx.OpIAdd(ctx.U32[1], address, spv_buffer.offset); + } const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u)); + const auto& data_types = alias == BufferAlias::U32 ? ctx.U32 : ctx.F32; + const auto [id, pointer_type] = spv_buffer[alias]; if constexpr (N == 1) { const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index)}; - return ctx.OpLoad(ctx.U32[1], ptr); + return ctx.OpLoad(data_types[1], ptr); } else { boost::container::static_vector ids; for (u32 i = 0; i < N; i++) { const Id index_i = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i)); const Id ptr{ ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index_i)}; - ids.push_back(ctx.OpLoad(ctx.U32[1], ptr)); + ids.push_back(ctx.OpLoad(data_types[1], ptr)); } - return ctx.OpCompositeConstruct(ctx.U32[N], ids); + return ctx.OpCompositeConstruct(data_types[N], ids); } } Id EmitLoadBufferU8(EmitContext& ctx, IR::Inst*, u32 handle, Id address) { - const Id byte_index{ctx.OpBitwiseAnd(ctx.U32[1], address, ctx.ConstU32(3u))}; - const Id bit_offset{ctx.OpShiftLeftLogical(ctx.U32[1], byte_index, ctx.ConstU32(3u))}; - const Id dword{EmitLoadBufferU32xN<1>(ctx, handle, address)}; - return ctx.OpBitFieldUExtract(ctx.U32[1], dword, bit_offset, ctx.ConstU32(8u)); + const auto& spv_buffer = ctx.buffers[handle]; + if (Sirit::ValidId(spv_buffer.offset)) { + address = ctx.OpIAdd(ctx.U32[1], address, spv_buffer.offset); + } + const auto [id, pointer_type] = spv_buffer[BufferAlias::U8]; + const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, address)}; + return ctx.OpUConvert(ctx.U32[1], ctx.OpLoad(ctx.U8, ptr)); } Id EmitLoadBufferU16(EmitContext& ctx, IR::Inst*, u32 handle, Id address) { - const Id byte_index{ctx.OpBitwiseAnd(ctx.U32[1], address, ctx.ConstU32(2u))}; - const Id bit_offset{ctx.OpShiftLeftLogical(ctx.U32[1], byte_index, ctx.ConstU32(3u))}; - const Id dword{EmitLoadBufferU32xN<1>(ctx, handle, address)}; - return ctx.OpBitFieldUExtract(ctx.U32[1], dword, bit_offset, ctx.ConstU32(16u)); + const auto& spv_buffer = ctx.buffers[handle]; + if (Sirit::ValidId(spv_buffer.offset)) { + address = ctx.OpIAdd(ctx.U32[1], address, spv_buffer.offset); + } + const auto [id, pointer_type] = spv_buffer[BufferAlias::U16]; + const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(1u)); + const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index)}; + return ctx.OpUConvert(ctx.U32[1], ctx.OpLoad(ctx.U16, ptr)); } Id EmitLoadBufferU32(EmitContext& ctx, IR::Inst*, u32 handle, Id address) { - return EmitLoadBufferU32xN<1>(ctx, handle, address); + return EmitLoadBufferB32xN<1, BufferAlias::U32>(ctx, handle, address); } Id EmitLoadBufferU32x2(EmitContext& ctx, IR::Inst*, u32 handle, Id address) { - return EmitLoadBufferU32xN<2>(ctx, handle, address); + return EmitLoadBufferB32xN<2, BufferAlias::U32>(ctx, handle, address); } Id EmitLoadBufferU32x3(EmitContext& ctx, IR::Inst*, u32 handle, Id address) { - return EmitLoadBufferU32xN<3>(ctx, handle, address); + return EmitLoadBufferB32xN<3, BufferAlias::U32>(ctx, handle, address); } Id EmitLoadBufferU32x4(EmitContext& ctx, IR::Inst*, u32 handle, Id address) { - return EmitLoadBufferU32xN<4>(ctx, handle, address); + return EmitLoadBufferB32xN<4, BufferAlias::U32>(ctx, handle, address); } Id EmitLoadBufferF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) { - return ctx.OpBitcast(ctx.F32[1], EmitLoadBufferU32(ctx, inst, handle, address)); + return EmitLoadBufferB32xN<1, BufferAlias::F32>(ctx, handle, address); } Id EmitLoadBufferF32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) { - return ctx.OpBitcast(ctx.F32[2], EmitLoadBufferU32x2(ctx, inst, handle, address)); + return EmitLoadBufferB32xN<2, BufferAlias::F32>(ctx, handle, address); } Id EmitLoadBufferF32x3(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) { - return ctx.OpBitcast(ctx.F32[3], EmitLoadBufferU32x3(ctx, inst, handle, address)); + return EmitLoadBufferB32xN<3, BufferAlias::F32>(ctx, handle, address); } Id EmitLoadBufferF32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) { - return ctx.OpBitcast(ctx.F32[4], EmitLoadBufferU32x4(ctx, inst, handle, address)); + return EmitLoadBufferB32xN<4, BufferAlias::F32>(ctx, handle, address); } Id EmitLoadBufferFormatF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) { UNREACHABLE_MSG("SPIR-V instruction"); } -template -static void EmitStoreBufferU32xN(EmitContext& ctx, u32 handle, Id address, Id value) { - auto& buffer = ctx.buffers[handle]; - address = ctx.OpIAdd(ctx.U32[1], address, buffer.offset); - const auto [id, pointer_type] = buffer[BufferAlias::U32]; +template +static void EmitStoreBufferB32xN(EmitContext& ctx, u32 handle, Id address, Id value) { + const auto& spv_buffer = ctx.buffers[handle]; + if (Sirit::ValidId(spv_buffer.offset)) { + address = ctx.OpIAdd(ctx.U32[1], address, spv_buffer.offset); + } const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u)); + const auto& data_types = alias == BufferAlias::U32 ? ctx.U32 : ctx.F32; + const auto [id, pointer_type] = spv_buffer[alias]; if constexpr (N == 1) { const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index)}; ctx.OpStore(ptr, value); @@ -482,58 +495,63 @@ static void EmitStoreBufferU32xN(EmitContext& ctx, u32 handle, Id address, Id va for (u32 i = 0; i < N; i++) { const Id index_i = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i)); const Id ptr = - ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index_i); - ctx.OpStore(ptr, ctx.OpCompositeExtract(ctx.U32[1], value, i)); + ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index_i); + ctx.OpStore(ptr, ctx.OpCompositeExtract(data_types[1], value, i)); } } } void EmitStoreBufferU8(EmitContext& ctx, IR::Inst*, u32 handle, Id address, Id value) { - const Id byte_index{ctx.OpBitwiseAnd(ctx.U32[1], address, ctx.ConstU32(3u))}; - const Id bit_offset{ctx.OpShiftLeftLogical(ctx.U32[1], byte_index, ctx.ConstU32(3u))}; - const Id dword{EmitLoadBufferU32xN<1>(ctx, handle, address)}; - const Id new_val{ctx.OpBitFieldInsert(ctx.U32[1], dword, value, bit_offset, ctx.ConstU32(8u))}; - EmitStoreBufferU32xN<1>(ctx, handle, address, new_val); + const auto& spv_buffer = ctx.buffers[handle]; + if (Sirit::ValidId(spv_buffer.offset)) { + address = ctx.OpIAdd(ctx.U32[1], address, spv_buffer.offset); + } + const auto [id, pointer_type] = spv_buffer[BufferAlias::U8]; + const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, address)}; + ctx.OpStore(ptr, ctx.OpUConvert(ctx.U8, value)); } void EmitStoreBufferU16(EmitContext& ctx, IR::Inst*, u32 handle, Id address, Id value) { - const Id byte_index{ctx.OpBitwiseAnd(ctx.U32[1], address, ctx.ConstU32(2u))}; - const Id bit_offset{ctx.OpShiftLeftLogical(ctx.U32[1], byte_index, ctx.ConstU32(3u))}; - const Id dword{EmitLoadBufferU32xN<1>(ctx, handle, address)}; - const Id new_val{ctx.OpBitFieldInsert(ctx.U32[1], dword, value, bit_offset, ctx.ConstU32(16u))}; - EmitStoreBufferU32xN<1>(ctx, handle, address, new_val); + const auto& spv_buffer = ctx.buffers[handle]; + if (Sirit::ValidId(spv_buffer.offset)) { + address = ctx.OpIAdd(ctx.U32[1], address, spv_buffer.offset); + } + const auto [id, pointer_type] = spv_buffer[BufferAlias::U16]; + const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(1u)); + const Id ptr{ctx.OpAccessChain(pointer_type, id, ctx.u32_zero_value, index)}; + ctx.OpStore(ptr, ctx.OpUConvert(ctx.U16, value)); } void EmitStoreBufferU32(EmitContext& ctx, IR::Inst*, u32 handle, Id address, Id value) { - EmitStoreBufferU32xN<1>(ctx, handle, address, value); + EmitStoreBufferB32xN<1, BufferAlias::U32>(ctx, handle, address, value); } void EmitStoreBufferU32x2(EmitContext& ctx, IR::Inst*, u32 handle, Id address, Id value) { - EmitStoreBufferU32xN<2>(ctx, handle, address, value); + EmitStoreBufferB32xN<2, BufferAlias::U32>(ctx, handle, address, value); } void EmitStoreBufferU32x3(EmitContext& ctx, IR::Inst*, u32 handle, Id address, Id value) { - EmitStoreBufferU32xN<3>(ctx, handle, address, value); + EmitStoreBufferB32xN<3, BufferAlias::U32>(ctx, handle, address, value); } void EmitStoreBufferU32x4(EmitContext& ctx, IR::Inst*, u32 handle, Id address, Id value) { - EmitStoreBufferU32xN<4>(ctx, handle, address, value); + EmitStoreBufferB32xN<4, BufferAlias::U32>(ctx, handle, address, value); } void EmitStoreBufferF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) { - EmitStoreBufferU32(ctx, inst, handle, address, ctx.OpBitcast(ctx.U32[1], value)); + EmitStoreBufferB32xN<1, BufferAlias::F32>(ctx, handle, address, value); } void EmitStoreBufferF32x2(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) { - EmitStoreBufferU32x2(ctx, inst, handle, address, ctx.OpBitcast(ctx.U32[2], value)); + EmitStoreBufferB32xN<2, BufferAlias::F32>(ctx, handle, address, value); } void EmitStoreBufferF32x3(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) { - EmitStoreBufferU32x3(ctx, inst, handle, address, ctx.OpBitcast(ctx.U32[3], value)); + EmitStoreBufferB32xN<3, BufferAlias::F32>(ctx, handle, address, value); } void EmitStoreBufferF32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) { - EmitStoreBufferU32x4(ctx, inst, handle, address, ctx.OpBitcast(ctx.U32[4], value)); + EmitStoreBufferB32xN<4, BufferAlias::F32>(ctx, handle, address, value); } void EmitStoreBufferFormatF32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address, Id value) { diff --git a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp index c95bc2560..18985500e 100644 --- a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp +++ b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp @@ -107,6 +107,8 @@ Id EmitContext::Def(const IR::Value& value) { void EmitContext::DefineArithmeticTypes() { void_id = Name(TypeVoid(), "void_id"); U1[1] = Name(TypeBool(), "bool_id"); + U8 = Name(TypeUInt(8), "u8_id"); + U16 = Name(TypeUInt(16), "u16_id"); if (info.uses_fp16) { F16[1] = Name(TypeFloat(16), "f16_id"); U16 = Name(TypeUInt(16), "u16_id"); @@ -638,8 +640,21 @@ void EmitContext::DefineBuffers() { for (const auto& desc : info.buffers) { const auto buf_sharp = desc.GetSharp(info); const bool is_storage = desc.IsStorage(buf_sharp, profile); + + // Define aliases depending on the shader usage. auto& spv_buffer = buffers.emplace_back(binding.buffer++, desc.buffer_type); - spv_buffer[BufferAlias::U32] = DefineBuffer(is_storage, desc.is_written, 2, desc.buffer_type, U32[1]); + if (True(desc.used_types & IR::Type::U32)) { + spv_buffer[BufferAlias::U32] = DefineBuffer(is_storage, desc.is_written, 2, desc.buffer_type, U32[1]); + } + if (True(desc.used_types & IR::Type::F32)) { + spv_buffer[BufferAlias::F32] = DefineBuffer(is_storage, desc.is_written, 2, desc.buffer_type, F32[1]); + } + if (True(desc.used_types & IR::Type::U16)) { + spv_buffer[BufferAlias::U16] = DefineBuffer(is_storage, desc.is_written, 1, desc.buffer_type, U16); + } + if (True(desc.used_types & IR::Type::U8)) { + spv_buffer[BufferAlias::U8] = DefineBuffer(is_storage, desc.is_written, 0, desc.buffer_type, U8); + } ++binding.unified; } } diff --git a/src/shader_recompiler/ir/passes/resource_tracking_pass.cpp b/src/shader_recompiler/ir/passes/resource_tracking_pass.cpp index 5737707b0..bdcb65fc3 100644 --- a/src/shader_recompiler/ir/passes/resource_tracking_pass.cpp +++ b/src/shader_recompiler/ir/passes/resource_tracking_pass.cpp @@ -78,7 +78,20 @@ bool IsDataRingInstruction(const IR::Inst& inst) { } IR::Type BufferDataType(const IR::Inst& inst, AmdGpu::NumberFormat num_format) { - return IR::Type::U32; + switch (inst.GetOpcode()) { + case IR::Opcode::LoadBufferU8: + case IR::Opcode::StoreBufferU8: + return IR::Type::U8; + case IR::Opcode::LoadBufferU16: + case IR::Opcode::StoreBufferU16: + return IR::Type::U16; + case IR::Opcode::LoadBufferFormatF32: + case IR::Opcode::StoreBufferFormatF32: + // Formatted buffer loads can use a variety of types. + return IR::Type::U32 | IR::Type::F32 | IR::Type::U16 | IR::Type::U8; + default: + return IR::Type::U32; + } } bool IsImageAtomicInstruction(const IR::Inst& inst) { diff --git a/src/video_core/renderer_vulkan/vk_instance.cpp b/src/video_core/renderer_vulkan/vk_instance.cpp index 780779c0b..761ef6fff 100644 --- a/src/video_core/renderer_vulkan/vk_instance.cpp +++ b/src/video_core/renderer_vulkan/vk_instance.cpp @@ -1,14 +1,11 @@ // SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project // SPDX-License-Identifier: GPL-2.0-or-later -#include -#include #include #include #include #include "common/assert.h" -#include "common/config.h" #include "common/debug.h" #include "sdl_window.h" #include "video_core/renderer_vulkan/liverpool_to_vk.h" @@ -208,7 +205,8 @@ std::string Instance::GetDriverVersionName() { bool Instance::CreateDevice() { const vk::StructureChain feature_chain = physical_device - .getFeatures2(); + const auto vk11_features = feature_chain.get(); const auto vk12_features = feature_chain.get(); vk::StructureChain device_chain = { vk::DeviceCreateInfo{ @@ -351,12 +350,17 @@ bool Instance::CreateDevice() { }, }, vk::PhysicalDeviceVulkan11Features{ - .shaderDrawParameters = true, + .storageBuffer16BitAccess = vk11_features.storageBuffer16BitAccess, + .uniformAndStorageBuffer16BitAccess = vk11_features.uniformAndStorageBuffer16BitAccess, + .shaderDrawParameters = vk11_features.shaderDrawParameters, }, vk::PhysicalDeviceVulkan12Features{ .samplerMirrorClampToEdge = vk12_features.samplerMirrorClampToEdge, .drawIndirectCount = vk12_features.drawIndirectCount, + .storageBuffer8BitAccess = vk12_features.storageBuffer8BitAccess, + .uniformAndStorageBuffer8BitAccess = vk12_features.uniformAndStorageBuffer8BitAccess, .shaderFloat16 = vk12_features.shaderFloat16, + .shaderInt8 = vk12_features.shaderInt8, .scalarBlockLayout = vk12_features.scalarBlockLayout, .uniformBufferStandardLayout = vk12_features.uniformBufferStandardLayout, .separateDepthStencilLayouts = vk12_features.separateDepthStencilLayouts,