mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-25 19:44:34 +00:00
Keep page.parsed_page.textline_cells and page.cells in sync, including OCR
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
6613b9e98b
commit
e310c5cff3
@ -130,19 +130,49 @@ class BaseOcrModel(BasePageModel, BaseModelWithOptions):
|
||||
]
|
||||
return filtered_ocr_cells
|
||||
|
||||
def post_process_cells(self, ocr_cells, programmatic_cells):
|
||||
def post_process_cells(self, ocr_cells, page):
|
||||
r"""
|
||||
Post-process the ocr and programmatic cells and return the final list of of cells
|
||||
Post-process the OCR cells and update the page object.
|
||||
Treats page.parsed_page as authoritative when available, with page.cells for compatibility.
|
||||
"""
|
||||
if self.options.force_full_page_ocr:
|
||||
# If a full page OCR is forced, use only the OCR cells
|
||||
cells = ocr_cells
|
||||
return cells
|
||||
# Get existing cells (prefer parsed_page, fallback to page.cells)
|
||||
existing_cells = self._get_existing_cells(page)
|
||||
|
||||
## Remove OCR cells which overlap with programmatic cells.
|
||||
filtered_ocr_cells = self._filter_ocr_cells(ocr_cells, programmatic_cells)
|
||||
programmatic_cells.extend(filtered_ocr_cells)
|
||||
return programmatic_cells
|
||||
# Combine existing and OCR cells with overlap filtering
|
||||
final_cells = self._combine_cells(existing_cells, ocr_cells)
|
||||
|
||||
# Update both structures efficiently
|
||||
self._update_page_structures(page, final_cells)
|
||||
|
||||
def _get_existing_cells(self, page):
|
||||
"""Get existing cells, preferring parsed_page when available."""
|
||||
return page.parsed_page.textline_cells if page.parsed_page else page.cells
|
||||
|
||||
def _combine_cells(self, existing_cells, ocr_cells):
|
||||
"""Combine existing and OCR cells with filtering and re-indexing."""
|
||||
if self.options.force_full_page_ocr:
|
||||
combined = ocr_cells
|
||||
else:
|
||||
filtered_ocr_cells = self._filter_ocr_cells(ocr_cells, existing_cells)
|
||||
combined = list(existing_cells) + filtered_ocr_cells
|
||||
|
||||
# Re-index in-place
|
||||
for i, cell in enumerate(combined):
|
||||
cell.index = i
|
||||
|
||||
return combined
|
||||
|
||||
def _update_page_structures(self, page, final_cells):
|
||||
"""Update both page structures efficiently."""
|
||||
if page.parsed_page:
|
||||
# Update parsed_page as primary source
|
||||
page.parsed_page.textline_cells = final_cells
|
||||
page.parsed_page.has_lines = bool(final_cells)
|
||||
# Sync to page.cells for compatibility
|
||||
page.cells = final_cells
|
||||
else:
|
||||
# Legacy fallback: only page.cells available
|
||||
page.cells = final_cells
|
||||
|
||||
def draw_ocr_rects_and_cells(self, conv_res, page, ocr_rects, show: bool = False):
|
||||
image = copy.deepcopy(page.image)
|
||||
|
@ -177,7 +177,7 @@ class EasyOcrModel(BaseOcrModel):
|
||||
all_ocr_cells.extend(cells)
|
||||
|
||||
# Post-process the cells
|
||||
page.cells = self.post_process_cells(all_ocr_cells, page.cells)
|
||||
self.post_process_cells(all_ocr_cells, page)
|
||||
|
||||
# DEBUG code:
|
||||
if settings.debug.visualize_ocr:
|
||||
|
@ -176,9 +176,9 @@ class LayoutModel(BasePageModel):
|
||||
# Apply postprocessing
|
||||
|
||||
processed_clusters, processed_cells = LayoutPostprocessor(
|
||||
page.cells, clusters, page.size
|
||||
page, clusters
|
||||
).postprocess()
|
||||
# processed_clusters, processed_cells = clusters, page.cells
|
||||
# Note: LayoutPostprocessor updates page.cells and page.parsed_page internally
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
@ -198,7 +198,7 @@ class LayoutModel(BasePageModel):
|
||||
)
|
||||
)
|
||||
|
||||
page.cells = processed_cells
|
||||
# page.cells is already updated by LayoutPostprocessor
|
||||
page.predictions.layout = LayoutPrediction(
|
||||
clusters=processed_clusters
|
||||
)
|
||||
|
@ -132,7 +132,7 @@ class OcrMacModel(BaseOcrModel):
|
||||
all_ocr_cells.extend(cells)
|
||||
|
||||
# Post-process the cells
|
||||
page.cells = self.post_process_cells(all_ocr_cells, page.cells)
|
||||
self.post_process_cells(all_ocr_cells, page)
|
||||
|
||||
# DEBUG code:
|
||||
if settings.debug.visualize_ocr:
|
||||
|
@ -306,7 +306,7 @@ class TesseractOcrCliModel(BaseOcrModel):
|
||||
all_ocr_cells.append(cell)
|
||||
|
||||
# Post-process the cells
|
||||
page.cells = self.post_process_cells(all_ocr_cells, page.cells)
|
||||
self.post_process_cells(all_ocr_cells, page)
|
||||
|
||||
# DEBUG code:
|
||||
if settings.debug.visualize_ocr:
|
||||
|
@ -235,7 +235,7 @@ class TesseractOcrModel(BaseOcrModel):
|
||||
all_ocr_cells.extend(cells)
|
||||
|
||||
# Post-process the cells
|
||||
page.cells = self.post_process_cells(all_ocr_cells, page.cells)
|
||||
self.post_process_cells(all_ocr_cells, page)
|
||||
|
||||
# DEBUG code:
|
||||
if settings.debug.visualize_ocr:
|
||||
|
@ -194,11 +194,12 @@ class LayoutPostprocessor:
|
||||
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
|
||||
}
|
||||
|
||||
def __init__(self, cells: List[TextCell], clusters: List[Cluster], page_size: Size):
|
||||
"""Initialize processor with cells and clusters."""
|
||||
"""Initialize processor with cells and spatial indices."""
|
||||
self.cells = cells
|
||||
self.page_size = page_size
|
||||
def __init__(self, page, clusters: List[Cluster]):
|
||||
"""Initialize processor with page and clusters."""
|
||||
# Get cells from best available source (prefer parsed_page)
|
||||
self.cells = self._get_page_cells(page)
|
||||
self.page = page
|
||||
self.page_size = page.size
|
||||
self.all_clusters = clusters
|
||||
self.regular_clusters = [
|
||||
c for c in clusters if c.label not in self.SPECIAL_TYPES
|
||||
@ -214,6 +215,24 @@ class LayoutPostprocessor:
|
||||
[c for c in self.special_clusters if c.label in self.WRAPPER_TYPES]
|
||||
)
|
||||
|
||||
def _get_page_cells(self, page):
|
||||
"""Get cells from best available source (prefer parsed_page)."""
|
||||
return (
|
||||
page.parsed_page.textline_cells
|
||||
if page.parsed_page is not None
|
||||
else page.cells
|
||||
)
|
||||
|
||||
def _update_page_structures(self, final_cells):
|
||||
"""Update both page structures efficiently."""
|
||||
if self.page.parsed_page is not None:
|
||||
# Update parsed_page as primary source
|
||||
self.page.parsed_page.textline_cells = final_cells
|
||||
self.page.parsed_page.has_lines = len(final_cells) > 0
|
||||
|
||||
# Legacy fallback: only page.cells available
|
||||
self.page.cells = final_cells
|
||||
|
||||
def postprocess(self) -> Tuple[List[Cluster], List[TextCell]]:
|
||||
"""Main processing pipeline."""
|
||||
self.regular_clusters = self._process_regular_clusters()
|
||||
@ -240,6 +259,9 @@ class LayoutPostprocessor:
|
||||
for child in cluster.children:
|
||||
child.cells = self._sort_cells(child.cells)
|
||||
|
||||
# Update page structures with processed cells
|
||||
self._update_page_structures(self.cells)
|
||||
|
||||
return final_clusters, self.cells
|
||||
|
||||
def _process_regular_clusters(self) -> List[Cluster]:
|
||||
|
Loading…
Reference in New Issue
Block a user