mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-26 12:04:31 +00:00
first working ASR pipeline
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
05b8485dfb
commit
cbd2e535db
@ -48,6 +48,7 @@ from docling.datamodel.pipeline_options import (
|
|||||||
PaginatedPipelineOptions,
|
PaginatedPipelineOptions,
|
||||||
PdfBackend,
|
PdfBackend,
|
||||||
PdfPipelineOptions,
|
PdfPipelineOptions,
|
||||||
|
PipelineOptions,
|
||||||
ProcessingPipeline,
|
ProcessingPipeline,
|
||||||
TableFormerMode,
|
TableFormerMode,
|
||||||
VlmPipelineOptions,
|
VlmPipelineOptions,
|
||||||
@ -466,12 +467,14 @@ def convert( # noqa: C901
|
|||||||
),
|
),
|
||||||
] = None,
|
] = None,
|
||||||
):
|
):
|
||||||
|
log_format = "%(asctime)s\t%(levelname)s\t%(name)s: %(message)s"
|
||||||
|
|
||||||
if verbose == 0:
|
if verbose == 0:
|
||||||
logging.basicConfig(level=logging.WARNING)
|
logging.basicConfig(level=logging.WARNING, format=log_format)
|
||||||
elif verbose == 1:
|
elif verbose == 1:
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO, format=log_format)
|
||||||
else:
|
else:
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG, format=log_format)
|
||||||
|
|
||||||
settings.debug.visualize_cells = debug_visualize_cells
|
settings.debug.visualize_cells = debug_visualize_cells
|
||||||
settings.debug.visualize_layout = debug_visualize_layout
|
settings.debug.visualize_layout = debug_visualize_layout
|
||||||
@ -546,7 +549,8 @@ def convert( # noqa: C901
|
|||||||
ocr_options.lang = ocr_lang_list
|
ocr_options.lang = ocr_lang_list
|
||||||
|
|
||||||
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
|
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
|
||||||
pipeline_options: PaginatedPipelineOptions
|
# pipeline_options: PaginatedPipelineOptions
|
||||||
|
pipeline_options: PipelineOptions
|
||||||
|
|
||||||
format_options: Dict[InputFormat, FormatOption] = {}
|
format_options: Dict[InputFormat, FormatOption] = {}
|
||||||
|
|
||||||
@ -593,7 +597,7 @@ def convert( # noqa: C901
|
|||||||
backend=backend, # pdf_backend
|
backend=backend, # pdf_backend
|
||||||
)
|
)
|
||||||
|
|
||||||
format_options: Dict[InputFormat, FormatOption] = {
|
format_options = {
|
||||||
InputFormat.PDF: pdf_format_option,
|
InputFormat.PDF: pdf_format_option,
|
||||||
InputFormat.IMAGE: pdf_format_option,
|
InputFormat.IMAGE: pdf_format_option,
|
||||||
}
|
}
|
||||||
@ -624,7 +628,7 @@ def convert( # noqa: C901
|
|||||||
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
|
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
|
||||||
)
|
)
|
||||||
|
|
||||||
format_options: Dict[InputFormat, FormatOption] = {
|
format_options = {
|
||||||
InputFormat.PDF: pdf_format_option,
|
InputFormat.PDF: pdf_format_option,
|
||||||
InputFormat.IMAGE: pdf_format_option,
|
InputFormat.IMAGE: pdf_format_option,
|
||||||
}
|
}
|
||||||
@ -638,6 +642,7 @@ def convert( # noqa: C901
|
|||||||
if asr_model == AsrModelType.WHISPER_TINY:
|
if asr_model == AsrModelType.WHISPER_TINY:
|
||||||
pipeline_options.asr_options = WHISPER_TINY
|
pipeline_options.asr_options = WHISPER_TINY
|
||||||
else:
|
else:
|
||||||
|
_log.warning("falling back in base ASR model: WHISPER_TINY")
|
||||||
pipeline_options.asr_options = WHISPER_TINY
|
pipeline_options.asr_options = WHISPER_TINY
|
||||||
|
|
||||||
audio_format_option = AudioFormatOption(
|
audio_format_option = AudioFormatOption(
|
||||||
@ -646,15 +651,10 @@ def convert( # noqa: C901
|
|||||||
backend=AudioBackend,
|
backend=AudioBackend,
|
||||||
)
|
)
|
||||||
|
|
||||||
format_options: Dict[InputFormat, FormatOption] = {
|
format_options = {
|
||||||
InputFormat.AUDIO_WAV: audio_format_option,
|
InputFormat.AUDIO_WAV: audio_format_option,
|
||||||
}
|
}
|
||||||
|
|
||||||
"""
|
|
||||||
if asr_model == AsrModelType.WHISPER_TINY:
|
|
||||||
pipeline_options.asr_options = WHISPER_TINY:
|
|
||||||
"""
|
|
||||||
|
|
||||||
if artifacts_path is not None:
|
if artifacts_path is not None:
|
||||||
pipeline_options.artifacts_path = artifacts_path
|
pipeline_options.artifacts_path = artifacts_path
|
||||||
# audio_pipeline_options.artifacts_path = artifacts_path
|
# audio_pipeline_options.artifacts_path = artifacts_path
|
||||||
|
@ -5,9 +5,13 @@ from io import BytesIO
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Union, cast
|
from typing import List, Optional, Union, cast
|
||||||
|
|
||||||
import soundfile as sf
|
import librosa # type: ignore
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf # type: ignore
|
||||||
from docling_core.types.doc.labels import DocItemLabel
|
from docling_core.types.doc.labels import DocItemLabel
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
|
from pydub import AudioSegment # type: ignore
|
||||||
|
from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline
|
||||||
|
|
||||||
from docling.backend.abstract_backend import AbstractDocumentBackend
|
from docling.backend.abstract_backend import AbstractDocumentBackend
|
||||||
from docling.backend.audio_backend import AudioBackend
|
from docling.backend.audio_backend import AudioBackend
|
||||||
@ -29,58 +33,297 @@ from docling.datamodel.settings import settings
|
|||||||
from docling.pipeline.base_pipeline import BasePipeline
|
from docling.pipeline.base_pipeline import BasePipeline
|
||||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||||
|
|
||||||
import librosa
|
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ConversationEntry(BaseModel):
|
class _ConversationItem(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
start_time: float = Field(
|
start_time: Optional[float] = Field(
|
||||||
..., ge=0, description="Start time in seconds from video start"
|
None, description="Start time in seconds from video start"
|
||||||
)
|
)
|
||||||
end_time: float = Field(
|
end_time: Optional[float] = Field(
|
||||||
..., ge=0, description="End time in seconds from video start"
|
None, ge=0, description="End time in seconds from video start"
|
||||||
)
|
)
|
||||||
speaker_id: int = Field(..., ge=0, description="Numeric speaker identifier")
|
speaker_id: Optional[int] = Field(None, description="Numeric speaker identifier")
|
||||||
speaker: Optional[str] = Field(
|
speaker: Optional[str] = Field(
|
||||||
None, description="Speaker name, defaults to speaker-{speaker_id}"
|
None, description="Speaker name, defaults to speaker-{speaker_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@validator("end_time")
|
|
||||||
def end_time_must_be_after_start(cls, v, values):
|
|
||||||
if "start_time" in values and v <= values["start_time"]:
|
|
||||||
raise ValueError("end_time must be greater than start_time")
|
|
||||||
return v
|
|
||||||
|
|
||||||
@validator("speaker", always=True)
|
|
||||||
def set_default_speaker_name(cls, v, values):
|
|
||||||
if v is None and "speaker_id" in values:
|
|
||||||
return f"speaker-{values['speaker_id']}"
|
|
||||||
return v
|
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
if not isinstance(other, ConversationEntry):
|
if not isinstance(other, _ConversationItem):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
return self.start_time < other.start_time
|
return self.start_time < other.start_time
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not isinstance(other, ConversationEntry):
|
if not isinstance(other, _ConversationItem):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
return self.start_time == other.start_time
|
return self.start_time == other.start_time
|
||||||
|
|
||||||
def to_string(self) -> str:
|
def to_string(self) -> str:
|
||||||
"""Format the conversation entry as a string"""
|
"""Format the conversation entry as a string"""
|
||||||
return f"[time: {self.start_time}-{self.end_time}] [speaker:{self.speaker}] {self.text}"
|
result = ""
|
||||||
|
if (self.start_time is not None) and (self.end_time is not None):
|
||||||
|
result += f"[time: {self.start_time}-{self.end_time}] "
|
||||||
|
|
||||||
|
if self.speaker is not None:
|
||||||
|
result += f"[speaker:{self.speaker}] "
|
||||||
|
|
||||||
|
result += self.text
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class _WhisperASR:
|
||||||
|
def __init__(self, model_name: str = "openai/whisper-small"):
|
||||||
|
"""
|
||||||
|
Transcriber using Hugging Face Transformers Whisper + energy-based VAD.
|
||||||
|
"""
|
||||||
|
print(f"Loading Whisper model: {model_name}")
|
||||||
|
|
||||||
|
self.device = "cpu"
|
||||||
|
|
||||||
|
self.transcriber = pipeline(
|
||||||
|
"automatic-speech-recognition",
|
||||||
|
model=model_name,
|
||||||
|
return_timestamps=True,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _energy_vad(
|
||||||
|
self,
|
||||||
|
y: np.ndarray,
|
||||||
|
sr: int,
|
||||||
|
frame_length=2048,
|
||||||
|
hop_length=512,
|
||||||
|
threshold_percentile=85,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Simple energy-based VAD.
|
||||||
|
Returns list of (start_time, end_time) tuples for speech segments.
|
||||||
|
"""
|
||||||
|
_log.debug(f"_energy_vad {sr}: ", y.shape)
|
||||||
|
energy = np.array(
|
||||||
|
[
|
||||||
|
np.sum(np.abs(y[i : i + frame_length] ** 2))
|
||||||
|
for i in range(0, len(y), hop_length)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_log.debug(f"energy: {energy}")
|
||||||
|
|
||||||
|
threshold = np.percentile(energy, threshold_percentile) * 0.3
|
||||||
|
_log.debug(f"threshold: {threshold}")
|
||||||
|
|
||||||
|
speech_frames = energy > threshold
|
||||||
|
_log.debug(f"speech_frames: {speech_frames}")
|
||||||
|
|
||||||
|
frame_times = librosa.frames_to_time(
|
||||||
|
np.arange(len(energy)), sr=sr, hop_length=hop_length
|
||||||
|
)
|
||||||
|
|
||||||
|
segments = []
|
||||||
|
start_time = None
|
||||||
|
|
||||||
|
for i, is_speech in enumerate(speech_frames):
|
||||||
|
t = frame_times[i]
|
||||||
|
if is_speech and start_time is None:
|
||||||
|
start_time = t
|
||||||
|
elif not is_speech and start_time is not None:
|
||||||
|
segments.append((start_time, t))
|
||||||
|
start_time = None
|
||||||
|
|
||||||
|
if start_time is not None:
|
||||||
|
segments.append((start_time, frame_times[-1]))
|
||||||
|
|
||||||
|
return segments
|
||||||
|
|
||||||
|
def _merge_vad_segments(self, segments, min_duration=5.0, max_gap=0.5):
|
||||||
|
"""
|
||||||
|
Merge short/adjacent speech segments to improve transcription quality.
|
||||||
|
"""
|
||||||
|
if not segments:
|
||||||
|
return []
|
||||||
|
|
||||||
|
merged = []
|
||||||
|
current_start, current_end = segments[0]
|
||||||
|
|
||||||
|
for start, end in segments[1:]:
|
||||||
|
gap = start - current_end
|
||||||
|
if gap <= max_gap or (current_end - current_start) < min_duration:
|
||||||
|
current_end = end # merge
|
||||||
|
else:
|
||||||
|
if current_end - current_start >= 1.0: # skip ultra-short
|
||||||
|
merged.append((current_start, current_end))
|
||||||
|
current_start, current_end = start, end
|
||||||
|
|
||||||
|
if current_end - current_start >= 1.0:
|
||||||
|
merged.append((current_start, current_end))
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def run(self, conv_res: ConversionResult) -> ConversionResult:
|
||||||
|
"""
|
||||||
|
Transcribe audio using custom VAD and Whisper, returning timestamped segments.
|
||||||
|
Returns list of {"start", "end", "text"} dictionaries.
|
||||||
|
"""
|
||||||
|
audio_path = conv_res.input.file
|
||||||
|
|
||||||
|
_log.info(f"Loading audio and resampling: {audio_path}")
|
||||||
|
y, sr = librosa.load(audio_path, sr=16000)
|
||||||
|
|
||||||
|
speech_segments = self._energy_vad(y=y, sr=int(sr))
|
||||||
|
speech_segments = self._merge_vad_segments(speech_segments)
|
||||||
|
_log.info("#-speech: ", len(speech_segments))
|
||||||
|
|
||||||
|
_log.info("Preparing AudioSegment for chunk slicing...")
|
||||||
|
pcm = (y * 32767).astype(np.int16).tobytes()
|
||||||
|
audio_seg = AudioSegment(data=pcm, sample_width=2, frame_rate=16000, channels=1)
|
||||||
|
|
||||||
|
result = self._create_conversation_entries_v2(speech_segments, audio_seg)
|
||||||
|
result.sort()
|
||||||
|
|
||||||
|
for _ in result:
|
||||||
|
conv_res.document.add_text(label=DocItemLabel.TEXT, text=_.to_string())
|
||||||
|
|
||||||
|
conv_res.status = ConversionStatus.SUCCESS
|
||||||
|
return conv_res
|
||||||
|
|
||||||
|
def _create_conversation_entries_v1(
|
||||||
|
self, speech_segments, audio_seg
|
||||||
|
) -> list[_ConversationItem]:
|
||||||
|
"""
|
||||||
|
Chunk audio based on speech_segments, transcribe with Whisper,
|
||||||
|
and return structured _ConversationItem items.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
chunk_id = 0
|
||||||
|
|
||||||
|
for start, end in speech_segments:
|
||||||
|
duration = end - start
|
||||||
|
while duration > 0:
|
||||||
|
sub_end = min(start + 30.0, end)
|
||||||
|
chunk = audio_seg[start * 1000 : sub_end * 1000]
|
||||||
|
samples = (
|
||||||
|
np.array(chunk.get_array_of_samples()).astype(np.float32) / 32768.0
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_log.debug(
|
||||||
|
f"Transcribing chunk {chunk_id}: {start:.2f}s - {sub_end:.2f}s [{sub_end - start:.2f}]"
|
||||||
|
)
|
||||||
|
result = self.transcriber(samples, return_timestamps=True)
|
||||||
|
|
||||||
|
# Adjust timestamps globally
|
||||||
|
for seg in result["chunks"]:
|
||||||
|
t0, t1 = seg["timestamp"]
|
||||||
|
if t0 is None or t1 is None or t1 <= t0:
|
||||||
|
_log.warning(f"skipping bad segment: {seg}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
item = _ConversationItem(
|
||||||
|
text=seg["text"].strip(),
|
||||||
|
start_time=start + t0,
|
||||||
|
end_time=start + t1,
|
||||||
|
)
|
||||||
|
results.append(item)
|
||||||
|
|
||||||
|
start = sub_end
|
||||||
|
duration = end - start
|
||||||
|
chunk_id += 1
|
||||||
|
except Exception as exc:
|
||||||
|
_log.error(f"Exception: {exc}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _create_conversation_entries_v2(
|
||||||
|
self, speech_segments, audio_seg
|
||||||
|
) -> list[_ConversationItem]:
|
||||||
|
"""
|
||||||
|
Chunk audio based on speech_segments, transcribe with Whisper,
|
||||||
|
and return structured _ConversationItem items.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
chunk_id = 0
|
||||||
|
|
||||||
|
if len(speech_segments) == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
any_valid = False
|
||||||
|
last_valid_offset: float = speech_segments[0][0]
|
||||||
|
|
||||||
|
for start, end in speech_segments:
|
||||||
|
if any_valid:
|
||||||
|
last_valid_offset = min(start, last_valid_offset)
|
||||||
|
else:
|
||||||
|
last_valid_offset = start
|
||||||
|
|
||||||
|
duration = end - last_valid_offset
|
||||||
|
|
||||||
|
if duration > 0.2:
|
||||||
|
sub_end = min(last_valid_offset + 30.0, end)
|
||||||
|
|
||||||
|
chunk_i0 = int(last_valid_offset * 1000)
|
||||||
|
chunk_i1 = int(sub_end * 1000)
|
||||||
|
|
||||||
|
chunk = audio_seg[chunk_i0:chunk_i1]
|
||||||
|
samples = (
|
||||||
|
np.array(chunk.get_array_of_samples()).astype(np.float32) / 32768.0
|
||||||
|
)
|
||||||
|
chunk_id += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.transcriber(samples, return_timestamps=True)
|
||||||
|
|
||||||
|
any_valid = False
|
||||||
|
|
||||||
|
last_valid_offset_ = last_valid_offset
|
||||||
|
|
||||||
|
for seg in result["chunks"]:
|
||||||
|
t0, t1 = seg["timestamp"]
|
||||||
|
if t0 is None or t1 is None or t1 <= t0:
|
||||||
|
_log.warning(f" => skipping bad segment: {seg}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
global_start = round(last_valid_offset_ + t0, 2)
|
||||||
|
global_end = round(last_valid_offset_ + t1, 2)
|
||||||
|
text = seg["text"].strip()
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
_ConversationItem(
|
||||||
|
start_time=global_start, end_time=global_end, text=text
|
||||||
|
)
|
||||||
|
)
|
||||||
|
last_valid_offset = max(global_end, last_valid_offset)
|
||||||
|
any_valid = True
|
||||||
|
|
||||||
|
if not any_valid:
|
||||||
|
_log.warning(
|
||||||
|
"No valid transcription in chunk, nudging forward 1s."
|
||||||
|
)
|
||||||
|
last_valid_offset += 1.0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
_log.error(f"Whisper failed: {e}")
|
||||||
|
last_valid_offset += 1.0
|
||||||
|
|
||||||
|
duration = end - last_valid_offset
|
||||||
|
else:
|
||||||
|
any_valid = False
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
class _WhisperModel:
|
class _WhisperModel:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
_log.info("initialisation `_WhisperModel`")
|
_log.info("initialisation `_WhisperModel`")
|
||||||
|
|
||||||
from transformers import WhisperForConditionalGeneration, WhisperProcessor
|
self.device = "cpu"
|
||||||
|
self.chunk_length = 30
|
||||||
|
|
||||||
self.model_repo = "openai/whisper-tiny"
|
self.batch_size = 8
|
||||||
|
|
||||||
|
# self.model_repo = "openai/whisper-tiny"
|
||||||
|
# self.model_repo = "openai/whisper-small"
|
||||||
|
self.model_repo = "openai/whisper-medium"
|
||||||
|
# self.model_repo = "openai/whisper-large"
|
||||||
|
|
||||||
self.processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
self.processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||||
self.model = WhisperForConditionalGeneration.from_pretrained(
|
self.model = WhisperForConditionalGeneration.from_pretrained(
|
||||||
@ -89,100 +332,43 @@ class _WhisperModel:
|
|||||||
|
|
||||||
# FIXME
|
# FIXME
|
||||||
self.max_new_tokens = 256
|
self.max_new_tokens = 256
|
||||||
|
|
||||||
_log.info(f"model is loaded: {self.model_repo}")
|
_log.info(f"model is loaded: {self.model_repo}")
|
||||||
|
|
||||||
def run(self, conv_res: ConversionResult):
|
self.pipe = pipeline(
|
||||||
# fpath = Path(conv_res.input.file)
|
"automatic-speech-recognition",
|
||||||
# _log.info(f"`_WhisperModel::run: {conv_res}`")
|
model=self.model_repo,
|
||||||
_log.info(f"`_WhisperModel::run: {conv_res.input}`")
|
chunk_length_s=self.chunk_length,
|
||||||
_log.info(f"`_WhisperModel::run: {conv_res.input.file}`")
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
if os.path.exists(str(conv_res.input.file)):
|
def run(self, conv_res: ConversionResult) -> ConversionResult:
|
||||||
print("file exists")
|
return self._run_pipeline(conv_res=conv_res)
|
||||||
else:
|
|
||||||
print("file does not exist")
|
|
||||||
#
|
|
||||||
|
|
||||||
_log.info(f"sampling-rate: {self.processor.feature_extractor.sampling_rate}")
|
|
||||||
|
|
||||||
|
def _run_pipeline(self, conv_res: ConversionResult) -> ConversionResult:
|
||||||
try:
|
try:
|
||||||
fpath = conv_res.input.file
|
fpath = conv_res.input.file
|
||||||
# array, sampling_rate = sf.read(fpath)#, samplerate=processor.feature_extractor.sampling_rate)
|
|
||||||
array, sampling_rate = sf.read(
|
|
||||||
fpath
|
|
||||||
) # , samplerate=self.processor.feature_extractor.sampling_rate)
|
|
||||||
|
|
||||||
_log.info(
|
|
||||||
f"read the file .. (sampling-rate: {sampling_rate}, array: {array.shape})"
|
|
||||||
)
|
|
||||||
|
|
||||||
array, sampling_rate = librosa.load(fpath, sr=16000)
|
array, sampling_rate = librosa.load(fpath, sr=16000)
|
||||||
|
|
||||||
_log.info(
|
prediction = self.pipe(
|
||||||
f"read the file .. (sampling-rate: {sampling_rate}, array: {array.shape})"
|
inputs=array, batch_size=self.batch_size, return_timestamps=True
|
||||||
)
|
) # ["chunks"]
|
||||||
|
|
||||||
|
for _ in prediction["chunks"]:
|
||||||
processed_input = self.processor(
|
item = _ConversationItem(
|
||||||
array,
|
text=_["text"],
|
||||||
sampling_rate=self.processor.feature_extractor.sampling_rate, # sampling_rate,
|
start_time=_["timestamp"][0],
|
||||||
return_tensors="pt",
|
end_time=_["timestamp"][1],
|
||||||
)
|
)
|
||||||
print(processed_input)
|
conv_res.document.add_text(
|
||||||
|
label=DocItemLabel.TEXT, text=item.to_string()
|
||||||
# pre-process to get the input features
|
)
|
||||||
input_features = self.processor(
|
|
||||||
array, sampling_rate=sampling_rate, return_tensors="pt"
|
|
||||||
).input_features
|
|
||||||
|
|
||||||
_log.info(f"got input-features: {input_features.shape}")
|
|
||||||
_log.info(f"max new tokens: {self.max_new_tokens}")
|
|
||||||
|
|
||||||
# generate token ids by running model forward sequentially
|
|
||||||
predicted_ids = self.model.generate(
|
|
||||||
input_features, max_new_tokens=self.max_new_tokens, return_timestamps=True
|
|
||||||
)
|
|
||||||
|
|
||||||
_log.info("ran model ..")
|
|
||||||
|
|
||||||
"""
|
|
||||||
transcription = self.processor.batch_decode(predicted_ids,
|
|
||||||
skip_special_tokens=False,
|
|
||||||
decode_with_timestamps=True)
|
|
||||||
|
|
||||||
_log.info("decoded output ..")
|
|
||||||
|
|
||||||
print(f"Transcription: {transcription}")
|
|
||||||
"""
|
|
||||||
|
|
||||||
conversation = []
|
|
||||||
|
|
||||||
print("Timestamp info:")
|
|
||||||
for pidi, pid in enumerate(predicted_ids):
|
|
||||||
# timestamps = processor.tokenizer.decode(pid, decode_with_timestamps=True)
|
|
||||||
timestamps = self.processor.tokenizer.decode(pid, output_offsets=True)
|
|
||||||
print(f"Predicted id [{pidi}]: {timestamps['text']}")
|
|
||||||
for offset in timestamps["offsets"]:
|
|
||||||
print(f" => {offset['timestamp']}: {offset['text']}")
|
|
||||||
|
|
||||||
item = ConversationEntry(
|
|
||||||
text=offset["text"],
|
|
||||||
speaker_id=pidi,
|
|
||||||
start_time=offset["timestamp"][0],
|
|
||||||
end_time=offset["timestamp"][1],
|
|
||||||
)
|
|
||||||
conv_res.document.add_text(
|
|
||||||
label=DocItemLabel.TEXT, text=item.to_string()
|
|
||||||
)
|
|
||||||
|
|
||||||
conv_res.status = ConversionStatus.SUCCESS
|
conv_res.status = ConversionStatus.SUCCESS
|
||||||
|
|
||||||
print("document: \n\n", conv_res.document.export_to_markdown())
|
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
conv_res.status = ConversionStatus.FAILED
|
conv_res.status = ConversionStatus.FAILURE
|
||||||
print(exc)
|
_log.error(f"Failed to convert with {self.model_repo}: {exc}")
|
||||||
|
|
||||||
return conv_res
|
return conv_res
|
||||||
|
|
||||||
@ -206,7 +392,8 @@ class AsrPipeline(BasePipeline):
|
|||||||
"When defined, it must point to a folder containing all models required by the pipeline."
|
"When defined, it must point to a folder containing all models required by the pipeline."
|
||||||
)
|
)
|
||||||
|
|
||||||
self._model = _WhisperModel()
|
# self._model = _WhisperModel()
|
||||||
|
self._model = _WhisperASR()
|
||||||
|
|
||||||
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
|
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
|
||||||
status = ConversionStatus.SUCCESS
|
status = ConversionStatus.SUCCESS
|
||||||
|
@ -70,6 +70,8 @@ dependencies = [
|
|||||||
'scipy (>=1.6.0,<2.0.0)',
|
'scipy (>=1.6.0,<2.0.0)',
|
||||||
# 'scipy (>=1.6.0,<2.0.0) ; python_version >= "3.10"',
|
# 'scipy (>=1.6.0,<2.0.0) ; python_version >= "3.10"',
|
||||||
# 'scipy (>=1.6.0,<1.14.0) ; python_version < "3.10"',
|
# 'scipy (>=1.6.0,<1.14.0) ; python_version < "3.10"',
|
||||||
|
"pydub[asr]>=0.25.1",
|
||||||
|
"pyannote-audio[asr]>=1.1.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
Loading…
Reference in New Issue
Block a user