shader_recompiler: Fix quad/rect list FS passthrough semantics.

This commit is contained in:
squidbus 2024-12-23 15:39:53 -08:00 committed by IndecisiveTurtle
parent df4613a82e
commit 9a183f4820
3 changed files with 23 additions and 18 deletions

View File

@ -3,6 +3,7 @@
#include <sirit/sirit.h>
#include "shader_recompiler/backend/spirv/emit_spirv_quad_rect.h"
#include "shader_recompiler/runtime_info.h"
namespace Shader::Backend::SPIRV {
@ -11,9 +12,9 @@ using Sirit::Id;
constexpr u32 SPIRV_VERSION_1_3 = 0x00010300;
struct QuadRectListEmitter : public Sirit::Module {
explicit QuadRectListEmitter(size_t num_attribs_)
: Sirit::Module{SPIRV_VERSION_1_3}, num_attribs{num_attribs_}, inputs{num_attribs},
outputs{num_attribs} {
explicit QuadRectListEmitter(const FragmentRuntimeInfo& fs_info_)
: Sirit::Module{SPIRV_VERSION_1_3}, fs_info{fs_info_}, inputs{fs_info_.num_inputs},
outputs{fs_info_.num_inputs} {
void_id = TypeVoid();
bool_id = TypeBool();
float_id = TypeFloat(32);
@ -161,7 +162,7 @@ struct QuadRectListEmitter : public Sirit::Module {
const Id in_position{OpLoad(vec4_id, OpAccessChain(input_vec4, gl_in, index, Int(0)))};
OpStore(OpAccessChain(output_vec4, gl_out, invocation_id, Int(0)), in_position);
for (int i = 0; i < num_attribs; i++) {
for (int i = 0; i < inputs.size(); i++) {
// out_paramN[gl_InvocationID] = in_paramN[gl_InvocationID];
const Id in_param{OpLoad(vec4_id, OpAccessChain(input_vec4, inputs[i], index))};
OpStore(OpAccessChain(output_vec4, outputs[i], invocation_id), in_param);
@ -189,7 +190,7 @@ struct QuadRectListEmitter : public Sirit::Module {
OpStore(OpAccessChain(output_vec4, gl_per_vertex, Int(0)), position);
// out_paramN = in_paramN[index];
for (int i = 0; i < num_attribs; i++) {
for (int i = 0; i < inputs.size(); i++) {
const Id param{OpLoad(vec4_id, OpAccessChain(input_vec4, inputs[i], index))};
OpStore(outputs[i], param);
}
@ -252,11 +253,11 @@ private:
} else {
gl_per_vertex = AddOutput(gl_per_vertex_type);
}
for (int i = 0; i < num_attribs; i++) {
for (int i = 0; i < fs_info.num_inputs; i++) {
outputs[i] = AddOutput(model == spv::ExecutionModel::TessellationControl
? TypeArray(vec4_id, Int(4))
: vec4_id);
Decorate(outputs[i], spv::Decoration::Location, i);
Decorate(outputs[i], spv::Decoration::Location, fs_info.inputs[i].param_index);
}
}
@ -271,14 +272,14 @@ private:
const Id gl_per_vertex_array{TypeArray(gl_per_vertex_type, Constant(uint_id, 32U))};
gl_in = AddInput(gl_per_vertex_array);
const Id float_arr{TypeArray(vec4_id, Int(32))};
for (int i = 0; i < num_attribs; i++) {
for (int i = 0; i < fs_info.num_inputs; i++) {
inputs[i] = AddInput(float_arr);
Decorate(inputs[i], spv::Decoration::Location, i);
Decorate(inputs[i], spv::Decoration::Location, fs_info.inputs[i].param_index);
}
}
private:
size_t num_attribs;
FragmentRuntimeInfo fs_info;
Id main;
Id void_id;
Id bool_id;
@ -309,8 +310,8 @@ private:
std::vector<Id> interfaces;
};
std::vector<u32> EmitAuxilaryTessShader(AuxShaderType type, size_t num_attribs) {
QuadRectListEmitter ctx{num_attribs};
std::vector<u32> EmitAuxilaryTessShader(AuxShaderType type, const FragmentRuntimeInfo& fs_info) {
QuadRectListEmitter ctx{fs_info};
switch (type) {
case AuxShaderType::RectListTCS:
ctx.EmitRectListTCS();

View File

@ -6,6 +6,10 @@
#include <vector>
#include "common/types.h"
namespace Shader {
struct FragmentRuntimeInfo;
}
namespace Shader::Backend::SPIRV {
enum class AuxShaderType : u32 {
@ -14,6 +18,7 @@ enum class AuxShaderType : u32 {
PassthroughTES,
};
[[nodiscard]] std::vector<u32> EmitAuxilaryTessShader(AuxShaderType type, size_t num_attribs);
[[nodiscard]] std::vector<u32> EmitAuxilaryTessShader(AuxShaderType type,
const FragmentRuntimeInfo& fs_info);
} // namespace Shader::Backend::SPIRV

View File

@ -108,8 +108,7 @@ GraphicsPipeline::GraphicsPipeline(
"Primitive restart index other than -1 is not supported yet");
const bool is_rect_list = key.prim_type == AmdGpu::PrimitiveType::RectList;
const bool is_quad_list = key.prim_type == AmdGpu::PrimitiveType::QuadList;
const size_t num_fs_inputs =
runtime_infos[u32(Shader::LogicalStage::Fragment)].fs_info.num_inputs;
const auto& fs_info = runtime_infos[u32(Shader::LogicalStage::Fragment)].fs_info;
const vk::PipelineTessellationStateCreateInfo tessellation_state = {
.patchControlPoints = is_rect_list ? 3U : (is_quad_list ? 4U : key.patch_control_points),
};
@ -237,7 +236,7 @@ GraphicsPipeline::GraphicsPipeline(
});
} else if (is_rect_list || is_quad_list) {
const auto type = is_quad_list ? AuxShaderType::QuadListTCS : AuxShaderType::RectListTCS;
auto tcs = Shader::Backend::SPIRV::EmitAuxilaryTessShader(type, num_fs_inputs);
auto tcs = Shader::Backend::SPIRV::EmitAuxilaryTessShader(type, fs_info);
shader_stages.emplace_back(vk::PipelineShaderStageCreateInfo{
.stage = vk::ShaderStageFlagBits::eTessellationControl,
.module = CompileSPV(tcs, instance.GetDevice()),
@ -252,8 +251,8 @@ GraphicsPipeline::GraphicsPipeline(
.pName = "main",
});
} else if (is_rect_list || is_quad_list) {
auto tes = Shader::Backend::SPIRV::EmitAuxilaryTessShader(AuxShaderType::PassthroughTES,
num_fs_inputs);
auto tes =
Shader::Backend::SPIRV::EmitAuxilaryTessShader(AuxShaderType::PassthroughTES, fs_info);
shader_stages.emplace_back(vk::PipelineShaderStageCreateInfo{
.stage = vk::ShaderStageFlagBits::eTessellationEvaluation,
.module = CompileSPV(tes, instance.GetDevice()),