added the tesseract_model.py

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2024-10-02 16:40:24 +02:00
parent bfdc4e32cc
commit 8d1c1d6dd5

View File

@ -2,6 +2,7 @@ import logging
from typing import Iterable from typing import Iterable
import numpy import numpy
import pandas as pd
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
from docling.datamodel.pipeline_options import TesseractOcrOptions from docling.datamodel.pipeline_options import TesseractOcrOptions
@ -9,8 +10,8 @@ from docling.models.base_ocr_model import BaseOcrModel
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
class TesseractModel(BaseOcrModel): class TesseractModel(BaseOcrModel):
def __init__(self, enabled: bool, options: TesseractOcrOptions): def __init__(self, enabled: bool, options: TesseractOcrOptions):
super().__init__(enabled=enabled, options=options) super().__init__(enabled=enabled, options=options)
self.options: TesseractOcrOptions self.options: TesseractOcrOptions
@ -18,10 +19,65 @@ class TesseractModel(BaseOcrModel):
self.scale = 3 # multiplier for 72 dpi == 216 dpi. self.scale = 3 # multiplier for 72 dpi == 216 dpi.
if self.enabled: if self.enabled:
import tesserocr try:
self._get_name_and_version()
except Exception as exc:
_log.error(f"Tesseract is not supported, aborting ...")
self.enabled = False
def _get_name_and_version(self) -> Tuple[str, str]:
self.reader = easyocr.Reader(lang_list=self.options.lang) if self._name!=None and self._version!=None:
return self._name, self._version
cmd = ['tesseract', '--version']
proc = Popen(cmd, stdout=PIPE, stderr=PIPE)
stdout, stderr = proc.communicate()
proc.wait()
# HACK: Windows versions of Tesseract output the version to stdout, Linux versions
# to stderr, so check both.
version_line = (stdout.decode('utf8').strip() or stderr.decode('utf8').strip()).split('\n')[0].strip()
# If everything else fails...
if not version_line:
version_line = 'tesseract XXX'
name, version = version_line.split(' ')
self._name = name
self._version = version
return name, version
def _run_tesseract(self, ifilename, languages=None):
cmd = ["tesseract"]
if languages:
cmd += ['-l', '+'.join(languages)]
cmd += [ifilename, 'stdout', "tsv"]
logger.info("command: {}".format(" ".join(cmd)))
proc = Popen(cmd, stdout=PIPE)
output, _ = proc.communicate()
# Read the TSV file generated by Tesseract
df = pd.read_csv('output_file_name.tsv', sep='\t')
# Display the dataframe (optional)
print(df.head())
# Filter rows that contain actual text (ignore header or empty rows)
df_filtered = df[df['text'].notnull() & (df['text'].str.strip() != '')]
return df_filtered
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
if not self.enabled: if not self.enabled:
@ -36,37 +92,55 @@ class TesseractModel(BaseOcrModel):
high_res_image = page._backend.get_page_image( high_res_image = page._backend.get_page_image(
scale=self.scale, cropbox=ocr_rect scale=self.scale, cropbox=ocr_rect
) )
im = numpy.array(high_res_image) print(high_res_image)
result = self.reader.readtext(im)
# FIXME: do we really need to save the image to a file
fname = "temporary-file.png"
high_res_image.save(fname)
del high_res_image if os.path.exists(fname):
del im df = self._run_tesseract(fname)
os.remove(fname)
else:
_log.error(f"no image file: {fname}")
# Print relevant columns (bounding box and text)
for index, row in df_filtered.iterrows():
print(row)
text = row["text"]
conf = row["confidence"]
l = float(row['left'])
t = float(row['top'])
w = float(row['width'])
h = float(row['height'])
cells = [ b = t-h
OcrCell( r = l+w
cell = OcrCell(
id=ix, id=ix,
text=line[1], text=text,
confidence=line[2], confidence=line[2],
bbox=BoundingBox.from_tuple( bbox=BoundingBox.from_tuple(
coord=( coord=(
(line[0][0][0] / self.scale) + ocr_rect.l, (l / self.scale) + ocr_rect.l,
(line[0][0][1] / self.scale) + ocr_rect.t, (b / self.scale) + ocr_rect.t,
(line[0][2][0] / self.scale) + ocr_rect.l, (r / self.scale) + ocr_rect.l,
(line[0][2][1] / self.scale) + ocr_rect.t, (t / self.scale) + ocr_rect.t,
), ),
origin=CoordOrigin.TOPLEFT, origin=CoordOrigin.TOPLEFT,
), ),
) )
for ix, line in enumerate(result) all_ocr_cells.append(cell)
]
all_ocr_cells.extend(cells)
## Remove OCR cells which overlap with programmatic cells. ## Remove OCR cells which overlap with programmatic cells.
filtered_ocr_cells = self.filter_ocr_cells(all_ocr_cells, page.cells) filtered_ocr_cells = self.filter_ocr_cells(all_ocr_cells, page.cells)
page.cells.extend(filtered_ocr_cells) page.cells.extend(filtered_ocr_cells)
# DEBUG code: # DEBUG code:
# self.draw_ocr_rects_and_cells(page, ocr_rects) self.draw_ocr_rects_and_cells(page, ocr_rects)
yield page yield page