# Advanced Chunking

In [1]:
# %pip install -qU docling docling-core sentence-transformers transformers semchunk lancedb pydantic

# FIXME temp install line
%pip install -qU "docling-core @ git+https://github.com/DS4SD/docling-core.git@expand-chunking" sentence-transformers transformers semchunk lancedb pydantic

Note: you may need to restart the kernel to use updated packages.


## Setup

In [2]:
import warnings
from dataclasses import dataclass
from pathlib import Path
from tempfile import mkdtemp
from typing import Iterator, Optional, Self, Union

import lancedb
import semchunk
from docling_core.transforms.chunker import (
    BaseChunk,
    BaseChunker,
    DocChunk,
    DocMeta,
    HierarchicalChunker,
)
from docling_core.types import DoclingDocument
from pydantic import ConfigDict, PositiveInt, TypeAdapter, model_validator
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from docling.document_converter import DocumentConverter

In [3]:
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
MAX_TOKENS = 64
DOC_SOURCE = "http://bill.murdocks.org/iccbr2011murdock_web.pdf"

tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_ID)
embed_model = SentenceTransformer(EMBED_MODEL_ID)

## Chunker Definition

In [4]:
class DocChunker(BaseChunker):

    model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)

    tokenizer: PreTrainedTokenizerBase

    inner_chunker: BaseChunker = HierarchicalChunker()
    max_tokens: int = None  # actual dflt value resolved in validator based on tokenizer
    delim: str = "\n"

    @model_validator(mode="after")
    def patch_max_tokens(self) -> Self:
        if self.max_tokens is None:
            print(f"{self.tokenizer.model_max_length=}")
            self.max_tokens = TypeAdapter(PositiveInt).validate_python(
                self.tokenizer.model_max_length
            )
        return self

    def _count_tokens(self, text: Optional[Union[str, list[str]]]):
        if text is None:
            return 0
        elif isinstance(text, list):
            total = 0
            for t in text:
                total += self._count_tokens(t)
            return total
        return len(self.tokenizer.tokenize(text, max_length=None))

    @dataclass
    class _ChunkLengthInfo:
        total_len: int
        text_len: int
        other_len: int

    def _doc_chunk_length(self, doc_chunk: DocChunk):
        text_length = self._count_tokens(doc_chunk.text)
        # Note that count_tokens handles None and lists, making this code simpler:
        # TODO check if delim properly considered
        headings_length = self._count_tokens(doc_chunk.meta.headings)
        captions_length = self._count_tokens(doc_chunk.meta.captions)
        total = text_length + headings_length + captions_length
        return self._ChunkLengthInfo(
            total_len=total,
            text_len=text_length,
            other_len=total - text_length,
        )

    def _make_chunk_from_doc_items(
        self, doc_chunk: DocChunk, window_text: str, window_start: int, window_end: int
    ):
        meta = DocMeta(
            doc_items=doc_chunk.meta.doc_items[window_start : window_end + 1],
            headings=doc_chunk.meta.headings,
            captions=doc_chunk.meta.captions,
        )
        new_chunk = DocChunk.from_data(text=window_text, meta=meta, delim=self.delim)
        return new_chunk

    def _merge_text(self, t1, t2):
        if t1 == "":
            return t2
        elif t2 == "":
            return t1
        else:
            return f"{t1}{self.delim}{t2}"

    def _split_by_doc_items(self, doc_chunk: DocChunk) -> list[DocChunk]:
        if doc_chunk.meta.doc_items == None or len(doc_chunk.meta.doc_items) <= 1:
            return [doc_chunk]
        length = self._doc_chunk_length(doc_chunk)
        if length.total_len <= self.max_tokens:
            return [doc_chunk]
        else:
            chunks = []
            window_start = 0
            window_end = 0
            window_text = ""
            window_text_length = 0
            other_length = length.other_len
            l = len(doc_chunk.meta.doc_items)
            while window_end < l:
                doc_item = doc_chunk.meta.doc_items[window_end]
                text = doc_item.text
                text_length = self._count_tokens(text)
                if (
                    text_length + window_text_length + other_length < self.max_tokens
                    and window_end < l - 1
                ):
                    # Still room left to add more to this chunk AND still at least one item left
                    window_end += 1
                    window_text_length += text_length
                    window_text = self._merge_text(window_text, text)
                elif text_length + window_text_length + other_length < self.max_tokens:
                    # All the items in the window fit into the chunk and there are no other items left
                    window_text = self._merge_text(window_text, text)
                    new_chunk = self._make_chunk_from_doc_items(
                        doc_chunk, window_text, window_start, window_end
                    )
                    chunks.append(new_chunk)
                    window_end = l
                elif window_start == window_end:
                    # Only one item in the window and it doesn't fit into the chunk.  So we'll just make it a chunk for now and it will get split in the plain text splitter.
                    window_text = self._merge_text(window_text, text)
                    new_chunk = self._make_chunk_from_doc_items(
                        doc_chunk, window_text, window_start, window_end
                    )
                    chunks.append(new_chunk)
                    window_start = window_end + 1
                    window_end = window_start
                    window_text = ""
                    window_text_length = 0
                else:
                    # Multiple items in the window but they don't fit into the chunk.  However, the existing items must have fit or we wouldn't have gotten here.
                    # So we put everything but the last item into the chunk and then start a new window INCLUDING the current window end.
                    new_chunk = self._make_chunk_from_doc_items(
                        doc_chunk, window_text, window_start, window_end - 1
                    )
                    chunks.append(new_chunk)
                    window_start = window_end
                    window_text = ""
                    window_text_length = 0
            return chunks

    def _split_using_plain_text(
        self,
        doc_chunk: DocChunk,
    ) -> list[DocChunk]:
        lengths = self._doc_chunk_length(doc_chunk)
        if lengths.total_len <= self.max_tokens:
            return [
                DocChunk.from_data(
                    delim=self.delim,
                    **doc_chunk.export_json_dict(),
                )
            ]
        else:

            # How much room is there for text after subtracting out the headers and captions:
            available_length = self.max_tokens - lengths.other_len
            sem_chunker = semchunk.chunkerify(
                self.tokenizer, chunk_size=available_length
            )
            if available_length <= 0:
                warnings.warn(
                    f"Headers and captions for this chunk are longer than the total amount of size for the chunk. Chunk will be ignored."
                )
                return []
            text = doc_chunk.text
            segments = sem_chunker.chunk(text)
            chunks = [
                DocChunk.from_data(text=s, meta=doc_chunk.meta, delim=self.delim)
                for s in segments
            ]
            return chunks

    def _merge_chunks_with_matching_metadata(self, chunks: list[DocChunk]):
        output_chunks = []
        window_start = 0
        window_end = 0
        l = len(chunks)
        while window_end < l:
            chunk = chunks[window_end]
            lengths = self._doc_chunk_length(chunk)
            headings_and_captions = (chunk.meta.headings, chunk.meta.captions)
            if window_start == window_end:
                # starting a new block of chunks to potentially merge
                current_headings_and_captions = headings_and_captions
                window_text = chunk.text
                window_other_length = lengths.other_len
                window_text_length = lengths.text_len
                window_items = chunk.meta.doc_items
                window_end += 1
                first_chunk_of_window = chunk
            elif (
                headings_and_captions == current_headings_and_captions
                and window_text_length + window_other_length + lengths.text_len
                <= self.max_tokens
            ):
                # there is room to include the new chunk so add it to the window and continue
                window_text = self._merge_text(window_text, chunk.text)
                window_text_length += lengths.text_len
                window_items = window_items + chunk.meta.doc_items
                window_end += 1
            else:
                # no more room OR the start of new metadata.  Either way, end the block and use the current window_end as the start of a new block
                if window_start + 1 == window_end:
                    # just one chunk so use it as is
                    output_chunks.append(first_chunk_of_window)
                else:
                    new_meta = DocMeta(
                        doc_items=window_items,
                        headings=headings_and_captions[0],
                        captions=headings_and_captions[1],
                    )
                    new_chunk = DocChunk.from_data(
                        text=window_text,
                        meta=new_meta,
                        delim=self.delim,
                    )
                    output_chunks.append(new_chunk)
                window_start = window_end  # no need to reset window_text, etc. because that will be reset in the next iteration in the if window_start == window_end block

        return output_chunks

    def _merge_chunks(self, chunks: list[DocChunk]) -> list[DocChunk]:
        res = chunks
        # merges as many chunks as possible that have the same headings+captions.
        res = self._merge_chunks_with_matching_metadata(res)
        # merges chunks with different headings+captions.  This is later so that merges within a section or other grouping are preferred.
        # res = self._merge_chunks_with_mismatching_metadata(res)
        return res

    def _adjust_chunks_for_fixed_size(self, chunks: list[DocChunk]):
        res = chunks
        res = [x for c in res for x in self._split_by_doc_items(c)]
        res = [x for c in res for x in self._split_using_plain_text(c)]
        res = self._merge_chunks(res)
        return res

    def chunk(self, dl_doc: DoclingDocument, **kwargs) -> Iterator[BaseChunk]:
        preliminary_chunks = self.inner_chunker.chunk(dl_doc=dl_doc, **kwargs)
        output_chunks = self._adjust_chunks_for_fixed_size(preliminary_chunks)
        return iter(output_chunks)

## Usage

In [5]:
conv_res = DocumentConverter().convert(source=DOC_SOURCE)
doc = conv_res.document

In [6]:
chunker = DocChunker(
    tokenizer=tokenizer,
    max_tokens=MAX_TOKENS,  # optional, derived from `tokenizer` if not provided
)
chunks = list(chunker.chunk(dl_doc=doc))

for chunk in chunks[:5]:
    txt_tokens = len(tokenizer.tokenize(chunk.text, max_length=None))
    print(f"chunk.text ({txt_tokens} tokens):\n{repr(chunk.text)}")
    emb_txt = chunk.get_text_for_embedding()
    emb_tokens = len(tokenizer.tokenize(emb_txt, max_length=None))
    print(f"chunk.get_text_for_embedding() ({emb_tokens} tokens):\n{repr(emb_txt)}")
    gen_txt = chunk.get_text_for_generation()
    gen_tokens = len(tokenizer.tokenize(gen_txt, max_length=None))
    print(f"chunk.get_text_for_generation() ({gen_tokens} tokens):\n{repr(gen_txt)}")
    print()

chunk.text (33 tokens):
'murdockj@us.ibm.com IBM T.J. Watson Research Center P.O. Box 704 Yorktown Heights, NY 10598'
chunk.get_text_for_embedding() (39 tokens):
'J. William Murdock\nmurdockj@us.ibm.com IBM T.J. Watson Research Center P.O. Box 704 Yorktown Heights, NY 10598'
chunk.get_text_for_generation() (39 tokens):
'J. William Murdock\nmurdockj@us.ibm.com IBM T.J. Watson Research Center P.O. Box 704 Yorktown Heights, NY 10598'

chunk.text (58 tokens):
'Abstract. The Jeopardy! television quiz show asks natural-language questions and requires natural-language answers. One useful source of information for answering Jeopardy! questions is text from written sources such as encyclopedias or news articles. A text passage may partially or fully indicate that some candidate answer is the correct'
chunk.get_text_for_embedding() (64 tokens):
'J. William Murdock\nAbstract. The Jeopardy! television quiz show asks natural-language questions and requires natural-language answers. One useful sourc

## Vector Retrieval

In [7]:
def make_lancedb_index(db_uri, index_name, chunks: list[DocChunk], embedding_model):
    db = lancedb.connect(db_uri)
    data = []
    for chunk in chunks:
        embeddings = embedding_model.encode(chunk.get_text_for_embedding())
        data_item = {
            "vector": embeddings,
            "text": chunk.text,
            "headings": chunk.meta.headings,
            "captions": chunk.meta.captions,
        }
        data.append(data_item)
    tbl = db.create_table(index_name, data=data, exist_ok=True)
    return tbl


db_uri = str(Path(mkdtemp()) / "docling.db")  # or set as needed
index = make_lancedb_index(db_uri, doc.name, chunks, embed_model)

sample_query = "Making SME greedy and pragmatic"
sample_embedding = embed_model.encode(sample_query)
results = index.search(sample_embedding).limit(5)

results.to_pandas()

Unnamed: 0,vector,text,headings,captions,_distance
0,"[-0.025746439, 0.03888134, 0.0033668755, -0.03...","3. Forbus, K. and Oblinger, D. (1990). Making ...",[References],,0.332435
1,"[0.04400234, -0.034766007, -0.00025527124, 0.0...","4. McCord, M. C. (1990). Slot Grammar: A Syste...",[References],,1.525625
2,"[0.10043394, 0.00652478, 0.011601829, -0.06390...",passage using semantic and/or syntactic edges:...,[3 Syntactic-Semantic Graphs],,1.569923
3,"[0.025994677, 0.08402823, 0.03268827, -0.03727...","In using this algorithm, we have encountered a...",[4 Algorithm],,1.576838
4,"[0.050165094, 0.08015387, 0.035965856, 0.00846...",word order) are more aggressive in what they c...,[5 Evaluation and Conclusions],,1.580265
