docling/examples/rag_llamaindex.ipynb
Panos Vagenas 6e16a2464e add docling splitter to LC example, simplify & align QA output
Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
2024-10-04 14:43:27 +02:00

531 lines
15 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# RAG with Docling and 🦙 LlamaIndex"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"# requirements for this example:\n",
"%pip install -qq docling docling-core python-dotenv llama-index-embeddings-huggingface llama-index-llms-huggingface-api llama-index-vector-stores-milvus"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"from tempfile import TemporaryDirectory\n",
"\n",
"from dotenv import load_dotenv\n",
"from pydantic import TypeAdapter\n",
"from rich.pretty import pprint\n",
"\n",
"load_dotenv()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"\n",
"warnings.filterwarnings(action=\"ignore\", category=UserWarning, module=\"pydantic|torch\")\n",
"warnings.filterwarnings(action=\"ignore\", category=FutureWarning, module=\"easyocr\")\n",
"# https://github.com/huggingface/transformers/issues/5486:\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Helpers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Below we define:\n",
"\n",
"- `DoclingPDFReader` which will be used to create LlamaIndex documents,\n",
"- `DoclingNodeParser`, which can be used to create LlamaIndex nodes out of JSON-based documents, and\n",
"- a helper function for QA printing"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from enum import Enum\n",
"from pathlib import Path\n",
"from typing import Any, Iterable\n",
"\n",
"from llama_index.core.readers.base import BasePydanticReader\n",
"from llama_index.core.schema import Document as LIDocument\n",
"\n",
"from docling.document_converter import DocumentConverter\n",
"\n",
"_KEY_DL_DOC_HASH = \"dl_doc_hash\"\n",
"_KEY_ORIGIN = \"origin\"\n",
"\n",
"\n",
"class DoclingPDFReader(BasePydanticReader):\n",
" class ParseType(str, Enum):\n",
" MARKDOWN = \"markdown\"\n",
" JSON = \"json\"\n",
"\n",
" parse_type: ParseType = ParseType.MARKDOWN\n",
" include_origin: bool = False\n",
"\n",
" def lazy_load_data(\n",
" self,\n",
" file_path: str | Path | Iterable[str] | Iterable[Path],\n",
" *args: Any,\n",
" **load_kwargs: Any,\n",
" ) -> Iterable[LIDocument]:\n",
" file_paths = (\n",
" file_path\n",
" if isinstance(file_path, Iterable) and not isinstance(file_path, str)\n",
" else [file_path]\n",
" )\n",
" converter = DocumentConverter()\n",
" for source in file_paths:\n",
" dl_doc = converter.convert_single(source).output\n",
" match self.parse_type:\n",
" case self.ParseType.MARKDOWN:\n",
" text = dl_doc.export_to_markdown()\n",
" case self.ParseType.JSON:\n",
" text = dl_doc.model_dump_json()\n",
" case _:\n",
" raise RuntimeError(\n",
" f\"Unexpected export type encountered: {self.export_type}\"\n",
" )\n",
" origin = str(source) if isinstance(source, Path) else source\n",
" li_doc = LIDocument(text=text)\n",
" li_doc.metadata = {\n",
" _KEY_DL_DOC_HASH: dl_doc.file_info.document_hash,\n",
" }\n",
" if self.include_origin:\n",
" li_doc.metadata[_KEY_ORIGIN] = origin\n",
" yield li_doc"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, Iterable, Sequence\n",
"\n",
"from docling_core.transforms.chunker import BaseChunker, HierarchicalChunker\n",
"from docling_core.types import Document as DLDocument\n",
"from llama_index.core import Document as LIDocument\n",
"from llama_index.core.node_parser.interface import NodeParser\n",
"from llama_index.core.node_parser.node_utils import IdFuncCallable, default_id_func\n",
"from llama_index.core.schema import (\n",
" BaseNode,\n",
" NodeRelationship,\n",
" RelatedNodeType,\n",
" TextNode,\n",
")\n",
"from llama_index.core.utils import get_tqdm_iterable\n",
"\n",
"\n",
"class DoclingNodeParser(NodeParser):\n",
" chunker: BaseChunker = HierarchicalChunker(heading_as_metadata=True)\n",
"\n",
" def _parse_nodes(\n",
" self,\n",
" nodes: Sequence[BaseNode],\n",
" show_progress: bool = False,\n",
" **kwargs: Any,\n",
" ) -> list[BaseNode]:\n",
" id_func: IdFuncCallable = self.id_func or default_id_func\n",
" nodes_with_progress: Iterable[BaseNode] = get_tqdm_iterable(\n",
" items=nodes, show_progress=show_progress, desc=\"Parsing nodes\"\n",
" )\n",
" all_nodes: list[BaseNode] = []\n",
" for input_node in nodes_with_progress:\n",
" li_doc = LIDocument.model_validate(input_node)\n",
" dl_doc: DLDocument = DLDocument.model_validate_json(li_doc.get_content())\n",
" chunk_iter = self.chunker.chunk(dl_doc=dl_doc)\n",
" for i, chunk in enumerate(chunk_iter):\n",
" rels: dict[NodeRelationship, RelatedNodeType] = {\n",
" NodeRelationship.SOURCE: li_doc.as_related_node_info(),\n",
" }\n",
" metadata = chunk.model_dump(\n",
" exclude=\"text\",\n",
" exclude_none=True,\n",
" )\n",
" # by default we exclude all meta keys from embedding/LLM — unless allowed\n",
" excl_meta_keys = [k for k in metadata if k not in {\"heading\"}]\n",
" if self.include_metadata:\n",
" excl_meta_keys = [k for k in li_doc.metadata] + excl_meta_keys\n",
" node = TextNode(\n",
" id_=id_func(i=i, doc=li_doc),\n",
" text=chunk.text,\n",
" excluded_embed_metadata_keys=excl_meta_keys,\n",
" excluded_llm_metadata_keys=excl_meta_keys,\n",
" relationships=rels,\n",
" )\n",
" node.metadata = metadata\n",
" all_nodes.append(node)\n",
" return all_nodes"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"from llama_index.core.base.response.schema import RESPONSE_TYPE\n",
"\n",
"\n",
"def print_qa(query: str, query_res: RESPONSE_TYPE):\n",
" def clip(inp, max_len=100):\n",
" if isinstance(inp, str):\n",
" return f\"{inp[:max_len]}{'...' if len(inp) > max_len else ''}\"\n",
" else:\n",
" return inp\n",
"\n",
" print(\n",
" f\"Question:\\n{query}\\n\\nAnswer:\\n{json.dumps(clip(query_res.response.strip()))}\"\n",
" )\n",
" for i, res in enumerate(query_res.source_nodes):\n",
" print()\n",
" print(f\"Source {i+1}:\")\n",
" print(f\" text: {json.dumps(clip(res.text.strip()))}\")\n",
" for key in res.metadata:\n",
" print(f\" {key}: {clip(res.metadata.get(key))}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Reader and node parser"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Using native Docling format (as JSON)**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To leverage Docling's rich document structure format, we can namely export to JSON and use the `DoclingNodeParser` accordingly:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"reader = DoclingPDFReader(parse_type=DoclingPDFReader.ParseType.JSON)\n",
"node_parser = DoclingNodeParser()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Using Markdown**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternatively, to just use the flat Markdown export instead of the native document format, one can uncomment and use the following:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# from llama_index.core.node_parser import MarkdownNodeParser\n",
"\n",
"# reader = DoclingPDFReader(parse_type=DoclingPDFReader.ParseType.MARKDOWN)\n",
"# node_parser = MarkdownNodeParser()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Transformations"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our transformations currently include the `node_parser`:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"transformations = [node_parser]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One can include add more transformations, e.g. further chunking based on text size / overlap, as shown below:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# from llama_index.core.node_parser import TokenTextSplitter\n",
"\n",
"# splitter = TokenTextSplitter(\n",
"# chunk_size=1024,\n",
"# chunk_overlap=0,\n",
"# )\n",
"# transformations.append(splitter)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Embed model"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from llama_index.embeddings.huggingface import HuggingFaceEmbedding\n",
"\n",
"embed_model = HuggingFaceEmbedding(model_name=\"intfloat/multilingual-e5-small\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Vector store"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"INGEST = True # whether to ingest from scratch or reuse an existing vector store"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from llama_index.vector_stores.milvus import MilvusVectorStore\n",
"\n",
"MILVUS_URI = os.environ.get(\n",
" \"MILVUS_URI\", f\"{(tmp_dir := TemporaryDirectory()).name}/milvus_demo.db\"\n",
")\n",
"\n",
"vector_store = MilvusVectorStore(\n",
" uri=MILVUS_URI,\n",
" collection_name=\"docling_li_demo\",\n",
" dim=len(embed_model.get_text_embedding(\"hi\")),\n",
" overwrite=INGEST,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from llama_index.core import StorageContext, VectorStoreIndex\n",
"\n",
"if INGEST:\n",
" # in this case we ingest the data into the vector store\n",
" docs = reader.load_data(\n",
" file_path=\"https://arxiv.org/pdf/2206.01062\", # DocLayNet paper\n",
" )\n",
" storage_context = StorageContext.from_defaults(vector_store=vector_store)\n",
" index = VectorStoreIndex.from_documents(\n",
" documents=docs,\n",
" embed_model=embed_model,\n",
" storage_context=storage_context,\n",
" transformations=transformations,\n",
" )\n",
"else:\n",
" # in this case we just load the vector store index\n",
" index = VectorStoreIndex.from_vector_store(\n",
" vector_store=vector_store,\n",
" embed_model=embed_model,\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### LLM"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI\n",
"\n",
"HF_API_KEY = os.environ.get(\"HF_API_KEY\")\n",
"\n",
"llm = HuggingFaceInferenceAPI(\n",
" token=HF_API_KEY,\n",
" model_name=\"mistralai/Mistral-7B-Instruct-v0.3\",\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RAG"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Question:\n",
"How many pages were annotated by humans?\n",
"\n",
"Answer:\n",
"\"80863 pages were annotated by humans.\"\n",
"\n",
"Source 1:\n",
" text: \"DocLayNet contains 80863 PDF pages. Among these, 7059 carry two instances of human annotations, and ...\"\n",
" dl_doc_hash: 5dfbd8c115a15fd3396b68409124cfee29fc8efac7b5c846634ff924e635e0dc\n",
" path: $.main-text[37]\n",
" page: 2\n",
" bbox: [317.2852478027344, 116.46983337402344, 559.7131958007812, 201.73675537109375]\n",
" heading: 3 THE DOCLAYNET DATASET\n",
"\n",
"Source 2:\n",
" text: \"In this paper, we present the DocLayNet dataset. It provides pageby-page layout annotation ground-tr...\"\n",
" dl_doc_hash: 5dfbd8c115a15fd3396b68409124cfee29fc8efac7b5c846634ff924e635e0dc\n",
" path: $.main-text[23]\n",
" page: 2\n",
" bbox: [53.50020980834961, 212.36782836914062, 295.56396484375, 286.4964599609375]\n",
" heading: 1 INTRODUCTION\n"
]
}
],
"source": [
"query_engine = index.as_query_engine(llm=llm)\n",
"QUERY = \"How many pages were annotated by humans?\"\n",
"query_res = query_engine.query(QUERY)\n",
"print_qa(query=QUERY, query_res=query_res)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}