From 50a3081084fe8ff7dd71b065d35d1aac837f1c9c Mon Sep 17 00:00:00 2001 From: Vladislav Mikhalin Date: Mon, 22 Sep 2025 18:50:57 +0300 Subject: [PATCH] ajm: handle ParseRiffHeader flag (#3618) * ajm: handle ParseRiffheader flag * small optimizations and cleanup * allow uninitialized instances handle RIFF * fixed audio cutoff and small refactoring * small fix to the returned data * fix gapless init, reset total samples on RIFF init * warning reporting + consume input buffer on gapless loop --- src/core/libraries/ajm/ajm_at9.cpp | 120 +++++++++++++++++++++--- src/core/libraries/ajm/ajm_at9.h | 9 +- src/core/libraries/ajm/ajm_instance.cpp | 77 ++++++++++----- src/core/libraries/ajm/ajm_instance.h | 12 ++- src/core/libraries/ajm/ajm_mp3.cpp | 11 ++- src/core/libraries/ajm/ajm_mp3.h | 4 +- 6 files changed, 184 insertions(+), 49 deletions(-) diff --git a/src/core/libraries/ajm/ajm_at9.cpp b/src/core/libraries/ajm/ajm_at9.cpp index 936929ae8..014d1a4e5 100644 --- a/src/core/libraries/ajm/ajm_at9.cpp +++ b/src/core/libraries/ajm/ajm_at9.cpp @@ -10,15 +10,51 @@ extern "C" { #include } -#include - namespace Libraries::Ajm { +struct ChunkHeader { + u32 tag; + u32 length; +}; +static_assert(sizeof(ChunkHeader) == 8); + +struct AudioFormat { + u16 fmt_type; + u16 num_channels; + u32 avg_sample_rate; + u32 avg_byte_rate; + u16 block_align; + u16 bits_per_sample; + u16 ext_size; + union { + u16 valid_bits_per_sample; + u16 samples_per_block; + u16 reserved; + }; + u32 channel_mask; + u8 guid[16]; + u32 version; + u8 config_data[4]; + u32 reserved2; +}; +static_assert(sizeof(AudioFormat) == 52); + +struct SampleData { + u32 sample_length; + u32 encoder_delay; + u32 encoder_delay2; +}; +static_assert(sizeof(SampleData) == 12); + +struct RIFFHeader { + u32 riff; + u32 size; + u32 wave; +}; +static_assert(sizeof(RIFFHeader) == 12); + AjmAt9Decoder::AjmAt9Decoder(AjmFormatEncoding format, AjmAt9CodecFlags flags) - : m_format(format), m_flags(flags), m_handle(Atrac9GetHandle()) { - ASSERT_MSG(m_handle, "Atrac9GetHandle failed"); - AjmAt9Decoder::Reset(); -} + : m_format(format), m_flags(flags), m_handle(Atrac9GetHandle()) {} AjmAt9Decoder::~AjmAt9Decoder() { Atrac9ReleaseHandle(m_handle); @@ -42,6 +78,7 @@ void AjmAt9Decoder::Initialize(const void* buffer, u32 buffer_size) { AjmAt9Decoder::Reset(); m_pcm_buffer.resize(m_codec_info.frameSamples * m_codec_info.channels * GetPCMSize(m_format), 0); + m_is_initialized = true; } void AjmAt9Decoder::GetInfo(void* out_info) const { @@ -52,8 +89,64 @@ void AjmAt9Decoder::GetInfo(void* out_info) const { info->next_frame_size = static_cast(m_handle)->Config.FrameBytes; } -std::tuple AjmAt9Decoder::ProcessData(std::span& in_buf, SparseOutputBuffer& output, - AjmInstanceGapless& gapless) { +u8 g_at9_guid[] = {0xD2, 0x42, 0xE1, 0x47, 0xBA, 0x36, 0x8D, 0x4D, + 0x88, 0xFC, 0x61, 0x65, 0x4F, 0x8C, 0x83, 0x6C}; + +void AjmAt9Decoder::ParseRIFFHeader(std::span& in_buf, AjmInstanceGapless& gapless) { + auto* header = reinterpret_cast(in_buf.data()); + in_buf = in_buf.subspan(sizeof(RIFFHeader)); + + ASSERT(header->riff == 'FFIR'); + ASSERT(header->wave == 'EVAW'); + + auto* chunk = reinterpret_cast(in_buf.data()); + in_buf = in_buf.subspan(sizeof(ChunkHeader)); + while (chunk->tag != 'atad') { + switch (chunk->tag) { + case ' tmf': { + ASSERT(chunk->length == sizeof(AudioFormat)); + auto* fmt = reinterpret_cast(in_buf.data()); + + ASSERT(fmt->fmt_type == 0xFFFE); + ASSERT(memcmp(fmt->guid, g_at9_guid, 16) == 0); + AjmDecAt9InitializeParameters init_params = {}; + std::memcpy(init_params.config_data, fmt->config_data, ORBIS_AT9_CONFIG_DATA_SIZE); + Initialize(&init_params, sizeof(init_params)); + break; + } + case 'tcaf': { + ASSERT(chunk->length == sizeof(SampleData)); + auto* samples = reinterpret_cast(in_buf.data()); + + gapless.init.total_samples = samples->sample_length; + gapless.init.skip_samples = samples->encoder_delay; + gapless.Reset(); + break; + } + default: + break; + } + in_buf = in_buf.subspan(chunk->length); + + chunk = reinterpret_cast(in_buf.data()); + in_buf = in_buf.subspan(sizeof(ChunkHeader)); + } +} + +std::tuple AjmAt9Decoder::ProcessData(std::span& in_buf, + SparseOutputBuffer& output, + AjmInstanceGapless& gapless) { + bool is_reset = false; + if (True(m_flags & AjmAt9CodecFlags::ParseRiffHeader) && + *reinterpret_cast(in_buf.data()) == 'FFIR') { + ParseRIFFHeader(in_buf, gapless); + is_reset = true; + } + + if (!m_is_initialized) { + return {0, 0, is_reset}; + } + int ret = 0; int bytes_used = 0; switch (m_format) { @@ -118,7 +211,7 @@ std::tuple AjmAt9Decoder::ProcessData(std::span& in_buf, SparseOut m_num_frames = 0; } - return {1, samples_written}; + return {1, m_codec_info.frameSamples, is_reset}; } AjmSidebandFormat AjmAt9Decoder::GetFormat() const { @@ -134,12 +227,13 @@ AjmSidebandFormat AjmAt9Decoder::GetFormat() const { } u32 AjmAt9Decoder::GetNextFrameSize(const AjmInstanceGapless& gapless) const { - const auto max_samples = + const auto skip_samples = + std::min(gapless.current.skip_samples, m_codec_info.frameSamples); + const auto samples = gapless.init.total_samples != 0 - ? std::min(gapless.current.total_samples, u32(m_codec_info.frameSamples)) + ? std::min(gapless.current.total_samples, m_codec_info.frameSamples - skip_samples) : m_codec_info.frameSamples; - const auto skip_samples = std::min(u32(gapless.current.skip_samples), max_samples); - return (max_samples - skip_samples) * m_codec_info.channels * GetPCMSize(m_format); + return samples * m_codec_info.channels * GetPCMSize(m_format); } } // namespace Libraries::Ajm diff --git a/src/core/libraries/ajm/ajm_at9.h b/src/core/libraries/ajm/ajm_at9.h index 689681dec..3262f1aa0 100644 --- a/src/core/libraries/ajm/ajm_at9.h +++ b/src/core/libraries/ajm/ajm_at9.h @@ -9,6 +9,7 @@ #include "libatrac9.h" #include +#include namespace Libraries::Ajm { @@ -36,8 +37,8 @@ struct AjmAt9Decoder final : AjmCodec { void GetInfo(void* out_info) const override; AjmSidebandFormat GetFormat() const override; u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const override; - std::tuple ProcessData(std::span& input, SparseOutputBuffer& output, - AjmInstanceGapless& gapless) override; + std::tuple ProcessData(std::span& input, SparseOutputBuffer& output, + AjmInstanceGapless& gapless) override; private: template @@ -49,8 +50,12 @@ private: return output.Write(pcm_data.subspan(0, pcm_size)); } + void ParseRIFFHeader(std::span& input, AjmInstanceGapless& gapless); + const AjmFormatEncoding m_format; const AjmAt9CodecFlags m_flags; + + bool m_is_initialized{}; void* m_handle{}; u8 m_config_data[ORBIS_AT9_CONFIG_DATA_SIZE]{}; u32 m_superframe_bytes_remain{}; diff --git a/src/core/libraries/ajm/ajm_instance.cpp b/src/core/libraries/ajm/ajm_instance.cpp index 01b1d2b21..c4ea395b9 100644 --- a/src/core/libraries/ajm/ajm_instance.cpp +++ b/src/core/libraries/ajm/ajm_instance.cpp @@ -52,21 +52,22 @@ AjmInstance::AjmInstance(AjmCodecType codec_type, AjmInstanceFlags flags) : m_fl } } +void AjmInstance::Reset() { + m_total_samples = 0; + m_gapless.Reset(); + m_codec->Reset(); +} + void AjmInstance::ExecuteJob(AjmJob& job) { const auto control_flags = job.flags.control_flags; if (True(control_flags & AjmJobControlFlags::Reset)) { LOG_TRACE(Lib_Ajm, "Resetting instance {}", job.instance_id); - m_format = {}; - m_gapless = {}; - m_resample_parameters = {}; - m_total_samples = 0; - m_codec->Reset(); + Reset(); } if (job.input.init_params.has_value()) { LOG_TRACE(Lib_Ajm, "Initializing instance {}", job.instance_id); auto& params = job.input.init_params.value(); m_codec->Initialize(¶ms, sizeof(params)); - is_initialized = true; } if (job.input.resample_parameters.has_value()) { LOG_ERROR(Lib_Ajm, "Unimplemented: resample parameters"); @@ -78,20 +79,37 @@ void AjmInstance::ExecuteJob(AjmJob& job) { } if (job.input.gapless_decode.has_value()) { auto& params = job.input.gapless_decode.value(); - if (params.total_samples != 0) { - const auto max = std::max(params.total_samples, m_gapless.init.total_samples); - m_gapless.current.total_samples += max - m_gapless.init.total_samples; - m_gapless.init.total_samples = max; - } - if (params.skip_samples != 0) { - const auto max = std::max(params.skip_samples, m_gapless.init.skip_samples); - m_gapless.current.skip_samples += max - m_gapless.init.skip_samples; - m_gapless.init.skip_samples = max; - } - } - if (!is_initialized) { - return; + const auto samples_processed = + m_gapless.init.total_samples - m_gapless.current.total_samples; + if (params.total_samples != 0 || params.skip_samples == 0) { + if (params.total_samples >= samples_processed) { + const auto sample_difference = + s64(m_gapless.init.total_samples) - params.total_samples; + + m_gapless.init.total_samples = params.total_samples; + m_gapless.current.total_samples -= sample_difference; + } else { + LOG_WARNING(Lib_Ajm, "ORBIS_AJM_RESULT_INVALID_PARAMETER"); + job.output.p_result->result = ORBIS_AJM_RESULT_INVALID_PARAMETER; + return; + } + } + + const auto samples_skipped = m_gapless.init.skip_samples - m_gapless.current.skip_samples; + if (params.skip_samples != 0 || params.total_samples == 0) { + if (params.skip_samples >= samples_skipped) { + const auto sample_difference = + s32(m_gapless.init.skip_samples) - params.skip_samples; + + m_gapless.init.skip_samples = params.skip_samples; + m_gapless.current.skip_samples -= sample_difference; + } else { + LOG_WARNING(Lib_Ajm, "ORBIS_AJM_RESULT_INVALID_PARAMETER"); + job.output.p_result->result = ORBIS_AJM_RESULT_INVALID_PARAMETER; + return; + } + } } if (!job.input.buffer.empty() && !job.output.buffers.empty()) { @@ -104,12 +122,23 @@ void AjmInstance::ExecuteJob(AjmJob& job) { while (!in_buf.empty() && !out_buf.IsEmpty() && !m_gapless.IsEnd()) { if (!HasEnoughSpace(out_buf)) { if (job.output.p_mframe == nullptr || frames_decoded == 0) { + LOG_WARNING(Lib_Ajm, "ORBIS_AJM_RESULT_NOT_ENOUGH_ROOM ({} < {})", + out_buf.Size(), m_codec->GetNextFrameSize(m_gapless)); job.output.p_result->result = ORBIS_AJM_RESULT_NOT_ENOUGH_ROOM; break; } } - const auto [nframes, nsamples] = m_codec->ProcessData(in_buf, out_buf, m_gapless); + const auto [nframes, nsamples, reset] = + m_codec->ProcessData(in_buf, out_buf, m_gapless); + if (reset) { + m_total_samples = 0; + } + if (!nframes) { + LOG_WARNING(Lib_Ajm, "ORBIS_AJM_RESULT_NOT_INITIALIZED"); + job.output.p_result->result = ORBIS_AJM_RESULT_NOT_INITIALIZED; + break; + } frames_decoded += nframes; m_total_samples += nsamples; @@ -118,10 +147,10 @@ void AjmInstance::ExecuteJob(AjmJob& job) { } } - if (m_gapless.IsEnd()) { + const auto total_decoded_samples = m_total_samples; + if (m_flags.gapless_loop && m_gapless.IsEnd()) { in_buf = in_buf.subspan(in_buf.size()); - m_gapless.current.total_samples = m_gapless.init.total_samples; - m_gapless.current.skip_samples = m_gapless.init.skip_samples; + m_gapless.Reset(); m_codec->Reset(); } if (job.output.p_mframe) { @@ -130,7 +159,7 @@ void AjmInstance::ExecuteJob(AjmJob& job) { if (job.output.p_stream) { job.output.p_stream->input_consumed = in_size - in_buf.size(); job.output.p_stream->output_written = out_size - out_buf.Size(); - job.output.p_stream->total_decoded_samples = m_total_samples; + job.output.p_stream->total_decoded_samples = total_decoded_samples; } } diff --git a/src/core/libraries/ajm/ajm_instance.h b/src/core/libraries/ajm/ajm_instance.h index e02ac6ffb..ad0a82f29 100644 --- a/src/core/libraries/ajm/ajm_instance.h +++ b/src/core/libraries/ajm/ajm_instance.h @@ -65,6 +65,12 @@ struct AjmInstanceGapless { bool IsEnd() const { return init.total_samples != 0 && current.total_samples == 0; } + + void Reset() { + current.total_samples = init.total_samples; + current.skip_samples = init.skip_samples; + current.skipped_samples = 0; + } }; class AjmCodec { @@ -76,8 +82,8 @@ public: virtual void GetInfo(void* out_info) const = 0; virtual AjmSidebandFormat GetFormat() const = 0; virtual u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const = 0; - virtual std::tuple ProcessData(std::span& input, SparseOutputBuffer& output, - AjmInstanceGapless& gapless) = 0; + virtual std::tuple ProcessData(std::span& input, SparseOutputBuffer& output, + AjmInstanceGapless& gapless) = 0; }; class AjmInstance { @@ -89,6 +95,7 @@ public: private: bool HasEnoughSpace(const SparseOutputBuffer& output) const; std::optional GetNumRemainingSamples() const; + void Reset(); AjmInstanceFlags m_flags{}; AjmSidebandFormat m_format{}; @@ -96,7 +103,6 @@ private: AjmSidebandResampleParameters m_resample_parameters{}; u32 m_total_samples{}; std::unique_ptr m_codec; - bool is_initialized = false; }; } // namespace Libraries::Ajm diff --git a/src/core/libraries/ajm/ajm_mp3.cpp b/src/core/libraries/ajm/ajm_mp3.cpp index 2c572a01b..f17f53d51 100644 --- a/src/core/libraries/ajm/ajm_mp3.cpp +++ b/src/core/libraries/ajm/ajm_mp3.cpp @@ -138,8 +138,9 @@ void AjmMp3Decoder::GetInfo(void* out_info) const { } } -std::tuple AjmMp3Decoder::ProcessData(std::span& in_buf, SparseOutputBuffer& output, - AjmInstanceGapless& gapless) { +std::tuple AjmMp3Decoder::ProcessData(std::span& in_buf, + SparseOutputBuffer& output, + AjmInstanceGapless& gapless) { AVPacket* pkt = av_packet_alloc(); if ((!m_header.has_value() || m_frame_samples == 0) && in_buf.size() >= 4) { @@ -155,7 +156,7 @@ std::tuple AjmMp3Decoder::ProcessData(std::span& in_buf, SparseOut in_buf = in_buf.subspan(ret); u32 frames_decoded = 0; - u32 samples_written = 0; + u32 samples_decoded = 0; if (pkt->size) { // Send the packet with the compressed data to the decoder @@ -176,6 +177,7 @@ std::tuple AjmMp3Decoder::ProcessData(std::span& in_buf, SparseOut UNREACHABLE_MSG("Error during decoding"); } frame = ConvertAudioFrame(frame); + samples_decoded += u32(frame->nb_samples); frames_decoded += 1; u32 skip_samples = 0; @@ -205,7 +207,6 @@ std::tuple AjmMp3Decoder::ProcessData(std::span& in_buf, SparseOut } const auto samples = pcm_written / m_codec_context->ch_layout.nb_channels; - samples_written += samples; gapless.current.skipped_samples += frame->nb_samples - samples; if (gapless.init.total_samples != 0) { gapless.current.total_samples -= samples; @@ -217,7 +218,7 @@ std::tuple AjmMp3Decoder::ProcessData(std::span& in_buf, SparseOut av_packet_free(&pkt); - return {frames_decoded, samples_written}; + return {frames_decoded, samples_decoded, false}; } u32 AjmMp3Decoder::GetNextFrameSize(const AjmInstanceGapless& gapless) const { diff --git a/src/core/libraries/ajm/ajm_mp3.h b/src/core/libraries/ajm/ajm_mp3.h index 70c949550..7ac65fdba 100644 --- a/src/core/libraries/ajm/ajm_mp3.h +++ b/src/core/libraries/ajm/ajm_mp3.h @@ -71,8 +71,8 @@ public: void GetInfo(void* out_info) const override; AjmSidebandFormat GetFormat() const override; u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const override; - std::tuple ProcessData(std::span& input, SparseOutputBuffer& output, - AjmInstanceGapless& gapless) override; + std::tuple ProcessData(std::span& input, SparseOutputBuffer& output, + AjmInstanceGapless& gapless) override; static int ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl, AjmDecMp3ParseFrame* frame);