mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-18 09:31:02 +00:00
feat: Add adaptive OCR, factor out treatment of OCR areas and cell filtering (#38)
* Introduce adaptive OCR Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Factor out BaseOcrModel, add docling-parse backend tests, fixes * Make easyocr default dep Signed-off-by: Christoph Auer <cau@zurich.ibm.com> --------- Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
124
docling/models/base_ocr_model.py
Normal file
124
docling/models/base_ocr_model.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import copy
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Iterable, List, Tuple
|
||||
|
||||
import numpy
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
from rtree import index
|
||||
from scipy.ndimage import find_objects, label
|
||||
|
||||
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseOcrModel:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.enabled = config["enabled"]
|
||||
|
||||
# Computes the optimum amount and coordinates of rectangles to OCR on a given page
|
||||
def get_ocr_rects(self, page: Page) -> Tuple[bool, List[BoundingBox]]:
|
||||
BITMAP_COVERAGE_TRESHOLD = 0.75
|
||||
|
||||
def find_ocr_rects(size, bitmap_rects):
|
||||
image = Image.new(
|
||||
"1", (round(size.width), round(size.height))
|
||||
) # '1' mode is binary
|
||||
|
||||
# Draw all bitmap rects into a binary image
|
||||
draw = ImageDraw.Draw(image)
|
||||
for rect in bitmap_rects:
|
||||
x0, y0, x1, y1 = rect.as_tuple()
|
||||
x0, y0, x1, y1 = round(x0), round(y0), round(x1), round(y1)
|
||||
draw.rectangle([(x0, y0), (x1, y1)], fill=1)
|
||||
|
||||
np_image = np.array(image)
|
||||
|
||||
# Find the connected components
|
||||
labeled_image, num_features = label(
|
||||
np_image > 0
|
||||
) # Label black (0 value) regions
|
||||
|
||||
# Find enclosing bounding boxes for each connected component.
|
||||
slices = find_objects(labeled_image)
|
||||
bounding_boxes = [
|
||||
BoundingBox(
|
||||
l=slc[1].start,
|
||||
t=slc[0].start,
|
||||
r=slc[1].stop - 1,
|
||||
b=slc[0].stop - 1,
|
||||
coord_origin=CoordOrigin.TOPLEFT,
|
||||
)
|
||||
for slc in slices
|
||||
]
|
||||
|
||||
# Compute area fraction on page covered by bitmaps
|
||||
area_frac = np.sum(np_image > 0) / (size.width * size.height)
|
||||
|
||||
return (area_frac, bounding_boxes) # fraction covered # boxes
|
||||
|
||||
bitmap_rects = page._backend.get_bitmap_rects()
|
||||
coverage, ocr_rects = find_ocr_rects(page.size, bitmap_rects)
|
||||
|
||||
# return full-page rectangle if sufficiently covered with bitmaps
|
||||
if coverage > BITMAP_COVERAGE_TRESHOLD:
|
||||
return [
|
||||
BoundingBox(
|
||||
l=0,
|
||||
t=0,
|
||||
r=page.size.width,
|
||||
b=page.size.height,
|
||||
coord_origin=CoordOrigin.TOPLEFT,
|
||||
)
|
||||
]
|
||||
# return individual rectangles if the bitmap coverage is smaller
|
||||
elif coverage < BITMAP_COVERAGE_TRESHOLD:
|
||||
return ocr_rects
|
||||
|
||||
# Filters OCR cells by dropping any OCR cell that intersects with an existing programmatic cell.
|
||||
def filter_ocr_cells(self, ocr_cells, programmatic_cells):
|
||||
# Create R-tree index for programmatic cells
|
||||
p = index.Property()
|
||||
p.dimension = 2
|
||||
idx = index.Index(properties=p)
|
||||
for i, cell in enumerate(programmatic_cells):
|
||||
idx.insert(i, cell.bbox.as_tuple())
|
||||
|
||||
def is_overlapping_with_existing_cells(ocr_cell):
|
||||
# Query the R-tree to get overlapping rectangles
|
||||
possible_matches_index = list(idx.intersection(ocr_cell.bbox.as_tuple()))
|
||||
|
||||
return (
|
||||
len(possible_matches_index) > 0
|
||||
) # this is a weak criterion but it works.
|
||||
|
||||
filtered_ocr_cells = [
|
||||
rect for rect in ocr_cells if not is_overlapping_with_existing_cells(rect)
|
||||
]
|
||||
return filtered_ocr_cells
|
||||
|
||||
def draw_ocr_rects_and_cells(self, page, ocr_rects):
|
||||
image = copy.deepcopy(page.image)
|
||||
draw = ImageDraw.Draw(image, "RGBA")
|
||||
|
||||
# Draw OCR rectangles as yellow filled rect
|
||||
for rect in ocr_rects:
|
||||
x0, y0, x1, y1 = rect.as_tuple()
|
||||
shade_color = (255, 255, 0, 40) # transparent yellow
|
||||
draw.rectangle([(x0, y0), (x1, y1)], fill=shade_color, outline=None)
|
||||
|
||||
# Draw OCR and programmatic cells
|
||||
for tc in page.cells:
|
||||
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
||||
color = "red"
|
||||
if isinstance(tc, OcrCell):
|
||||
color = "magenta"
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline=color)
|
||||
image.show()
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
pass
|
||||
@@ -1,20 +1,18 @@
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
from typing import Iterable
|
||||
|
||||
import numpy
|
||||
from PIL import ImageDraw
|
||||
|
||||
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EasyOcrModel:
|
||||
class EasyOcrModel(BaseOcrModel):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.enabled = config["enabled"]
|
||||
super().__init__(config)
|
||||
|
||||
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
|
||||
|
||||
if self.enabled:
|
||||
@@ -29,49 +27,44 @@ class EasyOcrModel:
|
||||
return
|
||||
|
||||
for page in page_batch:
|
||||
# rects = page._fpage.
|
||||
high_res_image = page.get_image(scale=self.scale)
|
||||
im = numpy.array(high_res_image)
|
||||
result = self.reader.readtext(im)
|
||||
ocr_rects = self.get_ocr_rects(page)
|
||||
|
||||
del high_res_image
|
||||
del im
|
||||
|
||||
cells = [
|
||||
OcrCell(
|
||||
id=ix,
|
||||
text=line[1],
|
||||
confidence=line[2],
|
||||
bbox=BoundingBox.from_tuple(
|
||||
coord=(
|
||||
line[0][0][0] / self.scale,
|
||||
line[0][0][1] / self.scale,
|
||||
line[0][2][0] / self.scale,
|
||||
line[0][2][1] / self.scale,
|
||||
),
|
||||
origin=CoordOrigin.TOPLEFT,
|
||||
),
|
||||
all_ocr_cells = []
|
||||
for ocr_rect in ocr_rects:
|
||||
high_res_image = page._backend.get_page_image(
|
||||
scale=self.scale, cropbox=ocr_rect
|
||||
)
|
||||
for ix, line in enumerate(result)
|
||||
]
|
||||
im = numpy.array(high_res_image)
|
||||
result = self.reader.readtext(im)
|
||||
|
||||
page.cells = cells # For now, just overwrites all digital cells.
|
||||
del high_res_image
|
||||
del im
|
||||
|
||||
cells = [
|
||||
OcrCell(
|
||||
id=ix,
|
||||
text=line[1],
|
||||
confidence=line[2],
|
||||
bbox=BoundingBox.from_tuple(
|
||||
coord=(
|
||||
(line[0][0][0] / self.scale) + ocr_rect.l,
|
||||
(line[0][0][1] / self.scale) + ocr_rect.t,
|
||||
(line[0][2][0] / self.scale) + ocr_rect.l,
|
||||
(line[0][2][1] / self.scale) + ocr_rect.t,
|
||||
),
|
||||
origin=CoordOrigin.TOPLEFT,
|
||||
),
|
||||
)
|
||||
for ix, line in enumerate(result)
|
||||
]
|
||||
all_ocr_cells.extend(cells)
|
||||
|
||||
## Remove OCR cells which overlap with programmatic cells.
|
||||
filtered_ocr_cells = self.filter_ocr_cells(all_ocr_cells, page.cells)
|
||||
|
||||
page.cells.extend(filtered_ocr_cells)
|
||||
|
||||
# DEBUG code:
|
||||
def draw_clusters_and_cells():
|
||||
image = copy.deepcopy(page.image)
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
cell_color = (
|
||||
random.randint(30, 140),
|
||||
random.randint(30, 140),
|
||||
random.randint(30, 140),
|
||||
)
|
||||
for tc in cells:
|
||||
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
|
||||
image.show()
|
||||
|
||||
# draw_clusters_and_cells()
|
||||
# self.draw_ocr_rects_and_cells(page, ocr_rects)
|
||||
|
||||
yield page
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import random
|
||||
from typing import Iterable, List
|
||||
|
||||
import numpy
|
||||
|
||||
Reference in New Issue
Block a user