diff --git a/src/core/libraries/ajm/ajm.cpp b/src/core/libraries/ajm/ajm.cpp index 19bc6511f..c31a3c4c0 100644 --- a/src/core/libraries/ajm/ajm.cpp +++ b/src/core/libraries/ajm/ajm.cpp @@ -165,24 +165,8 @@ struct AjmDevice { p_instance->gapless.skip_samples = params.skip_samples; } - ASSERT_MSG(job.input.buffers.size() <= job.output.buffers.size(), - "Unsupported combination of input/output buffers."); - - for (size_t i = 0; i < job.input.buffers.size(); ++i) { - // Decode as much of the input bitstream as possible. - const auto& in_buffer = job.input.buffers[i]; - auto& out_buffer = job.output.buffers[i]; - - const u8* in_address = in_buffer.data(); - u8* out_address = out_buffer.data(); - const auto [in_remain, out_remain] = p_instance->Decode( - in_address, in_buffer.size(), out_address, out_buffer.size(), &job.output); - - if (job.output.p_stream != nullptr) { - job.output.p_stream->input_consumed += in_buffer.size() - in_remain; - job.output.p_stream->output_written += out_buffer.size() - out_remain; - job.output.p_stream->total_decoded_samples += p_instance->decoded_samples; - } + if (!job.input.buffer.empty()) { + p_instance->Decode(&job.input, &job.output); } if (job.output.p_gapless_decode != nullptr) { @@ -439,7 +423,6 @@ int PS4_SYSV_ABI sceAjmBatchStartBuffer(u32 context, u8* p_batch, u32 batch_size const auto batch_info = std::make_shared(); auto batch_id = dev->batches.Create(batch_info); if (!batch_id.has_value()) { - LOG_ERROR(Lib_Ajm, "Too many batches in job!"); return ORBIS_AJM_ERROR_OUT_OF_MEMORY; } batch_info->id = batch_id.value(); @@ -471,7 +454,7 @@ int PS4_SYSV_ABI sceAjmBatchStartBuffer(u32 context, u8* p_batch, u32 batch_size case Identifier::AjmIdentInputRunBuf: { auto& buffer = AjmBufferExtract(p_current); u8* p_begin = reinterpret_cast(buffer.p_address); - job.input.buffers.emplace_back( + job.input.buffer.append_range( std::vector(p_begin, p_begin + buffer.header.size)); break; } @@ -614,7 +597,6 @@ int PS4_SYSV_ABI sceAjmBatchWait(const u32 context, const u32 batch_id, const u3 std::lock_guard guard(dev->batches_mutex); const auto opt_batch = dev->batches.Get(batch_id); if (!opt_batch.has_value()) { - LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_INVALID_BATCH"); return ORBIS_AJM_ERROR_INVALID_BATCH; } @@ -623,7 +605,6 @@ int PS4_SYSV_ABI sceAjmBatchWait(const u32 context, const u32 batch_id, const u3 bool expected = false; if (!batch->waiting.compare_exchange_strong(expected, true)) { - LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_BUSY"); return ORBIS_AJM_ERROR_BUSY; } @@ -631,7 +612,6 @@ int PS4_SYSV_ABI sceAjmBatchWait(const u32 context, const u32 batch_id, const u3 batch->finished.acquire(); } else if (!batch->finished.try_acquire_for(std::chrono::milliseconds(timeout))) { batch->waiting = false; - LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_IN_PROGRESS"); return ORBIS_AJM_ERROR_IN_PROGRESS; } @@ -641,11 +621,9 @@ int PS4_SYSV_ABI sceAjmBatchWait(const u32 context, const u32 batch_id, const u3 } if (batch->canceled) { - LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_CANCELLED"); return ORBIS_AJM_ERROR_CANCELLED; } - LOG_INFO(Lib_Ajm, "ORBIS_OK"); return ORBIS_OK; } @@ -656,7 +634,7 @@ int PS4_SYSV_ABI sceAjmDecAt9ParseConfigData() { int PS4_SYSV_ABI sceAjmDecMp3ParseFrame(const u8* buf, u32 stream_size, int parse_ofl, AjmDecMp3ParseFrame* frame) { - LOG_INFO(Lib_Ajm, "called parse_ofl = {}", parse_ofl); + LOG_INFO(Lib_Ajm, "called stream_size = {} parse_ofl = {}", stream_size, parse_ofl); if (buf == nullptr || stream_size < 4 || frame == nullptr) { return ORBIS_AJM_ERROR_INVALID_PARAMETER; } @@ -688,6 +666,9 @@ int PS4_SYSV_ABI sceAjmInstanceCodecType() { int PS4_SYSV_ABI sceAjmInstanceCreate(u32 context, AjmCodecType codec_type, AjmInstanceFlags flags, u32* out_instance) { + LOG_INFO(Lib_Ajm, "called context = {}, codec_type = {}, flags = {:#x}", context, + magic_enum::enum_name(codec_type), flags.raw); + if (codec_type >= AjmCodecType::Max) { return ORBIS_AJM_ERROR_INVALID_PARAMETER; } @@ -720,8 +701,8 @@ int PS4_SYSV_ABI sceAjmInstanceCreate(u32 context, AjmCodecType codec_type, AjmI instance->flags = flags; dev->instances[index] = std::move(instance); *out_instance = index; - LOG_INFO(Lib_Ajm, "called codec_type = {}, flags = {:#x}, instance = {}", - magic_enum::enum_name(codec_type), flags.raw, index); + + LOG_INFO(Lib_Ajm, "instance = {}", index); return ORBIS_OK; } diff --git a/src/core/libraries/ajm/ajm_at9.cpp b/src/core/libraries/ajm/ajm_at9.cpp index 7e13c31e7..59e45bf57 100644 --- a/src/core/libraries/ajm/ajm_at9.cpp +++ b/src/core/libraries/ajm/ajm_at9.cpp @@ -26,17 +26,23 @@ AjmAt9Decoder::~AjmAt9Decoder() { } void AjmAt9Decoder::Reset() { - num_frames = 0; - decoded_samples = 0; + total_decoded_samples = 0; gapless = {}; + ResetCodec(); +} + +void AjmAt9Decoder::ResetCodec() { Atrac9ReleaseHandle(handle); handle = Atrac9GetHandle(); Atrac9InitDecoder(handle, config_data); Atrac9CodecInfo codec_info; Atrac9GetCodecInfo(handle, &codec_info); - bytes_remain = codec_info.superframeSize; + num_frames = 0; + superframe_bytes_remain = codec_info.superframeSize; + gapless.skipped_samples = 0; + gapless_decoded_samples = 0; } void AjmAt9Decoder::Initialize(const void* buffer, u32 buffer_size) { @@ -58,72 +64,106 @@ void AjmAt9Decoder::GetCodecInfo(void* out_info) { codec_info->uiSuperFrameSize = decoder_codec_info.superframeSize; } -std::tuple AjmAt9Decoder::Decode(const u8* in_buf, u32 in_size_in, u8* out_buf, - u32 out_size_in, AjmJobOutput* output) { - const auto decoder_handle = static_cast(handle); +void AjmAt9Decoder::Decode(const AjmJobInput* input, AjmJobOutput* output) { + LOG_TRACE(Lib_Ajm, "Decoding with instance {} in size = {}", index, input->buffer.size()); Atrac9CodecInfo codec_info; Atrac9GetCodecInfo(handle, &codec_info); - int bytes_used = 0; - int num_superframes = 0; - - u32 in_size = in_size_in; - u32 out_size = out_size_in; - - const auto ShouldDecode = [&] { - if (in_size == 0 || out_size == 0) { + size_t out_buffer_index = 0; + std::span in_buf(input->buffer); + std::span out_buf = output->buffers[out_buffer_index]; + const auto should_decode = [&] { + if (in_buf.empty() || out_buf.empty()) { return false; } - if (gapless.total_samples != 0 && gapless.total_samples < decoded_samples) { + if (gapless.total_samples != 0 && gapless.total_samples < gapless_decoded_samples) { return false; } return true; }; - const auto written_size = codec_info.channels * codec_info.frameSamples * sizeof(u16); - std::vector pcm_buffer(written_size >> 1); - while (ShouldDecode()) { - u32 ret = Atrac9Decode(decoder_handle, in_buf, pcm_buffer.data(), &bytes_used); - ASSERT_MSG(ret == At9Status::ERR_SUCCESS, "Atrac9Decode failed ret = {:#x}", ret); - in_buf += bytes_used; - in_size -= bytes_used; - if (output->p_mframe) { - ++output->p_mframe->num_frames; + const auto write_output = [&](std::span pcm) { + while (!pcm.empty()) { + auto size = std::min(pcm.size() * sizeof(u16), out_buf.size()); + std::memcpy(out_buf.data(), pcm.data(), size); + pcm = pcm.subspan(size >> 1); + out_buf = out_buf.subspan(size); + if (out_buf.empty()) { + out_buffer_index += 1; + if (out_buffer_index >= output->buffers.size()) { + return pcm.empty(); + } + out_buf = output->buffers[out_buffer_index]; + } } - num_frames++; - bytes_remain -= bytes_used; + return true; + }; + + int num_superframes = 0; + const auto pcm_frame_size = codec_info.channels * codec_info.frameSamples * sizeof(u16); + std::vector pcm_buffer(pcm_frame_size >> 1); + while (should_decode()) { + int bytes_used = 0; + u32 ret = Atrac9Decode(handle, in_buf.data(), pcm_buffer.data(), &bytes_used); + ASSERT_MSG(ret == At9Status::ERR_SUCCESS, "Atrac9Decode failed ret = {:#x}", ret); + in_buf = in_buf.subspan(bytes_used); + superframe_bytes_remain -= bytes_used; + const size_t samples_remain = gapless.total_samples != 0 + ? gapless.total_samples - gapless_decoded_samples + : std::numeric_limits::max(); + bool written = false; if (gapless.skipped_samples < gapless.skip_samples) { - gapless.skipped_samples += decoder_handle->Config.FrameSamples; + gapless.skipped_samples += codec_info.frameSamples; if (gapless.skipped_samples > gapless.skip_samples) { - const auto size = gapless.skipped_samples - gapless.skip_samples; - const auto start = decoder_handle->Config.FrameSamples - size; - memcpy(out_buf, pcm_buffer.data() + start, size * sizeof(s16)); - out_buf += size * sizeof(s16); - out_size -= size * sizeof(s16); + const u32 nsamples = gapless.skipped_samples - gapless.skip_samples; + const auto start = codec_info.frameSamples - nsamples; + written = write_output({pcm_buffer.data() + start, nsamples}); + gapless.skipped_samples = gapless.skip_samples; + total_decoded_samples += nsamples; + gapless_decoded_samples += nsamples; } } else { - memcpy(out_buf, pcm_buffer.data(), written_size); - out_buf += written_size; - out_size -= written_size; + written = + write_output({pcm_buffer.data(), std::min(pcm_buffer.size(), samples_remain)}); + total_decoded_samples += codec_info.frameSamples; + gapless_decoded_samples += codec_info.frameSamples; } - decoded_samples += decoder_handle->Config.FrameSamples; + + num_frames += 1; if ((num_frames % codec_info.framesInSuperframe) == 0) { - in_buf += bytes_remain; - in_size -= bytes_remain; - bytes_remain = codec_info.superframeSize; - num_superframes++; + if (superframe_bytes_remain) { + if (output->p_stream) { + output->p_stream->input_consumed += superframe_bytes_remain; + } + in_buf = in_buf.subspan(superframe_bytes_remain); + } + superframe_bytes_remain = codec_info.superframeSize; + num_superframes += 1; + } + if (output->p_stream) { + output->p_stream->input_consumed += bytes_used; + if (written) { + output->p_stream->output_written += + std::min(pcm_frame_size, samples_remain * sizeof(16)); + } + } + if (output->p_mframe) { + output->p_mframe->num_frames += 1; } } - if (gapless.total_samples == decoded_samples) { - decoded_samples = 0; + if (gapless_decoded_samples >= gapless.total_samples) { if (flags.gapless_loop) { - gapless.skipped_samples = 0; + ResetCodec(); } } - LOG_TRACE(Lib_Ajm, "Decoded {} samples, frame count: {}", decoded_samples, num_frames); - return std::tuple(in_size, out_size); + if (output->p_stream) { + output->p_stream->total_decoded_samples = total_decoded_samples; + } + + LOG_TRACE(Lib_Ajm, "Decoded buffer, in remain = {}, out remain = {}", in_buf.size(), + out_buf.size()); } } // namespace Libraries::Ajm diff --git a/src/core/libraries/ajm/ajm_at9.h b/src/core/libraries/ajm/ajm_at9.h index d4e268eea..919314f26 100644 --- a/src/core/libraries/ajm/ajm_at9.h +++ b/src/core/libraries/ajm/ajm_at9.h @@ -30,6 +30,8 @@ struct AjmAt9Decoder final : AjmInstance { std::fstream file; int length; u8 config_data[ORBIS_AT9_CONFIG_DATA_SIZE]; + u32 superframe_bytes_remain{}; + u32 num_frames{}; explicit AjmAt9Decoder(); ~AjmAt9Decoder() override; @@ -43,8 +45,10 @@ struct AjmAt9Decoder final : AjmInstance { return sizeof(AjmSidebandDecAt9CodecInfo); } - std::tuple Decode(const u8* in_buf, u32 in_size, u8* out_buf, u32 out_size, - AjmJobOutput* output) override; + void Decode(const AjmJobInput* input, AjmJobOutput* output) override; + +private: + void ResetCodec(); }; } // namespace Libraries::Ajm diff --git a/src/core/libraries/ajm/ajm_instance.h b/src/core/libraries/ajm/ajm_instance.h index a14eb5ac7..126ce1ee4 100644 --- a/src/core/libraries/ajm/ajm_instance.h +++ b/src/core/libraries/ajm/ajm_instance.h @@ -101,7 +101,7 @@ struct AjmJobInput { std::optional resample_parameters; std::optional format; std::optional gapless_decode; - boost::container::small_vector, 4> buffers; + std::vector buffer; }; struct AjmJobOutput { @@ -132,9 +132,8 @@ struct AjmInstance { AjmInstanceFlags flags{.raw = 0}; u32 num_channels{}; u32 index{}; - u32 bytes_remain{}; - u32 num_frames{}; - u32 decoded_samples{}; + u32 gapless_decoded_samples{}; + u32 total_decoded_samples{}; AjmSidebandFormat format{}; AjmSidebandGaplessDecode gapless{}; AjmSidebandResampleParameters resample_parameters{}; @@ -149,8 +148,7 @@ struct AjmInstance { virtual void GetCodecInfo(void* out_info) = 0; virtual u32 GetCodecInfoSize() = 0; - virtual std::tuple Decode(const u8* in_buf, u32 in_size, u8* out_buf, u32 out_size, - AjmJobOutput* output) = 0; + virtual void Decode(const AjmJobInput* input, AjmJobOutput* output) = 0; }; } // namespace Libraries::Ajm diff --git a/src/core/libraries/ajm/ajm_mp3.cpp b/src/core/libraries/ajm/ajm_mp3.cpp index acdf3d3f0..0a43883b5 100644 --- a/src/core/libraries/ajm/ajm_mp3.cpp +++ b/src/core/libraries/ajm/ajm_mp3.cpp @@ -73,20 +73,51 @@ void AjmMp3Decoder::Reset() { ASSERT_MSG(c, "Could not allocate audio codec context"); int ret = avcodec_open2(c, codec, nullptr); ASSERT_MSG(ret >= 0, "Could not open codec"); - decoded_samples = 0; - num_frames = 0; + total_decoded_samples = 0; } -std::tuple AjmMp3Decoder::Decode(const u8* buf, u32 in_size, u8* out_buf, u32 out_size, - AjmJobOutput* output) { +void AjmMp3Decoder::Decode(const AjmJobInput* input, AjmJobOutput* output) { AVPacket* pkt = av_packet_alloc(); - while (in_size > 0 && out_size > 0) { - int ret = av_parser_parse2(parser, c, &pkt->data, &pkt->size, buf, in_size, AV_NOPTS_VALUE, - AV_NOPTS_VALUE, 0); - ASSERT_MSG(ret >= 0, "Error while parsing {}", ret); - buf += ret; - in_size -= ret; + size_t out_buffer_index = 0; + std::span in_buf(input->buffer); + std::span out_buf = output->buffers[out_buffer_index]; + const auto should_decode = [&] { + if (in_buf.empty() || out_buf.empty()) { + return false; + } + if (gapless.total_samples != 0 && gapless.total_samples < gapless_decoded_samples) { + return false; + } + return true; + }; + + const auto write_output = [&](std::span pcm) { + while (!pcm.empty()) { + auto size = std::min(pcm.size() * sizeof(u16), out_buf.size()); + std::memcpy(out_buf.data(), pcm.data(), size); + pcm = pcm.subspan(size >> 1); + out_buf = out_buf.subspan(size); + if (out_buf.empty()) { + out_buffer_index += 1; + if (out_buffer_index >= output->buffers.size()) { + return pcm.empty(); + } + out_buf = output->buffers[out_buffer_index]; + } + } + return true; + }; + + while (should_decode()) { + int ret = av_parser_parse2(parser, c, &pkt->data, &pkt->size, in_buf.data(), in_buf.size(), + AV_NOPTS_VALUE, AV_NOPTS_VALUE, 0); + ASSERT_MSG(ret >= 0, "Error while parsing {}", ret); + in_buf = in_buf.subspan(ret); + + if (output->p_stream) { + output->p_stream->input_consumed += ret; + } if (pkt->size) { // Send the packet with the compressed data to the decoder pkt->pts = parser->pts; @@ -107,22 +138,43 @@ std::tuple AjmMp3Decoder::Decode(const u8* buf, u32 in_size, u8* out_b if (frame->format != AV_SAMPLE_FMT_S16) { frame = ConvertAudioFrame(frame); } - const auto size = frame->ch_layout.nb_channels * frame->nb_samples * sizeof(u16); - std::memcpy(out_buf, frame->data[0], size); - file.write((const char*)frame->data[0], size); - out_buf += size; - out_size -= size; - decoded_samples += frame->nb_samples; - num_frames++; + const auto frame_samples = frame->ch_layout.nb_channels * frame->nb_samples; + const auto size = frame_samples * sizeof(u16); + if (gapless.skipped_samples < gapless.skip_samples) { + gapless.skipped_samples += frame_samples; + if (gapless.skipped_samples > gapless.skip_samples) { + const u32 nsamples = gapless.skipped_samples - gapless.skip_samples; + const auto start = frame_samples - nsamples; + write_output({reinterpret_cast(frame->data[0]), nsamples}); + gapless.skipped_samples = gapless.skip_samples; + total_decoded_samples += nsamples; + gapless_decoded_samples += nsamples; + } + } else { + write_output({reinterpret_cast(frame->data[0]), size >> 1}); + total_decoded_samples += frame_samples; + gapless_decoded_samples += frame_samples; + } av_frame_free(&frame); + if (output->p_stream) { + output->p_stream->output_written += size; + } + if (output->p_mframe) { + output->p_mframe->num_frames += 1; + } } } } av_packet_free(&pkt); - if (output->p_mframe) { - output->p_mframe->num_frames += num_frames; + if (gapless_decoded_samples >= gapless.total_samples) { + if (flags.gapless_loop) { + gapless.skipped_samples = 0; + gapless_decoded_samples = 0; + } + } + if (output->p_stream) { + output->p_stream->total_decoded_samples = total_decoded_samples; } - return std::make_tuple(in_size, out_size); } int AjmMp3Decoder::ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl, diff --git a/src/core/libraries/ajm/ajm_mp3.h b/src/core/libraries/ajm/ajm_mp3.h index 69ba25e19..8a93ca78c 100644 --- a/src/core/libraries/ajm/ajm_mp3.h +++ b/src/core/libraries/ajm/ajm_mp3.h @@ -74,8 +74,7 @@ struct AjmMp3Decoder : public AjmInstance { return sizeof(AjmSidebandDecMp3CodecInfo); } - std::tuple Decode(const u8* in_buf, u32 in_size, u8* out_buf, u32 out_size, - AjmJobOutput* output) override; + void Decode(const AjmJobInput* input, AjmJobOutput* output) override; static int ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl, AjmDecMp3ParseFrame* frame);