diff --git a/src/core/libraries/ajm/ajm.h b/src/core/libraries/ajm/ajm.h index 10e23482e..4935b09b3 100644 --- a/src/core/libraries/ajm/ajm.h +++ b/src/core/libraries/ajm/ajm.h @@ -137,6 +137,7 @@ union AjmInstanceFlags { u64 codec : 28; }; }; +static_assert(sizeof(AjmInstanceFlags) == 8); struct AjmDecMp3ParseFrame; diff --git a/src/core/libraries/ajm/ajm_at9.cpp b/src/core/libraries/ajm/ajm_at9.cpp index 07647c521..91a2c86f0 100644 --- a/src/core/libraries/ajm/ajm_at9.cpp +++ b/src/core/libraries/ajm/ajm_at9.cpp @@ -14,8 +14,8 @@ extern "C" { namespace Libraries::Ajm { -AjmAt9Decoder::AjmAt9Decoder() { - m_handle = Atrac9GetHandle(); +AjmAt9Decoder::AjmAt9Decoder(AjmFormatEncoding format) + : m_format(format), m_handle(Atrac9GetHandle()) { ASSERT_MSG(m_handle, "Atrac9GetHandle failed"); AjmAt9Decoder::Reset(); } @@ -40,7 +40,20 @@ void AjmAt9Decoder::Initialize(const void* buffer, u32 buffer_size) { const auto params = reinterpret_cast(buffer); std::memcpy(m_config_data, params->config_data, ORBIS_AT9_CONFIG_DATA_SIZE); AjmAt9Decoder::Reset(); - m_pcm_buffer.resize(m_codec_info.frameSamples * m_codec_info.channels, 0); + m_pcm_buffer.resize(m_codec_info.frameSamples * m_codec_info.channels * GetPointCodeSize(), 0); +} + +u8 AjmAt9Decoder::GetPointCodeSize() { + switch (m_format) { + case AjmFormatEncoding::S16: + return sizeof(s16); + case AjmFormatEncoding::S32: + return sizeof(s32); + case AjmFormatEncoding::Float: + return sizeof(float); + default: + UNREACHABLE(); + } } void AjmAt9Decoder::GetInfo(void* out_info) { @@ -53,28 +66,55 @@ void AjmAt9Decoder::GetInfo(void* out_info) { std::tuple AjmAt9Decoder::ProcessData(std::span& in_buf, SparseOutputBuffer& output, AjmSidebandGaplessDecode& gapless, - u32 max_samples_per_channel) { + std::optional max_samples_per_channel) { + int ret = 0; int bytes_used = 0; - u32 ret = Atrac9Decode(m_handle, in_buf.data(), m_pcm_buffer.data(), &bytes_used); + switch (m_format) { + case AjmFormatEncoding::S16: + ret = Atrac9Decode(m_handle, in_buf.data(), reinterpret_cast(m_pcm_buffer.data()), + &bytes_used); + break; + case AjmFormatEncoding::S32: + ret = Atrac9DecodeS32(m_handle, in_buf.data(), reinterpret_cast(m_pcm_buffer.data()), + &bytes_used); + break; + case AjmFormatEncoding::Float: + ret = Atrac9DecodeF32(m_handle, in_buf.data(), + reinterpret_cast(m_pcm_buffer.data()), &bytes_used); + break; + default: + UNREACHABLE(); + } ASSERT_MSG(ret == At9Status::ERR_SUCCESS, "Atrac9Decode failed ret = {:#x}", ret); in_buf = in_buf.subspan(bytes_used); m_superframe_bytes_remain -= bytes_used; - std::span pcm_data{m_pcm_buffer}; + u32 skipped_samples = 0; if (gapless.skipped_samples < gapless.skip_samples) { - const auto skipped_samples = std::min(u32(m_codec_info.frameSamples), - u32(gapless.skip_samples - gapless.skipped_samples)); + skipped_samples = std::min(u32(m_codec_info.frameSamples), + u32(gapless.skip_samples - gapless.skipped_samples)); gapless.skipped_samples += skipped_samples; - pcm_data = pcm_data.subspan(skipped_samples * m_codec_info.channels); } - const auto max_samples = max_samples_per_channel == std::numeric_limits::max() - ? max_samples_per_channel - : max_samples_per_channel * m_codec_info.channels; + const auto max_samples = max_samples_per_channel.has_value() + ? max_samples_per_channel.value() * m_codec_info.channels + : std::numeric_limits::max(); - const auto pcm_size = std::min(u32(pcm_data.size()), max_samples); - const auto written = output.Write(pcm_data.subspan(0, pcm_size)); + size_t samples_written = 0; + switch (m_format) { + case AjmFormatEncoding::S16: + samples_written = WriteOutputSamples(output, skipped_samples, max_samples); + break; + case AjmFormatEncoding::S32: + samples_written = WriteOutputSamples(output, skipped_samples, max_samples); + break; + case AjmFormatEncoding::Float: + samples_written = WriteOutputSamples(output, skipped_samples, max_samples); + break; + default: + UNREACHABLE(); + } m_num_frames += 1; if ((m_num_frames % m_codec_info.framesInSuperframe) == 0) { @@ -85,7 +125,7 @@ std::tuple AjmAt9Decoder::ProcessData(std::span& in_buf, SparseOut m_num_frames = 0; } - return {1, (written / m_codec_info.channels) / sizeof(s16)}; + return {1, samples_written / m_codec_info.channels}; } } // namespace Libraries::Ajm diff --git a/src/core/libraries/ajm/ajm_at9.h b/src/core/libraries/ajm/ajm_at9.h index e89e69cec..ac87b5e49 100644 --- a/src/core/libraries/ajm/ajm_at9.h +++ b/src/core/libraries/ajm/ajm_at9.h @@ -8,6 +8,8 @@ #include "libatrac9.h" +#include + namespace Libraries::Ajm { constexpr s32 ORBIS_AJM_DEC_AT9_MAX_CHANNELS = 8; @@ -20,22 +22,35 @@ struct AjmSidebandDecAt9CodecInfo { }; struct AjmAt9Decoder final : AjmCodec { - explicit AjmAt9Decoder(); + explicit AjmAt9Decoder(AjmFormatEncoding format); ~AjmAt9Decoder() override; void Reset() override; void Initialize(const void* buffer, u32 buffer_size) override; void GetInfo(void* out_info) override; std::tuple ProcessData(std::span& input, SparseOutputBuffer& output, - AjmSidebandGaplessDecode& gapless, u32 max_samples) override; + AjmSidebandGaplessDecode& gapless, + std::optional max_samples) override; private: + u8 GetPointCodeSize(); + + template + size_t WriteOutputSamples(SparseOutputBuffer& output, u32 skipped_samples, u32 max_samples) { + std::span pcm_data{reinterpret_cast(m_pcm_buffer.data()), + m_pcm_buffer.size() / sizeof(T)}; + pcm_data = pcm_data.subspan(skipped_samples * m_codec_info.channels); + const auto pcm_size = std::min(u32(pcm_data.size()), max_samples); + return output.Write(pcm_data.subspan(0, pcm_size)); + } + + const AjmFormatEncoding m_format; void* m_handle{}; u8 m_config_data[ORBIS_AT9_CONFIG_DATA_SIZE]{}; u32 m_superframe_bytes_remain{}; u32 m_num_frames{}; Atrac9CodecInfo m_codec_info{}; - std::vector m_pcm_buffer; + std::vector m_pcm_buffer; }; } // namespace Libraries::Ajm diff --git a/src/core/libraries/ajm/ajm_context.cpp b/src/core/libraries/ajm/ajm_context.cpp index d831ab484..e30e1c478 100644 --- a/src/core/libraries/ajm/ajm_context.cpp +++ b/src/core/libraries/ajm/ajm_context.cpp @@ -149,7 +149,6 @@ s32 AjmContext::InstanceCreate(AjmCodecType codec_type, AjmInstanceFlags flags, if (!IsRegistered(codec_type)) { return ORBIS_AJM_ERROR_CODEC_NOT_REGISTERED; } - ASSERT_MSG(flags.format == 0, "Only signed 16-bit PCM output is supported currently!"); std::optional opt_index; { std::unique_lock lock(instances_mutex); diff --git a/src/core/libraries/ajm/ajm_instance.cpp b/src/core/libraries/ajm/ajm_instance.cpp index 259fbc268..3202b3a3d 100644 --- a/src/core/libraries/ajm/ajm_instance.cpp +++ b/src/core/libraries/ajm/ajm_instance.cpp @@ -9,14 +9,29 @@ namespace Libraries::Ajm { +constexpr int ORBIS_AJM_RESULT_NOT_INITIALIZED = 0x00000001; +constexpr int ORBIS_AJM_RESULT_INVALID_DATA = 0x00000002; +constexpr int ORBIS_AJM_RESULT_INVALID_PARAMETER = 0x00000004; +constexpr int ORBIS_AJM_RESULT_PARTIAL_INPUT = 0x00000008; +constexpr int ORBIS_AJM_RESULT_NOT_ENOUGH_ROOM = 0x00000010; +constexpr int ORBIS_AJM_RESULT_STREAM_CHANGE = 0x00000020; +constexpr int ORBIS_AJM_RESULT_TOO_MANY_CHANNELS = 0x00000040; +constexpr int ORBIS_AJM_RESULT_UNSUPPORTED_FLAG = 0x00000080; +constexpr int ORBIS_AJM_RESULT_SIDEBAND_TRUNCATED = 0x00000100; +constexpr int ORBIS_AJM_RESULT_PRIORITY_PASSED = 0x00000200; +constexpr int ORBIS_AJM_RESULT_CODEC_ERROR = 0x40000000; +constexpr int ORBIS_AJM_RESULT_FATAL = 0x80000000; + AjmInstance::AjmInstance(AjmCodecType codec_type, AjmInstanceFlags flags) : m_flags(flags) { + LOG_CRITICAL(Lib_Ajm, "Creating instance with format {}", + magic_enum::enum_name(AjmFormatEncoding(flags.format))); switch (codec_type) { case AjmCodecType::At9Dec: { - m_codec = std::make_unique(); + m_codec = std::make_unique(AjmFormatEncoding(flags.format)); break; } case AjmCodecType::Mp3Dec: { - m_codec = std::make_unique(); + m_codec = std::make_unique(AjmFormatEncoding(flags.format)); break; } default: @@ -62,9 +77,10 @@ void AjmInstance::ExecuteJob(AjmJob& job) { auto in_size = in_buf.size(); auto out_size = out_buf.Size(); while (!in_buf.empty() && !out_buf.IsEmpty() && !IsGaplessEnd()) { - const u32 samples_remain = m_gapless.total_samples != 0 - ? m_gapless.total_samples - m_gapless_samples - : std::numeric_limits::max(); + const auto samples_remain = + m_gapless.total_samples != 0 + ? std::optional{m_gapless.total_samples - m_gapless_samples} + : std::optional{}; const auto [nframes, nsamples] = m_codec->ProcessData(in_buf, out_buf, m_gapless, samples_remain); frames_decoded += nframes; diff --git a/src/core/libraries/ajm/ajm_instance.h b/src/core/libraries/ajm/ajm_instance.h index 7c00c43ad..96a30ef47 100644 --- a/src/core/libraries/ajm/ajm_instance.h +++ b/src/core/libraries/ajm/ajm_instance.h @@ -14,19 +14,6 @@ namespace Libraries::Ajm { -constexpr int ORBIS_AJM_RESULT_NOT_INITIALIZED = 0x00000001; -constexpr int ORBIS_AJM_RESULT_INVALID_DATA = 0x00000002; -constexpr int ORBIS_AJM_RESULT_INVALID_PARAMETER = 0x00000004; -constexpr int ORBIS_AJM_RESULT_PARTIAL_INPUT = 0x00000008; -constexpr int ORBIS_AJM_RESULT_NOT_ENOUGH_ROOM = 0x00000010; -constexpr int ORBIS_AJM_RESULT_STREAM_CHANGE = 0x00000020; -constexpr int ORBIS_AJM_RESULT_TOO_MANY_CHANNELS = 0x00000040; -constexpr int ORBIS_AJM_RESULT_UNSUPPORTED_FLAG = 0x00000080; -constexpr int ORBIS_AJM_RESULT_SIDEBAND_TRUNCATED = 0x00000100; -constexpr int ORBIS_AJM_RESULT_PRIORITY_PASSED = 0x00000200; -constexpr int ORBIS_AJM_RESULT_CODEC_ERROR = 0x40000000; -constexpr int ORBIS_AJM_RESULT_FATAL = 0x80000000; - class SparseOutputBuffer { public: SparseOutputBuffer(std::span> chunks) @@ -34,18 +21,19 @@ public: template size_t Write(std::span pcm) { - size_t bytes_written = 0; + size_t samples_written = 0; while (!pcm.empty() && !IsEmpty()) { auto size = std::min(pcm.size() * sizeof(T), m_current->size()); std::memcpy(m_current->data(), pcm.data(), size); - bytes_written += size; - pcm = pcm.subspan(size / sizeof(T)); + const auto nsamples = size / sizeof(T); + samples_written += nsamples; + pcm = pcm.subspan(nsamples); *m_current = m_current->subspan(size); if (m_current->empty()) { ++m_current; } } - return bytes_written; + return samples_written; } bool IsEmpty() { @@ -65,11 +53,6 @@ private: std::span>::iterator m_current; }; -struct DecodeResult { - u32 bytes_consumed{}; - u32 bytes_written{}; -}; - class AjmCodec { public: virtual ~AjmCodec() = default; @@ -79,7 +62,7 @@ public: virtual void GetInfo(void* out_info) = 0; virtual std::tuple ProcessData(std::span& input, SparseOutputBuffer& output, AjmSidebandGaplessDecode& gapless, - u32 max_samples) = 0; + std::optional max_samples_per_channel) = 0; }; class AjmInstance { @@ -100,9 +83,6 @@ private: u32 m_total_samples{}; std::unique_ptr m_codec; - - // AjmCodecType codec_type; - // u32 index{}; }; } // namespace Libraries::Ajm diff --git a/src/core/libraries/ajm/ajm_mp3.cpp b/src/core/libraries/ajm/ajm_mp3.cpp index 542b31bee..d39b5bb88 100644 --- a/src/core/libraries/ajm/ajm_mp3.cpp +++ b/src/core/libraries/ajm/ajm_mp3.cpp @@ -30,9 +30,27 @@ static constexpr std::array UnkTable = {0x48, 0x90}; SwrContext* swr_context{}; -AVFrame* ConvertAudioFrame(AVFrame* frame) { +static AVSampleFormat AjmToAVSampleFormat(AjmFormatEncoding format) { + switch (format) { + case AjmFormatEncoding::S16: + return AV_SAMPLE_FMT_S16; + case AjmFormatEncoding::S32: + return AV_SAMPLE_FMT_S32; + case AjmFormatEncoding::Float: + return AV_SAMPLE_FMT_FLT; + default: + UNREACHABLE(); + } +} + +AVFrame* AjmMp3Decoder::ConvertAudioFrame(AVFrame* frame) { + AVSampleFormat format = AjmToAVSampleFormat(m_format); + if (frame->format == format) { + return frame; + } + auto pcm16_frame = av_frame_clone(frame); - pcm16_frame->format = AV_SAMPLE_FMT_S16; + pcm16_frame->format = format; if (swr_context) { swr_free(&swr_context); @@ -40,9 +58,9 @@ AVFrame* ConvertAudioFrame(AVFrame* frame) { } AVChannelLayout in_ch_layout = frame->ch_layout; AVChannelLayout out_ch_layout = pcm16_frame->ch_layout; - swr_alloc_set_opts2(&swr_context, &out_ch_layout, AV_SAMPLE_FMT_S16, frame->sample_rate, - &in_ch_layout, AVSampleFormat(frame->format), frame->sample_rate, 0, - nullptr); + swr_alloc_set_opts2(&swr_context, &out_ch_layout, AVSampleFormat(pcm16_frame->format), + frame->sample_rate, &in_ch_layout, AVSampleFormat(frame->format), + frame->sample_rate, 0, nullptr); swr_init(swr_context); const auto res = swr_convert_frame(swr_context, pcm16_frame, frame); if (res < 0) { @@ -53,11 +71,9 @@ AVFrame* ConvertAudioFrame(AVFrame* frame) { return pcm16_frame; } -AjmMp3Decoder::AjmMp3Decoder() { - m_codec = avcodec_find_decoder(AV_CODEC_ID_MP3); - ASSERT_MSG(m_codec, "MP3 m_codec not found"); - m_parser = av_parser_init(m_codec->id); - ASSERT_MSG(m_parser, "Parser not found"); +AjmMp3Decoder::AjmMp3Decoder(AjmFormatEncoding format) + : m_format(format), m_codec(avcodec_find_decoder(AV_CODEC_ID_MP3)), + m_parser(av_parser_init(m_codec->id)) { AjmMp3Decoder::Reset(); } @@ -81,7 +97,7 @@ void AjmMp3Decoder::GetInfo(void* out_info) { std::tuple AjmMp3Decoder::ProcessData(std::span& in_buf, SparseOutputBuffer& output, AjmSidebandGaplessDecode& gapless, - u32 max_samples) { + std::optional max_samples_per_channel) { AVPacket* pkt = av_packet_alloc(); int ret = av_parser_parse2(m_parser, m_codec_context, &pkt->data, &pkt->size, in_buf.data(), @@ -109,24 +125,37 @@ std::tuple AjmMp3Decoder::ProcessData(std::span& in_buf, SparseOut } else if (ret < 0) { UNREACHABLE_MSG("Error during decoding"); } - if (frame->format != AV_SAMPLE_FMT_S16) { - frame = ConvertAudioFrame(frame); - } + frame = ConvertAudioFrame(frame); frames_decoded += 1; - samples_decoded += frame->nb_samples; - const auto size = frame->ch_layout.nb_channels * frame->nb_samples * sizeof(u16); - std::span pcm_data(reinterpret_cast(frame->data[0]), size >> 1); + u32 skipped_samples = 0; if (gapless.skipped_samples < gapless.skip_samples) { - const auto skipped_samples = std::min( - u32(frame->nb_samples), u32(gapless.skip_samples - gapless.skipped_samples)); + skipped_samples = std::min(u32(frame->nb_samples), + u32(gapless.skip_samples - gapless.skipped_samples)); gapless.skipped_samples += skipped_samples; - pcm_data = pcm_data.subspan(skipped_samples * frame->ch_layout.nb_channels); - samples_decoded -= skipped_samples; } - const auto pcm_size = std::min(u32(pcm_data.size()), max_samples); - output.Write(pcm_data.subspan(0, pcm_size)); + const auto max_samples = + max_samples_per_channel.has_value() + ? max_samples_per_channel.value() * frame->ch_layout.nb_channels + : std::numeric_limits::max(); + + switch (m_format) { + case AjmFormatEncoding::S16: + samples_decoded += + WriteOutputSamples(frame, output, skipped_samples, max_samples); + break; + case AjmFormatEncoding::S32: + samples_decoded += + WriteOutputSamples(frame, output, skipped_samples, max_samples); + break; + case AjmFormatEncoding::Float: + samples_decoded += + WriteOutputSamples(frame, output, skipped_samples, max_samples); + break; + default: + UNREACHABLE(); + } av_frame_free(&frame); } diff --git a/src/core/libraries/ajm/ajm_mp3.h b/src/core/libraries/ajm/ajm_mp3.h index d2acf7abc..e6f4f7af6 100644 --- a/src/core/libraries/ajm/ajm_mp3.h +++ b/src/core/libraries/ajm/ajm_mp3.h @@ -7,11 +7,7 @@ #include "core/libraries/ajm/ajm_instance.h" extern "C" { -struct AVCodec; -struct AVCodecContext; -struct AVCodecParserContext; -struct AVFrame; -struct AVPacket; +#include } namespace Libraries::Ajm { @@ -54,19 +50,34 @@ struct AjmSidebandDecMp3CodecInfo { class AjmMp3Decoder : public AjmCodec { public: - explicit AjmMp3Decoder(); + explicit AjmMp3Decoder(AjmFormatEncoding format); ~AjmMp3Decoder() override; void Reset() override; void Initialize(const void* buffer, u32 buffer_size) override {} void GetInfo(void* out_info) override; std::tuple ProcessData(std::span& input, SparseOutputBuffer& output, - AjmSidebandGaplessDecode& gapless, u32 max_samples) override; + AjmSidebandGaplessDecode& gapless, + std::optional max_samples_per_channel) override; static int ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl, AjmDecMp3ParseFrame* frame); private: + template + size_t WriteOutputSamples(AVFrame* frame, SparseOutputBuffer& output, u32 skipped_samples, + u32 max_samples) { + const auto size = frame->ch_layout.nb_channels * frame->nb_samples * sizeof(T); + std::span pcm_data(reinterpret_cast(frame->data[0]), size >> 1); + pcm_data = pcm_data.subspan(skipped_samples * frame->ch_layout.nb_channels); + const auto pcm_size = std::min(u32(pcm_data.size()), max_samples); + const auto samples_written = output.Write(pcm_data.subspan(0, pcm_size)); + return samples_written / frame->ch_layout.nb_channels; + } + + AVFrame* ConvertAudioFrame(AVFrame* frame); + + const AjmFormatEncoding m_format; const AVCodec* m_codec = nullptr; AVCodecContext* m_codec_context = nullptr; AVCodecParserContext* m_parser = nullptr;