feat: add options for choosing OCR engines (#118)

---------

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
Co-authored-by: Nikos Livathinos <nli@zurich.ibm.com>
Co-authored-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Michele Dolfi
2024-10-08 19:07:08 +02:00
committed by GitHub
parent d412c363d7
commit f96ea86a00
20 changed files with 699 additions and 32 deletions

View File

@@ -3,21 +3,21 @@ import logging
from abc import abstractmethod
from typing import Iterable, List, Tuple
import numpy
import numpy as np
from PIL import Image, ImageDraw
from rtree import index
from scipy.ndimage import find_objects, label
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
from docling.datamodel.pipeline_options import OcrOptions
_log = logging.getLogger(__name__)
class BaseOcrModel:
def __init__(self, config):
self.config = config
self.enabled = config["enabled"]
def __init__(self, enabled: bool, options: OcrOptions):
self.enabled = enabled
self.options = options
# Computes the optimum amount and coordinates of rectangles to OCR on a given page
def get_ocr_rects(self, page: Page) -> Tuple[bool, List[BoundingBox]]:

View File

@@ -4,21 +4,33 @@ from typing import Iterable
import numpy
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
from docling.datamodel.pipeline_options import EasyOcrOptions
from docling.models.base_ocr_model import BaseOcrModel
_log = logging.getLogger(__name__)
class EasyOcrModel(BaseOcrModel):
def __init__(self, config):
super().__init__(config)
def __init__(self, enabled: bool, options: EasyOcrOptions):
super().__init__(enabled=enabled, options=options)
self.options: EasyOcrOptions
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
if self.enabled:
import easyocr
try:
import easyocr
except ImportError:
raise ImportError(
"EasyOCR is not installed. Please install it via `pip install easyocr` to use this OCR engine. "
"Alternatively, Docling has support for other OCR engines. See the documentation."
)
self.reader = easyocr.Reader(config["lang"])
self.reader = easyocr.Reader(
lang_list=self.options.lang,
model_storage_directory=self.options.model_storage_directory,
download_enabled=self.options.download_enabled,
)
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
@@ -31,6 +43,9 @@ class EasyOcrModel(BaseOcrModel):
all_ocr_cells = []
for ocr_rect in ocr_rects:
# Skip zero area boxes
if ocr_rect.area() == 0:
continue
high_res_image = page._backend.get_page_image(
scale=self.scale, cropbox=ocr_rect
)

View File

@@ -0,0 +1,167 @@
import io
import logging
import tempfile
from subprocess import PIPE, Popen
from typing import Iterable, Tuple
import pandas as pd
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
from docling.datamodel.pipeline_options import TesseractCliOcrOptions
from docling.models.base_ocr_model import BaseOcrModel
_log = logging.getLogger(__name__)
class TesseractOcrCliModel(BaseOcrModel):
def __init__(self, enabled: bool, options: TesseractCliOcrOptions):
super().__init__(enabled=enabled, options=options)
self.options: TesseractCliOcrOptions
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
self._name = None
self._version = None
if self.enabled:
try:
self._get_name_and_version()
except Exception as exc:
raise RuntimeError(
f"Tesseract is not available, aborting: {exc} "
"Install tesseract on your system and the tesseract binary is discoverable. "
"The actual command for Tesseract can be specified in `pipeline_options.ocr_options.tesseract_cmd='tesseract'`. "
"Alternatively, Docling has support for other OCR engines. See the documentation."
)
def _get_name_and_version(self) -> Tuple[str, str]:
if self._name != None and self._version != None:
return self._name, self._version
cmd = [self.options.tesseract_cmd, "--version"]
proc = Popen(cmd, stdout=PIPE, stderr=PIPE)
stdout, stderr = proc.communicate()
proc.wait()
# HACK: Windows versions of Tesseract output the version to stdout, Linux versions
# to stderr, so check both.
version_line = (
(stdout.decode("utf8").strip() or stderr.decode("utf8").strip())
.split("\n")[0]
.strip()
)
# If everything else fails...
if not version_line:
version_line = "tesseract XXX"
name, version = version_line.split(" ")
self._name = name
self._version = version
return name, version
def _run_tesseract(self, ifilename: str):
cmd = [self.options.tesseract_cmd]
if self.options.lang is not None and len(self.options.lang) > 0:
cmd.append("-l")
cmd.append("+".join(self.options.lang))
if self.options.path is not None:
cmd.append("--tessdata-dir")
cmd.append(self.options.path)
cmd += [ifilename, "stdout", "tsv"]
_log.info("command: {}".format(" ".join(cmd)))
proc = Popen(cmd, stdout=PIPE)
output, _ = proc.communicate()
# _log.info(output)
# Decode the byte string to a regular string
decoded_data = output.decode("utf-8")
# _log.info(decoded_data)
# Read the TSV file generated by Tesseract
df = pd.read_csv(io.StringIO(decoded_data), sep="\t")
# Display the dataframe (optional)
# _log.info("df: ", df.head())
# Filter rows that contain actual text (ignore header or empty rows)
df_filtered = df[df["text"].notnull() & (df["text"].str.strip() != "")]
return df_filtered
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
if not self.enabled:
yield from page_batch
return
for page in page_batch:
ocr_rects = self.get_ocr_rects(page)
all_ocr_cells = []
for ocr_rect in ocr_rects:
# Skip zero area boxes
if ocr_rect.area() == 0:
continue
high_res_image = page._backend.get_page_image(
scale=self.scale, cropbox=ocr_rect
)
with tempfile.NamedTemporaryFile(suffix=".png", mode="w") as image_file:
fname = image_file.name
high_res_image.save(fname)
df = self._run_tesseract(fname)
# _log.info(df)
# Print relevant columns (bounding box and text)
for ix, row in df.iterrows():
text = row["text"]
conf = row["conf"]
l = float(row["left"])
b = float(row["top"])
w = float(row["width"])
h = float(row["height"])
t = b + h
r = l + w
cell = OcrCell(
id=ix,
text=text,
confidence=conf / 100.0,
bbox=BoundingBox.from_tuple(
coord=(
(l / self.scale) + ocr_rect.l,
(b / self.scale) + ocr_rect.t,
(r / self.scale) + ocr_rect.l,
(t / self.scale) + ocr_rect.t,
),
origin=CoordOrigin.TOPLEFT,
),
)
all_ocr_cells.append(cell)
## Remove OCR cells which overlap with programmatic cells.
filtered_ocr_cells = self.filter_ocr_cells(all_ocr_cells, page.cells)
page.cells.extend(filtered_ocr_cells)
# DEBUG code:
# self.draw_ocr_rects_and_cells(page, ocr_rects)
yield page

View File

@@ -0,0 +1,122 @@
import logging
from typing import Iterable
import numpy
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
from docling.datamodel.pipeline_options import TesseractCliOcrOptions
from docling.models.base_ocr_model import BaseOcrModel
_log = logging.getLogger(__name__)
class TesseractOcrModel(BaseOcrModel):
def __init__(self, enabled: bool, options: TesseractCliOcrOptions):
super().__init__(enabled=enabled, options=options)
self.options: TesseractCliOcrOptions
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
self.reader = None
if self.enabled:
setup_errmsg = (
"tesserocr is not correctly installed. "
"Please install it via `pip install tesserocr` to use this OCR engine. "
"Note that tesserocr might have to be manually compiled for working with"
"your Tesseract installation. The Docling documentation provides examples for it. "
"Alternatively, Docling has support for other OCR engines. See the documentation."
)
try:
import tesserocr
except ImportError:
raise ImportError(setup_errmsg)
try:
tesseract_version = tesserocr.tesseract_version()
_log.debug("Initializing TesserOCR: %s", tesseract_version)
except:
raise ImportError(setup_errmsg)
# Initialize the tesseractAPI
lang = "+".join(self.options.lang)
if self.options.path is not None:
self.reader = tesserocr.PyTessBaseAPI(
path=self.options.path,
lang=lang,
psm=tesserocr.PSM.AUTO,
init=True,
oem=tesserocr.OEM.DEFAULT,
)
else:
self.reader = tesserocr.PyTessBaseAPI(
lang=lang,
psm=tesserocr.PSM.AUTO,
init=True,
oem=tesserocr.OEM.DEFAULT,
)
self.reader_RIL = tesserocr.RIL
def __del__(self):
if self.reader is not None:
# Finalize the tesseractAPI
self.reader.End()
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
if not self.enabled:
yield from page_batch
return
for page in page_batch:
ocr_rects = self.get_ocr_rects(page)
all_ocr_cells = []
for ocr_rect in ocr_rects:
# Skip zero area boxes
if ocr_rect.area() == 0:
continue
high_res_image = page._backend.get_page_image(
scale=self.scale, cropbox=ocr_rect
)
# Retrieve text snippets with their bounding boxes
self.reader.SetImage(high_res_image)
boxes = self.reader.GetComponentImages(self.reader_RIL.TEXTLINE, True)
cells = []
for ix, (im, box, _, _) in enumerate(boxes):
# Set the area of interest. Tesseract uses Bottom-Left for the origin
self.reader.SetRectangle(box["x"], box["y"], box["w"], box["h"])
# Extract text within the bounding box
text = self.reader.GetUTF8Text().strip()
confidence = self.reader.MeanTextConf()
left = box["x"] / self.scale
bottom = box["y"] / self.scale
right = (box["x"] + box["w"]) / self.scale
top = (box["y"] + box["h"]) / self.scale
cells.append(
OcrCell(
id=ix,
text=text,
confidence=confidence,
bbox=BoundingBox.from_tuple(
coord=(left, top, right, bottom),
origin=CoordOrigin.TOPLEFT,
),
)
)
# del high_res_image
all_ocr_cells.extend(cells)
## Remove OCR cells which overlap with programmatic cells.
filtered_ocr_cells = self.filter_ocr_cells(all_ocr_cells, page.cells)
page.cells.extend(filtered_ocr_cells)
# DEBUG code:
# self.draw_ocr_rects_and_cells(page, ocr_rects)
yield page