simplified gapless decoding

This commit is contained in:
Vladislav Mikhalin 2024-11-13 21:28:19 +03:00
parent 4d4d9d5e2c
commit 6dcf249b78
6 changed files with 93 additions and 91 deletions

View File

@ -53,8 +53,7 @@ void AjmAt9Decoder::GetInfo(void* out_info) const {
} }
std::tuple<u32, u32> AjmAt9Decoder::ProcessData(std::span<u8>& in_buf, SparseOutputBuffer& output, std::tuple<u32, u32> AjmAt9Decoder::ProcessData(std::span<u8>& in_buf, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless, AjmInstanceGapless& gapless) {
std::optional<u32> max_samples_per_channel) {
int ret = 0; int ret = 0;
int bytes_used = 0; int bytes_used = 0;
switch (m_format) { switch (m_format) {
@ -79,32 +78,37 @@ std::tuple<u32, u32> AjmAt9Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
m_superframe_bytes_remain -= bytes_used; m_superframe_bytes_remain -= bytes_used;
u32 skipped_samples = 0; u32 skip_samples = 0;
if (gapless.skipped_samples < gapless.skip_samples) { if (gapless.current.skip_samples > 0) {
skipped_samples = std::min(u32(m_codec_info.frameSamples), skip_samples = std::min(u16(m_codec_info.frameSamples), gapless.current.skip_samples);
u32(gapless.skip_samples - gapless.skipped_samples)); gapless.current.skip_samples -= skip_samples;
gapless.skipped_samples += skipped_samples;
} }
const auto max_samples = max_samples_per_channel.has_value() const auto max_pcm = gapless.init.total_samples != 0
? max_samples_per_channel.value() * m_codec_info.channels ? gapless.current.total_samples * m_codec_info.channels
: std::numeric_limits<u32>::max(); : std::numeric_limits<u32>::max();
size_t samples_written = 0; size_t pcm_written = 0;
switch (m_format) { switch (m_format) {
case AjmFormatEncoding::S16: case AjmFormatEncoding::S16:
samples_written = WriteOutputSamples<s16>(output, skipped_samples, max_samples); pcm_written = WriteOutputSamples<s16>(output, skip_samples, max_pcm);
break; break;
case AjmFormatEncoding::S32: case AjmFormatEncoding::S32:
samples_written = WriteOutputSamples<s32>(output, skipped_samples, max_samples); pcm_written = WriteOutputSamples<s32>(output, skip_samples, max_pcm);
break; break;
case AjmFormatEncoding::Float: case AjmFormatEncoding::Float:
samples_written = WriteOutputSamples<float>(output, skipped_samples, max_samples); pcm_written = WriteOutputSamples<float>(output, skip_samples, max_pcm);
break; break;
default: default:
UNREACHABLE(); UNREACHABLE();
} }
const auto samples_written = pcm_written / m_codec_info.channels;
gapless.current.skipped_samples += m_codec_info.frameSamples - samples_written;
if (gapless.init.total_samples != 0) {
gapless.current.total_samples -= samples_written;
}
m_num_frames += 1; m_num_frames += 1;
if ((m_num_frames % m_codec_info.framesInSuperframe) == 0) { if ((m_num_frames % m_codec_info.framesInSuperframe) == 0) {
if (m_superframe_bytes_remain) { if (m_superframe_bytes_remain) {
@ -114,7 +118,7 @@ std::tuple<u32, u32> AjmAt9Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
m_num_frames = 0; m_num_frames = 0;
} }
return {1, samples_written / m_codec_info.channels}; return {1, samples_written};
} }
AjmSidebandFormat AjmAt9Decoder::GetFormat() const { AjmSidebandFormat AjmAt9Decoder::GetFormat() const {
@ -129,10 +133,13 @@ AjmSidebandFormat AjmAt9Decoder::GetFormat() const {
}; };
} }
u32 AjmAt9Decoder::GetNextFrameSize(u32 skip_samples, u32 max_samples) const { u32 AjmAt9Decoder::GetNextFrameSize(const AjmInstanceGapless& gapless) const {
skip_samples = std::min({skip_samples, u32(m_codec_info.frameSamples), max_samples}); const auto max_samples =
return (std::min(u32(m_codec_info.frameSamples), max_samples) - skip_samples) * gapless.init.total_samples != 0
m_codec_info.channels * GetPCMSize(m_format); ? std::min(gapless.current.total_samples, u32(m_codec_info.frameSamples))
: m_codec_info.frameSamples;
const auto skip_samples = std::min(u32(gapless.current.skip_samples), max_samples);
return (max_samples - skip_samples) * m_codec_info.channels * GetPCMSize(m_format);
} }
} // namespace Libraries::Ajm } // namespace Libraries::Ajm

View File

@ -35,10 +35,9 @@ struct AjmAt9Decoder final : AjmCodec {
void Initialize(const void* buffer, u32 buffer_size) override; void Initialize(const void* buffer, u32 buffer_size) override;
void GetInfo(void* out_info) const override; void GetInfo(void* out_info) const override;
AjmSidebandFormat GetFormat() const override; AjmSidebandFormat GetFormat() const override;
u32 GetNextFrameSize(u32 skip_samples, u32 max_samples) const override; u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const override;
std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output, std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless, AjmInstanceGapless& gapless) override;
std::optional<u32> max_samples) override;
private: private:
template <class T> template <class T>

View File

@ -59,7 +59,6 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
m_format = {}; m_format = {};
m_gapless = {}; m_gapless = {};
m_resample_parameters = {}; m_resample_parameters = {};
m_gapless_samples = 0;
m_total_samples = 0; m_total_samples = 0;
m_codec->Reset(); m_codec->Reset();
} }
@ -79,10 +78,14 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
if (job.input.gapless_decode.has_value()) { if (job.input.gapless_decode.has_value()) {
auto& params = job.input.gapless_decode.value(); auto& params = job.input.gapless_decode.value();
if (params.total_samples != 0) { if (params.total_samples != 0) {
m_gapless.total_samples = std::max(params.total_samples, m_gapless.total_samples); const auto max = std::max(params.total_samples, m_gapless.init.total_samples);
m_gapless.current.total_samples += max - m_gapless.init.total_samples;
m_gapless.init.total_samples = max;
} }
if (params.skip_samples != 0) { if (params.skip_samples != 0) {
m_gapless.skip_samples = std::max(params.skip_samples, m_gapless.skip_samples); const auto max = std::max(params.skip_samples, m_gapless.init.skip_samples);
m_gapless.current.skip_samples += max - m_gapless.init.skip_samples;
m_gapless.init.skip_samples = max;
} }
} }
@ -93,22 +96,29 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
u32 frames_decoded = 0; u32 frames_decoded = 0;
auto in_size = in_buf.size(); auto in_size = in_buf.size();
auto out_size = out_buf.Size(); auto out_size = out_buf.Size();
while (!in_buf.empty() && !out_buf.IsEmpty() && !IsGaplessEnd()) { while (!in_buf.empty() && !out_buf.IsEmpty() && !m_gapless.IsEnd()) {
if (!HasEnoughSpace(out_buf)) { if (!HasEnoughSpace(out_buf)) {
if (job.output.p_mframe == nullptr || frames_decoded == 0) { if (job.output.p_mframe == nullptr || frames_decoded == 0) {
job.output.p_result->result = ORBIS_AJM_RESULT_NOT_ENOUGH_ROOM; job.output.p_result->result = ORBIS_AJM_RESULT_NOT_ENOUGH_ROOM;
break; break;
} }
} }
const auto [nframes, nsamples] =
m_codec->ProcessData(in_buf, out_buf, m_gapless, GetNumRemainingSamples()); const auto [nframes, nsamples] = m_codec->ProcessData(in_buf, out_buf, m_gapless);
frames_decoded += nframes; frames_decoded += nframes;
m_total_samples += nsamples; m_total_samples += nsamples;
m_gapless_samples += nsamples;
if (job.output.p_mframe == nullptr) { if (False(job.flags.run_flags & AjmJobRunFlags::MultipleFrames)) {
break; break;
} }
} }
if (m_gapless.IsEnd()) {
in_buf = in_buf.subspan(in_buf.size());
m_gapless.current.total_samples = m_gapless.init.total_samples;
m_gapless.current.skip_samples = m_gapless.init.skip_samples;
m_codec->Reset();
}
if (job.output.p_mframe) { if (job.output.p_mframe) {
job.output.p_mframe->num_frames = frames_decoded; job.output.p_mframe->num_frames = frames_decoded;
} }
@ -119,38 +129,19 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
} }
} }
if (m_flags.gapless_loop && m_gapless.total_samples != 0 &&
m_gapless_samples >= m_gapless.total_samples) {
m_gapless_samples = 0;
m_gapless.skipped_samples = 0;
m_codec->Reset();
}
if (job.output.p_format != nullptr) { if (job.output.p_format != nullptr) {
*job.output.p_format = m_codec->GetFormat(); *job.output.p_format = m_codec->GetFormat();
} }
if (job.output.p_gapless_decode != nullptr) { if (job.output.p_gapless_decode != nullptr) {
*job.output.p_gapless_decode = m_gapless; *job.output.p_gapless_decode = m_gapless.current;
} }
if (job.output.p_codec_info != nullptr) { if (job.output.p_codec_info != nullptr) {
m_codec->GetInfo(job.output.p_codec_info); m_codec->GetInfo(job.output.p_codec_info);
} }
} }
bool AjmInstance::IsGaplessEnd() const {
return m_gapless.total_samples != 0 && m_gapless_samples >= m_gapless.total_samples;
}
bool AjmInstance::HasEnoughSpace(const SparseOutputBuffer& output) const { bool AjmInstance::HasEnoughSpace(const SparseOutputBuffer& output) const {
const auto skip = return output.Size() >= m_codec->GetNextFrameSize(m_gapless);
m_gapless.skip_samples - std::min(m_gapless.skip_samples, m_gapless.skipped_samples);
const auto remain = GetNumRemainingSamples().value_or(std::numeric_limits<u32>::max());
return output.Size() >= m_codec->GetNextFrameSize(skip, remain);
}
std::optional<u32> AjmInstance::GetNumRemainingSamples() const {
return m_gapless.total_samples != 0
? std::optional<u32>{m_gapless.total_samples - m_gapless_samples}
: std::optional<u32>{};
} }
} // namespace Libraries::Ajm } // namespace Libraries::Ajm

View File

@ -58,6 +58,15 @@ private:
std::span<std::span<u8>>::iterator m_current; std::span<std::span<u8>>::iterator m_current;
}; };
struct AjmInstanceGapless {
AjmSidebandGaplessDecode init{};
AjmSidebandGaplessDecode current{};
bool IsEnd() const {
return init.total_samples != 0 && current.total_samples == 0;
}
};
class AjmCodec { class AjmCodec {
public: public:
virtual ~AjmCodec() = default; virtual ~AjmCodec() = default;
@ -66,10 +75,9 @@ public:
virtual void Reset() = 0; virtual void Reset() = 0;
virtual void GetInfo(void* out_info) const = 0; virtual void GetInfo(void* out_info) const = 0;
virtual AjmSidebandFormat GetFormat() const = 0; virtual AjmSidebandFormat GetFormat() const = 0;
virtual u32 GetNextFrameSize(u32 skip_samples, u32 max_samples) const = 0; virtual u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const = 0;
virtual std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output, virtual std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless, AjmInstanceGapless& gapless) = 0;
std::optional<u32> max_samples_per_channel) = 0;
}; };
class AjmInstance { class AjmInstance {
@ -79,18 +87,14 @@ public:
void ExecuteJob(AjmJob& job); void ExecuteJob(AjmJob& job);
private: private:
bool IsGaplessEnd() const;
bool HasEnoughSpace(const SparseOutputBuffer& output) const; bool HasEnoughSpace(const SparseOutputBuffer& output) const;
std::optional<u32> GetNumRemainingSamples() const; std::optional<u32> GetNumRemainingSamples() const;
AjmInstanceFlags m_flags{}; AjmInstanceFlags m_flags{};
AjmSidebandFormat m_format{}; AjmSidebandFormat m_format{};
AjmSidebandGaplessDecode m_gapless{}; AjmInstanceGapless m_gapless{};
AjmSidebandResampleParameters m_resample_parameters{}; AjmSidebandResampleParameters m_resample_parameters{};
u32 m_gapless_samples{};
u32 m_total_samples{}; u32 m_total_samples{};
std::unique_ptr<AjmCodec> m_codec; std::unique_ptr<AjmCodec> m_codec;
}; };

View File

@ -137,8 +137,7 @@ void AjmMp3Decoder::GetInfo(void* out_info) const {
} }
std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOutputBuffer& output, std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless, AjmInstanceGapless& gapless) {
std::optional<u32> max_samples_per_channel) {
AVPacket* pkt = av_packet_alloc(); AVPacket* pkt = av_packet_alloc();
if ((!m_header.has_value() || m_frame_samples == 0) && in_buf.size() >= 4) { if ((!m_header.has_value() || m_frame_samples == 0) && in_buf.size() >= 4) {
@ -154,12 +153,7 @@ std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
in_buf = in_buf.subspan(ret); in_buf = in_buf.subspan(ret);
u32 frames_decoded = 0; u32 frames_decoded = 0;
u32 samples_decoded = 0; u32 samples_written = 0;
auto max_samples =
max_samples_per_channel.has_value()
? max_samples_per_channel.value() * m_codec_context->ch_layout.nb_channels
: std::numeric_limits<u32>::max();
if (pkt->size) { if (pkt->size) {
// Send the packet with the compressed data to the decoder // Send the packet with the compressed data to the decoder
@ -182,32 +176,37 @@ std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
frame = ConvertAudioFrame(frame); frame = ConvertAudioFrame(frame);
frames_decoded += 1; frames_decoded += 1;
u32 skipped_samples = 0; u32 skip_samples = 0;
if (gapless.skipped_samples < gapless.skip_samples) { if (gapless.current.skip_samples > 0) {
skipped_samples = std::min(u32(frame->nb_samples), skip_samples = std::min(u16(frame->nb_samples), gapless.current.skip_samples);
u32(gapless.skip_samples - gapless.skipped_samples)); gapless.current.skip_samples -= skip_samples;
gapless.skipped_samples += skipped_samples;
} }
const auto max_pcm =
gapless.init.total_samples != 0
? gapless.current.total_samples * m_codec_context->ch_layout.nb_channels
: std::numeric_limits<u32>::max();
u32 pcm_written = 0;
switch (m_format) { switch (m_format) {
case AjmFormatEncoding::S16: case AjmFormatEncoding::S16:
samples_decoded += pcm_written = WriteOutputPCM<s16>(frame, output, skip_samples, max_pcm);
WriteOutputSamples<s16>(frame, output, skipped_samples, max_samples);
break; break;
case AjmFormatEncoding::S32: case AjmFormatEncoding::S32:
samples_decoded += pcm_written = WriteOutputPCM<s32>(frame, output, skip_samples, max_pcm);
WriteOutputSamples<s32>(frame, output, skipped_samples, max_samples);
break; break;
case AjmFormatEncoding::Float: case AjmFormatEncoding::Float:
samples_decoded += pcm_written = WriteOutputPCM<float>(frame, output, skip_samples, max_pcm);
WriteOutputSamples<float>(frame, output, skipped_samples, max_samples);
break; break;
default: default:
UNREACHABLE(); UNREACHABLE();
} }
if (max_samples_per_channel.has_value()) { const auto samples = pcm_written / m_codec_context->ch_layout.nb_channels;
max_samples -= samples_decoded; samples_written += samples;
gapless.current.skipped_samples += frame->nb_samples - samples;
if (gapless.init.total_samples != 0) {
gapless.current.total_samples -= samples;
} }
av_frame_free(&frame); av_frame_free(&frame);
@ -216,13 +215,16 @@ std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
av_packet_free(&pkt); av_packet_free(&pkt);
return {frames_decoded, samples_decoded / m_codec_context->ch_layout.nb_channels}; return {frames_decoded, samples_written};
} }
u32 AjmMp3Decoder::GetNextFrameSize(u32 skip_samples, u32 max_samples) const { u32 AjmMp3Decoder::GetNextFrameSize(const AjmInstanceGapless& gapless) const {
skip_samples = std::min({skip_samples, m_frame_samples, max_samples}); const auto max_samples = gapless.init.total_samples != 0
return (std::min(m_frame_samples, max_samples) - skip_samples) * ? std::min(gapless.current.total_samples, m_frame_samples)
m_codec_context->ch_layout.nb_channels * GetPCMSize(m_format); : m_frame_samples;
const auto skip_samples = std::min(u32(gapless.current.skip_samples), max_samples);
return (max_samples - skip_samples) * m_codec_context->ch_layout.nb_channels *
GetPCMSize(m_format);
} }
class BitReader { class BitReader {

View File

@ -70,22 +70,21 @@ public:
void Initialize(const void* buffer, u32 buffer_size) override {} void Initialize(const void* buffer, u32 buffer_size) override {}
void GetInfo(void* out_info) const override; void GetInfo(void* out_info) const override;
AjmSidebandFormat GetFormat() const override; AjmSidebandFormat GetFormat() const override;
u32 GetNextFrameSize(u32 skip_samples, u32 max_samples) const override; u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const override;
std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output, std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless, AjmInstanceGapless& gapless) override;
std::optional<u32> max_samples_per_channel) override;
static int ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl, static int ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl,
AjmDecMp3ParseFrame* frame); AjmDecMp3ParseFrame* frame);
private: private:
template <class T> template <class T>
size_t WriteOutputSamples(AVFrame* frame, SparseOutputBuffer& output, u32 skipped_samples, size_t WriteOutputPCM(AVFrame* frame, SparseOutputBuffer& output, u32 skipped_samples,
u32 max_samples) { u32 max_pcm) {
std::span<T> pcm_data(reinterpret_cast<T*>(frame->data[0]), std::span<T> pcm_data(reinterpret_cast<T*>(frame->data[0]),
frame->nb_samples * frame->ch_layout.nb_channels); frame->nb_samples * frame->ch_layout.nb_channels);
pcm_data = pcm_data.subspan(skipped_samples * frame->ch_layout.nb_channels); pcm_data = pcm_data.subspan(skipped_samples * frame->ch_layout.nb_channels);
return output.Write(pcm_data.subspan(0, std::min(u32(pcm_data.size()), max_samples))); return output.Write(pcm_data.subspan(0, std::min(u32(pcm_data.size()), max_pcm)));
} }
AVFrame* ConvertAudioFrame(AVFrame* frame); AVFrame* ConvertAudioFrame(AVFrame* frame);