mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
added the VlmPredictionToken
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
f159075b67
commit
7c97b494ec
@ -1,6 +1,10 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from docling_core.types.io import (
|
||||||
|
DocumentStream,
|
||||||
|
)
|
||||||
|
|
||||||
from docling_core.types.doc import (
|
from docling_core.types.doc import (
|
||||||
BoundingBox,
|
BoundingBox,
|
||||||
DocItemLabel,
|
DocItemLabel,
|
||||||
@ -12,9 +16,6 @@ from docling_core.types.doc import (
|
|||||||
from docling_core.types.doc.page import SegmentedPdfPage, TextCell
|
from docling_core.types.doc.page import SegmentedPdfPage, TextCell
|
||||||
|
|
||||||
# DO NOT REMOVE; explicitly exposed from this location
|
# DO NOT REMOVE; explicitly exposed from this location
|
||||||
from docling_core.types.io import (
|
|
||||||
DocumentStream,
|
|
||||||
)
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
@ -127,12 +128,6 @@ class ErrorItem(BaseModel):
|
|||||||
error_message: str
|
error_message: str
|
||||||
|
|
||||||
|
|
||||||
# class Cell(BaseModel):
|
|
||||||
# id: int
|
|
||||||
# text: str
|
|
||||||
# bbox: BoundingBox
|
|
||||||
|
|
||||||
|
|
||||||
class Cluster(BaseModel):
|
class Cluster(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
label: DocItemLabel
|
label: DocItemLabel
|
||||||
@ -153,9 +148,15 @@ class BasePageElement(BaseModel):
|
|||||||
class LayoutPrediction(BaseModel):
|
class LayoutPrediction(BaseModel):
|
||||||
clusters: List[Cluster] = []
|
clusters: List[Cluster] = []
|
||||||
|
|
||||||
|
class VlmPredictionToken(BaseModel):
|
||||||
|
text: str = ""
|
||||||
|
token: int = -1
|
||||||
|
logprob: float = -1
|
||||||
|
|
||||||
class VlmPrediction(BaseModel):
|
class VlmPrediction(BaseModel):
|
||||||
text: str = ""
|
text: str = ""
|
||||||
|
generated_tokens: list[VlmPredictionToken] = -1
|
||||||
|
generation_time: float = -1
|
||||||
|
|
||||||
|
|
||||||
class ContainerElement(
|
class ContainerElement(
|
||||||
|
@ -20,30 +20,8 @@ import filetype
|
|||||||
|
|
||||||
# DO NOT REMOVE; explicitly exposed from this location
|
# DO NOT REMOVE; explicitly exposed from this location
|
||||||
from docling_core.types.doc import (
|
from docling_core.types.doc import (
|
||||||
DocItem,
|
|
||||||
DocItemLabel,
|
DocItemLabel,
|
||||||
DoclingDocument,
|
DoclingDocument,
|
||||||
PictureItem,
|
|
||||||
SectionHeaderItem,
|
|
||||||
TableItem,
|
|
||||||
TextItem,
|
|
||||||
)
|
|
||||||
from docling_core.types.doc.document import ListItem
|
|
||||||
from docling_core.types.legacy_doc.base import (
|
|
||||||
BaseText,
|
|
||||||
Figure,
|
|
||||||
GlmTableCell,
|
|
||||||
PageDimensions,
|
|
||||||
PageReference,
|
|
||||||
Prov,
|
|
||||||
Ref,
|
|
||||||
Table as DsSchemaTable,
|
|
||||||
TableCell,
|
|
||||||
)
|
|
||||||
from docling_core.types.legacy_doc.document import (
|
|
||||||
CCSDocumentDescription as DsDocumentDescription,
|
|
||||||
CCSFileInfoObject as DsFileInfoObject,
|
|
||||||
ExportedCCSDocument as DsDocument,
|
|
||||||
)
|
)
|
||||||
from docling_core.utils.file import resolve_source_to_stream
|
from docling_core.utils.file import resolve_source_to_stream
|
||||||
from docling_core.utils.legacy import docling_document_to_legacy
|
from docling_core.utils.legacy import docling_document_to_legacy
|
||||||
|
@ -1,19 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from collections.abc import Iterable
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from docling.datamodel.base_models import Page, VlmPrediction
|
|
||||||
from docling.datamodel.document import ConversionResult
|
|
||||||
from docling.datamodel.pipeline_options import (
|
|
||||||
AcceleratorOptions,
|
|
||||||
HuggingFaceVlmOptions,
|
|
||||||
)
|
|
||||||
from docling.models.base_model import BasePageModel
|
|
||||||
from docling.utils.accelerator_utils import decide_device
|
|
||||||
from docling.utils.profiling import TimeRecorder
|
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from collections.abc import Iterable
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from docling.datamodel.base_models import Page, VlmPrediction
|
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options import (
|
from docling.datamodel.pipeline_options import (
|
||||||
AcceleratorOptions,
|
AcceleratorOptions,
|
||||||
@ -29,6 +29,8 @@ class HuggingFaceMlxModel(BasePageModel):
|
|||||||
|
|
||||||
self.vlm_options = vlm_options
|
self.vlm_options = vlm_options
|
||||||
|
|
||||||
|
self.max_tokens=4096
|
||||||
|
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
try:
|
try:
|
||||||
from mlx_vlm import generate, load # type: ignore
|
from mlx_vlm import generate, load # type: ignore
|
||||||
@ -40,29 +42,32 @@ class HuggingFaceMlxModel(BasePageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
||||||
print(f"model init: {repo_cache_folder}")
|
_log.debug(f"model init: {repo_cache_folder}")
|
||||||
|
|
||||||
self.apply_chat_template = apply_chat_template
|
self.apply_chat_template = apply_chat_template
|
||||||
self.stream_generate = stream_generate
|
self.stream_generate = stream_generate
|
||||||
|
|
||||||
# PARAMETERS:
|
# PARAMETERS:
|
||||||
if artifacts_path is None:
|
if artifacts_path is None:
|
||||||
print(f"before HuggingFaceVlmModel.download_models: {self.vlm_options.repo_id}")
|
_log.debug(
|
||||||
|
f"before HuggingFaceVlmModel.download_models: {self.vlm_options.repo_id}"
|
||||||
|
)
|
||||||
# artifacts_path = self.download_models(self.vlm_options.repo_id)
|
# artifacts_path = self.download_models(self.vlm_options.repo_id)
|
||||||
artifacts_path = HuggingFaceVlmModel.download_models(
|
artifacts_path = HuggingFaceVlmModel.download_models(
|
||||||
self.vlm_options.repo_id, progress=True,
|
self.vlm_options.repo_id,
|
||||||
|
progress=True,
|
||||||
)
|
)
|
||||||
elif (artifacts_path / repo_cache_folder).exists():
|
elif (artifacts_path / repo_cache_folder).exists():
|
||||||
artifacts_path = artifacts_path / repo_cache_folder
|
artifacts_path = artifacts_path / repo_cache_folder
|
||||||
|
|
||||||
print(f"downloaded model: {artifacts_path}")
|
_log.debug(f"downloaded model: {artifacts_path}")
|
||||||
|
|
||||||
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
|
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
|
||||||
|
|
||||||
## Load the model
|
## Load the model
|
||||||
print("start loading model ...")
|
_log.debug("start loading model ...")
|
||||||
self.vlm_model, self.processor = load(artifacts_path)
|
self.vlm_model, self.processor = load(artifacts_path)
|
||||||
print("loaded model ...")
|
_log.debug("loaded model ...")
|
||||||
self.config = load_config(artifacts_path)
|
self.config = load_config(artifacts_path)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -117,9 +122,11 @@ class HuggingFaceMlxModel(BasePageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
print("start generating ...")
|
_log.debug("start generating ...")
|
||||||
|
|
||||||
# Call model to generate:
|
# Call model to generate:
|
||||||
|
tokens:list[VlmPredictionToken] = []
|
||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
for token in self.stream_generate(
|
for token in self.stream_generate(
|
||||||
self.vlm_model,
|
self.vlm_model,
|
||||||
@ -129,23 +136,31 @@ class HuggingFaceMlxModel(BasePageModel):
|
|||||||
max_tokens=4096,
|
max_tokens=4096,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
):
|
):
|
||||||
print(token.text, end="", flush=True)
|
print(token.logprobs.shape)
|
||||||
|
if len(token.logprobs.shape)==1:
|
||||||
|
tokens.append(VlmPredictionToken(text=token.text,
|
||||||
|
token=token.token,
|
||||||
|
logprob=token.logprobs[token.token]))
|
||||||
|
elif len(token.logprobs.shape)==2 and token.logprobs.shape[0]==1:
|
||||||
|
tokens.append(VlmPredictionToken(text=token.text,
|
||||||
|
token=token.token,
|
||||||
|
logprob=token.logprobs[0, token.token]))
|
||||||
|
|
||||||
|
|
||||||
|
# print(token.text, end="", flush=True)
|
||||||
output += token.text
|
output += token.text
|
||||||
|
|
||||||
if "</doctag>" in token.text:
|
if "</doctag>" in token.text:
|
||||||
break
|
break
|
||||||
|
|
||||||
generation_time = time.time() - start_time
|
generation_time = time.time() - start_time
|
||||||
page_tags = output
|
page_tags = output
|
||||||
|
|
||||||
_log.debug(f"Generation time {generation_time:.2f} seconds.")
|
print(tokens)
|
||||||
|
|
||||||
# inference_time = time.time() - start_time
|
_log.debug(f"Generation time {generation_time:.2f} seconds.")
|
||||||
# tokens_per_second = num_tokens / generation_time
|
page.predictions.vlm_response = VlmPrediction(text=page_tags,
|
||||||
# print("")
|
generation_time=generation_time,
|
||||||
# print(f"Page Inference Time: {inference_time:.2f} seconds")
|
generated_tokens=tokens)
|
||||||
# print(f"Total tokens on page: {num_tokens:.2f}")
|
|
||||||
# print(f"Tokens/sec: {tokens_per_second:.2f}")
|
|
||||||
# print("")
|
|
||||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
|
||||||
|
|
||||||
yield page
|
yield page
|
||||||
|
@ -183,6 +183,4 @@ class HuggingFaceVlmModel_AutoModelForCausalLM(BasePageModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"No prompt template for {self.vlm_options.repo_id}")
|
raise ValueError(f"No prompt template for {self.vlm_options.repo_id}")
|
||||||
|
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
@ -4,8 +4,6 @@ from collections.abc import Iterable
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from transformers import AutoProcessor, LlavaForConditionalGeneration
|
|
||||||
|
|
||||||
from docling.datamodel.base_models import Page, VlmPrediction
|
from docling.datamodel.base_models import Page, VlmPrediction
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options import (
|
from docling.datamodel.pipeline_options import (
|
||||||
@ -35,7 +33,6 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
|||||||
self.vlm_options = vlm_options
|
self.vlm_options = vlm_options
|
||||||
|
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
import torch
|
|
||||||
from transformers import ( # type: ignore
|
from transformers import ( # type: ignore
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
LlavaForConditionalGeneration,
|
LlavaForConditionalGeneration,
|
||||||
@ -116,6 +113,8 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
|||||||
use_cache=self.use_cache, # Enables KV caching which can improve performance
|
use_cache=self.use_cache, # Enables KV caching which can improve performance
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print(generate_ids)
|
||||||
|
|
||||||
num_tokens = len(generate_ids[0])
|
num_tokens = len(generate_ids[0])
|
||||||
generation_time = time.time() - start_time
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
@ -125,19 +124,9 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
|||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
"""
|
page.predictions.vlm_response = VlmPrediction(text=response,
|
||||||
_log.debug(
|
generated_tokens=num_tokens,
|
||||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
generation_time=generation_time)
|
||||||
)
|
|
||||||
"""
|
|
||||||
# inference_time = time.time() - start_time
|
|
||||||
# tokens_per_second = num_tokens / generation_time
|
|
||||||
# print("")
|
|
||||||
# print(f"Page Inference Time: {inference_time:.2f} seconds")
|
|
||||||
# print(f"Total tokens on page: {num_tokens:.2f}")
|
|
||||||
# print(f"Tokens/sec: {tokens_per_second:.2f}")
|
|
||||||
# print("")
|
|
||||||
page.predictions.vlm_response = VlmPrediction(text=response)
|
|
||||||
|
|
||||||
yield page
|
yield page
|
||||||
|
|
||||||
@ -147,10 +136,11 @@ class HuggingFaceVlmModel_LlavaForConditionalGeneration(BasePageModel):
|
|||||||
# prompt = f"<s>[INST]{self.vlm_options.prompt}\n[IMG][/INST]"
|
# prompt = f"<s>[INST]{self.vlm_options.prompt}\n[IMG][/INST]"
|
||||||
chat = [
|
chat = [
|
||||||
{
|
{
|
||||||
"role": "user", "content": [
|
"role": "user",
|
||||||
|
"content": [
|
||||||
{"type": "text", "content": self.vlm_options.prompt},
|
{"type": "text", "content": self.vlm_options.prompt},
|
||||||
{"type": "image"},
|
{"type": "image"},
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
prompt = self.processor.apply_chat_template(chat)
|
prompt = self.processor.apply_chat_template(chat)
|
||||||
|
@ -11,7 +11,6 @@ from docling.datamodel.pipeline_options import (
|
|||||||
InferenceFramework,
|
InferenceFramework,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
VlmPipelineOptions,
|
VlmPipelineOptions,
|
||||||
smoldocling_vlm_mlx_conversion_options,
|
|
||||||
)
|
)
|
||||||
from docling.document_converter import DocumentConverter, PdfFormatOption
|
from docling.document_converter import DocumentConverter, PdfFormatOption
|
||||||
from docling.pipeline.vlm_pipeline import VlmPipeline
|
from docling.pipeline.vlm_pipeline import VlmPipeline
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from docling_core.types.doc import ImageRefMode, TableItem, TextItem
|
from docling_core.types.doc import ImageRefMode, TableItem, TextItem
|
||||||
|
Loading…
Reference in New Issue
Block a user