feat: add support for google ocr

Add support for google OCR.

Signed-off-by: Mr.Haddad <bushr.haddad@gmail.com>
This commit is contained in:
Mr.Haddad 2025-01-08 12:10:48 +03:00
parent ead396ab40
commit 88e86d4235
7 changed files with 392 additions and 2 deletions

View File

@ -29,6 +29,7 @@ from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
EasyOcrOptions,
GoogleOcrOptions,
OcrEngine,
OcrMacOptions,
OcrOptions,
@ -347,6 +348,8 @@ def convert(
ocr_options = OcrMacOptions(force_full_page_ocr=force_ocr)
elif ocr_engine == OcrEngine.RAPIDOCR:
ocr_options = RapidOcrOptions(force_full_page_ocr=force_ocr)
elif ocr_engine == OcrEngine.GOOGLEOCR:
ocr_options = GoogleOcrOptions(force_full_page_ocr=force_ocr)
else:
raise RuntimeError(f"Unexpected OCR engine type {ocr_engine}")

View File

@ -151,6 +151,19 @@ class EasyOcrOptions(OcrOptions):
)
class GoogleOcrOptions(OcrOptions):
"""Options for the dense GoogleOcr engine."""
kind: Literal["googleocr"] = "googleocr"
lang: List[str] = ["en", "de"]
google_ocr_config_file_path: Optional[str] = os.getenv("GOOGLE_CONFIG_FILE_PATH")
google_ocr_region: Optional[str] = "eu-vision.googleapis.com"
model_config = ConfigDict(
extra="forbid",
)
class TesseractCliOcrOptions(OcrOptions):
"""Options for the TesseractCli engine."""
@ -207,6 +220,7 @@ class OcrEngine(str, Enum):
TESSERACT = "tesseract"
OCRMAC = "ocrmac"
RAPIDOCR = "rapidocr"
GOOGLEOCR = "googleocr"
class PipelineOptions(BaseModel):
@ -233,6 +247,7 @@ class PdfPipelineOptions(PipelineOptions):
TesseractOcrOptions,
OcrMacOptions,
RapidOcrOptions,
GoogleOcrOptions,
] = Field(EasyOcrOptions(), discriminator="kind")
images_scale: float = 1.0

View File

@ -0,0 +1,180 @@
import io
import logging
from typing import Iterable
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling.datamodel.base_models import Cell, OcrCell, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import GoogleOcrOptions
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class GoogleOcrModel(BaseOcrModel):
def __init__(self, enabled: bool, options: GoogleOcrOptions):
super().__init__(enabled=enabled, options=options)
self.options: GoogleOcrOptions
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
self.reader = None
if self.enabled:
try:
from google.cloud import vision
from google.oauth2 import service_account
# Initialize the tesseractAPI
_log.debug("Initializing Google OCR ")
self.image_context = {"language_hints": self.options.lang}
client_options = {"api_endpoint": self.options.google_ocr_region}
if self.options.google_ocr_config_file_path is None:
raise FileNotFoundError(
"Google OCR Config File is missing. Please provide a valid file path "
"via the GOOGLE_CONFIG_FILE_PATH environment variable."
)
google_creds = service_account.Credentials.from_service_account_file(
self.options.google_ocr_config_file_path
)
self.reader = vision.ImageAnnotatorClient(
credentials=google_creds, client_options=client_options
)
except ImportError:
raise ImportError(
"Failed to import required libraries for Google OCR. Ensure that the "
"'google-cloud-vision' and 'google-auth' packages are installed. "
"You can install them using 'pip install google-cloud-vision google-auth'."
)
def __del__(self):
if self.reader is not None:
pass
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
if not self.enabled:
yield from page_batch
return
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "ocr"):
assert self.reader is not None
ocr_rects = self.get_ocr_rects(page)
try:
all_ocr_cells = []
for ocr_rect in ocr_rects:
# Skip zero area boxes
if ocr_rect.area() == 0:
continue
high_res_image = page._backend.get_page_image(
scale=self.scale, cropbox=ocr_rect
)
# Convert Pillow image to content, represented as a stream of bytes, using IO buffer.
buffer = io.BytesIO()
try:
from google.cloud import vision
from google.oauth2 import service_account
except:
raise Exception
high_res_image.save(buffer, "PNG")
content = buffer.getvalue()
new_image = vision.Image(content=content)
google_response = self.reader.text_detection(
image=new_image, image_context=self.image_context
)
cells = []
ix = 0
for file_page in google_response.full_text_annotation.pages:
for block in file_page.blocks:
for paragraph in block.paragraphs:
for word in paragraph.words:
box = word.bounding_box.vertices
text = ""
for symbol in word.symbols:
text += symbol.text
# Extract text within the bounding box
confidence = word.confidence * 100
left = (
min(
box[0].x,
box[1].x,
box[2].x,
box[3].x,
)
/ self.scale
) + ocr_rect.l
bottom = (
max(
box[0].y,
box[1].y,
box[2].y,
box[3].y,
)
/ self.scale
) + ocr_rect.t
top = (
min(
box[0].y,
box[1].y,
box[2].y,
box[3].y,
)
/ self.scale
) + ocr_rect.t
right = (
max(
box[0].x,
box[1].x,
box[2].x,
box[3].x,
)
/ self.scale
) + ocr_rect.l
cells.append(
OcrCell(
id=ix,
text=text,
confidence=confidence,
bbox=BoundingBox.from_tuple(
coord=(
left,
top,
right,
bottom,
),
origin=CoordOrigin.TOPLEFT,
),
)
)
ix += 1
del high_res_image, buffer, content
all_ocr_cells.extend(cells)
except Exception as e:
raise e
# Post-process the cells
page.cells = self.post_process_cells(all_ocr_cells, page.cells)
# DEBUG code:
if settings.debug.visualize_ocr:
self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects, show=True)
yield page

View File

@ -11,6 +11,7 @@ from docling.datamodel.base_models import AssembledUnit, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
EasyOcrOptions,
GoogleOcrOptions,
OcrMacOptions,
PdfPipelineOptions,
RapidOcrOptions,
@ -20,6 +21,7 @@ from docling.datamodel.pipeline_options import (
from docling.models.base_ocr_model import BaseOcrModel
from docling.models.ds_glm_model import GlmModel, GlmOptions
from docling.models.easyocr_model import EasyOcrModel
from docling.models.google_ocr_model import GoogleOcrModel
from docling.models.layout_model import LayoutModel
from docling.models.ocr_mac_model import OcrMacModel
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
@ -143,6 +145,11 @@ class StandardPdfPipeline(PaginatedPipeline):
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
)
elif isinstance(self.pipeline_options.ocr_options, GoogleOcrOptions):
return GoogleOcrModel(
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
)
return None
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:

183
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
[[package]]
name = "aiohappyeyeballs"
@ -376,6 +376,17 @@ webencodings = "*"
[package.extras]
css = ["tinycss2 (>=1.1.0,<1.5)"]
[[package]]
name = "cachetools"
version = "5.5.0"
description = "Extensible memoizing collections and decorators"
optional = false
python-versions = ">=3.7"
files = [
{file = "cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292"},
{file = "cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a"},
]
[[package]]
name = "certifi"
version = "2024.8.30"
@ -1355,6 +1366,102 @@ gitdb = ">=4.0.1,<5"
doc = ["sphinx (==4.3.2)", "sphinx-autodoc-typehints", "sphinx-rtd-theme", "sphinxcontrib-applehelp (>=1.0.2,<=1.0.4)", "sphinxcontrib-devhelp (==1.0.2)", "sphinxcontrib-htmlhelp (>=2.0.0,<=2.0.1)", "sphinxcontrib-qthelp (==1.0.3)", "sphinxcontrib-serializinghtml (==1.1.5)"]
test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"]
[[package]]
name = "google-api-core"
version = "2.24.0"
description = "Google API client core library"
optional = false
python-versions = ">=3.7"
files = [
{file = "google_api_core-2.24.0-py3-none-any.whl", hash = "sha256:10d82ac0fca69c82a25b3efdeefccf6f28e02ebb97925a8cce8edbfe379929d9"},
{file = "google_api_core-2.24.0.tar.gz", hash = "sha256:e255640547a597a4da010876d333208ddac417d60add22b6851a0c66a831fcaf"},
]
[package.dependencies]
google-auth = ">=2.14.1,<3.0.dev0"
googleapis-common-protos = ">=1.56.2,<2.0.dev0"
grpcio = [
{version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
{version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
]
grpcio-status = [
{version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""},
{version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""},
]
proto-plus = [
{version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""},
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
]
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
requests = ">=2.18.0,<3.0.0.dev0"
[package.extras]
async-rest = ["google-auth[aiohttp] (>=2.35.0,<3.0.dev0)"]
grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"]
grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"]
[[package]]
name = "google-auth"
version = "2.37.0"
description = "Google Authentication Library"
optional = false
python-versions = ">=3.7"
files = [
{file = "google_auth-2.37.0-py2.py3-none-any.whl", hash = "sha256:42664f18290a6be591be5329a96fe30184be1a1badb7292a7f686a9659de9ca0"},
{file = "google_auth-2.37.0.tar.gz", hash = "sha256:0054623abf1f9c83492c63d3f47e77f0a544caa3d40b2d98e099a611c2dd5d00"},
]
[package.dependencies]
cachetools = ">=2.0.0,<6.0"
pyasn1-modules = ">=0.2.1"
rsa = ">=3.1.4,<5"
[package.extras]
aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"]
enterprise-cert = ["cryptography", "pyopenssl"]
pyjwt = ["cryptography (>=38.0.3)", "pyjwt (>=2.0)"]
pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"]
reauth = ["pyu2f (>=0.1.5)"]
requests = ["requests (>=2.20.0,<3.0.0.dev0)"]
[[package]]
name = "google-cloud-vision"
version = "3.9.0"
description = "Google Cloud Vision API client library"
optional = false
python-versions = ">=3.7"
files = [
{file = "google_cloud_vision-3.9.0-py2.py3-none-any.whl", hash = "sha256:9acec27ee05bd197f0d89c97e9719712ef245e0c37fd428e6af09a15a082fbef"},
{file = "google_cloud_vision-3.9.0.tar.gz", hash = "sha256:21226aac9cb4ba45bf89cc2e107aea19e4f78f9736eb1de56837e0c2989fecff"},
]
[package.dependencies]
google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev"
proto-plus = [
{version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""},
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
]
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev"
[[package]]
name = "googleapis-common-protos"
version = "1.66.0"
description = "Common protobufs used in Google APIs"
optional = false
python-versions = ">=3.7"
files = [
{file = "googleapis_common_protos-1.66.0-py2.py3-none-any.whl", hash = "sha256:d7abcd75fabb2e0ec9f74466401f6c119a0b498e27370e9be4c94cb7e382b8ed"},
{file = "googleapis_common_protos-1.66.0.tar.gz", hash = "sha256:c3e7b33d15fdca5374cc0a7346dd92ffa847425cc4ea941d970f13680052ec8c"},
]
[package.dependencies]
protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0"
[package.extras]
grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"]
[[package]]
name = "griffe"
version = "1.5.1"
@ -1450,6 +1557,22 @@ files = [
[package.extras]
protobuf = ["grpcio-tools (>=1.67.1)"]
[[package]]
name = "grpcio-status"
version = "1.67.1"
description = "Status proto mapping for gRPC"
optional = false
python-versions = ">=3.8"
files = [
{file = "grpcio_status-1.67.1-py3-none-any.whl", hash = "sha256:16e6c085950bdacac97c779e6a502ea671232385e6e37f258884d6883392c2bd"},
{file = "grpcio_status-1.67.1.tar.gz", hash = "sha256:2bf38395e028ceeecfd8866b081f61628114b384da7d51ae064ddc8d766a5d11"},
]
[package.dependencies]
googleapis-common-protos = ">=1.5.5"
grpcio = ">=1.67.1"
protobuf = ">=5.26.1,<6.0dev"
[[package]]
name = "h11"
version = "0.14.0"
@ -4363,6 +4486,23 @@ files = [
{file = "propcache-0.2.1.tar.gz", hash = "sha256:3f77ce728b19cb537714499928fe800c3dda29e8d9428778fc7c186da4c09a64"},
]
[[package]]
name = "proto-plus"
version = "1.25.0"
description = "Beautiful, Pythonic protocol buffers."
optional = false
python-versions = ">=3.7"
files = [
{file = "proto_plus-1.25.0-py3-none-any.whl", hash = "sha256:c91fc4a65074ade8e458e95ef8bac34d4008daa7cce4a12d6707066fca648961"},
{file = "proto_plus-1.25.0.tar.gz", hash = "sha256:fbb17f57f7bd05a68b7707e745e26528b0b3c34e378db91eef93912c54982d91"},
]
[package.dependencies]
protobuf = ">=3.19.0,<6.0.0dev"
[package.extras]
testing = ["google-api-core (>=1.31.5)"]
[[package]]
name = "protobuf"
version = "5.29.1"
@ -4492,6 +4632,31 @@ files = [
[package.extras]
test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"]
[[package]]
name = "pyasn1"
version = "0.6.1"
description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"},
{file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"},
]
[[package]]
name = "pyasn1-modules"
version = "0.4.1"
description = "A collection of ASN.1-based protocols modules"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"},
{file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"},
]
[package.dependencies]
pyasn1 = ">=0.4.6,<0.7.0"
[[package]]
name = "pyclipper"
version = "1.3.0.post6"
@ -5843,6 +6008,20 @@ files = [
{file = "rpds_py-0.22.3.tar.gz", hash = "sha256:e32fee8ab45d3c2db6da19a5323bc3362237c8b653c70194414b892fd06a080d"},
]
[[package]]
name = "rsa"
version = "4.9"
description = "Pure-Python RSA implementation"
optional = false
python-versions = ">=3.6,<4"
files = [
{file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
{file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
]
[package.dependencies]
pyasn1 = ">=0.1.3"
[[package]]
name = "rtree"
version = "1.3.0"
@ -7613,4 +7792,4 @@ tesserocr = ["tesserocr"]
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
content-hash = "e83ff77c43954474022132b205f9b0156014580d4a2b7d60e6daa756ec2e6433"
content-hash = "4ae1dbfdbaaf1d91a0fd40a1dc583b16b62f17d05dbc3b96f467b07ed79a139a"

View File

@ -36,6 +36,10 @@ pydantic-settings = "^2.3.0"
huggingface_hub = ">=0.23,<1"
requests = "^2.32.3"
easyocr = "^1.7"
google-api-core="^2.13.0"
google-auth="^2.23.4"
google-cloud-vision="^3.4.5"
googleapis-common-protos="^1.61.0"
tesserocr = { version = "^2.7.1", optional = true }
certifi = ">=2024.7.4"
rtree = "^1.3.0"

View File

@ -7,6 +7,7 @@ from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
EasyOcrOptions,
GoogleOcrOptions,
OcrMacOptions,
OcrOptions,
PdfPipelineOptions,
@ -62,6 +63,7 @@ def test_e2e_conversions():
TesseractOcrOptions(force_full_page_ocr=True),
TesseractCliOcrOptions(force_full_page_ocr=True),
RapidOcrOptions(force_full_page_ocr=True),
GoogleOcrOptions(force_full_page_ocr=True),
]
# only works on mac