draft for picture description models

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2024-11-06 11:38:12 +01:00
parent 6c22cba0a7
commit e1cba8a825
5 changed files with 279 additions and 2 deletions

View File

@ -1,8 +1,8 @@
from enum import Enum
from pathlib import Path
from typing import List, Literal, Optional, Union
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field
from pydantic import AnyUrl, BaseModel, ConfigDict, Field
class TableFormerMode(str, Enum):
@ -61,6 +61,46 @@ class TesseractOcrOptions(OcrOptions):
)
class PicDescBaseOptions(BaseModel):
kind: str
batch_size: int = 8
scale: float = 2
bitmap_area_threshold: float = (
0.2 # percentage of the area for a bitmap to processed with the models
)
class PicDescApiOptions(PicDescBaseOptions):
kind: Literal["api"] = "api"
url: AnyUrl = AnyUrl("")
headers: Dict[str, str] = {}
params: Dict[str, Any] = {}
timeout: float = 20
llm_prompt: str = ""
provenance: str = ""
class PicDescVllmOptions(PicDescBaseOptions):
kind: Literal["vllm"] = "vllm"
# For more example parameters see https://docs.vllm.ai/en/latest/getting_started/examples/offline_inference_vision_language.html
# Parameters for LLaVA-1.6/LLaVA-NeXT
llm_name: str = "llava-hf/llava-v1.6-mistral-7b-hf"
llm_prompt: str = "[INST] <image>\nDescribe the image in details. [/INST]"
llm_extra: Dict[str, Any] = dict(max_model_len=8192)
# Parameters for Phi-3-Vision
# llm_name: str = "microsoft/Phi-3-vision-128k-instruct"
# llm_prompt: str = "<|user|>\n<|image_1|>\nDescribe the image in details.<|end|>\n<|assistant|>\n"
# llm_extra: Dict[str, Any] = dict(max_num_seqs=5, trust_remote_code=True)
sampling_params: Dict[str, Any] = dict(max_tokens=64, seed=42)
class PipelineOptions(BaseModel):
create_legacy_output: bool = (
True # This defautl will be set to False on a future version of docling
@ -71,11 +111,15 @@ class PdfPipelineOptions(PipelineOptions):
artifacts_path: Optional[Union[Path, str]] = None
do_table_structure: bool = True # True: perform table structure extraction
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
do_picture_description: bool = False
table_structure_options: TableStructureOptions = TableStructureOptions()
ocr_options: Union[EasyOcrOptions, TesseractCliOcrOptions, TesseractOcrOptions] = (
Field(EasyOcrOptions(), discriminator="kind")
)
picture_description_options: Annotated[
Union[PicDescApiOptions, PicDescVllmOptions], Field(discriminator="kind")
] = PicDescApiOptions() # TODO: needs defaults or optional
images_scale: float = 1.0
generate_page_images: bool = False

View File

@ -0,0 +1,99 @@
import base64
import io
import logging
from typing import List, Optional
import httpx
from docling_core.types.doc import PictureItem
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
PictureDescriptionData,
)
from pydantic import BaseModel, ConfigDict
from docling.datamodel.pipeline_options import PicDescApiOptions
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
_log = logging.getLogger(__name__)
class ChatMessage(BaseModel):
role: str
content: str
class ResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: str
class ResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ApiResponse(BaseModel):
model_config = ConfigDict(
protected_namespaces=(),
)
id: str
model: Optional[str] = None # returned bu openai
choices: List[ResponseChoice]
created: int
usage: ResponseUsage
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
def __init__(self, enabled: bool, options: PicDescApiOptions):
super().__init__(enabled=enabled, options=options)
self.options: PicDescApiOptions
def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
assert picture.image is not None
img_io = io.BytesIO()
picture.image.pil_image.save(img_io, "PNG")
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": self.options.llm_prompt,
},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
},
],
}
]
payload = {
"messages": messages,
**self.options.params,
}
r = httpx.post(
str(self.options.url),
headers=self.options.headers,
json=payload,
timeout=self.options.timeout,
)
if not r.is_success:
_log.error(f"Error calling the API. Reponse was {r.text}")
r.raise_for_status()
api_resp = ApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip()
return PictureDescriptionData(
provenance=self.options.provenance,
text=generated_text,
)

View File

@ -0,0 +1,46 @@
import logging
from pathlib import Path
from typing import Any, Iterable
from docling_core.types.doc import (
DoclingDocument,
NodeItem,
PictureClassificationClass,
PictureItem,
)
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
PictureDescriptionData,
)
from docling.datamodel.pipeline_options import PicDescBaseOptions
from docling.models.base_model import BaseEnrichmentModel
class PictureDescriptionBaseModel(BaseEnrichmentModel):
def __init__(self, enabled: bool, options: PicDescBaseOptions):
self.enabled = enabled
self.options = options
self.provenance = "TODO"
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
# TODO: once the image classifier is active, we can differentiate among image types
return self.enabled and isinstance(element, PictureItem)
def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
raise NotImplemented
def __call__(
self, doc: DoclingDocument, element_batch: Iterable[NodeItem]
) -> Iterable[Any]:
if not self.enabled:
return
for element in element_batch:
assert isinstance(element, PictureItem)
assert element.image is not None
annotation = self._annotate_image(element)
element.annotations.append(annotation)
yield element

View File

@ -0,0 +1,59 @@
import json
from typing import List
from docling_core.types.doc import PictureItem
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
PictureDescriptionData,
)
from docling.datamodel.pipeline_options import PicDescVllmOptions
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
class PictureDescriptionVllmModel(PictureDescriptionBaseModel):
def __init__(self, enabled: bool, options: PicDescVllmOptions):
super().__init__(enabled=enabled, options=options)
self.options: PicDescVllmOptions
if self.enabled:
raise NotImplemented
if self.enabled:
try:
from vllm import LLM, SamplingParams # type: ignore
except ImportError:
raise ImportError(
"VLLM is not installed. Please install Docling with the required extras `pip install docling[vllm]`."
)
self.sampling_params = SamplingParams(**self.options.sampling_params) # type: ignore
self.llm = LLM(model=self.options.llm_name, **self.options.llm_extra) # type: ignore
# Generate a stable hash from the extra parameters
def create_hash(t):
return ""
params_hash = create_hash(
json.dumps(self.options.llm_extra, sort_keys=True)
+ json.dumps(self.options.sampling_params, sort_keys=True)
)
self.provenance = f"{self.options.llm_name}-{params_hash[:8]}"
def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
assert picture.image is not None
from vllm import RequestOutput
inputs = [
{
"prompt": self.options.llm_prompt,
"multi_modal_data": {"image": picture.image.pil_image},
}
]
outputs: List[RequestOutput] = self.llm.generate( # type: ignore
inputs, sampling_params=self.sampling_params # type: ignore
)
generated_text = outputs[0].outputs[0].text
return PictureDescriptionData(provenance=self.provenance, text=generated_text)

View File

@ -11,6 +11,8 @@ from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
EasyOcrOptions,
PdfPipelineOptions,
PicDescApiOptions,
PicDescVllmOptions,
TesseractCliOcrOptions,
TesseractOcrOptions,
)
@ -23,6 +25,9 @@ from docling.models.page_preprocessing_model import (
PagePreprocessingModel,
PagePreprocessingOptions,
)
from docling.models.pic_description_api_model import PictureDescriptionApiModel
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
from docling.models.pic_description_vllm_model import PictureDescriptionVllmModel
from docling.models.table_structure_model import TableStructureModel
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
from docling.models.tesseract_ocr_model import TesseractOcrModel
@ -83,8 +88,15 @@ class StandardPdfPipeline(PaginatedPipeline):
PageAssembleModel(options=PageAssembleOptions(keep_images=keep_images)),
]
# Picture description model
if (pic_desc_model := self.get_pic_description_model()) is None:
raise RuntimeError(
f"The specified picture description kind is not supported: {pipeline_options.picture_description_options.kind}."
)
self.enrichment_pipe = [
# Other models working on `NodeItem` elements in the DoclingDocument
pic_desc_model,
]
@staticmethod
@ -120,6 +132,23 @@ class StandardPdfPipeline(PaginatedPipeline):
)
return None
def get_pic_description_model(self) -> Optional[PictureDescriptionBaseModel]:
if isinstance(
self.pipeline_options.picture_description_options, PicDescApiOptions
):
return PictureDescriptionApiModel(
enabled=self.pipeline_options.do_picture_description,
options=self.pipeline_options.picture_description_options,
)
elif isinstance(
self.pipeline_options.picture_description_options, PicDescVllmOptions
):
return PictureDescriptionVllmModel(
enabled=self.pipeline_options.do_picture_description,
options=self.pipeline_options.picture_description_options,
)
return None
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
with TimeRecorder(conv_res, "page_init"):
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore