mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
Prepare existing codes for use with new multi-stage VLM pipeline
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Annotated, Dict, List, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from docling_core.types.doc import (
|
from docling_core.types.doc import (
|
||||||
|
|||||||
@@ -282,6 +282,9 @@ class LayoutOptions(BaseModel):
|
|||||||
keep_empty_clusters: bool = (
|
keep_empty_clusters: bool = (
|
||||||
False # Whether to keep clusters that contain no text cells
|
False # Whether to keep clusters that contain no text cells
|
||||||
)
|
)
|
||||||
|
skip_cell_assignment: bool = (
|
||||||
|
False # Skip cell-to-cluster assignment for VLM-only processing
|
||||||
|
)
|
||||||
model_spec: LayoutModelConfig = DOCLING_LAYOUT_V2
|
model_spec: LayoutModelConfig = DOCLING_LAYOUT_V2
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,17 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Generic, Optional, Protocol, Type
|
from typing import Generic, Optional, Protocol, Type, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
|
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
|
||||||
|
from PIL.Image import Image
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
|
from docling.datamodel.base_models import (
|
||||||
|
ItemAndImageEnrichmentElement,
|
||||||
|
Page,
|
||||||
|
VlmPrediction,
|
||||||
|
)
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options import BaseOptions
|
from docling.datamodel.pipeline_options import BaseOptions
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
@@ -26,6 +32,46 @@ class BasePageModel(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseVlmModel(ABC):
|
||||||
|
"""Base class for Vision-Language Models that adds image processing capability."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process_images(
|
||||||
|
self, image_batch: Iterable[Union[Image, np.ndarray]]
|
||||||
|
) -> Iterable[VlmPrediction]:
|
||||||
|
"""Process raw images without page metadata."""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseVlmPageModel(BasePageModel, BaseVlmModel):
|
||||||
|
"""Base implementation for VLM models that inherit from BasePageModel.
|
||||||
|
|
||||||
|
Provides a default __call__ implementation that extracts images from pages,
|
||||||
|
processes them using process_images, and attaches results back to pages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
|
) -> Iterable[Page]:
|
||||||
|
"""Extract images from pages, process them, and attach results back."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process_images(
|
||||||
|
self,
|
||||||
|
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
) -> Iterable[VlmPrediction]:
|
||||||
|
"""Process raw images without page metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_batch: Iterable of PIL Images or numpy arrays
|
||||||
|
prompt: Optional prompt string. If None, uses vlm_options.prompt if it's a string.
|
||||||
|
If vlm_options.prompt is callable and no prompt is provided, raises ValueError.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If vlm_options.prompt is callable and no prompt parameter is provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
|
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ from docling.utils.profiling import TimeRecorder
|
|||||||
|
|
||||||
class PagePreprocessingOptions(BaseModel):
|
class PagePreprocessingOptions(BaseModel):
|
||||||
images_scale: Optional[float]
|
images_scale: Optional[float]
|
||||||
|
skip_cell_extraction: bool = (
|
||||||
|
False # Skip text cell extraction for VLM-only processing
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PagePreprocessingModel(BasePageModel):
|
class PagePreprocessingModel(BasePageModel):
|
||||||
@@ -41,7 +44,8 @@ class PagePreprocessingModel(BasePageModel):
|
|||||||
else:
|
else:
|
||||||
with TimeRecorder(conv_res, "page_parse"):
|
with TimeRecorder(conv_res, "page_parse"):
|
||||||
page = self._populate_page_images(page)
|
page = self._populate_page_images(page)
|
||||||
page = self._parse_page_cells(conv_res, page)
|
if not self.options.skip_cell_extraction:
|
||||||
|
page = self._parse_page_cells(conv_res, page)
|
||||||
yield page
|
yield page
|
||||||
|
|
||||||
# Generate the page image and store it in the page object
|
# Generate the page image and store it in the page object
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
from docling.datamodel.accelerator_options import (
|
from docling.datamodel.accelerator_options import (
|
||||||
AcceleratorOptions,
|
AcceleratorOptions,
|
||||||
@@ -15,7 +18,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
|||||||
TransformersModelType,
|
TransformersModelType,
|
||||||
TransformersPromptStyle,
|
TransformersPromptStyle,
|
||||||
)
|
)
|
||||||
from docling.models.base_model import BasePageModel
|
from docling.models.base_model import BaseVlmPageModel
|
||||||
from docling.models.utils.hf_model_download import (
|
from docling.models.utils.hf_model_download import (
|
||||||
HuggingFaceModelDownloadMixin,
|
HuggingFaceModelDownloadMixin,
|
||||||
)
|
)
|
||||||
@@ -25,7 +28,7 @@ from docling.utils.profiling import TimeRecorder
|
|||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
enabled: bool,
|
enabled: bool,
|
||||||
@@ -159,7 +162,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|||||||
generation_time = time.time() - start_time
|
generation_time = time.time() - start_time
|
||||||
generated_texts = self.processor.batch_decode(
|
generated_texts = self.processor.batch_decode(
|
||||||
generated_ids[:, inputs["input_ids"].shape[1] :],
|
generated_ids[:, inputs["input_ids"].shape[1] :],
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=True,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
num_tokens = len(generated_ids[0])
|
num_tokens = len(generated_ids[0])
|
||||||
@@ -214,3 +217,101 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
|
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def process_images(
|
||||||
|
self,
|
||||||
|
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
) -> Iterable[VlmPrediction]:
|
||||||
|
"""Process raw images without page metadata in a single batched inference call."""
|
||||||
|
pil_images: list[Image] = []
|
||||||
|
|
||||||
|
for img in image_batch:
|
||||||
|
# Convert numpy array to PIL Image if needed
|
||||||
|
if isinstance(img, np.ndarray):
|
||||||
|
if img.ndim == 3 and img.shape[2] in [3, 4]:
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
pil_img = PILImage.fromarray(img.astype(np.uint8))
|
||||||
|
elif img.ndim == 2:
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
|
||||||
|
else:
|
||||||
|
pil_img = img
|
||||||
|
|
||||||
|
# Ensure image is in RGB mode (handles RGBA, L, etc.)
|
||||||
|
if pil_img.mode != "RGB":
|
||||||
|
pil_img = pil_img.convert("RGB")
|
||||||
|
|
||||||
|
pil_images.append(pil_img)
|
||||||
|
|
||||||
|
if len(pil_images) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle prompt with priority: parameter > vlm_options.prompt > error
|
||||||
|
if prompt is not None:
|
||||||
|
user_prompt = prompt
|
||||||
|
elif not callable(self.vlm_options.prompt):
|
||||||
|
user_prompt = self.vlm_options.prompt
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"vlm_options.prompt is callable but no prompt parameter provided to process_images. "
|
||||||
|
"Please provide a prompt parameter when calling process_images directly."
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted_prompt = self.formulate_prompt(user_prompt)
|
||||||
|
prompts: list[str] = [formatted_prompt] * len(pil_images)
|
||||||
|
|
||||||
|
inputs = self.processor(
|
||||||
|
text=prompts, images=pil_images, return_tensors="pt", padding=True
|
||||||
|
).to(self.device)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
generated_ids = self.vlm_model.generate(
|
||||||
|
**inputs,
|
||||||
|
max_new_tokens=self.max_new_tokens,
|
||||||
|
use_cache=self.use_cache,
|
||||||
|
temperature=self.temperature,
|
||||||
|
generation_config=self.generation_config,
|
||||||
|
**self.vlm_options.extra_generation_config,
|
||||||
|
)
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Determine per-sample prompt lengths
|
||||||
|
try:
|
||||||
|
attention_mask = inputs["attention_mask"]
|
||||||
|
input_lengths: list[int] = attention_mask.sum(dim=1).tolist()
|
||||||
|
except KeyError:
|
||||||
|
tokenizer = (
|
||||||
|
self.processor.tokenizer
|
||||||
|
) # Expect tokenizer to be present when text is provided
|
||||||
|
pad_token_id = tokenizer.pad_token_id
|
||||||
|
if pad_token_id is not None:
|
||||||
|
input_lengths = (
|
||||||
|
(inputs["input_ids"] != pad_token_id).sum(dim=1).tolist()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback: assume uniform prompt length (least accurate but preserves execution)
|
||||||
|
uniform_len = int(inputs["input_ids"].shape[1])
|
||||||
|
input_lengths = [uniform_len] * int(inputs["input_ids"].shape[0])
|
||||||
|
|
||||||
|
trimmed_sequences: list[list[int]] = [
|
||||||
|
generated_ids[i, int(input_lengths[i]) :].tolist()
|
||||||
|
for i in range(generated_ids.shape[0])
|
||||||
|
]
|
||||||
|
decoded_texts: list[str] = self.processor.batch_decode(
|
||||||
|
trimmed_sequences, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Logging tokens count for the first sample as a representative metric
|
||||||
|
if generated_ids.shape[0] > 0:
|
||||||
|
num_tokens = int(generated_ids[0].shape[0])
|
||||||
|
_log.debug(
|
||||||
|
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||||
|
)
|
||||||
|
|
||||||
|
for text in decoded_texts:
|
||||||
|
yield VlmPrediction(text=text, generation_time=generation_time)
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL.Image import Image
|
||||||
|
|
||||||
from docling.datamodel.accelerator_options import (
|
from docling.datamodel.accelerator_options import (
|
||||||
AcceleratorOptions,
|
AcceleratorOptions,
|
||||||
@@ -10,7 +14,7 @@ from docling.datamodel.accelerator_options import (
|
|||||||
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
||||||
from docling.models.base_model import BasePageModel
|
from docling.models.base_model import BaseVlmPageModel
|
||||||
from docling.models.utils.hf_model_download import (
|
from docling.models.utils.hf_model_download import (
|
||||||
HuggingFaceModelDownloadMixin,
|
HuggingFaceModelDownloadMixin,
|
||||||
)
|
)
|
||||||
@@ -18,8 +22,12 @@ from docling.utils.profiling import TimeRecorder
|
|||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global lock for MLX model calls - MLX models are not thread-safe
|
||||||
|
# All MLX models share this lock to prevent concurrent MLX operations
|
||||||
|
_MLX_GLOBAL_LOCK = threading.Lock()
|
||||||
|
|
||||||
class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|
||||||
|
class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
enabled: bool,
|
enabled: bool,
|
||||||
@@ -92,51 +100,57 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
self.processor, self.config, user_prompt, num_images=1
|
self.processor, self.config, user_prompt, num_images=1
|
||||||
)
|
)
|
||||||
|
|
||||||
start_time = time.time()
|
# MLX models are not thread-safe - use global lock to serialize access
|
||||||
_log.debug("start generating ...")
|
with _MLX_GLOBAL_LOCK:
|
||||||
|
_log.debug(
|
||||||
|
"MLX model: Acquired global lock for __call__ method"
|
||||||
|
)
|
||||||
|
start_time = time.time()
|
||||||
|
_log.debug("start generating ...")
|
||||||
|
|
||||||
# Call model to generate:
|
# Call model to generate:
|
||||||
tokens: list[VlmPredictionToken] = []
|
tokens: list[VlmPredictionToken] = []
|
||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
for token in self.stream_generate(
|
for token in self.stream_generate(
|
||||||
self.vlm_model,
|
self.vlm_model,
|
||||||
self.processor,
|
self.processor,
|
||||||
prompt,
|
prompt,
|
||||||
[hi_res_image],
|
[hi_res_image],
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
temp=self.temperature,
|
temp=self.temperature,
|
||||||
):
|
|
||||||
if len(token.logprobs.shape) == 1:
|
|
||||||
tokens.append(
|
|
||||||
VlmPredictionToken(
|
|
||||||
text=token.text,
|
|
||||||
token=token.token,
|
|
||||||
logprob=token.logprobs[token.token],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
len(token.logprobs.shape) == 2
|
|
||||||
and token.logprobs.shape[0] == 1
|
|
||||||
):
|
):
|
||||||
tokens.append(
|
if len(token.logprobs.shape) == 1:
|
||||||
VlmPredictionToken(
|
tokens.append(
|
||||||
text=token.text,
|
VlmPredictionToken(
|
||||||
token=token.token,
|
text=token.text,
|
||||||
logprob=token.logprobs[0, token.token],
|
token=token.token,
|
||||||
|
logprob=token.logprobs[token.token],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
len(token.logprobs.shape) == 2
|
||||||
|
and token.logprobs.shape[0] == 1
|
||||||
|
):
|
||||||
|
tokens.append(
|
||||||
|
VlmPredictionToken(
|
||||||
|
text=token.text,
|
||||||
|
token=token.token,
|
||||||
|
logprob=token.logprobs[0, token.token],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_log.warning(
|
||||||
|
f"incompatible shape for logprobs: {token.logprobs.shape}"
|
||||||
)
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
_log.warning(
|
|
||||||
f"incompatible shape for logprobs: {token.logprobs.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
output += token.text
|
output += token.text
|
||||||
if "</doctag>" in token.text:
|
if "</doctag>" in token.text:
|
||||||
break
|
break
|
||||||
|
|
||||||
generation_time = time.time() - start_time
|
generation_time = time.time() - start_time
|
||||||
|
_log.debug("MLX model: Released global lock")
|
||||||
page_tags = output
|
page_tags = output
|
||||||
|
|
||||||
_log.debug(
|
_log.debug(
|
||||||
@@ -149,3 +163,82 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
yield page
|
yield page
|
||||||
|
|
||||||
|
def process_images(
|
||||||
|
self,
|
||||||
|
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
) -> Iterable[VlmPrediction]:
|
||||||
|
from mlx_vlm import generate
|
||||||
|
|
||||||
|
# MLX models are not thread-safe - use global lock to serialize access
|
||||||
|
with _MLX_GLOBAL_LOCK:
|
||||||
|
_log.debug("MLX model: Acquired global lock for thread safety")
|
||||||
|
for image in image_batch:
|
||||||
|
# Convert numpy array to PIL Image if needed
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
if image.ndim == 3 and image.shape[2] in [3, 4]:
|
||||||
|
# RGB or RGBA array
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
image = PILImage.fromarray(image.astype(np.uint8))
|
||||||
|
elif image.ndim == 2:
|
||||||
|
# Grayscale array
|
||||||
|
from PIL import Image as PILImage
|
||||||
|
|
||||||
|
image = PILImage.fromarray(image.astype(np.uint8), mode="L")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported numpy array shape: {image.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure image is in RGB mode (handles RGBA, L, etc.)
|
||||||
|
if image.mode != "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
# Handle prompt with priority: parameter > vlm_options.prompt > error
|
||||||
|
if prompt is not None:
|
||||||
|
user_prompt = prompt
|
||||||
|
elif not callable(self.vlm_options.prompt):
|
||||||
|
user_prompt = self.vlm_options.prompt
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"vlm_options.prompt is callable but no prompt parameter provided to process_images. "
|
||||||
|
"Please provide a prompt parameter when calling process_images directly."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the MLX chat template approach like in the __call__ method
|
||||||
|
formatted_prompt = self.apply_chat_template(
|
||||||
|
self.processor, self.config, user_prompt, num_images=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate text from the image - MLX can accept PIL Images directly despite type annotations
|
||||||
|
start_time = time.time()
|
||||||
|
generated_result = generate(
|
||||||
|
self.vlm_model,
|
||||||
|
self.processor,
|
||||||
|
formatted_prompt,
|
||||||
|
image=image, # Pass PIL Image directly - much more efficient than disk I/O
|
||||||
|
verbose=False,
|
||||||
|
temp=self.temperature,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
)
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
|
# MLX generate returns a tuple (text, info_dict), extract just the text
|
||||||
|
if isinstance(generated_result, tuple):
|
||||||
|
generated_text = generated_result[0]
|
||||||
|
_log.debug(
|
||||||
|
f"MLX generate returned tuple with additional info: {generated_result[1] if len(generated_result) > 1 else 'N/A'}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
generated_text = generated_result
|
||||||
|
|
||||||
|
_log.debug(f"Generated text in {generation_time:.2f}s.")
|
||||||
|
yield VlmPrediction(
|
||||||
|
text=generated_text,
|
||||||
|
generation_time=generation_time,
|
||||||
|
# MLX generate doesn't expose tokens directly, so we leave it empty
|
||||||
|
generated_tokens=[],
|
||||||
|
)
|
||||||
|
_log.debug("MLX model: Released global lock")
|
||||||
|
|||||||
@@ -239,15 +239,18 @@ class LayoutPostprocessor:
|
|||||||
final_clusters = self._sort_clusters(
|
final_clusters = self._sort_clusters(
|
||||||
self.regular_clusters + self.special_clusters, mode="id"
|
self.regular_clusters + self.special_clusters, mode="id"
|
||||||
)
|
)
|
||||||
for cluster in final_clusters:
|
|
||||||
cluster.cells = self._sort_cells(cluster.cells)
|
|
||||||
# Also sort cells in children if any
|
|
||||||
for child in cluster.children:
|
|
||||||
child.cells = self._sort_cells(child.cells)
|
|
||||||
|
|
||||||
assert self.page.parsed_page is not None
|
# Conditionally process cells if not skipping cell assignment
|
||||||
self.page.parsed_page.textline_cells = self.cells
|
if not self.options.skip_cell_assignment:
|
||||||
self.page.parsed_page.has_lines = len(self.cells) > 0
|
for cluster in final_clusters:
|
||||||
|
cluster.cells = self._sort_cells(cluster.cells)
|
||||||
|
# Also sort cells in children if any
|
||||||
|
for child in cluster.children:
|
||||||
|
child.cells = self._sort_cells(child.cells)
|
||||||
|
|
||||||
|
assert self.page.parsed_page is not None
|
||||||
|
self.page.parsed_page.textline_cells = self.cells
|
||||||
|
self.page.parsed_page.has_lines = len(self.cells) > 0
|
||||||
|
|
||||||
return final_clusters, self.cells
|
return final_clusters, self.cells
|
||||||
|
|
||||||
@@ -264,36 +267,38 @@ class LayoutPostprocessor:
|
|||||||
if cluster.label in self.LABEL_REMAPPING:
|
if cluster.label in self.LABEL_REMAPPING:
|
||||||
cluster.label = self.LABEL_REMAPPING[cluster.label]
|
cluster.label = self.LABEL_REMAPPING[cluster.label]
|
||||||
|
|
||||||
# Initial cell assignment
|
# Conditionally assign cells to clusters
|
||||||
clusters = self._assign_cells_to_clusters(clusters)
|
if not self.options.skip_cell_assignment:
|
||||||
|
# Initial cell assignment
|
||||||
|
clusters = self._assign_cells_to_clusters(clusters)
|
||||||
|
|
||||||
# Remove clusters with no cells (if keep_empty_clusters is False),
|
# Remove clusters with no cells (if keep_empty_clusters is False),
|
||||||
# but always keep clusters with label DocItemLabel.FORMULA
|
# but always keep clusters with label DocItemLabel.FORMULA
|
||||||
if not self.options.keep_empty_clusters:
|
if not self.options.keep_empty_clusters:
|
||||||
clusters = [
|
clusters = [
|
||||||
cluster
|
cluster
|
||||||
for cluster in clusters
|
for cluster in clusters
|
||||||
if cluster.cells or cluster.label == DocItemLabel.FORMULA
|
if cluster.cells or cluster.label == DocItemLabel.FORMULA
|
||||||
]
|
]
|
||||||
|
|
||||||
# Handle orphaned cells
|
# Handle orphaned cells
|
||||||
unassigned = self._find_unassigned_cells(clusters)
|
unassigned = self._find_unassigned_cells(clusters)
|
||||||
if unassigned and self.options.create_orphan_clusters:
|
if unassigned and self.options.create_orphan_clusters:
|
||||||
next_id = max((c.id for c in self.all_clusters), default=0) + 1
|
next_id = max((c.id for c in self.all_clusters), default=0) + 1
|
||||||
orphan_clusters = []
|
orphan_clusters = []
|
||||||
for i, cell in enumerate(unassigned):
|
for i, cell in enumerate(unassigned):
|
||||||
conf = cell.confidence
|
conf = cell.confidence
|
||||||
|
|
||||||
orphan_clusters.append(
|
orphan_clusters.append(
|
||||||
Cluster(
|
Cluster(
|
||||||
id=next_id + i,
|
id=next_id + i,
|
||||||
label=DocItemLabel.TEXT,
|
label=DocItemLabel.TEXT,
|
||||||
bbox=cell.to_bounding_box(),
|
bbox=cell.to_bounding_box(),
|
||||||
confidence=conf,
|
confidence=conf,
|
||||||
cells=[cell],
|
cells=[cell],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
clusters.extend(orphan_clusters)
|
||||||
clusters.extend(orphan_clusters)
|
|
||||||
|
|
||||||
# Iterative refinement
|
# Iterative refinement
|
||||||
prev_count = len(clusters) + 1
|
prev_count = len(clusters) + 1
|
||||||
@@ -350,12 +355,15 @@ class LayoutPostprocessor:
|
|||||||
b=max(c.bbox.b for c in contained),
|
b=max(c.bbox.b for c in contained),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Collect all cells from children
|
# Conditionally collect cells from children
|
||||||
all_cells = []
|
if not self.options.skip_cell_assignment:
|
||||||
for child in contained:
|
all_cells = []
|
||||||
all_cells.extend(child.cells)
|
for child in contained:
|
||||||
special.cells = self._deduplicate_cells(all_cells)
|
all_cells.extend(child.cells)
|
||||||
special.cells = self._sort_cells(special.cells)
|
special.cells = self._deduplicate_cells(all_cells)
|
||||||
|
special.cells = self._sort_cells(special.cells)
|
||||||
|
else:
|
||||||
|
special.cells = []
|
||||||
|
|
||||||
picture_clusters = [
|
picture_clusters = [
|
||||||
c for c in special_clusters if c.label == DocItemLabel.PICTURE
|
c for c in special_clusters if c.label == DocItemLabel.PICTURE
|
||||||
|
|||||||
Reference in New Issue
Block a user