Signed-off-by: felix <felixdittrich92@gmail.com>
This commit is contained in:
felix 2025-03-21 21:11:31 +01:00
parent 35f185f545
commit 7c87467ea5
2 changed files with 27 additions and 11 deletions

View File

@ -161,7 +161,9 @@ class OnnxtrOcrOptions(OcrOptions):
det_arch: str = "fast_base" det_arch: str = "fast_base"
reco_arch: str = "crnn_vgg16_bn" # NOTE: This can be also a hf hub model reco_arch: str = "crnn_vgg16_bn" # NOTE: This can be also a hf hub model
det_bs: int = 1 # NOTE: Should be 1 because docling seems not to support batch processing yet det_bs: int = (
1 # NOTE: Should be 1 because docling seems not to support batch processing yet
)
reco_bs: int = 512 reco_bs: int = 512
auto_correct_orientation: bool = False auto_correct_orientation: bool = False
preserve_aspect_ratio: bool = True preserve_aspect_ratio: bool = True

View File

@ -49,7 +49,6 @@ class OnnxtrOcrModel(BaseOcrModel):
"Alternatively, Docling has support for other OCR engines. See the documentation." "Alternatively, Docling has support for other OCR engines. See the documentation."
) )
if options.auto_correct_orientation: if options.auto_correct_orientation:
config = { config = {
"assume_straight_pages": False, "assume_straight_pages": False,
@ -69,8 +68,16 @@ class OnnxtrOcrModel(BaseOcrModel):
} }
self.reader = ocr_predictor( self.reader = ocr_predictor(
det_arch=from_hub(self.options.det_arch) if self.options.det_arch.count("/") == 1 else self.options.det_arch, det_arch=(
reco_arch=from_hub(self.options.reco_arch) if self.options.reco_arch.count("/") == 1 else self.options.reco_arch, from_hub(self.options.det_arch)
if self.options.det_arch.count("/") == 1
else self.options.det_arch
),
reco_arch=(
from_hub(self.options.reco_arch)
if self.options.reco_arch.count("/") == 1
else self.options.reco_arch
),
preserve_aspect_ratio=self.options.preserve_aspect_ratio, preserve_aspect_ratio=self.options.preserve_aspect_ratio,
symmetric_pad=self.options.symmetric_pad, symmetric_pad=self.options.symmetric_pad,
paragraph_break=self.options.paragraph_break, paragraph_break=self.options.paragraph_break,
@ -78,7 +85,9 @@ class OnnxtrOcrModel(BaseOcrModel):
**config, **config,
) )
def _to_absolute_and_docling_format(self, geom: list[list[float]], img_shape: tuple[int, int]) -> tuple[int, int, int, int]: def _to_absolute_and_docling_format(
self, geom: list[list[float]], img_shape: tuple[int, int]
) -> tuple[int, int, int, int]:
""" """
Convert a bounding box or polygon from relative to absolute coordinates and return in [x1, y1, x2, y2] format. Convert a bounding box or polygon from relative to absolute coordinates and return in [x1, y1, x2, y2] format.
@ -105,12 +114,15 @@ class OnnxtrOcrModel(BaseOcrModel):
x1, y1 = min(p[0] for p in abs_points), min(p[1] for p in abs_points) x1, y1 = min(p[0] for p in abs_points), min(p[1] for p in abs_points)
x2, y2 = max(p[0] for p in abs_points), max(p[1] for p in abs_points) x2, y2 = max(p[0] for p in abs_points), max(p[1] for p in abs_points)
else: else:
raise ValueError(f"Invalid geometry format: {geom}. Expected either 2 or 4 points.") raise ValueError(
f"Invalid geometry format: {geom}. Expected either 2 or 4 points."
)
return x1, y1, x2, y2 return x1, y1, x2, y2
def __call__(
def __call__(self, conv_res: ConversionResult, page_batch: Iterable[Page]) -> Iterable[Page]: self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
if not self.enabled: if not self.enabled:
yield from page_batch yield from page_batch
return return
@ -129,7 +141,9 @@ class OnnxtrOcrModel(BaseOcrModel):
if ocr_rect.area() == 0: if ocr_rect.area() == 0:
continue continue
with page._backend.get_page_image(scale=self.scale, cropbox=ocr_rect) as high_res_image: with page._backend.get_page_image(
scale=self.scale, cropbox=ocr_rect
) as high_res_image:
im_width, im_height = high_res_image.size im_width, im_height = high_res_image.size
result = self.reader([numpy.array(high_res_image)]) result = self.reader([numpy.array(high_res_image)])
@ -151,7 +165,8 @@ class OnnxtrOcrModel(BaseOcrModel):
rect=BoundingRectangle.from_bounding_box( rect=BoundingRectangle.from_bounding_box(
BoundingBox.from_tuple( BoundingBox.from_tuple(
self._to_absolute_and_docling_format( self._to_absolute_and_docling_format(
word.geometry, img_shape=(im_height, im_width) word.geometry,
img_shape=(im_height, im_width),
), ),
origin=CoordOrigin.TOPLEFT, origin=CoordOrigin.TOPLEFT,
) )
@ -168,7 +183,6 @@ class OnnxtrOcrModel(BaseOcrModel):
yield page yield page
@classmethod @classmethod
def get_options_type(cls) -> Type[OcrOptions]: def get_options_type(cls) -> Type[OcrOptions]:
return OnnxtrOcrOptions return OnnxtrOcrOptions