simplified gapless decoding

This commit is contained in:
Vladislav Mikhalin 2024-11-13 21:28:19 +03:00
parent 4d4d9d5e2c
commit 6dcf249b78
6 changed files with 93 additions and 91 deletions

View File

@ -53,8 +53,7 @@ void AjmAt9Decoder::GetInfo(void* out_info) const {
}
std::tuple<u32, u32> AjmAt9Decoder::ProcessData(std::span<u8>& in_buf, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless,
std::optional<u32> max_samples_per_channel) {
AjmInstanceGapless& gapless) {
int ret = 0;
int bytes_used = 0;
switch (m_format) {
@ -79,32 +78,37 @@ std::tuple<u32, u32> AjmAt9Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
m_superframe_bytes_remain -= bytes_used;
u32 skipped_samples = 0;
if (gapless.skipped_samples < gapless.skip_samples) {
skipped_samples = std::min(u32(m_codec_info.frameSamples),
u32(gapless.skip_samples - gapless.skipped_samples));
gapless.skipped_samples += skipped_samples;
u32 skip_samples = 0;
if (gapless.current.skip_samples > 0) {
skip_samples = std::min(u16(m_codec_info.frameSamples), gapless.current.skip_samples);
gapless.current.skip_samples -= skip_samples;
}
const auto max_samples = max_samples_per_channel.has_value()
? max_samples_per_channel.value() * m_codec_info.channels
: std::numeric_limits<u32>::max();
const auto max_pcm = gapless.init.total_samples != 0
? gapless.current.total_samples * m_codec_info.channels
: std::numeric_limits<u32>::max();
size_t samples_written = 0;
size_t pcm_written = 0;
switch (m_format) {
case AjmFormatEncoding::S16:
samples_written = WriteOutputSamples<s16>(output, skipped_samples, max_samples);
pcm_written = WriteOutputSamples<s16>(output, skip_samples, max_pcm);
break;
case AjmFormatEncoding::S32:
samples_written = WriteOutputSamples<s32>(output, skipped_samples, max_samples);
pcm_written = WriteOutputSamples<s32>(output, skip_samples, max_pcm);
break;
case AjmFormatEncoding::Float:
samples_written = WriteOutputSamples<float>(output, skipped_samples, max_samples);
pcm_written = WriteOutputSamples<float>(output, skip_samples, max_pcm);
break;
default:
UNREACHABLE();
}
const auto samples_written = pcm_written / m_codec_info.channels;
gapless.current.skipped_samples += m_codec_info.frameSamples - samples_written;
if (gapless.init.total_samples != 0) {
gapless.current.total_samples -= samples_written;
}
m_num_frames += 1;
if ((m_num_frames % m_codec_info.framesInSuperframe) == 0) {
if (m_superframe_bytes_remain) {
@ -114,7 +118,7 @@ std::tuple<u32, u32> AjmAt9Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
m_num_frames = 0;
}
return {1, samples_written / m_codec_info.channels};
return {1, samples_written};
}
AjmSidebandFormat AjmAt9Decoder::GetFormat() const {
@ -129,10 +133,13 @@ AjmSidebandFormat AjmAt9Decoder::GetFormat() const {
};
}
u32 AjmAt9Decoder::GetNextFrameSize(u32 skip_samples, u32 max_samples) const {
skip_samples = std::min({skip_samples, u32(m_codec_info.frameSamples), max_samples});
return (std::min(u32(m_codec_info.frameSamples), max_samples) - skip_samples) *
m_codec_info.channels * GetPCMSize(m_format);
u32 AjmAt9Decoder::GetNextFrameSize(const AjmInstanceGapless& gapless) const {
const auto max_samples =
gapless.init.total_samples != 0
? std::min(gapless.current.total_samples, u32(m_codec_info.frameSamples))
: m_codec_info.frameSamples;
const auto skip_samples = std::min(u32(gapless.current.skip_samples), max_samples);
return (max_samples - skip_samples) * m_codec_info.channels * GetPCMSize(m_format);
}
} // namespace Libraries::Ajm

View File

@ -35,10 +35,9 @@ struct AjmAt9Decoder final : AjmCodec {
void Initialize(const void* buffer, u32 buffer_size) override;
void GetInfo(void* out_info) const override;
AjmSidebandFormat GetFormat() const override;
u32 GetNextFrameSize(u32 skip_samples, u32 max_samples) const override;
u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const override;
std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless,
std::optional<u32> max_samples) override;
AjmInstanceGapless& gapless) override;
private:
template <class T>

View File

@ -59,7 +59,6 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
m_format = {};
m_gapless = {};
m_resample_parameters = {};
m_gapless_samples = 0;
m_total_samples = 0;
m_codec->Reset();
}
@ -79,10 +78,14 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
if (job.input.gapless_decode.has_value()) {
auto& params = job.input.gapless_decode.value();
if (params.total_samples != 0) {
m_gapless.total_samples = std::max(params.total_samples, m_gapless.total_samples);
const auto max = std::max(params.total_samples, m_gapless.init.total_samples);
m_gapless.current.total_samples += max - m_gapless.init.total_samples;
m_gapless.init.total_samples = max;
}
if (params.skip_samples != 0) {
m_gapless.skip_samples = std::max(params.skip_samples, m_gapless.skip_samples);
const auto max = std::max(params.skip_samples, m_gapless.init.skip_samples);
m_gapless.current.skip_samples += max - m_gapless.init.skip_samples;
m_gapless.init.skip_samples = max;
}
}
@ -93,22 +96,29 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
u32 frames_decoded = 0;
auto in_size = in_buf.size();
auto out_size = out_buf.Size();
while (!in_buf.empty() && !out_buf.IsEmpty() && !IsGaplessEnd()) {
while (!in_buf.empty() && !out_buf.IsEmpty() && !m_gapless.IsEnd()) {
if (!HasEnoughSpace(out_buf)) {
if (job.output.p_mframe == nullptr || frames_decoded == 0) {
job.output.p_result->result = ORBIS_AJM_RESULT_NOT_ENOUGH_ROOM;
break;
}
}
const auto [nframes, nsamples] =
m_codec->ProcessData(in_buf, out_buf, m_gapless, GetNumRemainingSamples());
const auto [nframes, nsamples] = m_codec->ProcessData(in_buf, out_buf, m_gapless);
frames_decoded += nframes;
m_total_samples += nsamples;
m_gapless_samples += nsamples;
if (job.output.p_mframe == nullptr) {
if (False(job.flags.run_flags & AjmJobRunFlags::MultipleFrames)) {
break;
}
}
if (m_gapless.IsEnd()) {
in_buf = in_buf.subspan(in_buf.size());
m_gapless.current.total_samples = m_gapless.init.total_samples;
m_gapless.current.skip_samples = m_gapless.init.skip_samples;
m_codec->Reset();
}
if (job.output.p_mframe) {
job.output.p_mframe->num_frames = frames_decoded;
}
@ -119,38 +129,19 @@ void AjmInstance::ExecuteJob(AjmJob& job) {
}
}
if (m_flags.gapless_loop && m_gapless.total_samples != 0 &&
m_gapless_samples >= m_gapless.total_samples) {
m_gapless_samples = 0;
m_gapless.skipped_samples = 0;
m_codec->Reset();
}
if (job.output.p_format != nullptr) {
*job.output.p_format = m_codec->GetFormat();
}
if (job.output.p_gapless_decode != nullptr) {
*job.output.p_gapless_decode = m_gapless;
*job.output.p_gapless_decode = m_gapless.current;
}
if (job.output.p_codec_info != nullptr) {
m_codec->GetInfo(job.output.p_codec_info);
}
}
bool AjmInstance::IsGaplessEnd() const {
return m_gapless.total_samples != 0 && m_gapless_samples >= m_gapless.total_samples;
}
bool AjmInstance::HasEnoughSpace(const SparseOutputBuffer& output) const {
const auto skip =
m_gapless.skip_samples - std::min(m_gapless.skip_samples, m_gapless.skipped_samples);
const auto remain = GetNumRemainingSamples().value_or(std::numeric_limits<u32>::max());
return output.Size() >= m_codec->GetNextFrameSize(skip, remain);
}
std::optional<u32> AjmInstance::GetNumRemainingSamples() const {
return m_gapless.total_samples != 0
? std::optional<u32>{m_gapless.total_samples - m_gapless_samples}
: std::optional<u32>{};
return output.Size() >= m_codec->GetNextFrameSize(m_gapless);
}
} // namespace Libraries::Ajm

View File

@ -58,6 +58,15 @@ private:
std::span<std::span<u8>>::iterator m_current;
};
struct AjmInstanceGapless {
AjmSidebandGaplessDecode init{};
AjmSidebandGaplessDecode current{};
bool IsEnd() const {
return init.total_samples != 0 && current.total_samples == 0;
}
};
class AjmCodec {
public:
virtual ~AjmCodec() = default;
@ -66,10 +75,9 @@ public:
virtual void Reset() = 0;
virtual void GetInfo(void* out_info) const = 0;
virtual AjmSidebandFormat GetFormat() const = 0;
virtual u32 GetNextFrameSize(u32 skip_samples, u32 max_samples) const = 0;
virtual u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const = 0;
virtual std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless,
std::optional<u32> max_samples_per_channel) = 0;
AjmInstanceGapless& gapless) = 0;
};
class AjmInstance {
@ -79,18 +87,14 @@ public:
void ExecuteJob(AjmJob& job);
private:
bool IsGaplessEnd() const;
bool HasEnoughSpace(const SparseOutputBuffer& output) const;
std::optional<u32> GetNumRemainingSamples() const;
AjmInstanceFlags m_flags{};
AjmSidebandFormat m_format{};
AjmSidebandGaplessDecode m_gapless{};
AjmInstanceGapless m_gapless{};
AjmSidebandResampleParameters m_resample_parameters{};
u32 m_gapless_samples{};
u32 m_total_samples{};
std::unique_ptr<AjmCodec> m_codec;
};

View File

@ -137,8 +137,7 @@ void AjmMp3Decoder::GetInfo(void* out_info) const {
}
std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless,
std::optional<u32> max_samples_per_channel) {
AjmInstanceGapless& gapless) {
AVPacket* pkt = av_packet_alloc();
if ((!m_header.has_value() || m_frame_samples == 0) && in_buf.size() >= 4) {
@ -154,12 +153,7 @@ std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
in_buf = in_buf.subspan(ret);
u32 frames_decoded = 0;
u32 samples_decoded = 0;
auto max_samples =
max_samples_per_channel.has_value()
? max_samples_per_channel.value() * m_codec_context->ch_layout.nb_channels
: std::numeric_limits<u32>::max();
u32 samples_written = 0;
if (pkt->size) {
// Send the packet with the compressed data to the decoder
@ -182,32 +176,37 @@ std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
frame = ConvertAudioFrame(frame);
frames_decoded += 1;
u32 skipped_samples = 0;
if (gapless.skipped_samples < gapless.skip_samples) {
skipped_samples = std::min(u32(frame->nb_samples),
u32(gapless.skip_samples - gapless.skipped_samples));
gapless.skipped_samples += skipped_samples;
u32 skip_samples = 0;
if (gapless.current.skip_samples > 0) {
skip_samples = std::min(u16(frame->nb_samples), gapless.current.skip_samples);
gapless.current.skip_samples -= skip_samples;
}
const auto max_pcm =
gapless.init.total_samples != 0
? gapless.current.total_samples * m_codec_context->ch_layout.nb_channels
: std::numeric_limits<u32>::max();
u32 pcm_written = 0;
switch (m_format) {
case AjmFormatEncoding::S16:
samples_decoded +=
WriteOutputSamples<s16>(frame, output, skipped_samples, max_samples);
pcm_written = WriteOutputPCM<s16>(frame, output, skip_samples, max_pcm);
break;
case AjmFormatEncoding::S32:
samples_decoded +=
WriteOutputSamples<s32>(frame, output, skipped_samples, max_samples);
pcm_written = WriteOutputPCM<s32>(frame, output, skip_samples, max_pcm);
break;
case AjmFormatEncoding::Float:
samples_decoded +=
WriteOutputSamples<float>(frame, output, skipped_samples, max_samples);
pcm_written = WriteOutputPCM<float>(frame, output, skip_samples, max_pcm);
break;
default:
UNREACHABLE();
}
if (max_samples_per_channel.has_value()) {
max_samples -= samples_decoded;
const auto samples = pcm_written / m_codec_context->ch_layout.nb_channels;
samples_written += samples;
gapless.current.skipped_samples += frame->nb_samples - samples;
if (gapless.init.total_samples != 0) {
gapless.current.total_samples -= samples;
}
av_frame_free(&frame);
@ -216,13 +215,16 @@ std::tuple<u32, u32> AjmMp3Decoder::ProcessData(std::span<u8>& in_buf, SparseOut
av_packet_free(&pkt);
return {frames_decoded, samples_decoded / m_codec_context->ch_layout.nb_channels};
return {frames_decoded, samples_written};
}
u32 AjmMp3Decoder::GetNextFrameSize(u32 skip_samples, u32 max_samples) const {
skip_samples = std::min({skip_samples, m_frame_samples, max_samples});
return (std::min(m_frame_samples, max_samples) - skip_samples) *
m_codec_context->ch_layout.nb_channels * GetPCMSize(m_format);
u32 AjmMp3Decoder::GetNextFrameSize(const AjmInstanceGapless& gapless) const {
const auto max_samples = gapless.init.total_samples != 0
? std::min(gapless.current.total_samples, m_frame_samples)
: m_frame_samples;
const auto skip_samples = std::min(u32(gapless.current.skip_samples), max_samples);
return (max_samples - skip_samples) * m_codec_context->ch_layout.nb_channels *
GetPCMSize(m_format);
}
class BitReader {

View File

@ -70,22 +70,21 @@ public:
void Initialize(const void* buffer, u32 buffer_size) override {}
void GetInfo(void* out_info) const override;
AjmSidebandFormat GetFormat() const override;
u32 GetNextFrameSize(u32 skip_samples, u32 max_samples) const override;
u32 GetNextFrameSize(const AjmInstanceGapless& gapless) const override;
std::tuple<u32, u32> ProcessData(std::span<u8>& input, SparseOutputBuffer& output,
AjmSidebandGaplessDecode& gapless,
std::optional<u32> max_samples_per_channel) override;
AjmInstanceGapless& gapless) override;
static int ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl,
AjmDecMp3ParseFrame* frame);
private:
template <class T>
size_t WriteOutputSamples(AVFrame* frame, SparseOutputBuffer& output, u32 skipped_samples,
u32 max_samples) {
size_t WriteOutputPCM(AVFrame* frame, SparseOutputBuffer& output, u32 skipped_samples,
u32 max_pcm) {
std::span<T> pcm_data(reinterpret_cast<T*>(frame->data[0]),
frame->nb_samples * frame->ch_layout.nb_channels);
pcm_data = pcm_data.subspan(skipped_samples * frame->ch_layout.nb_channels);
return output.Write(pcm_data.subspan(0, std::min(u32(pcm_data.size()), max_samples)));
return output.Write(pcm_data.subspan(0, std::min(u32(pcm_data.size()), max_pcm)));
}
AVFrame* ConvertAudioFrame(AVFrame* frame);