mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-30 22:14:37 +00:00
fix: Move common OCR code in the BaseOcrModel class
Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>
This commit is contained in:
parent
7234dc3a42
commit
7a0f16079d
@ -10,7 +10,7 @@ from PIL import Image, ImageDraw
|
||||
from rtree import index
|
||||
from scipy.ndimage import find_objects, label
|
||||
|
||||
from docling.datamodel.base_models import OcrCell, Page
|
||||
from docling.datamodel.base_models import Cell, OcrCell, Page
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import OcrOptions
|
||||
from docling.datamodel.settings import settings
|
||||
@ -98,7 +98,7 @@ class BaseOcrModel(BasePageModel):
|
||||
return ocr_rects
|
||||
|
||||
# Filters OCR cells by dropping any OCR cell that intersects with an existing programmatic cell.
|
||||
def filter_ocr_cells(self, ocr_cells, programmatic_cells):
|
||||
def _filter_ocr_cells(self, ocr_cells, programmatic_cells):
|
||||
# Create R-tree index for programmatic cells
|
||||
p = index.Property()
|
||||
p.dimension = 2
|
||||
@ -119,6 +119,23 @@ class BaseOcrModel(BasePageModel):
|
||||
]
|
||||
return filtered_ocr_cells
|
||||
|
||||
def post_process_cells(self, ocr_cells, programmatic_cells):
|
||||
r"""
|
||||
Post-process the ocr and programmatic cells and return the final list of of cells
|
||||
"""
|
||||
if self.options.force_full_page_ocr:
|
||||
# If a full page OCR is forced, use only the OCR cells
|
||||
cells = [
|
||||
Cell(id=c_ocr.id, text=c_ocr.text, bbox=c_ocr.bbox)
|
||||
for c_ocr in ocr_cells
|
||||
]
|
||||
return cells
|
||||
|
||||
## Remove OCR cells which overlap with programmatic cells.
|
||||
filtered_ocr_cells = self._filter_ocr_cells(ocr_cells, programmatic_cells)
|
||||
programmatic_cells.extend(filtered_ocr_cells)
|
||||
return programmatic_cells
|
||||
|
||||
def draw_ocr_rects_and_cells(self, conv_res, page, ocr_rects, show: bool = False):
|
||||
image = copy.deepcopy(page.image)
|
||||
draw = ImageDraw.Draw(image, "RGBA")
|
||||
|
@ -88,18 +88,8 @@ class EasyOcrModel(BaseOcrModel):
|
||||
]
|
||||
all_ocr_cells.extend(cells)
|
||||
|
||||
if self.options.force_full_page_ocr:
|
||||
# If a full page OCR is forced, use only the OCR cells
|
||||
page.cells = [
|
||||
Cell(id=c_ocr.id, text=c_ocr.text, bbox=c_ocr.bbox)
|
||||
for c_ocr in all_ocr_cells
|
||||
]
|
||||
else:
|
||||
## Remove OCR cells which overlap with programmatic cells.
|
||||
filtered_ocr_cells = self.filter_ocr_cells(
|
||||
all_ocr_cells, page.cells
|
||||
)
|
||||
page.cells.extend(filtered_ocr_cells)
|
||||
# Post-process the cells
|
||||
page.cells = self.post_process_cells(all_ocr_cells, page.cells)
|
||||
|
||||
# DEBUG code:
|
||||
if settings.debug.visualize_ocr:
|
||||
|
@ -170,18 +170,8 @@ class TesseractOcrCliModel(BaseOcrModel):
|
||||
)
|
||||
all_ocr_cells.append(cell)
|
||||
|
||||
if self.options.force_full_page_ocr:
|
||||
# If a full page OCR is forced, use only the OCR cells
|
||||
page.cells = [
|
||||
Cell(id=c_ocr.id, text=c_ocr.text, bbox=c_ocr.bbox)
|
||||
for c_ocr in all_ocr_cells
|
||||
]
|
||||
else:
|
||||
## Remove OCR cells which overlap with programmatic cells.
|
||||
filtered_ocr_cells = self.filter_ocr_cells(
|
||||
all_ocr_cells, page.cells
|
||||
)
|
||||
page.cells.extend(filtered_ocr_cells)
|
||||
# Post-process the cells
|
||||
page.cells = self.post_process_cells(all_ocr_cells, page.cells)
|
||||
|
||||
# DEBUG code:
|
||||
if settings.debug.visualize_ocr:
|
||||
|
@ -140,18 +140,8 @@ class TesseractOcrModel(BaseOcrModel):
|
||||
# del high_res_image
|
||||
all_ocr_cells.extend(cells)
|
||||
|
||||
if self.options.force_full_page_ocr:
|
||||
# If a full page OCR is forced, use only the OCR cells
|
||||
page.cells = [
|
||||
Cell(id=c_ocr.id, text=c_ocr.text, bbox=c_ocr.bbox)
|
||||
for c_ocr in all_ocr_cells
|
||||
]
|
||||
else:
|
||||
## Remove OCR cells which overlap with programmatic cells.
|
||||
filtered_ocr_cells = self.filter_ocr_cells(
|
||||
all_ocr_cells, page.cells
|
||||
)
|
||||
page.cells.extend(filtered_ocr_cells)
|
||||
# Post-process the cells
|
||||
page.cells = self.post_process_cells(all_ocr_cells, page.cells)
|
||||
|
||||
# DEBUG code:
|
||||
if settings.debug.visualize_ocr:
|
||||
|
Loading…
Reference in New Issue
Block a user