mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-26 20:14:47 +00:00
added the VlmPredictionToken
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
f159075b67
commit
7c97b494ec
@ -1,6 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
from docling_core.types.io import (
|
||||
DocumentStream,
|
||||
)
|
||||
|
||||
from docling_core.types.doc import (
|
||||
BoundingBox,
|
||||
DocItemLabel,
|
||||
@ -12,9 +16,6 @@ from docling_core.types.doc import (
|
||||
from docling_core.types.doc.page import SegmentedPdfPage, TextCell
|
||||
|
||||
# DO NOT REMOVE; explicitly exposed from this location
|
||||
from docling_core.types.io import (
|
||||
DocumentStream,
|
||||
)
|
||||
from PIL.Image import Image
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
@ -127,12 +128,6 @@ class ErrorItem(BaseModel):
|
||||
error_message: str
|
||||
|
||||
|
||||
# class Cell(BaseModel):
|
||||
# id: int
|
||||
# text: str
|
||||
# bbox: BoundingBox
|
||||
|
||||
|
||||
class Cluster(BaseModel):
|
||||
id: int
|
||||
label: DocItemLabel
|
||||
@ -153,9 +148,15 @@ class BasePageElement(BaseModel):
|
||||
class LayoutPrediction(BaseModel):
|
||||
clusters: List[Cluster] = []
|
||||
|
||||
class VlmPredictionToken(BaseModel):
|
||||
text: str = ""
|
||||
token: int = -1
|
||||
logprob: float = -1
|
||||
|
||||
class VlmPrediction(BaseModel):
|
||||
text: str = ""
|
||||
generated_tokens: list[VlmPredictionToken] = -1
|
||||
generation_time: float = -1
|
||||
|
||||
|
||||
class ContainerElement(
|
||||
|
@ -20,30 +20,8 @@ import filetype
|
||||
|
||||
# DO NOT REMOVE; explicitly exposed from this location
|
||||
from docling_core.types.doc import (
|
||||
DocItem,
|
||||
DocItemLabel,
|
||||
DoclingDocument,
|
||||
PictureItem,
|
||||
SectionHeaderItem,
|
||||
TableItem,
|
||||
TextItem,
|
||||
)
|
||||
from docling_core.types.doc.document import ListItem
|
||||
from docling_core.types.legacy_doc.base import (
|
||||
BaseText,
|
||||
Figure,
|
||||
GlmTableCell,
|
||||
PageDimensions,
|
||||
PageReference,
|
||||
Prov,
|
||||
Ref,
|
||||
Table as DsSchemaTable,
|
||||
TableCell,
|
||||
)
|
||||
from docling_core.types.legacy_doc.document import (
|
||||
CCSDocumentDescription as DsDocumentDescription,
|
||||
CCSFileInfoObject as DsFileInfoObject,
|
||||
ExportedCCSDocument as DsDocument,
|
||||
)
|
||||
from docling_core.utils.file import resolve_source_to_stream
|
||||
from docling_core.utils.legacy import docling_document_to_legacy
|
||||
|
@ -1,19 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from docling.datamodel.base_models import Page, VlmPrediction
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorOptions,
|
||||
HuggingFaceVlmOptions,
|
||||
)
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -4,7 +4,7 @@ from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from docling.datamodel.base_models import Page, VlmPrediction
|
||||
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorOptions,
|
||||
@ -29,6 +29,8 @@ class HuggingFaceMlxModel(BasePageModel):
|
||||
|
||||
self.vlm_options = vlm_options
|
||||
|
||||
self.max_tokens=4096
|
||||
|
||||
if self.enabled:
|
||||
try:
|
||||
from mlx_vlm import generate, load # type: ignore
|
||||
@ -40,29 +42,32 @@ class HuggingFaceMlxModel(BasePageModel):
|
||||
)
|
||||
|
||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||
print(f"model init: {repo_cache_folder}")
|
||||
_log.debug(f"model init: {repo_cache_folder}")
|
||||
|
||||
self.apply_chat_template = apply_chat_template
|
||||
self.stream_generate = stream_generate
|
||||
|
||||
# PARAMETERS:
|
||||
if artifacts_path is None:
|
||||
print(f"before HuggingFaceVlmModel.download_models: {self.vlm_options.repo_id}")
|
||||
_log.debug(
|
||||
f"before HuggingFaceVlmModel.download_models: {self.vlm_options.repo_id}"
|
||||
)
|
||||
# artifacts_path = self.download_models(self.vlm_options.repo_id)
|
||||
artifacts_path = HuggingFaceVlmModel.download_models(
|
||||
self.vlm_options.repo_id, progress=True,
|
||||
self.vlm_options.repo_id,
|
||||
progress=True,
|
||||
)
|
||||
elif (artifacts_path / repo_cache_folder).exists():
|
||||
artifacts_path = artifacts_path / repo_cache_folder
|
||||
|
||||
print(f"downloaded model: {artifacts_path}")
|
||||
_log.debug(f"downloaded model: {artifacts_path}")
|
||||
|
||||
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
|
||||
|
||||
## Load the model
|
||||
print("start loading model ...")
|
||||
_log.debug("start loading model ...")
|
||||
self.vlm_model, self.processor = load(artifacts_path)
|
||||
print("loaded model ...")
|
||||
_log.debug("loaded model ...")
|
||||
self.config = load_config(artifacts_path)
|
||||
|
||||
"""
|
||||
@ -117,9 +122,11 @@ class HuggingFaceMlxModel(BasePageModel):
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
print("start generating ...")
|
||||
_log.debug("start generating ...")
|
||||
|
||||
# Call model to generate:
|
||||
tokens:list[VlmPredictionToken] = []
|
||||
|
||||
output = ""
|
||||
for token in self.stream_generate(
|
||||
self.vlm_model,
|
||||
@ -129,23 +136,31 @@ class HuggingFaceMlxModel(BasePageModel):
|
||||
max_tokens=4096,
|
||||
verbose=False,
|
||||
):
|
||||
print(token.text, end="", flush=True)
|
||||
print(token.logprobs.shape)
|
||||
if len(token.logprobs.shape)==1:
|
||||
tokens.append(VlmPredictionToken(text=token.text,
|
||||
token=token.token,
|
||||
logprob=token.logprobs[token.token]))
|
||||
elif len(token.logprobs.shape)==2 and token.logprobs.shape[0]==1:
|
||||
tokens.append(VlmPredictionToken(text=token.text,
|
||||
token=token.token,
|
||||
logprob=token.logprobs[0, token.token]))
|
||||
|
||||
|
||||
# print(token.text, end="", flush=True)
|
||||
output += token.text
|
||||
|
||||
if "</doctag>" in token.text:
|
||||
break
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
page_tags = output
|
||||
|
||||
_log.debug(f"Generation time {generation_time:.2f} seconds.")
|
||||
print(tokens)
|
||||
|
||||
# inference_time = time.time() - start_time
|
||||
# tokens_per_second = num_tokens / generation_time
|
||||
# print("")
|
||||
# print(f"Page Inference Time: {inference_time:.2f} seconds")
|
||||
# print(f"Total tokens on page: {num_tokens:.2f}")
|
||||
# print(f"Tokens/sec: {tokens_per_second:.2f}")
|
||||
# print("")
|
||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
||||
_log.debug(f"Generation time {generation_time:.2f} seconds.")
|
||||
page.predictions.vlm_response = VlmPrediction(text=page_tags,
|
||||
generation_time=generation_time,
|
||||
generated_tokens=tokens)
|
||||
|
||||
yield page
|
||||
|
@ -170,7 +170,7 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
||||
|
||||
def formulate_prompt(self) -> str:
|
||||
"""Formulate a prompt for the VLM."""
|
||||
if self.vlm_options.repo_id=="microsoft/Phi-4-multimodal-instruct":
|
||||
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
||||
user_prompt = "<|user|>"
|
||||
assistant_prompt = "<|assistant|>"
|
||||
prompt_suffix = "<|end|>"
|
||||
@ -183,6 +183,4 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
||||
else:
|
||||
raise ValueError(f"No prompt template for {self.vlm_options.repo_id}")
|
||||
|
||||
|
||||
return ""
|
||||
|
||||
|
@ -4,8 +4,6 @@ from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
||||
|
||||
from docling.datamodel.base_models import Page, VlmPrediction
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import (
|
||||
@ -35,7 +33,6 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
||||
self.vlm_options = vlm_options
|
||||
|
||||
if self.enabled:
|
||||
import torch
|
||||
from transformers import ( # type: ignore
|
||||
AutoProcessor,
|
||||
LlavaForConditionalGeneration,
|
||||
@ -116,6 +113,8 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
||||
use_cache=self.use_cache, # Enables KV caching which can improve performance
|
||||
)
|
||||
|
||||
print(generate_ids)
|
||||
|
||||
num_tokens = len(generate_ids[0])
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
@ -125,32 +124,23 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
||||
clean_up_tokenization_spaces=False,
|
||||
)[0]
|
||||
|
||||
"""
|
||||
_log.debug(
|
||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||
)
|
||||
"""
|
||||
# inference_time = time.time() - start_time
|
||||
# tokens_per_second = num_tokens / generation_time
|
||||
# print("")
|
||||
# print(f"Page Inference Time: {inference_time:.2f} seconds")
|
||||
# print(f"Total tokens on page: {num_tokens:.2f}")
|
||||
# print(f"Tokens/sec: {tokens_per_second:.2f}")
|
||||
# print("")
|
||||
page.predictions.vlm_response = VlmPrediction(text=response)
|
||||
page.predictions.vlm_response = VlmPrediction(text=response,
|
||||
generated_tokens=num_tokens,
|
||||
generation_time=generation_time)
|
||||
|
||||
yield page
|
||||
|
||||
def formulate_prompt(self) -> str:
|
||||
"""Formulate a prompt for the VLM."""
|
||||
if self.vlm_options.repo_id=="mistral-community/pixtral-12b":
|
||||
#prompt = f"<s>[INST]{self.vlm_options.prompt}\n[IMG][/INST]"
|
||||
if self.vlm_options.repo_id == "mistral-community/pixtral-12b":
|
||||
# prompt = f"<s>[INST]{self.vlm_options.prompt}\n[IMG][/INST]"
|
||||
chat = [
|
||||
{
|
||||
"role": "user", "content": [
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "content": self.vlm_options.prompt},
|
||||
{"type": "image"},
|
||||
]
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt = self.processor.apply_chat_template(chat)
|
||||
|
@ -11,7 +11,6 @@ from docling.datamodel.pipeline_options import (
|
||||
InferenceFramework,
|
||||
ResponseFormat,
|
||||
VlmPipelineOptions,
|
||||
smoldocling_vlm_mlx_conversion_options,
|
||||
)
|
||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||
from docling.pipeline.vlm_pipeline import VlmPipeline
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from docling_core.types.doc import ImageRefMode, TableItem, TextItem
|
||||
|
Loading…
Reference in New Issue
Block a user