From 6f30048f3a494e1d6f3ec14f331a22d839b29bc0 Mon Sep 17 00:00:00 2001 From: Panos Vagenas <35837085+vagenas@users.noreply.github.com> Date: Tue, 12 Nov 2024 14:39:22 +0100 Subject: [PATCH] fix token counting bug, minor revamping Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> --- .../advanced_chunking_with_merging.ipynb | 183 +++++++++--------- 1 file changed, 91 insertions(+), 92 deletions(-) diff --git a/docs/examples/advanced_chunking_with_merging.ipynb b/docs/examples/advanced_chunking_with_merging.ipynb index 5c4d0446..c1975bec 100644 --- a/docs/examples/advanced_chunking_with_merging.ipynb +++ b/docs/examples/advanced_chunking_with_merging.ipynb @@ -40,7 +40,7 @@ "from dataclasses import dataclass\n", "from pathlib import Path\n", "from tempfile import mkdtemp\n", - "from typing import Any, Iterator, Optional, Union\n", + "from typing import Iterator, Optional, Union\n", "\n", "import lancedb\n", "import semchunk\n", @@ -54,7 +54,7 @@ "from docling_core.types import DoclingDocument\n", "from pydantic import ConfigDict, PositiveInt\n", "from sentence_transformers import SentenceTransformer\n", - "from transformers import AutoTokenizer\n", + "from transformers import AutoTokenizer, PreTrainedTokenizerBase\n", "\n", "from docling.document_converter import DocumentConverter" ] @@ -65,11 +65,12 @@ "metadata": {}, "outputs": [], "source": [ - "DOC_SOURCE = \"http://bill.murdocks.org/iccbr2011murdock_web.pdf\"\n", "EMBED_MODEL_ID = \"sentence-transformers/all-MiniLM-L6-v2\"\n", - "TOKENIZER = AutoTokenizer.from_pretrained(EMBED_MODEL_ID)\n", - "EMBED_MODEL = SentenceTransformer(EMBED_MODEL_ID)\n", - "MAX_TOKENS = 64" + "MAX_TOKENS = 64\n", + "DOC_SOURCE = \"http://bill.murdocks.org/iccbr2011murdock_web.pdf\"\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_ID)\n", + "embed_model = SentenceTransformer(EMBED_MODEL_ID)" ] }, { @@ -85,14 +86,14 @@ "metadata": {}, "outputs": [], "source": [ - "class HybridChunker(BaseChunker):\n", + "class HybridChunker(BaseChunker): # TODO: improve naming\n", "\n", " model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)\n", "\n", " inner_chunker: BaseChunker = HierarchicalChunker()\n", - " # TODO: improve typing for tokenizer below (ran into issues with `PreTrainedTokenizer`):\n", - " tokenizer: Any\n", + " tokenizer: PreTrainedTokenizerBase\n", " max_tokens: PositiveInt\n", + " delim: str = \"\\n\"\n", "\n", " def _count_tokens(self, text: Optional[Union[str, list[str]]]):\n", " if text is None:\n", @@ -104,9 +105,6 @@ " return total\n", " return len(self.tokenizer.tokenize(text, max_length=None))\n", "\n", - " def _make_splitter(self):\n", - " return semchunk.chunkerify(self.tokenizer, self.max_tokens)\n", - "\n", " @dataclass\n", " class _ChunkLengthInfo:\n", " total_len: int\n", @@ -116,6 +114,7 @@ " def _doc_chunk_length(self, doc_chunk: DocChunk):\n", " text_length = self._count_tokens(doc_chunk.text)\n", " # Note that count_tokens handles None and lists, making this code simpler:\n", + " # TODO check if delim properly considered\n", " headings_length = self._count_tokens(doc_chunk.meta.headings)\n", " captions_length = self._count_tokens(doc_chunk.meta.captions)\n", " total = text_length + headings_length + captions_length\n", @@ -137,14 +136,13 @@ " new_chunk = DocChunk(text=window_text, meta=meta)\n", " return new_chunk\n", "\n", - " @classmethod\n", - " def _merge_text(cls, t1, t2):\n", + " def _merge_text(self, t1, t2):\n", " if t1 == \"\":\n", " return t2\n", " elif t2 == \"\":\n", " return t1\n", " else:\n", - " return t1 + \"\\n\" + t2\n", + " return f\"{t1}{self.delim}{t2}\"\n", "\n", " def _split_by_doc_items(self, doc_chunk: DocChunk) -> list[DocChunk]:\n", " if doc_chunk.meta.doc_items == None or len(doc_chunk.meta.doc_items) <= 1:\n", @@ -206,20 +204,23 @@ " def _split_using_plain_text(\n", " self,\n", " doc_chunk: DocChunk,\n", - " plain_text_splitter,\n", - " ):\n", + " ) -> list[DocChunk]:\n", " lengths = self._doc_chunk_length(doc_chunk)\n", " if lengths.total_len <= self.max_tokens:\n", " return [doc_chunk]\n", " else:\n", + "\n", " # How much room is there for text after subtracting out the headers and captions:\n", " available_length = self.max_tokens - lengths.other_len\n", + " sem_chunker = semchunk.chunkerify(\n", + " self.tokenizer, chunk_size=available_length\n", + " )\n", " if available_length <= 0:\n", " raise ValueError(\n", - " \"Headers and captions for this chunk are longer than the total amount of size for the chunk. This is not supported now.\"\n", - " )\n", + " \"Headers and captions for this chunk are longer than the total amount of size for the chunk. This is not supported now.\"\n", + " ) # TODO switch to warning\n", " text = doc_chunk.text\n", - " segments = plain_text_splitter.chunk(text)\n", + " segments = sem_chunker.chunk(text)\n", " chunks = [DocChunk(text=s, meta=doc_chunk.meta) for s in segments]\n", " return chunks\n", "\n", @@ -283,36 +284,33 @@ " )\n", " return final_merged_chunks\n", "\n", - " @classmethod\n", - " def _make_text_for_embedding(cls, chunk: DocChunk):\n", + " def _make_text_for_embedding(self, chunk: DocChunk):\n", " output = \"\"\n", " if chunk.meta.headings != None:\n", " for h in chunk.meta.headings:\n", - " output += h + \"\\n\"\n", + " output += h + self.delim\n", " if chunk.meta.captions != None:\n", " for c in chunk.meta.captions:\n", - " output += c + \"\\n\"\n", + " output += c + self.delim\n", " output += chunk.text\n", " return output\n", "\n", - " def _adjust_chunks_for_fixed_size(self, chunks: list[DocChunk], splitter):\n", - " split_by_items = [x for c in chunks for x in self._split_by_doc_items(c)]\n", - " split_recursively = [\n", - " x for c in split_by_items for x in self._split_using_plain_text(c, splitter)\n", - " ]\n", - " merged = self._merge_chunks(split_recursively)\n", - " text_expanded = [\n", + " def _adjust_chunks_for_fixed_size(self, chunks: list[DocChunk]):\n", + " res = chunks\n", + " res = [x for c in res for x in self._split_by_doc_items(c)]\n", + " res = [x for c in res for x in self._split_using_plain_text(c)]\n", + " res = self._merge_chunks(res)\n", + " res = [\n", " DocChunk.model_validate(\n", " {**c.model_dump(), \"text\": self._make_text_for_embedding(c)}\n", " )\n", - " for c in merged\n", + " for c in res\n", " ]\n", - " return text_expanded\n", + " return res\n", "\n", " def chunk(self, dl_doc: DoclingDocument, **kwargs) -> Iterator[BaseChunk]:\n", " preliminary_chunks = self.inner_chunker.chunk(dl_doc=dl_doc, **kwargs)\n", - " splitter = self._make_splitter()\n", - " output_chunks = self._adjust_chunks_for_fixed_size(preliminary_chunks, splitter)\n", + " output_chunks = self._adjust_chunks_for_fixed_size(preliminary_chunks)\n", " return iter(output_chunks)" ] }, @@ -327,49 +325,50 @@ "cell_type": "code", "execution_count": 5, "metadata": {}, + "outputs": [], + "source": [ + "conv_res = DocumentConverter().convert(source=DOC_SOURCE)\n", + "doc = conv_res.document" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using CPU. Note: This module is much faster with a GPU.\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "J. William Murdock\n", - "murdockj@us.ibm.com IBM T.J. Watson Research Center P.O. Box 704 Yorktown Heights, NY 10598\n", - "39\n", - "J. William Murdock\n", - "Abstract. The Jeopardy! television quiz show asks natural-language questions and requires natural-language answers. One useful source of information for answering Jeopardy! questions is text from written sources such as encyclopedias or news articles. A text passage may partially or fully indicate that some candidate answer is the correct answer to the question. Recognizing\n", - "70\n", - "J. William Murdock\n", - "whether it does requires determining the extent to which what the passage is saying about the candidate answer is similar to what the question is saying about the desired answer. This paper describes how structure mapping [1] (an algorithm originally developed for analogical reasoning) is applied to determine similarity between content in questions and passages. That algorithm\n", - "70\n", - "J. William Murdock\n", - "is one of many used in the Watson question answering system [2]. It contributes a significant amount to Watson's effectiveness.\n", - "32\n", - "1 Introduction\n", - "Watson is a question answering system built on a set of technologies known as DeepQA [2]. Watson has been customized and configured to compete at Jeopardy!, an American television quiz show. Watson takes in a question and produces a ranked list of answers with confidence scores attached to each of these answers.\n", - "62\n" + "chunk.text='J. William Murdock\\nmurdockj@us.ibm.com IBM T.J. Watson Research Center P.O. Box 704 Yorktown Heights, NY 10598'\n", + "num tokens: 39\n", + "\n", + "chunk.text='J. William Murdock\\nAbstract. The Jeopardy! television quiz show asks natural-language questions and requires natural-language answers. One useful source of information for answering Jeopardy! questions is text from written sources such as encyclopedias or news articles. A text passage may partially or fully indicate that some candidate answer is the correct'\n", + "num tokens: 64\n", + "\n", + "chunk.text='J. William Murdock\\nanswer to the question. Recognizing whether it does requires determining the extent to which what the passage is saying about the candidate answer is similar to what the question is saying about the desired answer. This paper describes how structure mapping [1] (an algorithm originally developed for analogical reasoning) is applied'\n", + "num tokens: 64\n", + "\n", + "chunk.text=\"J. William Murdock\\nto determine similarity between content in questions and passages. That algorithm is one of many used in the Watson question answering system [2]. It contributes a significant amount to Watson's effectiveness.\"\n", + "num tokens: 44\n", + "\n", + "chunk.text='1 Introduction\\nWatson is a question answering system built on a set of technologies known as DeepQA [2]. Watson has been customized and configured to compete at Jeopardy!, an American television quiz show. Watson takes in a question and produces a ranked list of answers with confidence scores attached to each of these answers.'\n", + "num tokens: 62\n", + "\n" ] } ], "source": [ - "conv_res = DocumentConverter().convert(source=DOC_SOURCE)\n", - "doc = conv_res.document\n", - "\n", "chunker = HybridChunker(\n", - " tokenizer=TOKENIZER,\n", + " tokenizer=tokenizer,\n", " max_tokens=MAX_TOKENS,\n", ")\n", "chunks = list(chunker.chunk(dl_doc=doc))\n", "\n", "for chunk in chunks[:5]:\n", - " print(chunk.text)\n", - " print(chunker._count_tokens(chunk.text))" + " print(f\"{chunk.text=}\")\n", + " print(f\"num tokens: {len(tokenizer.tokenize(chunk.text, max_length=None))}\")\n", + " print()" ] }, { @@ -381,7 +380,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -423,14 +422,6 @@ " \n", "