impl tcs/tes read attr insts

This commit is contained in:
Frodo Baggins 2024-11-24 20:21:57 -08:00
parent 13523d8a09
commit b66db74c61
2 changed files with 44 additions and 54 deletions

View File

@ -49,8 +49,13 @@ Id VsOutputAttrPointer(EmitContext& ctx, VsOutput output) {
Id OutputAttrPointer(EmitContext& ctx, IR::Attribute attr, u32 element) { Id OutputAttrPointer(EmitContext& ctx, IR::Attribute attr, u32 element) {
if (IR::IsParam(attr)) { if (IR::IsParam(attr)) {
const u32 index{u32(attr) - u32(IR::Attribute::Param0)}; const u32 attr_index{u32(attr) - u32(IR::Attribute::Param0)};
const auto& info{ctx.output_params.at(index)}; if (ctx.stage == Stage::Local && ctx.runtime_info.ls_info.links_with_tcs) {
const auto component_ptr = ctx.TypePointer(spv::StorageClass::Output, ctx.F32[1]);
return ctx.OpAccessChain(component_ptr, ctx.output_attr_array, ctx.ConstU32(attr_index),
ctx.ConstU32(element));
} else {
const auto& info{ctx.output_params.at(attr_index)};
ASSERT(info.num_components > 0); ASSERT(info.num_components > 0);
if (info.num_components == 1) { if (info.num_components == 1) {
return info.id; return info.id;
@ -58,6 +63,7 @@ Id OutputAttrPointer(EmitContext& ctx, IR::Attribute attr, u32 element) {
return ctx.OpAccessChain(info.pointer_type, info.id, ctx.ConstU32(element)); return ctx.OpAccessChain(info.pointer_type, info.id, ctx.ConstU32(element));
} }
} }
}
if (IR::IsMrt(attr)) { if (IR::IsMrt(attr)) {
const u32 index{u32(attr) - u32(IR::Attribute::RenderTarget0)}; const u32 index{u32(attr) - u32(IR::Attribute::RenderTarget0)};
const auto& info{ctx.frag_outputs.at(index)}; const auto& info{ctx.frag_outputs.at(index)};
@ -86,10 +92,14 @@ Id OutputAttrPointer(EmitContext& ctx, IR::Attribute attr, u32 element) {
std::pair<Id, bool> OutputAttrComponentType(EmitContext& ctx, IR::Attribute attr) { std::pair<Id, bool> OutputAttrComponentType(EmitContext& ctx, IR::Attribute attr) {
if (IR::IsParam(attr)) { if (IR::IsParam(attr)) {
if (ctx.stage == Stage::Local && ctx.runtime_info.ls_info.links_with_tcs) {
return {ctx.F32[1], false};
} else {
const u32 index{u32(attr) - u32(IR::Attribute::Param0)}; const u32 index{u32(attr) - u32(IR::Attribute::Param0)};
const auto& info{ctx.output_params.at(index)}; const auto& info{ctx.output_params.at(index)};
return {info.component_type, info.is_integer}; return {info.component_type, info.is_integer};
} }
}
if (IR::IsMrt(attr)) { if (IR::IsMrt(attr)) {
const u32 index{u32(attr) - u32(IR::Attribute::RenderTarget0)}; const u32 index{u32(attr) - u32(IR::Attribute::RenderTarget0)};
const auto& info{ctx.frag_outputs.at(index)}; const auto& info{ctx.frag_outputs.at(index)};
@ -220,9 +230,6 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, u32 comp, Id index) {
UNREACHABLE(); UNREACHABLE();
} }
// TODO refactor
if (ctx.l_stage == LogicalStage::TessellationControl ||
ctx.l_stage == LogicalStage::TessellationEval) {
if (IR::IsTessCoord(attr)) { if (IR::IsTessCoord(attr)) {
const u32 component = attr == IR::Attribute::TessellationEvaluationPointU ? 0 : 1; const u32 component = attr == IR::Attribute::TessellationEvaluationPointU ? 0 : 1;
const auto component_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[1]); const auto component_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[1]);
@ -230,13 +237,6 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, u32 comp, Id index) {
ctx.OpAccessChain(component_ptr, ctx.tess_coord, ctx.ConstU32(component))}; ctx.OpAccessChain(component_ptr, ctx.tess_coord, ctx.ConstU32(component))};
return ctx.OpLoad(ctx.F32[1], pointer); return ctx.OpLoad(ctx.F32[1], pointer);
} }
ASSERT(IR::IsParam(attr));
const u32 attr_index{u32(attr) - u32(IR::Attribute::Param0)};
const auto attr_comp_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[1]);
return ctx.OpLoad(ctx.F32[1],
ctx.OpAccessChain(attr_comp_ptr, ctx.input_attr_array, index,
ctx.ConstU32(attr_index), ctx.ConstU32(comp)));
}
if (IR::IsParam(attr)) { if (IR::IsParam(attr)) {
const u32 index{u32(attr) - u32(IR::Attribute::Param0)}; const u32 index{u32(attr) - u32(IR::Attribute::Param0)};
@ -337,17 +337,6 @@ void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, Id value, u32 elemen
LOG_WARNING(Render_Vulkan, "Ignoring pos1 export"); LOG_WARNING(Render_Vulkan, "Ignoring pos1 export");
return; return;
} }
// TODO refactor
if (ctx.stage == Stage::Local && ctx.runtime_info.ls_info.links_with_tcs) {
const u32 attr_index{u32(attr) - u32(IR::Attribute::Param0)};
const auto component_ptr = ctx.TypePointer(spv::StorageClass::Output, ctx.F32[1]);
Id pointer = ctx.OpAccessChain(component_ptr, ctx.output_attr_array,
ctx.ConstU32(attr_index), ctx.ConstU32(element));
ctx.OpStore(pointer, value);
return;
}
const Id pointer{OutputAttrPointer(ctx, attr, element)}; const Id pointer{OutputAttrPointer(ctx, attr, element)};
const auto component_type{OutputAttrComponentType(ctx, attr)}; const auto component_type{OutputAttrComponentType(ctx, attr)};
if (component_type.second) { if (component_type.second) {
@ -358,7 +347,9 @@ void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, Id value, u32 elemen
} }
Id EmitGetTessGenericAttribute(EmitContext& ctx, Id vertex_index, Id attr_index, Id comp_index) { Id EmitGetTessGenericAttribute(EmitContext& ctx, Id vertex_index, Id attr_index, Id comp_index) {
UNREACHABLE(); const auto attr_comp_ptr = ctx.TypePointer(spv::StorageClass::Input, ctx.F32[1]);
return ctx.OpLoad(ctx.F32[1], ctx.OpAccessChain(attr_comp_ptr, ctx.input_attr_array,
vertex_index, attr_index, comp_index));
} }
void EmitSetTcsGenericAttribute(EmitContext& ctx, Id value, Id attr_index, Id comp_index) { void EmitSetTcsGenericAttribute(EmitContext& ctx, Id value, Id attr_index, Id comp_index) {

View File

@ -593,15 +593,14 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
case AttributeRegion::InputCP: { case AttributeRegion::InputCP: {
u32 offset_dw = u32 offset_dw =
(address_info.attribute_byte_offset % runtime_info.hs_info.ls_stride) >> 2; (address_info.attribute_byte_offset % runtime_info.hs_info.ls_stride) >> 2;
const u32 param = offset_dw >> 2; const u32 attr_no = offset_dw >> 2;
const u32 comp = offset_dw & 3; const u32 comp = offset_dw & 3;
IR::Value control_point_index = IR::U32 control_point_index = ir.IDiv(IR::U32{address_info.offset_in_patch},
ir.IDiv(IR::U32{address_info.offset_in_patch},
ir.Imm32(runtime_info.hs_info.ls_stride)); ir.Imm32(runtime_info.hs_info.ls_stride));
IR::Value get_attrib = IR::Value attr_read = ir.GetTessGenericAttribute(
ir.GetAttribute(IR::Attribute::Param0 + param, comp, control_point_index); control_point_index, ir.Imm32(attr_no), ir.Imm32(comp));
get_attrib = ir.BitCast<IR::U32>(IR::F32{get_attrib}); attr_read = ir.BitCast<IR::U32>(IR::F32{attr_read});
inst.ReplaceUsesWithAndRemove(get_attrib); inst.ReplaceUsesWithAndRemove(attr_read);
break; break;
} }
case AttributeRegion::OutputCP: { case AttributeRegion::OutputCP: {
@ -641,10 +640,10 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
const auto invocation_id = ir.GetAttributeU32(IR::Attribute::InvocationId); const auto invocation_id = ir.GetAttributeU32(IR::Attribute::InvocationId);
for (u32 attr_no = 0; attr_no < num_attributes; attr_no++) { for (u32 attr_no = 0; attr_no < num_attributes; attr_no++) {
for (u32 comp = 0; comp < 4; comp++) { for (u32 comp = 0; comp < 4; comp++) {
const auto input_attr = IR::F32 attr_read =
ir.GetAttribute(IR::Attribute::Param0 + attr_no, comp, invocation_id); ir.GetTessGenericAttribute(invocation_id, ir.Imm32(attr_no), ir.Imm32(comp));
// InvocationId is implicit index for output control point writes // InvocationId is implicit index for output control point writes
ir.SetTcsGenericAttribute(input_attr, ir.Imm32(attr_no), ir.Imm32(comp)); ir.SetTcsGenericAttribute(attr_read, ir.Imm32(attr_no), ir.Imm32(comp));
} }
} }
// TODO: wrap rest of program with if statement when passthrough? // TODO: wrap rest of program with if statement when passthrough?
@ -677,15 +676,15 @@ void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) {
u32 offset_dw = (address_info.attribute_byte_offset % u32 offset_dw = (address_info.attribute_byte_offset %
runtime_info.vs_info.hs_output_cp_stride) >> runtime_info.vs_info.hs_output_cp_stride) >>
2; 2;
const u32 param = offset_dw >> 2; const u32 attr_no = offset_dw >> 2;
const u32 comp = offset_dw & 3; const u32 comp = offset_dw & 3;
IR::Value control_point_index = IR::U32 control_point_index =
ir.IDiv(IR::U32{address_info.offset_in_patch}, ir.IDiv(IR::U32{address_info.offset_in_patch},
ir.Imm32(runtime_info.vs_info.hs_output_cp_stride)); ir.Imm32(runtime_info.vs_info.hs_output_cp_stride));
IR::Value get_attrib = IR::Value attr_read = ir.GetTessGenericAttribute(
ir.GetAttribute(IR::Attribute::Param0 + param, comp, control_point_index); control_point_index, ir.Imm32(attr_no), ir.Imm32(comp));
get_attrib = ir.BitCast<IR::U32>(IR::F32{get_attrib}); attr_read = ir.BitCast<IR::U32>(IR::F32{attr_read});
inst.ReplaceUsesWithAndRemove(get_attrib); inst.ReplaceUsesWithAndRemove(attr_read);
break; break;
} }
case AttributeRegion::PatchConst: { case AttributeRegion::PatchConst: {