mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
introduce img understand pipeline
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
1f4b224ab6
commit
a122a7be4c
@ -224,18 +224,27 @@ class TableStructurePrediction(BaseModel):
|
|||||||
class TextElement(BasePageElement): ...
|
class TextElement(BasePageElement): ...
|
||||||
|
|
||||||
|
|
||||||
|
class FigureClassificationData(BaseModel):
|
||||||
|
provenance: str
|
||||||
|
predicted_class: str
|
||||||
|
confidence: float
|
||||||
|
|
||||||
|
|
||||||
|
class FigureDescriptionData(BaseModel):
|
||||||
|
text: str
|
||||||
|
provenance: str = ""
|
||||||
|
|
||||||
|
|
||||||
class FigureData(BaseModel):
|
class FigureData(BaseModel):
|
||||||
pass
|
classification: Optional[FigureClassificationData] = None
|
||||||
|
description: Optional[FigureDescriptionData] = None
|
||||||
|
|
||||||
|
|
||||||
class FigureElement(BasePageElement):
|
class FigureElement(BasePageElement):
|
||||||
data: Optional[FigureData] = None
|
data: FigureData = FigureData()
|
||||||
provenance: Optional[str] = None
|
|
||||||
predicted_class: Optional[str] = None
|
|
||||||
confidence: Optional[float] = None
|
|
||||||
|
|
||||||
|
|
||||||
class FigureClassificationPrediction(BaseModel):
|
class FigurePrediction(BaseModel):
|
||||||
figure_count: int = 0
|
figure_count: int = 0
|
||||||
figure_map: Dict[int, FigureElement] = {}
|
figure_map: Dict[int, FigureElement] = {}
|
||||||
|
|
||||||
@ -248,7 +257,7 @@ class EquationPrediction(BaseModel):
|
|||||||
class PagePredictions(BaseModel):
|
class PagePredictions(BaseModel):
|
||||||
layout: Optional[LayoutPrediction] = None
|
layout: Optional[LayoutPrediction] = None
|
||||||
tablestructure: Optional[TableStructurePrediction] = None
|
tablestructure: Optional[TableStructurePrediction] = None
|
||||||
figures_classification: Optional[FigureClassificationPrediction] = None
|
figures_prediction: Optional[FigurePrediction] = None
|
||||||
equations_prediction: Optional[EquationPrediction] = None
|
equations_prediction: Optional[EquationPrediction] = None
|
||||||
|
|
||||||
|
|
||||||
|
123
docling/models/img_understand_api_model.py
Normal file
123
docling/models/img_understand_api_model.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
import base64
|
||||||
|
import datetime
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from PIL import Image
|
||||||
|
from pydantic import AnyUrl, BaseModel, ConfigDict
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import Cluster, FigureDescriptionData
|
||||||
|
from docling.models.img_understand_base_model import (
|
||||||
|
ImgUnderstandBaseModel,
|
||||||
|
ImgUnderstandOptions,
|
||||||
|
)
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ImgUnderstandApiOptions(ImgUnderstandOptions):
|
||||||
|
kind: Literal["api"] = "api"
|
||||||
|
|
||||||
|
url: AnyUrl
|
||||||
|
headers: Dict[str, str]
|
||||||
|
params: Dict[str, Any]
|
||||||
|
timeout: float = 20
|
||||||
|
|
||||||
|
llm_prompt: str
|
||||||
|
provenance: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
message: ChatMessage
|
||||||
|
finish_reason: str
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseUsage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class ApiResponse(BaseModel):
|
||||||
|
model_config = ConfigDict(
|
||||||
|
protected_namespaces=(),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: str
|
||||||
|
model_id: Optional[str] = None # returned by watsonx
|
||||||
|
model: Optional[str] = None # returned bu openai
|
||||||
|
choices: List[ResponseChoice]
|
||||||
|
created: int
|
||||||
|
usage: ResponseUsage
|
||||||
|
|
||||||
|
|
||||||
|
class ImgUnderstandApiModel(ImgUnderstandBaseModel):
|
||||||
|
|
||||||
|
def __init__(self, enabled: bool, options: ImgUnderstandApiOptions):
|
||||||
|
super().__init__(enabled=enabled, options=options)
|
||||||
|
self.options: ImgUnderstandApiOptions
|
||||||
|
|
||||||
|
def _annotate_image_batch(
|
||||||
|
self, batch: Iterable[Tuple[Cluster, Image.Image]]
|
||||||
|
) -> List[FigureDescriptionData]:
|
||||||
|
|
||||||
|
if not self.enabled:
|
||||||
|
return [FigureDescriptionData() for _ in batch]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for cluster, image in batch:
|
||||||
|
img_io = io.BytesIO()
|
||||||
|
image.save(img_io, "PNG")
|
||||||
|
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": self.options.llm_prompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/png;base64,{image_base64}"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"messages": messages,
|
||||||
|
**self.options.params,
|
||||||
|
}
|
||||||
|
|
||||||
|
r = httpx.post(
|
||||||
|
str(self.options.url),
|
||||||
|
headers=self.options.headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=self.options.timeout,
|
||||||
|
)
|
||||||
|
if not r.is_success:
|
||||||
|
_log.error(f"Error calling the API. Reponse was {r.text}")
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
api_resp = ApiResponse.model_validate_json(r.text)
|
||||||
|
generated_text = api_resp.choices[0].message.content.strip()
|
||||||
|
results.append(
|
||||||
|
FigureDescriptionData(
|
||||||
|
text=generated_text, provenance=self.options.provenance
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_log.info(f"Generated description: {generated_text}")
|
||||||
|
|
||||||
|
return results
|
145
docling/models/img_understand_base_model.py
Normal file
145
docling/models/img_understand_base_model.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Iterable, List, Literal, Tuple
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import (
|
||||||
|
Cluster,
|
||||||
|
FigureData,
|
||||||
|
FigureDescriptionData,
|
||||||
|
FigureElement,
|
||||||
|
FigurePrediction,
|
||||||
|
Page,
|
||||||
|
)
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ImgUnderstandOptions(BaseModel):
|
||||||
|
kind: str
|
||||||
|
batch_size: int = 8
|
||||||
|
scale: float = 2
|
||||||
|
|
||||||
|
# if the relative area of the image with respect to the whole image page
|
||||||
|
# is larger than this threshold it will be processed, otherwise not.
|
||||||
|
# TODO: implement the skip logic
|
||||||
|
min_area: float = 0.05
|
||||||
|
|
||||||
|
|
||||||
|
class ImgUnderstandBaseModel:
|
||||||
|
|
||||||
|
def __init__(self, enabled: bool, options: ImgUnderstandOptions):
|
||||||
|
self.enabled = enabled
|
||||||
|
self.options = options
|
||||||
|
|
||||||
|
def _annotate_image_batch(
|
||||||
|
self, batch: Iterable[Tuple[Cluster, Image.Image]]
|
||||||
|
) -> List[FigureDescriptionData]:
|
||||||
|
raise NotImplemented()
|
||||||
|
|
||||||
|
def _flush_merge(
|
||||||
|
self,
|
||||||
|
page: Page,
|
||||||
|
cluster_figure_batch: List[Tuple[Cluster, Image.Image]],
|
||||||
|
figures_prediction: FigurePrediction,
|
||||||
|
):
|
||||||
|
start_time = time.time()
|
||||||
|
results_batch = self._annotate_image_batch(cluster_figure_batch)
|
||||||
|
assert len(results_batch) == len(
|
||||||
|
cluster_figure_batch
|
||||||
|
), "The returned annotations is not matching the input size"
|
||||||
|
end_time = time.time()
|
||||||
|
_log.info(
|
||||||
|
f"Batch of {len(results_batch)} images processed in {end_time-start_time:.1f} seconds. Time per image is {(end_time-start_time) / len(results_batch):.3f} seconds."
|
||||||
|
)
|
||||||
|
|
||||||
|
for (cluster, _), desc_data in zip(cluster_figure_batch, results_batch):
|
||||||
|
if not cluster.id in figures_prediction.figure_map:
|
||||||
|
figures_prediction.figure_map[cluster.id] = FigureElement(
|
||||||
|
label=cluster.label,
|
||||||
|
id=cluster.id,
|
||||||
|
data=FigureData(desciption=desc_data),
|
||||||
|
cluster=cluster,
|
||||||
|
page_no=page.page_no,
|
||||||
|
)
|
||||||
|
elif figures_prediction.figure_map[cluster.id].data.description is None:
|
||||||
|
figures_prediction.figure_map[cluster.id].data.description = desc_data
|
||||||
|
else:
|
||||||
|
_log.warning(
|
||||||
|
f"Conflicting predictions. "
|
||||||
|
f"Another model ({figures_prediction.figure_map[cluster.id].data.description.provenance}) "
|
||||||
|
f"was already predicting an image description. The new prediction will be skipped."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||||
|
|
||||||
|
if not self.enabled:
|
||||||
|
yield from page_batch
|
||||||
|
return
|
||||||
|
|
||||||
|
for page in page_batch:
|
||||||
|
|
||||||
|
# This model could be the first one initializing figures_prediction
|
||||||
|
if page.predictions.figures_prediction is None:
|
||||||
|
page.predictions.figures_prediction = FigurePrediction()
|
||||||
|
|
||||||
|
# Select the picture clusters
|
||||||
|
in_clusters = []
|
||||||
|
for cluster in page.predictions.layout.clusters:
|
||||||
|
if cluster.label != "Picture":
|
||||||
|
continue
|
||||||
|
|
||||||
|
crop_bbox = cluster.bbox.scaled(
|
||||||
|
scale=self.options.scale
|
||||||
|
).to_top_left_origin(page_height=page.size.height * self.options.scale)
|
||||||
|
in_clusters.append(
|
||||||
|
(
|
||||||
|
cluster,
|
||||||
|
crop_bbox.as_tuple(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not len(in_clusters):
|
||||||
|
yield page
|
||||||
|
continue
|
||||||
|
|
||||||
|
# save classifications using proper object
|
||||||
|
if (
|
||||||
|
page.predictions.figures_prediction.figure_count > 0
|
||||||
|
and page.predictions.figures_prediction.figure_count != len(in_clusters)
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
"Different models predicted a different number of figures."
|
||||||
|
)
|
||||||
|
page.predictions.figures_prediction.figure_count = len(in_clusters)
|
||||||
|
|
||||||
|
cluster_figure_batch = []
|
||||||
|
page_image = page.get_image(scale=self.options.scale)
|
||||||
|
if page_image is None:
|
||||||
|
raise RuntimeError("The page image cannot be generated.")
|
||||||
|
|
||||||
|
for cluster, figure_bbox in in_clusters:
|
||||||
|
figure = page_image.crop(figure_bbox)
|
||||||
|
cluster_figure_batch.append((cluster, figure))
|
||||||
|
|
||||||
|
# if enough figures then flush
|
||||||
|
if len(cluster_figure_batch) == self.options.batch_size:
|
||||||
|
self._flush_merge(
|
||||||
|
page=page,
|
||||||
|
cluster_figure_batch=cluster_figure_batch,
|
||||||
|
figures_prediction=page.predictions.figures_prediction,
|
||||||
|
)
|
||||||
|
cluster_figure_batch = []
|
||||||
|
|
||||||
|
# final flush
|
||||||
|
if len(cluster_figure_batch) > 0:
|
||||||
|
self._flush_merge(
|
||||||
|
page=page,
|
||||||
|
cluster_figure_batch=cluster_figure_batch,
|
||||||
|
figures_prediction=page.predictions.figures_prediction,
|
||||||
|
)
|
||||||
|
cluster_figure_batch = []
|
||||||
|
|
||||||
|
yield page
|
87
docling/models/img_understand_vllm_model.py
Normal file
87
docling/models/img_understand_vllm_model.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Iterable, List, Literal, Tuple
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import Cluster, FigureDescriptionData
|
||||||
|
from docling.models.img_understand_base_model import (
|
||||||
|
ImgUnderstandBaseModel,
|
||||||
|
ImgUnderstandOptions,
|
||||||
|
)
|
||||||
|
from docling.utils.utils import create_hash
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ImgUnderstandVllmOptions(ImgUnderstandOptions):
|
||||||
|
kind: Literal["vllm"] = "vllm"
|
||||||
|
|
||||||
|
# For more example parameters see https://docs.vllm.ai/en/latest/getting_started/examples/offline_inference_vision_language.html
|
||||||
|
|
||||||
|
# Parameters for LLaVA-1.6/LLaVA-NeXT
|
||||||
|
llm_name: str = "llava-hf/llava-v1.6-mistral-7b-hf"
|
||||||
|
llm_prompt: str = "[INST] <image>\nDescribe the image in details. [/INST]"
|
||||||
|
llm_extra: Dict[str, Any] = dict(max_model_len=8192)
|
||||||
|
|
||||||
|
# Parameters for Phi-3-Vision
|
||||||
|
# llm_name: str = "microsoft/Phi-3-vision-128k-instruct"
|
||||||
|
# llm_prompt: str = "<|user|>\n<|image_1|>\nDescribe the image in details.<|end|>\n<|assistant|>\n"
|
||||||
|
# llm_extra: Dict[str, Any] = dict(max_num_seqs=5, trust_remote_code=True)
|
||||||
|
|
||||||
|
sampling_params: Dict[str, Any] = dict(max_tokens=64, seed=42)
|
||||||
|
|
||||||
|
|
||||||
|
class ImgUnderstandVllmModel(ImgUnderstandBaseModel):
|
||||||
|
|
||||||
|
def __init__(self, enabled: bool, options: ImgUnderstandVllmOptions):
|
||||||
|
super().__init__(enabled=enabled, options=options)
|
||||||
|
self.options: ImgUnderstandVllmOptions
|
||||||
|
|
||||||
|
if self.enabled:
|
||||||
|
try:
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"VLLM is not installed. Please install Docling with the required extras `pip install docling[vllm]`."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sampling_params = SamplingParams(**self.options.sampling_params)
|
||||||
|
self.llm = LLM(model=self.options.llm_name, **self.options.llm_extra)
|
||||||
|
|
||||||
|
# Generate a stable hash from the extra parameters
|
||||||
|
params_hash = create_hash(
|
||||||
|
json.dumps(self.options.llm_extra, sort_keys=True)
|
||||||
|
+ json.dumps(self.options.sampling_params, sort_keys=True)
|
||||||
|
)
|
||||||
|
self.provenance = f"{self.options.llm_name}-{params_hash[:8]}"
|
||||||
|
|
||||||
|
def _annotate_image_batch(
|
||||||
|
self, batch: Iterable[Tuple[Cluster, Image.Image]]
|
||||||
|
) -> List[FigureDescriptionData]:
|
||||||
|
|
||||||
|
if not self.enabled:
|
||||||
|
return [FigureDescriptionData() for _ in batch]
|
||||||
|
|
||||||
|
from vllm import RequestOutput
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
{
|
||||||
|
"prompt": self.options.llm_prompt,
|
||||||
|
"multi_modal_data": {"image": im},
|
||||||
|
}
|
||||||
|
for _, im in batch
|
||||||
|
]
|
||||||
|
outputs: List[RequestOutput] = self.llm.generate(
|
||||||
|
inputs, sampling_params=self.sampling_params
|
||||||
|
)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
results.append(
|
||||||
|
FigureDescriptionData(text=generated_text, provenance=self.provenance)
|
||||||
|
)
|
||||||
|
_log.info(f"Generated description: {generated_text}")
|
||||||
|
|
||||||
|
return results
|
@ -98,18 +98,17 @@ class PageAssembleModel:
|
|||||||
body.append(tbl)
|
body.append(tbl)
|
||||||
elif cluster.label == LayoutModel.FIGURE_LABEL:
|
elif cluster.label == LayoutModel.FIGURE_LABEL:
|
||||||
fig = None
|
fig = None
|
||||||
if page.predictions.figures_classification:
|
if page.predictions.figures_prediction:
|
||||||
fig = page.predictions.figures_classification.figure_map.get(
|
fig = page.predictions.figures_prediction.figure_map.get(
|
||||||
cluster.id, None
|
cluster.id, None
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
not fig
|
not fig
|
||||||
): # fallback: add figure without classification, if it isn't present
|
): # fallback: add figure with default data, if it isn't present
|
||||||
fig = FigureElement(
|
fig = FigureElement(
|
||||||
label=cluster.label,
|
label=cluster.label,
|
||||||
id=cluster.id,
|
id=cluster.id,
|
||||||
text="",
|
text="",
|
||||||
data=None,
|
|
||||||
cluster=cluster,
|
cluster=cluster,
|
||||||
page_no=page.page_no,
|
page_no=page.page_no,
|
||||||
)
|
)
|
||||||
|
53
docling/pipeline/img_understand_pipeline.py
Normal file
53
docling/pipeline/img_understand_pipeline.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import PipelineOptions
|
||||||
|
from docling.models.img_understand_api_model import (
|
||||||
|
ImgUnderstandApiModel,
|
||||||
|
ImgUnderstandApiOptions,
|
||||||
|
)
|
||||||
|
from docling.models.img_understand_vllm_model import (
|
||||||
|
ImgUnderstandVllmModel,
|
||||||
|
ImgUnderstandVllmOptions,
|
||||||
|
)
|
||||||
|
from docling.pipeline.standard_model_pipeline import StandardModelPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class ImgUnderstandPipelineOptions(PipelineOptions):
|
||||||
|
do_img_understand: bool = True
|
||||||
|
img_understand_options: Union[ImgUnderstandApiOptions, ImgUnderstandVllmOptions] = (
|
||||||
|
Field(ImgUnderstandVllmOptions(), discriminator="kind")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImgUnderstandPipeline(StandardModelPipeline):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, artifacts_path: Path, pipeline_options: ImgUnderstandPipelineOptions
|
||||||
|
):
|
||||||
|
super().__init__(artifacts_path, pipeline_options)
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
pipeline_options.img_understand_options, ImgUnderstandVllmOptions
|
||||||
|
):
|
||||||
|
self.model_pipe.append(
|
||||||
|
ImgUnderstandVllmModel(
|
||||||
|
enabled=pipeline_options.do_img_understand,
|
||||||
|
options=pipeline_options.img_understand_options,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(
|
||||||
|
pipeline_options.img_understand_options, ImgUnderstandApiOptions
|
||||||
|
):
|
||||||
|
self.model_pipe.append(
|
||||||
|
ImgUnderstandApiModel(
|
||||||
|
enabled=pipeline_options.do_img_understand,
|
||||||
|
options=pipeline_options.img_understand_options,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The specified imgage understanding kind is not supported: {pipeline_options.img_understand_options.kind}."
|
||||||
|
)
|
132
examples/img_understand_pipeline.py
Normal file
132
examples/img_understand_pipeline.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import ConversionStatus
|
||||||
|
from docling.datamodel.document import ConversionResult, DocumentConversionInput
|
||||||
|
from docling.document_converter import DocumentConverter
|
||||||
|
from docling.pipeline.img_understand_pipeline import (
|
||||||
|
ImgUnderstandApiOptions,
|
||||||
|
ImgUnderstandPipeline,
|
||||||
|
ImgUnderstandPipelineOptions,
|
||||||
|
ImgUnderstandVllmOptions,
|
||||||
|
)
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def export_documents(
|
||||||
|
conv_results: Iterable[ConversionResult],
|
||||||
|
output_dir: Path,
|
||||||
|
):
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
success_count = 0
|
||||||
|
failure_count = 0
|
||||||
|
|
||||||
|
for conv_res in conv_results:
|
||||||
|
if conv_res.status == ConversionStatus.SUCCESS:
|
||||||
|
success_count += 1
|
||||||
|
doc_filename = conv_res.input.file.stem
|
||||||
|
|
||||||
|
# # Export Deep Search document JSON format:
|
||||||
|
# with (output_dir / f"{doc_filename}.json").open("w") as fp:
|
||||||
|
# fp.write(json.dumps(conv_res.render_as_dict()))
|
||||||
|
|
||||||
|
# # Export Text format:
|
||||||
|
# with (output_dir / f"{doc_filename}.txt").open("w") as fp:
|
||||||
|
# fp.write(conv_res.render_as_text())
|
||||||
|
|
||||||
|
# # Export Markdown format:
|
||||||
|
# with (output_dir / f"{doc_filename}.md").open("w") as fp:
|
||||||
|
# fp.write(conv_res.render_as_markdown())
|
||||||
|
|
||||||
|
# # Export Document Tags format:
|
||||||
|
# with (output_dir / f"{doc_filename}.doctags").open("w") as fp:
|
||||||
|
# fp.write(conv_res.render_as_doctags())
|
||||||
|
|
||||||
|
else:
|
||||||
|
_log.info(f"Document {conv_res.input.file} failed to convert.")
|
||||||
|
failure_count += 1
|
||||||
|
|
||||||
|
_log.info(
|
||||||
|
f"Processed {success_count + failure_count} docs, of which {failure_count} failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success_count, failure_count
|
||||||
|
|
||||||
|
|
||||||
|
def _get_iam_access_token(api_key: str) -> str:
|
||||||
|
res = httpx.post(
|
||||||
|
url="https://iam.cloud.ibm.com/identity/token",
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/x-www-form-urlencoded",
|
||||||
|
},
|
||||||
|
data=f"grant_type=urn:ibm:params:oauth:grant-type:apikey&apikey={api_key}",
|
||||||
|
)
|
||||||
|
res.raise_for_status()
|
||||||
|
api_out = res.json()
|
||||||
|
print(f"{api_out=}")
|
||||||
|
return api_out["access_token"]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
input_doc_paths = [
|
||||||
|
Path("./tests/data/2206.01062.pdf"),
|
||||||
|
]
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
api_key = os.environ.get("WX_API_KEY")
|
||||||
|
project_id = os.environ.get("WX_PROJECT_ID")
|
||||||
|
|
||||||
|
doc_converter = DocumentConverter(
|
||||||
|
pipeline_cls=ImgUnderstandPipeline,
|
||||||
|
# TODO: make DocumentConverter provide the correct default value
|
||||||
|
# for pipeline_options, given the pipeline_cls
|
||||||
|
pipeline_options=ImgUnderstandPipelineOptions(
|
||||||
|
img_understand_options=ImgUnderstandApiOptions(
|
||||||
|
url="https://us-south.ml.cloud.ibm.com/ml/v1/text/chat?version=2023-05-29",
|
||||||
|
headers={
|
||||||
|
"Authorization": "Bearer " + _get_iam_access_token(api_key=api_key),
|
||||||
|
},
|
||||||
|
params=dict(
|
||||||
|
model_id="meta-llama/llama3-llava-next-8b-hf",
|
||||||
|
project_id=project_id,
|
||||||
|
max_tokens=512,
|
||||||
|
seed=42,
|
||||||
|
),
|
||||||
|
llm_prompt="Describe this figure in three sentences.",
|
||||||
|
provenance="llama3-llava-next-8b-hf",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Define input files
|
||||||
|
input = DocumentConversionInput.from_paths(input_doc_paths)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
conv_results = doc_converter.convert(input)
|
||||||
|
success_count, failure_count = export_documents(
|
||||||
|
conv_results, output_dir=Path("./scratch")
|
||||||
|
)
|
||||||
|
|
||||||
|
end_time = time.time() - start_time
|
||||||
|
|
||||||
|
_log.info(f"All documents were converted in {end_time:.2f} seconds.")
|
||||||
|
|
||||||
|
if failure_count > 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The example failed converting {failure_count} on {len(input_doc_paths)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -41,6 +41,7 @@ pyarrow = "^16.1.0"
|
|||||||
#########
|
#########
|
||||||
# extras:
|
# extras:
|
||||||
#########
|
#########
|
||||||
|
# vllm = { version = "^0.5.0", optional = true, markers = "sys_platform != 'darwin' or platform_machine != 'x86_64'" }
|
||||||
python-dotenv = { version = "^1.0.1", optional = true }
|
python-dotenv = { version = "^1.0.1", optional = true }
|
||||||
llama-index-embeddings-huggingface = { version = "^0.3.1", optional = true }
|
llama-index-embeddings-huggingface = { version = "^0.3.1", optional = true }
|
||||||
llama-index-llms-huggingface-api = { version = "^0.2.0", optional = true }
|
llama-index-llms-huggingface-api = { version = "^0.2.0", optional = true }
|
||||||
@ -84,6 +85,7 @@ nbqa = "^1.9.0"
|
|||||||
datasets = "^2.21.0"
|
datasets = "^2.21.0"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
# vllm = ["vllm"]
|
||||||
examples = [
|
examples = [
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
# LlamaIndex examples:
|
# LlamaIndex examples:
|
||||||
|
Loading…
Reference in New Issue
Block a user