feat: Rich tables support for HTML backend (#2324)

* Rich tables support for HTML backend

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* Decoupling JATS backend from HTML backend, ways of creating tables changed significantly

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* updated and added tests

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* Refactored parse_table_data in html_backend into few smaller functions

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* Changing scope of few functions in html_backend.py, making them static, when possible

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

* Fix for HTML tables that have tbody and/or thead, now these tables are also properly supported

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>

---------

Signed-off-by: Maksym Lysak <mly@zurich.ibm.com>
Co-authored-by: Maksym Lysak <mly@zurich.ibm.com>
This commit is contained in:
Maxim Lysak
2025-09-29 18:12:16 +02:00
committed by GitHub
parent 325877aee9
commit c803abed9a
46 changed files with 9233 additions and 5815 deletions

View File

@@ -17,8 +17,11 @@ from docling_core.types.doc import (
DocumentOrigin,
GroupItem,
GroupLabel,
RefItem,
RichTableCell,
TableCell,
TableData,
TableItem,
TextItem,
)
from docling_core.types.doc.document import ContentLayer, Formatting, Script
@@ -276,10 +279,175 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
# reset context
self.ctx = _Context()
self._walk(content, doc)
return doc
def _walk(self, element: Tag, doc: DoclingDocument) -> None:
@staticmethod
def group_cell_elements(
group_name: str,
doc: DoclingDocument,
provs_in_cell: list[RefItem],
docling_table: TableItem,
) -> RefItem:
group_element = doc.add_group(
label=GroupLabel.UNSPECIFIED,
name=group_name,
parent=docling_table,
)
for prov in provs_in_cell:
group_element.children.append(prov)
pr_item = prov.resolve(doc)
item_parent = pr_item.parent.resolve(doc)
if pr_item.get_ref() in item_parent.children:
item_parent.children.remove(pr_item.get_ref())
pr_item.parent = group_element.get_ref()
ref_for_rich_cell = group_element.get_ref()
return ref_for_rich_cell
@staticmethod
def process_rich_table_cells(
provs_in_cell: list[RefItem],
group_name: str,
doc: DoclingDocument,
docling_table: TableItem,
) -> tuple[bool, RefItem]:
rich_table_cell = False
ref_for_rich_cell = provs_in_cell[0]
if len(provs_in_cell) > 1:
# Cell has multiple elements, we need to group them
rich_table_cell = True
ref_for_rich_cell = HTMLDocumentBackend.group_cell_elements(
group_name, doc, provs_in_cell, docling_table
)
elif len(provs_in_cell) == 1:
item_ref = provs_in_cell[0]
pr_item = item_ref.resolve(doc)
if isinstance(pr_item, TextItem):
# Cell has only one element and it's just a text
rich_table_cell = False
doc.delete_items(node_items=[pr_item])
else:
rich_table_cell = True
ref_for_rich_cell = HTMLDocumentBackend.group_cell_elements(
group_name, doc, provs_in_cell, docling_table
)
return rich_table_cell, ref_for_rich_cell
def parse_table_data(
self,
element: Tag,
doc: DoclingDocument,
docling_table: TableItem,
num_rows: int,
num_cols: int,
) -> Optional[TableData]:
for t in cast(list[Tag], element.find_all(["thead", "tbody"], recursive=False)):
t.unwrap()
_log.debug(f"The table has {num_rows} rows and {num_cols} cols.")
grid: list = [[None for _ in range(num_cols)] for _ in range(num_rows)]
data = TableData(num_rows=num_rows, num_cols=num_cols, table_cells=[])
# Iterate over the rows in the table
start_row_span = 0
row_idx = -1
# We don't want this recursive to support nested tables
for row in element("tr", recursive=False):
if not isinstance(row, Tag):
continue
# For each row, find all the column cells (both <td> and <th>)
# We don't want this recursive to support nested tables
cells = row(["td", "th"], recursive=False)
# Check if cell is in a column header or row header
col_header = True
row_header = True
for html_cell in cells:
if isinstance(html_cell, Tag):
_, row_span = HTMLDocumentBackend._get_cell_spans(html_cell)
if html_cell.name == "td":
col_header = False
row_header = False
elif row_span == 1:
row_header = False
if not row_header:
row_idx += 1
start_row_span = 0
else:
start_row_span += 1
# Extract the text content of each cell
col_idx = 0
for html_cell in cells:
if not isinstance(html_cell, Tag):
continue
# extract inline formulas
for formula in html_cell("inline-formula"):
math_parts = formula.text.split("$$")
if len(math_parts) == 3:
math_formula = f"$${math_parts[1]}$$"
formula.replace_with(NavigableString(math_formula))
provs_in_cell: list[RefItem] = []
# Parse table cell sub-tree for Rich Cells content:
provs_in_cell = self._walk(html_cell, doc)
rich_table_cell = False
ref_for_rich_cell = None
if len(provs_in_cell) > 0:
group_name = f"rich_cell_group_{len(doc.tables)}_{col_idx}_{start_row_span + row_idx}"
rich_table_cell, ref_for_rich_cell = (
HTMLDocumentBackend.process_rich_table_cells(
provs_in_cell, group_name, doc, docling_table
)
)
# Extracting text
text = self.get_text(html_cell).strip()
col_span, row_span = self._get_cell_spans(html_cell)
if row_header:
row_span -= 1
while (
col_idx < num_cols
and grid[row_idx + start_row_span][col_idx] is not None
):
col_idx += 1
for r in range(start_row_span, start_row_span + row_span):
for c in range(col_span):
if row_idx + r < num_rows and col_idx + c < num_cols:
grid[row_idx + r][col_idx + c] = text
if rich_table_cell:
rich_cell = RichTableCell(
text=text,
row_span=row_span,
col_span=col_span,
start_row_offset_idx=start_row_span + row_idx,
end_row_offset_idx=start_row_span + row_idx + row_span,
start_col_offset_idx=col_idx,
end_col_offset_idx=col_idx + col_span,
column_header=col_header,
row_header=((not col_header) and html_cell.name == "th"),
ref=ref_for_rich_cell, # points to an artificial group around children
)
doc.add_table_cell(table_item=docling_table, cell=rich_cell)
else:
simple_cell = TableCell(
text=text,
row_span=row_span,
col_span=col_span,
start_row_offset_idx=start_row_span + row_idx,
end_row_offset_idx=start_row_span + row_idx + row_span,
start_col_offset_idx=col_idx,
end_col_offset_idx=col_idx + col_span,
column_header=col_header,
row_header=((not col_header) and html_cell.name == "th"),
)
doc.add_table_cell(table_item=docling_table, cell=simple_cell)
return data
def _walk(self, element: Tag, doc: DoclingDocument) -> list[RefItem]:
"""Parse an XML tag by recursively walking its content.
While walking, the method buffers inline text across tags like <b> or <span>,
@@ -289,17 +457,18 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
element: The XML tag to parse.
doc: The Docling document to be updated with the parsed content.
"""
added_refs: list[RefItem] = []
buffer: AnnotatedTextList = AnnotatedTextList()
def flush_buffer():
if not buffer:
return
return added_refs
annotated_text_list: AnnotatedTextList = buffer.simplify_text_elements()
parts = annotated_text_list.split_by_newline()
buffer.clear()
if not "".join([el.text for el in annotated_text_list]):
return
return added_refs
for annotated_text_list in parts:
with self._use_inline_group(annotated_text_list, doc):
@@ -309,15 +478,16 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
annotated_text.text.strip()
)
if annotated_text.code:
doc.add_code(
docling_code2 = doc.add_code(
parent=self.parents[self.level],
text=seg_clean,
content_layer=self.content_layer,
formatting=annotated_text.formatting,
hyperlink=annotated_text.hyperlink,
)
added_refs.append(docling_code2.get_ref())
else:
doc.add_text(
docling_text2 = doc.add_text(
parent=self.parents[self.level],
label=DocItemLabel.TEXT,
text=seg_clean,
@@ -325,25 +495,31 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
formatting=annotated_text.formatting,
hyperlink=annotated_text.hyperlink,
)
added_refs.append(docling_text2.get_ref())
for node in element.contents:
if isinstance(node, Tag):
name = node.name.lower()
if name == "img":
flush_buffer()
self._emit_image(node, doc)
im_ref3 = self._emit_image(node, doc)
added_refs.append(im_ref3)
elif name in _FORMAT_TAG_MAP:
with self._use_format([name]):
self._walk(node, doc)
wk = self._walk(node, doc)
added_refs.extend(wk)
elif name == "a":
with self._use_hyperlink(node):
self._walk(node, doc)
wk2 = self._walk(node, doc)
added_refs.extend(wk2)
elif name in _BLOCK_TAGS:
flush_buffer()
self._handle_block(node, doc)
blk = self._handle_block(node, doc)
added_refs.extend(blk)
elif node.find(_BLOCK_TAGS):
flush_buffer()
self._walk(node, doc)
wk3 = self._walk(node, doc)
added_refs.extend(wk3)
else:
buffer.extend(
self._extract_text_and_hyperlink_recursively(
@@ -363,6 +539,7 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
)
flush_buffer()
return added_refs
@staticmethod
def _collect_parent_format_tags(item: PageElement) -> list[str]:
@@ -581,7 +758,8 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
self.level -= 1
self.content_layer = current_layer
def _handle_heading(self, tag: Tag, doc: DoclingDocument) -> None:
def _handle_heading(self, tag: Tag, doc: DoclingDocument) -> list[RefItem]:
added_ref = []
tag_name = tag.name.lower()
# set default content layer to BODY as soon as we encounter a heading
self.content_layer = ContentLayer.BODY
@@ -596,12 +774,13 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
for key in self.parents.keys():
self.parents[key] = None
self.level = 0
self.parents[self.level + 1] = doc.add_title(
docling_title = self.parents[self.level + 1] = doc.add_title(
text_clean,
content_layer=self.content_layer,
formatting=annotated_text.formatting,
hyperlink=annotated_text.hyperlink,
)
added_ref = [docling_title.get_ref()]
# the other levels need to be lowered by 1 if a title was set
else:
level -= 1
@@ -623,7 +802,7 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
_log.debug(f"Remove the tail of level {key}")
self.parents[key] = None
self.level = level
self.parents[self.level + 1] = doc.add_heading(
docling_heading = self.parents[self.level + 1] = doc.add_heading(
parent=self.parents[self.level],
text=text_clean,
orig=annotated_text.text,
@@ -632,12 +811,15 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
formatting=annotated_text.formatting,
hyperlink=annotated_text.hyperlink,
)
added_ref = [docling_heading.get_ref()]
self.level += 1
for img_tag in tag("img"):
if isinstance(img_tag, Tag):
self._emit_image(img_tag, doc)
im_ref = self._emit_image(img_tag, doc)
added_ref.append(im_ref)
return added_ref
def _handle_list(self, tag: Tag, doc: DoclingDocument) -> None:
def _handle_list(self, tag: Tag, doc: DoclingDocument) -> RefItem:
tag_name = tag.name.lower()
start: Optional[int] = None
name: str = ""
@@ -765,20 +947,50 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
self.parents[self.level + 1] = None
self.level -= 1
return list_group.get_ref()
def _handle_block(self, tag: Tag, doc: DoclingDocument) -> None:
@staticmethod
def get_html_table_row_col(tag: Tag) -> tuple[int, int]:
for t in cast(list[Tag], tag.find_all(["thead", "tbody"], recursive=False)):
t.unwrap()
# Find the number of rows and columns (taking into account spans)
num_rows: int = 0
num_cols: int = 0
for row in tag("tr", recursive=False):
col_count = 0
is_row_header = True
if not isinstance(row, Tag):
continue
for cell in row(["td", "th"], recursive=False):
if not isinstance(row, Tag):
continue
cell_tag = cast(Tag, cell)
col_span, row_span = HTMLDocumentBackend._get_cell_spans(cell_tag)
col_count += col_span
if cell_tag.name == "td" or row_span == 1:
is_row_header = False
num_cols = max(num_cols, col_count)
if not is_row_header:
num_rows += 1
return num_rows, num_cols
def _handle_block(self, tag: Tag, doc: DoclingDocument) -> list[RefItem]:
added_refs = []
tag_name = tag.name.lower()
if tag_name == "figure":
img_tag = tag.find("img")
if isinstance(img_tag, Tag):
self._emit_image(img_tag, doc)
im_ref = self._emit_image(img_tag, doc)
added_refs.append(im_ref)
elif tag_name in {"h1", "h2", "h3", "h4", "h5", "h6"}:
self._handle_heading(tag, doc)
heading_refs = self._handle_heading(tag, doc)
added_refs.extend(heading_refs)
elif tag_name in {"ul", "ol"}:
self._handle_list(tag, doc)
list_ref = self._handle_list(tag, doc)
added_refs.append(list_ref)
elif tag_name in {"p", "address", "summary"}:
text_list = self._extract_text_and_hyperlink_recursively(
@@ -791,15 +1003,16 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
if seg := annotated_text.text.strip():
seg_clean = HTMLDocumentBackend._clean_unicode(seg)
if annotated_text.code:
doc.add_code(
docling_code = doc.add_code(
parent=self.parents[self.level],
text=seg_clean,
content_layer=self.content_layer,
formatting=annotated_text.formatting,
hyperlink=annotated_text.hyperlink,
)
added_refs.append(docling_code.get_ref())
else:
doc.add_text(
docling_text = doc.add_text(
parent=self.parents[self.level],
label=DocItemLabel.TEXT,
text=seg_clean,
@@ -807,22 +1020,27 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
formatting=annotated_text.formatting,
hyperlink=annotated_text.hyperlink,
)
added_refs.append(docling_text.get_ref())
for img_tag in tag("img"):
if isinstance(img_tag, Tag):
self._emit_image(img_tag, doc)
elif tag_name == "table":
data = HTMLDocumentBackend.parse_table_data(tag)
num_rows, num_cols = self.get_html_table_row_col(tag)
data_e = TableData(num_rows=num_rows, num_cols=num_cols)
docling_table = doc.add_table(
data=data_e,
parent=self.parents[self.level],
content_layer=self.content_layer,
)
added_refs.append(docling_table.get_ref())
self.parse_table_data(tag, doc, docling_table, num_rows, num_cols)
for img_tag in tag("img"):
if isinstance(img_tag, Tag):
self._emit_image(tag, doc)
if data is not None:
doc.add_table(
data=data,
parent=self.parents[self.level],
content_layer=self.content_layer,
)
im_ref2 = self._emit_image(tag, doc)
added_refs.append(im_ref2)
elif tag_name in {"pre"}:
# handle monospace code snippets (pre).
@@ -835,13 +1053,14 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
text_clean = HTMLDocumentBackend._clean_unicode(
annotated_text.text.strip()
)
doc.add_code(
docling_code2 = doc.add_code(
parent=self.parents[self.level],
text=text_clean,
content_layer=self.content_layer,
formatting=annotated_text.formatting,
hyperlink=annotated_text.hyperlink,
)
added_refs.append(docling_code2.get_ref())
elif tag_name == "footer":
with self._use_footer(tag, doc):
@@ -850,8 +1069,9 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
elif tag_name == "details":
with self._use_details(tag, doc):
self._walk(tag, doc)
return added_refs
def _emit_image(self, img_tag: Tag, doc: DoclingDocument) -> None:
def _emit_image(self, img_tag: Tag, doc: DoclingDocument) -> RefItem:
figure = img_tag.find_parent("figure")
caption: AnnotatedTextList = AnnotatedTextList()
@@ -894,11 +1114,12 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
hyperlink=caption_anno_text.hyperlink,
)
doc.add_picture(
docling_pic = doc.add_picture(
caption=caption_item,
parent=self.parents[self.level],
content_layer=self.content_layer,
)
return docling_pic.get_ref()
@staticmethod
def get_text(item: PageElement) -> str:
@@ -996,106 +1217,3 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
)
return int_spans
@staticmethod
def parse_table_data(element: Tag) -> Optional[TableData]: # noqa: C901
nested_tables = element.find("table")
if nested_tables is not None:
_log.debug("Skipping nested table.")
return None
# Find the number of rows and columns (taking into account spans)
num_rows = 0
num_cols = 0
for row in element("tr"):
col_count = 0
is_row_header = True
if not isinstance(row, Tag):
continue
for cell in row(["td", "th"]):
if not isinstance(row, Tag):
continue
cell_tag = cast(Tag, cell)
col_span, row_span = HTMLDocumentBackend._get_cell_spans(cell_tag)
col_count += col_span
if cell_tag.name == "td" or row_span == 1:
is_row_header = False
num_cols = max(num_cols, col_count)
if not is_row_header:
num_rows += 1
_log.debug(f"The table has {num_rows} rows and {num_cols} cols.")
grid: list = [[None for _ in range(num_cols)] for _ in range(num_rows)]
data = TableData(num_rows=num_rows, num_cols=num_cols, table_cells=[])
# Iterate over the rows in the table
start_row_span = 0
row_idx = -1
for row in element("tr"):
if not isinstance(row, Tag):
continue
# For each row, find all the column cells (both <td> and <th>)
cells = row(["td", "th"])
# Check if cell is in a column header or row header
col_header = True
row_header = True
for html_cell in cells:
if isinstance(html_cell, Tag):
_, row_span = HTMLDocumentBackend._get_cell_spans(html_cell)
if html_cell.name == "td":
col_header = False
row_header = False
elif row_span == 1:
row_header = False
if not row_header:
row_idx += 1
start_row_span = 0
else:
start_row_span += 1
# Extract the text content of each cell
col_idx = 0
for html_cell in cells:
if not isinstance(html_cell, Tag):
continue
# extract inline formulas
for formula in html_cell("inline-formula"):
math_parts = formula.text.split("$$")
if len(math_parts) == 3:
math_formula = f"$${math_parts[1]}$$"
formula.replace_with(NavigableString(math_formula))
# TODO: extract content correctly from table-cells with lists
text = HTMLDocumentBackend.get_text(html_cell).strip()
col_span, row_span = HTMLDocumentBackend._get_cell_spans(html_cell)
if row_header:
row_span -= 1
while (
col_idx < num_cols
and grid[row_idx + start_row_span][col_idx] is not None
):
col_idx += 1
for r in range(start_row_span, start_row_span + row_span):
for c in range(col_span):
if row_idx + r < num_rows and col_idx + c < num_cols:
grid[row_idx + r][col_idx + c] = text
table_cell = TableCell(
text=text,
row_span=row_span,
col_span=col_span,
start_row_offset_idx=start_row_span + row_idx,
end_row_offset_idx=start_row_span + row_idx + row_span,
start_col_offset_idx=col_idx,
end_col_offset_idx=col_idx + col_span,
column_header=col_header,
row_header=((not col_header) and html_cell.name == "th"),
)
data.table_cells.append(table_cell)
return data

View File

@@ -2,9 +2,9 @@ import logging
import traceback
from io import BytesIO
from pathlib import Path
from typing import Final, Optional, Union
from typing import Final, Optional, Union, cast
from bs4 import BeautifulSoup, Tag
from bs4 import BeautifulSoup, NavigableString, Tag
from docling_core.types.doc import (
DocItemLabel,
DoclingDocument,
@@ -12,6 +12,8 @@ from docling_core.types.doc import (
GroupItem,
GroupLabel,
NodeItem,
TableCell,
TableData,
TextItem,
)
from lxml import etree
@@ -535,6 +537,110 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
return
@staticmethod
def parse_table_data(element: Tag) -> Optional[TableData]: # noqa: C901
# TODO, see how to implement proper support for rich tables from HTML backend
nested_tables = element.find("table")
if nested_tables is not None:
_log.debug("Skipping nested table.")
return None
# Find the number of rows and columns (taking into account spans)
num_rows = 0
num_cols = 0
for row in element("tr"):
col_count = 0
is_row_header = True
if not isinstance(row, Tag):
continue
for cell in row(["td", "th"]):
if not isinstance(row, Tag):
continue
cell_tag = cast(Tag, cell)
col_span, row_span = HTMLDocumentBackend._get_cell_spans(cell_tag)
col_count += col_span
if cell_tag.name == "td" or row_span == 1:
is_row_header = False
num_cols = max(num_cols, col_count)
if not is_row_header:
num_rows += 1
_log.debug(f"The table has {num_rows} rows and {num_cols} cols.")
grid: list = [[None for _ in range(num_cols)] for _ in range(num_rows)]
data = TableData(num_rows=num_rows, num_cols=num_cols, table_cells=[])
# Iterate over the rows in the table
start_row_span = 0
row_idx = -1
for row in element("tr"):
if not isinstance(row, Tag):
continue
# For each row, find all the column cells (both <td> and <th>)
cells = row(["td", "th"])
# Check if cell is in a column header or row header
col_header = True
row_header = True
for html_cell in cells:
if isinstance(html_cell, Tag):
_, row_span = HTMLDocumentBackend._get_cell_spans(html_cell)
if html_cell.name == "td":
col_header = False
row_header = False
elif row_span == 1:
row_header = False
if not row_header:
row_idx += 1
start_row_span = 0
else:
start_row_span += 1
# Extract the text content of each cell
col_idx = 0
for html_cell in cells:
if not isinstance(html_cell, Tag):
continue
# extract inline formulas
for formula in html_cell("inline-formula"):
math_parts = formula.text.split("$$")
if len(math_parts) == 3:
math_formula = f"$${math_parts[1]}$$"
formula.replace_with(NavigableString(math_formula))
# TODO: extract content correctly from table-cells with lists
text = HTMLDocumentBackend.get_text(html_cell).strip()
col_span, row_span = HTMLDocumentBackend._get_cell_spans(html_cell)
if row_header:
row_span -= 1
while (
col_idx < num_cols
and grid[row_idx + start_row_span][col_idx] is not None
):
col_idx += 1
for r in range(start_row_span, start_row_span + row_span):
for c in range(col_span):
if row_idx + r < num_rows and col_idx + c < num_cols:
grid[row_idx + r][col_idx + c] = text
table_cell = TableCell(
text=text,
row_span=row_span,
col_span=col_span,
start_row_offset_idx=start_row_span + row_idx,
end_row_offset_idx=start_row_span + row_idx + row_span,
start_col_offset_idx=col_idx,
end_col_offset_idx=col_idx + col_span,
column_header=col_header,
row_header=((not col_header) and html_cell.name == "th"),
)
data.table_cells.append(table_cell)
return data
def _add_table(
self, doc: DoclingDocument, parent: NodeItem, table_xml_component: Table
) -> None:
@@ -543,8 +649,7 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
if not isinstance(table_tag, Tag):
return
data = HTMLDocumentBackend.parse_table_data(table_tag)
data = JatsDocumentBackend.parse_table_data(table_tag)
# TODO: format label vs caption once styling is supported
label = table_xml_component["label"]
caption = table_xml_component["caption"]
@@ -554,7 +659,6 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
if table_text
else None
)
if data is not None:
doc.add_table(data=data, parent=parent, caption=table_caption)