mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-11 22:28:31 +00:00
feat: Rich tables support for HTML backend (#2324)
* Rich tables support for HTML backend Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Decoupling JATS backend from HTML backend, ways of creating tables changed significantly Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * updated and added tests Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Refactored parse_table_data in html_backend into few smaller functions Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Changing scope of few functions in html_backend.py, making them static, when possible Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> * Fix for HTML tables that have tbody and/or thead, now these tables are also properly supported Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> --------- Signed-off-by: Maksym Lysak <mly@zurich.ibm.com> Co-authored-by: Maksym Lysak <mly@zurich.ibm.com>
This commit is contained in:
@@ -17,8 +17,11 @@ from docling_core.types.doc import (
|
||||
DocumentOrigin,
|
||||
GroupItem,
|
||||
GroupLabel,
|
||||
RefItem,
|
||||
RichTableCell,
|
||||
TableCell,
|
||||
TableData,
|
||||
TableItem,
|
||||
TextItem,
|
||||
)
|
||||
from docling_core.types.doc.document import ContentLayer, Formatting, Script
|
||||
@@ -276,10 +279,175 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
# reset context
|
||||
self.ctx = _Context()
|
||||
self._walk(content, doc)
|
||||
|
||||
return doc
|
||||
|
||||
def _walk(self, element: Tag, doc: DoclingDocument) -> None:
|
||||
@staticmethod
|
||||
def group_cell_elements(
|
||||
group_name: str,
|
||||
doc: DoclingDocument,
|
||||
provs_in_cell: list[RefItem],
|
||||
docling_table: TableItem,
|
||||
) -> RefItem:
|
||||
group_element = doc.add_group(
|
||||
label=GroupLabel.UNSPECIFIED,
|
||||
name=group_name,
|
||||
parent=docling_table,
|
||||
)
|
||||
for prov in provs_in_cell:
|
||||
group_element.children.append(prov)
|
||||
pr_item = prov.resolve(doc)
|
||||
item_parent = pr_item.parent.resolve(doc)
|
||||
if pr_item.get_ref() in item_parent.children:
|
||||
item_parent.children.remove(pr_item.get_ref())
|
||||
pr_item.parent = group_element.get_ref()
|
||||
ref_for_rich_cell = group_element.get_ref()
|
||||
return ref_for_rich_cell
|
||||
|
||||
@staticmethod
|
||||
def process_rich_table_cells(
|
||||
provs_in_cell: list[RefItem],
|
||||
group_name: str,
|
||||
doc: DoclingDocument,
|
||||
docling_table: TableItem,
|
||||
) -> tuple[bool, RefItem]:
|
||||
rich_table_cell = False
|
||||
ref_for_rich_cell = provs_in_cell[0]
|
||||
if len(provs_in_cell) > 1:
|
||||
# Cell has multiple elements, we need to group them
|
||||
rich_table_cell = True
|
||||
ref_for_rich_cell = HTMLDocumentBackend.group_cell_elements(
|
||||
group_name, doc, provs_in_cell, docling_table
|
||||
)
|
||||
elif len(provs_in_cell) == 1:
|
||||
item_ref = provs_in_cell[0]
|
||||
pr_item = item_ref.resolve(doc)
|
||||
if isinstance(pr_item, TextItem):
|
||||
# Cell has only one element and it's just a text
|
||||
rich_table_cell = False
|
||||
doc.delete_items(node_items=[pr_item])
|
||||
else:
|
||||
rich_table_cell = True
|
||||
ref_for_rich_cell = HTMLDocumentBackend.group_cell_elements(
|
||||
group_name, doc, provs_in_cell, docling_table
|
||||
)
|
||||
|
||||
return rich_table_cell, ref_for_rich_cell
|
||||
|
||||
def parse_table_data(
|
||||
self,
|
||||
element: Tag,
|
||||
doc: DoclingDocument,
|
||||
docling_table: TableItem,
|
||||
num_rows: int,
|
||||
num_cols: int,
|
||||
) -> Optional[TableData]:
|
||||
for t in cast(list[Tag], element.find_all(["thead", "tbody"], recursive=False)):
|
||||
t.unwrap()
|
||||
|
||||
_log.debug(f"The table has {num_rows} rows and {num_cols} cols.")
|
||||
grid: list = [[None for _ in range(num_cols)] for _ in range(num_rows)]
|
||||
data = TableData(num_rows=num_rows, num_cols=num_cols, table_cells=[])
|
||||
|
||||
# Iterate over the rows in the table
|
||||
start_row_span = 0
|
||||
row_idx = -1
|
||||
|
||||
# We don't want this recursive to support nested tables
|
||||
for row in element("tr", recursive=False):
|
||||
if not isinstance(row, Tag):
|
||||
continue
|
||||
# For each row, find all the column cells (both <td> and <th>)
|
||||
# We don't want this recursive to support nested tables
|
||||
cells = row(["td", "th"], recursive=False)
|
||||
# Check if cell is in a column header or row header
|
||||
col_header = True
|
||||
row_header = True
|
||||
for html_cell in cells:
|
||||
if isinstance(html_cell, Tag):
|
||||
_, row_span = HTMLDocumentBackend._get_cell_spans(html_cell)
|
||||
if html_cell.name == "td":
|
||||
col_header = False
|
||||
row_header = False
|
||||
elif row_span == 1:
|
||||
row_header = False
|
||||
if not row_header:
|
||||
row_idx += 1
|
||||
start_row_span = 0
|
||||
else:
|
||||
start_row_span += 1
|
||||
|
||||
# Extract the text content of each cell
|
||||
col_idx = 0
|
||||
for html_cell in cells:
|
||||
if not isinstance(html_cell, Tag):
|
||||
continue
|
||||
|
||||
# extract inline formulas
|
||||
for formula in html_cell("inline-formula"):
|
||||
math_parts = formula.text.split("$$")
|
||||
if len(math_parts) == 3:
|
||||
math_formula = f"$${math_parts[1]}$$"
|
||||
formula.replace_with(NavigableString(math_formula))
|
||||
|
||||
provs_in_cell: list[RefItem] = []
|
||||
# Parse table cell sub-tree for Rich Cells content:
|
||||
provs_in_cell = self._walk(html_cell, doc)
|
||||
|
||||
rich_table_cell = False
|
||||
ref_for_rich_cell = None
|
||||
if len(provs_in_cell) > 0:
|
||||
group_name = f"rich_cell_group_{len(doc.tables)}_{col_idx}_{start_row_span + row_idx}"
|
||||
rich_table_cell, ref_for_rich_cell = (
|
||||
HTMLDocumentBackend.process_rich_table_cells(
|
||||
provs_in_cell, group_name, doc, docling_table
|
||||
)
|
||||
)
|
||||
|
||||
# Extracting text
|
||||
text = self.get_text(html_cell).strip()
|
||||
col_span, row_span = self._get_cell_spans(html_cell)
|
||||
if row_header:
|
||||
row_span -= 1
|
||||
while (
|
||||
col_idx < num_cols
|
||||
and grid[row_idx + start_row_span][col_idx] is not None
|
||||
):
|
||||
col_idx += 1
|
||||
for r in range(start_row_span, start_row_span + row_span):
|
||||
for c in range(col_span):
|
||||
if row_idx + r < num_rows and col_idx + c < num_cols:
|
||||
grid[row_idx + r][col_idx + c] = text
|
||||
|
||||
if rich_table_cell:
|
||||
rich_cell = RichTableCell(
|
||||
text=text,
|
||||
row_span=row_span,
|
||||
col_span=col_span,
|
||||
start_row_offset_idx=start_row_span + row_idx,
|
||||
end_row_offset_idx=start_row_span + row_idx + row_span,
|
||||
start_col_offset_idx=col_idx,
|
||||
end_col_offset_idx=col_idx + col_span,
|
||||
column_header=col_header,
|
||||
row_header=((not col_header) and html_cell.name == "th"),
|
||||
ref=ref_for_rich_cell, # points to an artificial group around children
|
||||
)
|
||||
doc.add_table_cell(table_item=docling_table, cell=rich_cell)
|
||||
else:
|
||||
simple_cell = TableCell(
|
||||
text=text,
|
||||
row_span=row_span,
|
||||
col_span=col_span,
|
||||
start_row_offset_idx=start_row_span + row_idx,
|
||||
end_row_offset_idx=start_row_span + row_idx + row_span,
|
||||
start_col_offset_idx=col_idx,
|
||||
end_col_offset_idx=col_idx + col_span,
|
||||
column_header=col_header,
|
||||
row_header=((not col_header) and html_cell.name == "th"),
|
||||
)
|
||||
doc.add_table_cell(table_item=docling_table, cell=simple_cell)
|
||||
return data
|
||||
|
||||
def _walk(self, element: Tag, doc: DoclingDocument) -> list[RefItem]:
|
||||
"""Parse an XML tag by recursively walking its content.
|
||||
|
||||
While walking, the method buffers inline text across tags like <b> or <span>,
|
||||
@@ -289,17 +457,18 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
element: The XML tag to parse.
|
||||
doc: The Docling document to be updated with the parsed content.
|
||||
"""
|
||||
added_refs: list[RefItem] = []
|
||||
buffer: AnnotatedTextList = AnnotatedTextList()
|
||||
|
||||
def flush_buffer():
|
||||
if not buffer:
|
||||
return
|
||||
return added_refs
|
||||
annotated_text_list: AnnotatedTextList = buffer.simplify_text_elements()
|
||||
parts = annotated_text_list.split_by_newline()
|
||||
buffer.clear()
|
||||
|
||||
if not "".join([el.text for el in annotated_text_list]):
|
||||
return
|
||||
return added_refs
|
||||
|
||||
for annotated_text_list in parts:
|
||||
with self._use_inline_group(annotated_text_list, doc):
|
||||
@@ -309,15 +478,16 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
annotated_text.text.strip()
|
||||
)
|
||||
if annotated_text.code:
|
||||
doc.add_code(
|
||||
docling_code2 = doc.add_code(
|
||||
parent=self.parents[self.level],
|
||||
text=seg_clean,
|
||||
content_layer=self.content_layer,
|
||||
formatting=annotated_text.formatting,
|
||||
hyperlink=annotated_text.hyperlink,
|
||||
)
|
||||
added_refs.append(docling_code2.get_ref())
|
||||
else:
|
||||
doc.add_text(
|
||||
docling_text2 = doc.add_text(
|
||||
parent=self.parents[self.level],
|
||||
label=DocItemLabel.TEXT,
|
||||
text=seg_clean,
|
||||
@@ -325,25 +495,31 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
formatting=annotated_text.formatting,
|
||||
hyperlink=annotated_text.hyperlink,
|
||||
)
|
||||
added_refs.append(docling_text2.get_ref())
|
||||
|
||||
for node in element.contents:
|
||||
if isinstance(node, Tag):
|
||||
name = node.name.lower()
|
||||
if name == "img":
|
||||
flush_buffer()
|
||||
self._emit_image(node, doc)
|
||||
im_ref3 = self._emit_image(node, doc)
|
||||
added_refs.append(im_ref3)
|
||||
elif name in _FORMAT_TAG_MAP:
|
||||
with self._use_format([name]):
|
||||
self._walk(node, doc)
|
||||
wk = self._walk(node, doc)
|
||||
added_refs.extend(wk)
|
||||
elif name == "a":
|
||||
with self._use_hyperlink(node):
|
||||
self._walk(node, doc)
|
||||
wk2 = self._walk(node, doc)
|
||||
added_refs.extend(wk2)
|
||||
elif name in _BLOCK_TAGS:
|
||||
flush_buffer()
|
||||
self._handle_block(node, doc)
|
||||
blk = self._handle_block(node, doc)
|
||||
added_refs.extend(blk)
|
||||
elif node.find(_BLOCK_TAGS):
|
||||
flush_buffer()
|
||||
self._walk(node, doc)
|
||||
wk3 = self._walk(node, doc)
|
||||
added_refs.extend(wk3)
|
||||
else:
|
||||
buffer.extend(
|
||||
self._extract_text_and_hyperlink_recursively(
|
||||
@@ -363,6 +539,7 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
)
|
||||
|
||||
flush_buffer()
|
||||
return added_refs
|
||||
|
||||
@staticmethod
|
||||
def _collect_parent_format_tags(item: PageElement) -> list[str]:
|
||||
@@ -581,7 +758,8 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
self.level -= 1
|
||||
self.content_layer = current_layer
|
||||
|
||||
def _handle_heading(self, tag: Tag, doc: DoclingDocument) -> None:
|
||||
def _handle_heading(self, tag: Tag, doc: DoclingDocument) -> list[RefItem]:
|
||||
added_ref = []
|
||||
tag_name = tag.name.lower()
|
||||
# set default content layer to BODY as soon as we encounter a heading
|
||||
self.content_layer = ContentLayer.BODY
|
||||
@@ -596,12 +774,13 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
for key in self.parents.keys():
|
||||
self.parents[key] = None
|
||||
self.level = 0
|
||||
self.parents[self.level + 1] = doc.add_title(
|
||||
docling_title = self.parents[self.level + 1] = doc.add_title(
|
||||
text_clean,
|
||||
content_layer=self.content_layer,
|
||||
formatting=annotated_text.formatting,
|
||||
hyperlink=annotated_text.hyperlink,
|
||||
)
|
||||
added_ref = [docling_title.get_ref()]
|
||||
# the other levels need to be lowered by 1 if a title was set
|
||||
else:
|
||||
level -= 1
|
||||
@@ -623,7 +802,7 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
_log.debug(f"Remove the tail of level {key}")
|
||||
self.parents[key] = None
|
||||
self.level = level
|
||||
self.parents[self.level + 1] = doc.add_heading(
|
||||
docling_heading = self.parents[self.level + 1] = doc.add_heading(
|
||||
parent=self.parents[self.level],
|
||||
text=text_clean,
|
||||
orig=annotated_text.text,
|
||||
@@ -632,12 +811,15 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
formatting=annotated_text.formatting,
|
||||
hyperlink=annotated_text.hyperlink,
|
||||
)
|
||||
added_ref = [docling_heading.get_ref()]
|
||||
self.level += 1
|
||||
for img_tag in tag("img"):
|
||||
if isinstance(img_tag, Tag):
|
||||
self._emit_image(img_tag, doc)
|
||||
im_ref = self._emit_image(img_tag, doc)
|
||||
added_ref.append(im_ref)
|
||||
return added_ref
|
||||
|
||||
def _handle_list(self, tag: Tag, doc: DoclingDocument) -> None:
|
||||
def _handle_list(self, tag: Tag, doc: DoclingDocument) -> RefItem:
|
||||
tag_name = tag.name.lower()
|
||||
start: Optional[int] = None
|
||||
name: str = ""
|
||||
@@ -765,20 +947,50 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
|
||||
self.parents[self.level + 1] = None
|
||||
self.level -= 1
|
||||
return list_group.get_ref()
|
||||
|
||||
def _handle_block(self, tag: Tag, doc: DoclingDocument) -> None:
|
||||
@staticmethod
|
||||
def get_html_table_row_col(tag: Tag) -> tuple[int, int]:
|
||||
for t in cast(list[Tag], tag.find_all(["thead", "tbody"], recursive=False)):
|
||||
t.unwrap()
|
||||
# Find the number of rows and columns (taking into account spans)
|
||||
num_rows: int = 0
|
||||
num_cols: int = 0
|
||||
for row in tag("tr", recursive=False):
|
||||
col_count = 0
|
||||
is_row_header = True
|
||||
if not isinstance(row, Tag):
|
||||
continue
|
||||
for cell in row(["td", "th"], recursive=False):
|
||||
if not isinstance(row, Tag):
|
||||
continue
|
||||
cell_tag = cast(Tag, cell)
|
||||
col_span, row_span = HTMLDocumentBackend._get_cell_spans(cell_tag)
|
||||
col_count += col_span
|
||||
if cell_tag.name == "td" or row_span == 1:
|
||||
is_row_header = False
|
||||
num_cols = max(num_cols, col_count)
|
||||
if not is_row_header:
|
||||
num_rows += 1
|
||||
return num_rows, num_cols
|
||||
|
||||
def _handle_block(self, tag: Tag, doc: DoclingDocument) -> list[RefItem]:
|
||||
added_refs = []
|
||||
tag_name = tag.name.lower()
|
||||
|
||||
if tag_name == "figure":
|
||||
img_tag = tag.find("img")
|
||||
if isinstance(img_tag, Tag):
|
||||
self._emit_image(img_tag, doc)
|
||||
im_ref = self._emit_image(img_tag, doc)
|
||||
added_refs.append(im_ref)
|
||||
|
||||
elif tag_name in {"h1", "h2", "h3", "h4", "h5", "h6"}:
|
||||
self._handle_heading(tag, doc)
|
||||
heading_refs = self._handle_heading(tag, doc)
|
||||
added_refs.extend(heading_refs)
|
||||
|
||||
elif tag_name in {"ul", "ol"}:
|
||||
self._handle_list(tag, doc)
|
||||
list_ref = self._handle_list(tag, doc)
|
||||
added_refs.append(list_ref)
|
||||
|
||||
elif tag_name in {"p", "address", "summary"}:
|
||||
text_list = self._extract_text_and_hyperlink_recursively(
|
||||
@@ -791,15 +1003,16 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
if seg := annotated_text.text.strip():
|
||||
seg_clean = HTMLDocumentBackend._clean_unicode(seg)
|
||||
if annotated_text.code:
|
||||
doc.add_code(
|
||||
docling_code = doc.add_code(
|
||||
parent=self.parents[self.level],
|
||||
text=seg_clean,
|
||||
content_layer=self.content_layer,
|
||||
formatting=annotated_text.formatting,
|
||||
hyperlink=annotated_text.hyperlink,
|
||||
)
|
||||
added_refs.append(docling_code.get_ref())
|
||||
else:
|
||||
doc.add_text(
|
||||
docling_text = doc.add_text(
|
||||
parent=self.parents[self.level],
|
||||
label=DocItemLabel.TEXT,
|
||||
text=seg_clean,
|
||||
@@ -807,22 +1020,27 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
formatting=annotated_text.formatting,
|
||||
hyperlink=annotated_text.hyperlink,
|
||||
)
|
||||
added_refs.append(docling_text.get_ref())
|
||||
|
||||
for img_tag in tag("img"):
|
||||
if isinstance(img_tag, Tag):
|
||||
self._emit_image(img_tag, doc)
|
||||
|
||||
elif tag_name == "table":
|
||||
data = HTMLDocumentBackend.parse_table_data(tag)
|
||||
num_rows, num_cols = self.get_html_table_row_col(tag)
|
||||
data_e = TableData(num_rows=num_rows, num_cols=num_cols)
|
||||
docling_table = doc.add_table(
|
||||
data=data_e,
|
||||
parent=self.parents[self.level],
|
||||
content_layer=self.content_layer,
|
||||
)
|
||||
added_refs.append(docling_table.get_ref())
|
||||
self.parse_table_data(tag, doc, docling_table, num_rows, num_cols)
|
||||
|
||||
for img_tag in tag("img"):
|
||||
if isinstance(img_tag, Tag):
|
||||
self._emit_image(tag, doc)
|
||||
if data is not None:
|
||||
doc.add_table(
|
||||
data=data,
|
||||
parent=self.parents[self.level],
|
||||
content_layer=self.content_layer,
|
||||
)
|
||||
im_ref2 = self._emit_image(tag, doc)
|
||||
added_refs.append(im_ref2)
|
||||
|
||||
elif tag_name in {"pre"}:
|
||||
# handle monospace code snippets (pre).
|
||||
@@ -835,13 +1053,14 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
text_clean = HTMLDocumentBackend._clean_unicode(
|
||||
annotated_text.text.strip()
|
||||
)
|
||||
doc.add_code(
|
||||
docling_code2 = doc.add_code(
|
||||
parent=self.parents[self.level],
|
||||
text=text_clean,
|
||||
content_layer=self.content_layer,
|
||||
formatting=annotated_text.formatting,
|
||||
hyperlink=annotated_text.hyperlink,
|
||||
)
|
||||
added_refs.append(docling_code2.get_ref())
|
||||
|
||||
elif tag_name == "footer":
|
||||
with self._use_footer(tag, doc):
|
||||
@@ -850,8 +1069,9 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
elif tag_name == "details":
|
||||
with self._use_details(tag, doc):
|
||||
self._walk(tag, doc)
|
||||
return added_refs
|
||||
|
||||
def _emit_image(self, img_tag: Tag, doc: DoclingDocument) -> None:
|
||||
def _emit_image(self, img_tag: Tag, doc: DoclingDocument) -> RefItem:
|
||||
figure = img_tag.find_parent("figure")
|
||||
caption: AnnotatedTextList = AnnotatedTextList()
|
||||
|
||||
@@ -894,11 +1114,12 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
hyperlink=caption_anno_text.hyperlink,
|
||||
)
|
||||
|
||||
doc.add_picture(
|
||||
docling_pic = doc.add_picture(
|
||||
caption=caption_item,
|
||||
parent=self.parents[self.level],
|
||||
content_layer=self.content_layer,
|
||||
)
|
||||
return docling_pic.get_ref()
|
||||
|
||||
@staticmethod
|
||||
def get_text(item: PageElement) -> str:
|
||||
@@ -996,106 +1217,3 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
|
||||
)
|
||||
|
||||
return int_spans
|
||||
|
||||
@staticmethod
|
||||
def parse_table_data(element: Tag) -> Optional[TableData]: # noqa: C901
|
||||
nested_tables = element.find("table")
|
||||
if nested_tables is not None:
|
||||
_log.debug("Skipping nested table.")
|
||||
return None
|
||||
|
||||
# Find the number of rows and columns (taking into account spans)
|
||||
num_rows = 0
|
||||
num_cols = 0
|
||||
for row in element("tr"):
|
||||
col_count = 0
|
||||
is_row_header = True
|
||||
if not isinstance(row, Tag):
|
||||
continue
|
||||
for cell in row(["td", "th"]):
|
||||
if not isinstance(row, Tag):
|
||||
continue
|
||||
cell_tag = cast(Tag, cell)
|
||||
col_span, row_span = HTMLDocumentBackend._get_cell_spans(cell_tag)
|
||||
col_count += col_span
|
||||
if cell_tag.name == "td" or row_span == 1:
|
||||
is_row_header = False
|
||||
num_cols = max(num_cols, col_count)
|
||||
if not is_row_header:
|
||||
num_rows += 1
|
||||
|
||||
_log.debug(f"The table has {num_rows} rows and {num_cols} cols.")
|
||||
|
||||
grid: list = [[None for _ in range(num_cols)] for _ in range(num_rows)]
|
||||
|
||||
data = TableData(num_rows=num_rows, num_cols=num_cols, table_cells=[])
|
||||
|
||||
# Iterate over the rows in the table
|
||||
start_row_span = 0
|
||||
row_idx = -1
|
||||
for row in element("tr"):
|
||||
if not isinstance(row, Tag):
|
||||
continue
|
||||
|
||||
# For each row, find all the column cells (both <td> and <th>)
|
||||
cells = row(["td", "th"])
|
||||
|
||||
# Check if cell is in a column header or row header
|
||||
col_header = True
|
||||
row_header = True
|
||||
for html_cell in cells:
|
||||
if isinstance(html_cell, Tag):
|
||||
_, row_span = HTMLDocumentBackend._get_cell_spans(html_cell)
|
||||
if html_cell.name == "td":
|
||||
col_header = False
|
||||
row_header = False
|
||||
elif row_span == 1:
|
||||
row_header = False
|
||||
if not row_header:
|
||||
row_idx += 1
|
||||
start_row_span = 0
|
||||
else:
|
||||
start_row_span += 1
|
||||
|
||||
# Extract the text content of each cell
|
||||
col_idx = 0
|
||||
for html_cell in cells:
|
||||
if not isinstance(html_cell, Tag):
|
||||
continue
|
||||
|
||||
# extract inline formulas
|
||||
for formula in html_cell("inline-formula"):
|
||||
math_parts = formula.text.split("$$")
|
||||
if len(math_parts) == 3:
|
||||
math_formula = f"$${math_parts[1]}$$"
|
||||
formula.replace_with(NavigableString(math_formula))
|
||||
|
||||
# TODO: extract content correctly from table-cells with lists
|
||||
text = HTMLDocumentBackend.get_text(html_cell).strip()
|
||||
col_span, row_span = HTMLDocumentBackend._get_cell_spans(html_cell)
|
||||
if row_header:
|
||||
row_span -= 1
|
||||
while (
|
||||
col_idx < num_cols
|
||||
and grid[row_idx + start_row_span][col_idx] is not None
|
||||
):
|
||||
col_idx += 1
|
||||
for r in range(start_row_span, start_row_span + row_span):
|
||||
for c in range(col_span):
|
||||
if row_idx + r < num_rows and col_idx + c < num_cols:
|
||||
grid[row_idx + r][col_idx + c] = text
|
||||
|
||||
table_cell = TableCell(
|
||||
text=text,
|
||||
row_span=row_span,
|
||||
col_span=col_span,
|
||||
start_row_offset_idx=start_row_span + row_idx,
|
||||
end_row_offset_idx=start_row_span + row_idx + row_span,
|
||||
start_col_offset_idx=col_idx,
|
||||
end_col_offset_idx=col_idx + col_span,
|
||||
column_header=col_header,
|
||||
row_header=((not col_header) and html_cell.name == "th"),
|
||||
)
|
||||
data.table_cells.append(table_cell)
|
||||
|
||||
return data
|
||||
|
||||
@@ -2,9 +2,9 @@ import logging
|
||||
import traceback
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Final, Optional, Union
|
||||
from typing import Final, Optional, Union, cast
|
||||
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
from bs4 import BeautifulSoup, NavigableString, Tag
|
||||
from docling_core.types.doc import (
|
||||
DocItemLabel,
|
||||
DoclingDocument,
|
||||
@@ -12,6 +12,8 @@ from docling_core.types.doc import (
|
||||
GroupItem,
|
||||
GroupLabel,
|
||||
NodeItem,
|
||||
TableCell,
|
||||
TableData,
|
||||
TextItem,
|
||||
)
|
||||
from lxml import etree
|
||||
@@ -535,6 +537,110 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
|
||||
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def parse_table_data(element: Tag) -> Optional[TableData]: # noqa: C901
|
||||
# TODO, see how to implement proper support for rich tables from HTML backend
|
||||
nested_tables = element.find("table")
|
||||
if nested_tables is not None:
|
||||
_log.debug("Skipping nested table.")
|
||||
return None
|
||||
|
||||
# Find the number of rows and columns (taking into account spans)
|
||||
num_rows = 0
|
||||
num_cols = 0
|
||||
for row in element("tr"):
|
||||
col_count = 0
|
||||
is_row_header = True
|
||||
if not isinstance(row, Tag):
|
||||
continue
|
||||
for cell in row(["td", "th"]):
|
||||
if not isinstance(row, Tag):
|
||||
continue
|
||||
cell_tag = cast(Tag, cell)
|
||||
col_span, row_span = HTMLDocumentBackend._get_cell_spans(cell_tag)
|
||||
col_count += col_span
|
||||
if cell_tag.name == "td" or row_span == 1:
|
||||
is_row_header = False
|
||||
num_cols = max(num_cols, col_count)
|
||||
if not is_row_header:
|
||||
num_rows += 1
|
||||
|
||||
_log.debug(f"The table has {num_rows} rows and {num_cols} cols.")
|
||||
|
||||
grid: list = [[None for _ in range(num_cols)] for _ in range(num_rows)]
|
||||
|
||||
data = TableData(num_rows=num_rows, num_cols=num_cols, table_cells=[])
|
||||
|
||||
# Iterate over the rows in the table
|
||||
start_row_span = 0
|
||||
row_idx = -1
|
||||
for row in element("tr"):
|
||||
if not isinstance(row, Tag):
|
||||
continue
|
||||
|
||||
# For each row, find all the column cells (both <td> and <th>)
|
||||
cells = row(["td", "th"])
|
||||
|
||||
# Check if cell is in a column header or row header
|
||||
col_header = True
|
||||
row_header = True
|
||||
for html_cell in cells:
|
||||
if isinstance(html_cell, Tag):
|
||||
_, row_span = HTMLDocumentBackend._get_cell_spans(html_cell)
|
||||
if html_cell.name == "td":
|
||||
col_header = False
|
||||
row_header = False
|
||||
elif row_span == 1:
|
||||
row_header = False
|
||||
if not row_header:
|
||||
row_idx += 1
|
||||
start_row_span = 0
|
||||
else:
|
||||
start_row_span += 1
|
||||
|
||||
# Extract the text content of each cell
|
||||
col_idx = 0
|
||||
for html_cell in cells:
|
||||
if not isinstance(html_cell, Tag):
|
||||
continue
|
||||
|
||||
# extract inline formulas
|
||||
for formula in html_cell("inline-formula"):
|
||||
math_parts = formula.text.split("$$")
|
||||
if len(math_parts) == 3:
|
||||
math_formula = f"$${math_parts[1]}$$"
|
||||
formula.replace_with(NavigableString(math_formula))
|
||||
|
||||
# TODO: extract content correctly from table-cells with lists
|
||||
text = HTMLDocumentBackend.get_text(html_cell).strip()
|
||||
col_span, row_span = HTMLDocumentBackend._get_cell_spans(html_cell)
|
||||
if row_header:
|
||||
row_span -= 1
|
||||
while (
|
||||
col_idx < num_cols
|
||||
and grid[row_idx + start_row_span][col_idx] is not None
|
||||
):
|
||||
col_idx += 1
|
||||
for r in range(start_row_span, start_row_span + row_span):
|
||||
for c in range(col_span):
|
||||
if row_idx + r < num_rows and col_idx + c < num_cols:
|
||||
grid[row_idx + r][col_idx + c] = text
|
||||
|
||||
table_cell = TableCell(
|
||||
text=text,
|
||||
row_span=row_span,
|
||||
col_span=col_span,
|
||||
start_row_offset_idx=start_row_span + row_idx,
|
||||
end_row_offset_idx=start_row_span + row_idx + row_span,
|
||||
start_col_offset_idx=col_idx,
|
||||
end_col_offset_idx=col_idx + col_span,
|
||||
column_header=col_header,
|
||||
row_header=((not col_header) and html_cell.name == "th"),
|
||||
)
|
||||
data.table_cells.append(table_cell)
|
||||
|
||||
return data
|
||||
|
||||
def _add_table(
|
||||
self, doc: DoclingDocument, parent: NodeItem, table_xml_component: Table
|
||||
) -> None:
|
||||
@@ -543,8 +649,7 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
|
||||
if not isinstance(table_tag, Tag):
|
||||
return
|
||||
|
||||
data = HTMLDocumentBackend.parse_table_data(table_tag)
|
||||
|
||||
data = JatsDocumentBackend.parse_table_data(table_tag)
|
||||
# TODO: format label vs caption once styling is supported
|
||||
label = table_xml_component["label"]
|
||||
caption = table_xml_component["caption"]
|
||||
@@ -554,7 +659,6 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
|
||||
if table_text
|
||||
else None
|
||||
)
|
||||
|
||||
if data is not None:
|
||||
doc.add_table(data=data, parent=parent, caption=table_caption)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user