diff --git a/docling/models/base_ocr_model.py b/docling/models/base_ocr_model.py index d8b3262e..38b5e52c 100644 --- a/docling/models/base_ocr_model.py +++ b/docling/models/base_ocr_model.py @@ -10,7 +10,7 @@ from PIL import Image, ImageDraw from rtree import index from scipy.ndimage import find_objects, label -from docling.datamodel.base_models import OcrCell, Page +from docling.datamodel.base_models import Cell, OcrCell, Page from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options import OcrOptions from docling.datamodel.settings import settings @@ -98,7 +98,7 @@ class BaseOcrModel(BasePageModel): return ocr_rects # Filters OCR cells by dropping any OCR cell that intersects with an existing programmatic cell. - def filter_ocr_cells(self, ocr_cells, programmatic_cells): + def _filter_ocr_cells(self, ocr_cells, programmatic_cells): # Create R-tree index for programmatic cells p = index.Property() p.dimension = 2 @@ -119,6 +119,23 @@ class BaseOcrModel(BasePageModel): ] return filtered_ocr_cells + def post_process_cells(self, ocr_cells, programmatic_cells): + r""" + Post-process the ocr and programmatic cells and return the final list of of cells + """ + if self.options.force_full_page_ocr: + # If a full page OCR is forced, use only the OCR cells + cells = [ + Cell(id=c_ocr.id, text=c_ocr.text, bbox=c_ocr.bbox) + for c_ocr in ocr_cells + ] + return cells + + ## Remove OCR cells which overlap with programmatic cells. + filtered_ocr_cells = self._filter_ocr_cells(ocr_cells, programmatic_cells) + programmatic_cells.extend(filtered_ocr_cells) + return programmatic_cells + def draw_ocr_rects_and_cells(self, conv_res, page, ocr_rects, show: bool = False): image = copy.deepcopy(page.image) draw = ImageDraw.Draw(image, "RGBA") diff --git a/docling/models/easyocr_model.py b/docling/models/easyocr_model.py index 824242ca..f8d0cf8d 100644 --- a/docling/models/easyocr_model.py +++ b/docling/models/easyocr_model.py @@ -88,18 +88,8 @@ class EasyOcrModel(BaseOcrModel): ] all_ocr_cells.extend(cells) - if self.options.force_full_page_ocr: - # If a full page OCR is forced, use only the OCR cells - page.cells = [ - Cell(id=c_ocr.id, text=c_ocr.text, bbox=c_ocr.bbox) - for c_ocr in all_ocr_cells - ] - else: - ## Remove OCR cells which overlap with programmatic cells. - filtered_ocr_cells = self.filter_ocr_cells( - all_ocr_cells, page.cells - ) - page.cells.extend(filtered_ocr_cells) + # Post-process the cells + page.cells = self.post_process_cells(all_ocr_cells, page.cells) # DEBUG code: if settings.debug.visualize_ocr: diff --git a/docling/models/tesseract_ocr_cli_model.py b/docling/models/tesseract_ocr_cli_model.py index daee0572..9a50eee0 100644 --- a/docling/models/tesseract_ocr_cli_model.py +++ b/docling/models/tesseract_ocr_cli_model.py @@ -170,18 +170,8 @@ class TesseractOcrCliModel(BaseOcrModel): ) all_ocr_cells.append(cell) - if self.options.force_full_page_ocr: - # If a full page OCR is forced, use only the OCR cells - page.cells = [ - Cell(id=c_ocr.id, text=c_ocr.text, bbox=c_ocr.bbox) - for c_ocr in all_ocr_cells - ] - else: - ## Remove OCR cells which overlap with programmatic cells. - filtered_ocr_cells = self.filter_ocr_cells( - all_ocr_cells, page.cells - ) - page.cells.extend(filtered_ocr_cells) + # Post-process the cells + page.cells = self.post_process_cells(all_ocr_cells, page.cells) # DEBUG code: if settings.debug.visualize_ocr: diff --git a/docling/models/tesseract_ocr_model.py b/docling/models/tesseract_ocr_model.py index bb33327d..b2bd358b 100644 --- a/docling/models/tesseract_ocr_model.py +++ b/docling/models/tesseract_ocr_model.py @@ -140,18 +140,8 @@ class TesseractOcrModel(BaseOcrModel): # del high_res_image all_ocr_cells.extend(cells) - if self.options.force_full_page_ocr: - # If a full page OCR is forced, use only the OCR cells - page.cells = [ - Cell(id=c_ocr.id, text=c_ocr.text, bbox=c_ocr.bbox) - for c_ocr in all_ocr_cells - ] - else: - ## Remove OCR cells which overlap with programmatic cells. - filtered_ocr_cells = self.filter_ocr_cells( - all_ocr_cells, page.cells - ) - page.cells.extend(filtered_ocr_cells) + # Post-process the cells + page.cells = self.post_process_cells(all_ocr_cells, page.cells) # DEBUG code: if settings.debug.visualize_ocr: