fix: Move common OCR code in the BaseOcrModel class

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>
This commit is contained in:
Nikos Livathinos 2024-11-11 17:19:35 +01:00
parent 7234dc3a42
commit 7a0f16079d
4 changed files with 25 additions and 38 deletions

View File

@ -10,7 +10,7 @@ from PIL import Image, ImageDraw
from rtree import index
from scipy.ndimage import find_objects, label
from docling.datamodel.base_models import OcrCell, Page
from docling.datamodel.base_models import Cell, OcrCell, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import OcrOptions
from docling.datamodel.settings import settings
@ -98,7 +98,7 @@ class BaseOcrModel(BasePageModel):
return ocr_rects
# Filters OCR cells by dropping any OCR cell that intersects with an existing programmatic cell.
def filter_ocr_cells(self, ocr_cells, programmatic_cells):
def _filter_ocr_cells(self, ocr_cells, programmatic_cells):
# Create R-tree index for programmatic cells
p = index.Property()
p.dimension = 2
@ -119,6 +119,23 @@ class BaseOcrModel(BasePageModel):
]
return filtered_ocr_cells
def post_process_cells(self, ocr_cells, programmatic_cells):
r"""
Post-process the ocr and programmatic cells and return the final list of of cells
"""
if self.options.force_full_page_ocr:
# If a full page OCR is forced, use only the OCR cells
cells = [
Cell(id=c_ocr.id, text=c_ocr.text, bbox=c_ocr.bbox)
for c_ocr in ocr_cells
]
return cells
## Remove OCR cells which overlap with programmatic cells.
filtered_ocr_cells = self._filter_ocr_cells(ocr_cells, programmatic_cells)
programmatic_cells.extend(filtered_ocr_cells)
return programmatic_cells
def draw_ocr_rects_and_cells(self, conv_res, page, ocr_rects, show: bool = False):
image = copy.deepcopy(page.image)
draw = ImageDraw.Draw(image, "RGBA")

View File

@ -88,18 +88,8 @@ class EasyOcrModel(BaseOcrModel):
]
all_ocr_cells.extend(cells)
if self.options.force_full_page_ocr:
# If a full page OCR is forced, use only the OCR cells
page.cells = [
Cell(id=c_ocr.id, text=c_ocr.text, bbox=c_ocr.bbox)
for c_ocr in all_ocr_cells
]
else:
## Remove OCR cells which overlap with programmatic cells.
filtered_ocr_cells = self.filter_ocr_cells(
all_ocr_cells, page.cells
)
page.cells.extend(filtered_ocr_cells)
# Post-process the cells
page.cells = self.post_process_cells(all_ocr_cells, page.cells)
# DEBUG code:
if settings.debug.visualize_ocr:

View File

@ -170,18 +170,8 @@ class TesseractOcrCliModel(BaseOcrModel):
)
all_ocr_cells.append(cell)
if self.options.force_full_page_ocr:
# If a full page OCR is forced, use only the OCR cells
page.cells = [
Cell(id=c_ocr.id, text=c_ocr.text, bbox=c_ocr.bbox)
for c_ocr in all_ocr_cells
]
else:
## Remove OCR cells which overlap with programmatic cells.
filtered_ocr_cells = self.filter_ocr_cells(
all_ocr_cells, page.cells
)
page.cells.extend(filtered_ocr_cells)
# Post-process the cells
page.cells = self.post_process_cells(all_ocr_cells, page.cells)
# DEBUG code:
if settings.debug.visualize_ocr:

View File

@ -140,18 +140,8 @@ class TesseractOcrModel(BaseOcrModel):
# del high_res_image
all_ocr_cells.extend(cells)
if self.options.force_full_page_ocr:
# If a full page OCR is forced, use only the OCR cells
page.cells = [
Cell(id=c_ocr.id, text=c_ocr.text, bbox=c_ocr.bbox)
for c_ocr in all_ocr_cells
]
else:
## Remove OCR cells which overlap with programmatic cells.
filtered_ocr_cells = self.filter_ocr_cells(
all_ocr_cells, page.cells
)
page.cells.extend(filtered_ocr_cells)
# Post-process the cells
page.cells = self.post_process_cells(all_ocr_cells, page.cells)
# DEBUG code:
if settings.debug.visualize_ocr: