ajm: handle ParseRiffHeader flag (#3618)

* ajm: handle ParseRiffheader flag

* small optimizations and cleanup

* allow uninitialized instances handle RIFF

* fixed audio cutoff and small refactoring

* small fix to the returned data

* fix gapless init, reset total samples on RIFF init

* warning reporting + consume input buffer on gapless loop
This commit is contained in:
Vladislav Mikhalin
2025-09-22 18:50:57 +03:00
committed by GitHub
parent 525d24a7fc
commit 50a3081084
6 changed files with 184 additions and 49 deletions

View File

@@ -10,15 +10,51 @@ extern "C" {
#include <libatrac9.h>
}
#include <vector>
namespace Libraries::Ajm {
struct ChunkHeader {
u32 tag;
u32 length;
};
static_assert(sizeof(ChunkHeader) == 8);
struct AudioFormat {
u16 fmt_type;
u16 num_channels;
u32 avg_sample_rate;
u32 avg_byte_rate;
u16 block_align;
u16 bits_per_sample;
u16 ext_size;
union {
u16 valid_bits_per_sample;
u16 samples_per_block;
u16 reserved;
};
u32 channel_mask;
u8 guid[16];
u32 version;
u8 config_data[4];
u32 reserved2;
};
static_assert(sizeof(AudioFormat) == 52);
struct SampleData {
u32 sample_length;
u32 encoder_delay;
u32 encoder_delay2;
};
static_assert(sizeof(SampleData) == 12);
struct RIFFHeader {
u32 riff;
u32 size;
u32 wave;
};
static_assert(sizeof(RIFFHeader) == 12);
AjmAt9Decoder::AjmAt9Decoder(AjmFormatEncoding format, AjmAt9CodecFlags flags)
: m_format(format), m_flags(flags), m_handle(Atrac9GetHandle()) {
ASSERT_MSG(m_handle, "Atrac9GetHandle failed");
AjmAt9Decoder::Reset();
}
: m_format(format), m_flags(flags), m_handle(Atrac9GetHandle()) {}
AjmAt9Decoder::~AjmAt9Decoder() {
Atrac9ReleaseHandle(m_handle);
@@ -42,6 +78,7 @@ void AjmAt9Decoder::Initialize(const void* buffer, u32 buffer_size) {
AjmAt9Decoder::Reset();
m_pcm_buffer.resize(m_codec_info.frameSamples * m_codec_info.channels * GetPCMSize(m_format),
0);
m_is_initialized = true;
}
void AjmAt9Decoder::GetInfo(void* out_info) const {
@@ -52,8 +89,64 @@ void AjmAt9Decoder::GetInfo(void* out_info) const {
info->next_frame_size = static_cast<Atrac9Handle*>(m_handle)->Config.FrameBytes;
}
std::tuple<u32, u32> AjmAt9Decoder::ProcessData(std::span<u8>& in_buf, SparseOutputBuffer& output,
AjmInstanceGapless& gapless) {
u8 g_at9_guid[] = {0xD2, 0x42, 0xE1, 0x47, 0xBA, 0x36, 0x8D, 0x4D,
0x88, 0xFC, 0x61, 0x65, 0x4F, 0x8C, 0x83, 0x6C};
void AjmAt9Decoder::ParseRIFFHeader(std::span<u8>& in_buf, AjmInstanceGapless& gapless) {
auto* header = reinterpret_cast<RIFFHeader*>(in_buf.data());
in_buf = in_buf.subspan(sizeof(RIFFHeader));
ASSERT(header->riff == 'FFIR');
ASSERT(header->wave == 'EVAW');
auto* chunk = reinterpret_cast<ChunkHeader*>(in_buf.data());
in_buf = in_buf.subspan(sizeof(ChunkHeader));
while (chunk->tag != 'atad') {
switch (chunk->tag) {
case ' tmf': {
ASSERT(chunk->length == sizeof(AudioFormat));
auto* fmt = reinterpret_cast<AudioFormat*>(in_buf.data());
ASSERT(fmt->fmt_type == 0xFFFE);
ASSERT(memcmp(fmt->guid, g_at9_guid, 16) == 0);
AjmDecAt9InitializeParameters init_params = {};
std::memcpy(init_params.config_data, fmt->config_data, ORBIS_AT9_CONFIG_DATA_SIZE);
Initialize(&init_params, sizeof(init_params));
break;
}
case 'tcaf': {
ASSERT(chunk->length == sizeof(SampleData));
auto* samples = reinterpret_cast<SampleData*>(in_buf.data());
gapless.init.total_samples = samples->sample_length;
gapless.init.skip_samples = samples->encoder_delay;
gapless.Reset();
break;
}
default:
break;
}
in_buf = in_buf.subspan(chunk->length);
chunk = reinterpret_cast<ChunkHeader*>(in_buf.data());
in_buf = in_buf.subspan(sizeof(ChunkHeader));
}
}
std::tuple<u32, u32, bool> AjmAt9Decoder::ProcessData(std::span<u8>& in_buf,
SparseOutputBuffer& output,
AjmInstanceGapless& gapless) {
bool is_reset = false;
if (True(m_flags & AjmAt9CodecFlags::ParseRiffHeader) &&
*reinterpret_cast<u32*>(in_buf.data()) == 'FFIR') {
ParseRIFFHeader(in_buf, gapless);
is_reset = true;
}
if (!m_is_initialized) {
return {0, 0, is_reset};
}
int ret = 0;
int bytes_used = 0;
switch (m_format) {
@@ -118,7 +211,7 @@ std::tuple<u32, u32> AjmAt9Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
m_num_frames = 0;
}
return {1, samples_written};
return {1, m_codec_info.frameSamples, is_reset};
}
AjmSidebandFormat AjmAt9Decoder::GetFormat() const {
@@ -134,12 +227,13 @@ AjmSidebandFormat AjmAt9Decoder::GetFormat() const {
}
u32 AjmAt9Decoder::GetNextFrameSize(const AjmInstanceGapless& gapless) const {
const auto max_samples =
const auto skip_samples =
std::min<u32>(gapless.current.skip_samples, m_codec_info.frameSamples);
const auto samples =
gapless.init.total_samples != 0
? std::min(gapless.current.total_samples, u32(m_codec_info.frameSamples))
? std::min<u32>(gapless.current.total_samples, m_codec_info.frameSamples - skip_samples)
: m_codec_info.frameSamples;
const auto skip_samples = std::min(u32(gapless.current.skip_samples), max_samples);
return (max_samples - skip_samples) * m_codec_info.channels * GetPCMSize(m_format);
return samples * m_codec_info.channels * GetPCMSize(m_format);
}
} // namespace Libraries::Ajm

View File

@@ -9,6 +9,7 @@
#include "libatrac9.h"
#include <span>
#include <vector>
namespace Libraries::Ajm {
@@ -36,8 +37,8 @@ struct AjmAt9Decoder final : AjmCodec {
void GetInfo(void* out_info) const override;
AjmSidebandFormat GetFormat() const override;
u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const override;
std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmInstanceGapless& gapless) override;
std::tuple<u32, u32, bool> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmInstanceGapless& gapless) override;
private:
template <class T>
@@ -49,8 +50,12 @@ private:
return output.Write(pcm_data.subspan(0, pcm_size));
}
void ParseRIFFHeader(std::span<u8>& input, AjmInstanceGapless& gapless);
const AjmFormatEncoding m_format;
const AjmAt9CodecFlags m_flags;
bool m_is_initialized{};
void* m_handle{};
u8 m_config_data[ORBIS_AT9_CONFIG_DATA_SIZE]{};
u32 m_superframe_bytes_remain{};

View File

@@ -52,21 +52,22 @@ AjmInstance::AjmInstance(AjmCodecType codec_type, AjmInstanceFlags flags) : m_fl
}
}
void AjmInstance::Reset() {
m_total_samples = 0;
m_gapless.Reset();
m_codec->Reset();
}
void AjmInstance::ExecuteJob(AjmJob& job) {
const auto control_flags = job.flags.control_flags;
if (True(control_flags & AjmJobControlFlags::Reset)) {
LOG_TRACE(Lib_Ajm, "Resetting instance {}", job.instance_id);
m_format = {};
m_gapless = {};
m_resample_parameters = {};
m_total_samples = 0;
m_codec->Reset();
Reset();
}
if (job.input.init_params.has_value()) {
LOG_TRACE(Lib_Ajm, "Initializing instance {}", job.instance_id);
auto& params = job.input.init_params.value();
m_codec->Initialize(&params, sizeof(params));
is_initialized = true;
}
if (job.input.resample_parameters.has_value()) {
LOG_ERROR(Lib_Ajm, "Unimplemented: resample parameters");
@@ -78,20 +79,37 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
}
if (job.input.gapless_decode.has_value()) {
auto& params = job.input.gapless_decode.value();
if (params.total_samples != 0) {
const auto max = std::max(params.total_samples, m_gapless.init.total_samples);
m_gapless.current.total_samples += max - m_gapless.init.total_samples;
m_gapless.init.total_samples = max;
}
if (params.skip_samples != 0) {
const auto max = std::max(params.skip_samples, m_gapless.init.skip_samples);
m_gapless.current.skip_samples += max - m_gapless.init.skip_samples;
m_gapless.init.skip_samples = max;
}
}
if (!is_initialized) {
return;
const auto samples_processed =
m_gapless.init.total_samples - m_gapless.current.total_samples;
if (params.total_samples != 0 || params.skip_samples == 0) {
if (params.total_samples >= samples_processed) {
const auto sample_difference =
s64(m_gapless.init.total_samples) - params.total_samples;
m_gapless.init.total_samples = params.total_samples;
m_gapless.current.total_samples -= sample_difference;
} else {
LOG_WARNING(Lib_Ajm, "ORBIS_AJM_RESULT_INVALID_PARAMETER");
job.output.p_result->result = ORBIS_AJM_RESULT_INVALID_PARAMETER;
return;
}
}
const auto samples_skipped = m_gapless.init.skip_samples - m_gapless.current.skip_samples;
if (params.skip_samples != 0 || params.total_samples == 0) {
if (params.skip_samples >= samples_skipped) {
const auto sample_difference =
s32(m_gapless.init.skip_samples) - params.skip_samples;
m_gapless.init.skip_samples = params.skip_samples;
m_gapless.current.skip_samples -= sample_difference;
} else {
LOG_WARNING(Lib_Ajm, "ORBIS_AJM_RESULT_INVALID_PARAMETER");
job.output.p_result->result = ORBIS_AJM_RESULT_INVALID_PARAMETER;
return;
}
}
}
if (!job.input.buffer.empty() && !job.output.buffers.empty()) {
@@ -104,12 +122,23 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
while (!in_buf.empty() && !out_buf.IsEmpty() && !m_gapless.IsEnd()) {
if (!HasEnoughSpace(out_buf)) {
if (job.output.p_mframe == nullptr || frames_decoded == 0) {
LOG_WARNING(Lib_Ajm, "ORBIS_AJM_RESULT_NOT_ENOUGH_ROOM ({} < {})",
out_buf.Size(), m_codec->GetNextFrameSize(m_gapless));
job.output.p_result->result = ORBIS_AJM_RESULT_NOT_ENOUGH_ROOM;
break;
}
}
const auto [nframes, nsamples] = m_codec->ProcessData(in_buf, out_buf, m_gapless);
const auto [nframes, nsamples, reset] =
m_codec->ProcessData(in_buf, out_buf, m_gapless);
if (reset) {
m_total_samples = 0;
}
if (!nframes) {
LOG_WARNING(Lib_Ajm, "ORBIS_AJM_RESULT_NOT_INITIALIZED");
job.output.p_result->result = ORBIS_AJM_RESULT_NOT_INITIALIZED;
break;
}
frames_decoded += nframes;
m_total_samples += nsamples;
@@ -118,10 +147,10 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
}
}
if (m_gapless.IsEnd()) {
const auto total_decoded_samples = m_total_samples;
if (m_flags.gapless_loop && m_gapless.IsEnd()) {
in_buf = in_buf.subspan(in_buf.size());
m_gapless.current.total_samples = m_gapless.init.total_samples;
m_gapless.current.skip_samples = m_gapless.init.skip_samples;
m_gapless.Reset();
m_codec->Reset();
}
if (job.output.p_mframe) {
@@ -130,7 +159,7 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
if (job.output.p_stream) {
job.output.p_stream->input_consumed = in_size - in_buf.size();
job.output.p_stream->output_written = out_size - out_buf.Size();
job.output.p_stream->total_decoded_samples = m_total_samples;
job.output.p_stream->total_decoded_samples = total_decoded_samples;
}
}

View File

@@ -65,6 +65,12 @@ struct AjmInstanceGapless {
bool IsEnd() const {
return init.total_samples != 0 && current.total_samples == 0;
}
void Reset() {
current.total_samples = init.total_samples;
current.skip_samples = init.skip_samples;
current.skipped_samples = 0;
}
};
class AjmCodec {
@@ -76,8 +82,8 @@ public:
virtual void GetInfo(void* out_info) const = 0;
virtual AjmSidebandFormat GetFormat() const = 0;
virtual u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const = 0;
virtual std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmInstanceGapless& gapless) = 0;
virtual std::tuple<u32, u32, bool> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmInstanceGapless& gapless) = 0;
};
class AjmInstance {
@@ -89,6 +95,7 @@ public:
private:
bool HasEnoughSpace(const SparseOutputBuffer& output) const;
std::optional<u32> GetNumRemainingSamples() const;
void Reset();
AjmInstanceFlags m_flags{};
AjmSidebandFormat m_format{};
@@ -96,7 +103,6 @@ private:
AjmSidebandResampleParameters m_resample_parameters{};
u32 m_total_samples{};
std::unique_ptr<AjmCodec> m_codec;
bool is_initialized = false;
};
} // namespace Libraries::Ajm

View File

@@ -138,8 +138,9 @@ void AjmMp3Decoder::GetInfo(void* out_info) const {
}
}
std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOutputBuffer& output,
AjmInstanceGapless& gapless) {
std::tuple<u32, u32, bool> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf,
SparseOutputBuffer& output,
AjmInstanceGapless& gapless) {
AVPacket* pkt = av_packet_alloc();
if ((!m_header.has_value() || m_frame_samples == 0) && in_buf.size() >= 4) {
@@ -155,7 +156,7 @@ std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
in_buf = in_buf.subspan(ret);
u32 frames_decoded = 0;
u32 samples_written = 0;
u32 samples_decoded = 0;
if (pkt->size) {
// Send the packet with the compressed data to the decoder
@@ -176,6 +177,7 @@ std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
UNREACHABLE_MSG("Error during decoding");
}
frame = ConvertAudioFrame(frame);
samples_decoded += u32(frame->nb_samples);
frames_decoded += 1;
u32 skip_samples = 0;
@@ -205,7 +207,6 @@ std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
}
const auto samples = pcm_written / m_codec_context->ch_layout.nb_channels;
samples_written += samples;
gapless.current.skipped_samples += frame->nb_samples - samples;
if (gapless.init.total_samples != 0) {
gapless.current.total_samples -= samples;
@@ -217,7 +218,7 @@ std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
av_packet_free(&pkt);
return {frames_decoded, samples_written};
return {frames_decoded, samples_decoded, false};
}
u32 AjmMp3Decoder::GetNextFrameSize(const AjmInstanceGapless& gapless) const {

View File

@@ -71,8 +71,8 @@ public:
void GetInfo(void* out_info) const override;
AjmSidebandFormat GetFormat() const override;
u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const override;
std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmInstanceGapless& gapless) override;
std::tuple<u32, u32, bool> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmInstanceGapless& gapless) override;
static int ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl,
AjmDecMp3ParseFrame* frame);