mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-30 14:04:27 +00:00
draft for picture description models
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
6c22cba0a7
commit
e1cba8a825
@ -1,8 +1,8 @@
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import AnyUrl, BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class TableFormerMode(str, Enum):
|
||||
@ -61,6 +61,46 @@ class TesseractOcrOptions(OcrOptions):
|
||||
)
|
||||
|
||||
|
||||
class PicDescBaseOptions(BaseModel):
|
||||
kind: str
|
||||
batch_size: int = 8
|
||||
scale: float = 2
|
||||
|
||||
bitmap_area_threshold: float = (
|
||||
0.2 # percentage of the area for a bitmap to processed with the models
|
||||
)
|
||||
|
||||
|
||||
class PicDescApiOptions(PicDescBaseOptions):
|
||||
kind: Literal["api"] = "api"
|
||||
|
||||
url: AnyUrl = AnyUrl("")
|
||||
headers: Dict[str, str] = {}
|
||||
params: Dict[str, Any] = {}
|
||||
timeout: float = 20
|
||||
|
||||
llm_prompt: str = ""
|
||||
provenance: str = ""
|
||||
|
||||
|
||||
class PicDescVllmOptions(PicDescBaseOptions):
|
||||
kind: Literal["vllm"] = "vllm"
|
||||
|
||||
# For more example parameters see https://docs.vllm.ai/en/latest/getting_started/examples/offline_inference_vision_language.html
|
||||
|
||||
# Parameters for LLaVA-1.6/LLaVA-NeXT
|
||||
llm_name: str = "llava-hf/llava-v1.6-mistral-7b-hf"
|
||||
llm_prompt: str = "[INST] <image>\nDescribe the image in details. [/INST]"
|
||||
llm_extra: Dict[str, Any] = dict(max_model_len=8192)
|
||||
|
||||
# Parameters for Phi-3-Vision
|
||||
# llm_name: str = "microsoft/Phi-3-vision-128k-instruct"
|
||||
# llm_prompt: str = "<|user|>\n<|image_1|>\nDescribe the image in details.<|end|>\n<|assistant|>\n"
|
||||
# llm_extra: Dict[str, Any] = dict(max_num_seqs=5, trust_remote_code=True)
|
||||
|
||||
sampling_params: Dict[str, Any] = dict(max_tokens=64, seed=42)
|
||||
|
||||
|
||||
class PipelineOptions(BaseModel):
|
||||
create_legacy_output: bool = (
|
||||
True # This defautl will be set to False on a future version of docling
|
||||
@ -71,11 +111,15 @@ class PdfPipelineOptions(PipelineOptions):
|
||||
artifacts_path: Optional[Union[Path, str]] = None
|
||||
do_table_structure: bool = True # True: perform table structure extraction
|
||||
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
|
||||
do_picture_description: bool = False
|
||||
|
||||
table_structure_options: TableStructureOptions = TableStructureOptions()
|
||||
ocr_options: Union[EasyOcrOptions, TesseractCliOcrOptions, TesseractOcrOptions] = (
|
||||
Field(EasyOcrOptions(), discriminator="kind")
|
||||
)
|
||||
picture_description_options: Annotated[
|
||||
Union[PicDescApiOptions, PicDescVllmOptions], Field(discriminator="kind")
|
||||
] = PicDescApiOptions() # TODO: needs defaults or optional
|
||||
|
||||
images_scale: float = 1.0
|
||||
generate_page_images: bool = False
|
||||
|
99
docling/models/pic_description_api_model.py
Normal file
99
docling/models/pic_description_api_model.py
Normal file
@ -0,0 +1,99 @@
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
from docling_core.types.doc import PictureItem
|
||||
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
|
||||
PictureDescriptionData,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from docling.datamodel.pipeline_options import PicDescApiOptions
|
||||
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: str
|
||||
|
||||
|
||||
class ResponseUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ApiResponse(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
id: str
|
||||
model: Optional[str] = None # returned bu openai
|
||||
choices: List[ResponseChoice]
|
||||
created: int
|
||||
usage: ResponseUsage
|
||||
|
||||
|
||||
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
||||
|
||||
def __init__(self, enabled: bool, options: PicDescApiOptions):
|
||||
super().__init__(enabled=enabled, options=options)
|
||||
self.options: PicDescApiOptions
|
||||
|
||||
def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
|
||||
assert picture.image is not None
|
||||
|
||||
img_io = io.BytesIO()
|
||||
picture.image.pil_image.save(img_io, "PNG")
|
||||
|
||||
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": self.options.llm_prompt,
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
**self.options.params,
|
||||
}
|
||||
|
||||
r = httpx.post(
|
||||
str(self.options.url),
|
||||
headers=self.options.headers,
|
||||
json=payload,
|
||||
timeout=self.options.timeout,
|
||||
)
|
||||
if not r.is_success:
|
||||
_log.error(f"Error calling the API. Reponse was {r.text}")
|
||||
r.raise_for_status()
|
||||
|
||||
api_resp = ApiResponse.model_validate_json(r.text)
|
||||
generated_text = api_resp.choices[0].message.content.strip()
|
||||
|
||||
return PictureDescriptionData(
|
||||
provenance=self.options.provenance,
|
||||
text=generated_text,
|
||||
)
|
46
docling/models/pic_description_base_model.py
Normal file
46
docling/models/pic_description_base_model.py
Normal file
@ -0,0 +1,46 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
|
||||
from docling_core.types.doc import (
|
||||
DoclingDocument,
|
||||
NodeItem,
|
||||
PictureClassificationClass,
|
||||
PictureItem,
|
||||
)
|
||||
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
|
||||
PictureDescriptionData,
|
||||
)
|
||||
|
||||
from docling.datamodel.pipeline_options import PicDescBaseOptions
|
||||
from docling.models.base_model import BaseEnrichmentModel
|
||||
|
||||
|
||||
class PictureDescriptionBaseModel(BaseEnrichmentModel):
|
||||
|
||||
def __init__(self, enabled: bool, options: PicDescBaseOptions):
|
||||
self.enabled = enabled
|
||||
self.options = options
|
||||
self.provenance = "TODO"
|
||||
|
||||
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
||||
# TODO: once the image classifier is active, we can differentiate among image types
|
||||
return self.enabled and isinstance(element, PictureItem)
|
||||
|
||||
def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
|
||||
raise NotImplemented
|
||||
|
||||
def __call__(
|
||||
self, doc: DoclingDocument, element_batch: Iterable[NodeItem]
|
||||
) -> Iterable[Any]:
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
for element in element_batch:
|
||||
assert isinstance(element, PictureItem)
|
||||
assert element.image is not None
|
||||
|
||||
annotation = self._annotate_image(element)
|
||||
element.annotations.append(annotation)
|
||||
|
||||
yield element
|
59
docling/models/pic_description_vllm_model.py
Normal file
59
docling/models/pic_description_vllm_model.py
Normal file
@ -0,0 +1,59 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from docling_core.types.doc import PictureItem
|
||||
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
|
||||
PictureDescriptionData,
|
||||
)
|
||||
|
||||
from docling.datamodel.pipeline_options import PicDescVllmOptions
|
||||
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
|
||||
|
||||
|
||||
class PictureDescriptionVllmModel(PictureDescriptionBaseModel):
|
||||
|
||||
def __init__(self, enabled: bool, options: PicDescVllmOptions):
|
||||
super().__init__(enabled=enabled, options=options)
|
||||
self.options: PicDescVllmOptions
|
||||
|
||||
if self.enabled:
|
||||
raise NotImplemented
|
||||
|
||||
if self.enabled:
|
||||
try:
|
||||
from vllm import LLM, SamplingParams # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"VLLM is not installed. Please install Docling with the required extras `pip install docling[vllm]`."
|
||||
)
|
||||
|
||||
self.sampling_params = SamplingParams(**self.options.sampling_params) # type: ignore
|
||||
self.llm = LLM(model=self.options.llm_name, **self.options.llm_extra) # type: ignore
|
||||
|
||||
# Generate a stable hash from the extra parameters
|
||||
def create_hash(t):
|
||||
return ""
|
||||
|
||||
params_hash = create_hash(
|
||||
json.dumps(self.options.llm_extra, sort_keys=True)
|
||||
+ json.dumps(self.options.sampling_params, sort_keys=True)
|
||||
)
|
||||
self.provenance = f"{self.options.llm_name}-{params_hash[:8]}"
|
||||
|
||||
def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
|
||||
assert picture.image is not None
|
||||
|
||||
from vllm import RequestOutput
|
||||
|
||||
inputs = [
|
||||
{
|
||||
"prompt": self.options.llm_prompt,
|
||||
"multi_modal_data": {"image": picture.image.pil_image},
|
||||
}
|
||||
]
|
||||
outputs: List[RequestOutput] = self.llm.generate( # type: ignore
|
||||
inputs, sampling_params=self.sampling_params # type: ignore
|
||||
)
|
||||
|
||||
generated_text = outputs[0].outputs[0].text
|
||||
return PictureDescriptionData(provenance=self.provenance, text=generated_text)
|
@ -11,6 +11,8 @@ from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import (
|
||||
EasyOcrOptions,
|
||||
PdfPipelineOptions,
|
||||
PicDescApiOptions,
|
||||
PicDescVllmOptions,
|
||||
TesseractCliOcrOptions,
|
||||
TesseractOcrOptions,
|
||||
)
|
||||
@ -23,6 +25,9 @@ from docling.models.page_preprocessing_model import (
|
||||
PagePreprocessingModel,
|
||||
PagePreprocessingOptions,
|
||||
)
|
||||
from docling.models.pic_description_api_model import PictureDescriptionApiModel
|
||||
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
|
||||
from docling.models.pic_description_vllm_model import PictureDescriptionVllmModel
|
||||
from docling.models.table_structure_model import TableStructureModel
|
||||
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
|
||||
from docling.models.tesseract_ocr_model import TesseractOcrModel
|
||||
@ -83,8 +88,15 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
PageAssembleModel(options=PageAssembleOptions(keep_images=keep_images)),
|
||||
]
|
||||
|
||||
# Picture description model
|
||||
if (pic_desc_model := self.get_pic_description_model()) is None:
|
||||
raise RuntimeError(
|
||||
f"The specified picture description kind is not supported: {pipeline_options.picture_description_options.kind}."
|
||||
)
|
||||
|
||||
self.enrichment_pipe = [
|
||||
# Other models working on `NodeItem` elements in the DoclingDocument
|
||||
pic_desc_model,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@ -120,6 +132,23 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
)
|
||||
return None
|
||||
|
||||
def get_pic_description_model(self) -> Optional[PictureDescriptionBaseModel]:
|
||||
if isinstance(
|
||||
self.pipeline_options.picture_description_options, PicDescApiOptions
|
||||
):
|
||||
return PictureDescriptionApiModel(
|
||||
enabled=self.pipeline_options.do_picture_description,
|
||||
options=self.pipeline_options.picture_description_options,
|
||||
)
|
||||
elif isinstance(
|
||||
self.pipeline_options.picture_description_options, PicDescVllmOptions
|
||||
):
|
||||
return PictureDescriptionVllmModel(
|
||||
enabled=self.pipeline_options.do_picture_description,
|
||||
options=self.pipeline_options.picture_description_options,
|
||||
)
|
||||
return None
|
||||
|
||||
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
|
||||
with TimeRecorder(conv_res, "page_init"):
|
||||
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
|
||||
|
Loading…
Reference in New Issue
Block a user