introduce img understand pipeline

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2024-09-22 20:24:38 +02:00
parent 1f4b224ab6
commit a122a7be4c
8 changed files with 561 additions and 11 deletions

View File

@ -224,18 +224,27 @@ class TableStructurePrediction(BaseModel):
class TextElement(BasePageElement): ... class TextElement(BasePageElement): ...
class FigureClassificationData(BaseModel):
provenance: str
predicted_class: str
confidence: float
class FigureDescriptionData(BaseModel):
text: str
provenance: str = ""
class FigureData(BaseModel): class FigureData(BaseModel):
pass classification: Optional[FigureClassificationData] = None
description: Optional[FigureDescriptionData] = None
class FigureElement(BasePageElement): class FigureElement(BasePageElement):
data: Optional[FigureData] = None data: FigureData = FigureData()
provenance: Optional[str] = None
predicted_class: Optional[str] = None
confidence: Optional[float] = None
class FigureClassificationPrediction(BaseModel): class FigurePrediction(BaseModel):
figure_count: int = 0 figure_count: int = 0
figure_map: Dict[int, FigureElement] = {} figure_map: Dict[int, FigureElement] = {}
@ -248,7 +257,7 @@ class EquationPrediction(BaseModel):
class PagePredictions(BaseModel): class PagePredictions(BaseModel):
layout: Optional[LayoutPrediction] = None layout: Optional[LayoutPrediction] = None
tablestructure: Optional[TableStructurePrediction] = None tablestructure: Optional[TableStructurePrediction] = None
figures_classification: Optional[FigureClassificationPrediction] = None figures_prediction: Optional[FigurePrediction] = None
equations_prediction: Optional[EquationPrediction] = None equations_prediction: Optional[EquationPrediction] = None

View File

@ -0,0 +1,123 @@
import base64
import datetime
import io
import logging
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
import httpx
from PIL import Image
from pydantic import AnyUrl, BaseModel, ConfigDict
from docling.datamodel.base_models import Cluster, FigureDescriptionData
from docling.models.img_understand_base_model import (
ImgUnderstandBaseModel,
ImgUnderstandOptions,
)
_log = logging.getLogger(__name__)
class ImgUnderstandApiOptions(ImgUnderstandOptions):
kind: Literal["api"] = "api"
url: AnyUrl
headers: Dict[str, str]
params: Dict[str, Any]
timeout: float = 20
llm_prompt: str
provenance: str
class ChatMessage(BaseModel):
role: str
content: str
class ResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: str
class ResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ApiResponse(BaseModel):
model_config = ConfigDict(
protected_namespaces=(),
)
id: str
model_id: Optional[str] = None # returned by watsonx
model: Optional[str] = None # returned bu openai
choices: List[ResponseChoice]
created: int
usage: ResponseUsage
class ImgUnderstandApiModel(ImgUnderstandBaseModel):
def __init__(self, enabled: bool, options: ImgUnderstandApiOptions):
super().__init__(enabled=enabled, options=options)
self.options: ImgUnderstandApiOptions
def _annotate_image_batch(
self, batch: Iterable[Tuple[Cluster, Image.Image]]
) -> List[FigureDescriptionData]:
if not self.enabled:
return [FigureDescriptionData() for _ in batch]
results = []
for cluster, image in batch:
img_io = io.BytesIO()
image.save(img_io, "PNG")
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": self.options.llm_prompt,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_base64}"
},
},
],
}
]
payload = {
"messages": messages,
**self.options.params,
}
r = httpx.post(
str(self.options.url),
headers=self.options.headers,
json=payload,
timeout=self.options.timeout,
)
if not r.is_success:
_log.error(f"Error calling the API. Reponse was {r.text}")
r.raise_for_status()
api_resp = ApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip()
results.append(
FigureDescriptionData(
text=generated_text, provenance=self.options.provenance
)
)
_log.info(f"Generated description: {generated_text}")
return results

View File

@ -0,0 +1,145 @@
import logging
import time
from typing import Iterable, List, Literal, Tuple
from PIL import Image
from pydantic import BaseModel
from docling.datamodel.base_models import (
Cluster,
FigureData,
FigureDescriptionData,
FigureElement,
FigurePrediction,
Page,
)
_log = logging.getLogger(__name__)
class ImgUnderstandOptions(BaseModel):
kind: str
batch_size: int = 8
scale: float = 2
# if the relative area of the image with respect to the whole image page
# is larger than this threshold it will be processed, otherwise not.
# TODO: implement the skip logic
min_area: float = 0.05
class ImgUnderstandBaseModel:
def __init__(self, enabled: bool, options: ImgUnderstandOptions):
self.enabled = enabled
self.options = options
def _annotate_image_batch(
self, batch: Iterable[Tuple[Cluster, Image.Image]]
) -> List[FigureDescriptionData]:
raise NotImplemented()
def _flush_merge(
self,
page: Page,
cluster_figure_batch: List[Tuple[Cluster, Image.Image]],
figures_prediction: FigurePrediction,
):
start_time = time.time()
results_batch = self._annotate_image_batch(cluster_figure_batch)
assert len(results_batch) == len(
cluster_figure_batch
), "The returned annotations is not matching the input size"
end_time = time.time()
_log.info(
f"Batch of {len(results_batch)} images processed in {end_time-start_time:.1f} seconds. Time per image is {(end_time-start_time) / len(results_batch):.3f} seconds."
)
for (cluster, _), desc_data in zip(cluster_figure_batch, results_batch):
if not cluster.id in figures_prediction.figure_map:
figures_prediction.figure_map[cluster.id] = FigureElement(
label=cluster.label,
id=cluster.id,
data=FigureData(desciption=desc_data),
cluster=cluster,
page_no=page.page_no,
)
elif figures_prediction.figure_map[cluster.id].data.description is None:
figures_prediction.figure_map[cluster.id].data.description = desc_data
else:
_log.warning(
f"Conflicting predictions. "
f"Another model ({figures_prediction.figure_map[cluster.id].data.description.provenance}) "
f"was already predicting an image description. The new prediction will be skipped."
)
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
if not self.enabled:
yield from page_batch
return
for page in page_batch:
# This model could be the first one initializing figures_prediction
if page.predictions.figures_prediction is None:
page.predictions.figures_prediction = FigurePrediction()
# Select the picture clusters
in_clusters = []
for cluster in page.predictions.layout.clusters:
if cluster.label != "Picture":
continue
crop_bbox = cluster.bbox.scaled(
scale=self.options.scale
).to_top_left_origin(page_height=page.size.height * self.options.scale)
in_clusters.append(
(
cluster,
crop_bbox.as_tuple(),
)
)
if not len(in_clusters):
yield page
continue
# save classifications using proper object
if (
page.predictions.figures_prediction.figure_count > 0
and page.predictions.figures_prediction.figure_count != len(in_clusters)
):
raise RuntimeError(
"Different models predicted a different number of figures."
)
page.predictions.figures_prediction.figure_count = len(in_clusters)
cluster_figure_batch = []
page_image = page.get_image(scale=self.options.scale)
if page_image is None:
raise RuntimeError("The page image cannot be generated.")
for cluster, figure_bbox in in_clusters:
figure = page_image.crop(figure_bbox)
cluster_figure_batch.append((cluster, figure))
# if enough figures then flush
if len(cluster_figure_batch) == self.options.batch_size:
self._flush_merge(
page=page,
cluster_figure_batch=cluster_figure_batch,
figures_prediction=page.predictions.figures_prediction,
)
cluster_figure_batch = []
# final flush
if len(cluster_figure_batch) > 0:
self._flush_merge(
page=page,
cluster_figure_batch=cluster_figure_batch,
figures_prediction=page.predictions.figures_prediction,
)
cluster_figure_batch = []
yield page

View File

@ -0,0 +1,87 @@
import json
import logging
from typing import Any, Dict, Iterable, List, Literal, Tuple
from PIL import Image
from docling.datamodel.base_models import Cluster, FigureDescriptionData
from docling.models.img_understand_base_model import (
ImgUnderstandBaseModel,
ImgUnderstandOptions,
)
from docling.utils.utils import create_hash
_log = logging.getLogger(__name__)
class ImgUnderstandVllmOptions(ImgUnderstandOptions):
kind: Literal["vllm"] = "vllm"
# For more example parameters see https://docs.vllm.ai/en/latest/getting_started/examples/offline_inference_vision_language.html
# Parameters for LLaVA-1.6/LLaVA-NeXT
llm_name: str = "llava-hf/llava-v1.6-mistral-7b-hf"
llm_prompt: str = "[INST] <image>\nDescribe the image in details. [/INST]"
llm_extra: Dict[str, Any] = dict(max_model_len=8192)
# Parameters for Phi-3-Vision
# llm_name: str = "microsoft/Phi-3-vision-128k-instruct"
# llm_prompt: str = "<|user|>\n<|image_1|>\nDescribe the image in details.<|end|>\n<|assistant|>\n"
# llm_extra: Dict[str, Any] = dict(max_num_seqs=5, trust_remote_code=True)
sampling_params: Dict[str, Any] = dict(max_tokens=64, seed=42)
class ImgUnderstandVllmModel(ImgUnderstandBaseModel):
def __init__(self, enabled: bool, options: ImgUnderstandVllmOptions):
super().__init__(enabled=enabled, options=options)
self.options: ImgUnderstandVllmOptions
if self.enabled:
try:
from vllm import LLM, SamplingParams
except ImportError:
raise ImportError(
"VLLM is not installed. Please install Docling with the required extras `pip install docling[vllm]`."
)
self.sampling_params = SamplingParams(**self.options.sampling_params)
self.llm = LLM(model=self.options.llm_name, **self.options.llm_extra)
# Generate a stable hash from the extra parameters
params_hash = create_hash(
json.dumps(self.options.llm_extra, sort_keys=True)
+ json.dumps(self.options.sampling_params, sort_keys=True)
)
self.provenance = f"{self.options.llm_name}-{params_hash[:8]}"
def _annotate_image_batch(
self, batch: Iterable[Tuple[Cluster, Image.Image]]
) -> List[FigureDescriptionData]:
if not self.enabled:
return [FigureDescriptionData() for _ in batch]
from vllm import RequestOutput
inputs = [
{
"prompt": self.options.llm_prompt,
"multi_modal_data": {"image": im},
}
for _, im in batch
]
outputs: List[RequestOutput] = self.llm.generate(
inputs, sampling_params=self.sampling_params
)
results = []
for o in outputs:
generated_text = o.outputs[0].text
results.append(
FigureDescriptionData(text=generated_text, provenance=self.provenance)
)
_log.info(f"Generated description: {generated_text}")
return results

View File

@ -98,18 +98,17 @@ class PageAssembleModel:
body.append(tbl) body.append(tbl)
elif cluster.label == LayoutModel.FIGURE_LABEL: elif cluster.label == LayoutModel.FIGURE_LABEL:
fig = None fig = None
if page.predictions.figures_classification: if page.predictions.figures_prediction:
fig = page.predictions.figures_classification.figure_map.get( fig = page.predictions.figures_prediction.figure_map.get(
cluster.id, None cluster.id, None
) )
if ( if (
not fig not fig
): # fallback: add figure without classification, if it isn't present ): # fallback: add figure with default data, if it isn't present
fig = FigureElement( fig = FigureElement(
label=cluster.label, label=cluster.label,
id=cluster.id, id=cluster.id,
text="", text="",
data=None,
cluster=cluster, cluster=cluster,
page_no=page.page_no, page_no=page.page_no,
) )

View File

@ -0,0 +1,53 @@
from pathlib import Path
from typing import Union
from pydantic import BaseModel, Field
from docling.datamodel.base_models import PipelineOptions
from docling.models.img_understand_api_model import (
ImgUnderstandApiModel,
ImgUnderstandApiOptions,
)
from docling.models.img_understand_vllm_model import (
ImgUnderstandVllmModel,
ImgUnderstandVllmOptions,
)
from docling.pipeline.standard_model_pipeline import StandardModelPipeline
class ImgUnderstandPipelineOptions(PipelineOptions):
do_img_understand: bool = True
img_understand_options: Union[ImgUnderstandApiOptions, ImgUnderstandVllmOptions] = (
Field(ImgUnderstandVllmOptions(), discriminator="kind")
)
class ImgUnderstandPipeline(StandardModelPipeline):
def __init__(
self, artifacts_path: Path, pipeline_options: ImgUnderstandPipelineOptions
):
super().__init__(artifacts_path, pipeline_options)
if isinstance(
pipeline_options.img_understand_options, ImgUnderstandVllmOptions
):
self.model_pipe.append(
ImgUnderstandVllmModel(
enabled=pipeline_options.do_img_understand,
options=pipeline_options.img_understand_options,
)
)
elif isinstance(
pipeline_options.img_understand_options, ImgUnderstandApiOptions
):
self.model_pipe.append(
ImgUnderstandApiModel(
enabled=pipeline_options.do_img_understand,
options=pipeline_options.img_understand_options,
)
)
else:
raise RuntimeError(
f"The specified imgage understanding kind is not supported: {pipeline_options.img_understand_options.kind}."
)

View File

@ -0,0 +1,132 @@
import logging
import os
import time
from pathlib import Path
from typing import Iterable
import httpx
from dotenv import load_dotenv
from docling.datamodel.base_models import ConversionStatus
from docling.datamodel.document import ConversionResult, DocumentConversionInput
from docling.document_converter import DocumentConverter
from docling.pipeline.img_understand_pipeline import (
ImgUnderstandApiOptions,
ImgUnderstandPipeline,
ImgUnderstandPipelineOptions,
ImgUnderstandVllmOptions,
)
_log = logging.getLogger(__name__)
def export_documents(
conv_results: Iterable[ConversionResult],
output_dir: Path,
):
output_dir.mkdir(parents=True, exist_ok=True)
success_count = 0
failure_count = 0
for conv_res in conv_results:
if conv_res.status == ConversionStatus.SUCCESS:
success_count += 1
doc_filename = conv_res.input.file.stem
# # Export Deep Search document JSON format:
# with (output_dir / f"{doc_filename}.json").open("w") as fp:
# fp.write(json.dumps(conv_res.render_as_dict()))
# # Export Text format:
# with (output_dir / f"{doc_filename}.txt").open("w") as fp:
# fp.write(conv_res.render_as_text())
# # Export Markdown format:
# with (output_dir / f"{doc_filename}.md").open("w") as fp:
# fp.write(conv_res.render_as_markdown())
# # Export Document Tags format:
# with (output_dir / f"{doc_filename}.doctags").open("w") as fp:
# fp.write(conv_res.render_as_doctags())
else:
_log.info(f"Document {conv_res.input.file} failed to convert.")
failure_count += 1
_log.info(
f"Processed {success_count + failure_count} docs, of which {failure_count} failed"
)
return success_count, failure_count
def _get_iam_access_token(api_key: str) -> str:
res = httpx.post(
url="https://iam.cloud.ibm.com/identity/token",
headers={
"Content-Type": "application/x-www-form-urlencoded",
},
data=f"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey={api_key}",
)
res.raise_for_status()
api_out = res.json()
print(f"{api_out=}")
return api_out["access_token"]
def main():
logging.basicConfig(level=logging.INFO)
input_doc_paths = [
Path("./tests/data/2206.01062.pdf"),
]
load_dotenv()
api_key = os.environ.get("WX_API_KEY")
project_id = os.environ.get("WX_PROJECT_ID")
doc_converter = DocumentConverter(
pipeline_cls=ImgUnderstandPipeline,
# TODO: make DocumentConverter provide the correct default value
# for pipeline_options, given the pipeline_cls
pipeline_options=ImgUnderstandPipelineOptions(
img_understand_options=ImgUnderstandApiOptions(
url="https://us-south.ml.cloud.ibm.com/ml/v1/text/chat?version=2023-05-29",
headers={
"Authorization": "Bearer " + _get_iam_access_token(api_key=api_key),
},
params=dict(
model_id="meta-llama/llama3-llava-next-8b-hf",
project_id=project_id,
max_tokens=512,
seed=42,
),
llm_prompt="Describe this figure in three sentences.",
provenance="llama3-llava-next-8b-hf",
)
),
)
# Define input files
input = DocumentConversionInput.from_paths(input_doc_paths)
start_time = time.time()
conv_results = doc_converter.convert(input)
success_count, failure_count = export_documents(
conv_results, output_dir=Path("./scratch")
)
end_time = time.time() - start_time
_log.info(f"All documents were converted in {end_time:.2f} seconds.")
if failure_count > 0:
raise RuntimeError(
f"The example failed converting {failure_count} on {len(input_doc_paths)}."
)
if __name__ == "__main__":
main()

View File

@ -41,6 +41,7 @@ pyarrow = "^16.1.0"
######### #########
# extras: # extras:
######### #########
# vllm = { version = "^0.5.0", optional = true, markers = "sys_platform != 'darwin' or platform_machine != 'x86_64'" }
python-dotenv = { version = "^1.0.1", optional = true } python-dotenv = { version = "^1.0.1", optional = true }
llama-index-embeddings-huggingface = { version = "^0.3.1", optional = true } llama-index-embeddings-huggingface = { version = "^0.3.1", optional = true }
llama-index-llms-huggingface-api = { version = "^0.2.0", optional = true } llama-index-llms-huggingface-api = { version = "^0.2.0", optional = true }
@ -84,6 +85,7 @@ nbqa = "^1.9.0"
datasets = "^2.21.0" datasets = "^2.21.0"
[tool.poetry.extras] [tool.poetry.extras]
# vllm = ["vllm"]
examples = [ examples = [
"python-dotenv", "python-dotenv",
# LlamaIndex examples: # LlamaIndex examples: