Files
docling/docs/examples/post_process_ocr_with_vlm.py
Maxim Lysak c0b57ae389 chore: Cleaning the example of post_process_ocr_with_vlm (#2693)
Cleaning the example

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>
Co-authored-by: Maksym Lysak <mly@zurich.ibm.com>
2025-11-27 12:38:45 +01:00

740 lines
27 KiB
Python
Vendored

import argparse
import logging
import os
import re
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Optional, Union
import numpy as np
from docling_core.types.doc import (
DoclingDocument,
ImageRefMode,
NodeItem,
TextItem,
)
from docling_core.types.doc.document import (
ContentLayer,
DocItem,
FormItem,
GraphCell,
KeyValueItem,
PictureItem,
RichTableCell,
TableCell,
TableItem,
)
from PIL import Image, ImageFilter
from PIL.ImageOps import crop
from pydantic import BaseModel, ConfigDict
from tqdm import tqdm
from docling.backend.json.docling_json_backend import DoclingJSONBackend
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import InputFormat, ItemAndImageEnrichmentElement
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
ConvertPipelineOptions,
PdfPipelineOptions,
PictureDescriptionApiOptions,
)
from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption
from docling.exceptions import OperationNotAllowed
from docling.models.base_model import BaseModelWithOptions, GenericEnrichmentModel
from docling.pipeline.simple_pipeline import SimplePipeline
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
from docling.utils.api_image_request import api_image_request
from docling.utils.profiling import ProfilingScope, TimeRecorder
from docling.utils.utils import chunkify
# Example on how to apply to Docling Document OCR as a post-processing with "nanonets-ocr2-3b" via LM Studio
# Requires LM Studio running inference server with "nanonets-ocr2-3b" model pre-loaded
# To run:
# uv run python docs/examples/post_process_ocr_with_vlm.py
LM_STUDIO_URL = "http://localhost:1234/v1/chat/completions"
LM_STUDIO_MODEL = "nanonets-ocr2-3b"
DEFAULT_PROMPT = "Extract the text from the above document as if you were reading it naturally. Output pure text, no html and no markdown. Pay attention on line breaks and don't miss text after line break. Put all text in one line."
VERBOSE = True
SHOW_IMAGE = False
SHOW_EMPTY_CROPS = False
SHOW_NONEMPTY_CROPS = False
PRINT_RESULT_MARKDOWN = False
def is_empty_fast_with_lines_pil(
pil_img: Image.Image,
downscale_max_side: int = 48, # 64
grad_threshold: float = 15.0, # how strong a gradient must be to count as edge
min_line_coverage: float = 0.6, # line must cover 60% of height/width
max_allowed_lines: int = 10, # allow up to this many strong lines (default 4)
edge_fraction_threshold: float = 0.0035,
):
"""
Fast 'empty' detector using only PIL + NumPy.
Treats an image as empty if:
- It has very few edges overall, OR
- Edges can be explained by at most `max_allowed_lines` long vertical/horizontal lines.
Returns:
(is_empty: bool, remaining_edge_fraction: float, debug: dict)
"""
# 1) Convert to grayscale
gray = pil_img.convert("L")
# 2) Aggressive downscale, keeping aspect ratio
w0, h0 = gray.size
max_side = max(w0, h0)
if max_side > downscale_max_side:
# scale = downscale_max_side / max_side
# new_w = max(1, int(w0 * scale))
# new_h = max(1, int(h0 * scale))
new_w = downscale_max_side
new_h = downscale_max_side
gray = gray.resize((new_w, new_h), resample=Image.BILINEAR)
w, h = gray.size
if w == 0 or h == 0:
return True, 0.0, {"reason": "zero_size"}
# 3) Small blur to reduce noise
gray = gray.filter(ImageFilter.BoxBlur(1))
# 4) Convert to NumPy
arr = np.asarray(
gray, dtype=np.float32
) # shape (h, w) in PIL, but note: PIL size is (w, h)
H, W = arr.shape
# 5) Compute simple gradients (forward differences)
gx = np.zeros_like(arr)
gy = np.zeros_like(arr)
gx[:, :-1] = arr[:, 1:] - arr[:, :-1] # horizontal differences
gy[:-1, :] = arr[1:, :] - arr[:-1, :] # vertical differences
mag = np.hypot(gx, gy) # gradient magnitude
# 6) Threshold gradients to get edges (boolean mask)
edges = mag > grad_threshold
edge_fraction = edges.mean()
# Quick early-exit: almost no edges => empty
if edge_fraction < edge_fraction_threshold:
return True, float(edge_fraction), {"reason": "few_edges"}
# 7) Detect strong vertical & horizontal lines via edge sums
col_sum = edges.sum(axis=0) # per column
row_sum = edges.sum(axis=1) # per row
# Line must have edge pixels in at least `min_line_coverage` of the dimension
vert_line_cols = np.where(col_sum >= min_line_coverage * H)[0]
horiz_line_rows = np.where(row_sum >= min_line_coverage * W)[0]
num_lines = len(vert_line_cols) + len(horiz_line_rows)
# If we have more long lines than allowed => non-empty
if num_lines > max_allowed_lines:
return (
False,
float(edge_fraction),
{
"reason": "too_many_lines",
"num_lines": int(num_lines),
"edge_fraction": float(edge_fraction),
},
)
# 8) Mask out those lines and recompute remaining edges
line_mask = np.zeros_like(edges, dtype=bool)
if len(vert_line_cols) > 0:
line_mask[:, vert_line_cols] = True
if len(horiz_line_rows) > 0:
line_mask[horiz_line_rows, :] = True
remaining_edges = edges & ~line_mask
remaining_edge_fraction = remaining_edges.mean()
is_empty = remaining_edge_fraction < edge_fraction_threshold
debug = {
"original_edge_fraction": float(edge_fraction),
"remaining_edge_fraction": float(remaining_edge_fraction),
"num_vert_lines": len(vert_line_cols),
"num_horiz_lines": len(horiz_line_rows),
}
return is_empty, float(remaining_edge_fraction), debug
def remove_break_lines(text: str) -> str:
# Replace any newline types with a single space
cleaned = re.sub(r"[\r\n]+", " ", text)
# Collapse multiple spaces into one
cleaned = re.sub(r"\s+", " ", cleaned)
return cleaned.strip()
def safe_crop(img: Image.Image, bbox):
left, top, right, bottom = bbox
# Clamp to image boundaries
left = max(0, min(left, img.width))
top = max(0, min(top, img.height))
right = max(0, min(right, img.width))
bottom = max(0, min(bottom, img.height))
return img.crop((left, top, right, bottom))
def no_long_repeats(s: str, threshold: int) -> bool:
"""
Returns False if the string `s` contains more than `threshold`
identical characters in a row, otherwise True.
"""
pattern = r"(.)\1{" + str(threshold) + ",}"
return re.search(pattern, s) is None
class PostOcrEnrichmentElement(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
item: Union[DocItem, TableCell, RichTableCell, GraphCell]
image: list[
Image.Image
] # Needs to be an a list of images for multi-provenance elements
class PostOcrEnrichmentPipelineOptions(ConvertPipelineOptions):
api_options: PictureDescriptionApiOptions
class PostOcrEnrichmentPipeline(SimplePipeline):
def __init__(self, pipeline_options: PostOcrEnrichmentPipelineOptions):
super().__init__(pipeline_options)
self.pipeline_options: PostOcrEnrichmentPipelineOptions
self.enrichment_pipe = [
PostOcrApiEnrichmentModel(
enabled=True,
enable_remote_services=True,
artifacts_path=None,
options=self.pipeline_options.api_options,
accelerator_options=AcceleratorOptions(),
)
]
@classmethod
def get_default_options(cls) -> PostOcrEnrichmentPipelineOptions:
return PostOcrEnrichmentPipelineOptions()
def _enrich_document(self, conv_res: ConversionResult) -> ConversionResult:
def _prepare_elements(
conv_res: ConversionResult, model: GenericEnrichmentModel[Any]
) -> Iterable[NodeItem]:
for doc_element, _level in conv_res.document.iterate_items(
traverse_pictures=True,
included_content_layers={
ContentLayer.BODY,
ContentLayer.FURNITURE,
},
): # With all content layers, with traverse_pictures=True
prepared_elements = (
model.prepare_element( # make this one yield multiple items.
conv_res=conv_res, element=doc_element
)
)
if prepared_elements is not None:
yield prepared_elements
with TimeRecorder(conv_res, "doc_enrich", scope=ProfilingScope.DOCUMENT):
for model in self.enrichment_pipe:
for element_batch in chunkify(
_prepare_elements(conv_res, model),
model.elements_batch_size,
):
for element in model(
doc=conv_res.document, element_batch=element_batch
): # Must exhaust!
pass
return conv_res
class PostOcrApiEnrichmentModel(
GenericEnrichmentModel[PostOcrEnrichmentElement], BaseModelWithOptions
):
expansion_factor: float = 0.001
def prepare_element(
self, conv_res: ConversionResult, element: NodeItem
) -> Optional[list[PostOcrEnrichmentElement]]:
if not self.is_processable(doc=conv_res.document, element=element):
return None
allowed = (DocItem, TableItem, GraphCell)
assert isinstance(element, allowed)
if isinstance(element, (KeyValueItem, FormItem)):
# Yield from the graphCells inside here.
result = []
for c in element.graph.cells:
element_prov = c.prov # Key / Value have only one provenance!
bbox = element_prov.bbox
page_ix = element_prov.page_no
bbox = bbox.scale_to_size(
old_size=conv_res.document.pages[page_ix].size,
new_size=conv_res.document.pages[page_ix].image.size,
)
expanded_bbox = bbox.expand_by_scale(
x_scale=self.expansion_factor, y_scale=self.expansion_factor
).to_top_left_origin(
page_height=conv_res.document.pages[page_ix].image.size.height
)
good_bbox = True
if (
expanded_bbox.l > expanded_bbox.r
or expanded_bbox.t > expanded_bbox.b
):
good_bbox = False
if good_bbox:
cropped_image = conv_res.document.pages[
page_ix
].image.pil_image.crop(expanded_bbox.as_tuple())
is_empty, rem_frac, debug = is_empty_fast_with_lines_pil(
cropped_image
)
if is_empty:
if SHOW_EMPTY_CROPS:
try:
cropped_image.show()
except Exception as e:
print(f"Error with image: {e}")
print(
f"Detected empty form item image crop: {rem_frac} - {debug}"
)
else:
result.append(
PostOcrEnrichmentElement(item=c, image=[cropped_image])
)
return result
elif isinstance(element, TableItem):
element_prov = element.prov[0]
page_ix = element_prov.page_no
result = []
for i, row in enumerate(element.data.grid):
for j, cell in enumerate(row):
if hasattr(cell, "bbox"):
if cell.bbox:
bbox = cell.bbox
bbox = bbox.scale_to_size(
old_size=conv_res.document.pages[page_ix].size,
new_size=conv_res.document.pages[page_ix].image.size,
)
expanded_bbox = bbox.expand_by_scale(
x_scale=self.table_cell_expansion_factor,
y_scale=self.table_cell_expansion_factor,
).to_top_left_origin(
page_height=conv_res.document.pages[
page_ix
].image.size.height
)
good_bbox = True
if (
expanded_bbox.l > expanded_bbox.r
or expanded_bbox.t > expanded_bbox.b
):
good_bbox = False
if good_bbox:
cropped_image = conv_res.document.pages[
page_ix
].image.pil_image.crop(expanded_bbox.as_tuple())
is_empty, rem_frac, debug = (
is_empty_fast_with_lines_pil(cropped_image)
)
if is_empty:
if SHOW_EMPTY_CROPS:
try:
cropped_image.show()
except Exception as e:
print(f"Error with image: {e}")
print(
f"Detected empty table cell image crop: {rem_frac} - {debug}"
)
else:
if SHOW_NONEMPTY_CROPS:
cropped_image.show()
result.append(
PostOcrEnrichmentElement(
item=cell, image=[cropped_image]
)
)
return result
else:
multiple_crops = []
# Crop the image form the page
for element_prov in element.prov:
# Iterate over provenances
bbox = element_prov.bbox
page_ix = element_prov.page_no
bbox = bbox.scale_to_size(
old_size=conv_res.document.pages[page_ix].size,
new_size=conv_res.document.pages[page_ix].image.size,
)
expanded_bbox = bbox.expand_by_scale(
x_scale=self.expansion_factor, y_scale=self.expansion_factor
).to_top_left_origin(
page_height=conv_res.document.pages[page_ix].image.size.height
)
good_bbox = True
if (
expanded_bbox.l > expanded_bbox.r
or expanded_bbox.t > expanded_bbox.b
):
good_bbox = False
if hasattr(element, "text"):
if good_bbox:
cropped_image = conv_res.document.pages[
page_ix
].image.pil_image.crop(expanded_bbox.as_tuple())
is_empty, rem_frac, debug = is_empty_fast_with_lines_pil(
cropped_image
)
if is_empty:
if SHOW_EMPTY_CROPS:
try:
cropped_image.show()
except Exception as e:
print(f"Error with image: {e}")
print(f"Detected empty text crop: {rem_frac} - {debug}")
else:
multiple_crops.append(cropped_image)
if hasattr(element, "text"):
print(f"\nOLD TEXT: {element.text}")
else:
print("Not a text element")
if len(multiple_crops) > 0:
# good crops
return [PostOcrEnrichmentElement(item=element, image=multiple_crops)]
else:
# nothing
return []
@classmethod
def get_options_type(cls) -> type[PictureDescriptionApiOptions]:
return PictureDescriptionApiOptions
def __init__(
self,
*,
enabled: bool,
enable_remote_services: bool,
artifacts_path: Optional[Union[Path, str]],
options: PictureDescriptionApiOptions,
accelerator_options: AcceleratorOptions,
):
self.enabled = enabled
self.options = options
self.concurrency = 2
self.expansion_factor = 0.05
self.table_cell_expansion_factor = 0.0 # do not modify table cell size
self.elements_batch_size = 4
self._accelerator_options = accelerator_options
self._artifacts_path = (
Path(artifacts_path) if isinstance(artifacts_path, str) else artifacts_path
)
if self.enabled and not enable_remote_services:
raise OperationNotAllowed(
"Enable remote services by setting pipeline_options.enable_remote_services=True."
)
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
return self.enabled
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
def _api_request(image: Image.Image) -> str:
res = api_image_request(
image=image,
prompt=self.options.prompt,
url=self.options.url,
# timeout=self.options.timeout,
timeout=30,
headers=self.options.headers,
**self.options.params,
)
return res[0]
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
yield from executor.map(_api_request, images)
def __call__(
self,
doc: DoclingDocument,
element_batch: Iterable[ItemAndImageEnrichmentElement],
) -> Iterable[NodeItem]:
if not self.enabled:
for element in element_batch:
yield element.item
return
elements: list[TextItem] = []
images: list[Image.Image] = []
img_ind_per_element: list[int] = []
for element_stack in element_batch:
for element in element_stack:
allowed = (DocItem, TableCell, RichTableCell, GraphCell)
assert isinstance(element.item, allowed)
for ind, img in enumerate(element.image):
elements.append(element.item)
images.append(img)
# images.append(element.image)
img_ind_per_element.append(ind)
if not images:
return
outputs = list(self._annotate_images(images))
for item, output, img_ind in zip(elements, outputs, img_ind_per_element):
# Sometimes model can return html tags, which are not strictly needed in our, so it's better to clean them
def clean_html_tags(text):
for tag in [
"<table>",
"<tr>",
"<td>",
"<strong>",
"</table>",
"</tr>",
"</td>",
"</strong>",
"<th>",
"</th>",
"<tbody>",
"<tbody>",
"<thead>",
"</thead>",
]:
text = text.replace(tag, "")
return text
output = clean_html_tags(output).strip()
output = remove_break_lines(output)
# The last measure against hallucinations
# Detect hallucinated string...
if output.startswith("The first of these"):
output = ""
if no_long_repeats(output, 50):
if VERBOSE:
if isinstance(item, (TextItem)):
print(f"\nOLD TEXT: {item.text}")
# Re-populate text
if isinstance(item, (TextItem, GraphCell)):
if img_ind > 0:
# Concat texts across several provenances
item.text += " " + output
# item.orig += " " + output
else:
item.text = output
# item.orig = output
elif isinstance(item, (TableCell, RichTableCell)):
item.text = output
elif isinstance(item, PictureItem):
pass
else:
raise ValueError(f"Unknown item type: {type(item)}")
if VERBOSE:
if isinstance(item, (TextItem)):
print(f"NEW TEXT: {item.text}")
# Take care of charspans for relevant types
if isinstance(item, GraphCell):
item.prov.charspan = (0, len(item.text))
elif isinstance(item, TextItem):
item.prov[0].charspan = (0, len(item.text))
yield item
def convert_pdf(pdf_path: Path, out_intermediate_json: Path):
# Let's prepare a Docling document json with embedded page images
pipeline_options = PdfPipelineOptions()
pipeline_options.generate_page_images = True
pipeline_options.generate_picture_images = True
# pipeline_options.images_scale = 4.0
pipeline_options.images_scale = 2.0
doc_converter = (
DocumentConverter( # all of the below is optional, has internal defaults.
allowed_formats=[InputFormat.PDF],
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_cls=StandardPdfPipeline, pipeline_options=pipeline_options
)
},
)
)
if VERBOSE:
print(
"Converting PDF to get a Docling document json with embedded page images..."
)
conv_result = doc_converter.convert(pdf_path)
conv_result.document.save_as_json(
filename=out_intermediate_json, image_mode=ImageRefMode.EMBEDDED
)
if PRINT_RESULT_MARKDOWN:
md1 = conv_result.document.export_to_markdown()
print("*** ORIGINAL MARKDOWN ***")
print(md1)
def post_process_json(in_json: Path, out_final_json: Path):
# Post-Process OCR on top of existing Docling document, per element's bounding box:
print(f"Post-process all bounding boxes with OCR... {os.path.basename(in_json)}")
pipeline_options = PostOcrEnrichmentPipelineOptions(
api_options=PictureDescriptionApiOptions(
url=LM_STUDIO_URL,
prompt=DEFAULT_PROMPT,
provenance="lm-studio-ocr",
batch_size=4,
concurrency=2,
scale=2.0,
params={"model": LM_STUDIO_MODEL},
)
)
doc_converter = DocumentConverter(
format_options={
InputFormat.JSON_DOCLING: FormatOption(
pipeline_cls=PostOcrEnrichmentPipeline,
pipeline_options=pipeline_options,
backend=DoclingJSONBackend,
)
}
)
result = doc_converter.convert(in_json)
if SHOW_IMAGE:
result.document.pages[1].image.pil_image.show()
result.document.save_as_json(out_final_json)
if PRINT_RESULT_MARKDOWN:
md = result.document.export_to_markdown()
print("*** MARKDOWN ***")
print(md)
def process_pdf(pdf_path: Path, scratch_dir: Path, out_dir: Path):
inter_json = scratch_dir / (pdf_path.stem + ".json")
final_json = out_dir / (pdf_path.stem + ".json")
inter_json.parent.mkdir(parents=True, exist_ok=True)
final_json.parent.mkdir(parents=True, exist_ok=True)
if final_json.exists() and final_json.stat().st_size > 0:
print(f"Result already found here: '{final_json}', aborting...")
return # already done
convert_pdf(pdf_path, inter_json)
post_process_json(inter_json, final_json)
def process_json(json_path: Path, out_dir: Path):
final_json = out_dir / (json_path.stem + ".json")
final_json.parent.mkdir(parents=True, exist_ok=True)
if final_json.exists() and final_json.stat().st_size > 0:
return # already done
post_process_json(json_path, final_json)
def filter_jsons_by_ocr_list(jsons, folder):
"""
jsons: list[Path] - JSON files
folder: Path - folder containing ocr_documents.txt
"""
ocr_file = folder / "ocr_documents.txt"
# If the file doesn't exist, return the list unchanged
if not ocr_file.exists():
return jsons
# Read file names (strip whitespace, ignore empty lines)
with ocr_file.open("r", encoding="utf-8") as f:
allowed = {line.strip() for line in f if line.strip()}
# Keep only JSONs whose stem is in allowed list
filtered = [p for p in jsons if p.stem in allowed]
return filtered
def run_jsons(in_path: Path, out_dir: Path):
if in_path.is_dir():
jsons = sorted(in_path.glob("*.json"))
if not jsons:
raise SystemExit("Folder mode expects one or more .json files")
# Look for ocr_documents.txt, in case found, respect only the jsons
filtered_jsons = filter_jsons_by_ocr_list(jsons, in_path)
for j in tqdm(filtered_jsons):
print("")
print("Processing file...")
print(j)
process_json(j, out_dir)
else:
raise SystemExit("Invalid --in path")
def main():
logging.getLogger().setLevel(logging.ERROR)
p = argparse.ArgumentParser(description="PDF/JSON -> final JSON pipeline")
p.add_argument(
"--in",
dest="in_path",
default="tests/data/pdf/2305.03393v1-pg9.pdf",
required=False,
help="Path to a PDF/JSON file or a folder of JSONs",
)
p.add_argument(
"--out",
dest="out_dir",
default="scratch/",
required=False,
help="Folder for final JSONs (scratch goes inside)",
)
args = p.parse_args()
in_path = Path(args.in_path).expanduser().resolve()
out_dir = Path(args.out_dir).expanduser().resolve()
print(f"in_path: {in_path}")
print(f"out_dir: {out_dir}")
scratch_dir = out_dir / "temp"
if not in_path.exists():
raise SystemExit(f"Input not found: {in_path}")
if in_path.is_file():
if in_path.suffix.lower() == ".pdf":
process_pdf(in_path, scratch_dir, out_dir)
elif in_path.suffix.lower() == ".json":
process_json(in_path, out_dir)
else:
raise SystemExit("Single-file mode expects a .pdf or .json")
else:
run_jsons(in_path, out_dir)
if __name__ == "__main__":
main()