From b66db74c617962cebc71a1c0f37fd4eb418afe5e Mon Sep 17 00:00:00 2001 From: Frodo Baggins Date: Sun, 24 Nov 2024 20:21:57 -0800 Subject: [PATCH] impl tcs/tes read attr insts --- .../spirv/emit_spirv_context_get_set.cpp | 65 ++++++++----------- .../ir/passes/hull_shader_transform.cpp | 33 +++++----- 2 files changed, 44 insertions(+), 54 deletions(-) diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp index 1326c201c..7ba1f966c 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp @@ -49,13 +49,19 @@ Id VsOutputAttrPointer(EmitContext& ctx, VsOutput output) { Id OutputAttrPointer(EmitContext& ctx, IR::Attribute attr, u32 element) { if (IR::IsParam(attr)) { - const u32 index{u32(attr) - u32(IR::Attribute::Param0)}; - const auto& info{ctx.output_params.at(index)}; - ASSERT(info.num_components > 0); - if (info.num_components == 1) { - return info.id; + const u32 attr_index{u32(attr) - u32(IR::Attribute::Param0)}; + if (ctx.stage == Stage::Local && ctx.runtime_info.ls_info.links_with_tcs) { + const auto component_ptr = ctx.TypePointer(spv::StorageClass::Output, ctx.F32[1]); + return ctx.OpAccessChain(component_ptr, ctx.output_attr_array, ctx.ConstU32(attr_index), + ctx.ConstU32(element)); } else { - return ctx.OpAccessChain(info.pointer_type, info.id, ctx.ConstU32(element)); + const auto& info{ctx.output_params.at(attr_index)}; + ASSERT(info.num_components > 0); + if (info.num_components == 1) { + return info.id; + } else { + return ctx.OpAccessChain(info.pointer_type, info.id, ctx.ConstU32(element)); + } } } if (IR::IsMrt(attr)) { @@ -86,9 +92,13 @@ Id OutputAttrPointer(EmitContext& ctx, IR::Attribute attr, u32 element) { std::pair OutputAttrComponentType(EmitContext& ctx, IR::Attribute attr) { if (IR::IsParam(attr)) { - const u32 index{u32(attr) - u32(IR::Attribute::Param0)}; - const auto& info{ctx.output_params.at(index)}; - return {info.component_type, info.is_integer}; + if (ctx.stage == Stage::Local && ctx.runtime_info.ls_info.links_with_tcs) { + return {ctx.F32[1], false}; + } else { + const u32 index{u32(attr) - u32(IR::Attribute::Param0)}; + const auto& info{ctx.output_params.at(index)}; + return {info.component_type, info.is_integer}; + } } if (IR::IsMrt(attr)) { const u32 index{u32(attr) - u32(IR::Attribute::RenderTarget0)}; @@ -220,22 +230,12 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, u32 comp, Id index) { UNREACHABLE(); } - // TODO refactor - if (ctx.l_stage == LogicalStage::TessellationControl || - ctx.l_stage == LogicalStage::TessellationEval) { - if (IR::IsTessCoord(attr)) { - const u32 component = attr == IR::Attribute::TessellationEvaluationPointU ? 0 : 1; - const auto component_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[1]); - const auto pointer{ - ctx.OpAccessChain(component_ptr, ctx.tess_coord, ctx.ConstU32(component))}; - return ctx.OpLoad(ctx.F32[1], pointer); - } - ASSERT(IR::IsParam(attr)); - const u32 attr_index{u32(attr) - u32(IR::Attribute::Param0)}; - const auto attr_comp_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[1]); - return ctx.OpLoad(ctx.F32[1], - ctx.OpAccessChain(attr_comp_ptr, ctx.input_attr_array, index, - ctx.ConstU32(attr_index), ctx.ConstU32(comp))); + if (IR::IsTessCoord(attr)) { + const u32 component = attr == IR::Attribute::TessellationEvaluationPointU ? 0 : 1; + const auto component_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[1]); + const auto pointer{ + ctx.OpAccessChain(component_ptr, ctx.tess_coord, ctx.ConstU32(component))}; + return ctx.OpLoad(ctx.F32[1], pointer); } if (IR::IsParam(attr)) { @@ -337,17 +337,6 @@ void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, Id value, u32 elemen LOG_WARNING(Render_Vulkan, "Ignoring pos1 export"); return; } - - // TODO refactor - if (ctx.stage == Stage::Local && ctx.runtime_info.ls_info.links_with_tcs) { - const u32 attr_index{u32(attr) - u32(IR::Attribute::Param0)}; - const auto component_ptr = ctx.TypePointer(spv::StorageClass::Output, ctx.F32[1]); - Id pointer = ctx.OpAccessChain(component_ptr, ctx.output_attr_array, - ctx.ConstU32(attr_index), ctx.ConstU32(element)); - ctx.OpStore(pointer, value); - return; - } - const Id pointer{OutputAttrPointer(ctx, attr, element)}; const auto component_type{OutputAttrComponentType(ctx, attr)}; if (component_type.second) { @@ -358,7 +347,9 @@ void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, Id value, u32 elemen } Id EmitGetTessGenericAttribute(EmitContext& ctx, Id vertex_index, Id attr_index, Id comp_index) { - UNREACHABLE(); + const auto attr_comp_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[1]); + return ctx.OpLoad(ctx.F32[1], ctx.OpAccessChain(attr_comp_ptr, ctx.input_attr_array, + vertex_index, attr_index, comp_index)); } void EmitSetTcsGenericAttribute(EmitContext& ctx, Id value, Id attr_index, Id comp_index) { diff --git a/src/shader_recompiler/ir/passes/hull_shader_transform.cpp b/src/shader_recompiler/ir/passes/hull_shader_transform.cpp index 88e01d706..81598ee84 100644 --- a/src/shader_recompiler/ir/passes/hull_shader_transform.cpp +++ b/src/shader_recompiler/ir/passes/hull_shader_transform.cpp @@ -593,15 +593,14 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { case AttributeRegion::InputCP: { u32 offset_dw = (address_info.attribute_byte_offset % runtime_info.hs_info.ls_stride) >> 2; - const u32 param = offset_dw >> 2; + const u32 attr_no = offset_dw >> 2; const u32 comp = offset_dw & 3; - IR::Value control_point_index = - ir.IDiv(IR::U32{address_info.offset_in_patch}, - ir.Imm32(runtime_info.hs_info.ls_stride)); - IR::Value get_attrib = - ir.GetAttribute(IR::Attribute::Param0 + param, comp, control_point_index); - get_attrib = ir.BitCast(IR::F32{get_attrib}); - inst.ReplaceUsesWithAndRemove(get_attrib); + IR::U32 control_point_index = ir.IDiv(IR::U32{address_info.offset_in_patch}, + ir.Imm32(runtime_info.hs_info.ls_stride)); + IR::Value attr_read = ir.GetTessGenericAttribute( + control_point_index, ir.Imm32(attr_no), ir.Imm32(comp)); + attr_read = ir.BitCast(IR::F32{attr_read}); + inst.ReplaceUsesWithAndRemove(attr_read); break; } case AttributeRegion::OutputCP: { @@ -641,10 +640,10 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { const auto invocation_id = ir.GetAttributeU32(IR::Attribute::InvocationId); for (u32 attr_no = 0; attr_no < num_attributes; attr_no++) { for (u32 comp = 0; comp < 4; comp++) { - const auto input_attr = - ir.GetAttribute(IR::Attribute::Param0 + attr_no, comp, invocation_id); + IR::F32 attr_read = + ir.GetTessGenericAttribute(invocation_id, ir.Imm32(attr_no), ir.Imm32(comp)); // InvocationId is implicit index for output control point writes - ir.SetTcsGenericAttribute(input_attr, ir.Imm32(attr_no), ir.Imm32(comp)); + ir.SetTcsGenericAttribute(attr_read, ir.Imm32(attr_no), ir.Imm32(comp)); } } // TODO: wrap rest of program with if statement when passthrough? @@ -677,15 +676,15 @@ void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { u32 offset_dw = (address_info.attribute_byte_offset % runtime_info.vs_info.hs_output_cp_stride) >> 2; - const u32 param = offset_dw >> 2; + const u32 attr_no = offset_dw >> 2; const u32 comp = offset_dw & 3; - IR::Value control_point_index = + IR::U32 control_point_index = ir.IDiv(IR::U32{address_info.offset_in_patch}, ir.Imm32(runtime_info.vs_info.hs_output_cp_stride)); - IR::Value get_attrib = - ir.GetAttribute(IR::Attribute::Param0 + param, comp, control_point_index); - get_attrib = ir.BitCast(IR::F32{get_attrib}); - inst.ReplaceUsesWithAndRemove(get_attrib); + IR::Value attr_read = ir.GetTessGenericAttribute( + control_point_index, ir.Imm32(attr_no), ir.Imm32(comp)); + attr_read = ir.BitCast(IR::F32{attr_read}); + inst.ReplaceUsesWithAndRemove(attr_read); break; } case AttributeRegion::PatchConst: {