feat: update inference code to shuffle layout elements and discard initial prompt

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>
This commit is contained in:
ElHachem02
2025-12-03 12:59:31 +01:00
parent 54cd6d7406
commit 0904dbb95a

View File

@@ -19,6 +19,8 @@ from PIL import Image as PILImage
if TYPE_CHECKING: if TYPE_CHECKING:
from docling_core.types.doc.page import SegmentedPage from docling_core.types.doc.page import SegmentedPage
import random
from docling.backend.abstract_backend import AbstractDocumentBackend from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import ConversionStatus, Page from docling.datamodel.base_models import ConversionStatus, Page
@@ -84,13 +86,16 @@ class ThreadedLayoutVlmPipeline(BasePipeline):
class LayoutAwareVlmOptions(type(base_vlm_options)): # type: ignore[misc] class LayoutAwareVlmOptions(type(base_vlm_options)): # type: ignore[misc]
def build_prompt( def build_prompt(
self, self,
page: Optional[SegmentedPage],
*, *,
_internal_page: Optional[Page] = None, _internal_page: Optional[Page] = None,
) -> str: ) -> str:
base_prompt = self.prompt base_prompt = self.prompt
augmented_prompt = base_prompt augmented_prompt = base_prompt
# Only augment convert to docling base prompts
if base_prompt != "Convert this page to docling.":
return base_prompt
# In this layout-aware pipeline, _internal_page is always provided # In this layout-aware pipeline, _internal_page is always provided
if _internal_page is None: if _internal_page is None:
return base_prompt return base_prompt
@@ -111,6 +116,10 @@ class ThreadedLayoutVlmPipeline(BasePipeline):
label=cluster.label label=cluster.label
) )
if tag_name == DocumentToken.TABLE:
print("Found a table!")
tag_name = "otsl"
# Convert bbox to tuple and get location tokens # Convert bbox to tuple and get location tokens
bbox_tuple = cluster.bbox.as_tuple() bbox_tuple = cluster.bbox.as_tuple()
location_tokens = DocumentToken.get_location( location_tokens = DocumentToken.get_location(
@@ -124,13 +133,17 @@ class ThreadedLayoutVlmPipeline(BasePipeline):
layout_elements.append(xml_element) layout_elements.append(xml_element)
if layout_elements: if layout_elements:
# Shuffle elements
random.shuffle(layout_elements)
# Join elements with newlines and wrap in layout tags # Join elements with newlines and wrap in layout tags
layout_xml = ( layout_xml = (
"<layout>" + "\n".join(layout_elements) + "</layout>" "<layout>" + "\n".join(layout_elements) + "</layout>"
) )
layout_injection = f"{layout_xml}" layout_injection = f"{layout_xml}"
augmented_prompt = base_prompt + layout_injection augmented_prompt = layout_injection
print(f"final prompt is {augmented_prompt}")
_log.debug( _log.debug(
"Enhanced Prompt with Layout Info: %s\n", augmented_prompt "Enhanced Prompt with Layout Info: %s\n", augmented_prompt