diff --git a/src/shader_recompiler/ir/passes/hull_shader_transform.cpp b/src/shader_recompiler/ir/passes/hull_shader_transform.cpp index 0bf0dc014..dccce4922 100644 --- a/src/shader_recompiler/ir/passes/hull_shader_transform.cpp +++ b/src/shader_recompiler/ir/passes/hull_shader_transform.cpp @@ -380,208 +380,10 @@ static AttributeRegion FindRegionKind(IR::Inst* ring_access, const Shader::Info& } } -static IR::Program* g_program; // TODO delete - -struct RingAddressInfo { - AttributeRegion region{}; - u32 attribute_byte_offset{}; - // For InputCP and OutputCP, offset from the start of the patch's memory (including - // attribute_byte_offset) For PatchConst, not relevant - IR::U32 offset_in_patch{IR::Value(0u)}; -}; - -class Pass { -public: - Pass(Info& info_, RuntimeInfo& runtime_info_) : info(info_), runtime_info(runtime_info_) { - InitTessConstants(info.tess_consts_ptr_base, info.tess_consts_dword_offset, info, - runtime_info, tess_constants); - } - - RingAddressInfo WalkRingAccess(IR::Inst* access, IR::IREmitter& insert_point) { - Reset(); - RingAddressInfo address_info{}; - - IR::Value addr; - switch (access->GetOpcode()) { - case IR::Opcode::LoadSharedU32: - case IR::Opcode::LoadSharedU64: - case IR::Opcode::LoadSharedU128: - case IR::Opcode::WriteSharedU32: - case IR::Opcode::WriteSharedU64: - case IR::Opcode::WriteSharedU128: - addr = access->Arg(0); - break; - case IR::Opcode::StoreBufferU32: - case IR::Opcode::StoreBufferU32x2: - case IR::Opcode::StoreBufferU32x3: - case IR::Opcode::StoreBufferU32x4: - addr = access->Arg(1); - break; - default: - UNREACHABLE(); - } - - products.emplace_back(addr); - Visit(addr); - - FindIndexInfo(address_info, insert_point); - - return address_info; - } - -private: - void Reset() { - within_mul = false; - products.clear(); - } - - void Visit(IR::Value node) { - IR::Value a, b, c; - - if (MakeInstPattern(MatchValue(a), MatchValue(b)).DoMatch(node)) { - bool saved_within_mul = within_mul; - within_mul = true; - Visit(a); - Visit(b); - within_mul = saved_within_mul; - } else if (MakeInstPattern(MatchValue(a), MatchValue(b)) - .DoMatch(node)) { - if (within_mul) { - UNREACHABLE_MSG("Test"); - products.back().as_factors.emplace_back(IR::U32{node}); - } else { - products.back().as_nested_value = IR::U32{a}; - Visit(a); - products.emplace_back(b); - Visit(b); - } - } else if (MakeInstPattern(MatchValue(a), MatchImm(b)) - .DoMatch(node)) { - products.back().as_factors.emplace_back(IR::Value(u32(2 << (b.U32() - 1)))); - Visit(a); - } else if (MakeInstPattern(MatchValue(a), MatchU32(0)) - .DoMatch(node)) { - products.back().as_factors.emplace_back(a.Attribute()); - } else if (MakeInstPattern(MatchValue(a), MatchU32(0)) - .DoMatch(node)) { - products.back().as_factors.emplace_back(a); - } else if (MakeInstPattern(MatchValue(a), MatchIgnore(), - MatchIgnore()) - .DoMatch(node)) { - Visit(a); - } else if (MakeInstPattern(MatchValue(a), MatchIgnore(), - MatchIgnore()) - .DoMatch(node)) { - Visit(a); - } else if (MakeInstPattern(MatchValue(a)).DoMatch(node)) { - return Visit(a); - } else if (MakeInstPattern(MatchValue(a)).DoMatch(node)) { - return Visit(a); - } else if (node.TryInstRecursive() && - node.InstRecursive()->GetOpcode() == IR::Opcode::Phi) { - DEBUG_ASSERT(false && "Phi test"); - products.back().as_factors.emplace_back(node); - } else { - products.back().as_factors.emplace_back(node); - } - } - - void FindIndexInfo(RingAddressInfo& address_info, IR::IREmitter& ir) { - // infer which attribute base the address is indexing - // by how many addends are multiplied by TessellationDataConstantBuffer::m_hsNumPatch. - // Also handle m_hsOutputBase or m_patchConstBase - u32 region_count = 0; - - // Remove addends except for the attribute offset and possibly the - // control point index calc - std::erase_if(products, [&](Product& p) { - for (IR::Value& value : p.as_factors) { - if (value.Type() == IR::Type::Attribute) { - if (value.Attribute() == IR::Attribute::TcsNumPatches || - value.Attribute() == IR::Attribute::TcsOutputBase) { - ++region_count; - return true; - } else if (value.Attribute() == IR::Attribute::TcsPatchConstBase) { - region_count += 2; - return true; - } else if (value.Attribute() == IR::Attribute::TessPatchIdInVgt) { - return true; - } - } - } - return false; - }); - - // DumpIR(*g_program, "before_crash"); - - // Look for some term with a dynamic index (should be the control point index) - for (auto i = 0; i < products.size(); i++) { - auto& factors = products[i].as_factors; - // Remember this as the index term - if (std::any_of(factors.begin(), factors.end(), [&](const IR::Value& v) { - return !v.IsImmediate() || v.Type() == IR::Type::Attribute; - })) { - address_info.offset_in_patch = - ir.IAdd(address_info.offset_in_patch, products[i].as_nested_value); - } else { - ASSERT_MSG(factors.size() == 1, "factors all const but not const folded"); - // Otherwise assume it contributes to the attribute - address_info.offset_in_patch = - ir.IAdd(address_info.offset_in_patch, IR::U32{factors[0]}); - address_info.attribute_byte_offset += factors[0].U32(); - } - } - - if (region_count == 0) { - address_info.region = AttributeRegion::InputCP; - } else if (info.l_stage == LogicalStage::TessellationControl && - runtime_info.hs_info.IsPassthrough()) { - ASSERT(region_count <= 1); - address_info.region = AttributeRegion::PatchConst; - } else { - ASSERT(region_count <= 2); - address_info.region = AttributeRegion(region_count); - } - } - - Info& info; - RuntimeInfo& runtime_info; - - TessellationDataConstantBuffer tess_constants; - bool within_mul{}; - - // One product in the sum of products making up an address - struct Product { - Product(IR::Value val_) : as_nested_value(val_), as_factors() {} - Product(const Product& other) = default; - ~Product() = default; - - // IR value used as an addend in address calc - IR::U32 as_nested_value; - // all the leaves that feed the multiplication, linear - // TODO small_vector - // boost::container::small_vector as_factors; - std::vector as_factors; - }; - - std::vector products; -}; - } // namespace void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { - TessConstantUseWalker walker; - g_program = &program; // TODO delete Info& info = program.info; - Pass pass(info, runtime_info); - - for (IR::Block* block : program.blocks) { - for (IR::Inst& inst : block->Instructions()) { - if (inst.GetOpcode() == IR::Opcode::GetAttributeU32) { - walker.MarkTessAttributeUsers(&inst); - } - } - } for (IR::Block* block : program.blocks) { for (IR::Inst& inst : block->Instructions()) { @@ -646,7 +448,7 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { case IR::Opcode::WriteSharedU32: case IR::Opcode::WriteSharedU64: { // DumpIR(program, "before_walk"); - RingAddressInfo address_info = pass.WalkRingAccess(&inst, ir); + // RingAddressInfo address_info = pass.WalkRingAccess(&inst, ir); const u32 num_dwords = opcode == IR::Opcode::WriteSharedU32 ? 1 @@ -669,7 +471,6 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { } if (output_kind == AttributeRegion::OutputCP) { -#if 0 // Invocation ID array index is implicit, handled by SPIRV backend IR::U32 attr_index = ir.ShiftRightLogical( ir.IMod(addr, ir.GetAttributeU32(IR::Attribute::TcsCpStride)), @@ -677,17 +478,11 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { IR::U32 comp_index = ir.ShiftRightLogical(ir.BitwiseAnd(addr, ir.Imm32(0xFU)), ir.Imm32(2u)); ir.SetTcsGenericAttribute(data, attr_index, comp_index); -#else - u32 offset_dw = address_info.attribute_byte_offset >> 2; - const u32 param = offset_dw >> 2; - const u32 comp = offset_dw & 3; - // Invocation ID array index is implicit, handled by SPIRV backend - ir.SetTcsGenericAttribute(data, ir.Imm32(param), ir.Imm32(comp)); -#endif } else { ASSERT(output_kind == AttributeRegion::PatchConst); - ir.SetPatch(IR::PatchGeneric(address_info.attribute_byte_offset >> 2), - data); + ASSERT_MSG(addr.IsImmediate(), "patch addr non imm, inst {}", + fmt::ptr(addr.Inst())); + ir.SetPatch(IR::PatchGeneric(addr.U32() >> 2), data); } }; @@ -704,20 +499,21 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { case IR::Opcode::LoadSharedU32: { // case IR::Opcode::LoadSharedU64: // case IR::Opcode::LoadSharedU128: - RingAddressInfo address_info = pass.WalkRingAccess(&inst, ir); + const IR::U32 addr{inst.Arg(0)}; + AttributeRegion region = FindRegionKind(&inst, info, runtime_info); - ASSERT(address_info.region == AttributeRegion::InputCP || - address_info.region == AttributeRegion::OutputCP); - switch (address_info.region) { + ASSERT(region == AttributeRegion::InputCP || region == AttributeRegion::OutputCP); + switch (region) { case AttributeRegion::InputCP: { - u32 offset_dw = - (address_info.attribute_byte_offset % runtime_info.hs_info.ls_stride) >> 2; - const u32 attr_no = offset_dw >> 2; - const u32 comp = offset_dw & 3; - IR::U32 control_point_index = ir.IDiv(IR::U32{address_info.offset_in_patch}, - ir.Imm32(runtime_info.hs_info.ls_stride)); - IR::Value attr_read = ir.GetTessGenericAttribute( - control_point_index, ir.Imm32(attr_no), ir.Imm32(comp)); + IR::U32 control_point_index = + ir.IDiv(addr, ir.Imm32(runtime_info.hs_info.ls_stride)); + IR::U32 attr_index = ir.ShiftRightLogical( + ir.IMod(addr, ir.GetAttributeU32(IR::Attribute::TcsLsStride)), + ir.Imm32(4u)); + IR::U32 comp_index = + ir.ShiftRightLogical(ir.BitwiseAnd(addr, ir.Imm32(0xFU)), ir.Imm32(2u)); + IR::Value attr_read = + ir.GetTessGenericAttribute(control_point_index, attr_index, comp_index); attr_read = ir.BitCast(IR::F32{attr_read}); inst.ReplaceUsesWithAndRemove(attr_read); break; @@ -775,17 +571,7 @@ void HullShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { // TODO refactor void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { - TessConstantUseWalker walker; Info& info = program.info; - Pass pass(info, runtime_info); - - for (IR::Block* block : program.blocks) { - for (IR::Inst& inst : block->Instructions()) { - if (inst.GetOpcode() == IR::Opcode::GetAttributeU32) { - walker.MarkTessAttributeUsers(&inst); - } - } - } for (IR::Block* block : program.blocks) { for (IR::Inst& inst : block->Instructions()) { @@ -795,8 +581,7 @@ void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { case IR::Opcode::LoadSharedU32: { // case IR::Opcode::LoadSharedU64: // case IR::Opcode::LoadSharedU128: // TODO - RingAddressInfo address_info = pass.WalkRingAccess(&inst, ir); - const IR::U32 addr = IR::U32{inst.Arg(0)}; + const IR::U32 addr{inst.Arg(0)}; AttributeRegion region = FindRegionKind(&inst, info, runtime_info); ASSERT(region == AttributeRegion::OutputCP || @@ -817,9 +602,10 @@ void DomainShaderTransform(IR::Program& program, RuntimeInfo& runtime_info) { break; } case AttributeRegion::PatchConst: { - // TODO make patch consts into dynamic offset - u32 offset_dw = address_info.attribute_byte_offset >> 2; - IR::Value get_patch = ir.GetPatch(IR::PatchGeneric(offset_dw)); + // TODO if assert fails then make patch consts into dynamic offset + ASSERT_MSG(addr.IsImmediate(), "patch addr non imm, inst {}", + fmt::ptr(addr.Inst())); + IR::Value get_patch = ir.GetPatch(IR::PatchGeneric(addr.U32() >> 2)); inst.ReplaceUsesWithAndRemove(get_patch); break; } @@ -963,6 +749,38 @@ void TessellationPreprocess(IR::Program& program, RuntimeInfo& runtime_info) { } } } + + TessConstantUseWalker walker; + for (IR::Block* block : program.blocks) { + for (IR::Inst& inst : block->Instructions()) { + if (inst.GetOpcode() == IR::Opcode::GetAttributeU32) { + walker.MarkTessAttributeUsers(&inst); + } + } + } + + for (IR::Block* block : program.blocks) { + for (IR::Inst& inst : block->Instructions()) { + if (inst.GetOpcode() == IR::Opcode::GetAttributeU32) { + switch (inst.Arg(0).Attribute()) { + // Should verify that these are only used in address calculations for attr + // read/write + // Replace with 0 so we can dynamically index the control points within the + // region allocated for this patch (input or output). These terms should only + // contribute to the base address of that region, so replacing with 0 *should* + // be fine + case IR::Attribute::TcsNumPatches: + case IR::Attribute::TcsOutputBase: + case IR::Attribute::TcsPatchConstBase: + case IR::Attribute::TessPatchIdInVgt: + inst.ReplaceUsesWithAndRemove(IR::Value(0u)); + break; + default: + break; + } + } + } + } } void TessellationPostprocess(IR::Program& program, RuntimeInfo& runtime_info) { @@ -982,21 +800,6 @@ void TessellationPostprocess(IR::Program& program, RuntimeInfo& runtime_info) { case IR::Attribute::TcsCpStride: inst.ReplaceUsesWithAndRemove(IR::Value(tess_constants.m_hsCpStride)); break; - // Should verify that these are only used in address calculations for attr - // read/write - // Replace with 0 so we can dynamically index the control points within the - // region allocated for this patch (input or output). These terms should only - // contribute to the base address of that region, so replacing with 0 *should* - // be fine - case IR::Attribute::TcsNumPatches: - case IR::Attribute::TcsOutputBase: - case IR::Attribute::TcsPatchConstBase: - case IR::Attribute::TessPatchIdInVgt: - inst.ReplaceUsesWithAndRemove(IR::Value(0u)); - break; - case IR::Attribute::TcsPatchConstSize: - case IR::Attribute::TcsPatchOutputSize: - case IR::Attribute::TcsFirstEdgeTessFactorIndex: default: break; }