mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
feat: [Beta] Extraction with schema (#2138)
* Add DocumentConverter.extract and full extraction pipeline Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add DocumentConverter.extract template arg Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add NuExtract model Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add Extraction pipeline Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add proper test, support pydantic class types Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add qr bill example Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add base_extraction_pipeline Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add types Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update typing of ExtractionResult and inner fields Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Factor out extract to DocumentExtractor Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Address mypy issues Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add DocumentExtractor Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Resolve circular import issue Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Clean up imports, remove Optional for template arg Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Move new type definitions into datamodel Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update comments Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Respect page-range, disable test_extraction for CI Signed-off-by: Christoph Auer <cau@zurich.ibm.com> --------- Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
from docling_core.types.doc import (
|
||||
@@ -32,6 +32,18 @@ from pydantic import (
|
||||
if TYPE_CHECKING:
|
||||
from docling.backend.pdf_backend import PdfPageBackend
|
||||
|
||||
from docling.backend.abstract_backend import AbstractDocumentBackend
|
||||
from docling.datamodel.pipeline_options import PipelineOptions
|
||||
|
||||
|
||||
class BaseFormatOption(BaseModel):
|
||||
"""Base class for format options used by _DocumentConversionInput."""
|
||||
|
||||
pipeline_options: Optional[PipelineOptions] = None
|
||||
backend: Type[AbstractDocumentBackend]
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ConversionStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
|
||||
@@ -2,12 +2,13 @@ import csv
|
||||
import logging
|
||||
import re
|
||||
import tarfile
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Mapping
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path, PurePath
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
@@ -72,7 +73,7 @@ from docling.utils.profiling import ProfilingItem
|
||||
from docling.utils.utils import create_file_hash
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from docling.document_converter import FormatOption
|
||||
from docling.datamodel.base_models import BaseFormatOption
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
@@ -238,7 +239,8 @@ class _DocumentConversionInput(BaseModel):
|
||||
limits: Optional[DocumentLimits] = DocumentLimits()
|
||||
|
||||
def docs(
|
||||
self, format_options: Dict[InputFormat, "FormatOption"]
|
||||
self,
|
||||
format_options: Mapping[InputFormat, "BaseFormatOption"],
|
||||
) -> Iterable[InputDocument]:
|
||||
for item in self.path_or_stream_iterator:
|
||||
obj = (
|
||||
|
||||
39
docling/datamodel/extraction.py
Normal file
39
docling/datamodel/extraction.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Data models for document extraction functionality."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docling.datamodel.base_models import ConversionStatus, ErrorItem
|
||||
from docling.datamodel.document import InputDocument
|
||||
|
||||
|
||||
class ExtractedPageData(BaseModel):
|
||||
"""Data model for extracted content from a single page."""
|
||||
|
||||
page_no: int = Field(..., description="1-indexed page number")
|
||||
extracted_data: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Extracted structured data from the page"
|
||||
)
|
||||
raw_text: Optional[str] = Field(None, description="Raw extracted text")
|
||||
errors: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Any errors encountered during extraction for this page",
|
||||
)
|
||||
|
||||
|
||||
class ExtractionResult(BaseModel):
|
||||
"""Result of document extraction."""
|
||||
|
||||
input: InputDocument
|
||||
status: ConversionStatus = ConversionStatus.PENDING
|
||||
errors: List[ErrorItem] = []
|
||||
|
||||
# Pages field - always a list for consistency
|
||||
pages: List[ExtractedPageData] = Field(
|
||||
default_factory=list, description="Extracted data from each page"
|
||||
)
|
||||
|
||||
|
||||
# Type alias for template parameters that can be string, dict, or BaseModel
|
||||
ExtractionTemplateType = Union[str, Dict[str, Any], BaseModel, Type[BaseModel]]
|
||||
@@ -37,6 +37,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
||||
from docling.datamodel.vlm_model_specs import (
|
||||
GRANITE_VISION_OLLAMA as granite_vision_vlm_ollama_conversion_options,
|
||||
GRANITE_VISION_TRANSFORMERS as granite_vision_vlm_conversion_options,
|
||||
NU_EXTRACT_2B_TRANSFORMERS,
|
||||
SMOLDOCLING_MLX as smoldocling_vlm_mlx_conversion_options,
|
||||
SMOLDOCLING_TRANSFORMERS as smoldocling_vlm_conversion_options,
|
||||
VlmModelType,
|
||||
@@ -247,12 +248,9 @@ class OcrEngine(str, Enum):
|
||||
RAPIDOCR = "rapidocr"
|
||||
|
||||
|
||||
class PipelineOptions(BaseModel):
|
||||
class PipelineOptions(BaseOptions):
|
||||
"""Base pipeline options."""
|
||||
|
||||
create_legacy_output: bool = (
|
||||
True # This default will be set to False on a future version of docling
|
||||
)
|
||||
document_timeout: Optional[float] = None
|
||||
accelerator_options: AcceleratorOptions = AcceleratorOptions()
|
||||
enable_remote_services: bool = False
|
||||
@@ -296,6 +294,13 @@ class AsrPipelineOptions(PipelineOptions):
|
||||
artifacts_path: Optional[Union[Path, str]] = None
|
||||
|
||||
|
||||
class VlmExtractionPipelineOptions(PipelineOptions):
|
||||
"""Options for extraction pipeline."""
|
||||
|
||||
artifacts_path: Optional[Union[Path, str]] = None
|
||||
vlm_options: Union[InlineVlmOptions] = NU_EXTRACT_2B_TRANSFORMERS
|
||||
|
||||
|
||||
class PdfPipelineOptions(PaginatedPipelineOptions):
|
||||
"""Options for the PDF pipeline."""
|
||||
|
||||
|
||||
@@ -247,6 +247,23 @@ DOLPHIN_TRANSFORMERS = InlineVlmOptions(
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# NuExtract
|
||||
NU_EXTRACT_2B_TRANSFORMERS = InlineVlmOptions(
|
||||
repo_id="numind/NuExtract-2.0-2B",
|
||||
prompt="", # This won't be used, template is passed separately
|
||||
torch_dtype="bfloat16",
|
||||
inference_framework=InferenceFramework.TRANSFORMERS,
|
||||
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
||||
response_format=ResponseFormat.PLAINTEXT,
|
||||
supported_devices=[
|
||||
AcceleratorDevice.CPU,
|
||||
AcceleratorDevice.CUDA,
|
||||
AcceleratorDevice.MPS,
|
||||
],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
|
||||
class VlmModelType(str, Enum):
|
||||
SMOLDOCLING = "smoldocling"
|
||||
|
||||
@@ -28,6 +28,7 @@ from docling.backend.noop_backend import NoOpBackend
|
||||
from docling.backend.xml.jats_backend import JatsDocumentBackend
|
||||
from docling.backend.xml.uspto_backend import PatentUsptoDocumentBackend
|
||||
from docling.datamodel.base_models import (
|
||||
BaseFormatOption,
|
||||
ConversionStatus,
|
||||
DoclingComponentType,
|
||||
DocumentStream,
|
||||
@@ -57,12 +58,8 @@ _log = logging.getLogger(__name__)
|
||||
_PIPELINE_CACHE_LOCK = threading.Lock()
|
||||
|
||||
|
||||
class FormatOption(BaseModel):
|
||||
class FormatOption(BaseFormatOption):
|
||||
pipeline_cls: Type[BasePipeline]
|
||||
pipeline_options: Optional[PipelineOptions] = None
|
||||
backend: Type[AbstractDocumentBackend]
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_optional_field_default(self) -> "FormatOption":
|
||||
@@ -191,7 +188,7 @@ class DocumentConverter:
|
||||
self.allowed_formats = (
|
||||
allowed_formats if allowed_formats is not None else list(InputFormat)
|
||||
)
|
||||
self.format_to_options = {
|
||||
self.format_to_options: Dict[InputFormat, FormatOption] = {
|
||||
format: (
|
||||
_get_default_option(format=format)
|
||||
if (custom_option := (format_options or {}).get(format)) is None
|
||||
|
||||
325
docling/document_extractor.py
Normal file
325
docling/document_extractor.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import Iterable, Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from pydantic import ConfigDict, model_validator, validate_call
|
||||
|
||||
from docling.backend.abstract_backend import AbstractDocumentBackend
|
||||
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
|
||||
from docling.datamodel.base_models import (
|
||||
BaseFormatOption,
|
||||
ConversionStatus,
|
||||
DoclingComponentType,
|
||||
DocumentStream,
|
||||
ErrorItem,
|
||||
InputFormat,
|
||||
)
|
||||
from docling.datamodel.document import (
|
||||
InputDocument,
|
||||
_DocumentConversionInput, # intentionally reused builder
|
||||
)
|
||||
from docling.datamodel.extraction import ExtractionResult, ExtractionTemplateType
|
||||
from docling.datamodel.pipeline_options import PipelineOptions
|
||||
from docling.datamodel.settings import (
|
||||
DEFAULT_PAGE_RANGE,
|
||||
DocumentLimits,
|
||||
PageRange,
|
||||
settings,
|
||||
)
|
||||
from docling.exceptions import ConversionError
|
||||
from docling.pipeline.base_extraction_pipeline import BaseExtractionPipeline
|
||||
from docling.pipeline.extraction_vlm_pipeline import ExtractionVlmPipeline
|
||||
from docling.utils.utils import chunkify
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
_PIPELINE_CACHE_LOCK = threading.Lock()
|
||||
|
||||
|
||||
class ExtractionFormatOption(BaseFormatOption):
|
||||
"""Per-format configuration for extraction.
|
||||
|
||||
Notes:
|
||||
- `pipeline_cls` must subclass `BaseExtractionPipeline`.
|
||||
- `pipeline_options` is typed as `PipelineOptions` which MUST inherit from
|
||||
`BaseOptions` (as used by `BaseExtractionPipeline`).
|
||||
- `backend` is the document-opening backend used by `_DocumentConversionInput`.
|
||||
"""
|
||||
|
||||
pipeline_cls: Type[BaseExtractionPipeline]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_optional_field_default(self) -> "ExtractionFormatOption":
|
||||
if self.pipeline_options is None:
|
||||
# `get_default_options` comes from BaseExtractionPipeline
|
||||
self.pipeline_options = self.pipeline_cls.get_default_options() # type: ignore[assignment]
|
||||
return self
|
||||
|
||||
|
||||
def _get_default_extraction_option(fmt: InputFormat) -> ExtractionFormatOption:
|
||||
"""Return the default extraction option for a given input format.
|
||||
|
||||
Defaults mirror the converter's *backend* choices, while the pipeline is
|
||||
the VLM extractor. This duplication will be removed when we deduplicate
|
||||
the format registry between convert/extract.
|
||||
"""
|
||||
format_to_default_backend: Dict[InputFormat, Type[AbstractDocumentBackend]] = {
|
||||
InputFormat.IMAGE: PyPdfiumDocumentBackend,
|
||||
InputFormat.PDF: PyPdfiumDocumentBackend,
|
||||
}
|
||||
|
||||
backend = format_to_default_backend.get(fmt)
|
||||
if backend is None:
|
||||
raise RuntimeError(f"No default extraction backend configured for {fmt}")
|
||||
|
||||
return ExtractionFormatOption(
|
||||
pipeline_cls=ExtractionVlmPipeline,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
|
||||
class DocumentExtractor:
|
||||
"""Standalone extractor class.
|
||||
|
||||
Public API:
|
||||
- `extract(...) -> ExtractionResult`
|
||||
- `extract_all(...) -> Iterator[ExtractionResult]`
|
||||
|
||||
Implementation intentionally reuses `_DocumentConversionInput` to build
|
||||
`InputDocument` with the correct backend per format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
allowed_formats: Optional[List[InputFormat]] = None,
|
||||
extraction_format_options: Optional[
|
||||
Dict[InputFormat, ExtractionFormatOption]
|
||||
] = None,
|
||||
) -> None:
|
||||
self.allowed_formats: List[InputFormat] = (
|
||||
allowed_formats if allowed_formats is not None else list(InputFormat)
|
||||
)
|
||||
# Build per-format options with defaults, then apply any user overrides
|
||||
overrides = extraction_format_options or {}
|
||||
self.extraction_format_to_options: Dict[InputFormat, ExtractionFormatOption] = {
|
||||
fmt: overrides.get(fmt, _get_default_extraction_option(fmt))
|
||||
for fmt in self.allowed_formats
|
||||
}
|
||||
|
||||
# Cache pipelines by (class, options-hash)
|
||||
self._initialized_pipelines: Dict[
|
||||
Tuple[Type[BaseExtractionPipeline], str], BaseExtractionPipeline
|
||||
] = {}
|
||||
|
||||
# ---------------------------- Public API ---------------------------------
|
||||
|
||||
@validate_call(config=ConfigDict(strict=True))
|
||||
def extract(
|
||||
self,
|
||||
source: Union[Path, str, DocumentStream],
|
||||
template: ExtractionTemplateType,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
raises_on_error: bool = True,
|
||||
max_num_pages: int = sys.maxsize,
|
||||
max_file_size: int = sys.maxsize,
|
||||
page_range: PageRange = DEFAULT_PAGE_RANGE,
|
||||
) -> ExtractionResult:
|
||||
all_res = self.extract_all(
|
||||
source=[source],
|
||||
headers=headers,
|
||||
raises_on_error=raises_on_error,
|
||||
max_num_pages=max_num_pages,
|
||||
max_file_size=max_file_size,
|
||||
page_range=page_range,
|
||||
template=template,
|
||||
)
|
||||
return next(all_res)
|
||||
|
||||
@validate_call(config=ConfigDict(strict=True))
|
||||
def extract_all(
|
||||
self,
|
||||
source: Iterable[Union[Path, str, DocumentStream]],
|
||||
template: ExtractionTemplateType,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
raises_on_error: bool = True,
|
||||
max_num_pages: int = sys.maxsize,
|
||||
max_file_size: int = sys.maxsize,
|
||||
page_range: PageRange = DEFAULT_PAGE_RANGE,
|
||||
) -> Iterator[ExtractionResult]:
|
||||
warnings.warn(
|
||||
"The extract API is currently experimental and may change without prior notice.\n"
|
||||
"Only PDF and image formats are supported.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
limits = DocumentLimits(
|
||||
max_num_pages=max_num_pages,
|
||||
max_file_size=max_file_size,
|
||||
page_range=page_range,
|
||||
)
|
||||
conv_input = _DocumentConversionInput(
|
||||
path_or_stream_iterator=source, limits=limits, headers=headers
|
||||
)
|
||||
|
||||
ext_res_iter = self._extract(
|
||||
conv_input, raises_on_error=raises_on_error, template=template
|
||||
)
|
||||
|
||||
had_result = False
|
||||
for ext_res in ext_res_iter:
|
||||
had_result = True
|
||||
if raises_on_error and ext_res.status not in {
|
||||
ConversionStatus.SUCCESS,
|
||||
ConversionStatus.PARTIAL_SUCCESS,
|
||||
}:
|
||||
raise ConversionError(
|
||||
f"Extraction failed for: {ext_res.input.file} with status: {ext_res.status}"
|
||||
)
|
||||
else:
|
||||
yield ext_res
|
||||
|
||||
if not had_result and raises_on_error:
|
||||
raise ConversionError(
|
||||
"Extraction failed because the provided file has no recognizable format or it wasn't in the list of allowed formats."
|
||||
)
|
||||
|
||||
# --------------------------- Internal engine ------------------------------
|
||||
|
||||
def _extract(
|
||||
self,
|
||||
conv_input: _DocumentConversionInput,
|
||||
raises_on_error: bool,
|
||||
template: ExtractionTemplateType,
|
||||
) -> Iterator[ExtractionResult]:
|
||||
start_time = time.monotonic()
|
||||
|
||||
for input_batch in chunkify(
|
||||
conv_input.docs(self.extraction_format_to_options),
|
||||
settings.perf.doc_batch_size,
|
||||
):
|
||||
_log.info("Going to extract document batch...")
|
||||
process_func = partial(
|
||||
self._process_document_extraction,
|
||||
raises_on_error=raises_on_error,
|
||||
template=template,
|
||||
)
|
||||
|
||||
if (
|
||||
settings.perf.doc_batch_concurrency > 1
|
||||
and settings.perf.doc_batch_size > 1
|
||||
):
|
||||
with ThreadPoolExecutor(
|
||||
max_workers=settings.perf.doc_batch_concurrency
|
||||
) as pool:
|
||||
for item in pool.map(
|
||||
process_func,
|
||||
input_batch,
|
||||
):
|
||||
yield item
|
||||
else:
|
||||
for item in map(
|
||||
process_func,
|
||||
input_batch,
|
||||
):
|
||||
elapsed = time.monotonic() - start_time
|
||||
start_time = time.monotonic()
|
||||
_log.info(
|
||||
f"Finished extracting document {item.input.file.name} in {elapsed:.2f} sec."
|
||||
)
|
||||
yield item
|
||||
|
||||
def _process_document_extraction(
|
||||
self,
|
||||
in_doc: InputDocument,
|
||||
raises_on_error: bool,
|
||||
template: ExtractionTemplateType,
|
||||
) -> ExtractionResult:
|
||||
valid = (
|
||||
self.allowed_formats is not None and in_doc.format in self.allowed_formats
|
||||
)
|
||||
if valid:
|
||||
return self._execute_extraction_pipeline(
|
||||
in_doc, raises_on_error=raises_on_error, template=template
|
||||
)
|
||||
else:
|
||||
error_message = f"File format not allowed: {in_doc.file}"
|
||||
if raises_on_error:
|
||||
raise ConversionError(error_message)
|
||||
else:
|
||||
error_item = ErrorItem(
|
||||
component_type=DoclingComponentType.USER_INPUT,
|
||||
module_name="",
|
||||
error_message=error_message,
|
||||
)
|
||||
return ExtractionResult(
|
||||
input=in_doc, status=ConversionStatus.SKIPPED, errors=[error_item]
|
||||
)
|
||||
|
||||
def _execute_extraction_pipeline(
|
||||
self,
|
||||
in_doc: InputDocument,
|
||||
raises_on_error: bool,
|
||||
template: ExtractionTemplateType,
|
||||
) -> ExtractionResult:
|
||||
if not in_doc.valid:
|
||||
if raises_on_error:
|
||||
raise ConversionError(f"Input document {in_doc.file} is not valid.")
|
||||
else:
|
||||
return ExtractionResult(input=in_doc, status=ConversionStatus.FAILURE)
|
||||
|
||||
pipeline = self._get_pipeline(in_doc.format)
|
||||
if pipeline is None:
|
||||
if raises_on_error:
|
||||
raise ConversionError(
|
||||
f"No extraction pipeline could be initialized for {in_doc.file}."
|
||||
)
|
||||
else:
|
||||
return ExtractionResult(input=in_doc, status=ConversionStatus.FAILURE)
|
||||
|
||||
return pipeline.execute(
|
||||
in_doc, raises_on_error=raises_on_error, template=template
|
||||
)
|
||||
|
||||
def _get_pipeline(
|
||||
self, doc_format: InputFormat
|
||||
) -> Optional[BaseExtractionPipeline]:
|
||||
"""Retrieve or initialize a pipeline, reusing instances based on class and options."""
|
||||
fopt = self.extraction_format_to_options.get(doc_format)
|
||||
if fopt is None or fopt.pipeline_options is None:
|
||||
return None
|
||||
|
||||
pipeline_class = fopt.pipeline_cls
|
||||
pipeline_options = fopt.pipeline_options
|
||||
options_hash = self._get_pipeline_options_hash(pipeline_options)
|
||||
|
||||
cache_key = (pipeline_class, options_hash)
|
||||
with _PIPELINE_CACHE_LOCK:
|
||||
if cache_key not in self._initialized_pipelines:
|
||||
_log.info(
|
||||
f"Initializing extraction pipeline for {pipeline_class.__name__} with options hash {options_hash}"
|
||||
)
|
||||
self._initialized_pipelines[cache_key] = pipeline_class(
|
||||
pipeline_options=pipeline_options # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
_log.debug(
|
||||
f"Reusing cached extraction pipeline for {pipeline_class.__name__} with options hash {options_hash}"
|
||||
)
|
||||
|
||||
return self._initialized_pipelines[cache_key]
|
||||
|
||||
@staticmethod
|
||||
def _get_pipeline_options_hash(pipeline_options: PipelineOptions) -> str:
|
||||
"""Generate a stable hash of pipeline options to use as part of the cache key."""
|
||||
options_str = str(pipeline_options.model_dump())
|
||||
return hashlib.md5(
|
||||
options_str.encode("utf-8"), usedforsecurity=False
|
||||
).hexdigest()
|
||||
290
docling/models/vlm_models_inline/nuextract_transformers_model.py
Normal file
290
docling/models/vlm_models_inline/nuextract_transformers_model.py
Normal file
@@ -0,0 +1,290 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor, GenerationConfig
|
||||
|
||||
from docling.datamodel.accelerator_options import (
|
||||
AcceleratorOptions,
|
||||
)
|
||||
from docling.datamodel.base_models import VlmPrediction
|
||||
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
||||
from docling.models.base_model import BaseVlmModel
|
||||
from docling.models.utils.hf_model_download import (
|
||||
HuggingFaceModelDownloadMixin,
|
||||
)
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Source code from https://huggingface.co/numind/NuExtract-2.0-8B
|
||||
def process_all_vision_info(messages, examples=None):
|
||||
"""
|
||||
Process vision information from both messages and in-context examples, supporting batch processing.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries (single input) OR list of message lists (batch input)
|
||||
examples: Optional list of example dictionaries (single input) OR list of example lists (batch)
|
||||
|
||||
Returns:
|
||||
A flat list of all images in the correct order:
|
||||
- For single input: example images followed by message images
|
||||
- For batch input: interleaved as (item1 examples, item1 input, item2 examples, item2 input, etc.)
|
||||
- Returns None if no images were found
|
||||
"""
|
||||
try:
|
||||
from qwen_vl_utils import fetch_image, process_vision_info
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"qwen-vl-utils is required for NuExtractTransformersModel. "
|
||||
"Please install it with: pip install qwen-vl-utils"
|
||||
)
|
||||
|
||||
from qwen_vl_utils import fetch_image, process_vision_info
|
||||
|
||||
# Helper function to extract images from examples
|
||||
def extract_example_images(example_item):
|
||||
if not example_item:
|
||||
return []
|
||||
|
||||
# Handle both list of examples and single example
|
||||
examples_to_process = (
|
||||
example_item if isinstance(example_item, list) else [example_item]
|
||||
)
|
||||
images = []
|
||||
|
||||
for example in examples_to_process:
|
||||
if (
|
||||
isinstance(example.get("input"), dict)
|
||||
and example["input"].get("type") == "image"
|
||||
):
|
||||
images.append(fetch_image(example["input"]))
|
||||
|
||||
return images
|
||||
|
||||
# Normalize inputs to always be batched format
|
||||
is_batch = messages and isinstance(messages[0], list)
|
||||
messages_batch = messages if is_batch else [messages]
|
||||
is_batch_examples = (
|
||||
examples
|
||||
and isinstance(examples, list)
|
||||
and (isinstance(examples[0], list) or examples[0] is None)
|
||||
)
|
||||
examples_batch = (
|
||||
examples
|
||||
if is_batch_examples
|
||||
else ([examples] if examples is not None else None)
|
||||
)
|
||||
|
||||
# Ensure examples batch matches messages batch if provided
|
||||
if examples and len(examples_batch) != len(messages_batch):
|
||||
if not is_batch and len(examples_batch) == 1:
|
||||
# Single example set for a single input is fine
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Examples batch length must match messages batch length")
|
||||
|
||||
# Process all inputs, maintaining correct order
|
||||
all_images = []
|
||||
for i, message_group in enumerate(messages_batch):
|
||||
# Get example images for this input
|
||||
if examples and i < len(examples_batch):
|
||||
input_example_images = extract_example_images(examples_batch[i])
|
||||
all_images.extend(input_example_images)
|
||||
|
||||
# Get message images for this input
|
||||
input_message_images = process_vision_info(message_group)[0] or []
|
||||
all_images.extend(input_message_images)
|
||||
|
||||
return all_images if all_images else None
|
||||
|
||||
|
||||
class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Optional[Path],
|
||||
accelerator_options: AcceleratorOptions,
|
||||
vlm_options: InlineVlmOptions,
|
||||
):
|
||||
self.enabled = enabled
|
||||
self.vlm_options = vlm_options
|
||||
|
||||
if self.enabled:
|
||||
import torch
|
||||
|
||||
self.device = decide_device(
|
||||
accelerator_options.device,
|
||||
supported_devices=vlm_options.supported_devices,
|
||||
)
|
||||
_log.debug(f"Available device for NuExtract VLM: {self.device}")
|
||||
|
||||
self.max_new_tokens = vlm_options.max_new_tokens
|
||||
self.temperature = vlm_options.temperature
|
||||
|
||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models(self.vlm_options.repo_id)
|
||||
elif (artifacts_path / repo_cache_folder).exists():
|
||||
artifacts_path = artifacts_path / repo_cache_folder
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
artifacts_path,
|
||||
trust_remote_code=vlm_options.trust_remote_code,
|
||||
use_fast=True,
|
||||
)
|
||||
self.processor.tokenizer.padding_side = "left"
|
||||
|
||||
self.vlm_model = AutoModelForImageTextToText.from_pretrained(
|
||||
artifacts_path,
|
||||
device_map=self.device,
|
||||
torch_dtype=self.vlm_options.torch_dtype,
|
||||
_attn_implementation=(
|
||||
"flash_attention_2"
|
||||
if self.device.startswith("cuda")
|
||||
and accelerator_options.cuda_use_flash_attention2
|
||||
else "sdpa"
|
||||
),
|
||||
trust_remote_code=vlm_options.trust_remote_code,
|
||||
)
|
||||
self.vlm_model = torch.compile(self.vlm_model) # type: ignore
|
||||
|
||||
# Load generation config
|
||||
self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
|
||||
|
||||
def process_images(
|
||||
self,
|
||||
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||
prompt: Union[str, list[str]],
|
||||
) -> Iterable[VlmPrediction]:
|
||||
"""
|
||||
Batched inference for NuExtract VLM using the specialized input format.
|
||||
|
||||
Args:
|
||||
image_batch: Iterable of PIL Images or numpy arrays
|
||||
prompt: Either:
|
||||
- str: Single template used for all images
|
||||
- list[str]: List of templates (one per image, must match image count)
|
||||
"""
|
||||
import torch
|
||||
from PIL import Image as PILImage
|
||||
|
||||
# Normalize images to RGB PIL
|
||||
pil_images: list[Image] = []
|
||||
for img in image_batch:
|
||||
if isinstance(img, np.ndarray):
|
||||
if img.ndim == 3 and img.shape[2] in (3, 4):
|
||||
pil_img = PILImage.fromarray(img.astype(np.uint8))
|
||||
elif img.ndim == 2:
|
||||
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
|
||||
else:
|
||||
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
|
||||
else:
|
||||
pil_img = img
|
||||
if pil_img.mode != "RGB":
|
||||
pil_img = pil_img.convert("RGB")
|
||||
pil_images.append(pil_img)
|
||||
|
||||
if not pil_images:
|
||||
return
|
||||
|
||||
# Normalize templates (1 per image)
|
||||
if isinstance(prompt, str):
|
||||
templates = [prompt] * len(pil_images)
|
||||
else:
|
||||
if len(prompt) != len(pil_images):
|
||||
raise ValueError(
|
||||
f"Number of templates ({len(prompt)}) must match number of images ({len(pil_images)})"
|
||||
)
|
||||
templates = prompt
|
||||
|
||||
# Construct NuExtract input format
|
||||
inputs = []
|
||||
for pil_img, template in zip(pil_images, templates):
|
||||
input_item = {
|
||||
"document": {"type": "image", "image": pil_img},
|
||||
"template": template,
|
||||
}
|
||||
inputs.append(input_item)
|
||||
|
||||
# Create messages structure for batch processing
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [x["document"]],
|
||||
}
|
||||
]
|
||||
for x in inputs
|
||||
]
|
||||
|
||||
# Apply chat template to each example individually
|
||||
texts = [
|
||||
self.processor.tokenizer.apply_chat_template(
|
||||
messages[i],
|
||||
template=x["template"],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
for i, x in enumerate(inputs)
|
||||
]
|
||||
|
||||
# Process vision inputs using qwen-vl-utils
|
||||
image_inputs = process_all_vision_info(messages)
|
||||
|
||||
# Process with the processor
|
||||
processor_inputs = self.processor(
|
||||
text=texts,
|
||||
images=image_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
**self.vlm_options.extra_processor_kwargs,
|
||||
)
|
||||
processor_inputs = {k: v.to(self.device) for k, v in processor_inputs.items()}
|
||||
|
||||
# Generate
|
||||
gen_kwargs = {
|
||||
**processor_inputs,
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"generation_config": self.generation_config,
|
||||
**self.vlm_options.extra_generation_config,
|
||||
}
|
||||
if self.temperature > 0:
|
||||
gen_kwargs["do_sample"] = True
|
||||
gen_kwargs["temperature"] = self.temperature
|
||||
else:
|
||||
gen_kwargs["do_sample"] = False
|
||||
|
||||
start_time = time.time()
|
||||
with torch.inference_mode():
|
||||
generated_ids = self.vlm_model.generate(**gen_kwargs)
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
# Trim generated sequences
|
||||
input_len = processor_inputs["input_ids"].shape[1]
|
||||
trimmed_sequences = generated_ids[:, input_len:]
|
||||
|
||||
# Decode with the processor/tokenizer
|
||||
decoded_texts: list[str] = self.processor.batch_decode(
|
||||
trimmed_sequences,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
|
||||
# Optional logging
|
||||
if generated_ids.shape[0] > 0: # type: ignore
|
||||
_log.debug(
|
||||
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
|
||||
f"for batch size {generated_ids.shape[0]}." # type: ignore
|
||||
)
|
||||
|
||||
for text in decoded_texts:
|
||||
# Apply decode_response to the output text
|
||||
decoded_text = self.vlm_options.decode_response(text)
|
||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
||||
58
docling/pipeline/base_extraction_pipeline.py
Normal file
58
docling/pipeline/base_extraction_pipeline.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from docling.datamodel.base_models import ConversionStatus, ErrorItem
|
||||
from docling.datamodel.document import InputDocument
|
||||
from docling.datamodel.extraction import ExtractionResult, ExtractionTemplateType
|
||||
from docling.datamodel.pipeline_options import BaseOptions
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseExtractionPipeline(ABC):
|
||||
def __init__(self, pipeline_options: BaseOptions):
|
||||
self.pipeline_options = pipeline_options
|
||||
|
||||
def execute(
|
||||
self,
|
||||
in_doc: InputDocument,
|
||||
raises_on_error: bool,
|
||||
template: Optional[ExtractionTemplateType] = None,
|
||||
) -> ExtractionResult:
|
||||
ext_res = ExtractionResult(input=in_doc)
|
||||
|
||||
try:
|
||||
ext_res = self._extract_data(ext_res, template)
|
||||
ext_res.status = self._determine_status(ext_res)
|
||||
except Exception as e:
|
||||
ext_res.status = ConversionStatus.FAILURE
|
||||
error_item = ErrorItem(
|
||||
component_type="extraction_pipeline",
|
||||
module_name=self.__class__.__name__,
|
||||
error_message=str(e),
|
||||
)
|
||||
ext_res.errors.append(error_item)
|
||||
if raises_on_error:
|
||||
raise e
|
||||
|
||||
return ext_res
|
||||
|
||||
@abstractmethod
|
||||
def _extract_data(
|
||||
self,
|
||||
ext_res: ExtractionResult,
|
||||
template: Optional[ExtractionTemplateType] = None,
|
||||
) -> ExtractionResult:
|
||||
"""Subclass must populate ext_res.pages/errors and return the result."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _determine_status(self, ext_res: ExtractionResult) -> ConversionStatus:
|
||||
"""Subclass must decide SUCCESS/PARTIAL_SUCCESS/FAILURE based on ext_res."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_default_options(cls) -> BaseOptions:
|
||||
pass
|
||||
204
docling/pipeline/extraction_vlm_pipeline.py
Normal file
204
docling/pipeline/extraction_vlm_pipeline.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from PIL.Image import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from docling.backend.abstract_backend import PaginatedDocumentBackend
|
||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||
from docling.datamodel.base_models import ConversionStatus, ErrorItem
|
||||
from docling.datamodel.document import InputDocument
|
||||
from docling.datamodel.extraction import (
|
||||
ExtractedPageData,
|
||||
ExtractionResult,
|
||||
ExtractionTemplateType,
|
||||
)
|
||||
from docling.datamodel.pipeline_options import BaseOptions, VlmExtractionPipelineOptions
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.vlm_models_inline.nuextract_transformers_model import (
|
||||
NuExtractTransformersModel,
|
||||
)
|
||||
from docling.pipeline.base_extraction_pipeline import BaseExtractionPipeline
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExtractionVlmPipeline(BaseExtractionPipeline):
|
||||
def __init__(self, pipeline_options: VlmExtractionPipelineOptions):
|
||||
super().__init__(pipeline_options)
|
||||
|
||||
# Initialize VLM model with default options
|
||||
self.accelerator_options = pipeline_options.accelerator_options
|
||||
self.pipeline_options: VlmExtractionPipelineOptions
|
||||
|
||||
artifacts_path: Optional[Path] = None
|
||||
if pipeline_options.artifacts_path is not None:
|
||||
artifacts_path = Path(pipeline_options.artifacts_path).expanduser()
|
||||
elif settings.artifacts_path is not None:
|
||||
artifacts_path = Path(settings.artifacts_path).expanduser()
|
||||
|
||||
if artifacts_path is not None and not artifacts_path.is_dir():
|
||||
raise RuntimeError(
|
||||
f"The value of {artifacts_path=} is not valid. "
|
||||
"When defined, it must point to a folder containing all models required by the pipeline."
|
||||
)
|
||||
|
||||
# Create VLM model instance
|
||||
self.vlm_model = NuExtractTransformersModel(
|
||||
enabled=True,
|
||||
artifacts_path=artifacts_path, # Will download automatically
|
||||
accelerator_options=self.accelerator_options,
|
||||
vlm_options=pipeline_options.vlm_options,
|
||||
)
|
||||
|
||||
def _extract_data(
|
||||
self,
|
||||
ext_res: ExtractionResult,
|
||||
template: Optional[ExtractionTemplateType] = None,
|
||||
) -> ExtractionResult:
|
||||
"""Extract data using the VLM model."""
|
||||
try:
|
||||
# Get images from input document using the backend
|
||||
images = self._get_images_from_input(ext_res.input)
|
||||
if not images:
|
||||
ext_res.status = ConversionStatus.FAILURE
|
||||
ext_res.errors.append(
|
||||
ErrorItem(
|
||||
component_type="extraction_pipeline",
|
||||
module_name=self.__class__.__name__,
|
||||
error_message="No images found in document",
|
||||
)
|
||||
)
|
||||
return ext_res
|
||||
|
||||
# Use provided template or default prompt
|
||||
if template is not None:
|
||||
prompt = self._serialize_template(template)
|
||||
else:
|
||||
prompt = "Extract all text and structured information from this document. Return as JSON."
|
||||
|
||||
# Process all images with VLM model
|
||||
start_page, end_page = ext_res.input.limits.page_range
|
||||
for i, image in enumerate(images):
|
||||
# Calculate the actual page number based on the filtered range
|
||||
page_number = start_page + i
|
||||
try:
|
||||
predictions = list(self.vlm_model.process_images([image], prompt))
|
||||
|
||||
if predictions:
|
||||
# Parse the extracted text as JSON if possible, otherwise use as-is
|
||||
extracted_text = predictions[0].text
|
||||
extracted_data = None
|
||||
|
||||
try:
|
||||
extracted_data = json.loads(extracted_text)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# If not valid JSON, keep extracted_data as None
|
||||
pass
|
||||
|
||||
# Create page data with proper structure
|
||||
page_data = ExtractedPageData(
|
||||
page_no=page_number,
|
||||
extracted_data=extracted_data,
|
||||
raw_text=extracted_text, # Always populate raw_text
|
||||
)
|
||||
ext_res.pages.append(page_data)
|
||||
else:
|
||||
# Add error page data
|
||||
page_data = ExtractedPageData(
|
||||
page_no=page_number,
|
||||
extracted_data=None,
|
||||
errors=["No extraction result from VLM model"],
|
||||
)
|
||||
ext_res.pages.append(page_data)
|
||||
|
||||
except Exception as e:
|
||||
_log.error(f"Error processing page {page_number}: {e}")
|
||||
page_data = ExtractedPageData(
|
||||
page_no=page_number, extracted_data=None, errors=[str(e)]
|
||||
)
|
||||
ext_res.pages.append(page_data)
|
||||
|
||||
except Exception as e:
|
||||
_log.error(f"Error during extraction: {e}")
|
||||
ext_res.errors.append(
|
||||
ErrorItem(
|
||||
component_type="extraction_pipeline",
|
||||
module_name=self.__class__.__name__,
|
||||
error_message=str(e),
|
||||
)
|
||||
)
|
||||
|
||||
return ext_res
|
||||
|
||||
def _determine_status(self, ext_res: ExtractionResult) -> ConversionStatus:
|
||||
"""Determine the status based on extraction results."""
|
||||
if ext_res.pages and not any(page.errors for page in ext_res.pages):
|
||||
return ConversionStatus.SUCCESS
|
||||
else:
|
||||
return ConversionStatus.FAILURE
|
||||
|
||||
def _get_images_from_input(self, input_doc: InputDocument) -> list[Image]:
|
||||
"""Extract images from input document using the backend."""
|
||||
images = []
|
||||
|
||||
try:
|
||||
backend = input_doc._backend
|
||||
|
||||
assert isinstance(backend, PdfDocumentBackend)
|
||||
# Use the backend's pagination interface
|
||||
page_count = backend.page_count()
|
||||
|
||||
# Respect page range limits, following the same pattern as PaginatedPipeline
|
||||
start_page, end_page = input_doc.limits.page_range
|
||||
_log.info(
|
||||
f"Processing pages {start_page}-{end_page} of {page_count} total pages for extraction"
|
||||
)
|
||||
|
||||
for page_num in range(page_count):
|
||||
# Only process pages within the specified range (0-based indexing)
|
||||
if start_page - 1 <= page_num <= end_page - 1:
|
||||
try:
|
||||
page_backend = backend.load_page(page_num)
|
||||
if page_backend.is_valid():
|
||||
# Get page image at a reasonable scale
|
||||
page_image = page_backend.get_page_image(
|
||||
scale=self.pipeline_options.vlm_options.scale
|
||||
)
|
||||
images.append(page_image)
|
||||
else:
|
||||
_log.warning(f"Page {page_num + 1} backend is not valid")
|
||||
except Exception as e:
|
||||
_log.error(f"Error loading page {page_num + 1}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
_log.error(f"Error getting images from input document: {e}")
|
||||
|
||||
return images
|
||||
|
||||
def _serialize_template(self, template: ExtractionTemplateType) -> str:
|
||||
"""Serialize template to string based on its type."""
|
||||
if isinstance(template, str):
|
||||
return template
|
||||
elif isinstance(template, dict):
|
||||
return json.dumps(template, indent=2)
|
||||
elif isinstance(template, BaseModel):
|
||||
return template.model_dump_json(indent=2)
|
||||
elif inspect.isclass(template) and issubclass(template, BaseModel):
|
||||
from polyfactory.factories.pydantic_factory import ModelFactory
|
||||
|
||||
class ExtractionTemplateFactory(ModelFactory[template]): # type: ignore
|
||||
__use_examples__ = True # prefer Field(examples=...) when present
|
||||
__use_defaults__ = True # use field defaults instead of random values
|
||||
|
||||
return ExtractionTemplateFactory.build().model_dump_json(indent=2) # type: ignore
|
||||
else:
|
||||
raise ValueError(f"Unsupported template type: {type(template)}")
|
||||
|
||||
@classmethod
|
||||
def get_default_options(cls) -> BaseOptions:
|
||||
return VlmExtractionPipelineOptions()
|
||||
@@ -71,6 +71,7 @@ dependencies = [
|
||||
# 'scipy (>=1.6.0,<2.0.0) ; python_version >= "3.10"',
|
||||
# 'scipy (>=1.6.0,<1.14.0) ; python_version < "3.10"',
|
||||
"accelerate>=1.0.0,<2",
|
||||
"polyfactory>=2.22.2",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -94,6 +95,7 @@ vlm = [
|
||||
'accelerate (>=1.2.1,<2.0.0)',
|
||||
'mlx-vlm (>=0.3.0,<1.0.0) ; python_version >= "3.10" and sys_platform == "darwin" and platform_machine == "arm64"',
|
||||
'vllm (>=0.10.0,<1.0.0) ; python_version >= "3.10" and sys_platform == "linux" and platform_machine == "x86_64"',
|
||||
"qwen-vl-utils>=0.0.11",
|
||||
]
|
||||
rapidocr = [
|
||||
'rapidocr (>=3.3,<4.0.0) ; python_version < "3.14"',
|
||||
@@ -255,6 +257,7 @@ module = [
|
||||
"transformers.*",
|
||||
"pylatexenc.*",
|
||||
"vllm.*",
|
||||
"qwen_vl_utils.*",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
|
||||
BIN
tests/data_scanned/qr_bill_example.jpg
vendored
Normal file
BIN
tests/data_scanned/qr_bill_example.jpg
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 285 KiB |
108
tests/test_extraction.py
Normal file
108
tests/test_extraction.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Test unit for document extraction functionality.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling.document_extractor import DocumentExtractor
|
||||
|
||||
IS_CI = bool(os.getenv("CI"))
|
||||
|
||||
|
||||
class ExampleTemplate(BaseModel):
|
||||
bill_no: str = Field(
|
||||
examples=["A123", "5414"]
|
||||
) # provide some examples, but not the actual value of the test sample
|
||||
total: float = Field(
|
||||
default=10.0, examples=[20.0]
|
||||
) # provide a default value and some examples
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def extractor() -> DocumentExtractor:
|
||||
"""Create a document converter instance for testing."""
|
||||
|
||||
return DocumentExtractor(allowed_formats=[InputFormat.IMAGE, InputFormat.PDF])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_file_path() -> Path:
|
||||
"""Get the path to the test QR bill image."""
|
||||
return Path(__file__).parent / "data_scanned" / "qr_bill_example.jpg"
|
||||
# return Path("tests/data/pdf/code_and_formula.pdf")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
IS_CI, reason="Skipping test in CI because the dataset is too heavy."
|
||||
)
|
||||
def test_extraction_with_string_template(
|
||||
extractor: DocumentExtractor, test_file_path: Path
|
||||
) -> None:
|
||||
"""Test extraction using string template."""
|
||||
str_templ = '{"bill_no": "string", "total": "number"}'
|
||||
|
||||
result = extractor.extract(test_file_path, template=str_templ)
|
||||
|
||||
print(result.pages)
|
||||
|
||||
assert result.status is not None
|
||||
assert len(result.pages) == 1
|
||||
assert result.pages[0].extracted_data["bill_no"] == "3139"
|
||||
assert result.pages[0].extracted_data["total"] == 3949.75
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
IS_CI, reason="Skipping test in CI because the dataset is too heavy."
|
||||
)
|
||||
def test_extraction_with_dict_template(
|
||||
extractor: DocumentExtractor, test_file_path: Path
|
||||
) -> None:
|
||||
"""Test extraction using dictionary template."""
|
||||
dict_templ = {
|
||||
"bill_no": "string",
|
||||
"total": "number",
|
||||
}
|
||||
|
||||
result = extractor.extract(test_file_path, template=dict_templ)
|
||||
|
||||
assert len(result.pages) == 1
|
||||
assert result.pages[0].extracted_data["bill_no"] == "3139"
|
||||
assert result.pages[0].extracted_data["total"] == 3949.75
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
IS_CI, reason="Skipping test in CI because the dataset is too heavy."
|
||||
)
|
||||
def test_extraction_with_pydantic_instance_template(
|
||||
extractor: DocumentExtractor, test_file_path: Path
|
||||
) -> None:
|
||||
"""Test extraction using pydantic instance template."""
|
||||
pydantic_instance_templ = ExampleTemplate(bill_no="4321")
|
||||
|
||||
result = extractor.extract(test_file_path, template=pydantic_instance_templ)
|
||||
|
||||
assert len(result.pages) == 1
|
||||
assert result.pages[0].extracted_data["bill_no"] == "3139"
|
||||
assert result.pages[0].extracted_data["total"] == 3949.75
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
IS_CI, reason="Skipping test in CI because the dataset is too heavy."
|
||||
)
|
||||
def test_extraction_with_pydantic_class_template(
|
||||
extractor: DocumentExtractor, test_file_path: Path
|
||||
) -> None:
|
||||
"""Test extraction using pydantic class template."""
|
||||
pydantic_class_templ = ExampleTemplate
|
||||
|
||||
result = extractor.extract(test_file_path, template=pydantic_class_templ)
|
||||
|
||||
assert len(result.pages) == 1
|
||||
assert result.pages[0].extracted_data["bill_no"] == "3139"
|
||||
assert result.pages[0].extracted_data["total"] == 3949.75
|
||||
97
uv.lock
generated
97
uv.lock
generated
@@ -247,6 +247,59 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/43/53afb8ba17218f19b77c7834128566c5bbb100a0ad9ba2e8e89d089d7079/autopep8-2.3.2-py2.py3-none-any.whl", hash = "sha256:ce8ad498672c845a0c3de2629c15b635ec2b05ef8177a6e7c91c74f3e9b51128", size = 45807, upload-time = "2025-01-14T14:46:15.466Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "av"
|
||||
version = "15.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/17/89/940a509ee7e9449f0c877fa984b37b7cc485546035cc67bbc353f2ac20f3/av-15.0.0.tar.gz", hash = "sha256:871c1a9becddf00b60b1294dc0bff9ff193ac31286aeec1a34039bd27e650183", size = 3833128, upload-time = "2025-07-03T16:23:48.455Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/de/c6/4646aeffca77fb9e557509528341fdef409f7e5de44c70858fb639bb9a3e/av-15.0.0-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:f20c7565ad9aed8a5e3ca7ed30b151d4d8a937e072b6a4901c3200134fe7c68b", size = 21808084, upload-time = "2025-07-03T16:21:17.028Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/89/a2/8a0349e2ebf998f2305b61365240a748bc137f94f431e769c2ac83c5a321/av-15.0.0-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:0d8b78a88f0fdaf6591bca32b41301e40ba60be294b0698318948c4d1fa6f206", size = 26989279, upload-time = "2025-07-03T16:21:20.63Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/de/f7/2e3b9cc831a1891914ca09aaeac88195f36f24a22f8c18e57637604a8ef1/av-15.0.0-cp310-cp310-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:5d14f6cea6c4966d5478a50555fa92af4948f83e7843b63b747d4a451c53d4f1", size = 33955236, upload-time = "2025-07-03T16:21:23.752Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/23/fa/cc4e32d85d6e765f9e9c2680ce9bee6a4d66c8a069f136322be04a66e70d/av-15.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:09f516947890dcf27482af2f0f7b31a579dbd11d5566dd74ce5f1f6396c452b7", size = 37681552, upload-time = "2025-07-03T16:21:27.265Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/e3/438e1095c064fd21f1325ddae9383b4bcdc8f8493247144ed15bc1b931a2/av-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:908e2fb4358210a463f81e2dbfac77d5977cc53a400fea3f6decef6f9f9267e4", size = 39179769, upload-time = "2025-07-03T16:21:31.941Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/b7/e6c27a8bd75e3eede07c1ce888fc1aa6293ba35393d2f4adc1d2e41d563b/av-15.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:911fe092b1c75c35d3e9d836b750ff725599a343b6126449cb2c1b4aa8ac2792", size = 39725200, upload-time = "2025-07-03T16:21:35.73Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/42/06e91b07c77465af1b845ac5cf83be1b4cbe042fd940509ae3c5ad70e386/av-15.0.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:25743a08b674596f3b993392259a4953a445b4211796d168c992174c983b76f0", size = 36639563, upload-time = "2025-07-03T16:21:39.395Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/d6/92aafdd8ef420100aca3503b7772ca2484d3688b83b09ca6f96bfb47b7c1/av-15.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2ecd5df62b9697a9304a084fbfed13fa890ec9ba2f647aaed35dca291991c7b1", size = 40482430, upload-time = "2025-07-03T16:21:42.671Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/41/1a/22d7b2a151d4aeff6a1fb530e25c8d677dd59580418cab4a95c4628d5009/av-15.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:7156d1b326e328aaba7f10a0d89bec53d087aba16f4c5c7ae13890b9eefde972", size = 31363168, upload-time = "2025-07-03T16:21:45.809Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e5/2a/40e0ec34e8235e4a1f9fe60288cd1eebe6413765931b5b74aeb3ce79c422/av-15.0.0-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:eb19386466aafbac4ede549ed7dc6198714e8d35ecc238d5b5c0d91e770d53d4", size = 21793541, upload-time = "2025-07-03T16:21:48.819Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/21/74acec5492a901699a94715e94cb83772679b92183592a3d8b3e58cf0202/av-15.0.0-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:e3c841befff26823524f3260d29fb3162540535c43238587b24226d345c82af3", size = 26973175, upload-time = "2025-07-03T16:21:51.63Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/d9/04e7fc09c6246aaf8e695620cc026779e366c49dcab561f8f434fbed3256/av-15.0.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:fe50ddab68af27bb9f7123dac5b1ff43ee8c7d941499c625018f3cac7da01ff3", size = 34423925, upload-time = "2025-07-03T16:21:54.628Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/e0/f4a93c901d65188ffe21e182499abf5304403f47e24da001b21042c888ec/av-15.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:5877f9dacf04bba9e966e0feb707e0fc2955476dc50cc6de5707604f51440e1b", size = 38178587, upload-time = "2025-07-03T16:21:57.966Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/64/6dcfb449ed287a590ecf70d6259f1f2c06fa9a576996f53d1949d65c4ee5/av-15.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e2a5f0d1817ab73370fdb35e2e2ecd4c2e5a45d43b8d96d5ae8dfe86098fb9b3", size = 39683188, upload-time = "2025-07-03T16:22:01.315Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/ef/a89775afc0487a4f5ab892b423972ae47bd3ef004faeb666135c657ea308/av-15.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:710ad0307da524be553db123c0681edadb5cefc15baa49cf25217364fb7a80b5", size = 40230243, upload-time = "2025-07-03T16:22:04.811Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/35/4b/39c40ce50c7290b5091afe75264c31bb1afb53e918c16991c808131a5d27/av-15.0.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:d275d3015ab13aadc7bf38df3b2398ad992e30b1685cd350fd46c71913e98af4", size = 37059511, upload-time = "2025-07-03T16:22:09.216Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d3/39/947815be601b2dc9f43ea59fc5582cb7125070ef352cb0157ea29b98b796/av-15.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:5b300f824a7ca1854ca29a37265281fa07d3dd0f69a6d2ff55d4c54ee3d734e2", size = 40993811, upload-time = "2025-07-03T16:22:12.524Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/de/ca/0c77802f70248bc3e182451a174db30fca349858840c4fbf1c7f8e1beaa0/av-15.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:cb7a85b4fe853a9c2725cdf02f457221fcc24f0391c8333b25a3a889e16ff26d", size = 31358970, upload-time = "2025-07-03T16:22:15.78Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/89/81/c5d009ea9c01a513b7af6aac2ac49c0f2f7193345071cd6dd4d91bef3ab9/av-15.0.0-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:84e2ede9459e64e768f4bc56d9df65da9e94b704ee3eccfe2e5b1da1da754313", size = 21782026, upload-time = "2025-07-03T16:22:18.41Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/16/8a/ffe9fcac35a07efc6aa0d765015efa499d88823c01499f318760460f8088/av-15.0.0-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:9473ed92d6942c5a449a2c79d49f3425eb0272499d1a3559b32c1181ff736a08", size = 26974939, upload-time = "2025-07-03T16:22:21.493Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/e7/0816e52134dc2d0259bb1aaad78573eacaf2bebc1a643de34e3384b520d6/av-15.0.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:56a53fe4e09bebd99355eaa0ce221b681eaf205bdda114f5e17fb79f3c3746ad", size = 34573486, upload-time = "2025-07-03T16:22:24.684Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a3/f4/07cc05712e9824a4bb68beea44eb5a7369dee3f00fa258879190004b7fc5/av-15.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:247dd9a99d7ed3577b8c1e9977e811f423b04504ff36c9dcd7a4de3e6e5fe5ad", size = 38418908, upload-time = "2025-07-03T16:22:27.799Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/19/48/7f3a21a41e291f8c5b8a98f95cfef308ce1b024a634413ce910c270efd7d/av-15.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:fc50a7d5f60109221ccf44f8fa4c56ce73f22948b7f19b1717fcc58f7fbc383e", size = 40010257, upload-time = "2025-07-03T16:22:31.15Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6d/c9/ced392e82d39084544d2d0c05decb36446028928eddf0d40ec3d8fe6c050/av-15.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:77deaec8943abfebd4e262924f2f452d6594cf0bc67d8d98aac0462b476e4182", size = 40381801, upload-time = "2025-07-03T16:22:34.254Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/73/a23ad111200e27f5773e94b0b6f9e2ea492a72ded7f4787a358d9d504a8b/av-15.0.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:601d9b0740e47a17ec96ba2a537ebfd4d6edc859ae6f298475c06caa51f0a019", size = 37219417, upload-time = "2025-07-03T16:22:37.497Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/0c/2ac20143b74e3792ede40bfd397ce72fa4e76a03999c2fd0aee3997b6971/av-15.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e021f67e0db7256c9f5d3d6a2a4237a4a4a804b131b33e7f2778981070519b20", size = 41242077, upload-time = "2025-07-03T16:22:40.86Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bd/30/40452705dffbfef0f5505d36218970dfeff0a86048689910219c8717b310/av-15.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:383f1b57520d790069d85fc75f43cfa32fca07f5fb3fb842be37bd596638602c", size = 31357617, upload-time = "2025-07-03T16:22:43.934Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a6/27/c2e248498ce78dd504b0b1818ce88e71e30a7e26c348bdf5d6467d7b06f7/av-15.0.0-cp313-cp313-macosx_13_0_arm64.whl", hash = "sha256:0701c116f32bd9478023f610722f6371d15ca0c068ff228d355f54a7cf23d9cb", size = 21746400, upload-time = "2025-07-03T16:22:46.604Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/d8/11f8452f19f4ddc189e978b215420131db40e3919135c14a0d13520f7c94/av-15.0.0-cp313-cp313-macosx_13_0_x86_64.whl", hash = "sha256:57fb6232494ec575b8e78e5a9ef9b811d78f8d67324476ec8430ca3146751124", size = 26939576, upload-time = "2025-07-03T16:22:49.255Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/00/1c/b109fd41487d91b8843f9e199b65e89ca533a612ec788b11ed0ba9812ea3/av-15.0.0-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:801a3e0afd5c36df70d012d083bfca67ab22d0ebd2c860c0d9432ac875bc0ad6", size = 34284344, upload-time = "2025-07-03T16:22:52.373Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/99/71/aee35fa182d0a41227fbd3f4250fd94c54acdd2995025ee59dd948bba930/av-15.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:d5e97791b96741b344bf6dbea4fb14481c117b1f7fe8113721e8d80e26cbb388", size = 38130346, upload-time = "2025-07-03T16:22:56.755Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/c4/2d9bbc9c42a804c99bc571eeacb2fe1582fe9cfdb726616876cada937d6a/av-15.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:acb4e4aa6bb394d3a9e60feb4cb7a856fc7bac01f3c99019b1d0f11c898c682c", size = 39728857, upload-time = "2025-07-03T16:23:00.392Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/d6/a5746e9fb4fdf326e9897abd7538413210e66f35ad4793fe30f87859249d/av-15.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:02d2d80bdbe184f1f3f49b3f5eae7f0ff7cba0a62ab3b18be0505715e586ad29", size = 40109012, upload-time = "2025-07-03T16:23:04.1Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/77/1f/da89798231ad0feacfaaea4efec4f1779060226986f97498eabe2c7c54a8/av-15.0.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:603f3ae751f6678df5d8b949f92c6f8257064bba8b3e8db606a24c29d31b4e25", size = 36929211, upload-time = "2025-07-03T16:23:07.694Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d5/4c/2bcabe65a1c19e552f03540f16155a0d02cb9b7a90d31242ab3e0c7ea0d8/av-15.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:682686a9ea2745e63c8878641ec26b1787b9210533f3e945a6e07e24ab788c2e", size = 40967172, upload-time = "2025-07-03T16:23:13.488Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/f0/fe14adaa670ab7a3f709805a8494fd0a2eeb6a5b18b8c59dc6014639a5b1/av-15.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:5758231163b5486dfbf664036be010b7f5ebb24564aaeb62577464be5ea996e0", size = 31332650, upload-time = "2025-07-03T16:23:16.558Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/89/11d5e4ef0341a617d63b615609d9a0f3afe244835a5e464425a11ca20036/av-15.0.0-cp39-cp39-macosx_13_0_arm64.whl", hash = "sha256:4a110aecebd7daef08f8be68ac9d6540f716a492f1994886a65eab9d19de39e2", size = 21838897, upload-time = "2025-07-03T16:23:19.406Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2f/26/fa9439ec34d2199b68f80d0453325526720c1933449dfedff7dd71014948/av-15.0.0-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:5bcf99da2b1c67ed6b6a0d070cc218eccf05698fc960db9b8f42d36779714294", size = 27021427, upload-time = "2025-07-03T16:23:21.992Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/11/5ddaf6ab6f1444a6ae2e8cef7663f8b2cff330bf2355ebae16ff3c4210ee/av-15.0.0-cp39-cp39-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:f483c1dcefe1bb9b96bc5813b57acf03d13c717aea477088e26119392c53aa81", size = 34073519, upload-time = "2025-07-03T16:23:24.955Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/24/78/25c1b850aa6dd76307494a03ee971981c2ba203f5ea4053accbbc9f7071e/av-15.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:ad7faa2b906954fd75fa9d3b52c81cdf9a1b202df305de34456a2a1d4aee625f", size = 37808712, upload-time = "2025-07-03T16:23:28.046Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/46/3c192579f73d00eb4e856b6a25b1b128a20a70fe07a8268b67dc1ad4dc75/av-15.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:e1a8d63fa3af2136c70e330a9845a4a2314d936c8a487760598ed7692024cc93", size = 39306799, upload-time = "2025-07-03T16:23:31.333Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/07/f8204ac0c5805f870fd59722dab7b7fc720467d978e77a042e8d3f74917a/av-15.0.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:5abc363915c0bb4d5973e41095881ce7dd715fc921e8366732a6b1a2e91f928a", size = 39849840, upload-time = "2025-07-03T16:23:35.046Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a1/9b/d1b520391bea1597cc4e9b7f62984f5707cf0ffa56e7cf9fe1c6c0a99344/av-15.0.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:679720aba7540a974a7911a20cce5d0fd2c32a8ae3f7371283b9361140b8d0bb", size = 36759597, upload-time = "2025-07-03T16:23:38.727Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/b0/3d05231a212b26fe134b4b2a5d5cd3d7634b133e2b4909f9d984b4b7154a/av-15.0.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:886a73abb874d8f1813d750714ea271ecc6c1e1b489e4d8381cdd4e1ab3fced2", size = 40606055, upload-time = "2025-07-03T16:23:42.391Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/14/3c941a9e3032b2dfb8d5194f4c9325566aff3448683fae2d612c883e340f/av-15.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:224087661cc20f0de052f05c2a47ff35eccd00702f8c8a4260fe5d469c6d591d", size = 31389895, upload-time = "2025-07-03T16:23:45.438Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "babel"
|
||||
version = "2.17.0"
|
||||
@@ -985,6 +1038,7 @@ dependencies = [
|
||||
{ name = "pandas" },
|
||||
{ name = "pillow" },
|
||||
{ name = "pluggy" },
|
||||
{ name = "polyfactory" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pylatexenc" },
|
||||
@@ -1019,6 +1073,7 @@ tesserocr = [
|
||||
vlm = [
|
||||
{ name = "accelerate" },
|
||||
{ name = "mlx-vlm", marker = "python_full_version >= '3.10' and platform_machine == 'arm64' and sys_platform == 'darwin'" },
|
||||
{ name = "qwen-vl-utils" },
|
||||
{ name = "transformers" },
|
||||
{ name = "vllm", marker = "python_full_version >= '3.10' and platform_machine == 'x86_64' and sys_platform == 'linux'" },
|
||||
]
|
||||
@@ -1087,12 +1142,14 @@ requires-dist = [
|
||||
{ name = "pandas", specifier = ">=2.1.4,<3.0.0" },
|
||||
{ name = "pillow", specifier = ">=10.0.0,<12.0.0" },
|
||||
{ name = "pluggy", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "polyfactory", specifier = ">=2.22.2" },
|
||||
{ name = "pydantic", specifier = ">=2.0.0,<3.0.0" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.3.0,<3.0.0" },
|
||||
{ name = "pylatexenc", specifier = ">=2.10,<3.0" },
|
||||
{ name = "pypdfium2", specifier = ">=4.30.0,!=4.30.1,<5.0.0" },
|
||||
{ name = "python-docx", specifier = ">=1.1.2,<2.0.0" },
|
||||
{ name = "python-pptx", specifier = ">=1.0.2,<2.0.0" },
|
||||
{ name = "qwen-vl-utils", marker = "extra == 'vlm'", specifier = ">=0.0.11" },
|
||||
{ name = "rapidocr", marker = "python_full_version < '3.14' and extra == 'rapidocr'", specifier = ">=3.3,<4.0.0" },
|
||||
{ name = "requests", specifier = ">=2.32.2,<3.0.0" },
|
||||
{ name = "rtree", specifier = ">=1.3.0,<2.0.0" },
|
||||
@@ -1351,6 +1408,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702, upload-time = "2025-01-22T15:41:25.929Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "faker"
|
||||
version = "37.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "tzdata" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/24/cd/f7679c20f07d9e2013123b7f7e13809a3450a18d938d58e86081a486ea15/faker-37.6.0.tar.gz", hash = "sha256:0f8cc34f30095184adf87c3c24c45b38b33ad81c35ef6eb0a3118f301143012c", size = 1907960, upload-time = "2025-08-26T15:56:27.419Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/61/7d/8b50e4ac772719777be33661f4bde320793400a706f5eb214e4de46f093c/faker-37.6.0-py3-none-any.whl", hash = "sha256:3c5209b23d7049d596a51db5d76403a0ccfea6fc294ffa2ecfef6a8843b1e6a7", size = 1949837, upload-time = "2025-08-26T15:56:25.33Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.116.1"
|
||||
@@ -4761,6 +4830,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "polyfactory"
|
||||
version = "2.22.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "faker" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/4e/a6/950d13856d995705df33b92451559fd317207a9c43629ab1771135a0c966/polyfactory-2.22.2.tar.gz", hash = "sha256:a3297aa0b004f2b26341e903795565ae88507c4d86e68b132c2622969028587a", size = 254462, upload-time = "2025-08-15T06:23:21.28Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e7/fe/d52c90e07c458f38b26f9972a25cb011b2744813f76fcd6121dde64744fa/polyfactory-2.22.2-py3-none-any.whl", hash = "sha256:9bea58ac9a80375b4153cd60820f75e558b863e567e058794d28c6a52b84118a", size = 63715, upload-time = "2025-08-15T06:23:19.664Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pre-commit"
|
||||
version = "3.8.0"
|
||||
@@ -5975,6 +6057,21 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e0/fb/fdc3dd8ec46f5226b4cde299839f10c625886bd18adbeaa8a59ffe104356/pyzmq-27.0.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:8ffe40c216c41756ca05188c3e24a23142334b304f7aebd75c24210385e35573", size = 544636, upload-time = "2025-08-21T04:23:24.736Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "qwen-vl-utils"
|
||||
version = "0.0.11"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "av" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pillow" },
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/42/9f/1229a40ebd49f689a0252144126f3865f31bb4151e942cf781a2936f0c4d/qwen_vl_utils-0.0.11.tar.gz", hash = "sha256:083ba1e5cfa5002165b1e3bddd4d6d26d1d6d34473884033ef12ae3fe8496cd5", size = 7924, upload-time = "2025-04-21T10:38:47.461Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/0a/c2/ad7f93e1eea4ea0aefd1cc6fbe7a7095fd2f03a4d8fe2c3707e612b0866e/qwen_vl_utils-0.0.11-py3-none-any.whl", hash = "sha256:7fd5287ac04d6c1f01b93bf053b0be236a35149e414c9e864e3cc5bf2fe8cb7b", size = 7584, upload-time = "2025-04-21T10:38:45.595Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rapidocr"
|
||||
version = "3.3.1"
|
||||
|
||||
Reference in New Issue
Block a user