mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
adding doctr ocr to pipeline
This commit is contained in:
parent
976e92e289
commit
2acce04305
@ -196,6 +196,15 @@ class TesseractOcrOptions(OcrOptions):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DoctrOcrOptions(OcrOptions):
|
||||||
|
kind: ClassVar[Literal["doctr"]] = "doctr"
|
||||||
|
|
||||||
|
lang: Optional[List[str]] = None
|
||||||
|
model_name: str = "db_resnet50"
|
||||||
|
pretrained: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OcrMacOptions(OcrOptions):
|
class OcrMacOptions(OcrOptions):
|
||||||
"""Options for the Mac OCR engine."""
|
"""Options for the Mac OCR engine."""
|
||||||
|
|
||||||
|
115
docling/models/doctr_ocr_model.py
Normal file
115
docling/models/doctr_ocr_model.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Type
|
||||||
|
import numpy as np
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||||
|
from docling_core.types.doc.page import BoundingRectangle, TextCell
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import Page
|
||||||
|
from docling.datamodel.document import ConversionResult
|
||||||
|
from docling.datamodel.pipeline_options import (
|
||||||
|
AcceleratorOptions,
|
||||||
|
OcrOptions,
|
||||||
|
DoctrOcrOptions,
|
||||||
|
)
|
||||||
|
from docling.datamodel.settings import settings
|
||||||
|
from docling.models.base_ocr_model import BaseOcrModel
|
||||||
|
from docling.utils.profiling import TimeRecorder
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class DoctrOcrModel(BaseOcrModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enabled: bool,
|
||||||
|
artifacts_path: Optional[Path],
|
||||||
|
options: OcrOptions,
|
||||||
|
accelerator_options: AcceleratorOptions,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
enabled=enabled,
|
||||||
|
artifacts_path=artifacts_path,
|
||||||
|
options=options,
|
||||||
|
accelerator_options=accelerator_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.enabled:
|
||||||
|
try:
|
||||||
|
from doctr.models import ocr_predictor
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"The 'python-doctr' library is not installed. Install it via `pip install python-doctr`."
|
||||||
|
)
|
||||||
|
|
||||||
|
_log.debug("Initializing Doctr OCR engine")
|
||||||
|
# Initialize a simple Doctr OCR model
|
||||||
|
self.model = ocr_predictor(pretrained=True)
|
||||||
|
else:
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
def __call__(self, conv_res: ConversionResult, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||||
|
if not self.enabled or self.model is None:
|
||||||
|
yield from page_batch
|
||||||
|
return
|
||||||
|
|
||||||
|
from doctr.io import DocumentFile
|
||||||
|
|
||||||
|
for page in page_batch:
|
||||||
|
assert page._backend is not None
|
||||||
|
if not page._backend.is_valid():
|
||||||
|
yield page
|
||||||
|
else:
|
||||||
|
with TimeRecorder(conv_res, "ocr"):
|
||||||
|
pil_image = page._backend.get_page_image(scale=1).convert("RGB")
|
||||||
|
|
||||||
|
# 2) Convert it to raw PNG bytes
|
||||||
|
buf = BytesIO()
|
||||||
|
pil_image.save(buf, format="PNG")
|
||||||
|
img_bytes = buf.getvalue()
|
||||||
|
|
||||||
|
# 3) Wrap in a list and hand to doctr
|
||||||
|
doc = DocumentFile.from_images([img_bytes])
|
||||||
|
|
||||||
|
result = self.model(doc)
|
||||||
|
|
||||||
|
all_cells = []
|
||||||
|
if len(result.pages) > 0:
|
||||||
|
doc_page = result.pages[0]
|
||||||
|
for block in doc_page.blocks:
|
||||||
|
for line in block.lines:
|
||||||
|
line_text = " ".join(w.value for w in line.words)
|
||||||
|
(left, top), (right, bottom) = line.geometry
|
||||||
|
if line.words:
|
||||||
|
conf = float(np.mean([w.confidence for w in line.words]))
|
||||||
|
else:
|
||||||
|
conf = 0.0
|
||||||
|
|
||||||
|
all_cells.append(
|
||||||
|
TextCell(
|
||||||
|
index=len(all_cells),
|
||||||
|
text=line_text,
|
||||||
|
orig=line_text,
|
||||||
|
from_ocr=True,
|
||||||
|
confidence=conf,
|
||||||
|
rect=BoundingRectangle.from_bounding_box(
|
||||||
|
BoundingBox.from_tuple(
|
||||||
|
coord=(left, top, right, bottom),
|
||||||
|
origin=CoordOrigin.TOPLEFT,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Attach the OCR cells to the page
|
||||||
|
page.cells = self.post_process_cells(all_cells, page.cells)
|
||||||
|
|
||||||
|
if settings.debug.visualize_ocr:
|
||||||
|
self.draw_ocr_rects_and_cells(conv_res, page, [])
|
||||||
|
|
||||||
|
yield page
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_options_type(cls) -> Type[OcrOptions]:
|
||||||
|
return DoctrOcrOptions
|
@ -5,6 +5,7 @@ from docling.models.picture_description_vlm_model import PictureDescriptionVlmMo
|
|||||||
from docling.models.rapid_ocr_model import RapidOcrModel
|
from docling.models.rapid_ocr_model import RapidOcrModel
|
||||||
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
|
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
|
||||||
from docling.models.tesseract_ocr_model import TesseractOcrModel
|
from docling.models.tesseract_ocr_model import TesseractOcrModel
|
||||||
|
from docling.models.doctr_ocr_model import DoctrOcrModel
|
||||||
|
|
||||||
|
|
||||||
def ocr_engines():
|
def ocr_engines():
|
||||||
@ -15,6 +16,7 @@ def ocr_engines():
|
|||||||
RapidOcrModel,
|
RapidOcrModel,
|
||||||
TesseractOcrModel,
|
TesseractOcrModel,
|
||||||
TesseractOcrCliModel,
|
TesseractOcrCliModel,
|
||||||
|
DoctrOcrModel,
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user