mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
feat: update inference code to shuffle layout elements and discard initial prompt
Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>
This commit is contained in:
@@ -19,6 +19,8 @@ from PIL import Image as PILImage
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from docling_core.types.doc.page import SegmentedPage
|
from docling_core.types.doc.page import SegmentedPage
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
from docling.backend.abstract_backend import AbstractDocumentBackend
|
from docling.backend.abstract_backend import AbstractDocumentBackend
|
||||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||||
from docling.datamodel.base_models import ConversionStatus, Page
|
from docling.datamodel.base_models import ConversionStatus, Page
|
||||||
@@ -84,13 +86,16 @@ class ThreadedLayoutVlmPipeline(BasePipeline):
|
|||||||
class LayoutAwareVlmOptions(type(base_vlm_options)): # type: ignore[misc]
|
class LayoutAwareVlmOptions(type(base_vlm_options)): # type: ignore[misc]
|
||||||
def build_prompt(
|
def build_prompt(
|
||||||
self,
|
self,
|
||||||
page: Optional[SegmentedPage],
|
|
||||||
*,
|
*,
|
||||||
_internal_page: Optional[Page] = None,
|
_internal_page: Optional[Page] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
base_prompt = self.prompt
|
base_prompt = self.prompt
|
||||||
augmented_prompt = base_prompt
|
augmented_prompt = base_prompt
|
||||||
|
|
||||||
|
# Only augment convert to docling base prompts
|
||||||
|
if base_prompt != "Convert this page to docling.":
|
||||||
|
return base_prompt
|
||||||
|
|
||||||
# In this layout-aware pipeline, _internal_page is always provided
|
# In this layout-aware pipeline, _internal_page is always provided
|
||||||
if _internal_page is None:
|
if _internal_page is None:
|
||||||
return base_prompt
|
return base_prompt
|
||||||
@@ -111,6 +116,10 @@ class ThreadedLayoutVlmPipeline(BasePipeline):
|
|||||||
label=cluster.label
|
label=cluster.label
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if tag_name == DocumentToken.TABLE:
|
||||||
|
print("Found a table!")
|
||||||
|
tag_name = "otsl"
|
||||||
|
|
||||||
# Convert bbox to tuple and get location tokens
|
# Convert bbox to tuple and get location tokens
|
||||||
bbox_tuple = cluster.bbox.as_tuple()
|
bbox_tuple = cluster.bbox.as_tuple()
|
||||||
location_tokens = DocumentToken.get_location(
|
location_tokens = DocumentToken.get_location(
|
||||||
@@ -124,13 +133,17 @@ class ThreadedLayoutVlmPipeline(BasePipeline):
|
|||||||
layout_elements.append(xml_element)
|
layout_elements.append(xml_element)
|
||||||
|
|
||||||
if layout_elements:
|
if layout_elements:
|
||||||
|
# Shuffle elements
|
||||||
|
random.shuffle(layout_elements)
|
||||||
|
|
||||||
# Join elements with newlines and wrap in layout tags
|
# Join elements with newlines and wrap in layout tags
|
||||||
layout_xml = (
|
layout_xml = (
|
||||||
"<layout>" + "\n".join(layout_elements) + "</layout>"
|
"<layout>" + "\n".join(layout_elements) + "</layout>"
|
||||||
)
|
)
|
||||||
layout_injection = f"{layout_xml}"
|
layout_injection = f"{layout_xml}"
|
||||||
|
|
||||||
augmented_prompt = base_prompt + layout_injection
|
augmented_prompt = layout_injection
|
||||||
|
print(f"final prompt is {augmented_prompt}")
|
||||||
|
|
||||||
_log.debug(
|
_log.debug(
|
||||||
"Enhanced Prompt with Layout Info: %s\n", augmented_prompt
|
"Enhanced Prompt with Layout Info: %s\n", augmented_prompt
|
||||||
|
|||||||
Reference in New Issue
Block a user