first working ASR pipeline

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-06-16 19:06:47 +02:00
parent 05b8485dfb
commit cbd2e535db
4 changed files with 1292 additions and 120 deletions

View File

@ -48,6 +48,7 @@ from docling.datamodel.pipeline_options import (
PaginatedPipelineOptions, PaginatedPipelineOptions,
PdfBackend, PdfBackend,
PdfPipelineOptions, PdfPipelineOptions,
PipelineOptions,
ProcessingPipeline, ProcessingPipeline,
TableFormerMode, TableFormerMode,
VlmPipelineOptions, VlmPipelineOptions,
@ -466,12 +467,14 @@ def convert( # noqa: C901
), ),
] = None, ] = None,
): ):
log_format = "%(asctime)s\t%(levelname)s\t%(name)s: %(message)s"
if verbose == 0: if verbose == 0:
logging.basicConfig(level=logging.WARNING) logging.basicConfig(level=logging.WARNING, format=log_format)
elif verbose == 1: elif verbose == 1:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO, format=log_format)
else: else:
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG, format=log_format)
settings.debug.visualize_cells = debug_visualize_cells settings.debug.visualize_cells = debug_visualize_cells
settings.debug.visualize_layout = debug_visualize_layout settings.debug.visualize_layout = debug_visualize_layout
@ -546,7 +549,8 @@ def convert( # noqa: C901
ocr_options.lang = ocr_lang_list ocr_options.lang = ocr_lang_list
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device) accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
pipeline_options: PaginatedPipelineOptions # pipeline_options: PaginatedPipelineOptions
pipeline_options: PipelineOptions
format_options: Dict[InputFormat, FormatOption] = {} format_options: Dict[InputFormat, FormatOption] = {}
@ -593,7 +597,7 @@ def convert( # noqa: C901
backend=backend, # pdf_backend backend=backend, # pdf_backend
) )
format_options: Dict[InputFormat, FormatOption] = { format_options = {
InputFormat.PDF: pdf_format_option, InputFormat.PDF: pdf_format_option,
InputFormat.IMAGE: pdf_format_option, InputFormat.IMAGE: pdf_format_option,
} }
@ -624,7 +628,7 @@ def convert( # noqa: C901
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
) )
format_options: Dict[InputFormat, FormatOption] = { format_options = {
InputFormat.PDF: pdf_format_option, InputFormat.PDF: pdf_format_option,
InputFormat.IMAGE: pdf_format_option, InputFormat.IMAGE: pdf_format_option,
} }
@ -638,6 +642,7 @@ def convert( # noqa: C901
if asr_model == AsrModelType.WHISPER_TINY: if asr_model == AsrModelType.WHISPER_TINY:
pipeline_options.asr_options = WHISPER_TINY pipeline_options.asr_options = WHISPER_TINY
else: else:
_log.warning("falling back in base ASR model: WHISPER_TINY")
pipeline_options.asr_options = WHISPER_TINY pipeline_options.asr_options = WHISPER_TINY
audio_format_option = AudioFormatOption( audio_format_option = AudioFormatOption(
@ -646,15 +651,10 @@ def convert( # noqa: C901
backend=AudioBackend, backend=AudioBackend,
) )
format_options: Dict[InputFormat, FormatOption] = { format_options = {
InputFormat.AUDIO_WAV: audio_format_option, InputFormat.AUDIO_WAV: audio_format_option,
} }
"""
if asr_model == AsrModelType.WHISPER_TINY:
pipeline_options.asr_options = WHISPER_TINY:
"""
if artifacts_path is not None: if artifacts_path is not None:
pipeline_options.artifacts_path = artifacts_path pipeline_options.artifacts_path = artifacts_path
# audio_pipeline_options.artifacts_path = artifacts_path # audio_pipeline_options.artifacts_path = artifacts_path

View File

@ -5,9 +5,13 @@ from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union, cast from typing import List, Optional, Union, cast
import soundfile as sf import librosa # type: ignore
import numpy as np
import soundfile as sf # type: ignore
from docling_core.types.doc.labels import DocItemLabel from docling_core.types.doc.labels import DocItemLabel
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from pydub import AudioSegment # type: ignore
from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline
from docling.backend.abstract_backend import AbstractDocumentBackend from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.audio_backend import AudioBackend from docling.backend.audio_backend import AudioBackend
@ -29,58 +33,297 @@ from docling.datamodel.settings import settings
from docling.pipeline.base_pipeline import BasePipeline from docling.pipeline.base_pipeline import BasePipeline
from docling.utils.profiling import ProfilingScope, TimeRecorder from docling.utils.profiling import ProfilingScope, TimeRecorder
import librosa
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
class ConversationEntry(BaseModel): class _ConversationItem(BaseModel):
text: str text: str
start_time: float = Field( start_time: Optional[float] = Field(
..., ge=0, description="Start time in seconds from video start" None, description="Start time in seconds from video start"
) )
end_time: float = Field( end_time: Optional[float] = Field(
..., ge=0, description="End time in seconds from video start" None, ge=0, description="End time in seconds from video start"
) )
speaker_id: int = Field(..., ge=0, description="Numeric speaker identifier") speaker_id: Optional[int] = Field(None, description="Numeric speaker identifier")
speaker: Optional[str] = Field( speaker: Optional[str] = Field(
None, description="Speaker name, defaults to speaker-{speaker_id}" None, description="Speaker name, defaults to speaker-{speaker_id}"
) )
@validator("end_time")
def end_time_must_be_after_start(cls, v, values):
if "start_time" in values and v <= values["start_time"]:
raise ValueError("end_time must be greater than start_time")
return v
@validator("speaker", always=True)
def set_default_speaker_name(cls, v, values):
if v is None and "speaker_id" in values:
return f"speaker-{values['speaker_id']}"
return v
def __lt__(self, other): def __lt__(self, other):
if not isinstance(other, ConversationEntry): if not isinstance(other, _ConversationItem):
return NotImplemented return NotImplemented
return self.start_time < other.start_time return self.start_time < other.start_time
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, ConversationEntry): if not isinstance(other, _ConversationItem):
return NotImplemented return NotImplemented
return self.start_time == other.start_time return self.start_time == other.start_time
def to_string(self) -> str: def to_string(self) -> str:
"""Format the conversation entry as a string""" """Format the conversation entry as a string"""
return f"[time: {self.start_time}-{self.end_time}] [speaker:{self.speaker}] {self.text}" result = ""
if (self.start_time is not None) and (self.end_time is not None):
result += f"[time: {self.start_time}-{self.end_time}] "
if self.speaker is not None:
result += f"[speaker:{self.speaker}] "
result += self.text
return result
class _WhisperASR:
def __init__(self, model_name: str = "openai/whisper-small"):
"""
Transcriber using Hugging Face Transformers Whisper + energy-based VAD.
"""
print(f"Loading Whisper model: {model_name}")
self.device = "cpu"
self.transcriber = pipeline(
"automatic-speech-recognition",
model=model_name,
return_timestamps=True,
device=self.device,
)
def _energy_vad(
self,
y: np.ndarray,
sr: int,
frame_length=2048,
hop_length=512,
threshold_percentile=85,
):
"""
Simple energy-based VAD.
Returns list of (start_time, end_time) tuples for speech segments.
"""
_log.debug(f"_energy_vad {sr}: ", y.shape)
energy = np.array(
[
np.sum(np.abs(y[i : i + frame_length] ** 2))
for i in range(0, len(y), hop_length)
]
)
_log.debug(f"energy: {energy}")
threshold = np.percentile(energy, threshold_percentile) * 0.3
_log.debug(f"threshold: {threshold}")
speech_frames = energy > threshold
_log.debug(f"speech_frames: {speech_frames}")
frame_times = librosa.frames_to_time(
np.arange(len(energy)), sr=sr, hop_length=hop_length
)
segments = []
start_time = None
for i, is_speech in enumerate(speech_frames):
t = frame_times[i]
if is_speech and start_time is None:
start_time = t
elif not is_speech and start_time is not None:
segments.append((start_time, t))
start_time = None
if start_time is not None:
segments.append((start_time, frame_times[-1]))
return segments
def _merge_vad_segments(self, segments, min_duration=5.0, max_gap=0.5):
"""
Merge short/adjacent speech segments to improve transcription quality.
"""
if not segments:
return []
merged = []
current_start, current_end = segments[0]
for start, end in segments[1:]:
gap = start - current_end
if gap <= max_gap or (current_end - current_start) < min_duration:
current_end = end # merge
else:
if current_end - current_start >= 1.0: # skip ultra-short
merged.append((current_start, current_end))
current_start, current_end = start, end
if current_end - current_start >= 1.0:
merged.append((current_start, current_end))
return merged
def run(self, conv_res: ConversionResult) -> ConversionResult:
"""
Transcribe audio using custom VAD and Whisper, returning timestamped segments.
Returns list of {"start", "end", "text"} dictionaries.
"""
audio_path = conv_res.input.file
_log.info(f"Loading audio and resampling: {audio_path}")
y, sr = librosa.load(audio_path, sr=16000)
speech_segments = self._energy_vad(y=y, sr=int(sr))
speech_segments = self._merge_vad_segments(speech_segments)
_log.info("#-speech: ", len(speech_segments))
_log.info("Preparing AudioSegment for chunk slicing...")
pcm = (y * 32767).astype(np.int16).tobytes()
audio_seg = AudioSegment(data=pcm, sample_width=2, frame_rate=16000, channels=1)
result = self._create_conversation_entries_v2(speech_segments, audio_seg)
result.sort()
for _ in result:
conv_res.document.add_text(label=DocItemLabel.TEXT, text=_.to_string())
conv_res.status = ConversionStatus.SUCCESS
return conv_res
def _create_conversation_entries_v1(
self, speech_segments, audio_seg
) -> list[_ConversationItem]:
"""
Chunk audio based on speech_segments, transcribe with Whisper,
and return structured _ConversationItem items.
"""
results = []
chunk_id = 0
for start, end in speech_segments:
duration = end - start
while duration > 0:
sub_end = min(start + 30.0, end)
chunk = audio_seg[start * 1000 : sub_end * 1000]
samples = (
np.array(chunk.get_array_of_samples()).astype(np.float32) / 32768.0
)
try:
_log.debug(
f"Transcribing chunk {chunk_id}: {start:.2f}s - {sub_end:.2f}s [{sub_end - start:.2f}]"
)
result = self.transcriber(samples, return_timestamps=True)
# Adjust timestamps globally
for seg in result["chunks"]:
t0, t1 = seg["timestamp"]
if t0 is None or t1 is None or t1 <= t0:
_log.warning(f"skipping bad segment: {seg}")
continue
item = _ConversationItem(
text=seg["text"].strip(),
start_time=start + t0,
end_time=start + t1,
)
results.append(item)
start = sub_end
duration = end - start
chunk_id += 1
except Exception as exc:
_log.error(f"Exception: {exc}")
return results
def _create_conversation_entries_v2(
self, speech_segments, audio_seg
) -> list[_ConversationItem]:
"""
Chunk audio based on speech_segments, transcribe with Whisper,
and return structured _ConversationItem items.
"""
results = []
chunk_id = 0
if len(speech_segments) == 0:
return []
any_valid = False
last_valid_offset: float = speech_segments[0][0]
for start, end in speech_segments:
if any_valid:
last_valid_offset = min(start, last_valid_offset)
else:
last_valid_offset = start
duration = end - last_valid_offset
if duration > 0.2:
sub_end = min(last_valid_offset + 30.0, end)
chunk_i0 = int(last_valid_offset * 1000)
chunk_i1 = int(sub_end * 1000)
chunk = audio_seg[chunk_i0:chunk_i1]
samples = (
np.array(chunk.get_array_of_samples()).astype(np.float32) / 32768.0
)
chunk_id += 1
try:
result = self.transcriber(samples, return_timestamps=True)
any_valid = False
last_valid_offset_ = last_valid_offset
for seg in result["chunks"]:
t0, t1 = seg["timestamp"]
if t0 is None or t1 is None or t1 <= t0:
_log.warning(f" => skipping bad segment: {seg}")
continue
global_start = round(last_valid_offset_ + t0, 2)
global_end = round(last_valid_offset_ + t1, 2)
text = seg["text"].strip()
results.append(
_ConversationItem(
start_time=global_start, end_time=global_end, text=text
)
)
last_valid_offset = max(global_end, last_valid_offset)
any_valid = True
if not any_valid:
_log.warning(
"No valid transcription in chunk, nudging forward 1s."
)
last_valid_offset += 1.0
except Exception as e:
_log.error(f"Whisper failed: {e}")
last_valid_offset += 1.0
duration = end - last_valid_offset
else:
any_valid = False
return results
class _WhisperModel: class _WhisperModel:
def __init__(self): def __init__(self):
_log.info("initialisation `_WhisperModel`") _log.info("initialisation `_WhisperModel`")
from transformers import WhisperForConditionalGeneration, WhisperProcessor self.device = "cpu"
self.chunk_length = 30
self.model_repo = "openai/whisper-tiny" self.batch_size = 8
# self.model_repo = "openai/whisper-tiny"
# self.model_repo = "openai/whisper-small"
self.model_repo = "openai/whisper-medium"
# self.model_repo = "openai/whisper-large"
self.processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") self.processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
self.model = WhisperForConditionalGeneration.from_pretrained( self.model = WhisperForConditionalGeneration.from_pretrained(
@ -89,100 +332,43 @@ class _WhisperModel:
# FIXME # FIXME
self.max_new_tokens = 256 self.max_new_tokens = 256
_log.info(f"model is loaded: {self.model_repo}") _log.info(f"model is loaded: {self.model_repo}")
def run(self, conv_res: ConversionResult): self.pipe = pipeline(
# fpath = Path(conv_res.input.file) "automatic-speech-recognition",
# _log.info(f"`_WhisperModel::run: {conv_res}`") model=self.model_repo,
_log.info(f"`_WhisperModel::run: {conv_res.input}`") chunk_length_s=self.chunk_length,
_log.info(f"`_WhisperModel::run: {conv_res.input.file}`") device=self.device,
)
if os.path.exists(str(conv_res.input.file)): def run(self, conv_res: ConversionResult) -> ConversionResult:
print("file exists") return self._run_pipeline(conv_res=conv_res)
else:
print("file does not exist")
#
_log.info(f"sampling-rate: {self.processor.feature_extractor.sampling_rate}")
def _run_pipeline(self, conv_res: ConversionResult) -> ConversionResult:
try: try:
fpath = conv_res.input.file fpath = conv_res.input.file
# array, sampling_rate = sf.read(fpath)#, samplerate=processor.feature_extractor.sampling_rate)
array, sampling_rate = sf.read(
fpath
) # , samplerate=self.processor.feature_extractor.sampling_rate)
_log.info(
f"read the file .. (sampling-rate: {sampling_rate}, array: {array.shape})"
)
array, sampling_rate = librosa.load(fpath, sr=16000) array, sampling_rate = librosa.load(fpath, sr=16000)
_log.info( prediction = self.pipe(
f"read the file .. (sampling-rate: {sampling_rate}, array: {array.shape})" inputs=array, batch_size=self.batch_size, return_timestamps=True
) ) # ["chunks"]
for _ in prediction["chunks"]:
processed_input = self.processor( item = _ConversationItem(
array, text=_["text"],
sampling_rate=self.processor.feature_extractor.sampling_rate, # sampling_rate, start_time=_["timestamp"][0],
return_tensors="pt", end_time=_["timestamp"][1],
) )
print(processed_input) conv_res.document.add_text(
label=DocItemLabel.TEXT, text=item.to_string()
# pre-process to get the input features )
input_features = self.processor(
array, sampling_rate=sampling_rate, return_tensors="pt"
).input_features
_log.info(f"got input-features: {input_features.shape}")
_log.info(f"max new tokens: {self.max_new_tokens}")
# generate token ids by running model forward sequentially
predicted_ids = self.model.generate(
input_features, max_new_tokens=self.max_new_tokens, return_timestamps=True
)
_log.info("ran model ..")
"""
transcription = self.processor.batch_decode(predicted_ids,
skip_special_tokens=False,
decode_with_timestamps=True)
_log.info("decoded output ..")
print(f"Transcription: {transcription}")
"""
conversation = []
print("Timestamp info:")
for pidi, pid in enumerate(predicted_ids):
# timestamps = processor.tokenizer.decode(pid, decode_with_timestamps=True)
timestamps = self.processor.tokenizer.decode(pid, output_offsets=True)
print(f"Predicted id [{pidi}]: {timestamps['text']}")
for offset in timestamps["offsets"]:
print(f" => {offset['timestamp']}: {offset['text']}")
item = ConversationEntry(
text=offset["text"],
speaker_id=pidi,
start_time=offset["timestamp"][0],
end_time=offset["timestamp"][1],
)
conv_res.document.add_text(
label=DocItemLabel.TEXT, text=item.to_string()
)
conv_res.status = ConversionStatus.SUCCESS conv_res.status = ConversionStatus.SUCCESS
print("document: \n\n", conv_res.document.export_to_markdown())
except Exception as exc: except Exception as exc:
conv_res.status = ConversionStatus.FAILED conv_res.status = ConversionStatus.FAILURE
print(exc) _log.error(f"Failed to convert with {self.model_repo}: {exc}")
return conv_res return conv_res
@ -206,7 +392,8 @@ class AsrPipeline(BasePipeline):
"When defined, it must point to a folder containing all models required by the pipeline." "When defined, it must point to a folder containing all models required by the pipeline."
) )
self._model = _WhisperModel() # self._model = _WhisperModel()
self._model = _WhisperASR()
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus: def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
status = ConversionStatus.SUCCESS status = ConversionStatus.SUCCESS

View File

@ -70,6 +70,8 @@ dependencies = [
'scipy (>=1.6.0,<2.0.0)', 'scipy (>=1.6.0,<2.0.0)',
# 'scipy (>=1.6.0,<2.0.0) ; python_version >= "3.10"', # 'scipy (>=1.6.0,<2.0.0) ; python_version >= "3.10"',
# 'scipy (>=1.6.0,<1.14.0) ; python_version < "3.10"', # 'scipy (>=1.6.0,<1.14.0) ; python_version < "3.10"',
"pydub[asr]>=0.25.1",
"pyannote-audio[asr]>=1.1.2",
] ]
[project.urls] [project.urls]

983
uv.lock generated

File diff suppressed because it is too large Load Diff