Fix gapless decode and combine split buffers

This commit is contained in:
Vladislav Mikhalin 2024-10-29 20:37:04 +03:00
parent da5f7f232a
commit d7f78e6720
6 changed files with 177 additions and 103 deletions

View File

@ -165,24 +165,8 @@ struct AjmDevice {
p_instance->gapless.skip_samples = params.skip_samples; p_instance->gapless.skip_samples = params.skip_samples;
} }
ASSERT_MSG(job.input.buffers.size() <= job.output.buffers.size(), if (!job.input.buffer.empty()) {
"Unsupported combination of input/output buffers."); p_instance->Decode(&job.input, &job.output);
for (size_t i = 0; i < job.input.buffers.size(); ++i) {
// Decode as much of the input bitstream as possible.
const auto& in_buffer = job.input.buffers[i];
auto& out_buffer = job.output.buffers[i];
const u8* in_address = in_buffer.data();
u8* out_address = out_buffer.data();
const auto [in_remain, out_remain] = p_instance->Decode(
in_address, in_buffer.size(), out_address, out_buffer.size(), &job.output);
if (job.output.p_stream != nullptr) {
job.output.p_stream->input_consumed += in_buffer.size() - in_remain;
job.output.p_stream->output_written += out_buffer.size() - out_remain;
job.output.p_stream->total_decoded_samples += p_instance->decoded_samples;
}
} }
if (job.output.p_gapless_decode != nullptr) { if (job.output.p_gapless_decode != nullptr) {
@ -439,7 +423,6 @@ int PS4_SYSV_ABI sceAjmBatchStartBuffer(u32 context, u8* p_batch, u32 batch_size
const auto batch_info = std::make_shared<BatchInfo>(); const auto batch_info = std::make_shared<BatchInfo>();
auto batch_id = dev->batches.Create(batch_info); auto batch_id = dev->batches.Create(batch_info);
if (!batch_id.has_value()) { if (!batch_id.has_value()) {
LOG_ERROR(Lib_Ajm, "Too many batches in job!");
return ORBIS_AJM_ERROR_OUT_OF_MEMORY; return ORBIS_AJM_ERROR_OUT_OF_MEMORY;
} }
batch_info->id = batch_id.value(); batch_info->id = batch_id.value();
@ -471,7 +454,7 @@ int PS4_SYSV_ABI sceAjmBatchStartBuffer(u32 context, u8* p_batch, u32 batch_size
case Identifier::AjmIdentInputRunBuf: { case Identifier::AjmIdentInputRunBuf: {
auto& buffer = AjmBufferExtract<AjmChunkBuffer>(p_current); auto& buffer = AjmBufferExtract<AjmChunkBuffer>(p_current);
u8* p_begin = reinterpret_cast<u8*>(buffer.p_address); u8* p_begin = reinterpret_cast<u8*>(buffer.p_address);
job.input.buffers.emplace_back( job.input.buffer.append_range(
std::vector<u8>(p_begin, p_begin + buffer.header.size)); std::vector<u8>(p_begin, p_begin + buffer.header.size));
break; break;
} }
@ -614,7 +597,6 @@ int PS4_SYSV_ABI sceAjmBatchWait(const u32 context, const u32 batch_id, const u3
std::lock_guard guard(dev->batches_mutex); std::lock_guard guard(dev->batches_mutex);
const auto opt_batch = dev->batches.Get(batch_id); const auto opt_batch = dev->batches.Get(batch_id);
if (!opt_batch.has_value()) { if (!opt_batch.has_value()) {
LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_INVALID_BATCH");
return ORBIS_AJM_ERROR_INVALID_BATCH; return ORBIS_AJM_ERROR_INVALID_BATCH;
} }
@ -623,7 +605,6 @@ int PS4_SYSV_ABI sceAjmBatchWait(const u32 context, const u32 batch_id, const u3
bool expected = false; bool expected = false;
if (!batch->waiting.compare_exchange_strong(expected, true)) { if (!batch->waiting.compare_exchange_strong(expected, true)) {
LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_BUSY");
return ORBIS_AJM_ERROR_BUSY; return ORBIS_AJM_ERROR_BUSY;
} }
@ -631,7 +612,6 @@ int PS4_SYSV_ABI sceAjmBatchWait(const u32 context, const u32 batch_id, const u3
batch->finished.acquire(); batch->finished.acquire();
} else if (!batch->finished.try_acquire_for(std::chrono::milliseconds(timeout))) { } else if (!batch->finished.try_acquire_for(std::chrono::milliseconds(timeout))) {
batch->waiting = false; batch->waiting = false;
LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_IN_PROGRESS");
return ORBIS_AJM_ERROR_IN_PROGRESS; return ORBIS_AJM_ERROR_IN_PROGRESS;
} }
@ -641,11 +621,9 @@ int PS4_SYSV_ABI sceAjmBatchWait(const u32 context, const u32 batch_id, const u3
} }
if (batch->canceled) { if (batch->canceled) {
LOG_INFO(Lib_Ajm, "ORBIS_AJM_ERROR_CANCELLED");
return ORBIS_AJM_ERROR_CANCELLED; return ORBIS_AJM_ERROR_CANCELLED;
} }
LOG_INFO(Lib_Ajm, "ORBIS_OK");
return ORBIS_OK; return ORBIS_OK;
} }
@ -656,7 +634,7 @@ int PS4_SYSV_ABI sceAjmDecAt9ParseConfigData() {
int PS4_SYSV_ABI sceAjmDecMp3ParseFrame(const u8* buf, u32 stream_size, int parse_ofl, int PS4_SYSV_ABI sceAjmDecMp3ParseFrame(const u8* buf, u32 stream_size, int parse_ofl,
AjmDecMp3ParseFrame* frame) { AjmDecMp3ParseFrame* frame) {
LOG_INFO(Lib_Ajm, "called parse_ofl = {}", parse_ofl); LOG_INFO(Lib_Ajm, "called stream_size = {} parse_ofl = {}", stream_size, parse_ofl);
if (buf == nullptr || stream_size < 4 || frame == nullptr) { if (buf == nullptr || stream_size < 4 || frame == nullptr) {
return ORBIS_AJM_ERROR_INVALID_PARAMETER; return ORBIS_AJM_ERROR_INVALID_PARAMETER;
} }
@ -688,6 +666,9 @@ int PS4_SYSV_ABI sceAjmInstanceCodecType() {
int PS4_SYSV_ABI sceAjmInstanceCreate(u32 context, AjmCodecType codec_type, AjmInstanceFlags flags, int PS4_SYSV_ABI sceAjmInstanceCreate(u32 context, AjmCodecType codec_type, AjmInstanceFlags flags,
u32* out_instance) { u32* out_instance) {
LOG_INFO(Lib_Ajm, "called context = {}, codec_type = {}, flags = {:#x}", context,
magic_enum::enum_name(codec_type), flags.raw);
if (codec_type >= AjmCodecType::Max) { if (codec_type >= AjmCodecType::Max) {
return ORBIS_AJM_ERROR_INVALID_PARAMETER; return ORBIS_AJM_ERROR_INVALID_PARAMETER;
} }
@ -720,8 +701,8 @@ int PS4_SYSV_ABI sceAjmInstanceCreate(u32 context, AjmCodecType codec_type, AjmI
instance->flags = flags; instance->flags = flags;
dev->instances[index] = std::move(instance); dev->instances[index] = std::move(instance);
*out_instance = index; *out_instance = index;
LOG_INFO(Lib_Ajm, "called codec_type = {}, flags = {:#x}, instance = {}",
magic_enum::enum_name(codec_type), flags.raw, index); LOG_INFO(Lib_Ajm, "instance = {}", index);
return ORBIS_OK; return ORBIS_OK;
} }

View File

@ -26,17 +26,23 @@ AjmAt9Decoder::~AjmAt9Decoder() {
} }
void AjmAt9Decoder::Reset() { void AjmAt9Decoder::Reset() {
num_frames = 0; total_decoded_samples = 0;
decoded_samples = 0;
gapless = {}; gapless = {};
ResetCodec();
}
void AjmAt9Decoder::ResetCodec() {
Atrac9ReleaseHandle(handle); Atrac9ReleaseHandle(handle);
handle = Atrac9GetHandle(); handle = Atrac9GetHandle();
Atrac9InitDecoder(handle, config_data); Atrac9InitDecoder(handle, config_data);
Atrac9CodecInfo codec_info; Atrac9CodecInfo codec_info;
Atrac9GetCodecInfo(handle, &codec_info); Atrac9GetCodecInfo(handle, &codec_info);
bytes_remain = codec_info.superframeSize; num_frames = 0;
superframe_bytes_remain = codec_info.superframeSize;
gapless.skipped_samples = 0;
gapless_decoded_samples = 0;
} }
void AjmAt9Decoder::Initialize(const void* buffer, u32 buffer_size) { void AjmAt9Decoder::Initialize(const void* buffer, u32 buffer_size) {
@ -58,72 +64,106 @@ void AjmAt9Decoder::GetCodecInfo(void* out_info) {
codec_info->uiSuperFrameSize = decoder_codec_info.superframeSize; codec_info->uiSuperFrameSize = decoder_codec_info.superframeSize;
} }
std::tuple<u32, u32> AjmAt9Decoder::Decode(const u8* in_buf, u32 in_size_in, u8* out_buf, void AjmAt9Decoder::Decode(const AjmJobInput* input, AjmJobOutput* output) {
u32 out_size_in, AjmJobOutput* output) { LOG_TRACE(Lib_Ajm, "Decoding with instance {} in size = {}", index, input->buffer.size());
const auto decoder_handle = static_cast<Atrac9Handle*>(handle);
Atrac9CodecInfo codec_info; Atrac9CodecInfo codec_info;
Atrac9GetCodecInfo(handle, &codec_info); Atrac9GetCodecInfo(handle, &codec_info);
int bytes_used = 0; size_t out_buffer_index = 0;
int num_superframes = 0; std::span<const u8> in_buf(input->buffer);
std::span<u8> out_buf = output->buffers[out_buffer_index];
u32 in_size = in_size_in; const auto should_decode = [&] {
u32 out_size = out_size_in; if (in_buf.empty() || out_buf.empty()) {
const auto ShouldDecode = [&] {
if (in_size == 0 || out_size == 0) {
return false; return false;
} }
if (gapless.total_samples != 0 && gapless.total_samples < decoded_samples) { if (gapless.total_samples != 0 && gapless.total_samples < gapless_decoded_samples) {
return false; return false;
} }
return true; return true;
}; };
const auto written_size = codec_info.channels * codec_info.frameSamples * sizeof(u16); const auto write_output = [&](std::span<s16> pcm) {
std::vector<s16> pcm_buffer(written_size >> 1); while (!pcm.empty()) {
while (ShouldDecode()) { auto size = std::min(pcm.size() * sizeof(u16), out_buf.size());
u32 ret = Atrac9Decode(decoder_handle, in_buf, pcm_buffer.data(), &bytes_used); std::memcpy(out_buf.data(), pcm.data(), size);
ASSERT_MSG(ret == At9Status::ERR_SUCCESS, "Atrac9Decode failed ret = {:#x}", ret); pcm = pcm.subspan(size >> 1);
in_buf += bytes_used; out_buf = out_buf.subspan(size);
in_size -= bytes_used; if (out_buf.empty()) {
if (output->p_mframe) { out_buffer_index += 1;
++output->p_mframe->num_frames; if (out_buffer_index >= output->buffers.size()) {
return pcm.empty();
}
out_buf = output->buffers[out_buffer_index];
}
} }
num_frames++; return true;
bytes_remain -= bytes_used; };
int num_superframes = 0;
const auto pcm_frame_size = codec_info.channels * codec_info.frameSamples * sizeof(u16);
std::vector<s16> pcm_buffer(pcm_frame_size >> 1);
while (should_decode()) {
int bytes_used = 0;
u32 ret = Atrac9Decode(handle, in_buf.data(), pcm_buffer.data(), &bytes_used);
ASSERT_MSG(ret == At9Status::ERR_SUCCESS, "Atrac9Decode failed ret = {:#x}", ret);
in_buf = in_buf.subspan(bytes_used);
superframe_bytes_remain -= bytes_used;
const size_t samples_remain = gapless.total_samples != 0
? gapless.total_samples - gapless_decoded_samples
: std::numeric_limits<size_t>::max();
bool written = false;
if (gapless.skipped_samples < gapless.skip_samples) { if (gapless.skipped_samples < gapless.skip_samples) {
gapless.skipped_samples += decoder_handle->Config.FrameSamples; gapless.skipped_samples += codec_info.frameSamples;
if (gapless.skipped_samples > gapless.skip_samples) { if (gapless.skipped_samples > gapless.skip_samples) {
const auto size = gapless.skipped_samples - gapless.skip_samples; const u32 nsamples = gapless.skipped_samples - gapless.skip_samples;
const auto start = decoder_handle->Config.FrameSamples - size; const auto start = codec_info.frameSamples - nsamples;
memcpy(out_buf, pcm_buffer.data() + start, size * sizeof(s16)); written = write_output({pcm_buffer.data() + start, nsamples});
out_buf += size * sizeof(s16); gapless.skipped_samples = gapless.skip_samples;
out_size -= size * sizeof(s16); total_decoded_samples += nsamples;
gapless_decoded_samples += nsamples;
} }
} else { } else {
memcpy(out_buf, pcm_buffer.data(), written_size); written =
out_buf += written_size; write_output({pcm_buffer.data(), std::min(pcm_buffer.size(), samples_remain)});
out_size -= written_size; total_decoded_samples += codec_info.frameSamples;
gapless_decoded_samples += codec_info.frameSamples;
} }
decoded_samples += decoder_handle->Config.FrameSamples;
num_frames += 1;
if ((num_frames % codec_info.framesInSuperframe) == 0) { if ((num_frames % codec_info.framesInSuperframe) == 0) {
in_buf += bytes_remain; if (superframe_bytes_remain) {
in_size -= bytes_remain; if (output->p_stream) {
bytes_remain = codec_info.superframeSize; output->p_stream->input_consumed += superframe_bytes_remain;
num_superframes++; }
in_buf = in_buf.subspan(superframe_bytes_remain);
}
superframe_bytes_remain = codec_info.superframeSize;
num_superframes += 1;
}
if (output->p_stream) {
output->p_stream->input_consumed += bytes_used;
if (written) {
output->p_stream->output_written +=
std::min(pcm_frame_size, samples_remain * sizeof(16));
}
}
if (output->p_mframe) {
output->p_mframe->num_frames += 1;
} }
} }
if (gapless.total_samples == decoded_samples) { if (gapless_decoded_samples >= gapless.total_samples) {
decoded_samples = 0;
if (flags.gapless_loop) { if (flags.gapless_loop) {
gapless.skipped_samples = 0; ResetCodec();
} }
} }
LOG_TRACE(Lib_Ajm, "Decoded {} samples, frame count: {}", decoded_samples, num_frames); if (output->p_stream) {
return std::tuple(in_size, out_size); output->p_stream->total_decoded_samples = total_decoded_samples;
}
LOG_TRACE(Lib_Ajm, "Decoded buffer, in remain = {}, out remain = {}", in_buf.size(),
out_buf.size());
} }
} // namespace Libraries::Ajm } // namespace Libraries::Ajm

View File

@ -30,6 +30,8 @@ struct AjmAt9Decoder final : AjmInstance {
std::fstream file; std::fstream file;
int length; int length;
u8 config_data[ORBIS_AT9_CONFIG_DATA_SIZE]; u8 config_data[ORBIS_AT9_CONFIG_DATA_SIZE];
u32 superframe_bytes_remain{};
u32 num_frames{};
explicit AjmAt9Decoder(); explicit AjmAt9Decoder();
~AjmAt9Decoder() override; ~AjmAt9Decoder() override;
@ -43,8 +45,10 @@ struct AjmAt9Decoder final : AjmInstance {
return sizeof(AjmSidebandDecAt9CodecInfo); return sizeof(AjmSidebandDecAt9CodecInfo);
} }
std::tuple<u32, u32> Decode(const u8* in_buf, u32 in_size, u8* out_buf, u32 out_size, void Decode(const AjmJobInput* input, AjmJobOutput* output) override;
AjmJobOutput* output) override;
private:
void ResetCodec();
}; };
} // namespace Libraries::Ajm } // namespace Libraries::Ajm

View File

@ -101,7 +101,7 @@ struct AjmJobInput {
std::optional<AjmSidebandResampleParameters> resample_parameters; std::optional<AjmSidebandResampleParameters> resample_parameters;
std::optional<AjmSidebandFormat> format; std::optional<AjmSidebandFormat> format;
std::optional<AjmSidebandGaplessDecode> gapless_decode; std::optional<AjmSidebandGaplessDecode> gapless_decode;
boost::container::small_vector<std::vector<u8>, 4> buffers; std::vector<u8> buffer;
}; };
struct AjmJobOutput { struct AjmJobOutput {
@ -132,9 +132,8 @@ struct AjmInstance {
AjmInstanceFlags flags{.raw = 0}; AjmInstanceFlags flags{.raw = 0};
u32 num_channels{}; u32 num_channels{};
u32 index{}; u32 index{};
u32 bytes_remain{}; u32 gapless_decoded_samples{};
u32 num_frames{}; u32 total_decoded_samples{};
u32 decoded_samples{};
AjmSidebandFormat format{}; AjmSidebandFormat format{};
AjmSidebandGaplessDecode gapless{}; AjmSidebandGaplessDecode gapless{};
AjmSidebandResampleParameters resample_parameters{}; AjmSidebandResampleParameters resample_parameters{};
@ -149,8 +148,7 @@ struct AjmInstance {
virtual void GetCodecInfo(void* out_info) = 0; virtual void GetCodecInfo(void* out_info) = 0;
virtual u32 GetCodecInfoSize() = 0; virtual u32 GetCodecInfoSize() = 0;
virtual std::tuple<u32, u32> Decode(const u8* in_buf, u32 in_size, u8* out_buf, u32 out_size, virtual void Decode(const AjmJobInput* input, AjmJobOutput* output) = 0;
AjmJobOutput* output) = 0;
}; };
} // namespace Libraries::Ajm } // namespace Libraries::Ajm

View File

@ -73,20 +73,51 @@ void AjmMp3Decoder::Reset() {
ASSERT_MSG(c, "Could not allocate audio codec context"); ASSERT_MSG(c, "Could not allocate audio codec context");
int ret = avcodec_open2(c, codec, nullptr); int ret = avcodec_open2(c, codec, nullptr);
ASSERT_MSG(ret >= 0, "Could not open codec"); ASSERT_MSG(ret >= 0, "Could not open codec");
decoded_samples = 0; total_decoded_samples = 0;
num_frames = 0;
} }
std::tuple<u32, u32> AjmMp3Decoder::Decode(const u8* buf, u32 in_size, u8* out_buf, u32 out_size, void AjmMp3Decoder::Decode(const AjmJobInput* input, AjmJobOutput* output) {
AjmJobOutput* output) {
AVPacket* pkt = av_packet_alloc(); AVPacket* pkt = av_packet_alloc();
while (in_size > 0 && out_size > 0) {
int ret = av_parser_parse2(parser, c, &pkt->data, &pkt->size, buf, in_size, AV_NOPTS_VALUE,
AV_NOPTS_VALUE, 0);
ASSERT_MSG(ret >= 0, "Error while parsing {}", ret);
buf += ret;
in_size -= ret;
size_t out_buffer_index = 0;
std::span<const u8> in_buf(input->buffer);
std::span<u8> out_buf = output->buffers[out_buffer_index];
const auto should_decode = [&] {
if (in_buf.empty() || out_buf.empty()) {
return false;
}
if (gapless.total_samples != 0 && gapless.total_samples < gapless_decoded_samples) {
return false;
}
return true;
};
const auto write_output = [&](std::span<s16> pcm) {
while (!pcm.empty()) {
auto size = std::min(pcm.size() * sizeof(u16), out_buf.size());
std::memcpy(out_buf.data(), pcm.data(), size);
pcm = pcm.subspan(size >> 1);
out_buf = out_buf.subspan(size);
if (out_buf.empty()) {
out_buffer_index += 1;
if (out_buffer_index >= output->buffers.size()) {
return pcm.empty();
}
out_buf = output->buffers[out_buffer_index];
}
}
return true;
};
while (should_decode()) {
int ret = av_parser_parse2(parser, c, &pkt->data, &pkt->size, in_buf.data(), in_buf.size(),
AV_NOPTS_VALUE, AV_NOPTS_VALUE, 0);
ASSERT_MSG(ret >= 0, "Error while parsing {}", ret);
in_buf = in_buf.subspan(ret);
if (output->p_stream) {
output->p_stream->input_consumed += ret;
}
if (pkt->size) { if (pkt->size) {
// Send the packet with the compressed data to the decoder // Send the packet with the compressed data to the decoder
pkt->pts = parser->pts; pkt->pts = parser->pts;
@ -107,22 +138,43 @@ std::tuple<u32, u32> AjmMp3Decoder::Decode(const u8* buf, u32 in_size, u8* out_b
if (frame->format != AV_SAMPLE_FMT_S16) { if (frame->format != AV_SAMPLE_FMT_S16) {
frame = ConvertAudioFrame(frame); frame = ConvertAudioFrame(frame);
} }
const auto size = frame->ch_layout.nb_channels * frame->nb_samples * sizeof(u16); const auto frame_samples = frame->ch_layout.nb_channels * frame->nb_samples;
std::memcpy(out_buf, frame->data[0], size); const auto size = frame_samples * sizeof(u16);
file.write((const char*)frame->data[0], size); if (gapless.skipped_samples < gapless.skip_samples) {
out_buf += size; gapless.skipped_samples += frame_samples;
out_size -= size; if (gapless.skipped_samples > gapless.skip_samples) {
decoded_samples += frame->nb_samples; const u32 nsamples = gapless.skipped_samples - gapless.skip_samples;
num_frames++; const auto start = frame_samples - nsamples;
write_output({reinterpret_cast<s16*>(frame->data[0]), nsamples});
gapless.skipped_samples = gapless.skip_samples;
total_decoded_samples += nsamples;
gapless_decoded_samples += nsamples;
}
} else {
write_output({reinterpret_cast<s16*>(frame->data[0]), size >> 1});
total_decoded_samples += frame_samples;
gapless_decoded_samples += frame_samples;
}
av_frame_free(&frame); av_frame_free(&frame);
if (output->p_stream) {
output->p_stream->output_written += size;
}
if (output->p_mframe) {
output->p_mframe->num_frames += 1;
}
} }
} }
} }
av_packet_free(&pkt); av_packet_free(&pkt);
if (output->p_mframe) { if (gapless_decoded_samples >= gapless.total_samples) {
output->p_mframe->num_frames += num_frames; if (flags.gapless_loop) {
gapless.skipped_samples = 0;
gapless_decoded_samples = 0;
}
}
if (output->p_stream) {
output->p_stream->total_decoded_samples = total_decoded_samples;
} }
return std::make_tuple(in_size, out_size);
} }
int AjmMp3Decoder::ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl, int AjmMp3Decoder::ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl,

View File

@ -74,8 +74,7 @@ struct AjmMp3Decoder : public AjmInstance {
return sizeof(AjmSidebandDecMp3CodecInfo); return sizeof(AjmSidebandDecMp3CodecInfo);
} }
std::tuple<u32, u32> Decode(const u8* in_buf, u32 in_size, u8* out_buf, u32 out_size, void Decode(const AjmJobInput* input, AjmJobOutput* output) override;
AjmJobOutput* output) override;
static int ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl, static int ParseMp3Header(const u8* buf, u32 stream_size, int parse_ofl,
AjmDecMp3ParseFrame* frame); AjmDecMp3ParseFrame* frame);