mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-31 14:34:40 +00:00
draft for picture description models
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
6c22cba0a7
commit
e1cba8a825
@ -1,8 +1,8 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import AnyUrl, BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
class TableFormerMode(str, Enum):
|
class TableFormerMode(str, Enum):
|
||||||
@ -61,6 +61,46 @@ class TesseractOcrOptions(OcrOptions):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PicDescBaseOptions(BaseModel):
|
||||||
|
kind: str
|
||||||
|
batch_size: int = 8
|
||||||
|
scale: float = 2
|
||||||
|
|
||||||
|
bitmap_area_threshold: float = (
|
||||||
|
0.2 # percentage of the area for a bitmap to processed with the models
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PicDescApiOptions(PicDescBaseOptions):
|
||||||
|
kind: Literal["api"] = "api"
|
||||||
|
|
||||||
|
url: AnyUrl = AnyUrl("")
|
||||||
|
headers: Dict[str, str] = {}
|
||||||
|
params: Dict[str, Any] = {}
|
||||||
|
timeout: float = 20
|
||||||
|
|
||||||
|
llm_prompt: str = ""
|
||||||
|
provenance: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class PicDescVllmOptions(PicDescBaseOptions):
|
||||||
|
kind: Literal["vllm"] = "vllm"
|
||||||
|
|
||||||
|
# For more example parameters see https://docs.vllm.ai/en/latest/getting_started/examples/offline_inference_vision_language.html
|
||||||
|
|
||||||
|
# Parameters for LLaVA-1.6/LLaVA-NeXT
|
||||||
|
llm_name: str = "llava-hf/llava-v1.6-mistral-7b-hf"
|
||||||
|
llm_prompt: str = "[INST] <image>\nDescribe the image in details. [/INST]"
|
||||||
|
llm_extra: Dict[str, Any] = dict(max_model_len=8192)
|
||||||
|
|
||||||
|
# Parameters for Phi-3-Vision
|
||||||
|
# llm_name: str = "microsoft/Phi-3-vision-128k-instruct"
|
||||||
|
# llm_prompt: str = "<|user|>\n<|image_1|>\nDescribe the image in details.<|end|>\n<|assistant|>\n"
|
||||||
|
# llm_extra: Dict[str, Any] = dict(max_num_seqs=5, trust_remote_code=True)
|
||||||
|
|
||||||
|
sampling_params: Dict[str, Any] = dict(max_tokens=64, seed=42)
|
||||||
|
|
||||||
|
|
||||||
class PipelineOptions(BaseModel):
|
class PipelineOptions(BaseModel):
|
||||||
create_legacy_output: bool = (
|
create_legacy_output: bool = (
|
||||||
True # This defautl will be set to False on a future version of docling
|
True # This defautl will be set to False on a future version of docling
|
||||||
@ -71,11 +111,15 @@ class PdfPipelineOptions(PipelineOptions):
|
|||||||
artifacts_path: Optional[Union[Path, str]] = None
|
artifacts_path: Optional[Union[Path, str]] = None
|
||||||
do_table_structure: bool = True # True: perform table structure extraction
|
do_table_structure: bool = True # True: perform table structure extraction
|
||||||
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
|
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
|
||||||
|
do_picture_description: bool = False
|
||||||
|
|
||||||
table_structure_options: TableStructureOptions = TableStructureOptions()
|
table_structure_options: TableStructureOptions = TableStructureOptions()
|
||||||
ocr_options: Union[EasyOcrOptions, TesseractCliOcrOptions, TesseractOcrOptions] = (
|
ocr_options: Union[EasyOcrOptions, TesseractCliOcrOptions, TesseractOcrOptions] = (
|
||||||
Field(EasyOcrOptions(), discriminator="kind")
|
Field(EasyOcrOptions(), discriminator="kind")
|
||||||
)
|
)
|
||||||
|
picture_description_options: Annotated[
|
||||||
|
Union[PicDescApiOptions, PicDescVllmOptions], Field(discriminator="kind")
|
||||||
|
] = PicDescApiOptions() # TODO: needs defaults or optional
|
||||||
|
|
||||||
images_scale: float = 1.0
|
images_scale: float = 1.0
|
||||||
generate_page_images: bool = False
|
generate_page_images: bool = False
|
||||||
|
99
docling/models/pic_description_api_model.py
Normal file
99
docling/models/pic_description_api_model.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from docling_core.types.doc import PictureItem
|
||||||
|
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
|
||||||
|
PictureDescriptionData,
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from docling.datamodel.pipeline_options import PicDescApiOptions
|
||||||
|
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
message: ChatMessage
|
||||||
|
finish_reason: str
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseUsage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class ApiResponse(BaseModel):
|
||||||
|
model_config = ConfigDict(
|
||||||
|
protected_namespaces=(),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: str
|
||||||
|
model: Optional[str] = None # returned bu openai
|
||||||
|
choices: List[ResponseChoice]
|
||||||
|
created: int
|
||||||
|
usage: ResponseUsage
|
||||||
|
|
||||||
|
|
||||||
|
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
||||||
|
|
||||||
|
def __init__(self, enabled: bool, options: PicDescApiOptions):
|
||||||
|
super().__init__(enabled=enabled, options=options)
|
||||||
|
self.options: PicDescApiOptions
|
||||||
|
|
||||||
|
def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
|
||||||
|
assert picture.image is not None
|
||||||
|
|
||||||
|
img_io = io.BytesIO()
|
||||||
|
picture.image.pil_image.save(img_io, "PNG")
|
||||||
|
|
||||||
|
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": self.options.llm_prompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"messages": messages,
|
||||||
|
**self.options.params,
|
||||||
|
}
|
||||||
|
|
||||||
|
r = httpx.post(
|
||||||
|
str(self.options.url),
|
||||||
|
headers=self.options.headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=self.options.timeout,
|
||||||
|
)
|
||||||
|
if not r.is_success:
|
||||||
|
_log.error(f"Error calling the API. Reponse was {r.text}")
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
api_resp = ApiResponse.model_validate_json(r.text)
|
||||||
|
generated_text = api_resp.choices[0].message.content.strip()
|
||||||
|
|
||||||
|
return PictureDescriptionData(
|
||||||
|
provenance=self.options.provenance,
|
||||||
|
text=generated_text,
|
||||||
|
)
|
46
docling/models/pic_description_base_model.py
Normal file
46
docling/models/pic_description_base_model.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Iterable
|
||||||
|
|
||||||
|
from docling_core.types.doc import (
|
||||||
|
DoclingDocument,
|
||||||
|
NodeItem,
|
||||||
|
PictureClassificationClass,
|
||||||
|
PictureItem,
|
||||||
|
)
|
||||||
|
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
|
||||||
|
PictureDescriptionData,
|
||||||
|
)
|
||||||
|
|
||||||
|
from docling.datamodel.pipeline_options import PicDescBaseOptions
|
||||||
|
from docling.models.base_model import BaseEnrichmentModel
|
||||||
|
|
||||||
|
|
||||||
|
class PictureDescriptionBaseModel(BaseEnrichmentModel):
|
||||||
|
|
||||||
|
def __init__(self, enabled: bool, options: PicDescBaseOptions):
|
||||||
|
self.enabled = enabled
|
||||||
|
self.options = options
|
||||||
|
self.provenance = "TODO"
|
||||||
|
|
||||||
|
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
||||||
|
# TODO: once the image classifier is active, we can differentiate among image types
|
||||||
|
return self.enabled and isinstance(element, PictureItem)
|
||||||
|
|
||||||
|
def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, doc: DoclingDocument, element_batch: Iterable[NodeItem]
|
||||||
|
) -> Iterable[Any]:
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
for element in element_batch:
|
||||||
|
assert isinstance(element, PictureItem)
|
||||||
|
assert element.image is not None
|
||||||
|
|
||||||
|
annotation = self._annotate_image(element)
|
||||||
|
element.annotations.append(annotation)
|
||||||
|
|
||||||
|
yield element
|
59
docling/models/pic_description_vllm_model.py
Normal file
59
docling/models/pic_description_vllm_model.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from docling_core.types.doc import PictureItem
|
||||||
|
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
|
||||||
|
PictureDescriptionData,
|
||||||
|
)
|
||||||
|
|
||||||
|
from docling.datamodel.pipeline_options import PicDescVllmOptions
|
||||||
|
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class PictureDescriptionVllmModel(PictureDescriptionBaseModel):
|
||||||
|
|
||||||
|
def __init__(self, enabled: bool, options: PicDescVllmOptions):
|
||||||
|
super().__init__(enabled=enabled, options=options)
|
||||||
|
self.options: PicDescVllmOptions
|
||||||
|
|
||||||
|
if self.enabled:
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
if self.enabled:
|
||||||
|
try:
|
||||||
|
from vllm import LLM, SamplingParams # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"VLLM is not installed. Please install Docling with the required extras `pip install docling[vllm]`."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sampling_params = SamplingParams(**self.options.sampling_params) # type: ignore
|
||||||
|
self.llm = LLM(model=self.options.llm_name, **self.options.llm_extra) # type: ignore
|
||||||
|
|
||||||
|
# Generate a stable hash from the extra parameters
|
||||||
|
def create_hash(t):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
params_hash = create_hash(
|
||||||
|
json.dumps(self.options.llm_extra, sort_keys=True)
|
||||||
|
+ json.dumps(self.options.sampling_params, sort_keys=True)
|
||||||
|
)
|
||||||
|
self.provenance = f"{self.options.llm_name}-{params_hash[:8]}"
|
||||||
|
|
||||||
|
def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
|
||||||
|
assert picture.image is not None
|
||||||
|
|
||||||
|
from vllm import RequestOutput
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
{
|
||||||
|
"prompt": self.options.llm_prompt,
|
||||||
|
"multi_modal_data": {"image": picture.image.pil_image},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
outputs: List[RequestOutput] = self.llm.generate( # type: ignore
|
||||||
|
inputs, sampling_params=self.sampling_params # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_text = outputs[0].outputs[0].text
|
||||||
|
return PictureDescriptionData(provenance=self.provenance, text=generated_text)
|
@ -11,6 +11,8 @@ from docling.datamodel.document import ConversionResult
|
|||||||
from docling.datamodel.pipeline_options import (
|
from docling.datamodel.pipeline_options import (
|
||||||
EasyOcrOptions,
|
EasyOcrOptions,
|
||||||
PdfPipelineOptions,
|
PdfPipelineOptions,
|
||||||
|
PicDescApiOptions,
|
||||||
|
PicDescVllmOptions,
|
||||||
TesseractCliOcrOptions,
|
TesseractCliOcrOptions,
|
||||||
TesseractOcrOptions,
|
TesseractOcrOptions,
|
||||||
)
|
)
|
||||||
@ -23,6 +25,9 @@ from docling.models.page_preprocessing_model import (
|
|||||||
PagePreprocessingModel,
|
PagePreprocessingModel,
|
||||||
PagePreprocessingOptions,
|
PagePreprocessingOptions,
|
||||||
)
|
)
|
||||||
|
from docling.models.pic_description_api_model import PictureDescriptionApiModel
|
||||||
|
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
|
||||||
|
from docling.models.pic_description_vllm_model import PictureDescriptionVllmModel
|
||||||
from docling.models.table_structure_model import TableStructureModel
|
from docling.models.table_structure_model import TableStructureModel
|
||||||
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
|
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
|
||||||
from docling.models.tesseract_ocr_model import TesseractOcrModel
|
from docling.models.tesseract_ocr_model import TesseractOcrModel
|
||||||
@ -83,8 +88,15 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
PageAssembleModel(options=PageAssembleOptions(keep_images=keep_images)),
|
PageAssembleModel(options=PageAssembleOptions(keep_images=keep_images)),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Picture description model
|
||||||
|
if (pic_desc_model := self.get_pic_description_model()) is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The specified picture description kind is not supported: {pipeline_options.picture_description_options.kind}."
|
||||||
|
)
|
||||||
|
|
||||||
self.enrichment_pipe = [
|
self.enrichment_pipe = [
|
||||||
# Other models working on `NodeItem` elements in the DoclingDocument
|
# Other models working on `NodeItem` elements in the DoclingDocument
|
||||||
|
pic_desc_model,
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -120,6 +132,23 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_pic_description_model(self) -> Optional[PictureDescriptionBaseModel]:
|
||||||
|
if isinstance(
|
||||||
|
self.pipeline_options.picture_description_options, PicDescApiOptions
|
||||||
|
):
|
||||||
|
return PictureDescriptionApiModel(
|
||||||
|
enabled=self.pipeline_options.do_picture_description,
|
||||||
|
options=self.pipeline_options.picture_description_options,
|
||||||
|
)
|
||||||
|
elif isinstance(
|
||||||
|
self.pipeline_options.picture_description_options, PicDescVllmOptions
|
||||||
|
):
|
||||||
|
return PictureDescriptionVllmModel(
|
||||||
|
enabled=self.pipeline_options.do_picture_description,
|
||||||
|
options=self.pipeline_options.picture_description_options,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
|
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
|
||||||
with TimeRecorder(conv_res, "page_init"):
|
with TimeRecorder(conv_res, "page_init"):
|
||||||
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
|
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
|
||||||
|
Loading…
Reference in New Issue
Block a user