From 8ab3da6b686ed40b8e809809ede5afd7bb857212 Mon Sep 17 00:00:00 2001 From: Frodo Baggins Date: Sun, 24 Nov 2024 15:41:35 -0800 Subject: [PATCH] attr arrays in TCS/TES --- .../spirv/emit_spirv_context_get_set.cpp | 37 ++++++++++++------- .../backend/spirv/spirv_emit_context.cpp | 19 ++++++++++ 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp index a31a35b4a..d4a189f0a 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp @@ -47,18 +47,6 @@ Id VsOutputAttrPointer(EmitContext& ctx, VsOutput output) { } } -// TODO refactor -Id SpecialOutputAttrPointer(EmitContext& ctx, IR::Attribute attr, u32 element) { - const u32 attr_index{u32(attr) - u32(IR::Attribute::Param0)}; - if (ctx.stage == Stage::Local && ctx.runtime_info.ls_info.links_with_tcs) { - // TODO refactor all this - const auto component_ptr = ctx.TypePointer(spv::StorageClass::Output, ctx.F32[1]); - return ctx.OpAccessChain(component_ptr, ctx.output_attr_array, ctx.ConstU32(attr_index), - ctx.ConstU32(element)); - } - UNREACHABLE(); -} - Id OutputAttrPointer(EmitContext& ctx, IR::Attribute attr, u32 element) { if (IR::IsParam(attr)) { const u32 index{u32(attr) - u32(IR::Attribute::Param0)}; @@ -232,7 +220,16 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, u32 comp, Id index) { UNREACHABLE(); } - if (ctx.info.l_stage == LogicalStage::TessellationControl) { + // TODO refactor + if (ctx.l_stage == LogicalStage::TessellationControl || + ctx.l_stage == LogicalStage::TessellationEval) { + if (IR::IsTessCoord(attr)) { + const u32 component = attr == IR::Attribute::TessellationEvaluationPointU ? 0 : 1; + const auto component_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[1]); + const auto pointer{ + ctx.OpAccessChain(component_ptr, ctx.tess_coord, ctx.ConstU32(component))}; + return ctx.OpLoad(ctx.F32[1], pointer); + } ASSERT(IR::IsParam(attr)); const u32 attr_index{u32(attr) - u32(IR::Attribute::Param0)}; const auto attr_comp_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[1]); @@ -343,7 +340,19 @@ void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, Id value, u32 elemen // TODO refactor if (ctx.stage == Stage::Local && ctx.runtime_info.ls_info.links_with_tcs) { - ctx.OpStore(SpecialOutputAttrPointer(ctx, attr, element), value); + const u32 attr_index{u32(attr) - u32(IR::Attribute::Param0)}; + const auto component_ptr = ctx.TypePointer(spv::StorageClass::Output, ctx.F32[1]); + Id pointer = ctx.OpAccessChain(component_ptr, ctx.output_attr_array, + ctx.ConstU32(attr_index), ctx.ConstU32(element)); + ctx.OpStore(pointer, value); + return; + } else if (ctx.l_stage == LogicalStage::TessellationControl) { + const u32 attr_index{u32(attr) - u32(IR::Attribute::Param0)}; + const auto component_ptr = ctx.TypePointer(spv::StorageClass::Output, ctx.F32[1]); + Id pointer = ctx.OpAccessChain(component_ptr, ctx.output_attr_array, + ctx.OpLoad(ctx.U32[1], ctx.invocation_id), + ctx.ConstU32(attr_index), ctx.ConstU32(element)); + ctx.OpStore(pointer, value); return; } diff --git a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp index 3d7f611a4..62064feff 100644 --- a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp +++ b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp @@ -421,6 +421,7 @@ void EmitContext::DefineInputs() { tess_coord = DefineInput(F32[3], std::nullopt, spv::BuiltIn::TessCoord); primitive_id = DefineVariable(U32[1], spv::BuiltIn::PrimitiveId, spv::StorageClass::Input); +#if 0 for (u32 i = 0; i < IR::NumParams; i++) { const IR::Attribute param{IR::Attribute::Param0 + i}; if (!info.loads.GetAny(param)) { @@ -433,6 +434,14 @@ void EmitContext::DefineInputs() { Name(id, fmt::format("in_attr{}", i)); input_params[i] = {id, input_f32, F32[1], 4}; } +#else + const u32 num_attrs = runtime_info.vs_info.hs_output_cp_stride >> 4; + const Id per_vertex_type{TypeArray(F32[4], ConstU32(num_attrs))}; + // The input vertex count isn't statically known, so make length 32 (what glslang does) + const Id patch_array_type{TypeArray(per_vertex_type, ConstU32(32u))}; + input_attr_array = DefineInput(patch_array_type, 0); + Name(input_attr_array, "in_attrs"); +#endif u32 patch_base_location = runtime_info.vs_info.hs_output_cp_stride >> 4; for (size_t index = 0; index < 30; ++index) { @@ -502,6 +511,7 @@ void EmitContext::DefineOutputs() { Decorate(output_tess_level_inner, spv::Decoration::Patch); } +#if 0 for (u32 i = 0; i < IR::NumParams; i++) { const IR::Attribute param{IR::Attribute::Param0 + i}; if (!info.stores.GetAny(param)) { @@ -514,6 +524,15 @@ void EmitContext::DefineOutputs() { Name(id, fmt::format("out_attr{}", i)); output_params[i] = {id, output_f32, F32[1], 4}; } +#else + const u32 num_attrs = runtime_info.hs_info.hs_cp_stride >> 4; + const Id per_vertex_type{TypeArray(F32[4], ConstU32(num_attrs))}; + // The input vertex count isn't statically known, so make length 32 (what glslang does) + const Id patch_array_type{ + TypeArray(per_vertex_type, ConstU32(runtime_info.hs_info.output_control_points))}; + output_attr_array = DefineOutput(patch_array_type, 0); + Name(output_attr_array, "out_attrs"); +#endif u32 patch_base_location = runtime_info.hs_info.hs_output_cp_stride >> 4; for (size_t index = 0; index < 30; ++index) {