mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-10 21:58:15 +00:00
feat: [Beta] Extraction with schema (#2138)
* Add DocumentConverter.extract and full extraction pipeline Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add DocumentConverter.extract template arg Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add NuExtract model Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add Extraction pipeline Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add proper test, support pydantic class types Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add qr bill example Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add base_extraction_pipeline Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add types Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update typing of ExtractionResult and inner fields Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Factor out extract to DocumentExtractor Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Address mypy issues Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add DocumentExtractor Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Resolve circular import issue Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Clean up imports, remove Optional for template arg Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Move new type definitions into datamodel Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update comments Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Respect page-range, disable test_extraction for CI Signed-off-by: Christoph Auer <cau@zurich.ibm.com> --------- Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
from docling_core.types.doc import (
|
||||
@@ -32,6 +32,18 @@ from pydantic import (
|
||||
if TYPE_CHECKING:
|
||||
from docling.backend.pdf_backend import PdfPageBackend
|
||||
|
||||
from docling.backend.abstract_backend import AbstractDocumentBackend
|
||||
from docling.datamodel.pipeline_options import PipelineOptions
|
||||
|
||||
|
||||
class BaseFormatOption(BaseModel):
|
||||
"""Base class for format options used by _DocumentConversionInput."""
|
||||
|
||||
pipeline_options: Optional[PipelineOptions] = None
|
||||
backend: Type[AbstractDocumentBackend]
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ConversionStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
|
||||
@@ -2,12 +2,13 @@ import csv
|
||||
import logging
|
||||
import re
|
||||
import tarfile
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Mapping
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path, PurePath
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
@@ -72,7 +73,7 @@ from docling.utils.profiling import ProfilingItem
|
||||
from docling.utils.utils import create_file_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from docling.document_converter import FormatOption
|
||||
from docling.datamodel.base_models import BaseFormatOption
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
@@ -238,7 +239,8 @@ class _DocumentConversionInput(BaseModel):
|
||||
limits: Optional[DocumentLimits] = DocumentLimits()
|
||||
|
||||
def docs(
|
||||
self, format_options: Dict[InputFormat, "FormatOption"]
|
||||
self,
|
||||
format_options: Mapping[InputFormat, "BaseFormatOption"],
|
||||
) -> Iterable[InputDocument]:
|
||||
for item in self.path_or_stream_iterator:
|
||||
obj = (
|
||||
|
||||
39
docling/datamodel/extraction.py
Normal file
39
docling/datamodel/extraction.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Data models for document extraction functionality."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docling.datamodel.base_models import ConversionStatus, ErrorItem
|
||||
from docling.datamodel.document import InputDocument
|
||||
|
||||
|
||||
class ExtractedPageData(BaseModel):
|
||||
"""Data model for extracted content from a single page."""
|
||||
|
||||
page_no: int = Field(..., description="1-indexed page number")
|
||||
extracted_data: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Extracted structured data from the page"
|
||||
)
|
||||
raw_text: Optional[str] = Field(None, description="Raw extracted text")
|
||||
errors: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Any errors encountered during extraction for this page",
|
||||
)
|
||||
|
||||
|
||||
class ExtractionResult(BaseModel):
|
||||
"""Result of document extraction."""
|
||||
|
||||
input: InputDocument
|
||||
status: ConversionStatus = ConversionStatus.PENDING
|
||||
errors: List[ErrorItem] = []
|
||||
|
||||
# Pages field - always a list for consistency
|
||||
pages: List[ExtractedPageData] = Field(
|
||||
default_factory=list, description="Extracted data from each page"
|
||||
)
|
||||
|
||||
|
||||
# Type alias for template parameters that can be string, dict, or BaseModel
|
||||
ExtractionTemplateType = Union[str, Dict[str, Any], BaseModel, Type[BaseModel]]
|
||||
@@ -37,6 +37,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
||||
from docling.datamodel.vlm_model_specs import (
|
||||
GRANITE_VISION_OLLAMA as granite_vision_vlm_ollama_conversion_options,
|
||||
GRANITE_VISION_TRANSFORMERS as granite_vision_vlm_conversion_options,
|
||||
NU_EXTRACT_2B_TRANSFORMERS,
|
||||
SMOLDOCLING_MLX as smoldocling_vlm_mlx_conversion_options,
|
||||
SMOLDOCLING_TRANSFORMERS as smoldocling_vlm_conversion_options,
|
||||
VlmModelType,
|
||||
@@ -247,12 +248,9 @@ class OcrEngine(str, Enum):
|
||||
RAPIDOCR = "rapidocr"
|
||||
|
||||
|
||||
class PipelineOptions(BaseModel):
|
||||
class PipelineOptions(BaseOptions):
|
||||
"""Base pipeline options."""
|
||||
|
||||
create_legacy_output: bool = (
|
||||
True # This default will be set to False on a future version of docling
|
||||
)
|
||||
document_timeout: Optional[float] = None
|
||||
accelerator_options: AcceleratorOptions = AcceleratorOptions()
|
||||
enable_remote_services: bool = False
|
||||
@@ -296,6 +294,13 @@ class AsrPipelineOptions(PipelineOptions):
|
||||
artifacts_path: Optional[Union[Path, str]] = None
|
||||
|
||||
|
||||
class VlmExtractionPipelineOptions(PipelineOptions):
|
||||
"""Options for extraction pipeline."""
|
||||
|
||||
artifacts_path: Optional[Union[Path, str]] = None
|
||||
vlm_options: Union[InlineVlmOptions] = NU_EXTRACT_2B_TRANSFORMERS
|
||||
|
||||
|
||||
class PdfPipelineOptions(PaginatedPipelineOptions):
|
||||
"""Options for the PDF pipeline."""
|
||||
|
||||
|
||||
@@ -247,6 +247,23 @@ DOLPHIN_TRANSFORMERS = InlineVlmOptions(
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# NuExtract
|
||||
NU_EXTRACT_2B_TRANSFORMERS = InlineVlmOptions(
|
||||
repo_id="numind/NuExtract-2.0-2B",
|
||||
prompt="", # This won't be used, template is passed separately
|
||||
torch_dtype="bfloat16",
|
||||
inference_framework=InferenceFramework.TRANSFORMERS,
|
||||
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
||||
response_format=ResponseFormat.PLAINTEXT,
|
||||
supported_devices=[
|
||||
AcceleratorDevice.CPU,
|
||||
AcceleratorDevice.CUDA,
|
||||
AcceleratorDevice.MPS,
|
||||
],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
|
||||
class VlmModelType(str, Enum):
|
||||
SMOLDOCLING = "smoldocling"
|
||||
|
||||
@@ -28,6 +28,7 @@ from docling.backend.noop_backend import NoOpBackend
|
||||
from docling.backend.xml.jats_backend import JatsDocumentBackend
|
||||
from docling.backend.xml.uspto_backend import PatentUsptoDocumentBackend
|
||||
from docling.datamodel.base_models import (
|
||||
BaseFormatOption,
|
||||
ConversionStatus,
|
||||
DoclingComponentType,
|
||||
DocumentStream,
|
||||
@@ -57,12 +58,8 @@ _log = logging.getLogger(__name__)
|
||||
_PIPELINE_CACHE_LOCK = threading.Lock()
|
||||
|
||||
|
||||
class FormatOption(BaseModel):
|
||||
class FormatOption(BaseFormatOption):
|
||||
pipeline_cls: Type[BasePipeline]
|
||||
pipeline_options: Optional[PipelineOptions] = None
|
||||
backend: Type[AbstractDocumentBackend]
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_optional_field_default(self) -> "FormatOption":
|
||||
@@ -191,7 +188,7 @@ class DocumentConverter:
|
||||
self.allowed_formats = (
|
||||
allowed_formats if allowed_formats is not None else list(InputFormat)
|
||||
)
|
||||
self.format_to_options = {
|
||||
self.format_to_options: Dict[InputFormat, FormatOption] = {
|
||||
format: (
|
||||
_get_default_option(format=format)
|
||||
if (custom_option := (format_options or {}).get(format)) is None
|
||||
|
||||
325
docling/document_extractor.py
Normal file
325
docling/document_extractor.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import Iterable, Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from pydantic import ConfigDict, model_validator, validate_call
|
||||
|
||||
from docling.backend.abstract_backend import AbstractDocumentBackend
|
||||
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
|
||||
from docling.datamodel.base_models import (
|
||||
BaseFormatOption,
|
||||
ConversionStatus,
|
||||
DoclingComponentType,
|
||||
DocumentStream,
|
||||
ErrorItem,
|
||||
InputFormat,
|
||||
)
|
||||
from docling.datamodel.document import (
|
||||
InputDocument,
|
||||
_DocumentConversionInput, # intentionally reused builder
|
||||
)
|
||||
from docling.datamodel.extraction import ExtractionResult, ExtractionTemplateType
|
||||
from docling.datamodel.pipeline_options import PipelineOptions
|
||||
from docling.datamodel.settings import (
|
||||
DEFAULT_PAGE_RANGE,
|
||||
DocumentLimits,
|
||||
PageRange,
|
||||
settings,
|
||||
)
|
||||
from docling.exceptions import ConversionError
|
||||
from docling.pipeline.base_extraction_pipeline import BaseExtractionPipeline
|
||||
from docling.pipeline.extraction_vlm_pipeline import ExtractionVlmPipeline
|
||||
from docling.utils.utils import chunkify
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
_PIPELINE_CACHE_LOCK = threading.Lock()
|
||||
|
||||
|
||||
class ExtractionFormatOption(BaseFormatOption):
|
||||
"""Per-format configuration for extraction.
|
||||
|
||||
Notes:
|
||||
- `pipeline_cls` must subclass `BaseExtractionPipeline`.
|
||||
- `pipeline_options` is typed as `PipelineOptions` which MUST inherit from
|
||||
`BaseOptions` (as used by `BaseExtractionPipeline`).
|
||||
- `backend` is the document-opening backend used by `_DocumentConversionInput`.
|
||||
"""
|
||||
|
||||
pipeline_cls: Type[BaseExtractionPipeline]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_optional_field_default(self) -> "ExtractionFormatOption":
|
||||
if self.pipeline_options is None:
|
||||
# `get_default_options` comes from BaseExtractionPipeline
|
||||
self.pipeline_options = self.pipeline_cls.get_default_options() # type: ignore[assignment]
|
||||
return self
|
||||
|
||||
|
||||
def _get_default_extraction_option(fmt: InputFormat) -> ExtractionFormatOption:
|
||||
"""Return the default extraction option for a given input format.
|
||||
|
||||
Defaults mirror the converter's *backend* choices, while the pipeline is
|
||||
the VLM extractor. This duplication will be removed when we deduplicate
|
||||
the format registry between convert/extract.
|
||||
"""
|
||||
format_to_default_backend: Dict[InputFormat, Type[AbstractDocumentBackend]] = {
|
||||
InputFormat.IMAGE: PyPdfiumDocumentBackend,
|
||||
InputFormat.PDF: PyPdfiumDocumentBackend,
|
||||
}
|
||||
|
||||
backend = format_to_default_backend.get(fmt)
|
||||
if backend is None:
|
||||
raise RuntimeError(f"No default extraction backend configured for {fmt}")
|
||||
|
||||
return ExtractionFormatOption(
|
||||
pipeline_cls=ExtractionVlmPipeline,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
|
||||
class DocumentExtractor:
|
||||
"""Standalone extractor class.
|
||||
|
||||
Public API:
|
||||
- `extract(...) -> ExtractionResult`
|
||||
- `extract_all(...) -> Iterator[ExtractionResult]`
|
||||
|
||||
Implementation intentionally reuses `_DocumentConversionInput` to build
|
||||
`InputDocument` with the correct backend per format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
allowed_formats: Optional[List[InputFormat]] = None,
|
||||
extraction_format_options: Optional[
|
||||
Dict[InputFormat, ExtractionFormatOption]
|
||||
] = None,
|
||||
) -> None:
|
||||
self.allowed_formats: List[InputFormat] = (
|
||||
allowed_formats if allowed_formats is not None else list(InputFormat)
|
||||
)
|
||||
# Build per-format options with defaults, then apply any user overrides
|
||||
overrides = extraction_format_options or {}
|
||||
self.extraction_format_to_options: Dict[InputFormat, ExtractionFormatOption] = {
|
||||
fmt: overrides.get(fmt, _get_default_extraction_option(fmt))
|
||||
for fmt in self.allowed_formats
|
||||
}
|
||||
|
||||
# Cache pipelines by (class, options-hash)
|
||||
self._initialized_pipelines: Dict[
|
||||
Tuple[Type[BaseExtractionPipeline], str], BaseExtractionPipeline
|
||||
] = {}
|
||||
|
||||
# ---------------------------- Public API ---------------------------------
|
||||
|
||||
@validate_call(config=ConfigDict(strict=True))
|
||||
def extract(
|
||||
self,
|
||||
source: Union[Path, str, DocumentStream],
|
||||
template: ExtractionTemplateType,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
raises_on_error: bool = True,
|
||||
max_num_pages: int = sys.maxsize,
|
||||
max_file_size: int = sys.maxsize,
|
||||
page_range: PageRange = DEFAULT_PAGE_RANGE,
|
||||
) -> ExtractionResult:
|
||||
all_res = self.extract_all(
|
||||
source=[source],
|
||||
headers=headers,
|
||||
raises_on_error=raises_on_error,
|
||||
max_num_pages=max_num_pages,
|
||||
max_file_size=max_file_size,
|
||||
page_range=page_range,
|
||||
template=template,
|
||||
)
|
||||
return next(all_res)
|
||||
|
||||
@validate_call(config=ConfigDict(strict=True))
|
||||
def extract_all(
|
||||
self,
|
||||
source: Iterable[Union[Path, str, DocumentStream]],
|
||||
template: ExtractionTemplateType,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
raises_on_error: bool = True,
|
||||
max_num_pages: int = sys.maxsize,
|
||||
max_file_size: int = sys.maxsize,
|
||||
page_range: PageRange = DEFAULT_PAGE_RANGE,
|
||||
) -> Iterator[ExtractionResult]:
|
||||
warnings.warn(
|
||||
"The extract API is currently experimental and may change without prior notice.\n"
|
||||
"Only PDF and image formats are supported.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
limits = DocumentLimits(
|
||||
max_num_pages=max_num_pages,
|
||||
max_file_size=max_file_size,
|
||||
page_range=page_range,
|
||||
)
|
||||
conv_input = _DocumentConversionInput(
|
||||
path_or_stream_iterator=source, limits=limits, headers=headers
|
||||
)
|
||||
|
||||
ext_res_iter = self._extract(
|
||||
conv_input, raises_on_error=raises_on_error, template=template
|
||||
)
|
||||
|
||||
had_result = False
|
||||
for ext_res in ext_res_iter:
|
||||
had_result = True
|
||||
if raises_on_error and ext_res.status not in {
|
||||
ConversionStatus.SUCCESS,
|
||||
ConversionStatus.PARTIAL_SUCCESS,
|
||||
}:
|
||||
raise ConversionError(
|
||||
f"Extraction failed for: {ext_res.input.file} with status: {ext_res.status}"
|
||||
)
|
||||
else:
|
||||
yield ext_res
|
||||
|
||||
if not had_result and raises_on_error:
|
||||
raise ConversionError(
|
||||
"Extraction failed because the provided file has no recognizable format or it wasn't in the list of allowed formats."
|
||||
)
|
||||
|
||||
# --------------------------- Internal engine ------------------------------
|
||||
|
||||
def _extract(
|
||||
self,
|
||||
conv_input: _DocumentConversionInput,
|
||||
raises_on_error: bool,
|
||||
template: ExtractionTemplateType,
|
||||
) -> Iterator[ExtractionResult]:
|
||||
start_time = time.monotonic()
|
||||
|
||||
for input_batch in chunkify(
|
||||
conv_input.docs(self.extraction_format_to_options),
|
||||
settings.perf.doc_batch_size,
|
||||
):
|
||||
_log.info("Going to extract document batch...")
|
||||
process_func = partial(
|
||||
self._process_document_extraction,
|
||||
raises_on_error=raises_on_error,
|
||||
template=template,
|
||||
)
|
||||
|
||||
if (
|
||||
settings.perf.doc_batch_concurrency > 1
|
||||
and settings.perf.doc_batch_size > 1
|
||||
):
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=settings.perf.doc_batch_concurrency
|
||||
) as pool:
|
||||
for item in pool.map(
|
||||
process_func,
|
||||
input_batch,
|
||||
):
|
||||
yield item
|
||||
else:
|
||||
for item in map(
|
||||
process_func,
|
||||
input_batch,
|
||||
):
|
||||
elapsed = time.monotonic() - start_time
|
||||
start_time = time.monotonic()
|
||||
_log.info(
|
||||
f"Finished extracting document {item.input.file.name} in {elapsed:.2f} sec."
|
||||
)
|
||||
yield item
|
||||
|
||||
def _process_document_extraction(
|
||||
self,
|
||||
in_doc: InputDocument,
|
||||
raises_on_error: bool,
|
||||
template: ExtractionTemplateType,
|
||||
) -> ExtractionResult:
|
||||
valid = (
|
||||
self.allowed_formats is not None and in_doc.format in self.allowed_formats
|
||||
)
|
||||
if valid:
|
||||
return self._execute_extraction_pipeline(
|
||||
in_doc, raises_on_error=raises_on_error, template=template
|
||||
)
|
||||
else:
|
||||
error_message = f"File format not allowed: {in_doc.file}"
|
||||
if raises_on_error:
|
||||
raise ConversionError(error_message)
|
||||
else:
|
||||
error_item = ErrorItem(
|
||||
component_type=DoclingComponentType.USER_INPUT,
|
||||
module_name="",
|
||||
error_message=error_message,
|
||||
)
|
||||
return ExtractionResult(
|
||||
input=in_doc, status=ConversionStatus.SKIPPED, errors=[error_item]
|
||||
)
|
||||
|
||||
def _execute_extraction_pipeline(
|
||||
self,
|
||||
in_doc: InputDocument,
|
||||
raises_on_error: bool,
|
||||
template: ExtractionTemplateType,
|
||||
) -> ExtractionResult:
|
||||
if not in_doc.valid:
|
||||
if raises_on_error:
|
||||
raise ConversionError(f"Input document {in_doc.file} is not valid.")
|
||||
else:
|
||||
return ExtractionResult(input=in_doc, status=ConversionStatus.FAILURE)
|
||||
|
||||
pipeline = self._get_pipeline(in_doc.format)
|
||||
if pipeline is None:
|
||||
if raises_on_error:
|
||||
raise ConversionError(
|
||||
f"No extraction pipeline could be initialized for {in_doc.file}."
|
||||
)
|
||||
else:
|
||||
return ExtractionResult(input=in_doc, status=ConversionStatus.FAILURE)
|
||||
|
||||
return pipeline.execute(
|
||||
in_doc, raises_on_error=raises_on_error, template=template
|
||||
)
|
||||
|
||||
def _get_pipeline(
|
||||
self, doc_format: InputFormat
|
||||
) -> Optional[BaseExtractionPipeline]:
|
||||
"""Retrieve or initialize a pipeline, reusing instances based on class and options."""
|
||||
fopt = self.extraction_format_to_options.get(doc_format)
|
||||
if fopt is None or fopt.pipeline_options is None:
|
||||
return None
|
||||
|
||||
pipeline_class = fopt.pipeline_cls
|
||||
pipeline_options = fopt.pipeline_options
|
||||
options_hash = self._get_pipeline_options_hash(pipeline_options)
|
||||
|
||||
cache_key = (pipeline_class, options_hash)
|
||||
with _PIPELINE_CACHE_LOCK:
|
||||
if cache_key not in self._initialized_pipelines:
|
||||
_log.info(
|
||||
f"Initializing extraction pipeline for {pipeline_class.__name__} with options hash {options_hash}"
|
||||
)
|
||||
self._initialized_pipelines[cache_key] = pipeline_class(
|
||||
pipeline_options=pipeline_options # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
_log.debug(
|
||||
f"Reusing cached extraction pipeline for {pipeline_class.__name__} with options hash {options_hash}"
|
||||
)
|
||||
|
||||
return self._initialized_pipelines[cache_key]
|
||||
|
||||
@staticmethod
|
||||
def _get_pipeline_options_hash(pipeline_options: PipelineOptions) -> str:
|
||||
"""Generate a stable hash of pipeline options to use as part of the cache key."""
|
||||
options_str = str(pipeline_options.model_dump())
|
||||
return hashlib.md5(
|
||||
options_str.encode("utf-8"), usedforsecurity=False
|
||||
).hexdigest()
|
||||
290
docling/models/vlm_models_inline/nuextract_transformers_model.py
Normal file
290
docling/models/vlm_models_inline/nuextract_transformers_model.py
Normal file
@@ -0,0 +1,290 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor, GenerationConfig
|
||||
|
||||
from docling.datamodel.accelerator_options import (
|
||||
AcceleratorOptions,
|
||||
)
|
||||
from docling.datamodel.base_models import VlmPrediction
|
||||
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
||||
from docling.models.base_model import BaseVlmModel
|
||||
from docling.models.utils.hf_model_download import (
|
||||
HuggingFaceModelDownloadMixin,
|
||||
)
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Source code from https://huggingface.co/numind/NuExtract-2.0-8B
|
||||
def process_all_vision_info(messages, examples=None):
|
||||
"""
|
||||
Process vision information from both messages and in-context examples, supporting batch processing.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries (single input) OR list of message lists (batch input)
|
||||
examples: Optional list of example dictionaries (single input) OR list of example lists (batch)
|
||||
|
||||
Returns:
|
||||
A flat list of all images in the correct order:
|
||||
- For single input: example images followed by message images
|
||||
- For batch input: interleaved as (item1 examples, item1 input, item2 examples, item2 input, etc.)
|
||||
- Returns None if no images were found
|
||||
"""
|
||||
try:
|
||||
from qwen_vl_utils import fetch_image, process_vision_info
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"qwen-vl-utils is required for NuExtractTransformersModel. "
|
||||
"Please install it with: pip install qwen-vl-utils"
|
||||
)
|
||||
|
||||
from qwen_vl_utils import fetch_image, process_vision_info
|
||||
|
||||
# Helper function to extract images from examples
|
||||
def extract_example_images(example_item):
|
||||
if not example_item:
|
||||
return []
|
||||
|
||||
# Handle both list of examples and single example
|
||||
examples_to_process = (
|
||||
example_item if isinstance(example_item, list) else [example_item]
|
||||
)
|
||||
images = []
|
||||
|
||||
for example in examples_to_process:
|
||||
if (
|
||||
isinstance(example.get("input"), dict)
|
||||
and example["input"].get("type") == "image"
|
||||
):
|
||||
images.append(fetch_image(example["input"]))
|
||||
|
||||
return images
|
||||
|
||||
# Normalize inputs to always be batched format
|
||||
is_batch = messages and isinstance(messages[0], list)
|
||||
messages_batch = messages if is_batch else [messages]
|
||||
is_batch_examples = (
|
||||
examples
|
||||
and isinstance(examples, list)
|
||||
and (isinstance(examples[0], list) or examples[0] is None)
|
||||
)
|
||||
examples_batch = (
|
||||
examples
|
||||
if is_batch_examples
|
||||
else ([examples] if examples is not None else None)
|
||||
)
|
||||
|
||||
# Ensure examples batch matches messages batch if provided
|
||||
if examples and len(examples_batch) != len(messages_batch):
|
||||
if not is_batch and len(examples_batch) == 1:
|
||||
# Single example set for a single input is fine
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Examples batch length must match messages batch length")
|
||||
|
||||
# Process all inputs, maintaining correct order
|
||||
all_images = []
|
||||
for i, message_group in enumerate(messages_batch):
|
||||
# Get example images for this input
|
||||
if examples and i < len(examples_batch):
|
||||
input_example_images = extract_example_images(examples_batch[i])
|
||||
all_images.extend(input_example_images)
|
||||
|
||||
# Get message images for this input
|
||||
input_message_images = process_vision_info(message_group)[0] or []
|
||||
all_images.extend(input_message_images)
|
||||
|
||||
return all_images if all_images else None
|
||||
|
||||
|
||||
class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Optional[Path],
|
||||
accelerator_options: AcceleratorOptions,
|
||||
vlm_options: InlineVlmOptions,
|
||||
):
|
||||
self.enabled = enabled
|
||||
self.vlm_options = vlm_options
|
||||
|
||||
if self.enabled:
|
||||
import torch
|
||||
|
||||
self.device = decide_device(
|
||||
accelerator_options.device,
|
||||
supported_devices=vlm_options.supported_devices,
|
||||
)
|
||||
_log.debug(f"Available device for NuExtract VLM: {self.device}")
|
||||
|
||||
self.max_new_tokens = vlm_options.max_new_tokens
|
||||
self.temperature = vlm_options.temperature
|
||||
|
||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models(self.vlm_options.repo_id)
|
||||
elif (artifacts_path / repo_cache_folder).exists():
|
||||
artifacts_path = artifacts_path / repo_cache_folder
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
artifacts_path,
|
||||
trust_remote_code=vlm_options.trust_remote_code,
|
||||
use_fast=True,
|
||||
)
|
||||
self.processor.tokenizer.padding_side = "left"
|
||||
|
||||
self.vlm_model = AutoModelForImageTextToText.from_pretrained(
|
||||
artifacts_path,
|
||||
device_map=self.device,
|
||||
torch_dtype=self.vlm_options.torch_dtype,
|
||||
_attn_implementation=(
|
||||
"flash_attention_2"
|
||||
if self.device.startswith("cuda")
|
||||
and accelerator_options.cuda_use_flash_attention2
|
||||
else "sdpa"
|
||||
),
|
||||
trust_remote_code=vlm_options.trust_remote_code,
|
||||
)
|
||||
self.vlm_model = torch.compile(self.vlm_model) # type: ignore
|
||||
|
||||
# Load generation config
|
||||
self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
|
||||
|
||||
def process_images(
|
||||
self,
|
||||
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||
prompt: Union[str, list[str]],
|
||||
) -> Iterable[VlmPrediction]:
|
||||
"""
|
||||
Batched inference for NuExtract VLM using the specialized input format.
|
||||
|
||||
Args:
|
||||
image_batch: Iterable of PIL Images or numpy arrays
|
||||
prompt: Either:
|
||||
- str: Single template used for all images
|
||||
- list[str]: List of templates (one per image, must match image count)
|
||||
"""
|
||||
import torch
|
||||
from PIL import Image as PILImage
|
||||
|
||||
# Normalize images to RGB PIL
|
||||
pil_images: list[Image] = []
|
||||
for img in image_batch:
|
||||
if isinstance(img, np.ndarray):
|
||||
if img.ndim == 3 and img.shape[2] in (3, 4):
|
||||
pil_img = PILImage.fromarray(img.astype(np.uint8))
|
||||
elif img.ndim == 2:
|
||||
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
|
||||
else:
|
||||
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
|
||||
else:
|
||||
pil_img = img
|
||||
if pil_img.mode != "RGB":
|
||||
pil_img = pil_img.convert("RGB")
|
||||
pil_images.append(pil_img)
|
||||
|
||||
if not pil_images:
|
||||
return
|
||||
|
||||
# Normalize templates (1 per image)
|
||||
if isinstance(prompt, str):
|
||||
templates = [prompt] * len(pil_images)
|
||||
else:
|
||||
if len(prompt) != len(pil_images):
|
||||
raise ValueError(
|
||||
f"Number of templates ({len(prompt)}) must match number of images ({len(pil_images)})"
|
||||
)
|
||||
templates = prompt
|
||||
|
||||
# Construct NuExtract input format
|
||||
inputs = []
|
||||
for pil_img, template in zip(pil_images, templates):
|
||||
input_item = {
|
||||
"document": {"type": "image", "image": pil_img},
|
||||
"template": template,
|
||||
}
|
||||
inputs.append(input_item)
|
||||
|
||||
# Create messages structure for batch processing
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [x["document"]],
|
||||
}
|
||||
]
|
||||
for x in inputs
|
||||
]
|
||||
|
||||
# Apply chat template to each example individually
|
||||
texts = [
|
||||
self.processor.tokenizer.apply_chat_template(
|
||||
messages[i],
|
||||
template=x["template"],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
for i, x in enumerate(inputs)
|
||||
]
|
||||
|
||||
# Process vision inputs using qwen-vl-utils
|
||||
image_inputs = process_all_vision_info(messages)
|
||||
|
||||
# Process with the processor
|
||||
processor_inputs = self.processor(
|
||||
text=texts,
|
||||
images=image_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
**self.vlm_options.extra_processor_kwargs,
|
||||
)
|
||||
processor_inputs = {k: v.to(self.device) for k, v in processor_inputs.items()}
|
||||
|
||||
# Generate
|
||||
gen_kwargs = {
|
||||
**processor_inputs,
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"generation_config": self.generation_config,
|
||||
**self.vlm_options.extra_generation_config,
|
||||
}
|
||||
if self.temperature > 0:
|
||||
gen_kwargs["do_sample"] = True
|
||||
gen_kwargs["temperature"] = self.temperature
|
||||
else:
|
||||
gen_kwargs["do_sample"] = False
|
||||
|
||||
start_time = time.time()
|
||||
with torch.inference_mode():
|
||||
generated_ids = self.vlm_model.generate(**gen_kwargs)
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
# Trim generated sequences
|
||||
input_len = processor_inputs["input_ids"].shape[1]
|
||||
trimmed_sequences = generated_ids[:, input_len:]
|
||||
|
||||
# Decode with the processor/tokenizer
|
||||
decoded_texts: list[str] = self.processor.batch_decode(
|
||||
trimmed_sequences,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
|
||||
# Optional logging
|
||||
if generated_ids.shape[0] > 0: # type: ignore
|
||||
_log.debug(
|
||||
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
|
||||
f"for batch size {generated_ids.shape[0]}." # type: ignore
|
||||
)
|
||||
|
||||
for text in decoded_texts:
|
||||
# Apply decode_response to the output text
|
||||
decoded_text = self.vlm_options.decode_response(text)
|
||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
||||
58
docling/pipeline/base_extraction_pipeline.py
Normal file
58
docling/pipeline/base_extraction_pipeline.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from docling.datamodel.base_models import ConversionStatus, ErrorItem
|
||||
from docling.datamodel.document import InputDocument
|
||||
from docling.datamodel.extraction import ExtractionResult, ExtractionTemplateType
|
||||
from docling.datamodel.pipeline_options import BaseOptions
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseExtractionPipeline(ABC):
|
||||
def __init__(self, pipeline_options: BaseOptions):
|
||||
self.pipeline_options = pipeline_options
|
||||
|
||||
def execute(
|
||||
self,
|
||||
in_doc: InputDocument,
|
||||
raises_on_error: bool,
|
||||
template: Optional[ExtractionTemplateType] = None,
|
||||
) -> ExtractionResult:
|
||||
ext_res = ExtractionResult(input=in_doc)
|
||||
|
||||
try:
|
||||
ext_res = self._extract_data(ext_res, template)
|
||||
ext_res.status = self._determine_status(ext_res)
|
||||
except Exception as e:
|
||||
ext_res.status = ConversionStatus.FAILURE
|
||||
error_item = ErrorItem(
|
||||
component_type="extraction_pipeline",
|
||||
module_name=self.__class__.__name__,
|
||||
error_message=str(e),
|
||||
)
|
||||
ext_res.errors.append(error_item)
|
||||
if raises_on_error:
|
||||
raise e
|
||||
|
||||
return ext_res
|
||||
|
||||
@abstractmethod
|
||||
def _extract_data(
|
||||
self,
|
||||
ext_res: ExtractionResult,
|
||||
template: Optional[ExtractionTemplateType] = None,
|
||||
) -> ExtractionResult:
|
||||
"""Subclass must populate ext_res.pages/errors and return the result."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _determine_status(self, ext_res: ExtractionResult) -> ConversionStatus:
|
||||
"""Subclass must decide SUCCESS/PARTIAL_SUCCESS/FAILURE based on ext_res."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_default_options(cls) -> BaseOptions:
|
||||
pass
|
||||
204
docling/pipeline/extraction_vlm_pipeline.py
Normal file
204
docling/pipeline/extraction_vlm_pipeline.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from PIL.Image import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from docling.backend.abstract_backend import PaginatedDocumentBackend
|
||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||
from docling.datamodel.base_models import ConversionStatus, ErrorItem
|
||||
from docling.datamodel.document import InputDocument
|
||||
from docling.datamodel.extraction import (
|
||||
ExtractedPageData,
|
||||
ExtractionResult,
|
||||
ExtractionTemplateType,
|
||||
)
|
||||
from docling.datamodel.pipeline_options import BaseOptions, VlmExtractionPipelineOptions
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.vlm_models_inline.nuextract_transformers_model import (
|
||||
NuExtractTransformersModel,
|
||||
)
|
||||
from docling.pipeline.base_extraction_pipeline import BaseExtractionPipeline
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExtractionVlmPipeline(BaseExtractionPipeline):
|
||||
def __init__(self, pipeline_options: VlmExtractionPipelineOptions):
|
||||
super().__init__(pipeline_options)
|
||||
|
||||
# Initialize VLM model with default options
|
||||
self.accelerator_options = pipeline_options.accelerator_options
|
||||
self.pipeline_options: VlmExtractionPipelineOptions
|
||||
|
||||
artifacts_path: Optional[Path] = None
|
||||
if pipeline_options.artifacts_path is not None:
|
||||
artifacts_path = Path(pipeline_options.artifacts_path).expanduser()
|
||||
elif settings.artifacts_path is not None:
|
||||
artifacts_path = Path(settings.artifacts_path).expanduser()
|
||||
|
||||
if artifacts_path is not None and not artifacts_path.is_dir():
|
||||
raise RuntimeError(
|
||||
f"The value of {artifacts_path=} is not valid. "
|
||||
"When defined, it must point to a folder containing all models required by the pipeline."
|
||||
)
|
||||
|
||||
# Create VLM model instance
|
||||
self.vlm_model = NuExtractTransformersModel(
|
||||
enabled=True,
|
||||
artifacts_path=artifacts_path, # Will download automatically
|
||||
accelerator_options=self.accelerator_options,
|
||||
vlm_options=pipeline_options.vlm_options,
|
||||
)
|
||||
|
||||
def _extract_data(
|
||||
self,
|
||||
ext_res: ExtractionResult,
|
||||
template: Optional[ExtractionTemplateType] = None,
|
||||
) -> ExtractionResult:
|
||||
"""Extract data using the VLM model."""
|
||||
try:
|
||||
# Get images from input document using the backend
|
||||
images = self._get_images_from_input(ext_res.input)
|
||||
if not images:
|
||||
ext_res.status = ConversionStatus.FAILURE
|
||||
ext_res.errors.append(
|
||||
ErrorItem(
|
||||
component_type="extraction_pipeline",
|
||||
module_name=self.__class__.__name__,
|
||||
error_message="No images found in document",
|
||||
)
|
||||
)
|
||||
return ext_res
|
||||
|
||||
# Use provided template or default prompt
|
||||
if template is not None:
|
||||
prompt = self._serialize_template(template)
|
||||
else:
|
||||
prompt = "Extract all text and structured information from this document. Return as JSON."
|
||||
|
||||
# Process all images with VLM model
|
||||
start_page, end_page = ext_res.input.limits.page_range
|
||||
for i, image in enumerate(images):
|
||||
# Calculate the actual page number based on the filtered range
|
||||
page_number = start_page + i
|
||||
try:
|
||||
predictions = list(self.vlm_model.process_images([image], prompt))
|
||||
|
||||
if predictions:
|
||||
# Parse the extracted text as JSON if possible, otherwise use as-is
|
||||
extracted_text = predictions[0].text
|
||||
extracted_data = None
|
||||
|
||||
try:
|
||||
extracted_data = json.loads(extracted_text)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# If not valid JSON, keep extracted_data as None
|
||||
pass
|
||||
|
||||
# Create page data with proper structure
|
||||
page_data = ExtractedPageData(
|
||||
page_no=page_number,
|
||||
extracted_data=extracted_data,
|
||||
raw_text=extracted_text, # Always populate raw_text
|
||||
)
|
||||
ext_res.pages.append(page_data)
|
||||
else:
|
||||
# Add error page data
|
||||
page_data = ExtractedPageData(
|
||||
page_no=page_number,
|
||||
extracted_data=None,
|
||||
errors=["No extraction result from VLM model"],
|
||||
)
|
||||
ext_res.pages.append(page_data)
|
||||
|
||||
except Exception as e:
|
||||
_log.error(f"Error processing page {page_number}: {e}")
|
||||
page_data = ExtractedPageData(
|
||||
page_no=page_number, extracted_data=None, errors=[str(e)]
|
||||
)
|
||||
ext_res.pages.append(page_data)
|
||||
|
||||
except Exception as e:
|
||||
_log.error(f"Error during extraction: {e}")
|
||||
ext_res.errors.append(
|
||||
ErrorItem(
|
||||
component_type="extraction_pipeline",
|
||||
module_name=self.__class__.__name__,
|
||||
error_message=str(e),
|
||||
)
|
||||
)
|
||||
|
||||
return ext_res
|
||||
|
||||
def _determine_status(self, ext_res: ExtractionResult) -> ConversionStatus:
|
||||
"""Determine the status based on extraction results."""
|
||||
if ext_res.pages and not any(page.errors for page in ext_res.pages):
|
||||
return ConversionStatus.SUCCESS
|
||||
else:
|
||||
return ConversionStatus.FAILURE
|
||||
|
||||
def _get_images_from_input(self, input_doc: InputDocument) -> list[Image]:
|
||||
"""Extract images from input document using the backend."""
|
||||
images = []
|
||||
|
||||
try:
|
||||
backend = input_doc._backend
|
||||
|
||||
assert isinstance(backend, PdfDocumentBackend)
|
||||
# Use the backend's pagination interface
|
||||
page_count = backend.page_count()
|
||||
|
||||
# Respect page range limits, following the same pattern as PaginatedPipeline
|
||||
start_page, end_page = input_doc.limits.page_range
|
||||
_log.info(
|
||||
f"Processing pages {start_page}-{end_page} of {page_count} total pages for extraction"
|
||||
)
|
||||
|
||||
for page_num in range(page_count):
|
||||
# Only process pages within the specified range (0-based indexing)
|
||||
if start_page - 1 <= page_num <= end_page - 1:
|
||||
try:
|
||||
page_backend = backend.load_page(page_num)
|
||||
if page_backend.is_valid():
|
||||
# Get page image at a reasonable scale
|
||||
page_image = page_backend.get_page_image(
|
||||
scale=self.pipeline_options.vlm_options.scale
|
||||
)
|
||||
images.append(page_image)
|
||||
else:
|
||||
_log.warning(f"Page {page_num + 1} backend is not valid")
|
||||
except Exception as e:
|
||||
_log.error(f"Error loading page {page_num + 1}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
_log.error(f"Error getting images from input document: {e}")
|
||||
|
||||
return images
|
||||
|
||||
def _serialize_template(self, template: ExtractionTemplateType) -> str:
|
||||
"""Serialize template to string based on its type."""
|
||||
if isinstance(template, str):
|
||||
return template
|
||||
elif isinstance(template, dict):
|
||||
return json.dumps(template, indent=2)
|
||||
elif isinstance(template, BaseModel):
|
||||
return template.model_dump_json(indent=2)
|
||||
elif inspect.isclass(template) and issubclass(template, BaseModel):
|
||||
from polyfactory.factories.pydantic_factory import ModelFactory
|
||||
|
||||
class ExtractionTemplateFactory(ModelFactory[template]): # type: ignore
|
||||
__use_examples__ = True # prefer Field(examples=...) when present
|
||||
__use_defaults__ = True # use field defaults instead of random values
|
||||
|
||||
return ExtractionTemplateFactory.build().model_dump_json(indent=2) # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Unsupported template type: {type(template)}")
|
||||
|
||||
@classmethod
|
||||
def get_default_options(cls) -> BaseOptions:
|
||||
return VlmExtractionPipelineOptions()
|
||||
Reference in New Issue
Block a user