Merge branch 'main' into docx-markdown-formatting

Signed-off-by: SimJeg <sjegou@nvidia.com>
This commit is contained in:
SimJeg 2025-03-31 11:22:23 +02:00
commit fbfb37f363
395 changed files with 50662 additions and 39728 deletions

11
.actor/.dockerignore Normal file
View File

@ -0,0 +1,11 @@
**/__pycache__
**/*.pyc
**/*.pyo
**/*.pyd
.git
.gitignore
.env
.venv
*.log
.pytest_cache
.coverage

69
.actor/CHANGELOG.md Normal file
View File

@ -0,0 +1,69 @@
# Changelog
All notable changes to the Docling Actor will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [1.1.0] - 2025-03-09
### Changed
- Switched from full Docling CLI to docling-serve API
- Using the official quay.io/ds4sd/docling-serve-cpu Docker image
- Reduced Docker image size (from ~6GB to ~4GB)
- Implemented multi-stage Docker build to handle dependencies
- Improved Docker build process to ensure compatibility with docling-serve-cpu image
- Added new Python processor script for reliable API communication and content extraction
- Enhanced response handling with better content extraction logic
- Fixed ES modules compatibility issue with Apify CLI
- Added explicit tmpfs volume for temporary files
- Fixed environment variables format in actor.json
- Created optimized dependency installation approach
- Improved API compatibility with docling-serve
- Updated endpoint from custom `/convert` to standard `/v1alpha/convert/source`
- Revised JSON payload structure to match docling-serve API format
- Added proper output field parsing based on format
- Enhanced startup process with health checks
- Added configurable API host and port through environment variables
- Better content type handling for different output formats
- Updated error handling to align with API responses
### Fixed
- Fixed actor input file conflict in get_actor_input(): now checks for and removes an existing /tmp/actor-input/INPUT directory if found, ensuring valid JSON input parsing.
### Technical Details
- Actor Specification v1
- Using quay.io/ds4sd/docling-serve-cpu:latest base image
- Node.js 20.x for Apify CLI
- Eliminated Python dependencies
- Simplified Docker build process
## [1.0.0] - 2025-02-07
### Added
- Initial release of Docling Actor
- Support for multiple document formats (PDF, DOCX, images)
- OCR capabilities for scanned documents
- Multiple output formats (md, json, html, text, doctags)
- Comprehensive error handling and logging
- Dataset records with processing status
- Memory monitoring and resource optimization
- Security features including non-root user execution
### Technical Details
- Actor Specification v1
- Docling v2.17.0
- Python 3.11
- Node.js 20.x
- Comprehensive error codes:
- 10: Invalid input
- 11: URL inaccessible
- 12: Docling processing failed
- 13: Output file missing
- 14: Storage operation failed
- 15: OCR processing failed

87
.actor/Dockerfile Normal file
View File

@ -0,0 +1,87 @@
# Build stage for installing dependencies
FROM node:20-slim AS builder
# Install necessary tools and prepare dependencies environment in one layer
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
&& rm -rf /var/lib/apt/lists/* \
&& mkdir -p /build/bin /build/lib/node_modules \
&& cp /usr/local/bin/node /build/bin/
# Set working directory
WORKDIR /build
# Create package.json and install Apify CLI in one layer
RUN echo '{"name":"docling-actor-dependencies","version":"1.0.0","description":"Dependencies for Docling Actor","private":true,"type":"module","engines":{"node":">=18"}}' > package.json \
&& npm install apify-cli@latest \
&& cp -r node_modules/* lib/node_modules/ \
&& echo '#!/bin/sh\n/tmp/docling-tools/bin/node /tmp/docling-tools/lib/node_modules/apify-cli/bin/run "$@"' > bin/actor \
&& chmod +x bin/actor \
# Clean up npm cache to reduce image size
&& npm cache clean --force
# Final stage with docling-serve-cpu
FROM quay.io/ds4sd/docling-serve-cpu:latest
LABEL maintainer="Vaclav Vancura <@vancura>" \
description="Apify Actor for document processing using Docling" \
version="1.1.0"
# Set only essential environment variables
ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
DOCLING_SERVE_HOST=0.0.0.0 \
DOCLING_SERVE_PORT=5001
# Switch to root temporarily to set up directories and permissions
USER root
WORKDIR /app
# Install required tools and create directories in a single layer
RUN dnf install -y \
jq \
&& dnf clean all \
&& mkdir -p /build-files \
/tmp \
/tmp/actor-input \
/tmp/actor-output \
/tmp/actor-storage \
/tmp/apify_input \
/apify_input \
/opt/app-root/src/.EasyOCR/user_network \
/tmp/easyocr-models \
&& chown 1000:1000 /build-files \
&& chown -R 1000:1000 /opt/app-root/src/.EasyOCR \
&& chmod 1777 /tmp \
&& chmod 1777 /tmp/easyocr-models \
&& chmod 777 /tmp/actor-input /tmp/actor-output /tmp/actor-storage /tmp/apify_input /apify_input \
# Fix for uv_os_get_passwd error in Node.js
&& echo "docling:x:1000:1000:Docling User:/app:/bin/sh" >> /etc/passwd
# Set environment variable to tell EasyOCR to use a writable location for models
ENV EASYOCR_MODULE_PATH=/tmp/easyocr-models
# Copy only required files
COPY --chown=1000:1000 .actor/actor.sh .actor/actor.sh
COPY --chown=1000:1000 .actor/actor.json .actor/actor.json
COPY --chown=1000:1000 .actor/input_schema.json .actor/input_schema.json
COPY --chown=1000:1000 .actor/docling_processor.py .actor/docling_processor.py
RUN chmod +x .actor/actor.sh
# Copy the build files from builder
COPY --from=builder --chown=1000:1000 /build /build-files
# Switch to non-root user
USER 1000
# Set up TMPFS for temporary files
VOLUME ["/tmp"]
# Create additional volumes for OCR models persistence
VOLUME ["/tmp/easyocr-models"]
# Expose the docling-serve API port
EXPOSE 5001
# Run the actor script
ENTRYPOINT [".actor/actor.sh"]

314
.actor/README.md Normal file
View File

@ -0,0 +1,314 @@
# Docling Actor on Apify
[![Docling Actor](https://apify.com/actor-badge?actor=vancura/docling?fpr=docling)](https://apify.com/vancura/docling)
This Actor (specification v1) wraps the [Docling project](https://ds4sd.github.io/docling/) to provide serverless document processing in the cloud. It can process complex documents (PDF, DOCX, images) and convert them into structured formats (Markdown, JSON, HTML, Text, or DocTags) with optional OCR support.
## What are Actors?
[Actors](https://docs.apify.com/platform/actors?fpr=docling) are serverless microservices running on the [Apify Platform](https://apify.com/?fpr=docling). They are based on the [Actor SDK](https://docs.apify.com/sdk/js?fpr=docling) and can be found in the [Apify Store](https://apify.com/store?fpr=docling). Learn more about Actors in the [Apify Whitepaper](https://whitepaper.actor?fpr=docling).
## Table of Contents
1. [Features](#features)
2. [Usage](#usage)
3. [Input Parameters](#input-parameters)
4. [Output](#output)
5. [Performance & Resources](#performance--resources)
6. [Troubleshooting](#troubleshooting)
7. [Local Development](#local-development)
8. [Architecture](#architecture)
9. [License](#license)
10. [Acknowledgments](#acknowledgments)
11. [Security Considerations](#security-considerations)
## Features
- Leverages the official docling-serve-cpu Docker image for efficient document processing
- Processes multiple document formats:
- PDF documents (scanned or digital)
- Microsoft Office files (DOCX, XLSX, PPTX)
- Images (PNG, JPG, TIFF)
- Other text-based formats
- Provides OCR capabilities for scanned documents
- Exports to multiple formats:
- Markdown
- JSON
- HTML
- Plain Text
- DocTags (structured format)
- No local setup needed—just provide input via a simple JSON config
## Usage
### Using Apify Console
1. Go to the Apify Actor page.
2. Click "Run".
3. In the input form, fill in:
- The URL of the document.
- Output format (`md`, `json`, `html`, `text`, or `doctags`).
- OCR boolean toggle.
4. The Actor will run and produce its outputs in the default key-value store under the key `OUTPUT`.
### Using Apify API
```bash
curl --request POST \
--url "https://api.apify.com/v2/acts/vancura~docling/run" \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer YOUR_API_TOKEN' \
--data '{
"options": {
"to_formats": ["md", "json", "html", "text", "doctags"]
},
"http_sources": [
{"url": "https://vancura.dev/assets/actor-test/facial-hairstyles-and-filtering-facepiece-respirators.pdf"},
{"url": "https://arxiv.org/pdf/2408.09869"}
]
}'
```
### Using Apify CLI
```bash
apify call vancura/docling --input='{
"options": {
"to_formats": ["md", "json", "html", "text", "doctags"]
},
"http_sources": [
{"url": "https://vancura.dev/assets/actor-test/facial-hairstyles-and-filtering-facepiece-respirators.pdf"},
{"url": "https://arxiv.org/pdf/2408.09869"}
]
}'
```
## Input Parameters
The Actor accepts a JSON schema matching the file `.actor/input_schema.json`. Below is a summary of the fields:
| Field | Type | Required | Default | Description |
|----------------|---------|----------|----------|-------------------------------------------------------------------------------|
| `http_sources` | object | Yes | None | https://github.com/DS4SD/docling-serve?tab=readme-ov-file#url-endpoint |
| `options` | object | No | None | https://github.com/DS4SD/docling-serve?tab=readme-ov-file#common-parameters |
### Example Input
```json
{
"options": {
"to_formats": ["md", "json", "html", "text", "doctags"]
},
"http_sources": [
{"url": "https://vancura.dev/assets/actor-test/facial-hairstyles-and-filtering-facepiece-respirators.pdf"},
{"url": "https://arxiv.org/pdf/2408.09869"}
]
}
```
## Output
The Actor provides three types of outputs:
1. **Processed Documents in a ZIP** - The Actor will provide the direct URL to your result in the run log, looking like:
```text
You can find your results at: 'https://api.apify.com/v2/key-value-stores/[YOUR_STORE_ID]/records/OUTPUT'
```
2. **Processing Log** - Available in the key-value store as `DOCLING_LOG`
3. **Dataset Record** - Contains processing metadata with:
- Direct link to the processed output zip file
- Processing status
You can access the results in several ways:
1. **Direct URL** (shown in Actor run logs):
```text
https://api.apify.com/v2/key-value-stores/[STORE_ID]/records/OUTPUT
```
2. **Programmatically** via Apify CLI:
```bash
apify key-value-stores get-value OUTPUT
```
3. **Dataset** - Check the "Dataset" tab in the Actor run details to see processing metadata
### Example Outputs
#### Markdown (md)
```markdown
# Document Title
## Section 1
Content of section 1...
## Section 2
Content of section 2...
```
#### JSON
```json
{
"title": "Document Title",
"sections": [
{
"level": 1,
"title": "Section 1",
"content": "Content of section 1..."
}
]
}
```
#### HTML
```html
<h1>Document Title</h1>
<h2>Section 1</h2>
<p>Content of section 1...</p>
```
### Processing Logs (`DOCLING_LOG`)
The Actor maintains detailed processing logs including:
- API request and response details
- Processing steps and timing
- Error messages and stack traces
- Input validation results
Access logs via:
```bash
apify key-value-stores get-record DOCLING_LOG
```
## Performance & Resources
- **Docker Image Size**: ~4GB
- **Memory Requirements**:
- Minimum: 2 GB RAM
- Recommended: 4 GB RAM for large or complex documents
- **Processing Time**:
- Simple documents: 15-30 seconds
- Complex PDFs with OCR: 1-3 minutes
- Large documents (100+ pages): 3-10 minutes
## Troubleshooting
Common issues and solutions:
1. **Document URL Not Accessible**
- Ensure the URL is publicly accessible
- Check if the document requires authentication
- Verify the URL leads directly to the document
2. **OCR Processing Fails**
- Verify the document is not password-protected
- Check if the image quality is sufficient
- Try processing with OCR disabled
3. **API Response Issues**
- Check the logs for detailed error messages
- Ensure the document format is supported
- Verify the URL is correctly formatted
4. **Output Format Issues**
- Verify the output format is supported
- Check if the document structure is compatible
- Review the `DOCLING_LOG` for specific errors
### Error Handling
The Actor implements comprehensive error handling:
- Detailed error messages in `DOCLING_LOG`
- Proper exit codes for different failure scenarios
- Automatic cleanup on failure
- Dataset records with processing status
## Local Development
If you wish to develop or modify this Actor locally:
1. Clone the repository.
2. Ensure Docker is installed.
3. The Actor files are located in the `.actor` directory:
- `Dockerfile` - Defines the container environment
- `actor.json` - Actor configuration and metadata
- `actor.sh` - Main execution script that starts the docling-serve API and orchestrates document processing
- `input_schema.json` - Input parameter definitions
- `dataset_schema.json` - Dataset output format definition
- `CHANGELOG.md` - Change log documenting all notable changes
- `README.md` - This documentation
4. Run the Actor locally using:
```bash
apify run
```
### Actor Structure
```text
.actor/
├── Dockerfile # Container definition
├── actor.json # Actor metadata
├── actor.sh # Execution script (also starts docling-serve API)
├── input_schema.json # Input parameters
├── dataset_schema.json # Dataset output format definition
├── docling_processor.py # Python script for API communication
├── CHANGELOG.md # Version history and changes
└── README.md # This documentation
```
## Architecture
This Actor uses a lightweight architecture based on the official `quay.io/ds4sd/docling-serve-cpu` Docker image:
- **Base Image**: `quay.io/ds4sd/docling-serve-cpu:latest` (~4GB)
- **Multi-Stage Build**: Uses a multi-stage Docker build to include only necessary tools
- **API Communication**: Uses the RESTful API provided by docling-serve
- **Request Flow**:
1. The actor script starts the docling-serve API on port 5001
2. Performs health checks to ensure the API is running
3. Processes the input parameters
4. Creates a JSON payload for the docling-serve API with proper format:
```json
{
"options": {
"to_formats": ["md"],
"do_ocr": true
},
"http_sources": [{"url": "https://example.com/document.pdf"}]
}
```
5. Makes a POST request to the `/v1alpha/convert/source` endpoint
6. Processes the response and stores it in the key-value store
- **Dependencies**:
- Node.js for Apify CLI
- Essential tools (curl, jq, etc.) copied from build stage
- **Security**: Runs as a non-root user for enhanced security
## License
This wrapper project is under the MIT License, matching the original Docling license. See [LICENSE](../LICENSE) for details.
## Acknowledgments
- [Docling](https://ds4sd.github.io/docling/) and [docling-serve-cpu](https://quay.io/repository/ds4sd/docling-serve-cpu) by IBM
- [Apify](https://apify.com/?fpr=docling) for the serverless actor environment
## Security Considerations
- Actor runs under a non-root user for enhanced security
- Input URLs are validated before processing
- Temporary files are securely managed and cleaned up
- Process isolation through Docker containerization
- Secure handling of processing artifacts

11
.actor/actor.json Normal file
View File

@ -0,0 +1,11 @@
{
"actorSpecification": 1,
"name": "docling",
"version": "0.0",
"environmentVariables": {},
"dockerFile": "./Dockerfile",
"input": "./input_schema.json",
"scripts": {
"run": "./actor.sh"
}
}

419
.actor/actor.sh Executable file
View File

@ -0,0 +1,419 @@
#!/bin/bash
export PATH=$PATH:/build-files/node_modules/.bin
# Function to upload content to the key-value store
upload_to_kvs() {
local content_file="$1"
local key_name="$2"
local content_type="$3"
local description="$4"
# Find the Apify CLI command
find_apify_cmd
local apify_cmd="$FOUND_APIFY_CMD"
if [ -n "$apify_cmd" ]; then
echo "Uploading $description to key-value store (key: $key_name)..."
# Create a temporary home directory with write permissions
setup_temp_environment
# Use the --no-update-notifier flag if available
if $apify_cmd --help | grep -q "\--no-update-notifier"; then
if $apify_cmd --no-update-notifier actor:set-value "$key_name" --contentType "$content_type" < "$content_file"; then
echo "Successfully uploaded $description to key-value store"
local url="https://api.apify.com/v2/key-value-stores/${APIFY_DEFAULT_KEY_VALUE_STORE_ID}/records/$key_name"
echo "$description available at: $url"
cleanup_temp_environment
return 0
fi
else
# Fall back to regular command if flag isn't available
if $apify_cmd actor:set-value "$key_name" --contentType "$content_type" < "$content_file"; then
echo "Successfully uploaded $description to key-value store"
local url="https://api.apify.com/v2/key-value-stores/${APIFY_DEFAULT_KEY_VALUE_STORE_ID}/records/$key_name"
echo "$description available at: $url"
cleanup_temp_environment
return 0
fi
fi
echo "ERROR: Failed to upload $description to key-value store"
cleanup_temp_environment
return 1
else
echo "ERROR: Apify CLI not found for $description upload"
return 1
fi
}
# Function to find Apify CLI command
find_apify_cmd() {
FOUND_APIFY_CMD=""
for cmd in "apify" "actor" "/usr/local/bin/apify" "/usr/bin/apify" "/opt/apify/cli/bin/apify"; do
if command -v "$cmd" &> /dev/null; then
FOUND_APIFY_CMD="$cmd"
break
fi
done
}
# Function to set up temporary environment for Apify CLI
setup_temp_environment() {
export TMPDIR="/tmp/apify-home-${RANDOM}"
mkdir -p "$TMPDIR"
export APIFY_DISABLE_VERSION_CHECK=1
export NODE_OPTIONS="--no-warnings"
export HOME="$TMPDIR" # Override home directory to writable location
}
# Function to clean up temporary environment
cleanup_temp_environment() {
rm -rf "$TMPDIR" 2>/dev/null || true
}
# Function to push data to Apify dataset
push_to_dataset() {
# Example usage: push_to_dataset "$RESULT_URL" "$OUTPUT_SIZE" "zip"
local result_url="$1"
local size="$2"
local format="$3"
# Find Apify CLI command
find_apify_cmd
local apify_cmd="$FOUND_APIFY_CMD"
if [ -n "$apify_cmd" ]; then
echo "Adding record to dataset..."
setup_temp_environment
# Use the --no-update-notifier flag if available
if $apify_cmd --help | grep -q "\--no-update-notifier"; then
if $apify_cmd --no-update-notifier actor:push-data "{\"output_file\": \"${result_url}\", \"format\": \"${format}\", \"size\": \"${size}\", \"status\": \"success\"}"; then
echo "Successfully added record to dataset"
else
echo "Warning: Failed to add record to dataset"
fi
else
# Fall back to regular command
if $apify_cmd actor:push-data "{\"output_file\": \"${result_url}\", \"format\": \"${format}\", \"size\": \"${size}\", \"status\": \"success\"}"; then
echo "Successfully added record to dataset"
else
echo "Warning: Failed to add record to dataset"
fi
fi
cleanup_temp_environment
fi
}
# --- Setup logging and error handling ---
LOG_FILE="/tmp/docling.log"
touch "$LOG_FILE" || {
echo "Fatal: Cannot create log file at $LOG_FILE"
exit 1
}
# Log to both console and file
exec 1> >(tee -a "$LOG_FILE")
exec 2> >(tee -a "$LOG_FILE" >&2)
# Exit codes
readonly ERR_API_UNAVAILABLE=15
readonly ERR_INVALID_INPUT=16
# --- Debug environment ---
echo "Date: $(date)"
echo "Python version: $(python --version 2>&1)"
echo "Docling-serve path: $(which docling-serve 2>/dev/null || echo 'Not found')"
echo "Working directory: $(pwd)"
# --- Get input ---
echo "Getting Apify Actor Input"
INPUT=$(apify actor get-input 2>/dev/null)
# --- Setup tools ---
echo "Setting up tools..."
TOOLS_DIR="/tmp/docling-tools"
mkdir -p "$TOOLS_DIR"
# Copy tools if available
if [ -d "/build-files" ]; then
echo "Copying tools from /build-files..."
cp -r /build-files/* "$TOOLS_DIR/"
export PATH="$TOOLS_DIR/bin:$PATH"
else
echo "Warning: No build files directory found. Some tools may be unavailable."
fi
# Copy Python processor script to tools directory
PYTHON_SCRIPT_PATH="$(dirname "$0")/docling_processor.py"
if [ -f "$PYTHON_SCRIPT_PATH" ]; then
echo "Copying Python processor script to tools directory..."
cp "$PYTHON_SCRIPT_PATH" "$TOOLS_DIR/"
chmod +x "$TOOLS_DIR/docling_processor.py"
else
echo "ERROR: Python processor script not found at $PYTHON_SCRIPT_PATH"
exit 1
fi
# Check OCR directories and ensure they're writable
echo "Checking OCR directory permissions..."
OCR_DIR="/opt/app-root/src/.EasyOCR"
if [ -d "$OCR_DIR" ]; then
# Test if we can write to the directory
if touch "$OCR_DIR/test_write" 2>/dev/null; then
echo "[✓] OCR directory is writable"
rm "$OCR_DIR/test_write"
else
echo "[✗] OCR directory is not writable, setting up alternative in /tmp"
# Create alternative in /tmp (which is writable)
mkdir -p "/tmp/.EasyOCR/user_network"
export EASYOCR_MODULE_PATH="/tmp/.EasyOCR"
fi
else
echo "OCR directory not found, creating in /tmp"
mkdir -p "/tmp/.EasyOCR/user_network"
export EASYOCR_MODULE_PATH="/tmp/.EasyOCR"
fi
# --- Starting the API ---
echo "Starting docling-serve API..."
# Create a dedicated working directory in /tmp (writable)
API_DIR="/tmp/docling-api"
mkdir -p "$API_DIR"
cd "$API_DIR"
echo "API working directory: $(pwd)"
# Find docling-serve executable
DOCLING_SERVE_PATH=$(which docling-serve)
echo "Docling-serve executable: $DOCLING_SERVE_PATH"
# Start the API with minimal parameters to avoid any issues
echo "Starting docling-serve API..."
"$DOCLING_SERVE_PATH" run --host 0.0.0.0 --port 5001 > "$API_DIR/docling-serve.log" 2>&1 &
API_PID=$!
echo "Started docling-serve API with PID: $API_PID"
# A more reliable wait for API startup
echo "Waiting for API to initialize..."
MAX_TRIES=30
tries=0
started=false
while [ $tries -lt $MAX_TRIES ]; do
tries=$((tries + 1))
# Check if process is still running
if ! ps -p $API_PID > /dev/null; then
echo "ERROR: docling-serve API process terminated unexpectedly after $tries seconds"
break
fi
# Check log for startup completion or errors
if grep -q "Application startup complete" "$API_DIR/docling-serve.log" 2>/dev/null; then
echo "[✓] API startup completed successfully after $tries seconds"
started=true
break
fi
if grep -q "Permission denied\|PermissionError" "$API_DIR/docling-serve.log" 2>/dev/null; then
echo "ERROR: Permission errors detected in API startup"
break
fi
# Sleep and check again
sleep 1
# Output a progress indicator every 5 seconds
if [ $((tries % 5)) -eq 0 ]; then
echo "Still waiting for API startup... ($tries/$MAX_TRIES seconds)"
fi
done
# Show log content regardless of outcome
echo "docling-serve log output so far:"
tail -n 20 "$API_DIR/docling-serve.log"
# Verify the API is running
if ! ps -p $API_PID > /dev/null; then
echo "ERROR: docling-serve API failed to start"
if [ -f "$API_DIR/docling-serve.log" ]; then
echo "Full log output:"
cat "$API_DIR/docling-serve.log"
fi
exit $ERR_API_UNAVAILABLE
fi
if [ "$started" != "true" ]; then
echo "WARNING: API process is running but startup completion was not detected"
echo "Will attempt to continue anyway..."
fi
# Try to verify API is responding at this point
echo "Verifying API responsiveness..."
(python -c "
import sys, time, socket
for i in range(5):
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(1)
result = s.connect_ex(('localhost', 5001))
if result == 0:
s.close()
print('Port 5001 is open and accepting connections')
sys.exit(0)
s.close()
except Exception as e:
pass
time.sleep(1)
print('Could not connect to API port after 5 attempts')
sys.exit(1)
" && echo "API verification succeeded") || echo "API verification failed, but continuing anyway"
# Define API endpoint
DOCLING_API_ENDPOINT="http://localhost:5001/v1alpha/convert/source"
# --- Processing document ---
echo "Starting document processing..."
echo "Reading input from Apify..."
echo "Input content:" >&2
echo "$INPUT" >&2 # Send the raw input to stderr for debugging
echo "$INPUT" # Send the clean JSON to stdout for processing
# Create the request JSON
REQUEST_JSON=$(echo $INPUT | jq '.options += {"return_as_file": true}')
echo "Creating request JSON:" >&2
echo "$REQUEST_JSON" >&2
echo "$REQUEST_JSON" > "$API_DIR/request.json"
# Send the conversion request using our Python script
#echo "Sending conversion request to docling-serve API..."
#python "$TOOLS_DIR/docling_processor.py" \
# --api-endpoint "$DOCLING_API_ENDPOINT" \
# --request-json "$API_DIR/request.json" \
# --output-dir "$API_DIR" \
# --output-format "$OUTPUT_FORMAT"
echo "Curl the Docling API"
curl -s -H "content-type: application/json" -X POST --data-binary @$API_DIR/request.json -o $API_DIR/output.zip $DOCLING_API_ENDPOINT
CURL_EXIT_CODE=$?
# --- Check for various potential output files ---
echo "Checking for output files..."
if [ -f "$API_DIR/output.zip" ]; then
echo "Conversion completed successfully! Output file found."
# Get content from the converted file
OUTPUT_SIZE=$(wc -c < "$API_DIR/output.zip")
echo "Output file found with size: $OUTPUT_SIZE bytes"
# Calculate the access URL for result display
RESULT_URL="https://api.apify.com/v2/key-value-stores/${APIFY_DEFAULT_KEY_VALUE_STORE_ID}/records/OUTPUT"
echo "=============================="
echo "PROCESSING COMPLETE!"
echo "Output size: ${OUTPUT_SIZE} bytes"
echo "=============================="
# Set the output content type based on format
CONTENT_TYPE="application/zip"
# Upload the document content using our function
upload_to_kvs "$API_DIR/output.zip" "OUTPUT" "$CONTENT_TYPE" "Document content"
# Only proceed with dataset record if document upload succeeded
if [ $? -eq 0 ]; then
echo "Your document is available at: ${RESULT_URL}"
echo "=============================="
# Push data to dataset
push_to_dataset "$RESULT_URL" "$OUTPUT_SIZE" "zip"
fi
else
echo "ERROR: No converted output file found at $API_DIR/output.zip"
# Create error metadata
ERROR_METADATA="{\"status\":\"error\",\"error\":\"No converted output file found\",\"documentUrl\":\"$DOCUMENT_URL\"}"
echo "$ERROR_METADATA" > "/tmp/actor-output/OUTPUT"
chmod 644 "/tmp/actor-output/OUTPUT"
echo "Error information has been saved to /tmp/actor-output/OUTPUT"
fi
# --- Verify output files for debugging ---
echo "=== Final Output Verification ==="
echo "Files in /tmp/actor-output:"
ls -la /tmp/actor-output/ 2>/dev/null || echo "Cannot list /tmp/actor-output/"
echo "All operations completed. The output should be available in the default key-value store."
echo "Content URL: ${RESULT_URL:-No URL available}"
# --- Cleanup function ---
cleanup() {
echo "Running cleanup..."
# Stop the API process
if [ -n "$API_PID" ]; then
echo "Stopping docling-serve API (PID: $API_PID)..."
kill $API_PID 2>/dev/null || true
fi
# Export log file to KVS if it exists
# DO THIS BEFORE REMOVING TOOLS DIRECTORY
if [ -f "$LOG_FILE" ]; then
if [ -s "$LOG_FILE" ]; then
echo "Log file is not empty, pushing to key-value store (key: LOG)..."
# Upload log using our function
upload_to_kvs "$LOG_FILE" "LOG" "text/plain" "Log file"
else
echo "Warning: log file exists but is empty"
fi
else
echo "Warning: No log file found"
fi
# Clean up temporary files AFTER log is uploaded
echo "Cleaning up temporary files..."
if [ -d "$API_DIR" ]; then
echo "Removing API working directory: $API_DIR"
rm -rf "$API_DIR" 2>/dev/null || echo "Warning: Failed to remove $API_DIR"
fi
if [ -d "$TOOLS_DIR" ]; then
echo "Removing tools directory: $TOOLS_DIR"
rm -rf "$TOOLS_DIR" 2>/dev/null || echo "Warning: Failed to remove $TOOLS_DIR"
fi
# Keep log file until the very end
echo "Script execution completed at $(date)"
echo "Actor execution completed"
}
# Register cleanup
trap cleanup EXIT

View File

@ -0,0 +1,31 @@
{
"title": "Docling Actor Dataset",
"description": "Records of document processing results from the Docling Actor",
"type": "object",
"schemaVersion": 1,
"properties": {
"url": {
"title": "Document URL",
"type": "string",
"description": "URL of the processed document"
},
"output_file": {
"title": "Result URL",
"type": "string",
"description": "Direct URL to the processed result in key-value store"
},
"status": {
"title": "Processing Status",
"type": "string",
"description": "Status of the document processing",
"enum": ["success", "error"]
},
"error": {
"title": "Error Details",
"type": "string",
"description": "Error message if processing failed",
"optional": true
}
},
"required": ["url", "output_file", "status"]
}

27
.actor/input_schema.json Normal file
View File

@ -0,0 +1,27 @@
{
"title": "Docling Actor Input",
"description": "Options for processing documents with Docling via the docling-serve API.",
"type": "object",
"schemaVersion": 1,
"properties": {
"http_sources": {
"title": "Document URLs",
"type": "array",
"description": "URLs of documents to process. Supported formats: PDF, DOCX, PPTX, XLSX, HTML, MD, XML, images, and more.",
"editor": "json",
"prefill": [
{ "url": "https://vancura.dev/assets/actor-test/facial-hairstyles-and-filtering-facepiece-respirators.pdf" }
]
},
"options": {
"title": "Processing Options",
"type": "object",
"description": "Document processing configuration options",
"editor": "json",
"prefill": {
"to_formats": ["md"]
}
}
},
"required": ["options", "http_sources"]
}

2
.github/SECURITY.md vendored
View File

@ -20,4 +20,4 @@ After the initial reply to your report, the security team will keep you informed
## Security Alerts
We will send announcements of security vulnerabilities and steps to remediate on the [Docling announcements](https://github.com/DS4SD/docling/discussions/categories/announcements).
We will send announcements of security vulnerabilities and steps to remediate on the [Docling announcements](https://github.com/docling-project/docling/discussions/categories/announcements).

View File

@ -8,7 +8,7 @@ runs:
using: 'composite'
steps:
- name: Install poetry
run: pipx install poetry==1.8.3
run: pipx install poetry==1.8.5
shell: bash
- uses: actions/setup-python@v5
with:

View File

@ -1,19 +1,28 @@
on:
workflow_call:
env:
HF_HUB_DOWNLOAD_TIMEOUT: "60"
HF_HUB_ETAG_TIMEOUT: "60"
jobs:
run-checks:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
steps:
- uses: actions/checkout@v4
- name: Install tesseract
run: sudo apt-get update && sudo apt-get install -y tesseract-ocr tesseract-ocr-eng tesseract-ocr-fra tesseract-ocr-deu tesseract-ocr-spa libleptonica-dev libtesseract-dev pkg-config
run: sudo apt-get update && sudo apt-get install -y tesseract-ocr tesseract-ocr-eng tesseract-ocr-fra tesseract-ocr-deu tesseract-ocr-spa tesseract-ocr-script-latn libleptonica-dev libtesseract-dev pkg-config
- name: Set TESSDATA_PREFIX
run: |
echo "TESSDATA_PREFIX=$(dpkg -L tesseract-ocr-eng | grep tessdata$)" >> "$GITHUB_ENV"
- name: Cache Hugging Face models
uses: actions/cache@v4
with:
path: ~/.cache/huggingface
key: huggingface-cache-py${{ matrix.python-version }}
- uses: ./.github/actions/setup-poetry
with:
python-version: ${{ matrix.python-version }}
@ -28,7 +37,7 @@ jobs:
run: |
for file in docs/examples/*.py; do
# Skip batch_convert.py
if [[ "$(basename "$file")" =~ ^(batch_convert|minimal|export_multimodal|custom_convert|develop_picture_enrichment).py ]]; then
if [[ "$(basename "$file")" =~ ^(batch_convert|minimal_vlm_pipeline|minimal|export_multimodal|custom_convert|develop_picture_enrichment|rapidocr_with_custom_models|offline_convert|pictures_description|pictures_description_api).py ]]; then
echo "Skipping $file"
continue
fi

View File

@ -10,7 +10,7 @@ on:
jobs:
build-docs:
if: ${{ github.event_name == 'push' || (github.event.pull_request.head.repo.full_name != 'DS4SD/docling' && github.event.pull_request.head.repo.full_name != 'ds4sd/docling') }}
if: ${{ github.event_name == 'push' || (github.event.pull_request.head.repo.full_name != 'docling-project/docling' && github.event.pull_request.head.repo.full_name != 'docling-project/docling') }}
uses: ./.github/workflows/docs.yml
with:
deploy: false

View File

@ -15,5 +15,5 @@ env:
jobs:
code-checks:
if: ${{ github.event_name == 'push' || (github.event.pull_request.head.repo.full_name != 'DS4SD/docling' && github.event.pull_request.head.repo.full_name != 'ds4sd/docling') }}
if: ${{ github.event_name == 'push' || (github.event.pull_request.head.repo.full_name != 'docling-project/docling' && github.event.pull_request.head.repo.full_name != 'docling-project/docling') }}
uses: ./.github/workflows/checks.yml

View File

@ -17,4 +17,3 @@ jobs:
- name: Build and push docs
if: inputs.deploy
run: poetry run mkdocs gh-deploy --force

File diff suppressed because it is too large Load Diff

View File

@ -1,129 +1,3 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement using
[deepsearch-core@zurich.ibm.com](mailto:deepsearch-core@zurich.ibm.com).
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
[https://www.contributor-covenant.org/version/2/0/code_of_conduct.html](https://www.contributor-covenant.org/version/2/0/code_of_conduct.html).
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
Homepage: [https://www.contributor-covenant.org](https://www.contributor-covenant.org)
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq](https://www.contributor-covenant.org/faq). Translations are available at
[https://www.contributor-covenant.org/translations](https://www.contributor-covenant.org/translations).
This project adheres to the [Docling - Code of Conduct and Covenant](https://github.com/docling-project/community/blob/main/CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code.

View File

@ -2,85 +2,7 @@
Our project welcomes external contributions. If you have an itch, please feel
free to scratch it.
To contribute code or documentation, please submit a [pull request](https://github.com/DS4SD/docling/pulls).
A good way to familiarize yourself with the codebase and contribution process is
to look for and tackle low-hanging fruit in the [issue tracker](https://github.com/DS4SD/docling/issues).
Before embarking on a more ambitious contribution, please quickly [get in touch](#communication) with us.
For general questions or support requests, please refer to the [discussion section](https://github.com/DS4SD/docling/discussions).
**Note: We appreciate your effort and want to avoid situations where a contribution
requires extensive rework (by you or by us), sits in the backlog for a long time, or
cannot be accepted at all!**
### Proposing New Features
If you would like to implement a new feature, please [raise an issue](https://github.com/DS4SD/docling/issues)
before sending a pull request so the feature can be discussed. This is to avoid
you spending valuable time working on a feature that the project developers
are not interested in accepting into the codebase.
### Fixing Bugs
If you would like to fix a bug, please [raise an issue](https://github.com/DS4SD/docling/issues) before sending a
pull request so it can be tracked.
### Merge Approval
The project maintainers use LGTM (Looks Good To Me) in comments on the code
review to indicate acceptance. A change requires LGTMs from two of the
maintainers of each component affected.
For a list of the maintainers, see the [MAINTAINERS.md](MAINTAINERS.md) page.
## Legal
Each source file must include a license header for the MIT
Software. Using the SPDX format is the simplest approach,
e.g.
```
/*
Copyright IBM Inc. All rights reserved.
SPDX-License-Identifier: MIT
*/
```
We have tried to make it as easy as possible to make contributions. This
applies to how we handle the legal aspects of contribution. We use the
same approach - the [Developer's Certificate of Origin 1.1 (DCO)](https://github.com/hyperledger/fabric/blob/master/docs/source/DCO1.1.txt) - that the Linux® Kernel [community](https://elinux.org/Developer_Certificate_Of_Origin)
uses to manage code contributions.
We simply ask that when submitting a patch for review, the developer
must include a sign-off statement in the commit message.
Here is an example Signed-off-by line, which indicates that the
submitter accepts the DCO:
```
Signed-off-by: John Doe <john.doe@example.com>
```
You can include this automatically when you commit a change to your
local git repository using the following command:
```
git commit -s
```
### New dependencies
This project strictly adheres to using dependencies that are compatible with the MIT license to ensure maximum flexibility and permissiveness in its usage and distribution. As a result, dependencies licensed under restrictive terms such as GPL, LGPL, AGPL, or similar are explicitly excluded. These licenses impose additional requirements and limitations that are incompatible with the MIT license's minimal restrictions, potentially affecting derivative works and redistribution. By maintaining this policy, the project ensures simplicity and freedom for both developers and users, avoiding conflicts with stricter copyleft provisions.
## Communication
Please feel free to connect with us using the [discussion section](https://github.com/DS4SD/docling/discussions).
For more details on the contributing guidelines head to the Docling Project [community repository](https://github.com/docling-project/community).
## Developing

View File

@ -4,7 +4,7 @@ ENV GIT_SSH_COMMAND="ssh -o StrictHostKeyChecking=no"
RUN apt-get update \
&& apt-get install -y libgl1 libglib2.0-0 curl wget git procps \
&& apt-get clean
&& rm -rf /var/lib/apt/lists/*
# This will install torch with *only* cpu support
# Remove the --extra-index-url part if you want to install all the gpu requirements
@ -16,8 +16,7 @@ ENV TORCH_HOME=/tmp/
COPY docs/examples/minimal.py /root/minimal.py
RUN python -c 'from deepsearch_glm.utils.load_pretrained_models import load_pretrained_nlp_models; load_pretrained_nlp_models(verbose=True);'
RUN python -c 'from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline; StandardPdfPipeline.download_models_hf(force=True);'
RUN docling-tools models download
# On container environments, always set a thread budget to avoid undesired thread congestion.
ENV OMP_NUM_THREADS=4
@ -25,3 +24,6 @@ ENV OMP_NUM_THREADS=4
# On container shell:
# > cd /root/
# > python minimal.py
# Running as `docker run -e DOCLING_ARTIFACTS_PATH=/root/.cache/docling/models` will use the
# model weights included in the container image.

View File

@ -2,9 +2,6 @@
- Christoph Auer - [@cau-git](https://github.com/cau-git)
- Michele Dolfi - [@dolfim-ibm](https://github.com/dolfim-ibm)
- Maxim Lysak - [@maxmnemonic](https://github.com/maxmnemonic)
- Nikos Livathinos - [@nikos-livathinos](https://github.com/nikos-livathinos)
- Ahmed Nassar - [@nassarofficial](https://github.com/nassarofficial)
- Panos Vagenas - [@vagenas](https://github.com/vagenas)
- Peter Staar - [@PeterStaar-IBM](https://github.com/PeterStaar-IBM)

View File

@ -1,6 +1,6 @@
<p align="center">
<a href="https://github.com/ds4sd/docling">
<img loading="lazy" alt="Docling" src="https://github.com/DS4SD/docling/raw/main/docs/assets/docling_processing.png" width="100%"/>
<a href="https://github.com/docling-project/docling">
<img loading="lazy" alt="Docling" src="https://github.com/docling-project/docling/raw/main/docs/assets/docling_processing.png" width="100%"/>
</a>
</p>
@ -11,7 +11,7 @@
</p>
[![arXiv](https://img.shields.io/badge/arXiv-2408.09869-b31b1b.svg)](https://arxiv.org/abs/2408.09869)
[![Docs](https://img.shields.io/badge/docs-live-brightgreen)](https://ds4sd.github.io/docling/)
[![Docs](https://img.shields.io/badge/docs-live-brightgreen)](https://docling-project.github.io/docling/)
[![PyPI version](https://img.shields.io/pypi/v/docling)](https://pypi.org/project/docling/)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/docling)](https://pypi.org/project/docling/)
[![Poetry](https://img.shields.io/endpoint?url=https://python-poetry.org/badge/v0.json)](https://python-poetry.org/)
@ -19,27 +19,30 @@
[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
[![Pydantic v2](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/pydantic/pydantic/main/docs/badge/v2.json)](https://pydantic.dev)
[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit)
[![License MIT](https://img.shields.io/github/license/DS4SD/docling)](https://opensource.org/licenses/MIT)
[![License MIT](https://img.shields.io/github/license/docling-project/docling)](https://opensource.org/licenses/MIT)
[![PyPI Downloads](https://static.pepy.tech/badge/docling/month)](https://pepy.tech/projects/docling)
[![Docling Actor](https://apify.com/actor-badge?actor=vancura/docling?fpr=docling)](https://apify.com/vancura/docling)
[![LF AI & Data](https://img.shields.io/badge/LF%20AI%20%26%20Data-003778?logo=linuxfoundation&logoColor=fff&color=0094ff&labelColor=003778)](https://lfaidata.foundation/projects/)
Docling parses documents and exports them to the desired format with ease and speed.
Docling simplifies document processing, parsing diverse formats — including advanced PDF understanding — and providing seamless integrations with the gen AI ecosystem.
## Features
* 🗂️ Reads popular document formats (PDF, DOCX, PPTX, XLSX, Images, HTML, AsciiDoc & Markdown) and exports to HTML, Markdown and JSON (with embedded and referenced images)
* 📑 Advanced PDF document understanding including page layout, reading order & table structures
* 🧩 Unified, expressive [DoclingDocument](https://ds4sd.github.io/docling/concepts/docling_document/) representation format
* 🤖 Easy integration with 🦙 LlamaIndex & 🦜🔗 LangChain for powerful RAG / QA applications
* 🔍 OCR support for scanned PDFs
* 🗂️ Parsing of [multiple document formats][supported_formats] incl. PDF, DOCX, XLSX, HTML, images, and more
* 📑 Advanced PDF understanding incl. page layout, reading order, table structure, code, formulas, image classification, and more
* 🧬 Unified, expressive [DoclingDocument][docling_document] representation format
* ↪️ Various [export formats][supported_formats] and options, including Markdown, HTML, and lossless JSON
* 🔒 Local execution capabilities for sensitive data and air-gapped environments
* 🤖 Plug-and-play [integrations][integrations] incl. LangChain, LlamaIndex, Crew AI & Haystack for agentic AI
* 🔍 Extensive OCR support for scanned PDFs and images
* 🥚 Support of Visual Language Models ([SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview)) 🆕
* 💻 Simple and convenient CLI
Explore the [documentation](https://ds4sd.github.io/docling/) to discover plenty examples and unlock the full power of Docling!
### Coming soon
* ♾️ Equation & code extraction
* 📝 Metadata extraction, including title, authors, references & language
* 🦜🔗 Native LangChain extension
* 📝 Chart understanding (Barchart, Piechart, LinePlot, etc)
* 📝 Complex chemistry understanding (Molecular structures)
## Installation
@ -50,11 +53,11 @@ pip install docling
Works on macOS, Linux and Windows environments. Both x86_64 and arm64 architectures.
More [detailed installation instructions](https://ds4sd.github.io/docling/installation/) are available in the docs.
More [detailed installation instructions](https://docling-project.github.io/docling/installation/) are available in the docs.
## Getting started
To convert individual documents, use `convert()`, for example:
To convert individual documents with python, use `convert()`, for example:
```python
from docling.document_converter import DocumentConverter
@ -65,28 +68,44 @@ result = converter.convert(source)
print(result.document.export_to_markdown()) # output: "## Docling Technical Report[...]"
```
More [advanced usage options](https://ds4sd.github.io/docling/usage/) are available in
More [advanced usage options](https://docling-project.github.io/docling/usage/) are available in
the docs.
## CLI
Docling has a built-in CLI to run conversions.
```bash
docling https://arxiv.org/pdf/2206.01062
```
You can also use 🥚[SmolDocling](https://huggingface.co/ds4sd/SmolDocling-256M-preview) and other VLMs via Docling CLI:
```bash
docling --pipeline vlm --vlm-model smoldocling https://arxiv.org/pdf/2206.01062
```
This will use MLX acceleration on supported Apple Silicon hardware.
Read more [here](https://docling-project.github.io/docling/usage/)
## Documentation
Check out Docling's [documentation](https://ds4sd.github.io/docling/), for details on
Check out Docling's [documentation](https://docling-project.github.io/docling/), for details on
installation, usage, concepts, recipes, extensions, and more.
## Examples
Go hands-on with our [examples](https://ds4sd.github.io/docling/examples/),
Go hands-on with our [examples](https://docling-project.github.io/docling/examples/),
demonstrating how to address different application use cases with Docling.
## Integrations
To further accelerate your AI application development, check out Docling's native
[integrations](https://ds4sd.github.io/docling/integrations/) with popular frameworks
[integrations](https://docling-project.github.io/docling/integrations/) with popular frameworks
and tools.
## Get help and support
Please feel free to connect with us using the [discussion section](https://github.com/DS4SD/docling/discussions).
Please feel free to connect with us using the [discussion section](https://github.com/docling-project/docling/discussions).
## Technical report
@ -94,7 +113,7 @@ For more details on Docling's inner workings, check out the [Docling Technical R
## Contributing
Please read [Contributing to Docling](https://github.com/DS4SD/docling/blob/main/CONTRIBUTING.md) for details.
Please read [Contributing to Docling](https://github.com/docling-project/docling/blob/main/CONTRIBUTING.md) for details.
## References
@ -118,6 +137,14 @@ If you use Docling in your projects, please consider citing the following:
The Docling codebase is under MIT license.
For individual model usage, please refer to the model licenses found in the original packages.
## IBM ❤️ Open Source AI
## LF AI & Data
Docling has been brought to you by IBM.
Docling is hosted as a project in the [LF AI & Data Foundation](https://lfaidata.foundation/projects/).
### IBM ❤️ Open Source AI
The project was started by the AI for knowledge team at IBM Research Zurich.
[supported_formats]: https://docling-project.github.io/docling/usage/supported_formats/
[docling_document]: https://docling-project.github.io/docling/concepts/docling_document/
[integrations]: https://docling-project.github.io/docling/integrations/

View File

@ -27,7 +27,6 @@ class AbstractDocumentBackend(ABC):
def supports_pagination(cls) -> bool:
pass
@abstractmethod
def unload(self):
if isinstance(self.path_or_stream, BytesIO):
self.path_or_stream.close()

View File

@ -24,7 +24,6 @@ _log = logging.getLogger(__name__)
class AsciiDocBackend(DeclarativeDocumentBackend):
def __init__(self, in_doc: InputDocument, path_or_stream: Union[BytesIO, Path]):
super().__init__(in_doc, path_or_stream)
@ -381,7 +380,7 @@ class AsciiDocBackend(DeclarativeDocumentBackend):
end_row_offset_idx=row_idx + row_span,
start_col_offset_idx=col_idx,
end_col_offset_idx=col_idx + col_span,
col_header=False,
column_header=row_idx == 0,
row_header=False,
)
data.table_cells.append(cell)

View File

@ -0,0 +1,125 @@
import csv
import logging
import warnings
from io import BytesIO, StringIO
from pathlib import Path
from typing import Set, Union
from docling_core.types.doc import DoclingDocument, DocumentOrigin, TableCell, TableData
from docling.backend.abstract_backend import DeclarativeDocumentBackend
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import InputDocument
_log = logging.getLogger(__name__)
class CsvDocumentBackend(DeclarativeDocumentBackend):
content: StringIO
def __init__(self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]):
super().__init__(in_doc, path_or_stream)
# Load content
try:
if isinstance(self.path_or_stream, BytesIO):
self.content = StringIO(self.path_or_stream.getvalue().decode("utf-8"))
elif isinstance(self.path_or_stream, Path):
self.content = StringIO(self.path_or_stream.read_text("utf-8"))
self.valid = True
except Exception as e:
raise RuntimeError(
f"CsvDocumentBackend could not load document with hash {self.document_hash}"
) from e
return
def is_valid(self) -> bool:
return self.valid
@classmethod
def supports_pagination(cls) -> bool:
return False
def unload(self):
if isinstance(self.path_or_stream, BytesIO):
self.path_or_stream.close()
self.path_or_stream = None
@classmethod
def supported_formats(cls) -> Set[InputFormat]:
return {InputFormat.CSV}
def convert(self) -> DoclingDocument:
"""
Parses the CSV data into a structured document model.
"""
# Detect CSV dialect
head = self.content.readline()
dialect = csv.Sniffer().sniff(head, ",;\t|:")
_log.info(f'Parsing CSV with delimiter: "{dialect.delimiter}"')
if not dialect.delimiter in {",", ";", "\t", "|", ":"}:
raise RuntimeError(
f"Cannot convert csv with unknown delimiter {dialect.delimiter}."
)
# Parce CSV
self.content.seek(0)
result = csv.reader(self.content, dialect=dialect, strict=True)
self.csv_data = list(result)
_log.info(f"Detected {len(self.csv_data)} lines")
# Ensure uniform column length
expected_length = len(self.csv_data[0])
is_uniform = all(len(row) == expected_length for row in self.csv_data)
if not is_uniform:
warnings.warn(
f"Inconsistent column lengths detected in CSV data. "
f"Expected {expected_length} columns, but found rows with varying lengths. "
f"Ensure all rows have the same number of columns."
)
# Parse the CSV into a structured document model
origin = DocumentOrigin(
filename=self.file.name or "file.csv",
mimetype="text/csv",
binary_hash=self.document_hash,
)
doc = DoclingDocument(name=self.file.stem or "file.csv", origin=origin)
if self.is_valid():
# Convert CSV data to table
if self.csv_data:
num_rows = len(self.csv_data)
num_cols = max(len(row) for row in self.csv_data)
table_data = TableData(
num_rows=num_rows,
num_cols=num_cols,
table_cells=[],
)
# Convert each cell to TableCell
for row_idx, row in enumerate(self.csv_data):
for col_idx, cell_value in enumerate(row):
cell = TableCell(
text=str(cell_value),
row_span=1, # CSV doesn't support merged cells
col_span=1,
start_row_offset_idx=row_idx,
end_row_offset_idx=row_idx + 1,
start_col_offset_idx=col_idx,
end_col_offset_idx=col_idx + 1,
column_header=row_idx == 0, # First row as header
row_header=False,
)
table_data.table_cells.append(cell)
doc.add_table(data=table_data)
else:
raise RuntimeError(
f"Cannot convert doc with {self.document_hash} because the backend failed to init."
)
return doc

View File

@ -6,12 +6,12 @@ from typing import Iterable, List, Optional, Union
import pypdfium2 as pdfium
from docling_core.types.doc import BoundingBox, CoordOrigin, Size
from docling_core.types.doc.page import BoundingRectangle, SegmentedPdfPage, TextCell
from docling_parse.pdf_parsers import pdf_parser_v1
from PIL import Image, ImageDraw
from pypdfium2 import PdfPage
from docling.backend.pdf_backend import PdfDocumentBackend, PdfPageBackend
from docling.datamodel.base_models import Cell
from docling.datamodel.document import InputDocument
_log = logging.getLogger(__name__)
@ -68,8 +68,11 @@ class DoclingParsePageBackend(PdfPageBackend):
return text_piece
def get_text_cells(self) -> Iterable[Cell]:
cells: List[Cell] = []
def get_segmented_page(self) -> Optional[SegmentedPdfPage]:
return None
def get_text_cells(self) -> Iterable[TextCell]:
cells: List[TextCell] = []
cell_counter = 0
if not self.valid:
@ -91,19 +94,24 @@ class DoclingParsePageBackend(PdfPageBackend):
text_piece = self._dpage["cells"][i]["content"]["rnormalized"]
cells.append(
Cell(
id=cell_counter,
TextCell(
index=cell_counter,
text=text_piece,
bbox=BoundingBox(
# l=x0, b=y0, r=x1, t=y1,
l=x0 * page_size.width / parser_width,
b=y0 * page_size.height / parser_height,
r=x1 * page_size.width / parser_width,
t=y1 * page_size.height / parser_height,
coord_origin=CoordOrigin.BOTTOMLEFT,
orig=text_piece,
from_ocr=False,
rect=BoundingRectangle.from_bounding_box(
BoundingBox(
# l=x0, b=y0, r=x1, t=y1,
l=x0 * page_size.width / parser_width,
b=y0 * page_size.height / parser_height,
r=x1 * page_size.width / parser_width,
t=y1 * page_size.height / parser_height,
coord_origin=CoordOrigin.BOTTOMLEFT,
)
).to_top_left_origin(page_size.height),
)
)
cell_counter += 1
def draw_clusters_and_cells():
@ -112,7 +120,7 @@ class DoclingParsePageBackend(PdfPageBackend):
) # make new image to avoid drawing on the saved ones
draw = ImageDraw.Draw(image)
for c in cells:
x0, y0, x1, y1 = c.bbox.as_tuple()
x0, y0, x1, y1 = c.rect.to_bounding_box().as_tuple()
cell_color = (
random.randint(30, 140),
random.randint(30, 140),
@ -132,7 +140,7 @@ class DoclingParsePageBackend(PdfPageBackend):
return cells
def get_bitmap_rects(self, scale: float = 1) -> Iterable[BoundingBox]:
AREA_THRESHOLD = 32 * 32
AREA_THRESHOLD = 0 # 32 * 32
for i in range(len(self._dpage["images"])):
bitmap = self._dpage["images"][i]
@ -163,7 +171,7 @@ class DoclingParsePageBackend(PdfPageBackend):
l=0, r=0, t=0, b=0, coord_origin=CoordOrigin.BOTTOMLEFT
)
else:
padbox = cropbox.to_bottom_left_origin(page_size.height)
padbox = cropbox.to_bottom_left_origin(page_size.height).model_copy()
padbox.r = page_size.width - padbox.r
padbox.t = page_size.height - padbox.t

View File

@ -6,12 +6,14 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Union
import pypdfium2 as pdfium
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import BoundingRectangle, SegmentedPdfPage, TextCell
from docling_parse.pdf_parsers import pdf_parser_v2
from PIL import Image, ImageDraw
from pypdfium2 import PdfPage
from docling.backend.pdf_backend import PdfDocumentBackend, PdfPageBackend
from docling.datamodel.base_models import Cell, Size
from docling.datamodel.base_models import Size
from docling.utils.locks import pypdfium2_lock
if TYPE_CHECKING:
from docling.datamodel.document import InputDocument
@ -77,8 +79,11 @@ class DoclingParseV2PageBackend(PdfPageBackend):
return text_piece
def get_text_cells(self) -> Iterable[Cell]:
cells: List[Cell] = []
def get_segmented_page(self) -> Optional[SegmentedPdfPage]:
return None
def get_text_cells(self) -> Iterable[TextCell]:
cells: List[TextCell] = []
cell_counter = 0
if not self.valid:
@ -105,16 +110,20 @@ class DoclingParseV2PageBackend(PdfPageBackend):
text_piece = cell_data[cells_header.index("text")]
cells.append(
Cell(
id=cell_counter,
TextCell(
index=cell_counter,
text=text_piece,
bbox=BoundingBox(
# l=x0, b=y0, r=x1, t=y1,
l=x0 * page_size.width / parser_width,
b=y0 * page_size.height / parser_height,
r=x1 * page_size.width / parser_width,
t=y1 * page_size.height / parser_height,
coord_origin=CoordOrigin.BOTTOMLEFT,
orig=text_piece,
from_ocr=False,
rect=BoundingRectangle.from_bounding_box(
BoundingBox(
# l=x0, b=y0, r=x1, t=y1,
l=x0 * page_size.width / parser_width,
b=y0 * page_size.height / parser_height,
r=x1 * page_size.width / parser_width,
t=y1 * page_size.height / parser_height,
coord_origin=CoordOrigin.BOTTOMLEFT,
)
).to_top_left_origin(page_size.height),
)
)
@ -140,7 +149,7 @@ class DoclingParseV2PageBackend(PdfPageBackend):
return cells
def get_bitmap_rects(self, scale: float = 1) -> Iterable[BoundingBox]:
AREA_THRESHOLD = 32 * 32
AREA_THRESHOLD = 0 # 32 * 32
images = self._dpage["sanitized"]["images"]["data"]
images_header = self._dpage["sanitized"]["images"]["header"]
@ -178,24 +187,28 @@ class DoclingParseV2PageBackend(PdfPageBackend):
l=0, r=0, t=0, b=0, coord_origin=CoordOrigin.BOTTOMLEFT
)
else:
padbox = cropbox.to_bottom_left_origin(page_size.height)
padbox = cropbox.to_bottom_left_origin(page_size.height).model_copy()
padbox.r = page_size.width - padbox.r
padbox.t = page_size.height - padbox.t
image = (
self._ppage.render(
scale=scale * 1.5,
rotation=0, # no additional rotation
crop=padbox.as_tuple(),
)
.to_pil()
.resize(size=(round(cropbox.width * scale), round(cropbox.height * scale)))
) # We resize the image from 1.5x the given scale to make it sharper.
with pypdfium2_lock:
image = (
self._ppage.render(
scale=scale * 1.5,
rotation=0, # no additional rotation
crop=padbox.as_tuple(),
)
.to_pil()
.resize(
size=(round(cropbox.width * scale), round(cropbox.height * scale))
)
) # We resize the image from 1.5x the given scale to make it sharper.
return image
def get_size(self) -> Size:
return Size(width=self._ppage.get_width(), height=self._ppage.get_height())
with pypdfium2_lock:
return Size(width=self._ppage.get_width(), height=self._ppage.get_height())
def unload(self):
self._ppage = None
@ -206,23 +219,24 @@ class DoclingParseV2DocumentBackend(PdfDocumentBackend):
def __init__(self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]):
super().__init__(in_doc, path_or_stream)
self._pdoc = pdfium.PdfDocument(self.path_or_stream)
self.parser = pdf_parser_v2("fatal")
with pypdfium2_lock:
self._pdoc = pdfium.PdfDocument(self.path_or_stream)
self.parser = pdf_parser_v2("fatal")
success = False
if isinstance(self.path_or_stream, BytesIO):
success = self.parser.load_document_from_bytesio(
self.document_hash, self.path_or_stream
)
elif isinstance(self.path_or_stream, Path):
success = self.parser.load_document(
self.document_hash, str(self.path_or_stream)
)
success = False
if isinstance(self.path_or_stream, BytesIO):
success = self.parser.load_document_from_bytesio(
self.document_hash, self.path_or_stream
)
elif isinstance(self.path_or_stream, Path):
success = self.parser.load_document(
self.document_hash, str(self.path_or_stream)
)
if not success:
raise RuntimeError(
f"docling-parse v2 could not load document {self.document_hash}."
)
if not success:
raise RuntimeError(
f"docling-parse v2 could not load document {self.document_hash}."
)
def page_count(self) -> int:
# return len(self._pdoc) # To be replaced with docling-parse API
@ -236,9 +250,10 @@ class DoclingParseV2DocumentBackend(PdfDocumentBackend):
return len_2
def load_page(self, page_no: int) -> DoclingParseV2PageBackend:
return DoclingParseV2PageBackend(
self.parser, self.document_hash, page_no, self._pdoc[page_no]
)
with pypdfium2_lock:
return DoclingParseV2PageBackend(
self.parser, self.document_hash, page_no, self._pdoc[page_no]
)
def is_valid(self) -> bool:
return self.page_count() > 0
@ -246,5 +261,6 @@ class DoclingParseV2DocumentBackend(PdfDocumentBackend):
def unload(self):
super().unload()
self.parser.unload_document(self.document_hash)
self._pdoc.close()
self._pdoc = None
with pypdfium2_lock:
self._pdoc.close()
self._pdoc = None

View File

@ -0,0 +1,192 @@
import logging
import random
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
import pypdfium2 as pdfium
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import SegmentedPdfPage, TextCell
from docling_parse.pdf_parser import DoclingPdfParser, PdfDocument
from PIL import Image, ImageDraw
from pypdfium2 import PdfPage
from docling.backend.pdf_backend import PdfDocumentBackend, PdfPageBackend
from docling.datamodel.base_models import Size
from docling.utils.locks import pypdfium2_lock
if TYPE_CHECKING:
from docling.datamodel.document import InputDocument
_log = logging.getLogger(__name__)
class DoclingParseV4PageBackend(PdfPageBackend):
def __init__(self, parsed_page: SegmentedPdfPage, page_obj: PdfPage):
self._ppage = page_obj
self._dpage = parsed_page
self.valid = parsed_page is not None
def is_valid(self) -> bool:
return self.valid
def get_text_in_rect(self, bbox: BoundingBox) -> str:
# Find intersecting cells on the page
text_piece = ""
page_size = self.get_size()
scale = (
1 # FIX - Replace with param in get_text_in_rect across backends (optional)
)
for i, cell in enumerate(self._dpage.textline_cells):
cell_bbox = (
cell.rect.to_bounding_box()
.to_top_left_origin(page_height=page_size.height)
.scaled(scale)
)
overlap_frac = cell_bbox.intersection_area_with(bbox) / cell_bbox.area()
if overlap_frac > 0.5:
if len(text_piece) > 0:
text_piece += " "
text_piece += cell.text
return text_piece
def get_segmented_page(self) -> Optional[SegmentedPdfPage]:
return self._dpage
def get_text_cells(self) -> Iterable[TextCell]:
page_size = self.get_size()
[tc.to_top_left_origin(page_size.height) for tc in self._dpage.textline_cells]
# for cell in self._dpage.textline_cells:
# rect = cell.rect
#
# assert (
# rect.to_bounding_box().l <= rect.to_bounding_box().r
# ), f"left is > right on bounding box {rect.to_bounding_box()} of rect {rect}"
# assert (
# rect.to_bounding_box().t <= rect.to_bounding_box().b
# ), f"top is > bottom on bounding box {rect.to_bounding_box()} of rect {rect}"
return self._dpage.textline_cells
def get_bitmap_rects(self, scale: float = 1) -> Iterable[BoundingBox]:
AREA_THRESHOLD = 0 # 32 * 32
images = self._dpage.bitmap_resources
for img in images:
cropbox = img.rect.to_bounding_box().to_top_left_origin(
self.get_size().height
)
if cropbox.area() > AREA_THRESHOLD:
cropbox = cropbox.scaled(scale=scale)
yield cropbox
def get_page_image(
self, scale: float = 1, cropbox: Optional[BoundingBox] = None
) -> Image.Image:
page_size = self.get_size()
if not cropbox:
cropbox = BoundingBox(
l=0,
r=page_size.width,
t=0,
b=page_size.height,
coord_origin=CoordOrigin.TOPLEFT,
)
padbox = BoundingBox(
l=0, r=0, t=0, b=0, coord_origin=CoordOrigin.BOTTOMLEFT
)
else:
padbox = cropbox.to_bottom_left_origin(page_size.height).model_copy()
padbox.r = page_size.width - padbox.r
padbox.t = page_size.height - padbox.t
with pypdfium2_lock:
image = (
self._ppage.render(
scale=scale * 1.5,
rotation=0, # no additional rotation
crop=padbox.as_tuple(),
)
.to_pil()
.resize(
size=(round(cropbox.width * scale), round(cropbox.height * scale))
)
) # We resize the image from 1.5x the given scale to make it sharper.
return image
def get_size(self) -> Size:
with pypdfium2_lock:
return Size(width=self._ppage.get_width(), height=self._ppage.get_height())
# TODO: Take width and height from docling-parse.
# return Size(
# width=self._dpage.dimension.width,
# height=self._dpage.dimension.height,
# )
def unload(self):
self._ppage = None
self._dpage = None
class DoclingParseV4DocumentBackend(PdfDocumentBackend):
def __init__(self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]):
super().__init__(in_doc, path_or_stream)
with pypdfium2_lock:
self._pdoc = pdfium.PdfDocument(self.path_or_stream)
self.parser = DoclingPdfParser(loglevel="fatal")
self.dp_doc: PdfDocument = self.parser.load(path_or_stream=self.path_or_stream)
success = self.dp_doc is not None
if not success:
raise RuntimeError(
f"docling-parse v4 could not load document {self.document_hash}."
)
def page_count(self) -> int:
# return len(self._pdoc) # To be replaced with docling-parse API
len_1 = len(self._pdoc)
len_2 = self.dp_doc.number_of_pages()
if len_1 != len_2:
_log.error(f"Inconsistent number of pages: {len_1}!={len_2}")
return len_2
def load_page(
self, page_no: int, create_words: bool = True, create_textlines: bool = True
) -> DoclingParseV4PageBackend:
with pypdfium2_lock:
return DoclingParseV4PageBackend(
self.dp_doc.get_page(
page_no + 1,
create_words=create_words,
create_textlines=create_textlines,
),
self._pdoc[page_no],
)
def is_valid(self) -> bool:
return self.page_count() > 0
def unload(self):
super().unload()
self.dp_doc.unload()
with pypdfium2_lock:
self._pdoc.close()
self._pdoc = None

View File

View File

View File

@ -0,0 +1,271 @@
# -*- coding: utf-8 -*-
"""
Adapted from https://github.com/xiilei/dwml/blob/master/dwml/latex_dict.py
On 23/01/2025
"""
from __future__ import unicode_literals
CHARS = ("{", "}", "_", "^", "#", "&", "$", "%", "~")
BLANK = ""
BACKSLASH = "\\"
ALN = "&"
CHR = {
# Unicode : Latex Math Symbols
# Top accents
"\u0300": "\\grave{{{0}}}",
"\u0301": "\\acute{{{0}}}",
"\u0302": "\\hat{{{0}}}",
"\u0303": "\\tilde{{{0}}}",
"\u0304": "\\bar{{{0}}}",
"\u0305": "\\overbar{{{0}}}",
"\u0306": "\\breve{{{0}}}",
"\u0307": "\\dot{{{0}}}",
"\u0308": "\\ddot{{{0}}}",
"\u0309": "\\ovhook{{{0}}}",
"\u030a": "\\ocirc{{{0}}}}",
"\u030c": "\\check{{{0}}}}",
"\u0310": "\\candra{{{0}}}",
"\u0312": "\\oturnedcomma{{{0}}}",
"\u0315": "\\ocommatopright{{{0}}}",
"\u031a": "\\droang{{{0}}}",
"\u0338": "\\not{{{0}}}",
"\u20d0": "\\leftharpoonaccent{{{0}}}",
"\u20d1": "\\rightharpoonaccent{{{0}}}",
"\u20d2": "\\vertoverlay{{{0}}}",
"\u20d6": "\\overleftarrow{{{0}}}",
"\u20d7": "\\vec{{{0}}}",
"\u20db": "\\dddot{{{0}}}",
"\u20dc": "\\ddddot{{{0}}}",
"\u20e1": "\\overleftrightarrow{{{0}}}",
"\u20e7": "\\annuity{{{0}}}",
"\u20e9": "\\widebridgeabove{{{0}}}",
"\u20f0": "\\asteraccent{{{0}}}",
# Bottom accents
"\u0330": "\\wideutilde{{{0}}}",
"\u0331": "\\underbar{{{0}}}",
"\u20e8": "\\threeunderdot{{{0}}}",
"\u20ec": "\\underrightharpoondown{{{0}}}",
"\u20ed": "\\underleftharpoondown{{{0}}}",
"\u20ee": "\\underledtarrow{{{0}}}",
"\u20ef": "\\underrightarrow{{{0}}}",
# Over | group
"\u23b4": "\\overbracket{{{0}}}",
"\u23dc": "\\overparen{{{0}}}",
"\u23de": "\\overbrace{{{0}}}",
# Under| group
"\u23b5": "\\underbracket{{{0}}}",
"\u23dd": "\\underparen{{{0}}}",
"\u23df": "\\underbrace{{{0}}}",
}
CHR_BO = {
# Big operators,
"\u2140": "\\Bbbsum",
"\u220f": "\\prod",
"\u2210": "\\coprod",
"\u2211": "\\sum",
"\u222b": "\\int",
"\u22c0": "\\bigwedge",
"\u22c1": "\\bigvee",
"\u22c2": "\\bigcap",
"\u22c3": "\\bigcup",
"\u2a00": "\\bigodot",
"\u2a01": "\\bigoplus",
"\u2a02": "\\bigotimes",
}
T = {
"\u2192": "\\rightarrow ",
# Greek letters
"\U0001d6fc": "\\alpha ",
"\U0001d6fd": "\\beta ",
"\U0001d6fe": "\\gamma ",
"\U0001d6ff": "\\theta ",
"\U0001d700": "\\epsilon ",
"\U0001d701": "\\zeta ",
"\U0001d702": "\\eta ",
"\U0001d703": "\\theta ",
"\U0001d704": "\\iota ",
"\U0001d705": "\\kappa ",
"\U0001d706": "\\lambda ",
"\U0001d707": "\\m ",
"\U0001d708": "\\n ",
"\U0001d709": "\\xi ",
"\U0001d70a": "\\omicron ",
"\U0001d70b": "\\pi ",
"\U0001d70c": "\\rho ",
"\U0001d70d": "\\varsigma ",
"\U0001d70e": "\\sigma ",
"\U0001d70f": "\\ta ",
"\U0001d710": "\\upsilon ",
"\U0001d711": "\\phi ",
"\U0001d712": "\\chi ",
"\U0001d713": "\\psi ",
"\U0001d714": "\\omega ",
"\U0001d715": "\\partial ",
"\U0001d716": "\\varepsilon ",
"\U0001d717": "\\vartheta ",
"\U0001d718": "\\varkappa ",
"\U0001d719": "\\varphi ",
"\U0001d71a": "\\varrho ",
"\U0001d71b": "\\varpi ",
# Relation symbols
"\u2190": "\\leftarrow ",
"\u2191": "\\uparrow ",
"\u2192": "\\rightarrow ",
"\u2193": "\\downright ",
"\u2194": "\\leftrightarrow ",
"\u2195": "\\updownarrow ",
"\u2196": "\\nwarrow ",
"\u2197": "\\nearrow ",
"\u2198": "\\searrow ",
"\u2199": "\\swarrow ",
"\u22ee": "\\vdots ",
"\u22ef": "\\cdots ",
"\u22f0": "\\adots ",
"\u22f1": "\\ddots ",
"\u2260": "\\ne ",
"\u2264": "\\leq ",
"\u2265": "\\geq ",
"\u2266": "\\leqq ",
"\u2267": "\\geqq ",
"\u2268": "\\lneqq ",
"\u2269": "\\gneqq ",
"\u226a": "\\ll ",
"\u226b": "\\gg ",
"\u2208": "\\in ",
"\u2209": "\\notin ",
"\u220b": "\\ni ",
"\u220c": "\\nni ",
# Ordinary symbols
"\u221e": "\\infty ",
# Binary relations
"\u00b1": "\\pm ",
"\u2213": "\\mp ",
# Italic, Latin, uppercase
"\U0001d434": "A",
"\U0001d435": "B",
"\U0001d436": "C",
"\U0001d437": "D",
"\U0001d438": "E",
"\U0001d439": "F",
"\U0001d43a": "G",
"\U0001d43b": "H",
"\U0001d43c": "I",
"\U0001d43d": "J",
"\U0001d43e": "K",
"\U0001d43f": "L",
"\U0001d440": "M",
"\U0001d441": "N",
"\U0001d442": "O",
"\U0001d443": "P",
"\U0001d444": "Q",
"\U0001d445": "R",
"\U0001d446": "S",
"\U0001d447": "T",
"\U0001d448": "U",
"\U0001d449": "V",
"\U0001d44a": "W",
"\U0001d44b": "X",
"\U0001d44c": "Y",
"\U0001d44d": "Z",
# Italic, Latin, lowercase
"\U0001d44e": "a",
"\U0001d44f": "b",
"\U0001d450": "c",
"\U0001d451": "d",
"\U0001d452": "e",
"\U0001d453": "f",
"\U0001d454": "g",
"\U0001d456": "i",
"\U0001d457": "j",
"\U0001d458": "k",
"\U0001d459": "l",
"\U0001d45a": "m",
"\U0001d45b": "n",
"\U0001d45c": "o",
"\U0001d45d": "p",
"\U0001d45e": "q",
"\U0001d45f": "r",
"\U0001d460": "s",
"\U0001d461": "t",
"\U0001d462": "u",
"\U0001d463": "v",
"\U0001d464": "w",
"\U0001d465": "x",
"\U0001d466": "y",
"\U0001d467": "z",
}
FUNC = {
"sin": "\\sin({fe})",
"cos": "\\cos({fe})",
"tan": "\\tan({fe})",
"arcsin": "\\arcsin({fe})",
"arccos": "\\arccos({fe})",
"arctan": "\\arctan({fe})",
"arccot": "\\arccot({fe})",
"sinh": "\\sinh({fe})",
"cosh": "\\cosh({fe})",
"tanh": "\\tanh({fe})",
"coth": "\\coth({fe})",
"sec": "\\sec({fe})",
"csc": "\\csc({fe})",
}
FUNC_PLACE = "{fe}"
BRK = "\\\\"
CHR_DEFAULT = {
"ACC_VAL": "\\hat{{{0}}}",
}
POS = {
"top": "\\overline{{{0}}}", # not sure
"bot": "\\underline{{{0}}}",
}
POS_DEFAULT = {
"BAR_VAL": "\\overline{{{0}}}",
}
SUB = "_{{{0}}}"
SUP = "^{{{0}}}"
F = {
"bar": "\\frac{{{num}}}{{{den}}}",
"skw": r"^{{{num}}}/_{{{den}}}",
"noBar": "\\genfrac{{}}{{}}{{0pt}}{{}}{{{num}}}{{{den}}}",
"lin": "{{{num}}}/{{{den}}}",
}
F_DEFAULT = "\\frac{{{num}}}{{{den}}}"
D = "\\left{left}{text}\\right{right}"
D_DEFAULT = {
"left": "(",
"right": ")",
"null": ".",
}
RAD = "\\sqrt[{deg}]{{{text}}}"
RAD_DEFAULT = "\\sqrt{{{text}}}"
ARR = "{text}"
LIM_FUNC = {
"lim": "\\lim_{{{lim}}}",
"max": "\\max_{{{lim}}}",
"min": "\\min_{{{lim}}}",
}
LIM_TO = ("\\rightarrow", "\\to")
LIM_UPP = "\\overset{{{lim}}}{{{text}}}"
M = "\\begin{{matrix}}{text}\\end{{matrix}}"

View File

@ -0,0 +1,453 @@
"""
Office Math Markup Language (OMML)
Adapted from https://github.com/xiilei/dwml/blob/master/dwml/omml.py
On 23/01/2025
"""
import lxml.etree as ET
from pylatexenc.latexencode import UnicodeToLatexEncoder
from docling.backend.docx.latex.latex_dict import (
ALN,
ARR,
BACKSLASH,
BLANK,
BRK,
CHARS,
CHR,
CHR_BO,
CHR_DEFAULT,
D_DEFAULT,
F_DEFAULT,
FUNC,
FUNC_PLACE,
LIM_FUNC,
LIM_TO,
LIM_UPP,
POS,
POS_DEFAULT,
RAD,
RAD_DEFAULT,
SUB,
SUP,
D,
F,
M,
T,
)
OMML_NS = "{http://schemas.openxmlformats.org/officeDocument/2006/math}"
def load(stream):
tree = ET.parse(stream)
for omath in tree.findall(OMML_NS + "oMath"):
yield oMath2Latex(omath)
def load_string(string):
root = ET.fromstring(string)
for omath in root.findall(OMML_NS + "oMath"):
yield oMath2Latex(omath)
def escape_latex(strs):
last = None
new_chr = []
strs = strs.replace(r"\\", "\\")
for c in strs:
if (c in CHARS) and (last != BACKSLASH):
new_chr.append(BACKSLASH + c)
else:
new_chr.append(c)
last = c
return BLANK.join(new_chr)
def get_val(key, default=None, store=CHR):
if key is not None:
return key if not store else store.get(key, key)
else:
return default
class Tag2Method(object):
def call_method(self, elm, stag=None):
getmethod = self.tag2meth.get
if stag is None:
stag = elm.tag.replace(OMML_NS, "")
method = getmethod(stag)
if method:
return method(self, elm)
else:
return None
def process_children_list(self, elm, include=None):
"""
process children of the elm,return iterable
"""
for _e in list(elm):
if OMML_NS not in _e.tag:
continue
stag = _e.tag.replace(OMML_NS, "")
if include and (stag not in include):
continue
t = self.call_method(_e, stag=stag)
if t is None:
t = self.process_unknow(_e, stag)
if t is None:
continue
yield (stag, t, _e)
def process_children_dict(self, elm, include=None):
"""
process children of the elm,return dict
"""
latex_chars = dict()
for stag, t, e in self.process_children_list(elm, include):
latex_chars[stag] = t
return latex_chars
def process_children(self, elm, include=None):
"""
process children of the elm,return string
"""
return BLANK.join(
(
t if not isinstance(t, Tag2Method) else str(t)
for stag, t, e in self.process_children_list(elm, include)
)
)
def process_unknow(self, elm, stag):
return None
class Pr(Tag2Method):
text = ""
__val_tags = ("chr", "pos", "begChr", "endChr", "type")
__innerdict = None # can't use the __dict__
""" common properties of element"""
def __init__(self, elm):
self.__innerdict = {}
self.text = self.process_children(elm)
def __str__(self):
return self.text
def __unicode__(self):
return self.__str__(self)
def __getattr__(self, name):
return self.__innerdict.get(name, None)
def do_brk(self, elm):
self.__innerdict["brk"] = BRK
return BRK
def do_common(self, elm):
stag = elm.tag.replace(OMML_NS, "")
if stag in self.__val_tags:
t = elm.get("{0}val".format(OMML_NS))
self.__innerdict[stag] = t
return None
tag2meth = {
"brk": do_brk,
"chr": do_common,
"pos": do_common,
"begChr": do_common,
"endChr": do_common,
"type": do_common,
}
class oMath2Latex(Tag2Method):
"""
Convert oMath element of omml to latex
"""
_t_dict = T
__direct_tags = ("box", "sSub", "sSup", "sSubSup", "num", "den", "deg", "e")
u = UnicodeToLatexEncoder(
replacement_latex_protection="braces-all",
unknown_char_policy="keep",
unknown_char_warning=False,
)
def __init__(self, element):
self._latex = self.process_children(element)
def __str__(self):
return self.latex.replace(" ", " ")
def __unicode__(self):
return self.__str__(self)
def process_unknow(self, elm, stag):
if stag in self.__direct_tags:
return self.process_children(elm)
elif stag[-2:] == "Pr":
return Pr(elm)
else:
return None
@property
def latex(self):
return self._latex
def do_acc(self, elm):
"""
the accent function
"""
c_dict = self.process_children_dict(elm)
latex_s = get_val(
c_dict["accPr"].chr, default=CHR_DEFAULT.get("ACC_VAL"), store=CHR
)
return latex_s.format(c_dict["e"])
def do_bar(self, elm):
"""
the bar function
"""
c_dict = self.process_children_dict(elm)
pr = c_dict["barPr"]
latex_s = get_val(pr.pos, default=POS_DEFAULT.get("BAR_VAL"), store=POS)
return pr.text + latex_s.format(c_dict["e"])
def do_d(self, elm):
"""
the delimiter object
"""
c_dict = self.process_children_dict(elm)
pr = c_dict["dPr"]
null = D_DEFAULT.get("null")
s_val = get_val(pr.begChr, default=D_DEFAULT.get("left"), store=T)
e_val = get_val(pr.endChr, default=D_DEFAULT.get("right"), store=T)
delim = pr.text + D.format(
left=null if not s_val else escape_latex(s_val),
text=c_dict["e"],
right=null if not e_val else escape_latex(e_val),
)
return delim
def do_spre(self, elm):
"""
the Pre-Sub-Superscript object -- Not support yet
"""
pass
def do_sub(self, elm):
text = self.process_children(elm)
return SUB.format(text)
def do_sup(self, elm):
text = self.process_children(elm)
return SUP.format(text)
def do_f(self, elm):
"""
the fraction object
"""
c_dict = self.process_children_dict(elm)
pr = c_dict["fPr"]
latex_s = get_val(pr.type, default=F_DEFAULT, store=F)
return pr.text + latex_s.format(num=c_dict.get("num"), den=c_dict.get("den"))
def do_func(self, elm):
"""
the Function-Apply object (Examples:sin cos)
"""
c_dict = self.process_children_dict(elm)
func_name = c_dict.get("fName")
return func_name.replace(FUNC_PLACE, c_dict.get("e"))
def do_fname(self, elm):
"""
the func name
"""
latex_chars = []
for stag, t, e in self.process_children_list(elm):
if stag == "r":
if FUNC.get(t):
latex_chars.append(FUNC[t])
else:
raise NotSupport("Not support func %s" % t)
else:
latex_chars.append(t)
t = BLANK.join(latex_chars)
return t if FUNC_PLACE in t else t + FUNC_PLACE # do_func will replace this
def do_groupchr(self, elm):
"""
the Group-Character object
"""
c_dict = self.process_children_dict(elm)
pr = c_dict["groupChrPr"]
latex_s = get_val(pr.chr)
return pr.text + latex_s.format(c_dict["e"])
def do_rad(self, elm):
"""
the radical object
"""
c_dict = self.process_children_dict(elm)
text = c_dict.get("e")
deg_text = c_dict.get("deg")
if deg_text:
return RAD.format(deg=deg_text, text=text)
else:
return RAD_DEFAULT.format(text=text)
def do_eqarr(self, elm):
"""
the Array object
"""
return ARR.format(
text=BRK.join(
[t for stag, t, e in self.process_children_list(elm, include=("e",))]
)
)
def do_limlow(self, elm):
"""
the Lower-Limit object
"""
t_dict = self.process_children_dict(elm, include=("e", "lim"))
latex_s = LIM_FUNC.get(t_dict["e"])
if not latex_s:
raise NotSupport("Not support lim %s" % t_dict["e"])
else:
return latex_s.format(lim=t_dict.get("lim"))
def do_limupp(self, elm):
"""
the Upper-Limit object
"""
t_dict = self.process_children_dict(elm, include=("e", "lim"))
return LIM_UPP.format(lim=t_dict.get("lim"), text=t_dict.get("e"))
def do_lim(self, elm):
"""
the lower limit of the limLow object and the upper limit of the limUpp function
"""
return self.process_children(elm).replace(LIM_TO[0], LIM_TO[1])
def do_m(self, elm):
"""
the Matrix object
"""
rows = []
for stag, t, e in self.process_children_list(elm):
if stag == "mPr":
pass
elif stag == "mr":
rows.append(t)
return M.format(text=BRK.join(rows))
def do_mr(self, elm):
"""
a single row of the matrix m
"""
return ALN.join(
[t for stag, t, e in self.process_children_list(elm, include=("e",))]
)
def do_nary(self, elm):
"""
the n-ary object
"""
res = []
bo = ""
for stag, t, e in self.process_children_list(elm):
if stag == "naryPr":
bo = get_val(t.chr, store=CHR_BO)
else:
res.append(t)
return bo + BLANK.join(res)
def process_unicode(self, s):
# s = s if isinstance(s,unicode) else unicode(s,'utf-8')
# print(s, self._t_dict.get(s, s), unicode_to_latex(s))
# _str.append( self._t_dict.get(s, s) )
out_latex_str = self.u.unicode_to_latex(s)
# print(s, out_latex_str)
if (
s.startswith("{") is False
and out_latex_str.startswith("{")
and s.endswith("}") is False
and out_latex_str.endswith("}")
):
out_latex_str = f" {out_latex_str[1:-1]} "
# print(s, out_latex_str)
if "ensuremath" in out_latex_str:
out_latex_str = out_latex_str.replace("\\ensuremath{", " ")
out_latex_str = out_latex_str.replace("}", " ")
# print(s, out_latex_str)
if out_latex_str.strip().startswith("\\text"):
out_latex_str = f" \\text{{{out_latex_str}}} "
# print(s, out_latex_str)
return out_latex_str
def do_r(self, elm):
"""
Get text from 'r' element,And try convert them to latex symbols
@todo text style support , (sty)
@todo \text (latex pure text support)
"""
_str = []
_base_str = []
for s in elm.findtext("./{0}t".format(OMML_NS)):
out_latex_str = self.process_unicode(s)
_str.append(out_latex_str)
_base_str.append(s)
proc_str = escape_latex(BLANK.join(_str))
base_proc_str = BLANK.join(_base_str)
if "{" not in base_proc_str and "\\{" in proc_str:
proc_str = proc_str.replace("\\{", "{")
if "}" not in base_proc_str and "\\}" in proc_str:
proc_str = proc_str.replace("\\}", "}")
return proc_str
tag2meth = {
"acc": do_acc,
"r": do_r,
"bar": do_bar,
"sub": do_sub,
"sup": do_sup,
"f": do_f,
"func": do_func,
"fName": do_fname,
"groupChr": do_groupchr,
"d": do_d,
"rad": do_rad,
"eqArr": do_eqarr,
"limLow": do_limlow,
"limUpp": do_limupp,
"lim": do_lim,
"m": do_m,
"mr": do_mr,
"nary": do_nary,
}

View File

@ -1,17 +1,22 @@
import logging
from io import BytesIO
from pathlib import Path
from typing import Set, Union
from typing import Final, Optional, Union, cast
from bs4 import BeautifulSoup
from bs4 import BeautifulSoup, NavigableString, PageElement, Tag
from bs4.element import PreformattedString
from docling_core.types.doc import (
DocItem,
DocItemLabel,
DoclingDocument,
DocumentOrigin,
GroupItem,
GroupLabel,
TableCell,
TableData,
)
from docling_core.types.doc.document import ContentLayer
from typing_extensions import override
from docling.backend.abstract_backend import DeclarativeDocumentBackend
from docling.datamodel.base_models import InputFormat
@ -19,21 +24,38 @@ from docling.datamodel.document import InputDocument
_log = logging.getLogger(__name__)
# tags that generate NodeItem elements
TAGS_FOR_NODE_ITEMS: Final = [
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"p",
"pre",
"ul",
"ol",
"li",
"table",
"figure",
"img",
]
class HTMLDocumentBackend(DeclarativeDocumentBackend):
@override
def __init__(self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]):
super().__init__(in_doc, path_or_stream)
_log.debug("About to init HTML backend...")
self.soup = None
self.soup: Optional[Tag] = None
# HTML file:
self.path_or_stream = path_or_stream
# Initialise the parents for the hierarchy
self.max_levels = 10
self.level = 0
self.parents = {} # type: ignore
self.parents: dict[int, Optional[Union[DocItem, GroupItem]]] = {}
for i in range(0, self.max_levels):
self.parents[i] = None
self.labels = {} # type: ignore
try:
if isinstance(self.path_or_stream, BytesIO):
@ -45,16 +67,20 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
self.soup = BeautifulSoup(html_content, "html.parser")
except Exception as e:
raise RuntimeError(
f"Could not initialize HTML backend for file with hash {self.document_hash}."
"Could not initialize HTML backend for file with "
f"hash {self.document_hash}."
) from e
@override
def is_valid(self) -> bool:
return self.soup is not None
@classmethod
@override
def supports_pagination(cls) -> bool:
return False
@override
def unload(self):
if isinstance(self.path_or_stream, BytesIO):
self.path_or_stream.close()
@ -62,9 +88,11 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
self.path_or_stream = None
@classmethod
def supported_formats(cls) -> Set[InputFormat]:
@override
def supported_formats(cls) -> set[InputFormat]:
return {InputFormat.HTML}
@override
def convert(self) -> DoclingDocument:
# access self.path_or_stream to load stuff
origin = DocumentOrigin(
@ -78,108 +106,118 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
if self.is_valid():
assert self.soup is not None
content = self.soup.body or self.soup
# Replace <br> tags with newline characters
for br in self.soup.body.find_all("br"):
br.replace_with("\n")
doc = self.walk(self.soup.body, doc)
# TODO: remove style to avoid losing text from tags like i, b, span, ...
for br in content("br"):
br.replace_with(NavigableString("\n"))
headers = content.find(["h1", "h2", "h3", "h4", "h5", "h6"])
self.content_layer = (
ContentLayer.BODY if headers is None else ContentLayer.FURNITURE
)
self.walk(content, doc)
else:
raise RuntimeError(
f"Cannot convert doc with {self.document_hash} because the backend failed to init."
f"Cannot convert doc with {self.document_hash} because the backend "
"failed to init."
)
return doc
def walk(self, element, doc):
try:
# Iterate over elements in the body of the document
for idx, element in enumerate(element.children):
def walk(self, tag: Tag, doc: DoclingDocument) -> None:
# Iterate over elements in the body of the document
text: str = ""
for element in tag.children:
if isinstance(element, Tag):
try:
self.analyse_element(element, idx, doc)
self.analyze_tag(cast(Tag, element), doc)
except Exception as exc_child:
_log.error(" -> error treating child: ", exc_child)
_log.error(" => element: ", element, "\n")
_log.error(
f"Error processing child from tag {tag.name}: {repr(exc_child)}"
)
raise exc_child
elif isinstance(element, NavigableString) and not isinstance(
element, PreformattedString
):
# Floating text outside paragraphs or analyzed tags
text += element
siblings: list[Tag] = [
item for item in element.next_siblings if isinstance(item, Tag)
]
if element.next_sibling is None or any(
[item.name in TAGS_FOR_NODE_ITEMS for item in siblings]
):
text = text.strip()
if text and tag.name in ["div"]:
doc.add_text(
parent=self.parents[self.level],
label=DocItemLabel.TEXT,
text=text,
content_layer=self.content_layer,
)
text = ""
except Exception as exc:
pass
return
return doc
def analyse_element(self, element, idx, doc):
"""
if element.name!=None:
_log.debug("\t"*self.level, idx, "\t", f"{element.name} ({self.level})")
"""
if element.name in self.labels:
self.labels[element.name] += 1
def analyze_tag(self, tag: Tag, doc: DoclingDocument) -> None:
if tag.name in ["h1", "h2", "h3", "h4", "h5", "h6"]:
self.handle_header(tag, doc)
elif tag.name in ["p"]:
self.handle_paragraph(tag, doc)
elif tag.name in ["pre"]:
self.handle_code(tag, doc)
elif tag.name in ["ul", "ol"]:
self.handle_list(tag, doc)
elif tag.name in ["li"]:
self.handle_list_item(tag, doc)
elif tag.name == "table":
self.handle_table(tag, doc)
elif tag.name == "figure":
self.handle_figure(tag, doc)
elif tag.name == "img":
self.handle_image(tag, doc)
else:
self.labels[element.name] = 1
self.walk(tag, doc)
if element.name in ["h1", "h2", "h3", "h4", "h5", "h6"]:
self.handle_header(element, idx, doc)
elif element.name in ["p"]:
self.handle_paragraph(element, idx, doc)
elif element.name in ["pre"]:
self.handle_code(element, idx, doc)
elif element.name in ["ul", "ol"]:
self.handle_list(element, idx, doc)
elif element.name in ["li"]:
self.handle_listitem(element, idx, doc)
elif element.name == "table":
self.handle_table(element, idx, doc)
elif element.name == "figure":
self.handle_figure(element, idx, doc)
elif element.name == "img":
self.handle_image(element, idx, doc)
else:
self.walk(element, doc)
def get_text(self, item: PageElement) -> str:
"""Get the text content of a tag."""
parts: list[str] = self.extract_text_recursively(item)
def get_direct_text(self, item):
"""Get the direct text of the <li> element (ignoring nested lists)."""
text = item.find(string=True, recursive=False)
if isinstance(text, str):
return text.strip()
return ""
return "".join(parts) + " "
# Function to recursively extract text from all child nodes
def extract_text_recursively(self, item):
result = []
def extract_text_recursively(self, item: PageElement) -> list[str]:
result: list[str] = []
if isinstance(item, str):
if isinstance(item, NavigableString):
return [item]
if item.name not in ["ul", "ol"]:
try:
# Iterate over the children (and their text and tails)
for child in item:
try:
# Recursively get the child's text content
result.extend(self.extract_text_recursively(child))
except:
pass
except:
_log.warn("item has no children")
pass
tag = cast(Tag, item)
if tag.name not in ["ul", "ol"]:
for child in tag:
# Recursively get the child's text content
result.extend(self.extract_text_recursively(child))
return "".join(result) + " "
return ["".join(result) + " "]
def handle_header(self, element, idx, doc):
def handle_header(self, element: Tag, doc: DoclingDocument) -> None:
"""Handles header tags (h1, h2, etc.)."""
hlevel = int(element.name.replace("h", ""))
slevel = hlevel - 1
label = DocItemLabel.SECTION_HEADER
text = element.text.strip()
self.content_layer = ContentLayer.BODY
if hlevel == 1:
for key, val in self.parents.items():
for key in self.parents.keys():
self.parents[key] = None
self.level = 1
self.parents[self.level] = doc.add_text(
parent=self.parents[0], label=DocItemLabel.TITLE, text=text
parent=self.parents[0],
label=DocItemLabel.TITLE,
text=text,
content_layer=self.content_layer,
)
else:
if hlevel > self.level:
@ -190,13 +228,14 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
name=f"header-{i}",
label=GroupLabel.SECTION,
parent=self.parents[i - 1],
content_layer=self.content_layer,
)
self.level = hlevel
elif hlevel < self.level:
# remove the tail
for key, val in self.parents.items():
for key in self.parents.keys():
if key > hlevel:
self.parents[key] = None
self.level = hlevel
@ -204,43 +243,59 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
self.parents[hlevel] = doc.add_heading(
parent=self.parents[hlevel - 1],
text=text,
level=hlevel,
level=hlevel - 1,
content_layer=self.content_layer,
)
def handle_code(self, element, idx, doc):
def handle_code(self, element: Tag, doc: DoclingDocument) -> None:
"""Handles monospace code snippets (pre)."""
if element.text is None:
return
text = element.text.strip()
label = DocItemLabel.CODE
if len(text) == 0:
return
doc.add_text(parent=self.parents[self.level], label=label, text=text)
if text:
doc.add_code(
parent=self.parents[self.level],
text=text,
content_layer=self.content_layer,
)
def handle_paragraph(self, element, idx, doc):
def handle_paragraph(self, element: Tag, doc: DoclingDocument) -> None:
"""Handles paragraph tags (p)."""
if element.text is None:
return
text = element.text.strip()
label = DocItemLabel.PARAGRAPH
if len(text) == 0:
return
doc.add_text(parent=self.parents[self.level], label=label, text=text)
if text:
doc.add_text(
parent=self.parents[self.level],
label=DocItemLabel.TEXT,
text=text,
content_layer=self.content_layer,
)
def handle_list(self, element, idx, doc):
def handle_list(self, element: Tag, doc: DoclingDocument) -> None:
"""Handles list tags (ul, ol) and their list items."""
if element.name == "ul":
# create a list group
self.parents[self.level + 1] = doc.add_group(
parent=self.parents[self.level], name="list", label=GroupLabel.LIST
parent=self.parents[self.level],
name="list",
label=GroupLabel.LIST,
content_layer=self.content_layer,
)
elif element.name == "ol":
start_attr = element.get("start")
start: int = (
int(start_attr)
if isinstance(start_attr, str) and start_attr.isnumeric()
else 1
)
# create a list group
self.parents[self.level + 1] = doc.add_group(
parent=self.parents[self.level],
name="ordered list",
name="ordered list" + (f" start {start}" if start != 1 else ""),
label=GroupLabel.ORDERED_LIST,
content_layer=self.content_layer,
)
self.level += 1
@ -249,25 +304,36 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
self.parents[self.level + 1] = None
self.level -= 1
def handle_listitem(self, element, idx, doc):
"""Handles listitem tags (li)."""
nested_lists = element.find(["ul", "ol"])
def handle_list_item(self, element: Tag, doc: DoclingDocument) -> None:
"""Handles list item tags (li)."""
nested_list = element.find(["ul", "ol"])
parent_list_label = self.parents[self.level].label
index_in_list = len(self.parents[self.level].children) + 1
parent = self.parents[self.level]
if parent is None:
_log.debug(f"list-item has no parent in DoclingDocument: {element}")
return
parent_label: str = parent.label
index_in_list = len(parent.children) + 1
if (
parent_label == GroupLabel.ORDERED_LIST
and isinstance(parent, GroupItem)
and parent.name
):
start_in_list: str = parent.name.split(" ")[-1]
start: int = int(start_in_list) if start_in_list.isnumeric() else 1
index_in_list += start - 1
if nested_lists:
name = element.name
if nested_list:
# Text in list item can be hidden within hierarchy, hence
# we need to extract it recursively
text = self.extract_text_recursively(element)
text: str = self.get_text(element)
# Flatten text, remove break lines:
text = text.replace("\n", "").replace("\r", "")
text = " ".join(text.split()).strip()
marker = ""
enumerated = False
if parent_list_label == GroupLabel.ORDERED_LIST:
if parent_label == GroupLabel.ORDERED_LIST:
marker = str(index_in_list)
enumerated = True
@ -277,83 +343,105 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
text=text,
enumerated=enumerated,
marker=marker,
parent=self.parents[self.level],
parent=parent,
content_layer=self.content_layer,
)
self.level += 1
self.walk(element, doc)
self.parents[self.level + 1] = None
self.level -= 1
else:
self.walk(element, doc)
self.walk(element, doc)
self.parents[self.level + 1] = None
self.level -= 1
elif isinstance(element.text, str):
elif element.text.strip():
text = element.text.strip()
marker = ""
enumerated = False
if parent_list_label == GroupLabel.ORDERED_LIST:
if parent_label == GroupLabel.ORDERED_LIST:
marker = f"{str(index_in_list)}."
enumerated = True
doc.add_list_item(
text=text,
enumerated=enumerated,
marker=marker,
parent=self.parents[self.level],
parent=parent,
content_layer=self.content_layer,
)
else:
_log.warn("list-item has no text: ", element)
def handle_table(self, element, idx, doc):
"""Handles table tags."""
_log.debug(f"list-item has no text: {element}")
@staticmethod
def parse_table_data(element: Tag) -> Optional[TableData]:
nested_tables = element.find("table")
if nested_tables is not None:
_log.warn("detected nested tables: skipping for now")
return
_log.debug("Skipping nested table.")
return None
# Count the number of rows (number of <tr> elements)
num_rows = len(element.find_all("tr"))
num_rows = len(element("tr"))
# Find the number of columns (taking into account colspan)
num_cols = 0
for row in element.find_all("tr"):
for row in element("tr"):
col_count = 0
for cell in row.find_all(["td", "th"]):
colspan = int(cell.get("colspan", 1))
if not isinstance(row, Tag):
continue
for cell in row(["td", "th"]):
if not isinstance(row, Tag):
continue
val = cast(Tag, cell).get("colspan", "1")
colspan = int(val) if (isinstance(val, str) and val.isnumeric()) else 1
col_count += colspan
num_cols = max(num_cols, col_count)
grid = [[None for _ in range(num_cols)] for _ in range(num_rows)]
grid: list = [[None for _ in range(num_cols)] for _ in range(num_rows)]
data = TableData(num_rows=num_rows, num_cols=num_cols, table_cells=[])
# Iterate over the rows in the table
for row_idx, row in enumerate(element.find_all("tr")):
for row_idx, row in enumerate(element("tr")):
if not isinstance(row, Tag):
continue
# For each row, find all the column cells (both <td> and <th>)
cells = row.find_all(["td", "th"])
cells = row(["td", "th"])
# Check if each cell in the row is a header -> means it is a column header
col_header = True
for j, html_cell in enumerate(cells):
if html_cell.name == "td":
for html_cell in cells:
if isinstance(html_cell, Tag) and html_cell.name == "td":
col_header = False
# Extract the text content of each cell
col_idx = 0
# Extract and print the text content of each cell
for _, html_cell in enumerate(cells):
for html_cell in cells:
if not isinstance(html_cell, Tag):
continue
# extract inline formulas
for formula in html_cell("inline-formula"):
math_parts = formula.text.split("$$")
if len(math_parts) == 3:
math_formula = f"$${math_parts[1]}$$"
formula.replace_with(NavigableString(math_formula))
# TODO: extract content correctly from table-cells with lists
text = html_cell.text
try:
text = self.extract_table_cell_text(html_cell)
except Exception as exc:
_log.warn("exception: ", exc)
exit(-1)
# label = html_cell.name
col_span = int(html_cell.get("colspan", 1))
row_span = int(html_cell.get("rowspan", 1))
col_val = html_cell.get("colspan", "1")
col_span = (
int(col_val)
if isinstance(col_val, str) and col_val.isnumeric()
else 1
)
row_val = html_cell.get("rowspan", "1")
row_span = (
int(row_val)
if isinstance(row_val, str) and row_val.isnumeric()
else 1
)
while grid[row_idx][col_idx] is not None:
col_idx += 1
@ -361,7 +449,7 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
for c in range(col_span):
grid[row_idx + r][col_idx + c] = text
cell = TableCell(
table_cell = TableCell(
text=text,
row_span=row_span,
col_span=col_span,
@ -369,73 +457,90 @@ class HTMLDocumentBackend(DeclarativeDocumentBackend):
end_row_offset_idx=row_idx + row_span,
start_col_offset_idx=col_idx,
end_col_offset_idx=col_idx + col_span,
col_header=col_header,
column_header=col_header,
row_header=((not col_header) and html_cell.name == "th"),
)
data.table_cells.append(cell)
data.table_cells.append(table_cell)
doc.add_table(data=data, parent=self.parents[self.level])
return data
def get_list_text(self, list_element, level=0):
def handle_table(self, element: Tag, doc: DoclingDocument) -> None:
"""Handles table tags."""
table_data = HTMLDocumentBackend.parse_table_data(element)
if table_data is not None:
doc.add_table(
data=table_data,
parent=self.parents[self.level],
content_layer=self.content_layer,
)
def get_list_text(self, list_element: Tag, level: int = 0) -> list[str]:
"""Recursively extract text from <ul> or <ol> with proper indentation."""
result = []
bullet_char = "*" # Default bullet character for unordered lists
if list_element.name == "ol": # For ordered lists, use numbers
for i, li in enumerate(list_element.find_all("li", recursive=False), 1):
for i, li in enumerate(list_element("li", recursive=False), 1):
if not isinstance(li, Tag):
continue
# Add numbering for ordered lists
result.append(f"{' ' * level}{i}. {li.get_text(strip=True)}")
# Handle nested lists
nested_list = li.find(["ul", "ol"])
if nested_list:
if isinstance(nested_list, Tag):
result.extend(self.get_list_text(nested_list, level + 1))
elif list_element.name == "ul": # For unordered lists, use bullet points
for li in list_element.find_all("li", recursive=False):
for li in list_element("li", recursive=False):
if not isinstance(li, Tag):
continue
# Add bullet points for unordered lists
result.append(
f"{' ' * level}{bullet_char} {li.get_text(strip=True)}"
)
# Handle nested lists
nested_list = li.find(["ul", "ol"])
if nested_list:
if isinstance(nested_list, Tag):
result.extend(self.get_list_text(nested_list, level + 1))
return result
def extract_table_cell_text(self, cell):
"""Extract text from a table cell, including lists with indents."""
contains_lists = cell.find(["ul", "ol"])
if contains_lists is None:
return cell.text
else:
_log.debug(
"should extract the content correctly for table-cells with lists ..."
)
return cell.text
def handle_figure(self, element, idx, doc):
def handle_figure(self, element: Tag, doc: DoclingDocument) -> None:
"""Handles image tags (img)."""
# Extract the image URI from the <img> tag
# image_uri = root.xpath('//figure//img/@src')[0]
contains_captions = element.find(["figcaption"])
if contains_captions is None:
doc.add_picture(parent=self.parents[self.level], caption=None)
if not isinstance(contains_captions, Tag):
doc.add_picture(
parent=self.parents[self.level],
caption=None,
content_layer=self.content_layer,
)
else:
texts = []
for item in contains_captions:
texts.append(item.text)
fig_caption = doc.add_text(
label=DocItemLabel.CAPTION, text=("".join(texts)).strip()
label=DocItemLabel.CAPTION,
text=("".join(texts)).strip(),
content_layer=self.content_layer,
)
doc.add_picture(
parent=self.parents[self.level],
caption=fig_caption,
content_layer=self.content_layer,
)
def handle_image(self, element, idx, doc):
def handle_image(self, element: Tag, doc: DoclingDocument) -> None:
"""Handles image tags (img)."""
doc.add_picture(parent=self.parents[self.level], caption=None)
_log.debug(f"ignoring <img> tags at the moment: {element}")
doc.add_picture(
parent=self.parents[self.level],
caption=None,
content_layer=self.content_layer,
)

View File

View File

@ -0,0 +1,58 @@
from io import BytesIO
from pathlib import Path
from typing import Union
from docling_core.types.doc import DoclingDocument
from typing_extensions import override
from docling.backend.abstract_backend import DeclarativeDocumentBackend
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import InputDocument
class DoclingJSONBackend(DeclarativeDocumentBackend):
@override
def __init__(
self, in_doc: InputDocument, path_or_stream: Union[BytesIO, Path]
) -> None:
super().__init__(in_doc, path_or_stream)
# given we need to store any actual conversion exception for raising it from
# convert(), this captures the successful result or the actual error in a
# mutually exclusive way:
self._doc_or_err = self._get_doc_or_err()
@override
def is_valid(self) -> bool:
return isinstance(self._doc_or_err, DoclingDocument)
@classmethod
@override
def supports_pagination(cls) -> bool:
return False
@classmethod
@override
def supported_formats(cls) -> set[InputFormat]:
return {InputFormat.JSON_DOCLING}
def _get_doc_or_err(self) -> Union[DoclingDocument, Exception]:
try:
json_data: Union[str, bytes]
if isinstance(self.path_or_stream, Path):
with open(self.path_or_stream, encoding="utf-8") as f:
json_data = f.read()
elif isinstance(self.path_or_stream, BytesIO):
json_data = self.path_or_stream.getvalue()
else:
raise RuntimeError(f"Unexpected: {type(self.path_or_stream)=}")
return DoclingDocument.model_validate_json(json_data=json_data)
except Exception as e:
return e
@override
def convert(self) -> DoclingDocument:
if isinstance(self._doc_or_err, DoclingDocument):
return self._doc_or_err
else:
raise self._doc_or_err

View File

@ -3,32 +3,40 @@ import re
import warnings
from io import BytesIO
from pathlib import Path
from typing import Set, Union
from typing import List, Optional, Set, Union
import marko
import marko.element
import marko.ext
import marko.ext.gfm
import marko.inline
from docling_core.types.doc import (
DocItem,
DocItemLabel,
DoclingDocument,
DocumentOrigin,
GroupLabel,
NodeItem,
TableCell,
TableData,
TextItem,
)
from marko import Markdown
from docling.backend.abstract_backend import DeclarativeDocumentBackend
from docling.backend.html_backend import HTMLDocumentBackend
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import InputDocument
_log = logging.getLogger(__name__)
_MARKER_BODY = "DOCLING_DOC_MD_HTML_EXPORT"
_START_MARKER = f"#_#_{_MARKER_BODY}_START_#_#"
_STOP_MARKER = f"#_#_{_MARKER_BODY}_STOP_#_#"
class MarkdownDocumentBackend(DeclarativeDocumentBackend):
def shorten_underscore_sequences(self, markdown_text, max_length=10):
def _shorten_underscore_sequences(self, markdown_text: str, max_length: int = 10):
# This regex will match any sequence of underscores
pattern = r"_+"
@ -63,7 +71,8 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
self.in_table = False
self.md_table_buffer: list[str] = []
self.inline_text_buffer = ""
self.inline_texts: list[str] = []
self._html_blocks: int = 0
try:
if isinstance(self.path_or_stream, BytesIO):
@ -72,7 +81,7 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
# very long sequences of underscores will lead to unnecessary long processing times.
# In any proper Markdown files, underscores have to be escaped,
# otherwise they represent emphasis (bold or italic)
self.markdown = self.shorten_underscore_sequences(text_stream)
self.markdown = self._shorten_underscore_sequences(text_stream)
if isinstance(self.path_or_stream, Path):
with open(self.path_or_stream, "r", encoding="utf-8") as f:
md_content = f.read()
@ -80,7 +89,7 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
# very long sequences of underscores will lead to unnecessary long processing times.
# In any proper Markdown files, underscores have to be escaped,
# otherwise they represent emphasis (bold or italic)
self.markdown = self.shorten_underscore_sequences(md_content)
self.markdown = self._shorten_underscore_sequences(md_content)
self.valid = True
_log.debug(self.markdown)
@ -90,13 +99,13 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
) from e
return
def close_table(self, doc=None):
def _close_table(self, doc: DoclingDocument):
if self.in_table:
_log.debug("=== TABLE START ===")
for md_table_row in self.md_table_buffer:
_log.debug(md_table_row)
_log.debug("=== TABLE END ===")
tcells = []
tcells: List[TableCell] = []
result_table = []
for n, md_table_row in enumerate(self.md_table_buffer):
data = []
@ -127,7 +136,7 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
end_row_offset_idx=trow_ind + row_span,
start_col_offset_idx=tcol_ind,
end_col_offset_idx=tcol_ind + col_span,
col_header=False,
column_header=trow_ind == 0,
row_header=False,
)
tcells.append(icell)
@ -137,33 +146,47 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
self.in_table = False
self.md_table_buffer = [] # clean table markdown buffer
# Initialize Docling TableData
data = TableData(num_rows=num_rows, num_cols=num_cols, table_cells=tcells)
table_data = TableData(
num_rows=num_rows, num_cols=num_cols, table_cells=tcells
)
# Populate
for tcell in tcells:
data.table_cells.append(tcell)
table_data.table_cells.append(tcell)
if len(tcells) > 0:
doc.add_table(data=data)
doc.add_table(data=table_data)
return
def process_inline_text(self, parent_element, doc=None):
# self.inline_text_buffer += str(text_in)
txt = self.inline_text_buffer.strip()
def _process_inline_text(
self, parent_item: Optional[NodeItem], doc: DoclingDocument
):
txt = " ".join(self.inline_texts)
if len(txt) > 0:
doc.add_text(
label=DocItemLabel.PARAGRAPH,
parent=parent_element,
parent=parent_item,
text=txt,
)
self.inline_text_buffer = ""
self.inline_texts = []
def _iterate_elements(
self,
element: marko.element.Element,
depth: int,
doc: DoclingDocument,
visited: Set[marko.element.Element],
parent_item: Optional[NodeItem] = None,
):
if element in visited:
return
def iterate_elements(self, element, depth=0, doc=None, parent_element=None):
# Iterates over all elements in the AST
# Check for different element types and process relevant details
if isinstance(element, marko.block.Heading):
self.close_table(doc)
self.process_inline_text(parent_element, doc)
if isinstance(element, marko.block.Heading) and len(element.children) > 0:
self._close_table(doc)
self._process_inline_text(parent_item, doc)
_log.debug(
f" - Heading level {element.level}, content: {element.children[0].children}"
f" - Heading level {element.level}, content: {element.children[0].children}" # type: ignore
)
if element.level == 1:
doc_label = DocItemLabel.TITLE
@ -172,10 +195,10 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
# Header could have arbitrary inclusion of bold, italic or emphasis,
# hence we need to traverse the tree to get full text of a header
strings = []
strings: List[str] = []
# Define a recursive function to traverse the tree
def traverse(node):
def traverse(node: marko.block.BlockElement):
# Check if the node has a "children" attribute
if hasattr(node, "children"):
# If "children" is a list, continue traversal
@ -189,121 +212,147 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
traverse(element)
snippet_text = "".join(strings)
if len(snippet_text) > 0:
parent_element = doc.add_text(
label=doc_label, parent=parent_element, text=snippet_text
)
if doc_label == DocItemLabel.SECTION_HEADER:
parent_item = doc.add_heading(
text=snippet_text,
level=element.level - 1,
parent=parent_item,
)
else:
parent_item = doc.add_text(
label=doc_label, parent=parent_item, text=snippet_text
)
elif isinstance(element, marko.block.List):
self.close_table(doc)
self.process_inline_text(parent_element, doc)
_log.debug(f" - List {'ordered' if element.ordered else 'unordered'}")
list_label = GroupLabel.LIST
if element.ordered:
list_label = GroupLabel.ORDERED_LIST
parent_element = doc.add_group(
label=list_label, name=f"list", parent=parent_element
)
has_non_empty_list_items = False
for child in element.children:
if isinstance(child, marko.block.ListItem) and len(child.children) > 0:
has_non_empty_list_items = True
break
elif isinstance(element, marko.block.ListItem):
self.close_table(doc)
self.process_inline_text(parent_element, doc)
self._close_table(doc)
self._process_inline_text(parent_item, doc)
_log.debug(f" - List {'ordered' if element.ordered else 'unordered'}")
if has_non_empty_list_items:
label = GroupLabel.ORDERED_LIST if element.ordered else GroupLabel.LIST
parent_item = doc.add_group(
label=label, name=f"list", parent=parent_item
)
elif (
isinstance(element, marko.block.ListItem)
and len(element.children) > 0
and isinstance((first_child := element.children[0]), marko.block.Paragraph)
):
self._close_table(doc)
self._process_inline_text(parent_item, doc)
_log.debug(" - List item")
snippet_text = str(element.children[0].children[0].children)
snippet_text = str(first_child.children[0].children) # type: ignore
is_numbered = False
if parent_element.label == GroupLabel.ORDERED_LIST:
if (
parent_item is not None
and isinstance(parent_item, DocItem)
and parent_item.label == GroupLabel.ORDERED_LIST
):
is_numbered = True
doc.add_list_item(
enumerated=is_numbered, parent=parent_element, text=snippet_text
enumerated=is_numbered, parent=parent_item, text=snippet_text
)
visited.add(first_child)
elif isinstance(element, marko.inline.Image):
self.close_table(doc)
self.process_inline_text(parent_element, doc)
self._close_table(doc)
self._process_inline_text(parent_item, doc)
_log.debug(f" - Image with alt: {element.title}, url: {element.dest}")
doc.add_picture(parent=parent_element, caption=element.title)
elif isinstance(element, marko.block.Paragraph):
self.process_inline_text(parent_element, doc)
fig_caption: Optional[TextItem] = None
if element.title is not None and element.title != "":
fig_caption = doc.add_text(
label=DocItemLabel.CAPTION, text=element.title
)
doc.add_picture(parent=parent_item, caption=fig_caption)
elif isinstance(element, marko.block.Paragraph) and len(element.children) > 0:
self._process_inline_text(parent_item, doc)
elif isinstance(element, marko.inline.RawText):
_log.debug(f" - Paragraph (raw text): {element.children}")
snippet_text = str(element.children).strip()
snippet_text = element.children.strip()
# Detect start of the table:
if "|" in snippet_text:
# most likely part of the markdown table
self.in_table = True
if len(self.md_table_buffer) > 0:
self.md_table_buffer[len(self.md_table_buffer) - 1] += str(
snippet_text
)
self.md_table_buffer[len(self.md_table_buffer) - 1] += snippet_text
else:
self.md_table_buffer.append(snippet_text)
else:
self.close_table(doc)
self.in_table = False
self._close_table(doc)
# most likely just inline text
self.inline_text_buffer += str(
element.children
) # do not strip an inline text, as it may contain important spaces
self.inline_texts.append(str(element.children))
elif isinstance(element, marko.inline.CodeSpan):
self.close_table(doc)
self.process_inline_text(parent_element, doc)
self._close_table(doc)
self._process_inline_text(parent_item, doc)
_log.debug(f" - Code Span: {element.children}")
snippet_text = str(element.children).strip()
doc.add_text(
label=DocItemLabel.CODE, parent=parent_element, text=snippet_text
)
doc.add_code(parent=parent_item, text=snippet_text)
elif isinstance(element, marko.block.CodeBlock):
self.close_table(doc)
self.process_inline_text(parent_element, doc)
elif (
isinstance(element, (marko.block.CodeBlock, marko.block.FencedCode))
and len(element.children) > 0
and isinstance((first_child := element.children[0]), marko.inline.RawText)
and len(snippet_text := (first_child.children.strip())) > 0
):
self._close_table(doc)
self._process_inline_text(parent_item, doc)
_log.debug(f" - Code Block: {element.children}")
snippet_text = str(element.children[0].children).strip()
doc.add_text(
label=DocItemLabel.CODE, parent=parent_element, text=snippet_text
)
elif isinstance(element, marko.block.FencedCode):
self.close_table(doc)
self.process_inline_text(parent_element, doc)
_log.debug(f" - Code Block: {element.children}")
snippet_text = str(element.children[0].children).strip()
doc.add_text(
label=DocItemLabel.CODE, parent=parent_element, text=snippet_text
)
doc.add_code(parent=parent_item, text=snippet_text)
elif isinstance(element, marko.inline.LineBreak):
self.process_inline_text(parent_element, doc)
if self.in_table:
_log.debug("Line break in a table")
self.md_table_buffer.append("")
elif isinstance(element, marko.block.HTMLBlock):
self.process_inline_text(parent_element, doc)
self.close_table(doc)
self._html_blocks += 1
self._process_inline_text(parent_item, doc)
self._close_table(doc)
_log.debug("HTML Block: {}".format(element))
if (
len(element.children) > 0
len(element.body) > 0
): # If Marko doesn't return any content for HTML block, skip it
snippet_text = str(element.children).strip()
doc.add_text(
label=DocItemLabel.CODE, parent=parent_element, text=snippet_text
)
html_block = element.body.strip()
# wrap in markers to enable post-processing in convert()
text_to_add = f"{_START_MARKER}{html_block}{_STOP_MARKER}"
doc.add_code(parent=parent_item, text=text_to_add)
else:
if not isinstance(element, str):
self.close_table(doc)
self._close_table(doc)
_log.debug("Some other element: {}".format(element))
processed_block_types = (
marko.block.Heading,
marko.block.CodeBlock,
marko.block.FencedCode,
marko.inline.RawText,
)
# Iterate through the element's children (if any)
if not isinstance(element, marko.block.ListItem):
if not isinstance(element, marko.block.Heading):
if not isinstance(element, marko.block.FencedCode):
# if not isinstance(element, marko.block.Paragraph):
if hasattr(element, "children"):
for child in element.children:
self.iterate_elements(child, depth + 1, doc, parent_element)
if hasattr(element, "children") and not isinstance(
element, processed_block_types
):
for child in element.children:
self._iterate_elements(
element=child,
depth=depth + 1,
doc=doc,
visited=visited,
parent_item=parent_item,
)
def is_valid(self) -> bool:
return self.valid
@ -337,8 +386,51 @@ class MarkdownDocumentBackend(DeclarativeDocumentBackend):
marko_parser = Markdown()
parsed_ast = marko_parser.parse(self.markdown)
# Start iterating from the root of the AST
self.iterate_elements(parsed_ast, 0, doc, None)
self.process_inline_text(None, doc) # handle last hanging inline text
self._iterate_elements(
element=parsed_ast,
depth=0,
doc=doc,
parent_item=None,
visited=set(),
)
self._process_inline_text(None, doc) # handle last hanging inline text
self._close_table(doc=doc) # handle any last hanging table
# if HTML blocks were detected, export to HTML and delegate to HTML backend
if self._html_blocks > 0:
# export to HTML
html_backend_cls = HTMLDocumentBackend
html_str = doc.export_to_html()
def _restore_original_html(txt, regex):
_txt, count = re.subn(regex, "", txt)
if count != self._html_blocks:
raise RuntimeError(
"An internal error has occurred during Markdown conversion."
)
return _txt
# restore original HTML by removing previouly added markers
for regex in [
rf"<pre>\s*<code>\s*{_START_MARKER}",
rf"{_STOP_MARKER}\s*</code>\s*</pre>",
]:
html_str = _restore_original_html(txt=html_str, regex=regex)
self._html_blocks = 0
# delegate to HTML backend
stream = BytesIO(bytes(html_str, encoding="utf-8"))
in_doc = InputDocument(
path_or_stream=stream,
format=InputFormat.HTML,
backend=html_backend_cls,
filename=self.file.name,
)
html_backend_obj = html_backend_cls(
in_doc=in_doc, path_or_stream=stream
)
doc = html_backend_obj.convert()
else:
raise RuntimeError(
f"Cannot convert md with {self.document_hash} because the backend failed to init."

View File

@ -26,6 +26,7 @@ _log = logging.getLogger(__name__)
from typing import Any, List
from PIL import Image as PILImage
from pydantic import BaseModel
@ -44,7 +45,6 @@ class ExcelTable(BaseModel):
class MsExcelDocumentBackend(DeclarativeDocumentBackend):
def __init__(self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]):
super().__init__(in_doc, path_or_stream)
@ -164,7 +164,7 @@ class MsExcelDocumentBackend(DeclarativeDocumentBackend):
end_row_offset_idx=excel_cell.row + excel_cell.row_span,
start_col_offset_idx=excel_cell.col,
end_col_offset_idx=excel_cell.col + excel_cell.col_span,
col_header=False,
column_header=excel_cell.row == 0,
row_header=False,
)
table_data.table_cells.append(cell)
@ -173,7 +173,7 @@ class MsExcelDocumentBackend(DeclarativeDocumentBackend):
return doc
def _find_data_tables(self, sheet: Worksheet):
def _find_data_tables(self, sheet: Worksheet) -> List[ExcelTable]:
"""
Find all compact rectangular data tables in a sheet.
"""
@ -326,49 +326,18 @@ class MsExcelDocumentBackend(DeclarativeDocumentBackend):
self, doc: DoclingDocument, sheet: Worksheet
) -> DoclingDocument:
# FIXME: mypy does not agree with _images ...
"""
# Iterate over images in the sheet
for idx, image in enumerate(sheet._images): # Access embedded images
# Iterate over byte images in the sheet
for idx, image in enumerate(sheet._images): # type: ignore
image_bytes = BytesIO(image.ref.blob)
pil_image = Image.open(image_bytes)
try:
pil_image = PILImage.open(image.ref)
doc.add_picture(
parent=self.parents[0],
image=ImageRef.from_pil(image=pil_image, dpi=72),
caption=None,
)
"""
# FIXME: mypy does not agree with _charts ...
"""
for idx, chart in enumerate(sheet._charts): # Access embedded charts
chart_path = f"chart_{idx + 1}.png"
_log.info(
f"Chart found, but dynamic rendering is required for: {chart_path}"
)
_log.info(f"Chart {idx + 1}:")
# Chart type
_log.info(f"Type: {type(chart).__name__}")
# Title
if chart.title:
_log.info(f"Title: {chart.title}")
else:
_log.info("No title")
# Data series
for series in chart.series:
_log.info(" => series ...")
_log.info(f"Data Series: {series.title}")
_log.info(f"Values: {series.values}")
_log.info(f"Categories: {series.categories}")
# Position
# _log.info(f"Anchor Cell: {chart.anchor}")
"""
doc.add_picture(
parent=self.parents[0],
image=ImageRef.from_pil(image=pil_image, dpi=72),
caption=None,
)
except:
_log.error("could not extract the image from excel sheets")
return doc

View File

@ -16,6 +16,7 @@ from docling_core.types.doc import (
TableCell,
TableData,
)
from docling_core.types.doc.document import ContentLayer
from PIL import Image, UnidentifiedImageError
from pptx import Presentation
from pptx.enum.shapes import MSO_SHAPE_TYPE, PP_PLACEHOLDER
@ -98,21 +99,28 @@ class MsPowerpointDocumentBackend(DeclarativeDocumentBackend, PaginatedDocumentB
return doc
def generate_prov(self, shape, slide_ind, text=""):
left = shape.left
top = shape.top
width = shape.width
height = shape.height
def generate_prov(
self, shape, slide_ind, text="", slide_size=Size(width=1, height=1)
):
if shape.left:
left = shape.left
top = shape.top
width = shape.width
height = shape.height
else:
left = 0
top = 0
width = slide_size.width
height = slide_size.height
shape_bbox = [left, top, left + width, top + height]
shape_bbox = BoundingBox.from_tuple(shape_bbox, origin=CoordOrigin.BOTTOMLEFT)
# prov = [{"bbox": shape_bbox, "page": parent_slide, "span": [0, len(text)]}]
prov = ProvenanceItem(
page_no=slide_ind + 1, charspan=[0, len(text)], bbox=shape_bbox
)
return prov
def handle_text_elements(self, shape, parent_slide, slide_ind, doc):
def handle_text_elements(self, shape, parent_slide, slide_ind, doc, slide_size):
is_a_list = False
is_list_group_created = False
enum_list_item_value = 0
@ -121,7 +129,7 @@ class MsPowerpointDocumentBackend(DeclarativeDocumentBackend, PaginatedDocumentB
list_text = ""
list_label = GroupLabel.LIST
doc_label = DocItemLabel.LIST_ITEM
prov = self.generate_prov(shape, slide_ind, shape.text.strip())
prov = self.generate_prov(shape, slide_ind, shape.text.strip(), slide_size)
# Identify if shape contains lists
for paragraph in shape.text_frame.paragraphs:
@ -270,18 +278,17 @@ class MsPowerpointDocumentBackend(DeclarativeDocumentBackend, PaginatedDocumentB
)
return
def handle_pictures(self, shape, parent_slide, slide_ind, doc):
# Get the image bytes
image = shape.image
image_bytes = image.blob
im_dpi, _ = image.dpi
def handle_pictures(self, shape, parent_slide, slide_ind, doc, slide_size):
# Open it with PIL
try:
# Get the image bytes
image = shape.image
image_bytes = image.blob
im_dpi, _ = image.dpi
pil_image = Image.open(BytesIO(image_bytes))
# shape has picture
prov = self.generate_prov(shape, slide_ind, "")
prov = self.generate_prov(shape, slide_ind, "", slide_size)
doc.add_picture(
parent=parent_slide,
image=ImageRef.from_pil(image=pil_image, dpi=im_dpi),
@ -292,13 +299,13 @@ class MsPowerpointDocumentBackend(DeclarativeDocumentBackend, PaginatedDocumentB
_log.warning(f"Warning: image cannot be loaded by Pillow: {e}")
return
def handle_tables(self, shape, parent_slide, slide_ind, doc):
def handle_tables(self, shape, parent_slide, slide_ind, doc, slide_size):
# Handling tables, images, charts
if shape.has_table:
table = shape.table
table_xml = shape._element
prov = self.generate_prov(shape, slide_ind, "")
prov = self.generate_prov(shape, slide_ind, "", slide_size)
num_cols = 0
num_rows = len(table.rows)
@ -340,7 +347,7 @@ class MsPowerpointDocumentBackend(DeclarativeDocumentBackend, PaginatedDocumentB
end_row_offset_idx=row_idx + row_span,
start_col_offset_idx=col_idx,
end_col_offset_idx=col_idx + col_span,
col_header=False,
column_header=row_idx == 0,
row_header=False,
)
if len(cell.text.strip()) > 0:
@ -375,17 +382,19 @@ class MsPowerpointDocumentBackend(DeclarativeDocumentBackend, PaginatedDocumentB
name=f"slide-{slide_ind}", label=GroupLabel.CHAPTER, parent=parents[0]
)
size = Size(width=slide_width, height=slide_height)
parent_page = doc.add_page(page_no=slide_ind + 1, size=size)
slide_size = Size(width=slide_width, height=slide_height)
parent_page = doc.add_page(page_no=slide_ind + 1, size=slide_size)
def handle_shapes(shape, parent_slide, slide_ind, doc):
handle_groups(shape, parent_slide, slide_ind, doc)
def handle_shapes(shape, parent_slide, slide_ind, doc, slide_size):
handle_groups(shape, parent_slide, slide_ind, doc, slide_size)
if shape.has_table:
# Handle Tables
self.handle_tables(shape, parent_slide, slide_ind, doc)
self.handle_tables(shape, parent_slide, slide_ind, doc, slide_size)
if shape.shape_type == MSO_SHAPE_TYPE.PICTURE:
# Handle Pictures
self.handle_pictures(shape, parent_slide, slide_ind, doc)
self.handle_pictures(
shape, parent_slide, slide_ind, doc, slide_size
)
# If shape doesn't have any text, move on to the next shape
if not hasattr(shape, "text"):
return
@ -397,16 +406,37 @@ class MsPowerpointDocumentBackend(DeclarativeDocumentBackend, PaginatedDocumentB
_log.warning("Warning: shape has text but not text_frame")
return
# Handle other text elements, including lists (bullet lists, numbered lists)
self.handle_text_elements(shape, parent_slide, slide_ind, doc)
self.handle_text_elements(
shape, parent_slide, slide_ind, doc, slide_size
)
return
def handle_groups(shape, parent_slide, slide_ind, doc):
def handle_groups(shape, parent_slide, slide_ind, doc, slide_size):
if shape.shape_type == MSO_SHAPE_TYPE.GROUP:
for groupedshape in shape.shapes:
handle_shapes(groupedshape, parent_slide, slide_ind, doc)
handle_shapes(
groupedshape, parent_slide, slide_ind, doc, slide_size
)
# Loop through each shape in the slide
for shape in slide.shapes:
handle_shapes(shape, parent_slide, slide_ind, doc)
handle_shapes(shape, parent_slide, slide_ind, doc, slide_size)
# Handle notes slide
if slide.has_notes_slide:
notes_slide = slide.notes_slide
notes_text = notes_slide.notes_text_frame.text.strip()
if notes_text:
bbox = BoundingBox(l=0, t=0, r=0, b=0)
prov = ProvenanceItem(
page_no=slide_ind + 1, charspan=[0, len(notes_text)], bbox=bbox
)
doc.add_text(
label=DocItemLabel.TEXT,
parent=parent_slide,
text=notes_text,
prov=prov,
content_layer=ContentLayer.FURNITURE,
)
return doc

View File

@ -2,23 +2,31 @@ import logging
import re
from io import BytesIO
from pathlib import Path
from typing import Set, Union
from typing import Any, Optional, Union
import docx
from docling_core.types.doc import (
DocItemLabel,
DoclingDocument,
DocumentOrigin,
GroupLabel,
ImageRef,
NodeItem,
TableCell,
TableData,
)
from docx import Document
from docx.document import Document as DocxDocument
from docx.oxml.table import CT_Tc
from docx.oxml.xmlchemy import BaseOxmlElement
from docx.table import Table, _Cell
from docx.text.paragraph import Paragraph
from lxml import etree
from lxml.etree import XPath
from PIL import Image, UnidentifiedImageError
from typing_extensions import override
from docling.backend.abstract_backend import DeclarativeDocumentBackend
from docling.backend.docx.latex.omml import oMath2Latex
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import InputDocument
@ -26,8 +34,10 @@ _log = logging.getLogger(__name__)
class MsWordDocumentBackend(DeclarativeDocumentBackend):
def __init__(self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]):
@override
def __init__(
self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]
) -> None:
super().__init__(in_doc, path_or_stream)
self.XML_KEY = (
"{http://schemas.openxmlformats.org/wordprocessingml/2006/main}val"
@ -37,19 +47,20 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
}
# self.initialise(path_or_stream)
# Word file:
self.path_or_stream = path_or_stream
self.valid = False
self.path_or_stream: Union[BytesIO, Path] = path_or_stream
self.valid: bool = False
# Initialise the parents for the hierarchy
self.max_levels = 10
self.level_at_new_list = None
self.parents = {} # type: ignore
self.max_levels: int = 10
self.level_at_new_list: Optional[int] = None
self.parents: dict[int, Optional[NodeItem]] = {}
self.numbered_headers: dict[int, int] = {}
for i in range(-1, self.max_levels):
self.parents[i] = None
self.level = 0
self.listIter = 0
self.history = {
self.history: dict[str, Any] = {
"names": [None],
"levels": [None],
"numids": [None],
@ -59,9 +70,9 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
self.docx_obj = None
try:
if isinstance(self.path_or_stream, BytesIO):
self.docx_obj = docx.Document(self.path_or_stream)
self.docx_obj = Document(self.path_or_stream)
elif isinstance(self.path_or_stream, Path):
self.docx_obj = docx.Document(str(self.path_or_stream))
self.docx_obj = Document(str(self.path_or_stream))
self.valid = True
except Exception as e:
@ -69,13 +80,16 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
f"MsPowerpointDocumentBackend could not load document with hash {self.document_hash}"
) from e
@override
def is_valid(self) -> bool:
return self.valid
@classmethod
@override
def supports_pagination(cls) -> bool:
return False
@override
def unload(self):
if isinstance(self.path_or_stream, BytesIO):
self.path_or_stream.close()
@ -83,11 +97,17 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
self.path_or_stream = None
@classmethod
def supported_formats(cls) -> Set[InputFormat]:
@override
def supported_formats(cls) -> set[InputFormat]:
return {InputFormat.DOCX}
@override
def convert(self) -> DoclingDocument:
# Parses the DOCX into a structured document model.
"""Parses the DOCX into a structured document model.
Returns:
The parsed document.
"""
origin = DocumentOrigin(
filename=self.file.name or "file",
@ -105,23 +125,29 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
f"Cannot convert doc with {self.document_hash} because the backend failed to init."
)
def update_history(self, name, level, numid, ilevel):
def update_history(
self,
name: str,
level: Optional[int],
numid: Optional[int],
ilevel: Optional[int],
):
self.history["names"].append(name)
self.history["levels"].append(level)
self.history["numids"].append(numid)
self.history["indents"].append(ilevel)
def prev_name(self):
def prev_name(self) -> Optional[str]:
return self.history["names"][-1]
def prev_level(self):
def prev_level(self) -> Optional[int]:
return self.history["levels"][-1]
def prev_numid(self):
def prev_numid(self) -> Optional[int]:
return self.history["numids"][-1]
def prev_indent(self):
def prev_indent(self) -> Optional[int]:
return self.history["indents"][-1]
def get_level(self) -> int:
@ -131,13 +157,19 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
return k
return 0
def walk_linear(self, body, docx_obj, doc) -> DoclingDocument:
def walk_linear(
self,
body: BaseOxmlElement,
docx_obj: DocxDocument,
doc: DoclingDocument,
) -> DoclingDocument:
for element in body:
tag_name = etree.QName(element).localname
# Check for Inline Images (blip elements)
namespaces = {
"a": "http://schemas.openxmlformats.org/drawingml/2006/main",
"r": "http://schemas.openxmlformats.org/officeDocument/2006/relationships",
"w": "http://schemas.openxmlformats.org/wordprocessingml/2006/main",
}
xpath_expr = XPath(".//a:blip", namespaces=namespaces)
drawing_blip = xpath_expr(element)
@ -150,7 +182,15 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
_log.debug("could not parse a table, broken docx table")
elif drawing_blip:
self.handle_pictures(element, docx_obj, drawing_blip, doc)
self.handle_pictures(docx_obj, drawing_blip, doc)
# Check for the sdt containers, like table of contents
elif tag_name in ["sdt"]:
sdt_content = element.find(".//w:sdtContent", namespaces=namespaces)
if sdt_content is not None:
# Iterate paragraphs, runs, or text inside <w:sdtContent>.
paragraphs = sdt_content.findall(".//w:p", namespaces=namespaces)
for p in paragraphs:
self.handle_text_elements(p, docx_obj, doc)
# Check for Text
elif tag_name in ["p"]:
# "tcPr", "sectPr"
@ -159,7 +199,7 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
_log.debug(f"Ignoring element in DOCX with tag: {tag_name}")
return doc
def str_to_int(self, s, default=0):
def str_to_int(self, s: Optional[str], default: Optional[int] = 0) -> Optional[int]:
if s is None:
return None
try:
@ -167,7 +207,7 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
except ValueError:
return default
def split_text_and_number(self, input_string):
def split_text_and_number(self, input_string: str) -> list[str]:
match = re.match(r"(\D+)(\d+)$|^(\d+)(\D+)", input_string)
if match:
parts = list(filter(None, match.groups()))
@ -175,7 +215,9 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
else:
return [input_string]
def get_numId_and_ilvl(self, paragraph):
def get_numId_and_ilvl(
self, paragraph: Paragraph
) -> tuple[Optional[int], Optional[int]]:
# Access the XML element of the paragraph
numPr = paragraph._element.find(
".//w:numPr", namespaces=paragraph._element.nsmap
@ -188,13 +230,11 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
numId = numId_elem.get(self.XML_KEY) if numId_elem is not None else None
ilvl = ilvl_elem.get(self.XML_KEY) if ilvl_elem is not None else None
return self.str_to_int(numId, default=None), self.str_to_int(
ilvl, default=None
)
return self.str_to_int(numId, None), self.str_to_int(ilvl, None)
return None, None # If the paragraph is not part of a list
def get_label_and_level(self, paragraph):
def get_label_and_level(self, paragraph: Paragraph) -> tuple[str, Optional[int]]:
if paragraph.style is None:
return "Normal", None
label = paragraph.style.style_id
@ -204,20 +244,20 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
parts = label.split(":")
if len(parts) == 2:
return parts[0], int(parts[1])
return parts[0], self.str_to_int(parts[1], None)
parts = self.split_text_and_number(label)
if "Heading" in label and len(parts) == 2:
parts.sort()
label_str = ""
label_level = 0
label_str: str = ""
label_level: Optional[int] = 0
if parts[0] == "Heading":
label_str = parts[0]
label_level = self.str_to_int(parts[1], default=None)
label_level = self.str_to_int(parts[1], None)
if parts[1] == "Heading":
label_str = parts[1]
label_level = self.str_to_int(parts[0], default=None)
label_level = self.str_to_int(parts[0], None)
return label_str, label_level
else:
return label, None
@ -280,10 +320,39 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
return paragraph_text.strip()
def handle_text_elements(self, element, docx_obj, doc):
paragraph = docx.text.paragraph.Paragraph(element, docx_obj)
def handle_equations_in_text(self, element, text):
only_texts = []
only_equations = []
texts_and_equations = []
for subt in element.iter():
tag_name = etree.QName(subt).localname
if tag_name == "t" and "math" not in subt.tag:
only_texts.append(subt.text)
texts_and_equations.append(subt.text)
elif "oMath" in subt.tag and "oMathPara" not in subt.tag:
latex_equation = str(oMath2Latex(subt))
only_equations.append(latex_equation)
texts_and_equations.append(latex_equation)
if paragraph.text is None:
if "".join(only_texts).strip() != text.strip():
# If we are not able to reconstruct the initial raw text
# do not try to parse equations and return the original
return text, []
return "".join(texts_and_equations), only_equations
def handle_text_elements(
self,
element: BaseOxmlElement,
docx_obj: DocxDocument,
doc: DoclingDocument,
) -> None:
paragraph = Paragraph(element, docx_obj)
raw_text = paragraph.text
text, equations = self.handle_equations_in_text(element=element, text=raw_text)
if text is None:
return
text = self.format_paragraph(paragraph)
@ -299,13 +368,13 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
numid = None
# Handle lists
if numid is not None and ilevel is not None:
if (
numid is not None
and ilevel is not None
and p_style_id not in ["Title", "Heading"]
):
self.add_listitem(
element,
docx_obj,
doc,
p_style_id,
p_level,
numid,
ilevel,
text,
@ -313,20 +382,77 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
)
self.update_history(p_style_id, p_level, numid, ilevel)
return
elif numid is None and self.prev_numid() is not None: # Close list
for key, val in self.parents.items():
if key >= self.level_at_new_list:
elif (
numid is None
and self.prev_numid() is not None
and p_style_id not in ["Title", "Heading"]
): # Close list
if self.level_at_new_list:
for key in range(len(self.parents)):
if key >= self.level_at_new_list:
self.parents[key] = None
self.level = self.level_at_new_list - 1
self.level_at_new_list = None
else:
for key in range(len(self.parents)):
self.parents[key] = None
self.level = self.level_at_new_list - 1
self.level_at_new_list = None
self.level = 0
if p_style_id in ["Title"]:
for key, val in self.parents.items():
for key in range(len(self.parents)):
self.parents[key] = None
self.parents[0] = doc.add_text(
parent=None, label=DocItemLabel.TITLE, text=text
)
elif "Heading" in p_style_id:
self.add_header(element, docx_obj, doc, p_style_id, p_level, text)
style_element = getattr(paragraph.style, "element", None)
if style_element:
is_numbered_style = (
"<w:numPr>" in style_element.xml or "<w:numPr>" in element.xml
)
else:
is_numbered_style = False
self.add_header(doc, p_level, text, is_numbered_style)
elif len(equations) > 0:
if (raw_text is None or len(raw_text) == 0) and len(text) > 0:
# Standalone equation
level = self.get_level()
doc.add_text(
label=DocItemLabel.FORMULA,
parent=self.parents[level - 1],
text=text,
)
else:
# Inline equation
level = self.get_level()
inline_equation = doc.add_group(
label=GroupLabel.INLINE, parent=self.parents[level - 1]
)
text_tmp = text
for eq in equations:
if len(text_tmp) == 0:
break
pre_eq_text = text_tmp.split(eq, maxsplit=1)[0]
text_tmp = text_tmp.split(eq, maxsplit=1)[1]
if len(pre_eq_text) > 0:
doc.add_text(
label=DocItemLabel.PARAGRAPH,
parent=inline_equation,
text=pre_eq_text,
)
doc.add_text(
label=DocItemLabel.FORMULA,
parent=inline_equation,
text=eq,
)
if len(text_tmp) > 0:
doc.add_text(
label=DocItemLabel.PARAGRAPH,
parent=inline_equation,
text=text_tmp,
)
elif p_style_id in [
"Paragraph",
@ -354,7 +480,13 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
self.update_history(p_style_id, p_level, numid, ilevel)
return
def add_header(self, element, docx_obj, doc, curr_name, curr_level, text: str):
def add_header(
self,
doc: DoclingDocument,
curr_level: Optional[int],
text: str,
is_numbered_style: bool = False,
) -> None:
level = self.get_level()
if isinstance(curr_level, int):
if curr_level > level:
@ -367,41 +499,64 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
)
elif curr_level < level:
# remove the tail
for key, val in self.parents.items():
for key in range(len(self.parents)):
if key >= curr_level:
self.parents[key] = None
self.parents[curr_level] = doc.add_heading(
parent=self.parents[curr_level - 1],
text=text,
level=curr_level,
)
current_level = curr_level
parent_level = curr_level - 1
add_level = curr_level
else:
self.parents[self.level] = doc.add_heading(
parent=self.parents[self.level - 1],
text=text,
level=1,
)
current_level = self.level
parent_level = self.level - 1
add_level = 1
if is_numbered_style:
if add_level in self.numbered_headers:
self.numbered_headers[add_level] += 1
else:
self.numbered_headers[add_level] = 1
text = f"{self.numbered_headers[add_level]} {text}"
# Reset deeper levels
next_level = add_level + 1
while next_level in self.numbered_headers:
self.numbered_headers[next_level] = 0
next_level += 1
# Scan upper levels
previous_level = add_level - 1
while previous_level in self.numbered_headers:
# MSWord convention: no empty sublevels
# I.e., sub-sub section (2.0.1) without a sub-section (2.1)
# is processed as 2.1.1
if self.numbered_headers[previous_level] == 0:
self.numbered_headers[previous_level] += 1
text = f"{self.numbered_headers[previous_level]}.{text}"
previous_level -= 1
self.parents[current_level] = doc.add_heading(
parent=self.parents[parent_level],
text=text,
level=add_level,
)
return
def add_listitem(
self,
element,
docx_obj,
doc,
p_style_id,
p_level,
numid,
ilevel,
doc: DoclingDocument,
numid: int,
ilevel: int,
text: str,
is_numbered=False,
):
# is_numbered = is_numbered
is_numbered: bool = False,
) -> None:
enum_marker = ""
level = self.get_level()
prev_indent = self.prev_indent()
if self.prev_numid() is None: # Open new list
self.level_at_new_list = level # type: ignore
self.level_at_new_list = level
self.parents[level] = doc.add_group(
label=GroupLabel.LIST, name="list", parent=self.parents[level - 1]
@ -420,10 +575,13 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
)
elif (
self.prev_numid() == numid and self.prev_indent() < ilevel
self.prev_numid() == numid
and self.level_at_new_list is not None
and prev_indent is not None
and prev_indent < ilevel
): # Open indented list
for i in range(
self.level_at_new_list + self.prev_indent() + 1,
self.level_at_new_list + prev_indent + 1,
self.level_at_new_list + ilevel + 1,
):
# Determine if this is an unordered list or an ordered list.
@ -452,7 +610,12 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
text=text,
)
elif self.prev_numid() == numid and ilevel < self.prev_indent(): # Close list
elif (
self.prev_numid() == numid
and self.level_at_new_list is not None
and prev_indent is not None
and ilevel < prev_indent
): # Close list
for k, v in self.parents.items():
if k > self.level_at_new_list + ilevel:
self.parents[k] = None
@ -470,7 +633,7 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
)
self.listIter = 0
elif self.prev_numid() == numid or self.prev_indent() == ilevel:
elif self.prev_numid() == numid or prev_indent == ilevel:
# TODO: Set marker and enumerated arguments if this is an enumeration element.
self.listIter += 1
if is_numbered:
@ -484,31 +647,16 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
)
return
def handle_tables(self, element, docx_obj, doc):
# Function to check if a cell has a colspan (gridSpan)
def get_colspan(cell):
grid_span = cell._element.xpath("@w:gridSpan")
if grid_span:
return int(grid_span[0]) # Return the number of columns spanned
return 1 # Default is 1 (no colspan)
# Function to check if a cell has a rowspan (vMerge)
def get_rowspan(cell):
v_merge = cell._element.xpath("@w:vMerge")
if v_merge:
return v_merge[
0
] # 'restart' indicates the beginning of a rowspan, others are continuation
return 1
table = docx.table.Table(element, docx_obj)
def handle_tables(
self,
element: BaseOxmlElement,
docx_obj: DocxDocument,
doc: DoclingDocument,
) -> None:
table: Table = Table(element, docx_obj)
num_rows = len(table.rows)
num_cols = 0
for row in table.rows:
# Calculate the max number of columns
num_cols = max(num_cols, sum(get_colspan(cell) for cell in row.cells))
num_cols = len(table.columns)
_log.debug(f"Table grid with {num_rows} rows and {num_cols} columns")
if num_rows == 1 and num_cols == 1:
cell_element = table.rows[0].cells[0]
@ -517,59 +665,56 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
self.walk_linear(cell_element._element, docx_obj, doc)
return
# Initialize the table grid
table_grid = [[None for _ in range(num_cols)] for _ in range(num_rows)]
data = TableData(num_rows=num_rows, num_cols=num_cols, table_cells=[])
data = TableData(num_rows=num_rows, num_cols=num_cols)
cell_set: set[CT_Tc] = set()
for row_idx, row in enumerate(table.rows):
_log.debug(f"Row index {row_idx} with {len(row.cells)} populated cells")
col_idx = 0
for c, cell in enumerate(row.cells):
row_span = get_rowspan(cell)
col_span = get_colspan(cell)
while col_idx < num_cols:
cell: _Cell = row.cells[col_idx]
_log.debug(
f" col {col_idx} grid_span {cell.grid_span} grid_cols_before {row.grid_cols_before}"
)
if cell is None or cell._tc in cell_set:
_log.debug(f" skipped since repeated content")
col_idx += cell.grid_span
continue
else:
cell_set.add(cell._tc)
cell_text = cell.text
# In case cell doesn't return text via docx library:
if len(cell_text) == 0:
cell_xml = cell._element
spanned_idx = row_idx
spanned_tc: Optional[CT_Tc] = cell._tc
while spanned_tc == cell._tc:
spanned_idx += 1
spanned_tc = (
table.rows[spanned_idx].cells[col_idx]._tc
if spanned_idx < num_rows
else None
)
_log.debug(f" spanned before row {spanned_idx}")
texts = [""]
for elem in cell_xml.iter():
if elem.tag.endswith("t"): # <w:t> tags that contain text
if elem.text:
texts.append(elem.text)
# Join the collected text
cell_text = " ".join(texts).strip()
# Find the next available column in the grid
while table_grid[row_idx][col_idx] is not None:
col_idx += 1
# Fill the grid with the cell value, considering rowspan and colspan
for i in range(row_span if row_span == "restart" else 1):
for j in range(col_span):
table_grid[row_idx + i][col_idx + j] = ""
cell = TableCell(
text=cell_text,
row_span=row_span,
col_span=col_span,
start_row_offset_idx=row_idx,
end_row_offset_idx=row_idx + row_span,
table_cell = TableCell(
text=cell.text,
row_span=spanned_idx - row_idx,
col_span=cell.grid_span,
start_row_offset_idx=row.grid_cols_before + row_idx,
end_row_offset_idx=row.grid_cols_before + spanned_idx,
start_col_offset_idx=col_idx,
end_col_offset_idx=col_idx + col_span,
col_header=False,
end_col_offset_idx=col_idx + cell.grid_span,
column_header=row.grid_cols_before + row_idx == 0,
row_header=False,
)
data.table_cells.append(cell)
data.table_cells.append(table_cell)
col_idx += cell.grid_span
level = self.get_level()
doc.add_table(data=data, parent=self.parents[level - 1])
return
def handle_pictures(self, element, docx_obj, drawing_blip, doc):
def get_docx_image(element, drawing_blip):
def handle_pictures(
self, docx_obj: DocxDocument, drawing_blip: Any, doc: DoclingDocument
) -> None:
def get_docx_image(drawing_blip):
rId = drawing_blip[0].get(
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
)
@ -579,11 +724,11 @@ class MsWordDocumentBackend(DeclarativeDocumentBackend):
image_data = image_part.blob # Get the binary image data
return image_data
image_data = get_docx_image(element, drawing_blip)
image_bytes = BytesIO(image_data)
level = self.get_level()
# Open the BytesIO object with PIL to create an Image
try:
image_data = get_docx_image(drawing_blip)
image_bytes = BytesIO(image_data)
pil_image = Image.open(image_bytes)
doc.add_picture(
parent=self.parents[level - 1],

View File

@ -4,21 +4,25 @@ from pathlib import Path
from typing import Iterable, Optional, Set, Union
from docling_core.types.doc import BoundingBox, Size
from docling_core.types.doc.page import SegmentedPdfPage, TextCell
from PIL import Image
from docling.backend.abstract_backend import PaginatedDocumentBackend
from docling.datamodel.base_models import Cell, InputFormat
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import InputDocument
class PdfPageBackend(ABC):
@abstractmethod
def get_text_in_rect(self, bbox: BoundingBox) -> str:
pass
@abstractmethod
def get_text_cells(self) -> Iterable[Cell]:
def get_segmented_page(self) -> Optional[SegmentedPdfPage]:
pass
@abstractmethod
def get_text_cells(self) -> Iterable[TextCell]:
pass
@abstractmethod
@ -45,7 +49,6 @@ class PdfPageBackend(ABC):
class PdfDocumentBackend(PaginatedDocumentBackend):
def __init__(self, in_doc: InputDocument, path_or_stream: Union[BytesIO, Path]):
super().__init__(in_doc, path_or_stream)

View File

@ -7,12 +7,13 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Union
import pypdfium2 as pdfium
import pypdfium2.raw as pdfium_c
from docling_core.types.doc import BoundingBox, CoordOrigin, Size
from docling_core.types.doc.page import BoundingRectangle, SegmentedPdfPage, TextCell
from PIL import Image, ImageDraw
from pypdfium2 import PdfTextPage
from pypdfium2._helpers.misc import PdfiumError
from docling.backend.pdf_backend import PdfDocumentBackend, PdfPageBackend
from docling.datamodel.base_models import Cell
from docling.utils.locks import pypdfium2_lock
if TYPE_CHECKING:
from docling.datamodel.document import InputDocument
@ -24,6 +25,7 @@ class PyPdfiumPageBackend(PdfPageBackend):
def __init__(
self, pdfium_doc: pdfium.PdfDocument, document_hash: str, page_no: int
):
# Note: lock applied by the caller
self.valid = True # No better way to tell from pypdfium.
try:
self._ppage: pdfium.PdfPage = pdfium_doc[page_no]
@ -39,101 +41,123 @@ class PyPdfiumPageBackend(PdfPageBackend):
return self.valid
def get_bitmap_rects(self, scale: float = 1) -> Iterable[BoundingBox]:
AREA_THRESHOLD = 32 * 32
for obj in self._ppage.get_objects(filter=[pdfium_c.FPDF_PAGEOBJ_IMAGE]):
pos = obj.get_pos()
cropbox = BoundingBox.from_tuple(
pos, origin=CoordOrigin.BOTTOMLEFT
).to_top_left_origin(page_height=self.get_size().height)
AREA_THRESHOLD = 0 # 32 * 32
page_size = self.get_size()
with pypdfium2_lock:
for obj in self._ppage.get_objects(filter=[pdfium_c.FPDF_PAGEOBJ_IMAGE]):
pos = obj.get_pos()
cropbox = BoundingBox.from_tuple(
pos, origin=CoordOrigin.BOTTOMLEFT
).to_top_left_origin(page_height=page_size.height)
if cropbox.area() > AREA_THRESHOLD:
cropbox = cropbox.scaled(scale=scale)
if cropbox.area() > AREA_THRESHOLD:
cropbox = cropbox.scaled(scale=scale)
yield cropbox
yield cropbox
def get_text_in_rect(self, bbox: BoundingBox) -> str:
if not self.text_page:
self.text_page = self._ppage.get_textpage()
with pypdfium2_lock:
if not self.text_page:
self.text_page = self._ppage.get_textpage()
if bbox.coord_origin != CoordOrigin.BOTTOMLEFT:
bbox = bbox.to_bottom_left_origin(self.get_size().height)
text_piece = self.text_page.get_text_bounded(*bbox.as_tuple())
with pypdfium2_lock:
text_piece = self.text_page.get_text_bounded(*bbox.as_tuple())
return text_piece
def get_text_cells(self) -> Iterable[Cell]:
if not self.text_page:
self.text_page = self._ppage.get_textpage()
def get_segmented_page(self) -> Optional[SegmentedPdfPage]:
return None
def get_text_cells(self) -> Iterable[TextCell]:
with pypdfium2_lock:
if not self.text_page:
self.text_page = self._ppage.get_textpage()
cells = []
cell_counter = 0
page_size = self.get_size()
for i in range(self.text_page.count_rects()):
rect = self.text_page.get_rect(i)
text_piece = self.text_page.get_text_bounded(*rect)
x0, y0, x1, y1 = rect
cells.append(
Cell(
id=cell_counter,
text=text_piece,
bbox=BoundingBox(
l=x0, b=y0, r=x1, t=y1, coord_origin=CoordOrigin.BOTTOMLEFT
).to_top_left_origin(page_size.height),
with pypdfium2_lock:
for i in range(self.text_page.count_rects()):
rect = self.text_page.get_rect(i)
text_piece = self.text_page.get_text_bounded(*rect)
x0, y0, x1, y1 = rect
cells.append(
TextCell(
index=cell_counter,
text=text_piece,
orig=text_piece,
from_ocr=False,
rect=BoundingRectangle.from_bounding_box(
BoundingBox(
l=x0,
b=y0,
r=x1,
t=y1,
coord_origin=CoordOrigin.BOTTOMLEFT,
)
).to_top_left_origin(page_size.height),
)
)
)
cell_counter += 1
cell_counter += 1
# PyPdfium2 produces very fragmented cells, with sub-word level boundaries, in many PDFs.
# The cell merging code below is to clean this up.
def merge_horizontal_cells(
cells: List[Cell],
cells: List[TextCell],
horizontal_threshold_factor: float = 1.0,
vertical_threshold_factor: float = 0.5,
) -> List[Cell]:
) -> List[TextCell]:
if not cells:
return []
def group_rows(cells: List[Cell]) -> List[List[Cell]]:
def group_rows(cells: List[TextCell]) -> List[List[TextCell]]:
rows = []
current_row = [cells[0]]
row_top = cells[0].bbox.t
row_bottom = cells[0].bbox.b
row_height = cells[0].bbox.height
row_top = cells[0].rect.to_bounding_box().t
row_bottom = cells[0].rect.to_bounding_box().b
row_height = cells[0].rect.to_bounding_box().height
for cell in cells[1:]:
vertical_threshold = row_height * vertical_threshold_factor
if (
abs(cell.bbox.t - row_top) <= vertical_threshold
and abs(cell.bbox.b - row_bottom) <= vertical_threshold
abs(cell.rect.to_bounding_box().t - row_top)
<= vertical_threshold
and abs(cell.rect.to_bounding_box().b - row_bottom)
<= vertical_threshold
):
current_row.append(cell)
row_top = min(row_top, cell.bbox.t)
row_bottom = max(row_bottom, cell.bbox.b)
row_top = min(row_top, cell.rect.to_bounding_box().t)
row_bottom = max(row_bottom, cell.rect.to_bounding_box().b)
row_height = row_bottom - row_top
else:
rows.append(current_row)
current_row = [cell]
row_top = cell.bbox.t
row_bottom = cell.bbox.b
row_height = cell.bbox.height
row_top = cell.rect.to_bounding_box().t
row_bottom = cell.rect.to_bounding_box().b
row_height = cell.rect.to_bounding_box().height
if current_row:
rows.append(current_row)
return rows
def merge_row(row: List[Cell]) -> List[Cell]:
def merge_row(row: List[TextCell]) -> List[TextCell]:
merged = []
current_group = [row[0]]
for cell in row[1:]:
prev_cell = current_group[-1]
avg_height = (prev_cell.bbox.height + cell.bbox.height) / 2
avg_height = (
prev_cell.rect.height + cell.rect.to_bounding_box().height
) / 2
if (
cell.bbox.l - prev_cell.bbox.r
cell.rect.to_bounding_box().l
- prev_cell.rect.to_bounding_box().r
<= avg_height * horizontal_threshold_factor
):
current_group.append(cell)
@ -146,24 +170,30 @@ class PyPdfiumPageBackend(PdfPageBackend):
return merged
def merge_group(group: List[Cell]) -> Cell:
def merge_group(group: List[TextCell]) -> TextCell:
if len(group) == 1:
return group[0]
merged_text = "".join(cell.text for cell in group)
merged_bbox = BoundingBox(
l=min(cell.bbox.l for cell in group),
t=min(cell.bbox.t for cell in group),
r=max(cell.bbox.r for cell in group),
b=max(cell.bbox.b for cell in group),
l=min(cell.rect.to_bounding_box().l for cell in group),
t=min(cell.rect.to_bounding_box().t for cell in group),
r=max(cell.rect.to_bounding_box().r for cell in group),
b=max(cell.rect.to_bounding_box().b for cell in group),
)
return TextCell(
index=group[0].index,
text=merged_text,
orig=merged_text,
rect=BoundingRectangle.from_bounding_box(merged_bbox),
from_ocr=False,
)
return Cell(id=group[0].id, text=merged_text, bbox=merged_bbox)
rows = group_rows(cells)
merged_cells = [cell for row in rows for cell in merge_row(row)]
for i, cell in enumerate(merged_cells, 1):
cell.id = i
cell.index = i
return merged_cells
@ -173,7 +203,7 @@ class PyPdfiumPageBackend(PdfPageBackend):
) # make new image to avoid drawing on the saved ones
draw = ImageDraw.Draw(image)
for c in cells:
x0, y0, x1, y1 = c.bbox.as_tuple()
x0, y0, x1, y1 = c.rect.to_bounding_box().as_tuple()
cell_color = (
random.randint(30, 140),
random.randint(30, 140),
@ -210,24 +240,28 @@ class PyPdfiumPageBackend(PdfPageBackend):
l=0, r=0, t=0, b=0, coord_origin=CoordOrigin.BOTTOMLEFT
)
else:
padbox = cropbox.to_bottom_left_origin(page_size.height)
padbox = cropbox.to_bottom_left_origin(page_size.height).model_copy()
padbox.r = page_size.width - padbox.r
padbox.t = page_size.height - padbox.t
image = (
self._ppage.render(
scale=scale * 1.5,
rotation=0, # no additional rotation
crop=padbox.as_tuple(),
)
.to_pil()
.resize(size=(round(cropbox.width * scale), round(cropbox.height * scale)))
) # We resize the image from 1.5x the given scale to make it sharper.
with pypdfium2_lock:
image = (
self._ppage.render(
scale=scale * 1.5,
rotation=0, # no additional rotation
crop=padbox.as_tuple(),
)
.to_pil()
.resize(
size=(round(cropbox.width * scale), round(cropbox.height * scale))
)
) # We resize the image from 1.5x the given scale to make it sharper.
return image
def get_size(self) -> Size:
return Size(width=self._ppage.get_width(), height=self._ppage.get_height())
with pypdfium2_lock:
return Size(width=self._ppage.get_width(), height=self._ppage.get_height())
def unload(self):
self._ppage = None
@ -239,22 +273,26 @@ class PyPdfiumDocumentBackend(PdfDocumentBackend):
super().__init__(in_doc, path_or_stream)
try:
self._pdoc = pdfium.PdfDocument(self.path_or_stream)
with pypdfium2_lock:
self._pdoc = pdfium.PdfDocument(self.path_or_stream)
except PdfiumError as e:
raise RuntimeError(
f"pypdfium could not load document with hash {self.document_hash}"
) from e
def page_count(self) -> int:
return len(self._pdoc)
with pypdfium2_lock:
return len(self._pdoc)
def load_page(self, page_no: int) -> PyPdfiumPageBackend:
return PyPdfiumPageBackend(self._pdoc, self.document_hash, page_no)
with pypdfium2_lock:
return PyPdfiumPageBackend(self._pdoc, self.document_hash, page_no)
def is_valid(self) -> bool:
return self.page_count() > 0
def unload(self):
super().unload()
self._pdoc.close()
self._pdoc = None
with pypdfium2_lock:
self._pdoc.close()
self._pdoc = None

View File

@ -0,0 +1,710 @@
import logging
import traceback
from io import BytesIO
from pathlib import Path
from typing import Final, Optional, Union
from bs4 import BeautifulSoup, Tag
from docling_core.types.doc import (
DocItemLabel,
DoclingDocument,
DocumentOrigin,
GroupItem,
GroupLabel,
NodeItem,
TextItem,
)
from lxml import etree
from typing_extensions import TypedDict, override
from docling.backend.abstract_backend import DeclarativeDocumentBackend
from docling.backend.html_backend import HTMLDocumentBackend
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import InputDocument
_log = logging.getLogger(__name__)
JATS_DTD_URL: Final = ["JATS-journalpublishing", "JATS-archive"]
DEFAULT_HEADER_ACKNOWLEDGMENTS: Final = "Acknowledgments"
DEFAULT_HEADER_ABSTRACT: Final = "Abstract"
DEFAULT_HEADER_REFERENCES: Final = "References"
DEFAULT_TEXT_ETAL: Final = "et al."
class Abstract(TypedDict):
label: str
content: str
class Author(TypedDict):
name: str
affiliation_names: list[str]
class Citation(TypedDict):
author_names: str
title: str
source: str
year: str
volume: str
page: str
pub_id: str
publisher_name: str
publisher_loc: str
class Table(TypedDict):
label: str
caption: str
content: str
class XMLComponents(TypedDict):
title: str
authors: list[Author]
abstract: list[Abstract]
class JatsDocumentBackend(DeclarativeDocumentBackend):
"""Backend to parse articles in XML format tagged according to JATS definition.
The Journal Article Tag Suite (JATS) is an definition standard for the
representation of journal articles in XML format. Several publishers and journal
archives provide content in JATS format, including PubMed Central® (PMC), bioRxiv,
medRxiv, or Springer Nature.
Refer to https://jats.nlm.nih.gov for more details on JATS.
The code from this document backend has been developed by modifying parts of the
PubMed Parser library (version 0.5.0, released on 12.08.2024):
Achakulvisut et al., (2020).
Pubmed Parser: A Python Parser for PubMed Open-Access XML Subset and MEDLINE XML
Dataset XML Dataset.
Journal of Open Source Software, 5(46), 1979,
https://doi.org/10.21105/joss.01979
"""
@override
def __init__(
self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]
) -> None:
super().__init__(in_doc, path_or_stream)
self.path_or_stream = path_or_stream
# Initialize the root of the document hiearchy
self.root: Optional[NodeItem] = None
self.valid = False
try:
if isinstance(self.path_or_stream, BytesIO):
self.path_or_stream.seek(0)
self.tree: etree._ElementTree = etree.parse(self.path_or_stream)
doc_info: etree.DocInfo = self.tree.docinfo
if doc_info.system_url and any(
[kwd in doc_info.system_url for kwd in JATS_DTD_URL]
):
self.valid = True
return
for ent in doc_info.internalDTD.iterentities():
if ent.system_url and any(
[kwd in ent.system_url for kwd in JATS_DTD_URL]
):
self.valid = True
return
except Exception as exc:
raise RuntimeError(
f"Could not initialize JATS backend for file with hash {self.document_hash}."
) from exc
@override
def is_valid(self) -> bool:
return self.valid
@classmethod
@override
def supports_pagination(cls) -> bool:
return False
@override
def unload(self):
if isinstance(self.path_or_stream, BytesIO):
self.path_or_stream.close()
self.path_or_stream = None
@classmethod
@override
def supported_formats(cls) -> set[InputFormat]:
return {InputFormat.XML_JATS}
@override
def convert(self) -> DoclingDocument:
try:
# Create empty document
origin = DocumentOrigin(
filename=self.file.name or "file",
mimetype="application/xml",
binary_hash=self.document_hash,
)
doc = DoclingDocument(name=self.file.stem or "file", origin=origin)
# Get metadata XML components
xml_components: XMLComponents = self._parse_metadata()
# Add metadata to the document
self._add_metadata(doc, xml_components)
# walk over the XML body
body = self.tree.xpath("//body")
if self.root and len(body) > 0:
self._walk_linear(doc, self.root, body[0])
# walk over the XML back matter
back = self.tree.xpath("//back")
if self.root and len(back) > 0:
self._walk_linear(doc, self.root, back[0])
except Exception:
_log.error(traceback.format_exc())
return doc
@staticmethod
def _get_text(node: etree._Element, sep: Optional[str] = None) -> str:
skip_tags = ["term", "disp-formula", "inline-formula"]
text: str = (
node.text.replace("\n", " ")
if (node.tag not in skip_tags and node.text)
else ""
)
for child in list(node):
if child.tag not in skip_tags:
# TODO: apply styling according to child.tag when supported by docling-core
text += JatsDocumentBackend._get_text(child, sep)
if sep:
text = text.rstrip(sep) + sep
text += child.tail.replace("\n", " ") if child.tail else ""
return text
def _find_metadata(self) -> Optional[etree._Element]:
meta_names: list[str] = ["article-meta", "book-part-meta"]
meta: Optional[etree._Element] = None
for name in meta_names:
node = self.tree.xpath(f".//{name}")
if len(node) > 0:
meta = node[0]
break
return meta
def _parse_abstract(self) -> list[Abstract]:
# TODO: address cases with multiple sections
abs_list: list[Abstract] = []
for abs_node in self.tree.xpath(".//abstract"):
abstract: Abstract = dict(label="", content="")
texts = []
for abs_par in abs_node.xpath("p"):
texts.append(JatsDocumentBackend._get_text(abs_par).strip())
abstract["content"] = " ".join(texts)
label_node = abs_node.xpath("title|label")
if len(label_node) > 0:
abstract["label"] = label_node[0].text.strip()
abs_list.append(abstract)
return abs_list
def _parse_authors(self) -> list[Author]:
# Get mapping between affiliation ids and names
authors: list[Author] = []
meta: Optional[etree._Element] = self._find_metadata()
if meta is None:
return authors
affiliation_names = []
for affiliation_node in meta.xpath(".//aff[@id]"):
aff = ", ".join([t for t in affiliation_node.itertext() if t.strip()])
aff = aff.replace("\n", " ")
label = affiliation_node.xpath("label")
if label:
# TODO: once superscript is supported, add label with formatting
aff = aff.removeprefix(f"{label[0].text}, ")
affiliation_names.append(aff)
affiliation_ids_names = {
id: name
for id, name in zip(meta.xpath(".//aff[@id]/@id"), affiliation_names)
}
# Get author names and affiliation names
for author_node in meta.xpath(
'.//contrib-group/contrib[@contrib-type="author"]'
):
author: Author = {
"name": "",
"affiliation_names": [],
}
# Affiliation names
affiliation_ids = [
a.attrib["rid"] for a in author_node.xpath('xref[@ref-type="aff"]')
]
for id in affiliation_ids:
if id in affiliation_ids_names:
author["affiliation_names"].append(affiliation_ids_names[id])
# Name
author["name"] = (
author_node.xpath("name/given-names")[0].text
+ " "
+ author_node.xpath("name/surname")[0].text
)
authors.append(author)
return authors
def _parse_title(self) -> str:
meta_names: list[str] = [
"article-meta",
"collection-meta",
"book-meta",
"book-part-meta",
]
title_names: list[str] = ["article-title", "subtitle", "title", "label"]
titles: list[str] = [
" ".join(
elem.text.replace("\n", " ").strip()
for elem in list(title_node)
if elem.tag in title_names
).strip()
for title_node in self.tree.xpath(
"|".join([f".//{item}/title-group" for item in meta_names])
)
]
text = " - ".join(titles)
return text
def _parse_metadata(self) -> XMLComponents:
"""Parsing JATS document metadata."""
xml_components: XMLComponents = {
"title": self._parse_title(),
"authors": self._parse_authors(),
"abstract": self._parse_abstract(),
}
return xml_components
def _add_abstract(
self, doc: DoclingDocument, xml_components: XMLComponents
) -> None:
for abstract in xml_components["abstract"]:
text: str = abstract["content"]
title: str = abstract["label"] or DEFAULT_HEADER_ABSTRACT
if not text:
continue
parent = doc.add_heading(parent=self.root, text=title)
doc.add_text(
parent=parent,
text=text,
label=DocItemLabel.TEXT,
)
return
def _add_authors(self, doc: DoclingDocument, xml_components: XMLComponents) -> None:
# TODO: once docling supports text formatting, add affiliation reference to
# author names through superscripts
authors: list = [item["name"] for item in xml_components["authors"]]
authors_str = ", ".join(authors)
affiliations: list = [
item
for author in xml_components["authors"]
for item in author["affiliation_names"]
]
affiliations_str = "; ".join(list(dict.fromkeys(affiliations)))
if authors_str:
doc.add_text(
parent=self.root,
text=authors_str,
label=DocItemLabel.PARAGRAPH,
)
if affiliations_str:
doc.add_text(
parent=self.root,
text=affiliations_str,
label=DocItemLabel.PARAGRAPH,
)
return
def _add_citation(self, doc: DoclingDocument, parent: NodeItem, text: str) -> None:
if isinstance(parent, GroupItem) and parent.label == GroupLabel.LIST:
doc.add_list_item(text=text, enumerated=False, parent=parent)
else:
doc.add_text(text=text, label=DocItemLabel.TEXT, parent=parent)
return
def _parse_element_citation(self, node: etree._Element) -> str:
citation: Citation = {
"author_names": "",
"title": "",
"source": "",
"year": "",
"volume": "",
"page": "",
"pub_id": "",
"publisher_name": "",
"publisher_loc": "",
}
_log.debug("Citation parsing started")
# Author names
names = []
for name_node in node.xpath(".//name"):
name_str = (
name_node.xpath("surname")[0].text.replace("\n", " ").strip()
+ " "
+ name_node.xpath("given-names")[0].text.replace("\n", " ").strip()
)
names.append(name_str)
etal_node = node.xpath(".//etal")
if len(etal_node) > 0:
etal_text = etal_node[0].text or DEFAULT_TEXT_ETAL
names.append(etal_text)
citation["author_names"] = ", ".join(names)
titles: list[str] = [
"article-title",
"chapter-title",
"data-title",
"issue-title",
"part-title",
"trans-title",
]
title_node: Optional[etree._Element] = None
for name in titles:
name_node = node.xpath(name)
if len(name_node) > 0:
title_node = name_node[0]
break
citation["title"] = (
JatsDocumentBackend._get_text(title_node)
if title_node is not None
else node.text.replace("\n", " ").strip()
)
# Journal, year, publisher name, publisher location, volume, elocation
fields: list[str] = [
"source",
"year",
"publisher-name",
"publisher-loc",
"volume",
]
for item in fields:
item_node = node.xpath(item)
if len(item_node) > 0:
citation[item.replace("-", "_")] = ( # type: ignore[literal-required]
item_node[0].text.replace("\n", " ").strip()
)
# Publication identifier
if len(node.xpath("pub-id")) > 0:
pub_id: list[str] = []
for id_node in node.xpath("pub-id"):
id_type = id_node.get("assigning-authority") or id_node.get(
"pub-id-type"
)
id_text = id_node.text
if id_type and id_text:
pub_id.append(
id_type.replace("\n", " ").strip().upper()
+ ": "
+ id_text.replace("\n", " ").strip()
)
if pub_id:
citation["pub_id"] = ", ".join(pub_id)
# Pages
if len(node.xpath("elocation-id")) > 0:
citation["page"] = (
node.xpath("elocation-id")[0].text.replace("\n", " ").strip()
)
elif len(node.xpath("fpage")) > 0:
citation["page"] = node.xpath("fpage")[0].text.replace("\n", " ").strip()
if len(node.xpath("lpage")) > 0:
citation["page"] += (
"" + node.xpath("lpage")[0].text.replace("\n", " ").strip()
)
# Flatten the citation to string
text = ""
if citation["author_names"]:
text += citation["author_names"].rstrip(".") + ". "
if citation["title"]:
text += citation["title"] + ". "
if citation["source"]:
text += citation["source"] + ". "
if citation["publisher_name"]:
if citation["publisher_loc"]:
text += f"{citation['publisher_loc']}: "
text += citation["publisher_name"] + ". "
if citation["volume"]:
text = text.rstrip(". ")
text += f" {citation['volume']}. "
if citation["page"]:
text = text.rstrip(". ")
if citation["volume"]:
text += ":"
text += citation["page"] + ". "
if citation["year"]:
text = text.rstrip(". ")
text += f" ({citation['year']})."
if citation["pub_id"]:
text = text.rstrip(".") + ". "
text += citation["pub_id"]
_log.debug("Citation flattened")
return text
def _add_equation(
self, doc: DoclingDocument, parent: NodeItem, node: etree._Element
) -> None:
math_text = node.text
math_parts = math_text.split("$$")
if len(math_parts) == 3:
math_formula = math_parts[1]
doc.add_text(label=DocItemLabel.FORMULA, text=math_formula, parent=parent)
return
def _add_figure_captions(
self, doc: DoclingDocument, parent: NodeItem, node: etree._Element
) -> None:
label_node = node.xpath("label")
label: Optional[str] = (
JatsDocumentBackend._get_text(label_node[0]).strip() if label_node else ""
)
caption_node = node.xpath("caption")
caption: Optional[str]
if len(caption_node) > 0:
caption = ""
for caption_par in list(caption_node[0]):
if caption_par.xpath(".//supplementary-material"):
continue
caption += JatsDocumentBackend._get_text(caption_par).strip() + " "
caption = caption.strip()
else:
caption = None
# TODO: format label vs caption once styling is supported
fig_text: str = f"{label}{' ' if label and caption else ''}{caption}"
fig_caption: Optional[TextItem] = (
doc.add_text(label=DocItemLabel.CAPTION, text=fig_text)
if fig_text
else None
)
doc.add_picture(parent=parent, caption=fig_caption)
return
# TODO: add footnotes when DocItemLabel.FOOTNOTE and styling are supported
# def _add_footnote_group(self, doc: DoclingDocument, parent: NodeItem, node: etree._Element) -> None:
# new_parent = doc.add_group(label=GroupLabel.LIST, name="footnotes", parent=parent)
# for child in node.iterchildren(tag="fn"):
# text = JatsDocumentBackend._get_text(child)
# doc.add_list_item(text=text, parent=new_parent)
def _add_metadata(
self, doc: DoclingDocument, xml_components: XMLComponents
) -> None:
self._add_title(doc, xml_components)
self._add_authors(doc, xml_components)
self._add_abstract(doc, xml_components)
return
def _add_table(
self, doc: DoclingDocument, parent: NodeItem, table_xml_component: Table
) -> None:
soup = BeautifulSoup(table_xml_component["content"], "html.parser")
table_tag = soup.find("table")
if not isinstance(table_tag, Tag):
return
data = HTMLDocumentBackend.parse_table_data(table_tag)
# TODO: format label vs caption once styling is supported
label = table_xml_component["label"]
caption = table_xml_component["caption"]
table_text: str = f"{label}{' ' if label and caption else ''}{caption}"
table_caption: Optional[TextItem] = (
doc.add_text(label=DocItemLabel.CAPTION, text=table_text)
if table_text
else None
)
if data is not None:
doc.add_table(data=data, parent=parent, caption=table_caption)
return
def _add_tables(
self, doc: DoclingDocument, parent: NodeItem, node: etree._Element
) -> None:
table: Table = {"label": "", "caption": "", "content": ""}
# Content
if len(node.xpath("table")) > 0:
table_content_node = node.xpath("table")[0]
elif len(node.xpath("alternatives/table")) > 0:
table_content_node = node.xpath("alternatives/table")[0]
else:
table_content_node = None
if table_content_node is not None:
table["content"] = etree.tostring(table_content_node).decode("utf-8")
# Caption
caption_node = node.xpath("caption")
caption: Optional[str]
if caption_node:
caption = ""
for caption_par in list(caption_node[0]):
if caption_par.xpath(".//supplementary-material"):
continue
caption += JatsDocumentBackend._get_text(caption_par).strip() + " "
caption = caption.strip()
else:
caption = None
if caption is not None:
table["caption"] = caption
# Label
if len(node.xpath("label")) > 0:
table["label"] = node.xpath("label")[0].text
try:
self._add_table(doc, parent, table)
except Exception as e:
_log.warning(f"Skipping unsupported table in {str(self.file)}")
pass
return
def _add_title(self, doc: DoclingDocument, xml_components: XMLComponents) -> None:
self.root = doc.add_text(
parent=None,
text=xml_components["title"],
label=DocItemLabel.TITLE,
)
return
def _walk_linear(
self, doc: DoclingDocument, parent: NodeItem, node: etree._Element
) -> str:
skip_tags = ["term"]
flush_tags = ["ack", "sec", "list", "boxed-text", "disp-formula", "fig"]
new_parent: NodeItem = parent
node_text: str = (
node.text.replace("\n", " ")
if (node.tag not in skip_tags and node.text)
else ""
)
for child in list(node):
stop_walk: bool = False
# flush text into TextItem for some tags in paragraph nodes
if node.tag == "p" and node_text.strip() and child.tag in flush_tags:
doc.add_text(
label=DocItemLabel.TEXT, text=node_text.strip(), parent=parent
)
node_text = ""
# add elements and decide whether to stop walking
if child.tag in ("sec", "ack"):
header = child.xpath("title|label")
text: Optional[str] = None
if len(header) > 0:
text = JatsDocumentBackend._get_text(header[0])
elif child.tag == "ack":
text = DEFAULT_HEADER_ACKNOWLEDGMENTS
if text:
new_parent = doc.add_heading(text=text, parent=parent)
elif child.tag == "list":
new_parent = doc.add_group(
label=GroupLabel.LIST, name="list", parent=parent
)
elif child.tag == "list-item":
# TODO: address any type of content (another list, formula,...)
# TODO: address list type and item label
text = JatsDocumentBackend._get_text(child).strip()
new_parent = doc.add_list_item(text=text, parent=parent)
stop_walk = True
elif child.tag == "fig":
self._add_figure_captions(doc, parent, child)
stop_walk = True
elif child.tag == "table-wrap":
self._add_tables(doc, parent, child)
stop_walk = True
elif child.tag == "suplementary-material":
stop_walk = True
elif child.tag == "fn-group":
# header = child.xpath(".//title") or child.xpath(".//label")
# if header:
# text = JatsDocumentBackend._get_text(header[0])
# fn_parent = doc.add_heading(text=text, parent=new_parent)
# self._add_footnote_group(doc, fn_parent, child)
stop_walk = True
elif child.tag == "ref-list" and node.tag != "ref-list":
header = child.xpath("title|label")
text = (
JatsDocumentBackend._get_text(header[0])
if len(header) > 0
else DEFAULT_HEADER_REFERENCES
)
new_parent = doc.add_heading(text=text, parent=parent)
new_parent = doc.add_group(
parent=new_parent, label=GroupLabel.LIST, name="list"
)
elif child.tag == "element-citation":
text = self._parse_element_citation(child)
self._add_citation(doc, parent, text)
stop_walk = True
elif child.tag == "mixed-citation":
text = JatsDocumentBackend._get_text(child).strip()
self._add_citation(doc, parent, text)
stop_walk = True
elif child.tag == "tex-math":
self._add_equation(doc, parent, child)
stop_walk = True
elif child.tag == "inline-formula":
# TODO: address inline formulas when supported by docling-core
stop_walk = True
# step into child
if not stop_walk:
new_text = self._walk_linear(doc, new_parent, child)
if not (node.getparent().tag == "p" and node.tag in flush_tags):
node_text += new_text
# pick up the tail text
node_text += child.tail.replace("\n", " ") if child.tail else ""
# create paragraph
if node.tag == "p" and node_text.strip():
doc.add_text(label=DocItemLabel.TEXT, text=node_text.strip(), parent=parent)
return ""
else:
# backpropagate the text
return node_text

View File

@ -1,592 +0,0 @@
import logging
from io import BytesIO
from pathlib import Path
from typing import Any, Set, Union
import lxml
from bs4 import BeautifulSoup
from docling_core.types.doc import (
DocItemLabel,
DoclingDocument,
DocumentOrigin,
GroupLabel,
TableCell,
TableData,
)
from lxml import etree
from typing_extensions import TypedDict, override
from docling.backend.abstract_backend import DeclarativeDocumentBackend
from docling.datamodel.base_models import InputFormat
from docling.datamodel.document import InputDocument
_log = logging.getLogger(__name__)
class Paragraph(TypedDict):
text: str
headers: list[str]
class Author(TypedDict):
name: str
affiliation_names: list[str]
class Table(TypedDict):
label: str
caption: str
content: str
class FigureCaption(TypedDict):
label: str
caption: str
class Reference(TypedDict):
author_names: str
title: str
journal: str
year: str
class XMLComponents(TypedDict):
title: str
authors: list[Author]
abstract: str
paragraphs: list[Paragraph]
tables: list[Table]
figure_captions: list[FigureCaption]
references: list[Reference]
class PubMedDocumentBackend(DeclarativeDocumentBackend):
"""
The code from this document backend has been developed by modifying parts of the PubMed Parser library (version 0.5.0, released on 12.08.2024):
Achakulvisut et al., (2020).
Pubmed Parser: A Python Parser for PubMed Open-Access XML Subset and MEDLINE XML Dataset XML Dataset.
Journal of Open Source Software, 5(46), 1979,
https://doi.org/10.21105/joss.01979
"""
@override
def __init__(self, in_doc: "InputDocument", path_or_stream: Union[BytesIO, Path]):
super().__init__(in_doc, path_or_stream)
self.path_or_stream = path_or_stream
# Initialize parents for the document hierarchy
self.parents: dict = {}
self.valid = False
try:
if isinstance(self.path_or_stream, BytesIO):
self.path_or_stream.seek(0)
self.tree: lxml.etree._ElementTree = etree.parse(self.path_or_stream)
if "/NLM//DTD JATS" in self.tree.docinfo.public_id:
self.valid = True
except Exception as exc:
raise RuntimeError(
f"Could not initialize PubMed backend for file with hash {self.document_hash}."
) from exc
@override
def is_valid(self) -> bool:
return self.valid
@classmethod
@override
def supports_pagination(cls) -> bool:
return False
@override
def unload(self):
if isinstance(self.path_or_stream, BytesIO):
self.path_or_stream.close()
self.path_or_stream = None
@classmethod
@override
def supported_formats(cls) -> Set[InputFormat]:
return {InputFormat.XML_PUBMED}
@override
def convert(self) -> DoclingDocument:
# Create empty document
origin = DocumentOrigin(
filename=self.file.name or "file",
mimetype="application/xml",
binary_hash=self.document_hash,
)
doc = DoclingDocument(name=self.file.stem or "file", origin=origin)
_log.debug("Trying to convert PubMed XML document...")
# Get parsed XML components
xml_components: XMLComponents = self._parse()
# Add XML components to the document
doc = self._populate_document(doc, xml_components)
return doc
def _parse_title(self) -> str:
title: str = " ".join(
[
t.replace("\n", "")
for t in self.tree.xpath(".//title-group/article-title")[0].itertext()
]
)
return title
def _parse_authors(self) -> list[Author]:
# Get mapping between affiliation ids and names
affiliation_names = []
for affiliation_node in self.tree.xpath(".//aff[@id]"):
affiliation_names.append(
": ".join([t for t in affiliation_node.itertext() if t != "\n"])
)
affiliation_ids_names = {
id: name
for id, name in zip(self.tree.xpath(".//aff[@id]/@id"), affiliation_names)
}
# Get author names and affiliation names
authors: list[Author] = []
for author_node in self.tree.xpath(
'.//contrib-group/contrib[@contrib-type="author"]'
):
author: Author = {
"name": "",
"affiliation_names": [],
}
# Affiliation names
affiliation_ids = [
a.attrib["rid"] for a in author_node.xpath('xref[@ref-type="aff"]')
]
for id in affiliation_ids:
if id in affiliation_ids_names:
author["affiliation_names"].append(affiliation_ids_names[id])
# Name
author["name"] = (
author_node.xpath("name/surname")[0].text
+ " "
+ author_node.xpath("name/given-names")[0].text
)
authors.append(author)
return authors
def _parse_abstract(self) -> str:
texts = []
for abstract_node in self.tree.xpath(".//abstract"):
for text in abstract_node.itertext():
texts.append(text.replace("\n", ""))
abstract: str = "".join(texts)
return abstract
def _parse_main_text(self) -> list[Paragraph]:
paragraphs: list[Paragraph] = []
for paragraph_node in self.tree.xpath("//body//p"):
# Skip captions
if "/caption" in paragraph_node.getroottree().getpath(paragraph_node):
continue
paragraph: Paragraph = {"text": "", "headers": []}
# Text
paragraph["text"] = "".join(
[t.replace("\n", "") for t in paragraph_node.itertext()]
)
# Header
path = "../title"
while len(paragraph_node.xpath(path)) > 0:
paragraph["headers"].append(
"".join(
[
t.replace("\n", "")
for t in paragraph_node.xpath(path)[0].itertext()
]
)
)
path = "../" + path
paragraphs.append(paragraph)
return paragraphs
def _parse_tables(self) -> list[Table]:
tables: list[Table] = []
for table_node in self.tree.xpath(".//body//table-wrap"):
table: Table = {"label": "", "caption": "", "content": ""}
# Content
if len(table_node.xpath("table")) > 0:
table_content_node = table_node.xpath("table")[0]
elif len(table_node.xpath("alternatives/table")) > 0:
table_content_node = table_node.xpath("alternatives/table")[0]
else:
table_content_node = None
if table_content_node != None:
table["content"] = etree.tostring(table_content_node).decode("utf-8")
# Caption
if len(table_node.xpath("caption/p")) > 0:
caption_node = table_node.xpath("caption/p")[0]
elif len(table_node.xpath("caption/title")) > 0:
caption_node = table_node.xpath("caption/title")[0]
else:
caption_node = None
if caption_node != None:
table["caption"] = "".join(
[t.replace("\n", "") for t in caption_node.itertext()]
)
# Label
if len(table_node.xpath("label")) > 0:
table["label"] = table_node.xpath("label")[0].text
tables.append(table)
return tables
def _parse_figure_captions(self) -> list[FigureCaption]:
figure_captions: list[FigureCaption] = []
if not (self.tree.xpath(".//fig")):
return figure_captions
for figure_node in self.tree.xpath(".//fig"):
figure_caption: FigureCaption = {
"caption": "",
"label": "",
}
# Label
if figure_node.xpath("label"):
figure_caption["label"] = "".join(
[
t.replace("\n", "")
for t in figure_node.xpath("label")[0].itertext()
]
)
# Caption
if figure_node.xpath("caption"):
caption = ""
for caption_node in figure_node.xpath("caption")[0].getchildren():
caption += (
"".join([t.replace("\n", "") for t in caption_node.itertext()])
+ "\n"
)
figure_caption["caption"] = caption
figure_captions.append(figure_caption)
return figure_captions
def _parse_references(self) -> list[Reference]:
references: list[Reference] = []
for reference_node_abs in self.tree.xpath(".//ref-list/ref"):
reference: Reference = {
"author_names": "",
"title": "",
"journal": "",
"year": "",
}
reference_node: Any = None
for tag in ["mixed-citation", "element-citation", "citation"]:
if len(reference_node_abs.xpath(tag)) > 0:
reference_node = reference_node_abs.xpath(tag)[0]
break
if reference_node is None:
continue
if all(
not (ref_type in ["citation-type", "publication-type"])
for ref_type in reference_node.attrib.keys()
):
continue
# Author names
names = []
if len(reference_node.xpath("name")) > 0:
for name_node in reference_node.xpath("name"):
name_str = " ".join(
[t.text for t in name_node.getchildren() if (t.text != None)]
)
names.append(name_str)
elif len(reference_node.xpath("person-group")) > 0:
for name_node in reference_node.xpath("person-group")[0]:
name_str = (
name_node.xpath("given-names")[0].text
+ " "
+ name_node.xpath("surname")[0].text
)
names.append(name_str)
reference["author_names"] = "; ".join(names)
# Title
if len(reference_node.xpath("article-title")) > 0:
reference["title"] = " ".join(
[
t.replace("\n", " ")
for t in reference_node.xpath("article-title")[0].itertext()
]
)
# Journal
if len(reference_node.xpath("source")) > 0:
reference["journal"] = reference_node.xpath("source")[0].text
# Year
if len(reference_node.xpath("year")) > 0:
reference["year"] = reference_node.xpath("year")[0].text
if (
not (reference_node.xpath("article-title"))
and not (reference_node.xpath("journal"))
and not (reference_node.xpath("year"))
):
reference["title"] = reference_node.text
references.append(reference)
return references
def _parse(self) -> XMLComponents:
"""Parsing PubMed document."""
xml_components: XMLComponents = {
"title": self._parse_title(),
"authors": self._parse_authors(),
"abstract": self._parse_abstract(),
"paragraphs": self._parse_main_text(),
"tables": self._parse_tables(),
"figure_captions": self._parse_figure_captions(),
"references": self._parse_references(),
}
return xml_components
def _populate_document(
self, doc: DoclingDocument, xml_components: XMLComponents
) -> DoclingDocument:
self._add_title(doc, xml_components)
self._add_authors(doc, xml_components)
self._add_abstract(doc, xml_components)
self._add_main_text(doc, xml_components)
if xml_components["tables"]:
self._add_tables(doc, xml_components)
if xml_components["figure_captions"]:
self._add_figure_captions(doc, xml_components)
self._add_references(doc, xml_components)
return doc
def _add_figure_captions(
self, doc: DoclingDocument, xml_components: XMLComponents
) -> None:
self.parents["Figures"] = doc.add_heading(
parent=self.parents["Title"], text="Figures"
)
for figure_caption_xml_component in xml_components["figure_captions"]:
figure_caption_text = (
figure_caption_xml_component["label"]
+ ": "
+ figure_caption_xml_component["caption"].strip()
)
fig_caption = doc.add_text(
label=DocItemLabel.CAPTION, text=figure_caption_text
)
doc.add_picture(
parent=self.parents["Figures"],
caption=fig_caption,
)
return
def _add_title(self, doc: DoclingDocument, xml_components: XMLComponents) -> None:
self.parents["Title"] = doc.add_text(
parent=None,
text=xml_components["title"],
label=DocItemLabel.TITLE,
)
return
def _add_authors(self, doc: DoclingDocument, xml_components: XMLComponents) -> None:
authors_affiliations: list = []
for author in xml_components["authors"]:
authors_affiliations.append(author["name"])
authors_affiliations.append(", ".join(author["affiliation_names"]))
authors_affiliations_str = "; ".join(authors_affiliations)
doc.add_text(
parent=self.parents["Title"],
text=authors_affiliations_str,
label=DocItemLabel.PARAGRAPH,
)
return
def _add_abstract(
self, doc: DoclingDocument, xml_components: XMLComponents
) -> None:
abstract_text: str = xml_components["abstract"]
self.parents["Abstract"] = doc.add_heading(
parent=self.parents["Title"], text="Abstract"
)
doc.add_text(
parent=self.parents["Abstract"],
text=abstract_text,
label=DocItemLabel.TEXT,
)
return
def _add_main_text(
self, doc: DoclingDocument, xml_components: XMLComponents
) -> None:
added_headers: list = []
for paragraph in xml_components["paragraphs"]:
if not (paragraph["headers"]):
continue
# Header
for i, header in enumerate(reversed(paragraph["headers"])):
if header in added_headers:
continue
added_headers.append(header)
if ((i - 1) >= 0) and list(reversed(paragraph["headers"]))[
i - 1
] in self.parents:
parent = self.parents[list(reversed(paragraph["headers"]))[i - 1]]
else:
parent = self.parents["Title"]
self.parents[header] = doc.add_heading(parent=parent, text=header)
# Paragraph text
if paragraph["headers"][0] in self.parents:
parent = self.parents[paragraph["headers"][0]]
else:
parent = self.parents["Title"]
doc.add_text(parent=parent, label=DocItemLabel.TEXT, text=paragraph["text"])
return
def _add_references(
self, doc: DoclingDocument, xml_components: XMLComponents
) -> None:
self.parents["References"] = doc.add_heading(
parent=self.parents["Title"], text="References"
)
current_list = doc.add_group(
parent=self.parents["References"], label=GroupLabel.LIST, name="list"
)
for reference in xml_components["references"]:
reference_text: str = ""
if reference["author_names"]:
reference_text += reference["author_names"] + ". "
if reference["title"]:
reference_text += reference["title"]
if reference["title"][-1] != ".":
reference_text += "."
reference_text += " "
if reference["journal"]:
reference_text += reference["journal"]
if reference["year"]:
reference_text += " (" + reference["year"] + ")"
if not (reference_text):
_log.debug(f"Skipping reference for: {str(self.file)}")
continue
doc.add_list_item(
text=reference_text, enumerated=False, parent=current_list
)
return
def _add_tables(self, doc: DoclingDocument, xml_components: XMLComponents) -> None:
self.parents["Tables"] = doc.add_heading(
parent=self.parents["Title"], text="Tables"
)
for table_xml_component in xml_components["tables"]:
try:
self._add_table(doc, table_xml_component)
except Exception as e:
_log.debug(f"Skipping unsupported table for: {str(self.file)}")
pass
return
def _add_table(self, doc: DoclingDocument, table_xml_component: Table) -> None:
soup = BeautifulSoup(table_xml_component["content"], "html.parser")
table_tag = soup.find("table")
nested_tables = table_tag.find("table")
if nested_tables:
_log.debug(f"Skipping nested table for: {str(self.file)}")
return
# Count the number of rows (number of <tr> elements)
num_rows = len(table_tag.find_all("tr"))
# Find the number of columns (taking into account colspan)
num_cols = 0
for row in table_tag.find_all("tr"):
col_count = 0
for cell in row.find_all(["td", "th"]):
colspan = int(cell.get("colspan", 1))
col_count += colspan
num_cols = max(num_cols, col_count)
grid = [[None for _ in range(num_cols)] for _ in range(num_rows)]
data = TableData(num_rows=num_rows, num_cols=num_cols, table_cells=[])
# Iterate over the rows in the table
for row_idx, row in enumerate(table_tag.find_all("tr")):
# For each row, find all the column cells (both <td> and <th>)
cells = row.find_all(["td", "th"])
# Check if each cell in the row is a header -> means it is a column header
col_header = True
for j, html_cell in enumerate(cells):
if html_cell.name == "td":
col_header = False
# Extract and print the text content of each cell
col_idx = 0
for _, html_cell in enumerate(cells):
text = html_cell.text
col_span = int(html_cell.get("colspan", 1))
row_span = int(html_cell.get("rowspan", 1))
while grid[row_idx][col_idx] != None:
col_idx += 1
for r in range(row_span):
for c in range(col_span):
grid[row_idx + r][col_idx + c] = text
cell = TableCell(
text=text,
row_span=row_span,
col_span=col_span,
start_row_offset_idx=row_idx,
end_row_offset_idx=row_idx + row_span,
start_col_offset_idx=col_idx,
end_col_offset_idx=col_idx + col_span,
col_header=col_header,
row_header=((not col_header) and html_cell.name == "th"),
)
data.table_cells.append(cell)
table_caption = doc.add_text(
label=DocItemLabel.CAPTION,
text=table_xml_component["label"] + ": " + table_xml_component["caption"],
)
doc.add_table(data=data, parent=self.parents["Tables"], caption=table_caption)
return

View File

@ -14,7 +14,7 @@ from abc import ABC, abstractmethod
from enum import Enum, unique
from io import BytesIO
from pathlib import Path
from typing import Any, Final, Optional, Union
from typing import Final, Optional, Union
from bs4 import BeautifulSoup, Tag
from docling_core.types.doc import (
@ -389,7 +389,7 @@ class PatentUsptoIce(PatentUspto):
if name == self.Element.TITLE.value:
if text:
self.parents[self.level + 1] = self.doc.add_title(
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
text=text,
)
self.level += 1
@ -406,7 +406,7 @@ class PatentUsptoIce(PatentUspto):
abstract_item = self.doc.add_heading(
heading_text,
level=heading_level,
parent=self.parents[heading_level], # type: ignore[arg-type]
parent=self.parents[heading_level],
)
self.doc.add_text(
label=DocItemLabel.PARAGRAPH,
@ -434,7 +434,7 @@ class PatentUsptoIce(PatentUspto):
claims_item = self.doc.add_heading(
heading_text,
level=heading_level,
parent=self.parents[heading_level], # type: ignore[arg-type]
parent=self.parents[heading_level],
)
for text in self.claims:
self.doc.add_text(
@ -452,7 +452,7 @@ class PatentUsptoIce(PatentUspto):
self.doc.add_text(
label=DocItemLabel.PARAGRAPH,
text=text,
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
self.text = ""
@ -460,7 +460,7 @@ class PatentUsptoIce(PatentUspto):
self.parents[self.level + 1] = self.doc.add_heading(
text=text,
level=self.level,
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
self.level += 1
self.text = ""
@ -470,7 +470,7 @@ class PatentUsptoIce(PatentUspto):
empty_table = TableData(num_rows=0, num_cols=0, table_cells=[])
self.doc.add_table(
data=empty_table,
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
def _apply_style(self, text: str, style_tag: str) -> str:
@ -721,7 +721,7 @@ class PatentUsptoGrantV2(PatentUspto):
if self.Element.TITLE.value in self.property and text.strip():
title = text.strip()
self.parents[self.level + 1] = self.doc.add_title(
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
text=title,
)
self.level += 1
@ -749,7 +749,7 @@ class PatentUsptoGrantV2(PatentUspto):
self.parents[self.level + 1] = self.doc.add_heading(
text=text.strip(),
level=self.level,
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
self.level += 1
@ -769,7 +769,7 @@ class PatentUsptoGrantV2(PatentUspto):
claims_item = self.doc.add_heading(
heading_text,
level=heading_level,
parent=self.parents[heading_level], # type: ignore[arg-type]
parent=self.parents[heading_level],
)
for text in self.claims:
self.doc.add_text(
@ -787,7 +787,7 @@ class PatentUsptoGrantV2(PatentUspto):
abstract_item = self.doc.add_heading(
heading_text,
level=heading_level,
parent=self.parents[heading_level], # type: ignore[arg-type]
parent=self.parents[heading_level],
)
self.doc.add_text(
label=DocItemLabel.PARAGRAPH, text=abstract, parent=abstract_item
@ -799,7 +799,7 @@ class PatentUsptoGrantV2(PatentUspto):
self.doc.add_text(
label=DocItemLabel.PARAGRAPH,
text=paragraph,
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
elif self.Element.CLAIM.value in self.property:
# we may need a space after a paragraph in claim text
@ -811,7 +811,7 @@ class PatentUsptoGrantV2(PatentUspto):
empty_table = TableData(num_rows=0, num_cols=0, table_cells=[])
self.doc.add_table(
data=empty_table,
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
def _apply_style(self, text: str, style_tag: str) -> str:
@ -938,7 +938,7 @@ class PatentUsptoGrantAps(PatentUspto):
self.parents[self.level + 1] = self.doc.add_heading(
heading.value,
level=self.level,
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
self.level += 1
@ -959,7 +959,7 @@ class PatentUsptoGrantAps(PatentUspto):
if field == self.Field.TITLE.value:
self.parents[self.level + 1] = self.doc.add_title(
parent=self.parents[self.level], text=value # type: ignore[arg-type]
parent=self.parents[self.level], text=value
)
self.level += 1
@ -971,14 +971,14 @@ class PatentUsptoGrantAps(PatentUspto):
self.doc.add_text(
label=DocItemLabel.PARAGRAPH,
text=value,
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
elif field == self.Field.NUMBER.value and section == self.Section.CLAIMS.value:
self.doc.add_text(
label=DocItemLabel.PARAGRAPH,
text="",
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
elif (
@ -996,10 +996,10 @@ class PatentUsptoGrantAps(PatentUspto):
last_claim = self.doc.add_text(
label=DocItemLabel.PARAGRAPH,
text="",
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
last_claim.text += f" {value}" if last_claim.text else value
last_claim.text += f" {value.strip()}" if last_claim.text else value.strip()
elif field == self.Field.CAPTION.value and section in (
self.Section.SUMMARY.value,
@ -1012,7 +1012,7 @@ class PatentUsptoGrantAps(PatentUspto):
self.parents[self.level + 1] = self.doc.add_heading(
value,
level=self.level,
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
self.level += 1
@ -1029,7 +1029,7 @@ class PatentUsptoGrantAps(PatentUspto):
self.doc.add_text(
label=DocItemLabel.PARAGRAPH,
text=value,
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
def parse(self, patent_content: str) -> Optional[DoclingDocument]:
@ -1283,7 +1283,7 @@ class PatentUsptoAppV1(PatentUspto):
title = text.strip()
if title:
self.parents[self.level + 1] = self.doc.add_text(
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
label=DocItemLabel.TITLE,
text=title,
)
@ -1301,7 +1301,7 @@ class PatentUsptoAppV1(PatentUspto):
abstract_item = self.doc.add_heading(
heading_text,
level=heading_level,
parent=self.parents[heading_level], # type: ignore[arg-type]
parent=self.parents[heading_level],
)
self.doc.add_text(
label=DocItemLabel.PARAGRAPH,
@ -1331,7 +1331,7 @@ class PatentUsptoAppV1(PatentUspto):
claims_item = self.doc.add_heading(
heading_text,
level=heading_level,
parent=self.parents[heading_level], # type: ignore[arg-type]
parent=self.parents[heading_level],
)
for text in self.claims:
self.doc.add_text(
@ -1350,14 +1350,14 @@ class PatentUsptoAppV1(PatentUspto):
self.parents[self.level + 1] = self.doc.add_heading(
text=text,
level=self.level,
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
self.level += 1
else:
self.doc.add_text(
label=DocItemLabel.PARAGRAPH,
text=text,
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
self.text = ""
@ -1366,7 +1366,7 @@ class PatentUsptoAppV1(PatentUspto):
empty_table = TableData(num_rows=0, num_cols=0, table_cells=[])
self.doc.add_table(
data=empty_table,
parent=self.parents[self.level], # type: ignore[arg-type]
parent=self.parents[self.level],
)
def _apply_style(self, text: str, style_tag: str) -> str:
@ -1406,6 +1406,10 @@ class XmlTable:
http://oasis-open.org/specs/soextblx.dtd
"""
class ColInfo(TypedDict):
ncols: int
colinfo: list[dict]
class MinColInfoType(TypedDict):
offset: list[int]
colwidth: list[int]
@ -1425,7 +1429,7 @@ class XmlTable:
self.empty_text = ""
self._soup = BeautifulSoup(input, features="xml")
def _create_tg_range(self, tgs: list[dict[str, Any]]) -> dict[int, ColInfoType]:
def _create_tg_range(self, tgs: list[ColInfo]) -> dict[int, ColInfoType]:
"""Create a unified range along the table groups.
Args:
@ -1532,19 +1536,26 @@ class XmlTable:
Returns:
A docling table object.
"""
tgs_align = []
tg_secs = table.find_all("tgroup")
tgs_align: list[XmlTable.ColInfo] = []
tg_secs = table("tgroup")
if tg_secs:
for tg_sec in tg_secs:
ncols = tg_sec.get("cols", None)
if ncols:
ncols = int(ncols)
tg_align = {"ncols": ncols, "colinfo": []}
cs_secs = tg_sec.find_all("colspec")
if not isinstance(tg_sec, Tag):
continue
col_val = tg_sec.get("cols")
ncols = (
int(col_val)
if isinstance(col_val, str) and col_val.isnumeric()
else 1
)
tg_align: XmlTable.ColInfo = {"ncols": ncols, "colinfo": []}
cs_secs = tg_sec("colspec")
if cs_secs:
for cs_sec in cs_secs:
colname = cs_sec.get("colname", None)
colwidth = cs_sec.get("colwidth", None)
if not isinstance(cs_sec, Tag):
continue
colname = cs_sec.get("colname")
colwidth = cs_sec.get("colwidth")
tg_align["colinfo"].append(
{"colname": colname, "colwidth": colwidth}
)
@ -1565,16 +1576,23 @@ class XmlTable:
table_data: list[TableCell] = []
i_row_global = 0
is_row_empty: bool = True
tg_secs = table.find_all("tgroup")
tg_secs = table("tgroup")
if tg_secs:
for itg, tg_sec in enumerate(tg_secs):
if not isinstance(tg_sec, Tag):
continue
tg_range = tgs_range[itg]
row_secs = tg_sec.find_all(["row", "tr"])
row_secs = tg_sec(["row", "tr"])
if row_secs:
for row_sec in row_secs:
entry_secs = row_sec.find_all(["entry", "td"])
is_header: bool = row_sec.parent.name in ["thead"]
if not isinstance(row_sec, Tag):
continue
entry_secs = row_sec(["entry", "td"])
is_header: bool = (
row_sec.parent is not None
and row_sec.parent.name == "thead"
)
ncols = 0
local_row: list[TableCell] = []
@ -1582,23 +1600,26 @@ class XmlTable:
if entry_secs:
wrong_nbr_cols = False
for ientry, entry_sec in enumerate(entry_secs):
if not isinstance(entry_sec, Tag):
continue
text = entry_sec.get_text().strip()
# start-end
namest = entry_sec.attrs.get("namest", None)
nameend = entry_sec.attrs.get("nameend", None)
if isinstance(namest, str) and namest.isnumeric():
namest = int(namest)
else:
namest = ientry + 1
namest = entry_sec.get("namest")
nameend = entry_sec.get("nameend")
start = (
int(namest)
if isinstance(namest, str) and namest.isnumeric()
else ientry + 1
)
if isinstance(nameend, str) and nameend.isnumeric():
nameend = int(nameend)
end = int(nameend)
shift = 0
else:
nameend = ientry + 2
end = ientry + 2
shift = 1
if nameend > len(tg_range["cell_offst"]):
if end > len(tg_range["cell_offst"]):
wrong_nbr_cols = True
self.nbr_messages += 1
if self.nbr_messages <= self.max_nbr_messages:
@ -1608,8 +1629,8 @@ class XmlTable:
break
range_ = [
tg_range["cell_offst"][namest - 1],
tg_range["cell_offst"][nameend - 1] - shift,
tg_range["cell_offst"][start - 1],
tg_range["cell_offst"][end - 1] - shift,
]
# add row and replicate cell if needed
@ -1668,7 +1689,7 @@ class XmlTable:
A docling table data.
"""
section = self._soup.find("table")
if section is not None:
if isinstance(section, Tag):
table = self._parse_table(section)
if table.num_rows == 0 or table.num_cols == 0:
_log.warning("The parsed USPTO table is empty")

View File

@ -1,21 +1,23 @@
import importlib
import json
import logging
import platform
import re
import sys
import tempfile
import time
import warnings
from enum import Enum
from pathlib import Path
from typing import Annotated, Dict, Iterable, List, Optional, Type
import rich.table
import typer
from docling_core.types.doc import ImageRefMode
from docling_core.utils.file import resolve_source_to_path
from pydantic import TypeAdapter, ValidationError
from pydantic import TypeAdapter
from docling.backend.docling_parse_backend import DoclingParseDocumentBackend
from docling.backend.docling_parse_v2_backend import DoclingParseV2DocumentBackend
from docling.backend.docling_parse_v4_backend import DoclingParseV4DocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
from docling.datamodel.base_models import (
@ -29,18 +31,22 @@ from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
EasyOcrOptions,
OcrEngine,
OcrMacOptions,
OcrOptions,
PaginatedPipelineOptions,
PdfBackend,
PdfPipeline,
PdfPipelineOptions,
RapidOcrOptions,
TableFormerMode,
TesseractCliOcrOptions,
TesseractOcrOptions,
VlmModelType,
VlmPipelineOptions,
granite_vision_vlm_conversion_options,
smoldocling_vlm_conversion_options,
smoldocling_vlm_mlx_conversion_options,
)
from docling.datamodel.settings import settings
from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption
from docling.models.factories import get_ocr_factory
from docling.pipeline.vlm_pipeline import VlmPipeline
warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr")
@ -48,8 +54,11 @@ warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr
_log = logging.getLogger(__name__)
from rich.console import Console
console = Console()
err_console = Console(stderr=True)
ocr_factory_internal = get_ocr_factory(allow_external_plugins=False)
ocr_engines_enum_internal = ocr_factory_internal.get_enum()
app = typer.Typer(
name="Docling",
@ -65,10 +74,33 @@ def version_callback(value: bool):
docling_core_version = importlib.metadata.version("docling-core")
docling_ibm_models_version = importlib.metadata.version("docling-ibm-models")
docling_parse_version = importlib.metadata.version("docling-parse")
platform_str = platform.platform()
py_impl_version = sys.implementation.cache_tag
py_lang_version = platform.python_version()
print(f"Docling version: {docling_version}")
print(f"Docling Core version: {docling_core_version}")
print(f"Docling IBM Models version: {docling_ibm_models_version}")
print(f"Docling Parse version: {docling_parse_version}")
print(f"Python: {py_impl_version} ({py_lang_version})")
print(f"Platform: {platform_str}")
raise typer.Exit()
def show_external_plugins_callback(value: bool):
if value:
ocr_factory_all = get_ocr_factory(allow_external_plugins=True)
table = rich.table.Table(title="Available OCR engines")
table.add_column("Name", justify="right")
table.add_column("Plugin")
table.add_column("Package")
for meta in ocr_factory_all.registered_meta.values():
if not meta.module.startswith("docling."):
table.add_row(
f"[bold]{meta.kind}[/bold]",
meta.plugin_name,
meta.module.split(".")[0],
)
rich.print(table)
raise typer.Exit()
@ -176,6 +208,14 @@ def convert(
help="Image export mode for the document (only in case of JSON, Markdown or HTML). With `placeholder`, only the position of the image is marked in the output. In `embedded` mode, the image is embedded as base64 encoded string. In `referenced` mode, the image is exported in PNG format and referenced from the main exported document.",
),
] = ImageRefMode.EMBEDDED,
pipeline: Annotated[
PdfPipeline,
typer.Option(..., help="Choose the pipeline to process PDF or image files."),
] = PdfPipeline.STANDARD,
vlm_model: Annotated[
VlmModelType,
typer.Option(..., help="Choose the VLM model to use with PDF or image files."),
] = VlmModelType.SMOLDOCLING,
ocr: Annotated[
bool,
typer.Option(
@ -190,8 +230,16 @@ def convert(
),
] = False,
ocr_engine: Annotated[
OcrEngine, typer.Option(..., help="The OCR engine to use.")
] = OcrEngine.EASYOCR,
str,
typer.Option(
...,
help=(
f"The OCR engine to use. When --allow-external-plugins is *not* set, the available values are: "
f"{', '.join((o.value for o in ocr_engines_enum_internal))}. "
f"Use the option --show-external-plugins to see the options allowed with external plugins."
),
),
] = EasyOcrOptions.kind,
ocr_lang: Annotated[
Optional[str],
typer.Option(
@ -205,17 +253,57 @@ def convert(
table_mode: Annotated[
TableFormerMode,
typer.Option(..., help="The mode to use in the table structure model."),
] = TableFormerMode.FAST,
] = TableFormerMode.ACCURATE,
enrich_code: Annotated[
bool,
typer.Option(..., help="Enable the code enrichment model in the pipeline."),
] = False,
enrich_formula: Annotated[
bool,
typer.Option(..., help="Enable the formula enrichment model in the pipeline."),
] = False,
enrich_picture_classes: Annotated[
bool,
typer.Option(
...,
help="Enable the picture classification enrichment model in the pipeline.",
),
] = False,
enrich_picture_description: Annotated[
bool,
typer.Option(..., help="Enable the picture description model in the pipeline."),
] = False,
artifacts_path: Annotated[
Optional[Path],
typer.Option(..., help="If provided, the location of the model artifacts."),
] = None,
enable_remote_services: Annotated[
bool,
typer.Option(
..., help="Must be enabled when using models connecting to remote services."
),
] = False,
allow_external_plugins: Annotated[
bool,
typer.Option(
..., help="Must be enabled for loading modules from third-party plugins."
),
] = False,
show_external_plugins: Annotated[
bool,
typer.Option(
...,
help="List the third-party plugins which are available when the option --allow-external-plugins is set.",
callback=show_external_plugins_callback,
is_eager=True,
),
] = False,
abort_on_error: Annotated[
bool,
typer.Option(
...,
"--abort-on-error/--no-abort-on-error",
help="If enabled, the bitmap content will be processed using OCR.",
help="If enabled, the processing will be aborted when the first error is encountered.",
),
] = False,
output: Annotated[
@ -337,59 +425,88 @@ def convert(
export_txt = OutputFormat.TEXT in to_formats
export_doctags = OutputFormat.DOCTAGS in to_formats
if ocr_engine == OcrEngine.EASYOCR:
ocr_options: OcrOptions = EasyOcrOptions(force_full_page_ocr=force_ocr)
elif ocr_engine == OcrEngine.TESSERACT_CLI:
ocr_options = TesseractCliOcrOptions(force_full_page_ocr=force_ocr)
elif ocr_engine == OcrEngine.TESSERACT:
ocr_options = TesseractOcrOptions(force_full_page_ocr=force_ocr)
elif ocr_engine == OcrEngine.OCRMAC:
ocr_options = OcrMacOptions(force_full_page_ocr=force_ocr)
elif ocr_engine == OcrEngine.RAPIDOCR:
ocr_options = RapidOcrOptions(force_full_page_ocr=force_ocr)
else:
raise RuntimeError(f"Unexpected OCR engine type {ocr_engine}")
ocr_factory = get_ocr_factory(allow_external_plugins=allow_external_plugins)
ocr_options: OcrOptions = ocr_factory.create_options( # type: ignore
kind=ocr_engine,
force_full_page_ocr=force_ocr,
)
ocr_lang_list = _split_list(ocr_lang)
if ocr_lang_list is not None:
ocr_options.lang = ocr_lang_list
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
pipeline_options = PdfPipelineOptions(
accelerator_options=accelerator_options,
do_ocr=ocr,
ocr_options=ocr_options,
do_table_structure=True,
document_timeout=document_timeout,
)
pipeline_options.table_structure_options.do_cell_matching = (
True # do_cell_matching
)
pipeline_options.table_structure_options.mode = table_mode
pipeline_options: PaginatedPipelineOptions
if image_export_mode != ImageRefMode.PLACEHOLDER:
pipeline_options.generate_page_images = True
pipeline_options.generate_picture_images = (
True # FIXME: to be deprecated in verson 3
if pipeline == PdfPipeline.STANDARD:
pipeline_options = PdfPipelineOptions(
allow_external_plugins=allow_external_plugins,
enable_remote_services=enable_remote_services,
accelerator_options=accelerator_options,
do_ocr=ocr,
ocr_options=ocr_options,
do_table_structure=True,
do_code_enrichment=enrich_code,
do_formula_enrichment=enrich_formula,
do_picture_description=enrich_picture_description,
do_picture_classification=enrich_picture_classes,
document_timeout=document_timeout,
)
pipeline_options.table_structure_options.do_cell_matching = (
True # do_cell_matching
)
pipeline_options.table_structure_options.mode = table_mode
if image_export_mode != ImageRefMode.PLACEHOLDER:
pipeline_options.generate_page_images = True
pipeline_options.generate_picture_images = (
True # FIXME: to be deprecated in verson 3
)
pipeline_options.images_scale = 2
backend: Type[PdfDocumentBackend]
if pdf_backend == PdfBackend.DLPARSE_V1:
backend = DoclingParseDocumentBackend
elif pdf_backend == PdfBackend.DLPARSE_V2:
backend = DoclingParseV2DocumentBackend
elif pdf_backend == PdfBackend.DLPARSE_V4:
backend = DoclingParseV4DocumentBackend # type: ignore
elif pdf_backend == PdfBackend.PYPDFIUM2:
backend = PyPdfiumDocumentBackend # type: ignore
else:
raise RuntimeError(f"Unexpected PDF backend type {pdf_backend}")
pdf_format_option = PdfFormatOption(
pipeline_options=pipeline_options,
backend=backend, # pdf_backend
)
elif pipeline == PdfPipeline.VLM:
pipeline_options = VlmPipelineOptions()
if vlm_model == VlmModelType.GRANITE_VISION:
pipeline_options.vlm_options = granite_vision_vlm_conversion_options
elif vlm_model == VlmModelType.SMOLDOCLING:
pipeline_options.vlm_options = smoldocling_vlm_conversion_options
if sys.platform == "darwin":
try:
import mlx_vlm
pipeline_options.vlm_options = (
smoldocling_vlm_mlx_conversion_options
)
except ImportError:
_log.warning(
"To run SmolDocling faster, please install mlx-vlm:\n"
"pip install mlx-vlm"
)
pdf_format_option = PdfFormatOption(
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
)
pipeline_options.images_scale = 2
if artifacts_path is not None:
pipeline_options.artifacts_path = artifacts_path
if pdf_backend == PdfBackend.DLPARSE_V1:
backend: Type[PdfDocumentBackend] = DoclingParseDocumentBackend
elif pdf_backend == PdfBackend.DLPARSE_V2:
backend = DoclingParseV2DocumentBackend
elif pdf_backend == PdfBackend.PYPDFIUM2:
backend = PyPdfiumDocumentBackend
else:
raise RuntimeError(f"Unexpected PDF backend type {pdf_backend}")
pdf_format_option = PdfFormatOption(
pipeline_options=pipeline_options,
backend=backend, # pdf_backend
)
format_options: Dict[InputFormat, FormatOption] = {
InputFormat.PDF: pdf_format_option,
InputFormat.IMAGE: pdf_format_option,

131
docling/cli/models.py Normal file
View File

@ -0,0 +1,131 @@
import logging
import warnings
from enum import Enum
from pathlib import Path
from typing import Annotated, Optional
import typer
from rich.console import Console
from rich.logging import RichHandler
from docling.datamodel.settings import settings
from docling.utils.model_downloader import download_models
warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr")
console = Console()
err_console = Console(stderr=True)
app = typer.Typer(
name="Docling models helper",
no_args_is_help=True,
add_completion=False,
pretty_exceptions_enable=False,
)
class _AvailableModels(str, Enum):
LAYOUT = "layout"
TABLEFORMER = "tableformer"
CODE_FORMULA = "code_formula"
PICTURE_CLASSIFIER = "picture_classifier"
SMOLVLM = "smolvlm"
GRANITE_VISION = "granite_vision"
EASYOCR = "easyocr"
_default_models = [
_AvailableModels.LAYOUT,
_AvailableModels.TABLEFORMER,
_AvailableModels.CODE_FORMULA,
_AvailableModels.PICTURE_CLASSIFIER,
_AvailableModels.EASYOCR,
]
@app.command("download")
def download(
output_dir: Annotated[
Path,
typer.Option(
...,
"-o",
"--output-dir",
help="The directory where to download the models.",
),
] = (settings.cache_dir / "models"),
force: Annotated[
bool, typer.Option(..., help="If true, the download will be forced.")
] = False,
models: Annotated[
Optional[list[_AvailableModels]],
typer.Argument(
help=f"Models to download (default behavior: a predefined set of models will be downloaded).",
),
] = None,
all: Annotated[
bool,
typer.Option(
...,
"--all",
help="If true, all available models will be downloaded (mutually exclusive with passing specific models).",
show_default=True,
),
] = False,
quiet: Annotated[
bool,
typer.Option(
...,
"-q",
"--quiet",
help="No extra output is generated, the CLI prints only the directory with the cached models.",
),
] = False,
):
if models and all:
raise typer.BadParameter(
"Cannot simultaneously set 'all' parameter and specify models to download."
)
if not quiet:
FORMAT = "%(message)s"
logging.basicConfig(
level=logging.INFO,
format="[blue]%(message)s[/blue]",
datefmt="[%X]",
handlers=[RichHandler(show_level=False, show_time=False, markup=True)],
)
to_download = models or ([m for m in _AvailableModels] if all else _default_models)
output_dir = download_models(
output_dir=output_dir,
force=force,
progress=(not quiet),
with_layout=_AvailableModels.LAYOUT in to_download,
with_tableformer=_AvailableModels.TABLEFORMER in to_download,
with_code_formula=_AvailableModels.CODE_FORMULA in to_download,
with_picture_classifier=_AvailableModels.PICTURE_CLASSIFIER in to_download,
with_smolvlm=_AvailableModels.SMOLVLM in to_download,
with_granite_vision=_AvailableModels.GRANITE_VISION in to_download,
with_easyocr=_AvailableModels.EASYOCR in to_download,
)
if quiet:
typer.echo(output_dir)
else:
typer.secho(f"\nModels downloaded into: {output_dir}.", fg="green")
console.print(
"\n",
"Docling can now be configured for running offline using the local artifacts.\n\n",
"Using the CLI:",
f"`docling --artifacts-path={output_dir} FILE`",
"\n",
"Using Python: see the documentation at <https://docling-project.github.io/docling/usage>.",
)
click_app = typer.main.get_command(app)
if __name__ == "__main__":
app()

17
docling/cli/tools.py Normal file
View File

@ -0,0 +1,17 @@
import typer
from docling.cli.models import app as models_app
app = typer.Typer(
name="Docling helpers",
no_args_is_help=True,
add_completion=False,
pretty_exceptions_enable=False,
)
app.add_typer(models_app, name="models")
click_app = typer.main.get_command(app)
if __name__ == "__main__":
app()

View File

@ -4,10 +4,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
from docling_core.types.doc import (
BoundingBox,
DocItemLabel,
NodeItem,
PictureDataType,
Size,
TableCell,
)
from docling_core.types.doc.page import SegmentedPdfPage, TextCell
from docling_core.types.io import ( # DO ΝΟΤ REMOVE; explicitly exposed from this location
DocumentStream,
)
@ -33,13 +35,15 @@ class InputFormat(str, Enum):
DOCX = "docx"
PPTX = "pptx"
HTML = "html"
XML_PUBMED = "xml_pubmed"
IMAGE = "image"
PDF = "pdf"
ASCIIDOC = "asciidoc"
MD = "md"
CSV = "csv"
XLSX = "xlsx"
XML_USPTO = "xml_uspto"
XML_JATS = "xml_jats"
JSON_DOCLING = "json_docling"
class OutputFormat(str, Enum):
@ -56,11 +60,13 @@ FormatToExtensions: Dict[InputFormat, List[str]] = {
InputFormat.PDF: ["pdf"],
InputFormat.MD: ["md"],
InputFormat.HTML: ["html", "htm", "xhtml"],
InputFormat.XML_PUBMED: ["xml", "nxml"],
InputFormat.XML_JATS: ["xml", "nxml"],
InputFormat.IMAGE: ["jpg", "jpeg", "png", "tif", "tiff", "bmp"],
InputFormat.ASCIIDOC: ["adoc", "asciidoc", "asc"],
InputFormat.CSV: ["csv"],
InputFormat.XLSX: ["xlsx"],
InputFormat.XML_USPTO: ["xml", "txt"],
InputFormat.JSON_DOCLING: ["json"],
}
FormatToMimeType: Dict[InputFormat, List[str]] = {
@ -74,7 +80,7 @@ FormatToMimeType: Dict[InputFormat, List[str]] = {
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
],
InputFormat.HTML: ["text/html", "application/xhtml+xml"],
InputFormat.XML_PUBMED: ["application/xml"],
InputFormat.XML_JATS: ["application/xml"],
InputFormat.IMAGE: [
"image/png",
"image/jpeg",
@ -85,10 +91,12 @@ FormatToMimeType: Dict[InputFormat, List[str]] = {
InputFormat.PDF: ["application/pdf"],
InputFormat.ASCIIDOC: ["text/asciidoc"],
InputFormat.MD: ["text/markdown", "text/x-markdown"],
InputFormat.CSV: ["text/csv"],
InputFormat.XLSX: [
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
],
InputFormat.XML_USPTO: ["application/xml", "text/plain"],
InputFormat.JSON_DOCLING: ["application/json"],
}
MimeTypeToFormat: dict[str, list[InputFormat]] = {
@ -116,14 +124,10 @@ class ErrorItem(BaseModel):
error_message: str
class Cell(BaseModel):
id: int
text: str
bbox: BoundingBox
class OcrCell(Cell):
confidence: float
# class Cell(BaseModel):
# id: int
# text: str
# bbox: BoundingBox
class Cluster(BaseModel):
@ -131,7 +135,7 @@ class Cluster(BaseModel):
label: DocItemLabel
bbox: BoundingBox
confidence: float = 1.0
cells: List[Cell] = []
cells: List[TextCell] = []
children: List["Cluster"] = [] # Add child cluster support
@ -147,6 +151,10 @@ class LayoutPrediction(BaseModel):
clusters: List[Cluster] = []
class VlmPrediction(BaseModel):
text: str = ""
class ContainerElement(
BasePageElement
): # Used for Form and Key-Value-Regions, only for typing.
@ -190,6 +198,7 @@ class PagePredictions(BaseModel):
tablestructure: Optional[TableStructurePrediction] = None
figures_classification: Optional[FigureClassificationPrediction] = None
equations_prediction: Optional[EquationPrediction] = None
vlm_response: Optional[VlmPrediction] = None
PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
@ -201,13 +210,21 @@ class AssembledUnit(BaseModel):
headers: List[PageElement] = []
class ItemAndImageEnrichmentElement(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
item: NodeItem
image: Image
class Page(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
page_no: int
# page_hash: Optional[str] = None
size: Optional[Size] = None
cells: List[Cell] = []
cells: List[TextCell] = []
parsed_page: Optional[SegmentedPdfPage] = None
predictions: PagePredictions = PagePredictions()
assembled: Optional[AssembledUnit] = None
@ -219,12 +236,28 @@ class Page(BaseModel):
{}
) # Cache of images in different scales. By default it is cleared during assembling.
def get_image(self, scale: float = 1.0) -> Optional[Image]:
def get_image(
self, scale: float = 1.0, cropbox: Optional[BoundingBox] = None
) -> Optional[Image]:
if self._backend is None:
return self._image_cache.get(scale, None)
if not scale in self._image_cache:
self._image_cache[scale] = self._backend.get_page_image(scale=scale)
return self._image_cache[scale]
if cropbox is None:
self._image_cache[scale] = self._backend.get_page_image(scale=scale)
else:
return self._backend.get_page_image(scale=scale, cropbox=cropbox)
if cropbox is None:
return self._image_cache[scale]
else:
page_im = self._image_cache[scale]
assert self.size is not None
return page_im.crop(
cropbox.to_top_left_origin(page_height=self.size.height)
.scaled(scale=scale)
.as_tuple()
)
@property
def image(self) -> Optional[Image]:

View File

@ -1,3 +1,4 @@
import csv
import logging
import re
from enum import Enum
@ -157,6 +158,8 @@ class InputDocument(BaseModel):
self.page_count = self._backend.page_count()
if not self.page_count <= self.limits.max_num_pages:
self.valid = False
elif self.page_count < self.limits.page_range[0]:
self.valid = False
except (FileNotFoundError, OSError) as e:
self.valid = False
@ -294,6 +297,7 @@ class _DocumentConversionInput(BaseModel):
mime = _DocumentConversionInput._mime_from_extension(ext)
mime = mime or _DocumentConversionInput._detect_html_xhtml(content)
mime = mime or _DocumentConversionInput._detect_csv(content)
mime = mime or "text/plain"
formats = MimeTypeToFormat.get(mime, [])
if formats:
@ -329,11 +333,11 @@ class _DocumentConversionInput(BaseModel):
):
input_format = InputFormat.XML_USPTO
if (
InputFormat.XML_PUBMED in formats
and "/NLM//DTD JATS" in xml_doctype
if InputFormat.XML_JATS in formats and (
"JATS-journalpublishing" in xml_doctype
or "JATS-archive" in xml_doctype
):
input_format = InputFormat.XML_PUBMED
input_format = InputFormat.XML_JATS
elif mime == "text/plain":
if InputFormat.XML_USPTO in formats and content_str.startswith("PATN\r\n"):
@ -350,6 +354,12 @@ class _DocumentConversionInput(BaseModel):
mime = FormatToMimeType[InputFormat.HTML][0]
elif ext in FormatToExtensions[InputFormat.MD]:
mime = FormatToMimeType[InputFormat.MD][0]
elif ext in FormatToExtensions[InputFormat.CSV]:
mime = FormatToMimeType[InputFormat.CSV][0]
elif ext in FormatToExtensions[InputFormat.JSON_DOCLING]:
mime = FormatToMimeType[InputFormat.JSON_DOCLING][0]
elif ext in FormatToExtensions[InputFormat.PDF]:
mime = FormatToMimeType[InputFormat.PDF][0]
return mime
@staticmethod
@ -386,3 +396,32 @@ class _DocumentConversionInput(BaseModel):
return "application/xml"
return None
@staticmethod
def _detect_csv(
content: bytes,
) -> Optional[Literal["text/csv"]]:
"""Guess the mime type of a CSV file from its content.
Args:
content: A short piece of a document from its beginning.
Returns:
The mime type of a CSV file, or None if the content does
not match any of the format.
"""
content_str = content.decode("ascii", errors="ignore").strip()
# Ensure there's at least one newline (CSV is usually multi-line)
if "\n" not in content_str:
return None
# Use csv.Sniffer to detect CSV characteristics
try:
dialect = csv.Sniffer().sniff(content_str)
if dialect.delimiter in {",", ";", "\t", "|"}: # Common delimiters
return "text/csv"
except csv.Error:
return None
return None

View File

@ -1,16 +1,19 @@
import logging
import os
import warnings
import re
from enum import Enum
from pathlib import Path
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Type, Union
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
from pydantic import (
AnyUrl,
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
)
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing_extensions import deprecated
_log = logging.getLogger(__name__)
@ -31,7 +34,19 @@ class AcceleratorOptions(BaseSettings):
)
num_threads: int = 4
device: AcceleratorDevice = AcceleratorDevice.AUTO
device: Union[str, AcceleratorDevice] = "auto"
cuda_use_flash_attention2: bool = False
@field_validator("device")
def validate_device(cls, value):
# "auto", "cpu", "cuda", "mps", or "cuda:N"
if value in {d.value for d in AcceleratorDevice} or re.match(
r"^cuda(:\d+)?$", value
):
return value
raise ValueError(
"Invalid device option. Use 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'."
)
@model_validator(mode="before")
@classmethod
@ -47,7 +62,6 @@ class AcceleratorOptions(BaseSettings):
"""
if isinstance(data, dict):
input_num_threads = data.get("num_threads")
# Check if to set the num_threads from the alternative envvar
if input_num_threads is None:
docling_num_threads = os.getenv("DOCLING_NUM_THREADS")
@ -63,6 +77,12 @@ class AcceleratorOptions(BaseSettings):
return data
class BaseOptions(BaseModel):
"""Base class for options."""
kind: ClassVar[str]
class TableFormerMode(str, Enum):
"""Modes for the TableFormer model."""
@ -79,13 +99,12 @@ class TableStructureOptions(BaseModel):
# are merged across table columns.
# False: Let table structure model define the text cells, ignore PDF cells.
)
mode: TableFormerMode = TableFormerMode.FAST
mode: TableFormerMode = TableFormerMode.ACCURATE
class OcrOptions(BaseModel):
class OcrOptions(BaseOptions):
"""OCR options."""
kind: str
lang: List[str]
force_full_page_ocr: bool = False # If enabled a full page OCR is always applied
bitmap_area_threshold: float = (
@ -96,7 +115,7 @@ class OcrOptions(BaseModel):
class RapidOcrOptions(OcrOptions):
"""Options for the RapidOCR engine."""
kind: Literal["rapidocr"] = "rapidocr"
kind: ClassVar[Literal["rapidocr"]] = "rapidocr"
# English and chinese are the most commly used models and have been tested with RapidOCR.
lang: List[str] = [
@ -125,6 +144,7 @@ class RapidOcrOptions(OcrOptions):
det_model_path: Optional[str] = None # same default as rapidocr
cls_model_path: Optional[str] = None # same default as rapidocr
rec_model_path: Optional[str] = None # same default as rapidocr
rec_keys_path: Optional[str] = None # same default as rapidocr
model_config = ConfigDict(
extra="forbid",
@ -134,12 +154,12 @@ class RapidOcrOptions(OcrOptions):
class EasyOcrOptions(OcrOptions):
"""Options for the EasyOCR engine."""
kind: Literal["easyocr"] = "easyocr"
kind: ClassVar[Literal["easyocr"]] = "easyocr"
lang: List[str] = ["fr", "de", "es", "en"]
use_gpu: Optional[bool] = None
confidence_threshold: float = 0.65
confidence_threshold: float = 0.5
model_storage_directory: Optional[str] = None
recog_network: Optional[str] = "standard"
@ -154,7 +174,7 @@ class EasyOcrOptions(OcrOptions):
class TesseractCliOcrOptions(OcrOptions):
"""Options for the TesseractCli engine."""
kind: Literal["tesseract"] = "tesseract"
kind: ClassVar[Literal["tesseract"]] = "tesseract"
lang: List[str] = ["fra", "deu", "spa", "eng"]
tesseract_cmd: str = "tesseract"
path: Optional[str] = None
@ -167,7 +187,7 @@ class TesseractCliOcrOptions(OcrOptions):
class TesseractOcrOptions(OcrOptions):
"""Options for the Tesseract engine."""
kind: Literal["tesserocr"] = "tesserocr"
kind: ClassVar[Literal["tesserocr"]] = "tesserocr"
lang: List[str] = ["fra", "deu", "spa", "eng"]
path: Optional[str] = None
@ -179,7 +199,7 @@ class TesseractOcrOptions(OcrOptions):
class OcrMacOptions(OcrOptions):
"""Options for the Mac OCR engine."""
kind: Literal["ocrmac"] = "ocrmac"
kind: ClassVar[Literal["ocrmac"]] = "ocrmac"
lang: List[str] = ["fr-FR", "de-DE", "es-ES", "en-US"]
recognition: str = "accurate"
framework: str = "vision"
@ -189,6 +209,110 @@ class OcrMacOptions(OcrOptions):
)
class PictureDescriptionBaseOptions(BaseOptions):
batch_size: int = 8
scale: float = 2
bitmap_area_threshold: float = (
0.2 # percentage of the area for a bitmap to processed with the models
)
class PictureDescriptionApiOptions(PictureDescriptionBaseOptions):
kind: ClassVar[Literal["api"]] = "api"
url: AnyUrl = AnyUrl("http://localhost:8000/v1/chat/completions")
headers: Dict[str, str] = {}
params: Dict[str, Any] = {}
timeout: float = 20
prompt: str = "Describe this image in a few sentences."
provenance: str = ""
class PictureDescriptionVlmOptions(PictureDescriptionBaseOptions):
kind: ClassVar[Literal["vlm"]] = "vlm"
repo_id: str
prompt: str = "Describe this image in a few sentences."
# Config from here https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig
generation_config: Dict[str, Any] = dict(max_new_tokens=200, do_sample=False)
@property
def repo_cache_folder(self) -> str:
return self.repo_id.replace("/", "--")
smolvlm_picture_description = PictureDescriptionVlmOptions(
repo_id="HuggingFaceTB/SmolVLM-256M-Instruct"
)
# phi_picture_description = PictureDescriptionVlmOptions(repo_id="microsoft/Phi-3-vision-128k-instruct")
granite_picture_description = PictureDescriptionVlmOptions(
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
prompt="What is shown in this image?",
)
class BaseVlmOptions(BaseModel):
kind: str
prompt: str
class ResponseFormat(str, Enum):
DOCTAGS = "doctags"
MARKDOWN = "markdown"
class InferenceFramework(str, Enum):
MLX = "mlx"
TRANSFORMERS = "transformers"
class HuggingFaceVlmOptions(BaseVlmOptions):
kind: Literal["hf_model_options"] = "hf_model_options"
repo_id: str
load_in_8bit: bool = True
llm_int8_threshold: float = 6.0
quantized: bool = False
inference_framework: InferenceFramework
response_format: ResponseFormat
@property
def repo_cache_folder(self) -> str:
return self.repo_id.replace("/", "--")
smoldocling_vlm_mlx_conversion_options = HuggingFaceVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview-mlx-bf16",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.MLX,
)
smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.TRANSFORMERS,
)
granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
repo_id="ibm-granite/granite-vision-3.1-2b-preview",
# prompt="OCR the full page to markdown.",
prompt="OCR this image.",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS,
)
class VlmModelType(str, Enum):
SMOLDOCLING = "smoldocling"
GRANITE_VISION = "granite_vision"
# Define an enum for the backend options
class PdfBackend(str, Enum):
"""Enum of valid PDF backends."""
@ -196,9 +320,11 @@ class PdfBackend(str, Enum):
PYPDFIUM2 = "pypdfium2"
DLPARSE_V1 = "dlparse_v1"
DLPARSE_V2 = "dlparse_v2"
DLPARSE_V4 = "dlparse_v4"
# Define an enum for the ocr engines
@deprecated("Use ocr_factory.registered_enum")
class OcrEngine(str, Enum):
"""Enum of valid OCR engines."""
@ -217,23 +343,47 @@ class PipelineOptions(BaseModel):
)
document_timeout: Optional[float] = None
accelerator_options: AcceleratorOptions = AcceleratorOptions()
enable_remote_services: bool = False
allow_external_plugins: bool = False
class PdfPipelineOptions(PipelineOptions):
class PaginatedPipelineOptions(PipelineOptions):
artifacts_path: Optional[Union[Path, str]] = None
images_scale: float = 1.0
generate_page_images: bool = False
generate_picture_images: bool = False
class VlmPipelineOptions(PaginatedPipelineOptions):
generate_page_images: bool = True
force_backend_text: bool = (
False # (To be used with vlms, or other generative models)
)
# If True, text from backend will be used instead of generated text
vlm_options: Union[HuggingFaceVlmOptions] = smoldocling_vlm_conversion_options
class PdfPipelineOptions(PaginatedPipelineOptions):
"""Options for the PDF pipeline."""
artifacts_path: Optional[Union[Path, str]] = None
do_table_structure: bool = True # True: perform table structure extraction
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
do_code_enrichment: bool = False # True: perform code OCR
do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code
do_picture_classification: bool = False # True: classify pictures in documents
do_picture_description: bool = False # True: run describe pictures in documents
force_backend_text: bool = (
False # (To be used with vlms, or other generative models)
)
# If True, text from backend will be used instead of generated text
table_structure_options: TableStructureOptions = TableStructureOptions()
ocr_options: Union[
EasyOcrOptions,
TesseractCliOcrOptions,
TesseractOcrOptions,
OcrMacOptions,
RapidOcrOptions,
] = Field(EasyOcrOptions(), discriminator="kind")
ocr_options: OcrOptions = EasyOcrOptions()
picture_description_options: PictureDescriptionBaseOptions = (
smolvlm_picture_description
)
images_scale: float = 1.0
generate_page_images: bool = False
@ -246,3 +396,10 @@ class PdfPipelineOptions(PipelineOptions):
"before conversion and then use the `TableItem.get_image` function."
),
)
generate_parsed_pages: bool = False
class PdfPipeline(str, Enum):
STANDARD = "standard"
VLM = "vlm"

View File

@ -1,13 +1,28 @@
import sys
from pathlib import Path
from typing import Annotated, Optional, Tuple
from pydantic import BaseModel
from pydantic import BaseModel, PlainValidator
from pydantic_settings import BaseSettings, SettingsConfigDict
def _validate_page_range(v: Tuple[int, int]) -> Tuple[int, int]:
if v[0] < 1 or v[1] < v[0]:
raise ValueError(
"Invalid page range: start must be ≥ 1 and end must be ≥ start."
)
return v
PageRange = Annotated[Tuple[int, int], PlainValidator(_validate_page_range)]
DEFAULT_PAGE_RANGE: PageRange = (1, sys.maxsize)
class DocumentLimits(BaseModel):
max_num_pages: int = sys.maxsize
max_file_size: int = sys.maxsize
page_range: PageRange = DEFAULT_PAGE_RANGE
class BatchConcurrencySettings(BaseModel):
@ -46,5 +61,8 @@ class AppSettings(BaseSettings):
perf: BatchConcurrencySettings
debug: DebugSettings
cache_dir: Path = Path.home() / ".cache" / "docling"
artifacts_path: Optional[Path] = None
settings = AppSettings(perf=BatchConcurrencySettings(), debug=DebugSettings())

View File

@ -1,21 +1,25 @@
import hashlib
import logging
import math
import sys
import time
from functools import partial
from pathlib import Path
from typing import Dict, Iterable, Iterator, List, Optional, Type, Union
from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union
from pydantic import BaseModel, ConfigDict, model_validator, validate_call
from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.asciidoc_backend import AsciiDocBackend
from docling.backend.docling_parse_v2_backend import DoclingParseV2DocumentBackend
from docling.backend.csv_backend import CsvDocumentBackend
from docling.backend.docling_parse_v4_backend import DoclingParseV4DocumentBackend
from docling.backend.html_backend import HTMLDocumentBackend
from docling.backend.json.docling_json_backend import DoclingJSONBackend
from docling.backend.md_backend import MarkdownDocumentBackend
from docling.backend.msexcel_backend import MsExcelDocumentBackend
from docling.backend.mspowerpoint_backend import MsPowerpointDocumentBackend
from docling.backend.msword_backend import MsWordDocumentBackend
from docling.backend.xml.pubmed_backend import PubMedDocumentBackend
from docling.backend.xml.jats_backend import JatsDocumentBackend
from docling.backend.xml.uspto_backend import PatentUsptoDocumentBackend
from docling.datamodel.base_models import (
ConversionStatus,
@ -30,7 +34,12 @@ from docling.datamodel.document import (
_DocumentConversionInput,
)
from docling.datamodel.pipeline_options import PipelineOptions
from docling.datamodel.settings import DocumentLimits, settings
from docling.datamodel.settings import (
DEFAULT_PAGE_RANGE,
DocumentLimits,
PageRange,
settings,
)
from docling.exceptions import ConversionError
from docling.pipeline.base_pipeline import BasePipeline
from docling.pipeline.simple_pipeline import SimplePipeline
@ -54,6 +63,11 @@ class FormatOption(BaseModel):
return self
class CsvFormatOption(FormatOption):
pipeline_cls: Type = SimplePipeline
backend: Type[AbstractDocumentBackend] = CsvDocumentBackend
class ExcelFormatOption(FormatOption):
pipeline_cls: Type = SimplePipeline
backend: Type[AbstractDocumentBackend] = MsExcelDocumentBackend
@ -89,23 +103,26 @@ class PatentUsptoFormatOption(FormatOption):
backend: Type[PatentUsptoDocumentBackend] = PatentUsptoDocumentBackend
class XMLPubMedFormatOption(FormatOption):
class XMLJatsFormatOption(FormatOption):
pipeline_cls: Type = SimplePipeline
backend: Type[AbstractDocumentBackend] = PubMedDocumentBackend
backend: Type[AbstractDocumentBackend] = JatsDocumentBackend
class ImageFormatOption(FormatOption):
pipeline_cls: Type = StandardPdfPipeline
backend: Type[AbstractDocumentBackend] = DoclingParseV2DocumentBackend
backend: Type[AbstractDocumentBackend] = DoclingParseV4DocumentBackend
class PdfFormatOption(FormatOption):
pipeline_cls: Type = StandardPdfPipeline
backend: Type[AbstractDocumentBackend] = DoclingParseV2DocumentBackend
backend: Type[AbstractDocumentBackend] = DoclingParseV4DocumentBackend
def _get_default_option(format: InputFormat) -> FormatOption:
format_to_default_options = {
InputFormat.CSV: FormatOption(
pipeline_cls=SimplePipeline, backend=CsvDocumentBackend
),
InputFormat.XLSX: FormatOption(
pipeline_cls=SimplePipeline, backend=MsExcelDocumentBackend
),
@ -127,14 +144,17 @@ def _get_default_option(format: InputFormat) -> FormatOption:
InputFormat.XML_USPTO: FormatOption(
pipeline_cls=SimplePipeline, backend=PatentUsptoDocumentBackend
),
InputFormat.XML_PUBMED: FormatOption(
pipeline_cls=SimplePipeline, backend=PubMedDocumentBackend
InputFormat.XML_JATS: FormatOption(
pipeline_cls=SimplePipeline, backend=JatsDocumentBackend
),
InputFormat.IMAGE: FormatOption(
pipeline_cls=StandardPdfPipeline, backend=DoclingParseV2DocumentBackend
pipeline_cls=StandardPdfPipeline, backend=DoclingParseV4DocumentBackend
),
InputFormat.PDF: FormatOption(
pipeline_cls=StandardPdfPipeline, backend=DoclingParseV2DocumentBackend
pipeline_cls=StandardPdfPipeline, backend=DoclingParseV4DocumentBackend
),
InputFormat.JSON_DOCLING: FormatOption(
pipeline_cls=SimplePipeline, backend=DoclingJSONBackend
),
}
if (options := format_to_default_options.get(format)) is not None:
@ -162,7 +182,14 @@ class DocumentConverter:
)
for format in self.allowed_formats
}
self.initialized_pipelines: Dict[Type[BasePipeline], BasePipeline] = {}
self.initialized_pipelines: Dict[
Tuple[Type[BasePipeline], str], BasePipeline
] = {}
def _get_pipeline_options_hash(self, pipeline_options: PipelineOptions) -> str:
"""Generate a hash of pipeline options to use as part of the cache key."""
options_str = str(pipeline_options.model_dump())
return hashlib.md5(options_str.encode("utf-8")).hexdigest()
def initialize_pipeline(self, format: InputFormat):
"""Initialize the conversion pipeline for the selected format."""
@ -180,6 +207,7 @@ class DocumentConverter:
raises_on_error: bool = True,
max_num_pages: int = sys.maxsize,
max_file_size: int = sys.maxsize,
page_range: PageRange = DEFAULT_PAGE_RANGE,
) -> ConversionResult:
all_res = self.convert_all(
source=[source],
@ -187,6 +215,7 @@ class DocumentConverter:
max_num_pages=max_num_pages,
max_file_size=max_file_size,
headers=headers,
page_range=page_range,
)
return next(all_res)
@ -198,10 +227,12 @@ class DocumentConverter:
raises_on_error: bool = True, # True: raises on first conversion error; False: does not raise on conv error
max_num_pages: int = sys.maxsize,
max_file_size: int = sys.maxsize,
page_range: PageRange = DEFAULT_PAGE_RANGE,
) -> Iterator[ConversionResult]:
limits = DocumentLimits(
max_num_pages=max_num_pages,
max_file_size=max_file_size,
page_range=page_range,
)
conv_input = _DocumentConversionInput(
path_or_stream_iterator=source, limits=limits, headers=headers
@ -256,31 +287,36 @@ class DocumentConverter:
yield item
def _get_pipeline(self, doc_format: InputFormat) -> Optional[BasePipeline]:
"""Retrieve or initialize a pipeline, reusing instances based on class and options."""
fopt = self.format_to_options.get(doc_format)
if fopt is None:
if fopt is None or fopt.pipeline_options is None:
return None
else:
pipeline_class = fopt.pipeline_cls
pipeline_options = fopt.pipeline_options
if pipeline_options is None:
return None
# TODO this will ignore if different options have been defined for the same pipeline class.
if (
pipeline_class not in self.initialized_pipelines
or self.initialized_pipelines[pipeline_class].pipeline_options
!= pipeline_options
):
self.initialized_pipelines[pipeline_class] = pipeline_class(
pipeline_class = fopt.pipeline_cls
pipeline_options = fopt.pipeline_options
options_hash = self._get_pipeline_options_hash(pipeline_options)
# Use a composite key to cache pipelines
cache_key = (pipeline_class, options_hash)
if cache_key not in self.initialized_pipelines:
_log.info(
f"Initializing pipeline for {pipeline_class.__name__} with options hash {options_hash}"
)
self.initialized_pipelines[cache_key] = pipeline_class(
pipeline_options=pipeline_options
)
return self.initialized_pipelines[pipeline_class]
else:
_log.debug(
f"Reusing cached pipeline for {pipeline_class.__name__} with options hash {options_hash}"
)
return self.initialized_pipelines[cache_key]
def _process_document(
self, in_doc: InputDocument, raises_on_error: bool
) -> ConversionResult:
valid = (
self.allowed_formats is not None and in_doc.format in self.allowed_formats
)
@ -322,7 +358,6 @@ class DocumentConverter:
else:
if raises_on_error:
raise ConversionError(f"Input document {in_doc.file} is not valid.")
else:
# invalid doc or not of desired format
conv_res = ConversionResult(

View File

@ -4,3 +4,7 @@ class BaseError(RuntimeError):
class ConversionError(BaseError):
pass
class OperationNotAllowed(BaseError):
pass

View File

@ -1,10 +1,20 @@
from abc import ABC, abstractmethod
from typing import Any, Iterable
from typing import Any, Generic, Iterable, Optional, Protocol, Type
from docling_core.types.doc import DoclingDocument, NodeItem
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
from typing_extensions import TypeVar
from docling.datamodel.base_models import Page
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import BaseOptions
from docling.datamodel.settings import settings
class BaseModelWithOptions(Protocol):
@classmethod
def get_options_type(cls) -> Type[BaseOptions]: ...
def __init__(self, *, options: BaseOptions, **kwargs): ...
class BasePageModel(ABC):
@ -15,14 +25,71 @@ class BasePageModel(ABC):
pass
class BaseEnrichmentModel(ABC):
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
class GenericEnrichmentModel(ABC, Generic[EnrichElementT]):
elements_batch_size: int = settings.perf.elements_batch_size
@abstractmethod
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
pass
@abstractmethod
def __call__(
self, doc: DoclingDocument, element_batch: Iterable[NodeItem]
) -> Iterable[Any]:
def prepare_element(
self, conv_res: ConversionResult, element: NodeItem
) -> Optional[EnrichElementT]:
pass
@abstractmethod
def __call__(
self, doc: DoclingDocument, element_batch: Iterable[EnrichElementT]
) -> Iterable[NodeItem]:
pass
class BaseEnrichmentModel(GenericEnrichmentModel[NodeItem]):
def prepare_element(
self, conv_res: ConversionResult, element: NodeItem
) -> Optional[NodeItem]:
if self.is_processable(doc=conv_res.document, element=element):
return element
return None
class BaseItemAndImageEnrichmentModel(
GenericEnrichmentModel[ItemAndImageEnrichmentElement]
):
images_scale: float
expansion_factor: float = 0.0
def prepare_element(
self, conv_res: ConversionResult, element: NodeItem
) -> Optional[ItemAndImageEnrichmentElement]:
if not self.is_processable(doc=conv_res.document, element=element):
return None
assert isinstance(element, DocItem)
element_prov = element.prov[0]
bbox = element_prov.bbox
width = bbox.r - bbox.l
height = bbox.t - bbox.b
# TODO: move to a utility in the BoundingBox class
expanded_bbox = BoundingBox(
l=bbox.l - width * self.expansion_factor,
t=bbox.t + height * self.expansion_factor,
r=bbox.r + width * self.expansion_factor,
b=bbox.b - height * self.expansion_factor,
coord_origin=bbox.coord_origin,
)
page_ix = element_prov.page_no - 1
cropped_image = conv_res.pages[page_ix].get_image(
scale=self.images_scale, cropbox=expanded_bbox
)
return ItemAndImageEnrichmentElement(item=element, image=cropped_image)

View File

@ -2,25 +2,33 @@ import copy
import logging
from abc import abstractmethod
from pathlib import Path
from typing import Iterable, List
from typing import Iterable, List, Optional, Type
import numpy as np
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import BoundingRectangle, PdfTextCell, TextCell
from PIL import Image, ImageDraw
from rtree import index
from scipy.ndimage import find_objects, label
from scipy.ndimage import binary_dilation, find_objects, label
from docling.datamodel.base_models import Cell, OcrCell, Page
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import OcrOptions
from docling.datamodel.pipeline_options import AcceleratorOptions, OcrOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.models.base_model import BaseModelWithOptions, BasePageModel
_log = logging.getLogger(__name__)
class BaseOcrModel(BasePageModel):
def __init__(self, enabled: bool, options: OcrOptions):
class BaseOcrModel(BasePageModel, BaseModelWithOptions):
def __init__(
self,
*,
enabled: bool,
artifacts_path: Optional[Path],
options: OcrOptions,
accelerator_options: AcceleratorOptions,
):
self.enabled = enabled
self.options = options
@ -43,6 +51,12 @@ class BaseOcrModel(BasePageModel):
np_image = np.array(image)
# Dilate the image by 10 pixels to merge nearby bitmap rectangles
structure = np.ones(
(20, 20)
) # Create a 20x20 structure element (10 pixels in all directions)
np_image = binary_dilation(np_image > 0, structure=structure)
# Find the connected components
labeled_image, num_features = label(
np_image > 0
@ -72,7 +86,7 @@ class BaseOcrModel(BasePageModel):
bitmap_rects = []
coverage, ocr_rects = find_ocr_rects(page.size, bitmap_rects)
# return full-page rectangle if sufficiently covered with bitmaps
# return full-page rectangle if page is dominantly covered with bitmaps
if self.options.force_full_page_ocr or coverage > max(
BITMAP_COVERAGE_TRESHOLD, self.options.bitmap_area_threshold
):
@ -85,17 +99,11 @@ class BaseOcrModel(BasePageModel):
coord_origin=CoordOrigin.TOPLEFT,
)
]
# return individual rectangles if the bitmap coverage is smaller
else: # coverage <= BITMAP_COVERAGE_TRESHOLD:
# skip OCR if the bitmap area on the page is smaller than the options threshold
ocr_rects = [
rect
for rect in ocr_rects
if rect.area() / (page.size.width * page.size.height)
> self.options.bitmap_area_threshold
]
# return individual rectangles if the bitmap coverage is above the threshold
elif coverage > self.options.bitmap_area_threshold:
return ocr_rects
else: # overall coverage of bitmaps is too low, drop all bitmap rectangles.
return []
# Filters OCR cells by dropping any OCR cell that intersects with an existing programmatic cell.
def _filter_ocr_cells(self, ocr_cells, programmatic_cells):
@ -104,11 +112,13 @@ class BaseOcrModel(BasePageModel):
p.dimension = 2
idx = index.Index(properties=p)
for i, cell in enumerate(programmatic_cells):
idx.insert(i, cell.bbox.as_tuple())
idx.insert(i, cell.rect.to_bounding_box().as_tuple())
def is_overlapping_with_existing_cells(ocr_cell):
# Query the R-tree to get overlapping rectangles
possible_matches_index = list(idx.intersection(ocr_cell.bbox.as_tuple()))
possible_matches_index = list(
idx.intersection(ocr_cell.rect.to_bounding_box().as_tuple())
)
return (
len(possible_matches_index) > 0
@ -125,10 +135,7 @@ class BaseOcrModel(BasePageModel):
"""
if self.options.force_full_page_ocr:
# If a full page OCR is forced, use only the OCR cells
cells = [
Cell(id=c_ocr.id, text=c_ocr.text, bbox=c_ocr.bbox)
for c_ocr in ocr_cells
]
cells = ocr_cells
return cells
## Remove OCR cells which overlap with programmatic cells.
@ -138,20 +145,35 @@ class BaseOcrModel(BasePageModel):
def draw_ocr_rects_and_cells(self, conv_res, page, ocr_rects, show: bool = False):
image = copy.deepcopy(page.image)
scale_x = image.width / page.size.width
scale_y = image.height / page.size.height
draw = ImageDraw.Draw(image, "RGBA")
# Draw OCR rectangles as yellow filled rect
for rect in ocr_rects:
x0, y0, x1, y1 = rect.as_tuple()
y0 *= scale_x
y1 *= scale_y
x0 *= scale_x
x1 *= scale_x
shade_color = (255, 255, 0, 40) # transparent yellow
draw.rectangle([(x0, y0), (x1, y1)], fill=shade_color, outline=None)
# Draw OCR and programmatic cells
for tc in page.cells:
x0, y0, x1, y1 = tc.bbox.as_tuple()
color = "red"
if isinstance(tc, OcrCell):
color = "magenta"
x0, y0, x1, y1 = tc.rect.to_bounding_box().as_tuple()
y0 *= scale_x
y1 *= scale_y
x0 *= scale_x
x1 *= scale_x
if y1 <= y0:
y1, y0 = y0, y1
color = "magenta" if tc.from_ocr else "gray"
draw.rectangle([(x0, y0), (x1, y1)], outline=color)
if show:
@ -171,3 +193,8 @@ class BaseOcrModel(BasePageModel):
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
pass
@classmethod
@abstractmethod
def get_options_type(cls) -> Type[OcrOptions]:
pass

View File

@ -0,0 +1,330 @@
import re
from collections import Counter
from pathlib import Path
from typing import Iterable, List, Literal, Optional, Tuple, Union
import numpy as np
from docling_core.types.doc import (
CodeItem,
DocItemLabel,
DoclingDocument,
NodeItem,
TextItem,
)
from docling_core.types.doc.labels import CodeLanguageLabel
from PIL import Image, ImageOps
from pydantic import BaseModel
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.models.base_model import BaseItemAndImageEnrichmentModel
from docling.utils.accelerator_utils import decide_device
class CodeFormulaModelOptions(BaseModel):
"""
Configuration options for the CodeFormulaModel.
Attributes
----------
kind : str
Type of the model. Fixed value "code_formula".
do_code_enrichment : bool
True if code enrichment is enabled, False otherwise.
do_formula_enrichment : bool
True if formula enrichment is enabled, False otherwise.
"""
kind: Literal["code_formula"] = "code_formula"
do_code_enrichment: bool = True
do_formula_enrichment: bool = True
class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
"""
Model for processing and enriching documents with code and formula predictions.
Attributes
----------
enabled : bool
True if the model is enabled, False otherwise.
options : CodeFormulaModelOptions
Configuration options for the CodeFormulaModel.
code_formula_model : CodeFormulaPredictor
The predictor model for code and formula processing.
Methods
-------
__init__(self, enabled, artifacts_path, accelerator_options, code_formula_options)
Initializes the CodeFormulaModel with the given configuration options.
is_processable(self, doc, element)
Determines if a given element in a document can be processed by the model.
__call__(self, doc, element_batch)
Processes the given batch of elements and enriches them with predictions.
"""
_model_repo_folder = "ds4sd--CodeFormula"
elements_batch_size = 5
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
expansion_factor = 0.18
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
options: CodeFormulaModelOptions,
accelerator_options: AcceleratorOptions,
):
"""
Initializes the CodeFormulaModel with the given configuration.
Parameters
----------
enabled : bool
True if the model is enabled, False otherwise.
artifacts_path : Path
Path to the directory containing the model artifacts.
options : CodeFormulaModelOptions
Configuration options for the model.
accelerator_options : AcceleratorOptions
Options specifying the device and number of threads for acceleration.
"""
self.enabled = enabled
self.options = options
if self.enabled:
device = decide_device(accelerator_options.device)
from docling_ibm_models.code_formula_model.code_formula_predictor import (
CodeFormulaPredictor,
)
if artifacts_path is None:
artifacts_path = self.download_models()
else:
artifacts_path = artifacts_path / self._model_repo_folder
self.code_formula_model = CodeFormulaPredictor(
artifacts_path=str(artifacts_path),
device=device,
num_threads=accelerator_options.num_threads,
)
@staticmethod
def download_models(
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/CodeFormula",
force_download=force,
local_dir=local_dir,
revision="v1.0.2",
)
return Path(download_path)
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
"""
Determines if a given element in a document can be processed by the model.
Parameters
----------
doc : DoclingDocument
The document being processed.
element : NodeItem
The element within the document to check.
Returns
-------
bool
True if the element can be processed, False otherwise.
"""
return self.enabled and (
(isinstance(element, CodeItem) and self.options.do_code_enrichment)
or (
isinstance(element, TextItem)
and element.label == DocItemLabel.FORMULA
and self.options.do_formula_enrichment
)
)
def _extract_code_language(self, input_string: str) -> Tuple[str, Optional[str]]:
"""Extracts a programming language from the beginning of a string.
This function checks if the input string starts with a pattern of the form
``<_some_language_>``. If it does, it extracts the language string and returns
a tuple of (remainder, language). Otherwise, it returns the original string
and `None`.
Args:
input_string (str): The input string, which may start with ``<_language_>``.
Returns:
Tuple[str, Optional[str]]:
A tuple where:
- The first element is either:
- The remainder of the string (everything after ``<_language_>``),
if a match is found; or
- The original string, if no match is found.
- The second element is the extracted language if a match is found;
otherwise, `None`.
"""
pattern = r"^<_([^_>]+)_>\s(.*)"
match = re.match(pattern, input_string, flags=re.DOTALL)
if match:
language = str(match.group(1)) # the captured programming language
remainder = str(match.group(2)) # everything after the <_language_>
return remainder, language
else:
return input_string, None
def _get_code_language_enum(self, value: Optional[str]) -> CodeLanguageLabel:
"""
Converts a string to a corresponding `CodeLanguageLabel` enum member.
If the provided string does not match any value in `CodeLanguageLabel`,
it defaults to `CodeLanguageLabel.UNKNOWN`.
Args:
value (Optional[str]): The string representation of the code language or None.
Returns:
CodeLanguageLabel: The corresponding enum member if the value is valid,
otherwise `CodeLanguageLabel.UNKNOWN`.
"""
if not isinstance(value, str):
return CodeLanguageLabel.UNKNOWN
try:
return CodeLanguageLabel(value)
except ValueError:
return CodeLanguageLabel.UNKNOWN
def _get_most_frequent_edge_color(self, pil_img: Image.Image):
"""
Compute the most frequent color along the outer edges of a PIL image.
Parameters
----------
pil_img : Image.Image
A PIL Image in any mode (L, RGB, RGBA, etc.).
Returns
-------
(int) or (tuple): The most common edge color as a scalar (for grayscale) or
tuple (for RGB/RGBA).
"""
# Convert to NumPy array for easy pixel access
img_np = np.array(pil_img)
if img_np.ndim == 2:
# Grayscale-like image: shape (H, W)
# Extract edges: top row, bottom row, left col, right col
top = img_np[0, :] # shape (W,)
bottom = img_np[-1, :] # shape (W,)
left = img_np[:, 0] # shape (H,)
right = img_np[:, -1] # shape (H,)
# Concatenate all edges
edges = np.concatenate([top, bottom, left, right])
# Count frequencies
freq = Counter(edges.tolist())
most_common_value, _ = freq.most_common(1)[0]
return int(most_common_value) # single channel color
else:
# Color image: shape (H, W, C)
top = img_np[0, :, :] # shape (W, C)
bottom = img_np[-1, :, :] # shape (W, C)
left = img_np[:, 0, :] # shape (H, C)
right = img_np[:, -1, :] # shape (H, C)
# Concatenate edges along first axis
edges = np.concatenate([top, bottom, left, right], axis=0)
# Convert each color to a tuple for counting
edges_as_tuples = [tuple(pixel) for pixel in edges]
freq = Counter(edges_as_tuples)
most_common_value, _ = freq.most_common(1)[0]
return most_common_value # e.g. (R, G, B) or (R, G, B, A)
def _pad_with_most_frequent_edge_color(
self, img: Union[Image.Image, np.ndarray], padding: Tuple[int, int, int, int]
):
"""
Pads an image (PIL or NumPy array) using the most frequent edge color.
Parameters
----------
img : Union[Image.Image, np.ndarray]
The original image.
padding : tuple
Padding (left, top, right, bottom) in pixels.
Returns
-------
Image.Image: A new PIL image with the specified padding.
"""
if isinstance(img, np.ndarray):
pil_img = Image.fromarray(img)
else:
pil_img = img
most_freq_color = self._get_most_frequent_edge_color(pil_img)
padded_img = ImageOps.expand(pil_img, border=padding, fill=most_freq_color)
return padded_img
def __call__(
self,
doc: DoclingDocument,
element_batch: Iterable[ItemAndImageEnrichmentElement],
) -> Iterable[NodeItem]:
"""
Processes the given batch of elements and enriches them with predictions.
Parameters
----------
doc : DoclingDocument
The document being processed.
element_batch : Iterable[ItemAndImageEnrichmentElement]
A batch of elements to be processed.
Returns
-------
Iterable[Any]
An iterable of enriched elements.
"""
if not self.enabled:
for element in element_batch:
yield element.item
return
labels: List[str] = []
images: List[Union[Image.Image, np.ndarray]] = []
elements: List[TextItem] = []
for el in element_batch:
assert isinstance(el.item, TextItem)
elements.append(el.item)
labels.append(el.item.label)
images.append(
self._pad_with_most_frequent_edge_color(el.image, (20, 10, 20, 10))
)
outputs = self.code_formula_model.predict(images, labels)
for item, output in zip(elements, outputs):
if isinstance(item, CodeItem):
output, code_language = self._extract_code_language(output)
item.code_language = self._get_code_language_enum(code_language)
item.text = output
yield item

View File

@ -0,0 +1,190 @@
from pathlib import Path
from typing import Iterable, List, Literal, Optional, Tuple, Union
import numpy as np
from docling_core.types.doc import (
DoclingDocument,
NodeItem,
PictureClassificationClass,
PictureClassificationData,
PictureItem,
)
from PIL import Image
from pydantic import BaseModel
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.models.base_model import BaseEnrichmentModel
from docling.utils.accelerator_utils import decide_device
class DocumentPictureClassifierOptions(BaseModel):
"""
Options for configuring the DocumentPictureClassifier.
Attributes
----------
kind : Literal["document_picture_classifier"]
Identifier for the type of classifier.
"""
kind: Literal["document_picture_classifier"] = "document_picture_classifier"
class DocumentPictureClassifier(BaseEnrichmentModel):
"""
A model for classifying pictures in documents.
This class enriches document pictures with predicted classifications
based on a predefined set of classes.
Attributes
----------
enabled : bool
Whether the classifier is enabled for use.
options : DocumentPictureClassifierOptions
Configuration options for the classifier.
document_picture_classifier : DocumentPictureClassifierPredictor
The underlying prediction model, loaded if the classifier is enabled.
Methods
-------
__init__(enabled, artifacts_path, options, accelerator_options)
Initializes the classifier with specified configurations.
is_processable(doc, element)
Checks if the given element can be processed by the classifier.
__call__(doc, element_batch)
Processes a batch of elements and adds classification annotations.
"""
_model_repo_folder = "ds4sd--DocumentFigureClassifier"
images_scale = 2
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
options: DocumentPictureClassifierOptions,
accelerator_options: AcceleratorOptions,
):
"""
Initializes the DocumentPictureClassifier.
Parameters
----------
enabled : bool
Indicates whether the classifier is enabled.
artifacts_path : Optional[Union[Path, str]],
Path to the directory containing model artifacts.
options : DocumentPictureClassifierOptions
Configuration options for the classifier.
accelerator_options : AcceleratorOptions
Options for configuring the device and parallelism.
"""
self.enabled = enabled
self.options = options
if self.enabled:
device = decide_device(accelerator_options.device)
from docling_ibm_models.document_figure_classifier_model.document_figure_classifier_predictor import (
DocumentFigureClassifierPredictor,
)
if artifacts_path is None:
artifacts_path = self.download_models()
else:
artifacts_path = artifacts_path / self._model_repo_folder
self.document_picture_classifier = DocumentFigureClassifierPredictor(
artifacts_path=str(artifacts_path),
device=device,
num_threads=accelerator_options.num_threads,
)
@staticmethod
def download_models(
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/DocumentFigureClassifier",
force_download=force,
local_dir=local_dir,
revision="v1.0.1",
)
return Path(download_path)
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
"""
Determines if the given element can be processed by the classifier.
Parameters
----------
doc : DoclingDocument
The document containing the element.
element : NodeItem
The element to be checked.
Returns
-------
bool
True if the element is a PictureItem and processing is enabled; False otherwise.
"""
return self.enabled and isinstance(element, PictureItem)
def __call__(
self,
doc: DoclingDocument,
element_batch: Iterable[NodeItem],
) -> Iterable[NodeItem]:
"""
Processes a batch of elements and enriches them with classification predictions.
Parameters
----------
doc : DoclingDocument
The document containing the elements to be processed.
element_batch : Iterable[NodeItem]
A batch of pictures to classify.
Returns
-------
Iterable[NodeItem]
An iterable of NodeItem objects after processing. The field
'data.classification' is added containing the classification for each picture.
"""
if not self.enabled:
for element in element_batch:
yield element
return
images: List[Union[Image.Image, np.ndarray]] = []
elements: List[PictureItem] = []
for el in element_batch:
assert isinstance(el, PictureItem)
elements.append(el)
img = el.get_image(doc)
assert img is not None
images.append(img)
outputs = self.document_picture_classifier.predict(images)
for element, output in zip(elements, outputs):
element.annotations.append(
PictureClassificationData(
provenance="DocumentPictureClassifier",
predicted_classes=[
PictureClassificationClass(
class_name=pred[0],
confidence=pred[1],
)
for pred in output
],
)
)
yield element

View File

@ -1,328 +0,0 @@
import copy
import random
from pathlib import Path
from typing import List, Union
from deepsearch_glm.andromeda_nlp import nlp_model
from docling_core.types.doc import BoundingBox, CoordOrigin, DoclingDocument
from docling_core.types.legacy_doc.base import BoundingBox as DsBoundingBox
from docling_core.types.legacy_doc.base import (
Figure,
PageDimensions,
PageReference,
Prov,
Ref,
)
from docling_core.types.legacy_doc.base import Table as DsSchemaTable
from docling_core.types.legacy_doc.base import TableCell
from docling_core.types.legacy_doc.document import BaseText
from docling_core.types.legacy_doc.document import (
CCSDocumentDescription as DsDocumentDescription,
)
from docling_core.types.legacy_doc.document import CCSFileInfoObject as DsFileInfoObject
from docling_core.types.legacy_doc.document import ExportedCCSDocument as DsDocument
from PIL import ImageDraw
from pydantic import BaseModel, ConfigDict, TypeAdapter
from docling.datamodel.base_models import (
Cluster,
ContainerElement,
FigureElement,
Table,
TextElement,
)
from docling.datamodel.document import ConversionResult, layout_label_to_ds_type
from docling.datamodel.settings import settings
from docling.utils.glm_utils import to_docling_document
from docling.utils.profiling import ProfilingScope, TimeRecorder
from docling.utils.utils import create_hash
class GlmOptions(BaseModel):
model_config = ConfigDict(protected_namespaces=())
model_names: str = "" # e.g. "language;term;reference"
class GlmModel:
def __init__(self, options: GlmOptions):
self.options = options
self.model = nlp_model(loglevel="error", text_ordering=True)
def _to_legacy_document(self, conv_res) -> DsDocument:
title = ""
desc: DsDocumentDescription = DsDocumentDescription(logs=[])
page_hashes = [
PageReference(
hash=create_hash(conv_res.input.document_hash + ":" + str(p.page_no)),
page=p.page_no + 1,
model="default",
)
for p in conv_res.pages
]
file_info = DsFileInfoObject(
filename=conv_res.input.file.name,
document_hash=conv_res.input.document_hash,
num_pages=conv_res.input.page_count,
page_hashes=page_hashes,
)
main_text: List[Union[Ref, BaseText]] = []
tables: List[DsSchemaTable] = []
figures: List[Figure] = []
page_no_to_page = {p.page_no: p for p in conv_res.pages}
for element in conv_res.assembled.elements:
# Convert bboxes to lower-left origin.
target_bbox = DsBoundingBox(
element.cluster.bbox.to_bottom_left_origin(
page_no_to_page[element.page_no].size.height
).as_tuple()
)
if isinstance(element, TextElement):
main_text.append(
BaseText(
text=element.text,
obj_type=layout_label_to_ds_type.get(element.label),
name=element.label,
prov=[
Prov(
bbox=target_bbox,
page=element.page_no + 1,
span=[0, len(element.text)],
)
],
)
)
elif isinstance(element, Table):
index = len(tables)
ref_str = f"#/tables/{index}"
main_text.append(
Ref(
name=element.label,
obj_type=layout_label_to_ds_type.get(element.label),
ref=ref_str,
),
)
# Initialise empty table data grid (only empty cells)
table_data = [
[
TableCell(
text="",
# bbox=[0,0,0,0],
spans=[[i, j]],
obj_type="body",
)
for j in range(element.num_cols)
]
for i in range(element.num_rows)
]
# Overwrite cells in table data for which there is actual cell content.
for cell in element.table_cells:
for i in range(
min(cell.start_row_offset_idx, element.num_rows),
min(cell.end_row_offset_idx, element.num_rows),
):
for j in range(
min(cell.start_col_offset_idx, element.num_cols),
min(cell.end_col_offset_idx, element.num_cols),
):
celltype = "body"
if cell.column_header:
celltype = "col_header"
elif cell.row_header:
celltype = "row_header"
elif cell.row_section:
celltype = "row_section"
def make_spans(cell):
for rspan in range(
min(cell.start_row_offset_idx, element.num_rows),
min(cell.end_row_offset_idx, element.num_rows),
):
for cspan in range(
min(
cell.start_col_offset_idx, element.num_cols
),
min(cell.end_col_offset_idx, element.num_cols),
):
yield [rspan, cspan]
spans = list(make_spans(cell))
if cell.bbox is not None:
bbox = cell.bbox.to_bottom_left_origin(
page_no_to_page[element.page_no].size.height
).as_tuple()
else:
bbox = None
table_data[i][j] = TableCell(
text=cell.text,
bbox=bbox,
# col=j,
# row=i,
spans=spans,
obj_type=celltype,
# col_span=[cell.start_col_offset_idx, cell.end_col_offset_idx],
# row_span=[cell.start_row_offset_idx, cell.end_row_offset_idx]
)
tables.append(
DsSchemaTable(
num_cols=element.num_cols,
num_rows=element.num_rows,
obj_type=layout_label_to_ds_type.get(element.label),
data=table_data,
prov=[
Prov(
bbox=target_bbox,
page=element.page_no + 1,
span=[0, 0],
)
],
)
)
elif isinstance(element, FigureElement):
index = len(figures)
ref_str = f"#/figures/{index}"
main_text.append(
Ref(
name=element.label,
obj_type=layout_label_to_ds_type.get(element.label),
ref=ref_str,
),
)
figures.append(
Figure(
prov=[
Prov(
bbox=target_bbox,
page=element.page_no + 1,
span=[0, 0],
)
],
obj_type=layout_label_to_ds_type.get(element.label),
payload={
"children": TypeAdapter(List[Cluster]).dump_python(
element.cluster.children
)
}, # hack to channel child clusters through GLM
)
)
elif isinstance(element, ContainerElement):
main_text.append(
BaseText(
text="",
payload={
"children": TypeAdapter(List[Cluster]).dump_python(
element.cluster.children
)
}, # hack to channel child clusters through GLM
obj_type=layout_label_to_ds_type.get(element.label),
name=element.label,
prov=[
Prov(
bbox=target_bbox,
page=element.page_no + 1,
span=[0, 0],
)
],
)
)
page_dimensions = [
PageDimensions(page=p.page_no + 1, height=p.size.height, width=p.size.width)
for p in conv_res.pages
if p.size is not None
]
ds_doc: DsDocument = DsDocument(
name=title,
description=desc,
file_info=file_info,
main_text=main_text,
tables=tables,
figures=figures,
page_dimensions=page_dimensions,
)
return ds_doc
def __call__(self, conv_res: ConversionResult) -> DoclingDocument:
with TimeRecorder(conv_res, "glm", scope=ProfilingScope.DOCUMENT):
ds_doc = self._to_legacy_document(conv_res)
ds_doc_dict = ds_doc.model_dump(by_alias=True, exclude_none=True)
glm_doc = self.model.apply_on_doc(ds_doc_dict)
docling_doc: DoclingDocument = to_docling_document(glm_doc) # Experimental
# DEBUG code:
def draw_clusters_and_cells(ds_document, page_no, show: bool = False):
clusters_to_draw = []
image = copy.deepcopy(conv_res.pages[page_no].image)
for ix, elem in enumerate(ds_document.main_text):
if isinstance(elem, BaseText):
prov = elem.prov[0] # type: ignore
elif isinstance(elem, Ref):
_, arr, index = elem.ref.split("/")
index = int(index) # type: ignore
if arr == "tables":
prov = ds_document.tables[index].prov[0]
elif arr == "figures":
prov = ds_document.pictures[index].prov[0]
else:
prov = None
if prov and prov.page == page_no:
clusters_to_draw.append(
Cluster(
id=ix,
label=elem.name,
bbox=BoundingBox.from_tuple(
coord=prov.bbox, # type: ignore
origin=CoordOrigin.BOTTOMLEFT,
).to_top_left_origin(conv_res.pages[page_no].size.height),
)
)
draw = ImageDraw.Draw(image)
for c in clusters_to_draw:
x0, y0, x1, y1 = c.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
draw.text((x0 + 2, y0 + 2), f"{c.id}:{c.label}", fill=(255, 0, 0, 255))
cell_color = (
random.randint(30, 140),
random.randint(30, 140),
random.randint(30, 140),
)
for tc in c.cells: # [:1]:
x0, y0, x1, y1 = tc.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
if show:
image.show()
else:
out_path: Path = (
Path(settings.debug.debug_output_path)
/ f"debug_{conv_res.input.file.stem}"
)
out_path.mkdir(parents=True, exist_ok=True)
out_file = out_path / f"doc_page_{page_no:05}.png"
image.save(str(out_file), format="png")
# for item in ds_doc.page_dimensions:
# page_no = item.page
# draw_clusters_and_cells(ds_doc, page_no)
return docling_doc

View File

@ -1,34 +1,46 @@
import logging
import warnings
from typing import Iterable
import zipfile
from pathlib import Path
from typing import Iterable, List, Optional, Type
import numpy
import torch
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import BoundingRectangle, TextCell
from docling.datamodel.base_models import Cell, OcrCell, Page
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
EasyOcrOptions,
OcrOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
from docling.utils.utils import download_url_with_progress
_log = logging.getLogger(__name__)
class EasyOcrModel(BaseOcrModel):
_model_repo_folder = "EasyOcr"
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
options: EasyOcrOptions,
accelerator_options: AcceleratorOptions,
):
super().__init__(enabled=enabled, options=options)
super().__init__(
enabled=enabled,
artifacts_path=artifacts_path,
options=options,
accelerator_options=accelerator_options,
)
self.options: EasyOcrOptions
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
@ -62,15 +74,55 @@ class EasyOcrModel(BaseOcrModel):
)
use_gpu = self.options.use_gpu
download_enabled = self.options.download_enabled
model_storage_directory = self.options.model_storage_directory
if artifacts_path is not None and model_storage_directory is None:
download_enabled = False
model_storage_directory = str(artifacts_path / self._model_repo_folder)
self.reader = easyocr.Reader(
lang_list=self.options.lang,
gpu=use_gpu,
model_storage_directory=self.options.model_storage_directory,
model_storage_directory=model_storage_directory,
recog_network=self.options.recog_network,
download_enabled=self.options.download_enabled,
download_enabled=download_enabled,
verbose=False,
)
@staticmethod
def download_models(
detection_models: List[str] = ["craft"],
recognition_models: List[str] = ["english_g2", "latin_g2"],
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
# Models are located in https://github.com/JaidedAI/EasyOCR/blob/master/easyocr/config.py
from easyocr.config import detection_models as det_models_dict
from easyocr.config import recognition_models as rec_models_dict
if local_dir is None:
local_dir = settings.cache_dir / "models" / EasyOcrModel._model_repo_folder
local_dir.mkdir(parents=True, exist_ok=True)
# Collect models to download
download_list = []
for model_name in detection_models:
if model_name in det_models_dict:
download_list.append(det_models_dict[model_name])
for model_name in recognition_models:
if model_name in rec_models_dict["gen2"]:
download_list.append(rec_models_dict["gen2"][model_name])
# Download models
for model_details in download_list:
buf = download_url_with_progress(model_details["url"], progress=progress)
with zipfile.ZipFile(buf, "r") as zip_ref:
zip_ref.extractall(local_dir)
return local_dir
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
@ -103,18 +155,22 @@ class EasyOcrModel(BaseOcrModel):
del im
cells = [
OcrCell(
id=ix,
TextCell(
index=ix,
text=line[1],
orig=line[1],
from_ocr=True,
confidence=line[2],
bbox=BoundingBox.from_tuple(
coord=(
(line[0][0][0] / self.scale) + ocr_rect.l,
(line[0][0][1] / self.scale) + ocr_rect.t,
(line[0][2][0] / self.scale) + ocr_rect.l,
(line[0][2][1] / self.scale) + ocr_rect.t,
),
origin=CoordOrigin.TOPLEFT,
rect=BoundingRectangle.from_bounding_box(
BoundingBox.from_tuple(
coord=(
(line[0][0][0] / self.scale) + ocr_rect.l,
(line[0][0][1] / self.scale) + ocr_rect.t,
(line[0][2][0] / self.scale) + ocr_rect.l,
(line[0][2][1] / self.scale) + ocr_rect.t,
),
origin=CoordOrigin.TOPLEFT,
)
),
)
for ix, line in enumerate(result)
@ -130,3 +186,7 @@ class EasyOcrModel(BaseOcrModel):
self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects)
yield page
@classmethod
def get_options_type(cls) -> Type[OcrOptions]:
return EasyOcrOptions

View File

@ -0,0 +1,27 @@
import logging
from functools import lru_cache
from docling.models.factories.ocr_factory import OcrFactory
from docling.models.factories.picture_description_factory import (
PictureDescriptionFactory,
)
logger = logging.getLogger(__name__)
@lru_cache()
def get_ocr_factory(allow_external_plugins: bool = False) -> OcrFactory:
factory = OcrFactory()
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
logger.info("Registered ocr engines: %r", factory.registered_kind)
return factory
@lru_cache()
def get_picture_description_factory(
allow_external_plugins: bool = False,
) -> PictureDescriptionFactory:
factory = PictureDescriptionFactory()
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
logger.info("Registered picture descriptions: %r", factory.registered_kind)
return factory

View File

@ -0,0 +1,122 @@
import enum
import logging
from abc import ABCMeta
from typing import Generic, Optional, Type, TypeVar
from pluggy import PluginManager
from pydantic import BaseModel
from docling.datamodel.pipeline_options import BaseOptions
from docling.models.base_model import BaseModelWithOptions
A = TypeVar("A", bound=BaseModelWithOptions)
logger = logging.getLogger(__name__)
class FactoryMeta(BaseModel):
kind: str
plugin_name: str
module: str
class BaseFactory(Generic[A], metaclass=ABCMeta):
default_plugin_name = "docling"
def __init__(self, plugin_attr_name: str, plugin_name=default_plugin_name):
self.plugin_name = plugin_name
self.plugin_attr_name = plugin_attr_name
self._classes: dict[Type[BaseOptions], Type[A]] = {}
self._meta: dict[Type[BaseOptions], FactoryMeta] = {}
@property
def registered_kind(self) -> list[str]:
return list(opt.kind for opt in self._classes.keys())
def get_enum(self) -> enum.Enum:
return enum.Enum(
self.plugin_attr_name + "_enum",
names={kind: kind for kind in self.registered_kind},
type=str,
module=__name__,
)
@property
def classes(self):
return self._classes
@property
def registered_meta(self):
return self._meta
def create_instance(self, options: BaseOptions, **kwargs) -> A:
try:
_cls = self._classes[type(options)]
return _cls(options=options, **kwargs)
except KeyError:
raise RuntimeError(self._err_msg_on_class_not_found(options.kind))
def create_options(self, kind: str, *args, **kwargs) -> BaseOptions:
for opt_cls, _ in self._classes.items():
if opt_cls.kind == kind:
return opt_cls(*args, **kwargs)
raise RuntimeError(self._err_msg_on_class_not_found(kind))
def _err_msg_on_class_not_found(self, kind: str):
msg = []
for opt, cls in self._classes.items():
msg.append(f"\t{opt.kind!r} => {cls!r}")
msg_str = "\n".join(msg)
return f"No class found with the name {kind!r}, known classes are:\n{msg_str}"
def register(self, cls: Type[A], plugin_name: str, plugin_module_name: str):
opt_type = cls.get_options_type()
if opt_type in self._classes:
raise ValueError(
f"{opt_type.kind!r} already registered to class {self._classes[opt_type]!r}"
)
self._classes[opt_type] = cls
self._meta[opt_type] = FactoryMeta(
kind=opt_type.kind, plugin_name=plugin_name, module=plugin_module_name
)
def load_from_plugins(
self, plugin_name: Optional[str] = None, allow_external_plugins: bool = False
):
plugin_name = plugin_name or self.plugin_name
plugin_manager = PluginManager(plugin_name)
plugin_manager.load_setuptools_entrypoints(plugin_name)
for plugin_name, plugin_module in plugin_manager.list_name_plugin():
plugin_module_name = str(plugin_module.__name__) # type: ignore
if not allow_external_plugins and not plugin_module_name.startswith(
"docling."
):
logger.warning(
f"The plugin {plugin_name} will not be loaded because Docling is being executed with allow_external_plugins=false."
)
continue
attr = getattr(plugin_module, self.plugin_attr_name, None)
if callable(attr):
logger.info("Loading plugin %r", plugin_name)
config = attr()
self.process_plugin(config, plugin_name, plugin_module_name)
def process_plugin(self, config, plugin_name: str, plugin_module_name: str):
for item in config[self.plugin_attr_name]:
try:
self.register(item, plugin_name, plugin_module_name)
except ValueError:
logger.warning("%r already registered", item)

View File

@ -0,0 +1,11 @@
import logging
from docling.models.base_ocr_model import BaseOcrModel
from docling.models.factories.base_factory import BaseFactory
logger = logging.getLogger(__name__)
class OcrFactory(BaseFactory[BaseOcrModel]):
def __init__(self, *args, **kwargs):
super().__init__("ocr_engines", *args, **kwargs)

View File

@ -0,0 +1,11 @@
import logging
from docling.models.factories.base_factory import BaseFactory
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
logger = logging.getLogger(__name__)
class PictureDescriptionFactory(BaseFactory[PictureDescriptionBaseModel]):
def __init__(self, *args, **kwargs):
super().__init__("picture_description", *args, **kwargs)

View File

@ -0,0 +1,137 @@
import logging
import time
from pathlib import Path
from typing import Iterable, List, Optional
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
HuggingFaceVlmOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceMlxModel(BasePageModel):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: HuggingFaceVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
try:
from mlx_vlm import generate, load # type: ignore
from mlx_vlm.prompt_utils import apply_chat_template # type: ignore
from mlx_vlm.utils import load_config, stream_generate # type: ignore
except ImportError:
raise ImportError(
"mlx-vlm is not installed. Please install it via `pip install mlx-vlm` to use MLX VLM models."
)
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
self.apply_chat_template = apply_chat_template
self.stream_generate = stream_generate
# PARAMETERS:
if artifacts_path is None:
artifacts_path = self.download_models(self.vlm_options.repo_id)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
## Load the model
self.vlm_model, self.processor = load(artifacts_path)
self.config = load_config(artifacts_path)
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
# revision="v0.0.1",
)
return Path(download_path)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=2.0) # 144dpi
# hi_res_image = page.get_image(scale=1.0) # 72dpi
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
# populate page_tags with predicted doc tags
page_tags = ""
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
prompt = self.apply_chat_template(
self.processor, self.config, self.param_question, num_images=1
)
start_time = time.time()
# Call model to generate:
output = ""
for token in self.stream_generate(
self.vlm_model,
self.processor,
prompt,
[hi_res_image],
max_tokens=4096,
verbose=False,
):
output += token.text
if "</doctag>" in token.text:
break
generation_time = time.time() - start_time
page_tags = output
# inference_time = time.time() - start_time
# tokens_per_second = num_tokens / generation_time
# print("")
# print(f"Page Inference Time: {inference_time:.2f} seconds")
# print(f"Total tokens on page: {num_tokens:.2f}")
# print(f"Tokens/sec: {tokens_per_second:.2f}")
# print("")
page.predictions.vlm_response = VlmPrediction(text=page_tags)
yield page

View File

@ -0,0 +1,180 @@
import logging
import time
from pathlib import Path
from typing import Iterable, List, Optional
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
HuggingFaceVlmOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceVlmModel(BasePageModel):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: HuggingFaceVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
import torch
from transformers import ( # type: ignore
AutoModelForVision2Seq,
AutoProcessor,
BitsAndBytesConfig,
)
device = decide_device(accelerator_options.device)
self.device = device
_log.debug("Available device for HuggingFace VLM: {}".format(device))
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
# PARAMETERS:
if artifacts_path is None:
artifacts_path = self.download_models(self.vlm_options.repo_id)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
self.param_question = vlm_options.prompt # "Perform Layout Analysis."
self.param_quantization_config = BitsAndBytesConfig(
load_in_8bit=vlm_options.load_in_8bit, # True,
llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
)
self.param_quantized = vlm_options.quantized # False
self.processor = AutoProcessor.from_pretrained(artifacts_path)
if not self.param_quantized:
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
device_map=device,
torch_dtype=torch.bfloat16,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
) # .to(self.device)
else:
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
device_map=device,
torch_dtype="auto",
quantization_config=self.param_quantization_config,
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
) # .to(self.device)
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
# revision="v0.0.1",
)
return Path(download_path)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=2.0) # 144dpi
# hi_res_image = page.get_image(scale=1.0) # 72dpi
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
# populate page_tags with predicted doc tags
page_tags = ""
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": self.param_question},
],
}
]
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=False
)
inputs = self.processor(
text=prompt, images=[hi_res_image], return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
start_time = time.time()
# Call model to generate:
generated_ids = self.vlm_model.generate(
**inputs, max_new_tokens=4096, use_cache=True
)
generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False,
)[0]
num_tokens = len(generated_ids[0])
page_tags = generated_texts
# inference_time = time.time() - start_time
# tokens_per_second = num_tokens / generation_time
# print("")
# print(f"Page Inference Time: {inference_time:.2f} seconds")
# print(f"Total tokens on page: {num_tokens:.2f}")
# print(f"Tokens/sec: {tokens_per_second:.2f}")
# print("")
page.predictions.vlm_response = VlmPrediction(text=page_tags)
yield page

View File

@ -1,33 +1,29 @@
import copy
import logging
import random
import time
import warnings
from pathlib import Path
from typing import Iterable, List
from typing import Iterable, Optional, Union
from docling_core.types.doc import CoordOrigin, DocItemLabel
from docling_core.types.doc import DocItemLabel
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
from PIL import Image, ImageDraw, ImageFont
from PIL import Image
from docling.datamodel.base_models import (
BoundingBox,
Cell,
Cluster,
LayoutPrediction,
Page,
)
from docling.datamodel.base_models import BoundingBox, Cluster, LayoutPrediction, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.pipeline_options import AcceleratorOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.layout_postprocessor import LayoutPostprocessor
from docling.utils.profiling import TimeRecorder
from docling.utils.visualization import draw_clusters
_log = logging.getLogger(__name__)
class LayoutModel(BasePageModel):
_model_repo_folder = "ds4sd--docling-models"
_model_path = "model_artifacts/layout"
TEXT_ELEM_LABELS = [
DocItemLabel.TEXT,
@ -40,7 +36,7 @@ class LayoutModel(BasePageModel):
DocItemLabel.PAGE_FOOTER,
DocItemLabel.CODE,
DocItemLabel.LIST_ITEM,
# "Formula",
DocItemLabel.FORMULA,
]
PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]
@ -49,15 +45,56 @@ class LayoutModel(BasePageModel):
FORMULA_LABEL = DocItemLabel.FORMULA
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
def __init__(
self, artifacts_path: Optional[Path], accelerator_options: AcceleratorOptions
):
device = decide_device(accelerator_options.device)
if artifacts_path is None:
artifacts_path = self.download_models() / self._model_path
else:
# will become the default in the future
if (artifacts_path / self._model_repo_folder).exists():
artifacts_path = (
artifacts_path / self._model_repo_folder / self._model_path
)
elif (artifacts_path / self._model_path).exists():
warnings.warn(
"The usage of artifacts_path containing directly "
f"{self._model_path} is deprecated. Please point "
"the artifacts_path to the parent containing "
f"the {self._model_repo_folder} folder.",
DeprecationWarning,
stacklevel=3,
)
artifacts_path = artifacts_path / self._model_path
self.layout_predictor = LayoutPredictor(
artifact_path=str(artifacts_path),
device=device,
num_threads=accelerator_options.num_threads,
)
@staticmethod
def download_models(
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/docling-models",
force_download=force,
local_dir=local_dir,
revision="v2.1.0",
)
return Path(download_path)
def draw_clusters_and_cells_side_by_side(
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
):
@ -67,29 +104,9 @@ class LayoutModel(BasePageModel):
- Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE.
Includes label names and confidence scores for each cluster.
"""
label_to_color = {
DocItemLabel.TEXT: (255, 255, 153), # Light Yellow
DocItemLabel.CAPTION: (255, 204, 153), # Light Orange
DocItemLabel.LIST_ITEM: (153, 153, 255), # Light Purple
DocItemLabel.FORMULA: (192, 192, 192), # Gray
DocItemLabel.TABLE: (255, 204, 204), # Light Pink
DocItemLabel.PICTURE: (255, 204, 164), # Light Beige
DocItemLabel.SECTION_HEADER: (255, 153, 153), # Light Red
DocItemLabel.PAGE_HEADER: (204, 255, 204), # Light Green
DocItemLabel.PAGE_FOOTER: (
204,
255,
204,
), # Light Green (same as Page-Header)
DocItemLabel.TITLE: (255, 153, 153), # Light Red (same as Section-Header)
DocItemLabel.FOOTNOTE: (200, 200, 255), # Light Blue
DocItemLabel.DOCUMENT_INDEX: (220, 220, 220), # Light Gray
DocItemLabel.CODE: (125, 125, 125), # Gray
DocItemLabel.CHECKBOX_SELECTED: (255, 182, 193), # Pale Green
DocItemLabel.CHECKBOX_UNSELECTED: (255, 182, 193), # Light Pink
DocItemLabel.FORM: (200, 255, 255), # Light Cyan
DocItemLabel.KEY_VALUE_REGION: (183, 65, 14), # Rusty orange
}
scale_x = page.image.width / page.size.width
scale_y = page.image.height / page.size.height
# Filter clusters for left and right images
exclude_labels = {
DocItemLabel.FORM,
@ -102,65 +119,9 @@ class LayoutModel(BasePageModel):
left_image = copy.deepcopy(page.image)
right_image = copy.deepcopy(page.image)
# Function to draw clusters on an image
def draw_clusters(image, clusters):
draw = ImageDraw.Draw(image, "RGBA")
# Create a smaller font for the labels
try:
font = ImageFont.truetype("arial.ttf", 12)
except OSError:
# Fallback to default font if arial is not available
font = ImageFont.load_default()
for c_tl in clusters:
all_clusters = [c_tl, *c_tl.children]
for c in all_clusters:
# Draw cells first (underneath)
cell_color = (0, 0, 0, 40) # Transparent black for cells
for tc in c.cells:
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
draw.rectangle(
[(cx0, cy0), (cx1, cy1)],
outline=None,
fill=cell_color,
)
# Draw cluster rectangle
x0, y0, x1, y1 = c.bbox.as_tuple()
cluster_fill_color = (*list(label_to_color.get(c.label)), 70)
cluster_outline_color = (*list(label_to_color.get(c.label)), 255)
draw.rectangle(
[(x0, y0), (x1, y1)],
outline=cluster_outline_color,
fill=cluster_fill_color,
)
# Add label name and confidence
label_text = f"{c.label.name} ({c.confidence:.2f})"
# Create semi-transparent background for text
text_bbox = draw.textbbox((x0, y0), label_text, font=font)
text_bg_padding = 2
draw.rectangle(
[
(
text_bbox[0] - text_bg_padding,
text_bbox[1] - text_bg_padding,
),
(
text_bbox[2] + text_bg_padding,
text_bbox[3] + text_bg_padding,
),
],
fill=(255, 255, 255, 180), # Semi-transparent white
)
# Draw text
draw.text(
(x0, y0),
label_text,
fill=(0, 0, 0, 255), # Solid black
font=font,
)
# Draw clusters on both images
draw_clusters(left_image, left_clusters)
draw_clusters(right_image, right_clusters)
draw_clusters(left_image, left_clusters, scale_x, scale_y)
draw_clusters(right_image, right_clusters, scale_x, scale_y)
# Combine the images side by side
combined_width = left_image.width * 2
combined_height = left_image.height
@ -189,10 +150,12 @@ class LayoutModel(BasePageModel):
else:
with TimeRecorder(conv_res, "layout"):
assert page.size is not None
page_image = page.get_image(scale=1.0)
assert page_image is not None
clusters = []
for ix, pred_item in enumerate(
self.layout_predictor.predict(page.get_image(scale=1.0))
self.layout_predictor.predict(page_image)
):
label = DocItemLabel(
pred_item["label"]

View File

@ -1,12 +1,19 @@
import logging
import sys
import tempfile
from typing import Iterable, Optional, Tuple
from pathlib import Path
from typing import Iterable, Optional, Tuple, Type
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import BoundingRectangle, TextCell
from docling.datamodel.base_models import OcrCell, Page
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import OcrMacOptions
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
OcrMacOptions,
OcrOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
from docling.utils.profiling import TimeRecorder
@ -15,18 +22,31 @@ _log = logging.getLogger(__name__)
class OcrMacModel(BaseOcrModel):
def __init__(self, enabled: bool, options: OcrMacOptions):
super().__init__(enabled=enabled, options=options)
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
options: OcrMacOptions,
accelerator_options: AcceleratorOptions,
):
super().__init__(
enabled=enabled,
artifacts_path=artifacts_path,
options=options,
accelerator_options=accelerator_options,
)
self.options: OcrMacOptions
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
if self.enabled:
if "darwin" != sys.platform:
raise RuntimeError(f"OcrMac is only supported on Mac.")
install_errmsg = (
"ocrmac is not correctly installed. "
"Please install it via `pip install ocrmac` to use this OCR engine. "
"Alternatively, Docling has support for other OCR engines. See the documentation: "
"https://ds4sd.github.io/docling/installation/"
"https://docling-project.github.io/docling/installation/"
)
try:
from ocrmac import ocrmac
@ -94,13 +114,17 @@ class OcrMacModel(BaseOcrModel):
bottom = y2 / self.scale
cells.append(
OcrCell(
id=ix,
TextCell(
index=ix,
text=text,
orig=text,
from_ocr=True,
confidence=confidence,
bbox=BoundingBox.from_tuple(
coord=(left, top, right, bottom),
origin=CoordOrigin.TOPLEFT,
rect=BoundingRectangle.from_bounding_box(
BoundingBox.from_tuple(
coord=(left, top, right, bottom),
origin=CoordOrigin.TOPLEFT,
)
),
)
)
@ -116,3 +140,7 @@ class OcrMacModel(BaseOcrModel):
self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects)
yield page
@classmethod
def get_options_type(cls) -> Type[OcrOptions]:
return OcrMacOptions

View File

@ -22,7 +22,7 @@ _log = logging.getLogger(__name__)
class PageAssembleOptions(BaseModel):
keep_images: bool = False
pass
class PageAssembleModel(BasePageModel):
@ -52,6 +52,14 @@ class PageAssembleModel(BasePageModel):
sanitized_text = "".join(lines)
# Text normalization
sanitized_text = sanitized_text.replace("", "/")
sanitized_text = sanitized_text.replace("", "'")
sanitized_text = sanitized_text.replace("", "'")
sanitized_text = sanitized_text.replace("", '"')
sanitized_text = sanitized_text.replace("", '"')
sanitized_text = sanitized_text.replace("", "·")
return sanitized_text.strip() # Strip any leading or trailing whitespace
def __call__(
@ -135,31 +143,6 @@ class PageAssembleModel(BasePageModel):
)
elements.append(fig)
body.append(fig)
elif cluster.label == LayoutModel.FORMULA_LABEL:
equation = None
if page.predictions.equations_prediction:
equation = page.predictions.equations_prediction.equation_map.get(
cluster.id, None
)
if (
not equation
): # fallback: add empty formula, if it isn't present
text = self.sanitize_text(
[
cell.text.replace("\x02", "-").strip()
for cell in cluster.cells
if len(cell.text.strip()) > 0
]
)
equation = TextElement(
label=cluster.label,
id=cluster.id,
cluster=cluster,
page_no=page.page_no,
text=text,
)
elements.append(equation)
body.append(equation)
elif cluster.label in LayoutModel.CONTAINER_LABELS:
container_el = ContainerElement(
label=cluster.label,
@ -174,11 +157,4 @@ class PageAssembleModel(BasePageModel):
elements=elements, headers=headers, body=body
)
# Remove page images (can be disabled)
if not self.options.keep_images:
page._image_cache = {}
# Unload backend
page._backend.unload()
yield page

View File

@ -13,6 +13,7 @@ from docling.utils.profiling import TimeRecorder
class PagePreprocessingOptions(BaseModel):
images_scale: Optional[float]
create_parsed_page: bool
class PagePreprocessingModel(BasePageModel):
@ -55,11 +56,20 @@ class PagePreprocessingModel(BasePageModel):
page.cells = list(page._backend.get_text_cells())
if self.options.create_parsed_page:
page.parsed_page = page._backend.get_segmented_page()
# DEBUG code:
def draw_text_boxes(image, cells, show: bool = False):
draw = ImageDraw.Draw(image)
for c in cells:
x0, y0, x1, y1 = c.bbox.as_tuple()
x0, y0, x1, y1 = (
c.to_bounding_box().l,
c.to_bounding_box().t,
c.to_bounding_box().r,
c.to_bounding_box().b,
)
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
if show:
image.show()

View File

@ -0,0 +1,125 @@
import base64
import io
import logging
from pathlib import Path
from typing import Iterable, List, Optional, Type, Union
import requests
from PIL import Image
from pydantic import BaseModel, ConfigDict
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
PictureDescriptionApiOptions,
PictureDescriptionBaseOptions,
)
from docling.exceptions import OperationNotAllowed
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
_log = logging.getLogger(__name__)
class ChatMessage(BaseModel):
role: str
content: str
class ResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: str
class ResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ApiResponse(BaseModel):
model_config = ConfigDict(
protected_namespaces=(),
)
id: str
model: Optional[str] = None # returned by openai
choices: List[ResponseChoice]
created: int
usage: ResponseUsage
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
# elements_batch_size = 4
@classmethod
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
return PictureDescriptionApiOptions
def __init__(
self,
enabled: bool,
enable_remote_services: bool,
artifacts_path: Optional[Union[Path, str]],
options: PictureDescriptionApiOptions,
accelerator_options: AcceleratorOptions,
):
super().__init__(
enabled=enabled,
enable_remote_services=enable_remote_services,
artifacts_path=artifacts_path,
options=options,
accelerator_options=accelerator_options,
)
self.options: PictureDescriptionApiOptions
if self.enabled:
if not enable_remote_services:
raise OperationNotAllowed(
"Connections to remote services is only allowed when set explicitly. "
"pipeline_options.enable_remote_services=True."
)
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
# Note: technically we could make a batch request here,
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
for image in images:
img_io = io.BytesIO()
image.save(img_io, "PNG")
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": self.options.prompt,
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_base64}"
},
},
],
}
]
payload = {
"messages": messages,
**self.options.params,
}
r = requests.post(
str(self.options.url),
headers=self.options.headers,
json=payload,
timeout=self.options.timeout,
)
if not r.ok:
_log.error(f"Error calling the API. Reponse was {r.text}")
r.raise_for_status()
api_resp = ApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip()
yield generated_text

View File

@ -0,0 +1,80 @@
import logging
from abc import abstractmethod
from pathlib import Path
from typing import Any, Iterable, List, Optional, Type, Union
from docling_core.types.doc import (
DoclingDocument,
NodeItem,
PictureClassificationClass,
PictureItem,
)
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
PictureDescriptionData,
)
from PIL import Image
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
PictureDescriptionBaseOptions,
)
from docling.models.base_model import (
BaseItemAndImageEnrichmentModel,
BaseModelWithOptions,
ItemAndImageEnrichmentElement,
)
class PictureDescriptionBaseModel(
BaseItemAndImageEnrichmentModel, BaseModelWithOptions
):
images_scale: float = 2.0
def __init__(
self,
*,
enabled: bool,
enable_remote_services: bool,
artifacts_path: Optional[Union[Path, str]],
options: PictureDescriptionBaseOptions,
accelerator_options: AcceleratorOptions,
):
self.enabled = enabled
self.options = options
self.provenance = "not-implemented"
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
return self.enabled and isinstance(element, PictureItem)
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
raise NotImplementedError
def __call__(
self,
doc: DoclingDocument,
element_batch: Iterable[ItemAndImageEnrichmentElement],
) -> Iterable[NodeItem]:
if not self.enabled:
for element in element_batch:
yield element.item
return
images: List[Image.Image] = []
elements: List[PictureItem] = []
for el in element_batch:
assert isinstance(el.item, PictureItem)
elements.append(el.item)
images.append(el.image)
outputs = self._annotate_images(images)
for item, output in zip(elements, outputs):
item.annotations.append(
PictureDescriptionData(text=output, provenance=self.provenance)
)
yield item
@classmethod
@abstractmethod
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
pass

View File

@ -0,0 +1,121 @@
from pathlib import Path
from typing import Iterable, Optional, Type, Union
from PIL import Image
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
PictureDescriptionBaseOptions,
PictureDescriptionVlmOptions,
)
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
from docling.utils.accelerator_utils import decide_device
class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
@classmethod
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
return PictureDescriptionVlmOptions
def __init__(
self,
enabled: bool,
enable_remote_services: bool,
artifacts_path: Optional[Union[Path, str]],
options: PictureDescriptionVlmOptions,
accelerator_options: AcceleratorOptions,
):
super().__init__(
enabled=enabled,
enable_remote_services=enable_remote_services,
artifacts_path=artifacts_path,
options=options,
accelerator_options=accelerator_options,
)
self.options: PictureDescriptionVlmOptions
if self.enabled:
if artifacts_path is None:
artifacts_path = self.download_models(repo_id=self.options.repo_id)
else:
artifacts_path = Path(artifacts_path) / self.options.repo_cache_folder
self.device = decide_device(accelerator_options.device)
try:
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor
except ImportError:
raise ImportError(
"transformers >=4.46 is not installed. Please install Docling with the required extras `pip install docling[vlm]`."
)
# Initialize processor and model
self.processor = AutoProcessor.from_pretrained(artifacts_path)
self.model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
torch_dtype=torch.bfloat16,
_attn_implementation=(
"flash_attention_2" if self.device.startswith("cuda") else "eager"
),
).to(self.device)
self.provenance = f"{self.options.repo_id}"
@staticmethod
def download_models(
repo_id: str,
local_dir: Optional[Path] = None,
force: bool = False,
progress: bool = False,
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id=repo_id,
force_download=force,
local_dir=local_dir,
)
return Path(download_path)
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
from transformers import GenerationConfig
# Create input messages
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": self.options.prompt},
],
},
]
# TODO: do batch generation
for image in images:
# Prepare inputs
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=True
)
inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
inputs = inputs.to(self.device)
# Generate outputs
generated_ids = self.model.generate(
**inputs,
generation_config=GenerationConfig(**self.options.generation_config),
)
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
)
yield generated_texts[0].strip()

View File

View File

@ -0,0 +1,28 @@
from docling.models.easyocr_model import EasyOcrModel
from docling.models.ocr_mac_model import OcrMacModel
from docling.models.picture_description_api_model import PictureDescriptionApiModel
from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel
from docling.models.rapid_ocr_model import RapidOcrModel
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
from docling.models.tesseract_ocr_model import TesseractOcrModel
def ocr_engines():
return {
"ocr_engines": [
EasyOcrModel,
OcrMacModel,
RapidOcrModel,
TesseractOcrModel,
TesseractOcrCliModel,
]
}
def picture_description():
return {
"picture_description": [
PictureDescriptionVlmModel,
PictureDescriptionApiModel,
]
}

View File

@ -1,14 +1,17 @@
import logging
from typing import Iterable
from pathlib import Path
from typing import Iterable, Optional, Type
import numpy
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import BoundingRectangle, TextCell
from docling.datamodel.base_models import OcrCell, Page
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
OcrOptions,
RapidOcrOptions,
)
from docling.datamodel.settings import settings
@ -23,10 +26,16 @@ class RapidOcrModel(BaseOcrModel):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
options: RapidOcrOptions,
accelerator_options: AcceleratorOptions,
):
super().__init__(enabled=enabled, options=options)
super().__init__(
enabled=enabled,
artifacts_path=artifacts_path,
options=options,
accelerator_options=accelerator_options,
)
self.options: RapidOcrOptions
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
@ -59,6 +68,7 @@ class RapidOcrModel(BaseOcrModel):
det_model_path=self.options.det_model_path,
cls_model_path=self.options.cls_model_path,
rec_model_path=self.options.rec_model_path,
rec_keys_path=self.options.rec_keys_path,
)
def __call__(
@ -99,18 +109,26 @@ class RapidOcrModel(BaseOcrModel):
if result is not None:
cells = [
OcrCell(
id=ix,
TextCell(
index=ix,
text=line[1],
orig=line[1],
confidence=line[2],
bbox=BoundingBox.from_tuple(
coord=(
(line[0][0][0] / self.scale) + ocr_rect.l,
(line[0][0][1] / self.scale) + ocr_rect.t,
(line[0][2][0] / self.scale) + ocr_rect.l,
(line[0][2][1] / self.scale) + ocr_rect.t,
),
origin=CoordOrigin.TOPLEFT,
from_ocr=True,
rect=BoundingRectangle.from_bounding_box(
BoundingBox.from_tuple(
coord=(
(line[0][0][0] / self.scale)
+ ocr_rect.l,
(line[0][0][1] / self.scale)
+ ocr_rect.t,
(line[0][2][0] / self.scale)
+ ocr_rect.l,
(line[0][2][1] / self.scale)
+ ocr_rect.t,
),
origin=CoordOrigin.TOPLEFT,
)
),
)
for ix, line in enumerate(result)
@ -125,3 +143,7 @@ class RapidOcrModel(BaseOcrModel):
self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects)
yield page
@classmethod
def get_options_type(cls) -> Type[OcrOptions]:
return RapidOcrOptions

View File

@ -0,0 +1,389 @@
import copy
import random
from pathlib import Path
from typing import Dict, List
from docling_core.types.doc import (
BoundingBox,
CoordOrigin,
DocItem,
DocItemLabel,
DoclingDocument,
DocumentOrigin,
GroupLabel,
NodeItem,
ProvenanceItem,
RefItem,
TableData,
)
from docling_core.types.doc.document import ContentLayer
from docling_core.types.legacy_doc.base import Ref
from docling_core.types.legacy_doc.document import BaseText
from docling_ibm_models.reading_order.reading_order_rb import (
PageElement as ReadingOrderPageElement,
)
from docling_ibm_models.reading_order.reading_order_rb import ReadingOrderPredictor
from PIL import ImageDraw
from pydantic import BaseModel, ConfigDict
from docling.datamodel.base_models import (
BasePageElement,
Cluster,
ContainerElement,
FigureElement,
Table,
TextElement,
)
from docling.datamodel.document import ConversionResult
from docling.datamodel.settings import settings
from docling.utils.profiling import ProfilingScope, TimeRecorder
class ReadingOrderOptions(BaseModel):
model_config = ConfigDict(protected_namespaces=())
model_names: str = "" # e.g. "language;term;reference"
class ReadingOrderModel:
def __init__(self, options: ReadingOrderOptions):
self.options = options
self.ro_model = ReadingOrderPredictor()
def _assembled_to_readingorder_elements(
self, conv_res: ConversionResult
) -> List[ReadingOrderPageElement]:
elements: List[ReadingOrderPageElement] = []
page_no_to_pages = {p.page_no: p for p in conv_res.pages}
for element in conv_res.assembled.elements:
page_height = page_no_to_pages[element.page_no].size.height # type: ignore
bbox = element.cluster.bbox.to_bottom_left_origin(page_height)
text = element.text or ""
elements.append(
ReadingOrderPageElement(
cid=len(elements),
ref=RefItem(cref=f"#/{element.page_no}/{element.cluster.id}"),
text=text,
page_no=element.page_no,
page_size=page_no_to_pages[element.page_no].size,
label=element.label,
l=bbox.l,
r=bbox.r,
b=bbox.b,
t=bbox.t,
coord_origin=bbox.coord_origin,
)
)
return elements
def _add_child_elements(
self, element: BasePageElement, doc_item: NodeItem, doc: DoclingDocument
):
child: Cluster
for child in element.cluster.children:
c_label = child.label
c_bbox = child.bbox.to_bottom_left_origin(
doc.pages[element.page_no + 1].size.height
)
c_text = " ".join(
[
cell.text.replace("\x02", "-").strip()
for cell in child.cells
if len(cell.text.strip()) > 0
]
)
c_prov = ProvenanceItem(
page_no=element.page_no + 1, charspan=(0, len(c_text)), bbox=c_bbox
)
if c_label == DocItemLabel.LIST_ITEM:
# TODO: Infer if this is a numbered or a bullet list item
doc.add_list_item(parent=doc_item, text=c_text, prov=c_prov)
elif c_label == DocItemLabel.SECTION_HEADER:
doc.add_heading(parent=doc_item, text=c_text, prov=c_prov)
else:
doc.add_text(parent=doc_item, label=c_label, text=c_text, prov=c_prov)
def _readingorder_elements_to_docling_doc(
self,
conv_res: ConversionResult,
ro_elements: List[ReadingOrderPageElement],
el_to_captions_mapping: Dict[int, List[int]],
el_to_footnotes_mapping: Dict[int, List[int]],
el_merges_mapping: Dict[int, List[int]],
) -> DoclingDocument:
id_to_elem = {
RefItem(cref=f"#/{elem.page_no}/{elem.cluster.id}").cref: elem
for elem in conv_res.assembled.elements
}
cid_to_rels = {rel.cid: rel for rel in ro_elements}
origin = DocumentOrigin(
mimetype="application/pdf",
filename=conv_res.input.file.name,
binary_hash=conv_res.input.document_hash,
)
doc_name = Path(origin.filename).stem
out_doc: DoclingDocument = DoclingDocument(name=doc_name, origin=origin)
for page in conv_res.pages:
page_no = page.page_no + 1
size = page.size
assert size is not None
out_doc.add_page(page_no=page_no, size=size)
current_list = None
skippable_cids = {
cid
for mapping in (
el_to_captions_mapping,
el_to_footnotes_mapping,
el_merges_mapping,
)
for lst in mapping.values()
for cid in lst
}
page_no_to_pages = {p.page_no: p for p in conv_res.pages}
for rel in ro_elements:
if rel.cid in skippable_cids:
continue
element = id_to_elem[rel.ref.cref]
page_height = page_no_to_pages[element.page_no].size.height # type: ignore
if isinstance(element, TextElement):
if element.label == DocItemLabel.CODE:
cap_text = element.text
prov = ProvenanceItem(
page_no=element.page_no + 1,
charspan=(0, len(cap_text)),
bbox=element.cluster.bbox.to_bottom_left_origin(page_height),
)
code_item = out_doc.add_code(text=cap_text, prov=prov)
if rel.cid in el_to_captions_mapping.keys():
for caption_cid in el_to_captions_mapping[rel.cid]:
caption_elem = id_to_elem[cid_to_rels[caption_cid].ref.cref]
new_cap_item = self._add_caption_or_footnote(
caption_elem, out_doc, code_item, page_height
)
code_item.captions.append(new_cap_item.get_ref())
if rel.cid in el_to_footnotes_mapping.keys():
for footnote_cid in el_to_footnotes_mapping[rel.cid]:
footnote_elem = id_to_elem[
cid_to_rels[footnote_cid].ref.cref
]
new_footnote_item = self._add_caption_or_footnote(
footnote_elem, out_doc, code_item, page_height
)
code_item.footnotes.append(new_footnote_item.get_ref())
else:
new_item, current_list = self._handle_text_element(
element, out_doc, current_list, page_height
)
if rel.cid in el_merges_mapping.keys():
for merged_cid in el_merges_mapping[rel.cid]:
merged_elem = id_to_elem[cid_to_rels[merged_cid].ref.cref]
self._merge_elements(
element, merged_elem, new_item, page_height
)
elif isinstance(element, Table):
tbl_data = TableData(
num_rows=element.num_rows,
num_cols=element.num_cols,
table_cells=element.table_cells,
)
prov = ProvenanceItem(
page_no=element.page_no + 1,
charspan=(0, 0),
bbox=element.cluster.bbox.to_bottom_left_origin(page_height),
)
tbl = out_doc.add_table(
data=tbl_data, prov=prov, label=element.cluster.label
)
if rel.cid in el_to_captions_mapping.keys():
for caption_cid in el_to_captions_mapping[rel.cid]:
caption_elem = id_to_elem[cid_to_rels[caption_cid].ref.cref]
new_cap_item = self._add_caption_or_footnote(
caption_elem, out_doc, tbl, page_height
)
tbl.captions.append(new_cap_item.get_ref())
if rel.cid in el_to_footnotes_mapping.keys():
for footnote_cid in el_to_footnotes_mapping[rel.cid]:
footnote_elem = id_to_elem[cid_to_rels[footnote_cid].ref.cref]
new_footnote_item = self._add_caption_or_footnote(
footnote_elem, out_doc, tbl, page_height
)
tbl.footnotes.append(new_footnote_item.get_ref())
# TODO: Consider adding children of Table.
elif isinstance(element, FigureElement):
cap_text = ""
prov = ProvenanceItem(
page_no=element.page_no + 1,
charspan=(0, len(cap_text)),
bbox=element.cluster.bbox.to_bottom_left_origin(page_height),
)
pic = out_doc.add_picture(prov=prov)
if rel.cid in el_to_captions_mapping.keys():
for caption_cid in el_to_captions_mapping[rel.cid]:
caption_elem = id_to_elem[cid_to_rels[caption_cid].ref.cref]
new_cap_item = self._add_caption_or_footnote(
caption_elem, out_doc, pic, page_height
)
pic.captions.append(new_cap_item.get_ref())
if rel.cid in el_to_footnotes_mapping.keys():
for footnote_cid in el_to_footnotes_mapping[rel.cid]:
footnote_elem = id_to_elem[cid_to_rels[footnote_cid].ref.cref]
new_footnote_item = self._add_caption_or_footnote(
footnote_elem, out_doc, pic, page_height
)
pic.footnotes.append(new_footnote_item.get_ref())
self._add_child_elements(element, pic, out_doc)
elif isinstance(element, ContainerElement): # Form, KV region
label = element.label
group_label = GroupLabel.UNSPECIFIED
if label == DocItemLabel.FORM:
group_label = GroupLabel.FORM_AREA
elif label == DocItemLabel.KEY_VALUE_REGION:
group_label = GroupLabel.KEY_VALUE_AREA
container_el = out_doc.add_group(label=group_label)
self._add_child_elements(element, container_el, out_doc)
return out_doc
def _add_caption_or_footnote(self, elem, out_doc, parent, page_height):
assert isinstance(elem, TextElement)
text = elem.text
prov = ProvenanceItem(
page_no=elem.page_no + 1,
charspan=(0, len(text)),
bbox=elem.cluster.bbox.to_bottom_left_origin(page_height),
)
new_item = out_doc.add_text(
label=elem.label, text=text, prov=prov, parent=parent
)
return new_item
def _handle_text_element(self, element, out_doc, current_list, page_height):
cap_text = element.text
prov = ProvenanceItem(
page_no=element.page_no + 1,
charspan=(0, len(cap_text)),
bbox=element.cluster.bbox.to_bottom_left_origin(page_height),
)
label = element.label
if label == DocItemLabel.LIST_ITEM:
if current_list is None:
current_list = out_doc.add_group(label=GroupLabel.LIST, name="list")
# TODO: Infer if this is a numbered or a bullet list item
new_item = out_doc.add_list_item(
text=cap_text, enumerated=False, prov=prov, parent=current_list
)
elif label == DocItemLabel.SECTION_HEADER:
current_list = None
new_item = out_doc.add_heading(text=cap_text, prov=prov)
elif label == DocItemLabel.FORMULA:
current_list = None
new_item = out_doc.add_text(
label=DocItemLabel.FORMULA, text="", orig=cap_text, prov=prov
)
else:
current_list = None
content_layer = ContentLayer.BODY
if element.label in [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]:
content_layer = ContentLayer.FURNITURE
new_item = out_doc.add_text(
label=element.label,
text=cap_text,
prov=prov,
content_layer=content_layer,
)
return new_item, current_list
def _merge_elements(self, element, merged_elem, new_item, page_height):
assert isinstance(
merged_elem, type(element)
), "Merged element must be of same type as element."
assert (
merged_elem.label == new_item.label
), "Labels of merged elements must match."
prov = ProvenanceItem(
page_no=element.page_no + 1,
charspan=(
len(new_item.text) + 1,
len(new_item.text) + 1 + len(merged_elem.text),
),
bbox=element.cluster.bbox.to_bottom_left_origin(page_height),
)
new_item.text += f" {merged_elem.text}"
new_item.orig += f" {merged_elem.text}" # TODO: This is incomplete, we don't have the `orig` field of the merged element.
new_item.prov.append(prov)
def __call__(self, conv_res: ConversionResult) -> DoclingDocument:
with TimeRecorder(conv_res, "glm", scope=ProfilingScope.DOCUMENT):
page_elements = self._assembled_to_readingorder_elements(conv_res)
# Apply reading order
sorted_elements = self.ro_model.predict_reading_order(
page_elements=page_elements
)
el_to_captions_mapping = self.ro_model.predict_to_captions(
sorted_elements=sorted_elements
)
el_to_footnotes_mapping = self.ro_model.predict_to_footnotes(
sorted_elements=sorted_elements
)
el_merges_mapping = self.ro_model.predict_merges(
sorted_elements=sorted_elements
)
docling_doc: DoclingDocument = self._readingorder_elements_to_docling_doc(
conv_res,
sorted_elements,
el_to_captions_mapping,
el_to_footnotes_mapping,
el_merges_mapping,
)
return docling_doc

View File

@ -1,9 +1,15 @@
import copy
import warnings
from pathlib import Path
from typing import Iterable
from typing import Iterable, Optional, Union
import numpy
from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell
from docling_core.types.doc.page import (
BoundingRectangle,
SegmentedPdfPage,
TextCellUnit,
)
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
from PIL import ImageDraw
@ -22,10 +28,13 @@ from docling.utils.profiling import TimeRecorder
class TableStructureModel(BasePageModel):
_model_repo_folder = "ds4sd--docling-models"
_model_path = "model_artifacts/tableformer"
def __init__(
self,
enabled: bool,
artifacts_path: Path,
artifacts_path: Optional[Path],
options: TableStructureOptions,
accelerator_options: AcceleratorOptions,
):
@ -35,6 +44,26 @@ class TableStructureModel(BasePageModel):
self.enabled = enabled
if self.enabled:
if artifacts_path is None:
artifacts_path = self.download_models() / self._model_path
else:
# will become the default in the future
if (artifacts_path / self._model_repo_folder).exists():
artifacts_path = (
artifacts_path / self._model_repo_folder / self._model_path
)
elif (artifacts_path / self._model_path).exists():
warnings.warn(
"The usage of artifacts_path containing directly "
f"{self._model_path} is deprecated. Please point "
"the artifacts_path to the parent containing "
f"the {self._model_repo_folder} folder.",
DeprecationWarning,
stacklevel=3,
)
artifacts_path = artifacts_path / self._model_path
if self.mode == TableFormerMode.ACCURATE:
artifacts_path = artifacts_path / "accurate"
else:
@ -58,6 +87,24 @@ class TableStructureModel(BasePageModel):
)
self.scale = 2.0 # Scale up table input images to 144 dpi
@staticmethod
def download_models(
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
if not progress:
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/docling-models",
force_download=force,
local_dir=local_dir,
revision="v2.2.0",
)
return Path(download_path)
def draw_table_and_cells(
self,
conv_res: ConversionResult,
@ -66,23 +113,43 @@ class TableStructureModel(BasePageModel):
show: bool = False,
):
assert page._backend is not None
assert page.size is not None
image = (
page._backend.get_page_image()
) # make new image to avoid drawing on the saved ones
scale_x = image.width / page.size.width
scale_y = image.height / page.size.height
draw = ImageDraw.Draw(image)
for table_element in tbl_list:
x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple()
y0 *= scale_x
y1 *= scale_y
x0 *= scale_x
x1 *= scale_x
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
for cell in table_element.cluster.cells:
x0, y0, x1, y1 = cell.bbox.as_tuple()
x0, y0, x1, y1 = cell.rect.to_bounding_box().as_tuple()
x0 *= scale_x
x1 *= scale_x
y0 *= scale_x
y1 *= scale_y
draw.rectangle([(x0, y0), (x1, y1)], outline="green")
for tc in table_element.table_cells:
if tc.bbox is not None:
x0, y0, x1, y1 = tc.bbox.as_tuple()
x0 *= scale_x
x1 *= scale_x
y0 *= scale_x
y1 *= scale_y
if tc.column_header:
width = 3
else:
@ -155,17 +222,36 @@ class TableStructureModel(BasePageModel):
if len(table_bboxes):
for table_cluster, tbl_box in in_tables:
# Check if word-level cells are available from backend:
sp = page._backend.get_segmented_page()
if sp is not None:
tcells = sp.get_cells_in_bbox(
cell_unit=TextCellUnit.WORD,
bbox=table_cluster.bbox,
)
if len(tcells) == 0:
# In case word-level cells yield empty
tcells = table_cluster.cells
else:
# Otherwise - we use normal (line/phrase) cells
tcells = table_cluster.cells
tokens = []
for c in table_cluster.cells:
for c in tcells:
# Only allow non empty stings (spaces) into the cells of a table
if len(c.text.strip()) > 0:
new_cell = copy.deepcopy(c)
new_cell.bbox = new_cell.bbox.scaled(
scale=self.scale
new_cell.rect = BoundingRectangle.from_bounding_box(
new_cell.rect.to_bounding_box().scaled(
scale=self.scale
)
)
tokens.append(
{
"id": new_cell.index,
"text": new_cell.text,
"bbox": new_cell.rect.to_bounding_box().model_dump(),
}
)
tokens.append(new_cell.model_dump())
page_input["tokens"] = tokens
tf_output = self.tf_predictor.multi_table_predict(
@ -189,12 +275,16 @@ class TableStructureModel(BasePageModel):
tc.bbox = tc.bbox.scaled(1 / self.scale)
table_cells.append(tc)
assert "predict_details" in table_out
# Retrieving cols/rows, after post processing:
num_rows = table_out["predict_details"]["num_rows"]
num_cols = table_out["predict_details"]["num_cols"]
otsl_seq = table_out["predict_details"]["prediction"][
"rs_seq"
]
num_rows = table_out["predict_details"].get("num_rows", 0)
num_cols = table_out["predict_details"].get("num_cols", 0)
otsl_seq = (
table_out["predict_details"]
.get("prediction", {})
.get("rs_seq", [])
)
tbl = Table(
otsl_seq=otsl_seq,

View File

@ -3,36 +3,56 @@ import io
import logging
import os
import tempfile
from pathlib import Path
from subprocess import DEVNULL, PIPE, Popen
from typing import Iterable, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Type
import pandas as pd
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import BoundingRectangle, TextCell
from docling.datamodel.base_models import Cell, OcrCell, Page
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import TesseractCliOcrOptions
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
OcrOptions,
TesseractCliOcrOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
from docling.utils.ocr_utils import map_tesseract_script
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class TesseractOcrCliModel(BaseOcrModel):
def __init__(self, enabled: bool, options: TesseractCliOcrOptions):
super().__init__(enabled=enabled, options=options)
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
options: TesseractCliOcrOptions,
accelerator_options: AcceleratorOptions,
):
super().__init__(
enabled=enabled,
artifacts_path=artifacts_path,
options=options,
accelerator_options=accelerator_options,
)
self.options: TesseractCliOcrOptions
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
self._name: Optional[str] = None
self._version: Optional[str] = None
self._tesseract_languages: Optional[List[str]] = None
self._script_prefix: Optional[str] = None
if self.enabled:
try:
self._get_name_and_version()
self._set_languages_and_prefix()
except Exception as exc:
raise RuntimeError(
@ -74,12 +94,20 @@ class TesseractOcrCliModel(BaseOcrModel):
return name, version
def _run_tesseract(self, ifilename: str):
r"""
Run tesseract CLI
"""
cmd = [self.options.tesseract_cmd]
if self.options.lang is not None and len(self.options.lang) > 0:
if "auto" in self.options.lang:
lang = self._detect_language(ifilename)
if lang is not None:
cmd.append("-l")
cmd.append(lang)
elif self.options.lang is not None and len(self.options.lang) > 0:
cmd.append("-l")
cmd.append("+".join(self.options.lang))
if self.options.path is not None:
cmd.append("--tessdata-dir")
cmd.append(self.options.path)
@ -103,10 +131,69 @@ class TesseractOcrCliModel(BaseOcrModel):
# _log.info("df: ", df.head())
# Filter rows that contain actual text (ignore header or empty rows)
df_filtered = df[df["text"].notnull() & (df["text"].str.strip() != "")]
df_filtered = df[
df["text"].notnull() & (df["text"].apply(str).str.strip() != "")
]
return df_filtered
def _detect_language(self, ifilename: str):
r"""
Run tesseract in PSM 0 mode to detect the language
"""
assert self._tesseract_languages is not None
cmd = [self.options.tesseract_cmd]
cmd.extend(["--psm", "0", "-l", "osd", ifilename, "stdout"])
_log.info("command: {}".format(" ".join(cmd)))
proc = Popen(cmd, stdout=PIPE, stderr=DEVNULL)
output, _ = proc.communicate()
decoded_data = output.decode("utf-8")
df = pd.read_csv(
io.StringIO(decoded_data), sep=":", header=None, names=["key", "value"]
)
scripts = df.loc[df["key"] == "Script"].value.tolist()
if len(scripts) == 0:
_log.warning("Tesseract cannot detect the script of the page")
return None
script = map_tesseract_script(scripts[0].strip())
lang = f"{self._script_prefix}{script}"
# Check if the detected language has been installed
if lang not in self._tesseract_languages:
msg = f"Tesseract detected the script '{script}' and language '{lang}'."
msg += " However this language is not installed in your system and will be ignored."
_log.warning(msg)
return None
_log.debug(
f"Using tesseract model for the detected script '{script}' and language '{lang}'"
)
return lang
def _set_languages_and_prefix(self):
r"""
Read and set the languages installed in tesseract and decide the script prefix
"""
# Get all languages
cmd = [self.options.tesseract_cmd]
cmd.append("--list-langs")
_log.info("command: {}".format(" ".join(cmd)))
proc = Popen(cmd, stdout=PIPE, stderr=DEVNULL)
output, _ = proc.communicate()
decoded_data = output.decode("utf-8")
df = pd.read_csv(io.StringIO(decoded_data), header=None)
self._tesseract_languages = df[0].tolist()[1:]
# Decide the script prefix
if any([l.startswith("script/") for l in self._tesseract_languages]):
script_prefix = "script/"
else:
script_prefix = ""
self._script_prefix = script_prefix
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
@ -121,7 +208,6 @@ class TesseractOcrCliModel(BaseOcrModel):
yield page
else:
with TimeRecorder(conv_res, "ocr"):
ocr_rects = self.get_ocr_rects(page)
all_ocr_cells = []
@ -159,18 +245,22 @@ class TesseractOcrCliModel(BaseOcrModel):
t = b + h
r = l + w
cell = OcrCell(
id=ix,
text=text,
cell = TextCell(
index=ix,
text=str(text),
orig=text,
from_ocr=True,
confidence=conf / 100.0,
bbox=BoundingBox.from_tuple(
coord=(
(l / self.scale) + ocr_rect.l,
(b / self.scale) + ocr_rect.t,
(r / self.scale) + ocr_rect.l,
(t / self.scale) + ocr_rect.t,
),
origin=CoordOrigin.TOPLEFT,
rect=BoundingRectangle.from_bounding_box(
BoundingBox.from_tuple(
coord=(
(l / self.scale) + ocr_rect.l,
(b / self.scale) + ocr_rect.t,
(r / self.scale) + ocr_rect.l,
(t / self.scale) + ocr_rect.t,
),
origin=CoordOrigin.TOPLEFT,
)
),
)
all_ocr_cells.append(cell)
@ -183,3 +273,7 @@ class TesseractOcrCliModel(BaseOcrModel):
self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects)
yield page
@classmethod
def get_options_type(cls) -> Type[OcrOptions]:
return TesseractCliOcrOptions

View File

@ -1,25 +1,45 @@
import logging
from typing import Iterable
from pathlib import Path
from typing import Iterable, Optional, Type
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import BoundingRectangle, TextCell
from docling.datamodel.base_models import Cell, OcrCell, Page
from docling.datamodel.base_models import Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import TesseractOcrOptions
from docling.datamodel.pipeline_options import (
AcceleratorOptions,
OcrOptions,
TesseractOcrOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
from docling.utils.ocr_utils import map_tesseract_script
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class TesseractOcrModel(BaseOcrModel):
def __init__(self, enabled: bool, options: TesseractOcrOptions):
super().__init__(enabled=enabled, options=options)
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
options: TesseractOcrOptions,
accelerator_options: AcceleratorOptions,
):
super().__init__(
enabled=enabled,
artifacts_path=artifacts_path,
options=options,
accelerator_options=accelerator_options,
)
self.options: TesseractOcrOptions
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
self.reader = None
self.osd_reader = None
self.script_readers: dict[str, tesserocr.PyTessBaseAPI] = {}
if self.enabled:
install_errmsg = (
@ -28,14 +48,14 @@ class TesseractOcrModel(BaseOcrModel):
"Note that tesserocr might have to be manually compiled for working with "
"your Tesseract installation. The Docling documentation provides examples for it. "
"Alternatively, Docling has support for other OCR engines. See the documentation: "
"https://ds4sd.github.io/docling/installation/"
"https://docling-project.github.io/docling/installation/"
)
missing_langs_errmsg = (
"tesserocr is not correctly configured. No language models have been detected. "
"Please ensure that the TESSDATA_PREFIX envvar points to tesseract languages dir. "
"You can find more information how to setup other OCR engines in Docling "
"documentation: "
"https://ds4sd.github.io/docling/installation/"
"https://docling-project.github.io/docling/installation/"
)
try:
@ -47,27 +67,36 @@ class TesseractOcrModel(BaseOcrModel):
except:
raise ImportError(install_errmsg)
_, tesserocr_languages = tesserocr.get_languages()
if not tesserocr_languages:
_, self._tesserocr_languages = tesserocr.get_languages()
if not self._tesserocr_languages:
raise ImportError(missing_langs_errmsg)
# Initialize the tesseractAPI
_log.debug("Initializing TesserOCR: %s", tesseract_version)
lang = "+".join(self.options.lang)
if any([l.startswith("script/") for l in self._tesserocr_languages]):
self.script_prefix = "script/"
else:
self.script_prefix = ""
tesserocr_kwargs = {
"psm": tesserocr.PSM.AUTO,
"init": True,
"oem": tesserocr.OEM.DEFAULT,
}
if self.options.path is not None:
self.reader = tesserocr.PyTessBaseAPI(
path=self.options.path,
lang=lang,
psm=tesserocr.PSM.AUTO,
init=True,
oem=tesserocr.OEM.DEFAULT,
tesserocr_kwargs["path"] = self.options.path
if lang == "auto":
self.reader = tesserocr.PyTessBaseAPI(**tesserocr_kwargs)
self.osd_reader = tesserocr.PyTessBaseAPI(
**{"lang": "osd", "psm": tesserocr.PSM.OSD_ONLY} | tesserocr_kwargs
)
else:
self.reader = tesserocr.PyTessBaseAPI(
lang=lang,
psm=tesserocr.PSM.AUTO,
init=True,
oem=tesserocr.OEM.DEFAULT,
**{"lang": lang} | tesserocr_kwargs,
)
self.reader_RIL = tesserocr.RIL
@ -75,11 +104,12 @@ class TesseractOcrModel(BaseOcrModel):
if self.reader is not None:
# Finalize the tesseractAPI
self.reader.End()
for script in self.script_readers:
self.script_readers[script].End()
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
if not self.enabled:
yield from page_batch
return
@ -90,8 +120,8 @@ class TesseractOcrModel(BaseOcrModel):
yield page
else:
with TimeRecorder(conv_res, "ocr"):
assert self.reader is not None
assert self._tesserocr_languages is not None
ocr_rects = self.get_ocr_rects(page)
@ -104,35 +134,73 @@ class TesseractOcrModel(BaseOcrModel):
scale=self.scale, cropbox=ocr_rect
)
# Retrieve text snippets with their bounding boxes
self.reader.SetImage(high_res_image)
boxes = self.reader.GetComponentImages(
local_reader = self.reader
if "auto" in self.options.lang:
assert self.osd_reader is not None
self.osd_reader.SetImage(high_res_image)
osd = self.osd_reader.DetectOrientationScript()
# No text, probably
if osd is None:
continue
script = osd["script_name"]
script = map_tesseract_script(script)
lang = f"{self.script_prefix}{script}"
# Check if the detected languge is present in the system
if lang not in self._tesserocr_languages:
msg = f"Tesseract detected the script '{script}' and language '{lang}'."
msg += " However this language is not installed in your system and will be ignored."
_log.warning(msg)
else:
if script not in self.script_readers:
import tesserocr
self.script_readers[script] = (
tesserocr.PyTessBaseAPI(
path=self.reader.GetDatapath(),
lang=lang,
psm=tesserocr.PSM.AUTO,
init=True,
oem=tesserocr.OEM.DEFAULT,
)
)
local_reader = self.script_readers[script]
local_reader.SetImage(high_res_image)
boxes = local_reader.GetComponentImages(
self.reader_RIL.TEXTLINE, True
)
cells = []
for ix, (im, box, _, _) in enumerate(boxes):
# Set the area of interest. Tesseract uses Bottom-Left for the origin
self.reader.SetRectangle(
local_reader.SetRectangle(
box["x"], box["y"], box["w"], box["h"]
)
# Extract text within the bounding box
text = self.reader.GetUTF8Text().strip()
confidence = self.reader.MeanTextConf()
text = local_reader.GetUTF8Text().strip()
confidence = local_reader.MeanTextConf()
left = box["x"] / self.scale
bottom = box["y"] / self.scale
right = (box["x"] + box["w"]) / self.scale
top = (box["y"] + box["h"]) / self.scale
cells.append(
OcrCell(
id=ix,
TextCell(
index=ix,
text=text,
orig=text,
from_ocr=True,
confidence=confidence,
bbox=BoundingBox.from_tuple(
coord=(left, top, right, bottom),
origin=CoordOrigin.TOPLEFT,
rect=BoundingRectangle.from_bounding_box(
BoundingBox.from_tuple(
coord=(left, top, right, bottom),
origin=CoordOrigin.TOPLEFT,
),
),
)
)
@ -148,3 +216,7 @@ class TesseractOcrModel(BaseOcrModel):
self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects)
yield page
@classmethod
def get_options_type(cls) -> Type[OcrOptions]:
return TesseractOcrOptions

View File

@ -3,7 +3,7 @@ import logging
import time
import traceback
from abc import ABC, abstractmethod
from typing import Callable, Iterable, List
from typing import Any, Callable, Iterable, List
from docling_core.types.doc import DoclingDocument, NodeItem
@ -18,7 +18,7 @@ from docling.datamodel.base_models import (
from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import PipelineOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BaseEnrichmentModel
from docling.models.base_model import GenericEnrichmentModel
from docling.utils.profiling import ProfilingScope, TimeRecorder
from docling.utils.utils import chunkify
@ -28,8 +28,9 @@ _log = logging.getLogger(__name__)
class BasePipeline(ABC):
def __init__(self, pipeline_options: PipelineOptions):
self.pipeline_options = pipeline_options
self.keep_images = False
self.build_pipe: List[Callable] = []
self.enrichment_pipe: List[BaseEnrichmentModel] = []
self.enrichment_pipe: List[GenericEnrichmentModel[Any]] = []
def execute(self, in_doc: InputDocument, raises_on_error: bool) -> ConversionResult:
conv_res = ConversionResult(input=in_doc)
@ -40,7 +41,7 @@ class BasePipeline(ABC):
conv_res, "pipeline_total", scope=ProfilingScope.DOCUMENT
):
# These steps are building and assembling the structure of the
# output DoclingDocument
# output DoclingDocument.
conv_res = self._build_document(conv_res)
conv_res = self._assemble_document(conv_res)
# From this stage, all operations should rely only on conv_res.output
@ -50,6 +51,8 @@ class BasePipeline(ABC):
conv_res.status = ConversionStatus.FAILURE
if raises_on_error:
raise e
finally:
self._unload(conv_res)
return conv_res
@ -62,21 +65,22 @@ class BasePipeline(ABC):
def _enrich_document(self, conv_res: ConversionResult) -> ConversionResult:
def _filter_elements(
doc: DoclingDocument, model: BaseEnrichmentModel
def _prepare_elements(
conv_res: ConversionResult, model: GenericEnrichmentModel[Any]
) -> Iterable[NodeItem]:
for element, _level in doc.iterate_items():
if model.is_processable(doc=doc, element=element):
yield element
for doc_element, _level in conv_res.document.iterate_items():
prepared_element = model.prepare_element(
conv_res=conv_res, element=doc_element
)
if prepared_element is not None:
yield prepared_element
with TimeRecorder(conv_res, "doc_enrich", scope=ProfilingScope.DOCUMENT):
for model in self.enrichment_pipe:
for element_batch in chunkify(
_filter_elements(conv_res.document, model),
settings.perf.elements_batch_size,
_prepare_elements(conv_res, model),
model.elements_batch_size,
):
# TODO: currently we assume the element itself is modified, because
# we don't have an interface to save the element back to the document
for element in model(
doc=conv_res.document, element_batch=element_batch
): # Must exhaust!
@ -88,6 +92,9 @@ class BasePipeline(ABC):
def _determine_status(self, conv_res: ConversionResult) -> ConversionStatus:
pass
def _unload(self, conv_res: ConversionResult):
pass
@classmethod
@abstractmethod
def get_default_options(cls) -> PipelineOptions:
@ -107,6 +114,10 @@ class BasePipeline(ABC):
class PaginatedPipeline(BasePipeline): # TODO this is a bad name.
def __init__(self, pipeline_options: PipelineOptions):
super().__init__(pipeline_options)
self.keep_backend = False
def _apply_on_pages(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
@ -130,7 +141,9 @@ class PaginatedPipeline(BasePipeline): # TODO this is a bad name.
with TimeRecorder(conv_res, "doc_build", scope=ProfilingScope.DOCUMENT):
for i in range(0, conv_res.input.page_count):
conv_res.pages.append(Page(page_no=i))
start_page, end_page = conv_res.input.limits.page_range
if (start_page - 1) <= i <= (end_page - 1):
conv_res.pages.append(Page(page_no=i))
try:
# Iterate batches of pages (page_batch_size) in the doc
@ -148,7 +161,14 @@ class PaginatedPipeline(BasePipeline): # TODO this is a bad name.
pipeline_pages = self._apply_on_pages(conv_res, init_pages)
for p in pipeline_pages: # Must exhaust!
pass
# Cleanup cached images
if not self.keep_images:
p._image_cache = {}
# Cleanup page backends
if not self.keep_backend and p._backend is not None:
p._backend.unload()
end_batch_time = time.monotonic()
total_elapsed_time += end_batch_time - start_batch_time
@ -177,10 +197,15 @@ class PaginatedPipeline(BasePipeline): # TODO this is a bad name.
)
raise e
finally:
# Always unload the PDF backend, even in case of failure
if conv_res.input._backend:
conv_res.input._backend.unload()
return conv_res
def _unload(self, conv_res: ConversionResult) -> ConversionResult:
for page in conv_res.pages:
if page._backend is not None:
page._backend.unload()
if conv_res.input._backend:
conv_res.input._backend.unload()
return conv_res

View File

@ -1,5 +1,6 @@
import logging
import sys
import warnings
from pathlib import Path
from typing import Optional
@ -9,141 +10,166 @@ from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import AssembledUnit, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
EasyOcrOptions,
OcrMacOptions,
PdfPipelineOptions,
RapidOcrOptions,
TesseractCliOcrOptions,
TesseractOcrOptions,
)
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
from docling.models.ds_glm_model import GlmModel, GlmOptions
from docling.models.easyocr_model import EasyOcrModel
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
from docling.models.document_picture_classifier import (
DocumentPictureClassifier,
DocumentPictureClassifierOptions,
)
from docling.models.factories import get_ocr_factory, get_picture_description_factory
from docling.models.layout_model import LayoutModel
from docling.models.ocr_mac_model import OcrMacModel
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
from docling.models.page_preprocessing_model import (
PagePreprocessingModel,
PagePreprocessingOptions,
)
from docling.models.rapid_ocr_model import RapidOcrModel
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
from docling.models.table_structure_model import TableStructureModel
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
from docling.models.tesseract_ocr_model import TesseractOcrModel
from docling.pipeline.base_pipeline import PaginatedPipeline
from docling.utils.model_downloader import download_models
from docling.utils.profiling import ProfilingScope, TimeRecorder
_log = logging.getLogger(__name__)
class StandardPdfPipeline(PaginatedPipeline):
_layout_model_path = "model_artifacts/layout"
_table_model_path = "model_artifacts/tableformer"
_layout_model_path = LayoutModel._model_path
_table_model_path = TableStructureModel._model_path
def __init__(self, pipeline_options: PdfPipelineOptions):
super().__init__(pipeline_options)
self.pipeline_options: PdfPipelineOptions
if pipeline_options.artifacts_path is None:
self.artifacts_path = self.download_models_hf()
else:
self.artifacts_path = Path(pipeline_options.artifacts_path)
artifacts_path: Optional[Path] = None
if pipeline_options.artifacts_path is not None:
artifacts_path = Path(pipeline_options.artifacts_path).expanduser()
elif settings.artifacts_path is not None:
artifacts_path = Path(settings.artifacts_path).expanduser()
keep_images = (
if artifacts_path is not None and not artifacts_path.is_dir():
raise RuntimeError(
f"The value of {artifacts_path=} is not valid. "
"When defined, it must point to a folder containing all models required by the pipeline."
)
self.keep_images = (
self.pipeline_options.generate_page_images
or self.pipeline_options.generate_picture_images
or self.pipeline_options.generate_table_images
)
self.glm_model = GlmModel(options=GlmOptions())
self.glm_model = ReadingOrderModel(options=ReadingOrderOptions())
if (ocr_model := self.get_ocr_model()) is None:
raise RuntimeError(
f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}."
)
ocr_model = self.get_ocr_model(artifacts_path=artifacts_path)
self.build_pipe = [
# Pre-processing
PagePreprocessingModel(
options=PagePreprocessingOptions(
images_scale=pipeline_options.images_scale
images_scale=pipeline_options.images_scale,
create_parsed_page=pipeline_options.generate_parsed_pages,
)
),
# OCR
ocr_model,
# Layout model
LayoutModel(
artifacts_path=self.artifacts_path
/ StandardPdfPipeline._layout_model_path,
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
),
# Table structure model
TableStructureModel(
enabled=pipeline_options.do_table_structure,
artifacts_path=self.artifacts_path
/ StandardPdfPipeline._table_model_path,
artifacts_path=artifacts_path,
options=pipeline_options.table_structure_options,
accelerator_options=pipeline_options.accelerator_options,
),
# Page assemble
PageAssembleModel(options=PageAssembleOptions(keep_images=keep_images)),
PageAssembleModel(options=PageAssembleOptions()),
]
# Picture description model
if (
picture_description_model := self.get_picture_description_model(
artifacts_path=artifacts_path
)
) is None:
raise RuntimeError(
f"The specified picture description kind is not supported: {pipeline_options.picture_description_options.kind}."
)
self.enrichment_pipe = [
# Other models working on `NodeItem` elements in the DoclingDocument
# Code Formula Enrichment Model
CodeFormulaModel(
enabled=pipeline_options.do_code_enrichment
or pipeline_options.do_formula_enrichment,
artifacts_path=artifacts_path,
options=CodeFormulaModelOptions(
do_code_enrichment=pipeline_options.do_code_enrichment,
do_formula_enrichment=pipeline_options.do_formula_enrichment,
),
accelerator_options=pipeline_options.accelerator_options,
),
# Document Picture Classifier
DocumentPictureClassifier(
enabled=pipeline_options.do_picture_classification,
artifacts_path=artifacts_path,
options=DocumentPictureClassifierOptions(),
accelerator_options=pipeline_options.accelerator_options,
),
# Document Picture description
picture_description_model,
]
if (
self.pipeline_options.do_formula_enrichment
or self.pipeline_options.do_code_enrichment
or self.pipeline_options.do_picture_description
):
self.keep_backend = True
@staticmethod
def download_models_hf(
local_dir: Optional[Path] = None, force: bool = False
) -> Path:
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
disable_progress_bars()
download_path = snapshot_download(
repo_id="ds4sd/docling-models",
force_download=force,
local_dir=local_dir,
revision="v2.1.0",
warnings.warn(
"The usage of StandardPdfPipeline.download_models_hf() is deprecated "
"use instead the utility `docling-tools models download`, or "
"the upstream method docling.utils.models_downloader.download_all()",
DeprecationWarning,
stacklevel=3,
)
return Path(download_path)
output_dir = download_models(output_dir=local_dir, force=force, progress=False)
return output_dir
def get_ocr_model(self) -> Optional[BaseOcrModel]:
if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions):
return EasyOcrModel(
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
accelerator_options=self.pipeline_options.accelerator_options,
)
elif isinstance(self.pipeline_options.ocr_options, TesseractCliOcrOptions):
return TesseractOcrCliModel(
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
)
elif isinstance(self.pipeline_options.ocr_options, TesseractOcrOptions):
return TesseractOcrModel(
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
)
elif isinstance(self.pipeline_options.ocr_options, RapidOcrOptions):
return RapidOcrModel(
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
accelerator_options=self.pipeline_options.accelerator_options,
)
elif isinstance(self.pipeline_options.ocr_options, OcrMacOptions):
if "darwin" != sys.platform:
raise RuntimeError(
f"The specified OCR type is only supported on Mac: {self.pipeline_options.ocr_options.kind}."
)
return OcrMacModel(
enabled=self.pipeline_options.do_ocr,
options=self.pipeline_options.ocr_options,
)
return None
def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel:
factory = get_ocr_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
return factory.create_instance(
options=self.pipeline_options.ocr_options,
enabled=self.pipeline_options.do_ocr,
artifacts_path=artifacts_path,
accelerator_options=self.pipeline_options.accelerator_options,
)
def get_picture_description_model(
self, artifacts_path: Optional[Path] = None
) -> Optional[PictureDescriptionBaseModel]:
factory = get_picture_description_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
return factory.create_instance(
options=self.pipeline_options.picture_description_options,
enabled=self.pipeline_options.do_picture_description,
enable_remote_services=self.pipeline_options.enable_remote_services,
artifacts_path=artifacts_path,
accelerator_options=self.pipeline_options.accelerator_options,
)
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
with TimeRecorder(conv_res, "page_init"):

View File

@ -0,0 +1,214 @@
import logging
import warnings
from io import BytesIO
from pathlib import Path
from typing import List, Optional, Union, cast
# from docling_core.types import DoclingDocument
from docling_core.types.doc import BoundingBox, DocItem, ImageRef, PictureItem, TextItem
from docling_core.types.doc.document import DocTagsDocument
from PIL import Image as PILImage
from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.md_backend import MarkdownDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import InputFormat, Page
from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import (
InferenceFramework,
ResponseFormat,
VlmPipelineOptions,
)
from docling.datamodel.settings import settings
from docling.models.hf_mlx_model import HuggingFaceMlxModel
from docling.models.hf_vlm_model import HuggingFaceVlmModel
from docling.pipeline.base_pipeline import PaginatedPipeline
from docling.utils.profiling import ProfilingScope, TimeRecorder
_log = logging.getLogger(__name__)
class VlmPipeline(PaginatedPipeline):
def __init__(self, pipeline_options: VlmPipelineOptions):
super().__init__(pipeline_options)
self.keep_backend = True
self.pipeline_options: VlmPipelineOptions
artifacts_path: Optional[Path] = None
if pipeline_options.artifacts_path is not None:
artifacts_path = Path(pipeline_options.artifacts_path).expanduser()
elif settings.artifacts_path is not None:
artifacts_path = Path(settings.artifacts_path).expanduser()
if artifacts_path is not None and not artifacts_path.is_dir():
raise RuntimeError(
f"The value of {artifacts_path=} is not valid. "
"When defined, it must point to a folder containing all models required by the pipeline."
)
# force_backend_text = False - use text that is coming from VLM response
# force_backend_text = True - get text from backend using bounding boxes predicted by SmolDocling doctags
self.force_backend_text = (
pipeline_options.force_backend_text
and pipeline_options.vlm_options.response_format == ResponseFormat.DOCTAGS
)
self.keep_images = self.pipeline_options.generate_page_images
if (
self.pipeline_options.vlm_options.inference_framework
== InferenceFramework.MLX
):
self.build_pipe = [
HuggingFaceMlxModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=self.pipeline_options.vlm_options,
),
]
else:
self.build_pipe = [
HuggingFaceVlmModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=self.pipeline_options.vlm_options,
),
]
self.enrichment_pipe = [
# Other models working on `NodeItem` elements in the DoclingDocument
]
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
with TimeRecorder(conv_res, "page_init"):
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
if page._backend is not None and page._backend.is_valid():
page.size = page._backend.get_size()
return page
def extract_text_from_backend(
self, page: Page, bbox: Union[BoundingBox, None]
) -> str:
# Convert bounding box normalized to 0-100 into page coordinates for cropping
text = ""
if bbox:
if page.size:
if page._backend:
text = page._backend.get_text_in_rect(bbox)
return text
def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult:
with TimeRecorder(conv_res, "doc_assemble", scope=ProfilingScope.DOCUMENT):
if (
self.pipeline_options.vlm_options.response_format
== ResponseFormat.DOCTAGS
):
doctags_list = []
image_list = []
for page in conv_res.pages:
predicted_doctags = ""
img = PILImage.new("RGB", (1, 1), "rgb(255,255,255)")
if page.predictions.vlm_response:
predicted_doctags = page.predictions.vlm_response.text
if page.image:
img = page.image
image_list.append(img)
doctags_list.append(predicted_doctags)
doctags_list_c = cast(List[Union[Path, str]], doctags_list)
image_list_c = cast(List[Union[Path, PILImage.Image]], image_list)
doctags_doc = DocTagsDocument.from_doctags_and_image_pairs(
doctags_list_c, image_list_c
)
conv_res.document.load_from_doctags(doctags_doc)
# If forced backend text, replace model predicted text with backend one
if page.size:
if self.force_backend_text:
scale = self.pipeline_options.images_scale
for element, _level in conv_res.document.iterate_items():
if (
not isinstance(element, TextItem)
or len(element.prov) == 0
):
continue
crop_bbox = (
element.prov[0]
.bbox.scaled(scale=scale)
.to_top_left_origin(
page_height=page.size.height * scale
)
)
txt = self.extract_text_from_backend(page, crop_bbox)
element.text = txt
element.orig = txt
elif (
self.pipeline_options.vlm_options.response_format
== ResponseFormat.MARKDOWN
):
conv_res.document = self._turn_md_into_doc(conv_res)
else:
raise RuntimeError(
f"Unsupported VLM response format {self.pipeline_options.vlm_options.response_format}"
)
# Generate images of the requested element types
if self.pipeline_options.generate_picture_images:
scale = self.pipeline_options.images_scale
for element, _level in conv_res.document.iterate_items():
if not isinstance(element, DocItem) or len(element.prov) == 0:
continue
if (
isinstance(element, PictureItem)
and self.pipeline_options.generate_picture_images
):
page_ix = element.prov[0].page_no - 1
page = conv_res.pages[page_ix]
assert page.size is not None
assert page.image is not None
crop_bbox = (
element.prov[0]
.bbox.scaled(scale=scale)
.to_top_left_origin(page_height=page.size.height * scale)
)
cropped_im = page.image.crop(crop_bbox.as_tuple())
element.image = ImageRef.from_pil(
cropped_im, dpi=int(72 * scale)
)
return conv_res
def _turn_md_into_doc(self, conv_res):
predicted_text = ""
for pg_idx, page in enumerate(conv_res.pages):
if page.predictions.vlm_response:
predicted_text += page.predictions.vlm_response.text + "\n\n"
response_bytes = BytesIO(predicted_text.encode("utf8"))
out_doc = InputDocument(
path_or_stream=response_bytes,
filename=conv_res.input.file.name,
format=InputFormat.MD,
backend=MarkdownDocumentBackend,
)
backend = MarkdownDocumentBackend(
in_doc=out_doc,
path_or_stream=response_bytes,
)
return backend.convert()
@classmethod
def get_default_options(cls) -> VlmPipelineOptions:
return VlmPipelineOptions()
@classmethod
def is_backend_supported(cls, backend: AbstractDocumentBackend):
return isinstance(backend, PdfDocumentBackend)

View File

@ -7,36 +7,62 @@ from docling.datamodel.pipeline_options import AcceleratorDevice
_log = logging.getLogger(__name__)
def decide_device(accelerator_device: AcceleratorDevice) -> str:
def decide_device(accelerator_device: str) -> str:
r"""
Resolve the device based on the acceleration options and the available devices in the system
Resolve the device based on the acceleration options and the available devices in the system.
Rules:
1. AUTO: Check for the best available device on the system.
2. User-defined: Check if the device actually exists, otherwise fall-back to CPU
"""
cuda_index = 0
device = "cpu"
has_cuda = torch.backends.cuda.is_built() and torch.cuda.is_available()
has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
if accelerator_device == AcceleratorDevice.AUTO:
if accelerator_device == AcceleratorDevice.AUTO.value: # Handle 'auto'
if has_cuda:
device = f"cuda:{cuda_index}"
device = "cuda:0"
elif has_mps:
device = "mps"
elif accelerator_device.startswith("cuda"):
if has_cuda:
# if cuda device index specified extract device id
parts = accelerator_device.split(":")
if len(parts) == 2 and parts[1].isdigit():
# select cuda device's id
cuda_index = int(parts[1])
if cuda_index < torch.cuda.device_count():
device = f"cuda:{cuda_index}"
else:
_log.warning(
"CUDA device 'cuda:%d' is not available. Fall back to 'CPU'.",
cuda_index,
)
elif len(parts) == 1: # just "cuda"
device = "cuda:0"
else:
_log.warning(
"Invalid CUDA device format '%s'. Fall back to 'CPU'",
accelerator_device,
)
else:
_log.warning("CUDA is not available in the system. Fall back to 'CPU'")
elif accelerator_device == AcceleratorDevice.MPS.value:
if has_mps:
device = "mps"
else:
_log.warning("MPS is not available in the system. Fall back to 'CPU'")
elif accelerator_device == AcceleratorDevice.CPU.value:
device = "cpu"
else:
if accelerator_device == AcceleratorDevice.CUDA:
if has_cuda:
device = f"cuda:{cuda_index}"
else:
_log.warning("CUDA is not available in the system. Fall back to 'CPU'")
elif accelerator_device == AcceleratorDevice.MPS:
if has_mps:
device = "mps"
else:
_log.warning("MPS is not available in the system. Fall back to 'CPU'")
_log.warning(
"Unknown device option '%s'. Fall back to 'CPU'", accelerator_device
)
_log.info("Accelerator device: '%s'", device)
return device

View File

@ -2,9 +2,9 @@ import logging
from typing import Any, Dict, Iterable, List, Tuple, Union
from docling_core.types.doc import BoundingBox, CoordOrigin
from docling_core.types.doc.page import TextCell
from docling_core.types.legacy_doc.base import BaseCell, BaseText, Ref, Table
from docling.datamodel.base_models import OcrCell
from docling.datamodel.document import ConversionResult, Page
_log = logging.getLogger(__name__)
@ -86,11 +86,13 @@ def generate_multimodal_pages(
if page.size is None:
return cells
for cell in page.cells:
new_bbox = cell.bbox.to_top_left_origin(
page_height=page.size.height
).normalized(page_size=page.size)
is_ocr = isinstance(cell, OcrCell)
ocr_confidence = cell.confidence if isinstance(cell, OcrCell) else 1.0
new_bbox = (
cell.rect.to_bounding_box()
.to_top_left_origin(page_height=page.size.height)
.normalized(page_size=page.size)
)
is_ocr = cell.from_ocr
ocr_confidence = cell.confidence
cells.append(
{
"text": cell.text,

View File

@ -15,6 +15,7 @@ from docling_core.types.doc import (
TableCell,
TableData,
)
from docling_core.types.doc.document import ContentLayer
def resolve_item(paths, obj):
@ -270,7 +271,6 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
container_el = doc.add_group(label=group_label)
_add_child_elements(container_el, doc, obj, pelem)
elif "text" in obj:
text = obj["text"][span_i:span_j]
@ -304,6 +304,23 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
current_list = None
doc.add_heading(text=text, prov=prov)
elif label == DocItemLabel.CODE:
current_list = None
doc.add_code(text=text, prov=prov)
elif label == DocItemLabel.FORMULA:
current_list = None
doc.add_text(label=DocItemLabel.FORMULA, text="", orig=text, prov=prov)
elif label in [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]:
current_list = None
doc.add_text(
label=DocItemLabel(name_label),
text=text,
prov=prov,
content_layer=ContentLayer.FURNITURE,
)
else:
current_list = None

View File

@ -5,9 +5,10 @@ from collections import defaultdict
from typing import Dict, List, Set, Tuple
from docling_core.types.doc import DocItemLabel, Size
from docling_core.types.doc.page import TextCell
from rtree import index
from docling.datamodel.base_models import BoundingBox, Cell, Cluster, OcrCell
from docling.datamodel.base_models import BoundingBox, Cluster
_log = logging.getLogger(__name__)
@ -198,11 +199,12 @@ class LayoutPostprocessor:
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
}
def __init__(self, cells: List[Cell], clusters: List[Cluster], page_size: Size):
def __init__(self, cells: List[TextCell], clusters: List[Cluster], page_size: Size):
"""Initialize processor with cells and clusters."""
"""Initialize processor with cells and spatial indices."""
self.cells = cells
self.page_size = page_size
self.all_clusters = clusters
self.regular_clusters = [
c for c in clusters if c.label not in self.SPECIAL_TYPES
]
@ -217,7 +219,7 @@ class LayoutPostprocessor:
[c for c in self.special_clusters if c.label in self.WRAPPER_TYPES]
)
def postprocess(self) -> Tuple[List[Cluster], List[Cell]]:
def postprocess(self) -> Tuple[List[Cluster], List[TextCell]]:
"""Main processing pipeline."""
self.regular_clusters = self._process_regular_clusters()
self.special_clusters = self._process_special_clusters()
@ -267,18 +269,16 @@ class LayoutPostprocessor:
# Handle orphaned cells
unassigned = self._find_unassigned_cells(clusters)
if unassigned:
next_id = max((c.id for c in clusters), default=0) + 1
next_id = max((c.id for c in self.all_clusters), default=0) + 1
orphan_clusters = []
for i, cell in enumerate(unassigned):
conf = 1.0
if isinstance(cell, OcrCell):
conf = cell.confidence
conf = cell.confidence
orphan_clusters.append(
Cluster(
id=next_id + i,
label=DocItemLabel.TEXT,
bbox=cell.bbox,
bbox=cell.to_bounding_box(),
confidence=conf,
cells=[cell],
)
@ -556,13 +556,13 @@ class LayoutPostprocessor:
return current_best if current_best else clusters[0]
def _deduplicate_cells(self, cells: List[Cell]) -> List[Cell]:
def _deduplicate_cells(self, cells: List[TextCell]) -> List[TextCell]:
"""Ensure each cell appears only once, maintaining order of first appearance."""
seen_ids = set()
unique_cells = []
for cell in cells:
if cell.id not in seen_ids:
seen_ids.add(cell.id)
if cell.index not in seen_ids:
seen_ids.add(cell.index)
unique_cells.append(cell)
return unique_cells
@ -581,11 +581,13 @@ class LayoutPostprocessor:
best_cluster = None
for cluster in clusters:
if cell.bbox.area() <= 0:
if cell.rect.to_bounding_box().area() <= 0:
continue
overlap = cell.bbox.intersection_area_with(cluster.bbox)
overlap_ratio = overlap / cell.bbox.area()
overlap = cell.rect.to_bounding_box().intersection_area_with(
cluster.bbox
)
overlap_ratio = overlap / cell.rect.to_bounding_box().area()
if overlap_ratio > best_overlap:
best_overlap = overlap_ratio
@ -600,11 +602,13 @@ class LayoutPostprocessor:
return clusters
def _find_unassigned_cells(self, clusters: List[Cluster]) -> List[Cell]:
def _find_unassigned_cells(self, clusters: List[Cluster]) -> List[TextCell]:
"""Find cells not assigned to any cluster."""
assigned = {cell.id for cluster in clusters for cell in cluster.cells}
assigned = {cell.index for cluster in clusters for cell in cluster.cells}
return [
cell for cell in self.cells if cell.id not in assigned and cell.text.strip()
cell
for cell in self.cells
if cell.index not in assigned and cell.text.strip()
]
def _adjust_cluster_bboxes(self, clusters: List[Cluster]) -> List[Cluster]:
@ -614,10 +618,10 @@ class LayoutPostprocessor:
continue
cells_bbox = BoundingBox(
l=min(cell.bbox.l for cell in cluster.cells),
t=min(cell.bbox.t for cell in cluster.cells),
r=max(cell.bbox.r for cell in cluster.cells),
b=max(cell.bbox.b for cell in cluster.cells),
l=min(cell.rect.to_bounding_box().l for cell in cluster.cells),
t=min(cell.rect.to_bounding_box().t for cell in cluster.cells),
r=max(cell.rect.to_bounding_box().r for cell in cluster.cells),
b=max(cell.rect.to_bounding_box().b for cell in cluster.cells),
)
if cluster.label == DocItemLabel.TABLE:
@ -633,9 +637,9 @@ class LayoutPostprocessor:
return clusters
def _sort_cells(self, cells: List[Cell]) -> List[Cell]:
def _sort_cells(self, cells: List[TextCell]) -> List[TextCell]:
"""Sort cells in native reading order."""
return sorted(cells, key=lambda c: (c.id))
return sorted(cells, key=lambda c: (c.index))
def _sort_clusters(
self, clusters: List[Cluster], mode: str = "id"
@ -646,7 +650,7 @@ class LayoutPostprocessor:
clusters,
key=lambda cluster: (
(
min(cell.id for cell in cluster.cells)
min(cell.index for cell in cluster.cells)
if cluster.cells
else sys.maxsize
),

3
docling/utils/locks.py Normal file
View File

@ -0,0 +1,3 @@
import threading
pypdfium2_lock = threading.Lock()

View File

@ -0,0 +1,97 @@
import logging
from pathlib import Path
from typing import Optional
from docling.datamodel.pipeline_options import (
granite_picture_description,
smolvlm_picture_description,
)
from docling.datamodel.settings import settings
from docling.models.code_formula_model import CodeFormulaModel
from docling.models.document_picture_classifier import DocumentPictureClassifier
from docling.models.easyocr_model import EasyOcrModel
from docling.models.layout_model import LayoutModel
from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel
from docling.models.table_structure_model import TableStructureModel
_log = logging.getLogger(__name__)
def download_models(
output_dir: Optional[Path] = None,
*,
force: bool = False,
progress: bool = False,
with_layout: bool = True,
with_tableformer: bool = True,
with_code_formula: bool = True,
with_picture_classifier: bool = True,
with_smolvlm: bool = False,
with_granite_vision: bool = False,
with_easyocr: bool = True,
):
if output_dir is None:
output_dir = settings.cache_dir / "models"
# Make sure the folder exists
output_dir.mkdir(exist_ok=True, parents=True)
if with_layout:
_log.info(f"Downloading layout model...")
LayoutModel.download_models(
local_dir=output_dir / LayoutModel._model_repo_folder,
force=force,
progress=progress,
)
if with_tableformer:
_log.info(f"Downloading tableformer model...")
TableStructureModel.download_models(
local_dir=output_dir / TableStructureModel._model_repo_folder,
force=force,
progress=progress,
)
if with_picture_classifier:
_log.info(f"Downloading picture classifier model...")
DocumentPictureClassifier.download_models(
local_dir=output_dir / DocumentPictureClassifier._model_repo_folder,
force=force,
progress=progress,
)
if with_code_formula:
_log.info(f"Downloading code formula model...")
CodeFormulaModel.download_models(
local_dir=output_dir / CodeFormulaModel._model_repo_folder,
force=force,
progress=progress,
)
if with_smolvlm:
_log.info(f"Downloading SmolVlm model...")
PictureDescriptionVlmModel.download_models(
repo_id=smolvlm_picture_description.repo_id,
local_dir=output_dir / smolvlm_picture_description.repo_cache_folder,
force=force,
progress=progress,
)
if with_granite_vision:
_log.info(f"Downloading Granite Vision model...")
PictureDescriptionVlmModel.download_models(
repo_id=granite_picture_description.repo_id,
local_dir=output_dir / granite_picture_description.repo_cache_folder,
force=force,
progress=progress,
)
if with_easyocr:
_log.info(f"Downloading easyocr models...")
EasyOcrModel.download_models(
local_dir=output_dir / EasyOcrModel._model_repo_folder,
force=force,
progress=progress,
)
return output_dir

View File

@ -0,0 +1,9 @@
def map_tesseract_script(script: str) -> str:
r""" """
if script == "Katakana" or script == "Hiragana":
script = "Japanese"
elif script == "Han":
script = "HanS"
elif script == "Korean":
script = "Hangul"
return script

View File

@ -4,6 +4,9 @@ from itertools import islice
from pathlib import Path
from typing import List, Union
import requests
from tqdm import tqdm
def chunkify(iterator, chunk_size):
"""Yield successive chunks of chunk_size from the iterable."""
@ -39,3 +42,24 @@ def create_hash(string: str):
hasher.update(string.encode("utf-8"))
return hasher.hexdigest()
def download_url_with_progress(url: str, progress: bool = False) -> BytesIO:
buf = BytesIO()
with requests.get(url, stream=True, allow_redirects=True) as response:
total_size = int(response.headers.get("content-length", 0))
progress_bar = tqdm(
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
disable=(not progress),
)
for chunk in response.iter_content(10 * 1024):
buf.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
buf.seek(0)
return buf

View File

@ -0,0 +1,85 @@
from docling_core.types.doc import DocItemLabel
from PIL import Image, ImageDraw, ImageFont
from PIL.ImageFont import FreeTypeFont
from docling.datamodel.base_models import Cluster
def draw_clusters(
image: Image.Image, clusters: list[Cluster], scale_x: float, scale_y: float
) -> None:
"""
Draw clusters on an image
"""
draw = ImageDraw.Draw(image, "RGBA")
# Create a smaller font for the labels
font: ImageFont.ImageFont | FreeTypeFont
try:
font = ImageFont.truetype("arial.ttf", 12)
except OSError:
# Fallback to default font if arial is not available
font = ImageFont.load_default()
for c_tl in clusters:
all_clusters = [c_tl, *c_tl.children]
for c in all_clusters:
# Draw cells first (underneath)
cell_color = (0, 0, 0, 40) # Transparent black for cells
for tc in c.cells:
cx0, cy0, cx1, cy1 = tc.rect.to_bounding_box().as_tuple()
cx0 *= scale_x
cx1 *= scale_x
cy0 *= scale_x
cy1 *= scale_y
draw.rectangle(
[(cx0, cy0), (cx1, cy1)],
outline=None,
fill=cell_color,
)
# Draw cluster rectangle
x0, y0, x1, y1 = c.bbox.as_tuple()
x0 *= scale_x
x1 *= scale_x
y0 *= scale_x
y1 *= scale_y
if y1 <= y0:
y1, y0 = y0, y1
if x1 <= x0:
x1, x0 = x0, x1
cluster_fill_color = (*list(DocItemLabel.get_color(c.label)), 70)
cluster_outline_color = (
*list(DocItemLabel.get_color(c.label)),
255,
)
draw.rectangle(
[(x0, y0), (x1, y1)],
outline=cluster_outline_color,
fill=cluster_fill_color,
)
# Add label name and confidence
label_text = f"{c.label.name} ({c.confidence:.2f})"
# Create semi-transparent background for text
text_bbox = draw.textbbox((x0, y0), label_text, font=font)
text_bg_padding = 2
draw.rectangle(
[
(
text_bbox[0] - text_bg_padding,
text_bbox[1] - text_bg_padding,
),
(
text_bbox[2] + text_bg_padding,
text_bbox[3] + text_bg_padding,
),
],
fill=(255, 255, 255, 180), # Semi-transparent white
)
# Draw text
draw.text(
(x0, y0),
label_text,
fill=(0, 0, 0, 255), # Solid black
font=font,
)

View File

@ -1,5 +1,18 @@
## Introduction
!!! note "Chunking approaches"
Starting from a `DoclingDocument`, there are in principle two possible chunking
approaches:
1. exporting the `DoclingDocument` to Markdown (or similar format) and then
performing user-defined chunking as a post-processing step, or
2. using native Docling chunkers, i.e. operating directly on the `DoclingDocument`
This page is about the latter, i.e. using native Docling chunkers.
For an example of using approach (1) check out e.g.
[this recipe](../examples/rag_langchain.ipynb) looking at the Markdown export mode.
A *chunker* is a Docling abstraction that, given a
[`DoclingDocument`](./docling_document.md), returns a stream of chunks, each of which
captures some part of the document as a string accompanied by respective metadata.
@ -54,12 +67,12 @@ tokens), &
chunks with same headings & captions) — users can opt out of this step via param
`merge_peers` (by default `True`)
👉 Example: see [here](../../examples/hybrid_chunking).
👉 Example: see [here](../examples/hybrid_chunking.ipynb).
## Hierarchical Chunker
The `HierarchicalChunker` implementation uses the document structure information from
the [`DoclingDocument`](../docling_document) to create one chunk for each individual
the [`DoclingDocument`](./docling_document.md) to create one chunk for each individual
detected document element, by default only merging together list items (can be opted out
via param `merge_list_items`). It also takes care of attaching all relevant document
metadata, including headers and captions.

View File

@ -7,7 +7,7 @@ pydantic datatype, which can express several features common to documents, such
* Layout information (i.e. bounding boxes) for all items, if available
* Provenance information
The definition of the Pydantic types is implemented in the module `docling_core.types.doc`, more details in [source code definitions](https://github.com/DS4SD/docling-core/tree/main/docling_core/types/doc).
The definition of the Pydantic types is implemented in the module `docling_core.types.doc`, more details in [source code definitions](https://github.com/docling-project/docling-core/tree/main/docling_core/types/doc).
It also brings a set of document construction APIs to build up a `DoclingDocument` from scratch.

View File

@ -0,0 +1,80 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Conversion of CSV files\n",
"\n",
"This example shows how to convert CSV files to a structured Docling Document.\n",
"\n",
"* Multiple delimiters are supported: `,` `;` `|` `[tab]`\n",
"* Additional CSV dialect settings are detected automatically (e.g. quotes, line separator, escape character)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example Code"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"from docling.document_converter import DocumentConverter\n",
"\n",
"# Convert CSV to Docling document\n",
"converter = DocumentConverter()\n",
"result = converter.convert(Path(\"../../tests/data/csv/csv-comma.csv\"))\n",
"output = result.document.export_to_markdown()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This code generates the following output:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"| Index | Customer Id | First Name | Last Name | Company | City | Country | Phone 1 | Phone 2 | Email | Subscription Date | Website |\n",
"|---------|-----------------|--------------|-------------|---------------------------------|-------------------|----------------------------|------------------------|-----------------------|-----------------------------|---------------------|-----------------------------|\n",
"| 1 | DD37Cf93aecA6Dc | Sheryl | Baxter | Rasmussen Group | East Leonard | Chile | 229.077.5154 | 397.884.0519x718 | zunigavanessa@smith.info | 2020-08-24 | http://www.stephenson.com/ |\n",
"| 2 | 1Ef7b82A4CAAD10 | Preston | Lozano, Dr | Vega-Gentry | East Jimmychester | Djibouti | 5153435776 | 686-620-1820x944 | vmata@colon.com | 2021-04-23 | http://www.hobbs.com/ |\n",
"| 3 | 6F94879bDAfE5a6 | Roy | Berry | Murillo-Perry | Isabelborough | Antigua and Barbuda | +1-539-402-0259 | (496)978-3969x58947 | beckycarr@hogan.com | 2020-03-25 | http://www.lawrence.com/ |\n",
"| 4 | 5Cef8BFA16c5e3c | Linda | Olsen | Dominguez, Mcmillan and Donovan | Bensonview | Dominican Republic | 001-808-617-6467x12895 | +1-813-324-8756 | stanleyblackwell@benson.org | 2020-06-02 | http://www.good-lyons.com/ |\n",
"| 5 | 053d585Ab6b3159 | Joanna | Bender | Martin, Lang and Andrade | West Priscilla | Slovakia (Slovak Republic) | 001-234-203-0635x76146 | 001-199-446-3860x3486 | colinalvarado@miles.net | 2021-04-17 | https://goodwin-ingram.com/ |"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "docling-TtEIaPrw-py3.12",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,931 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/docling-project/docling/blob/main/docs/examples/backend_xml_rag.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Conversion of custom XML"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"| Step | Tech | Execution | \n",
"| --- | --- | --- |\n",
"| Embedding | Hugging Face / Sentence Transformers | 💻 Local |\n",
"| Vector store | Milvus | 💻 Local |\n",
"| Gen AI | Hugging Face Inference API | 🌐 Remote | "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Overview"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is an example of using [Docling](https://docling-project.github.io/docling/) for converting structured data (XML) into a unified document\n",
"representation format, `DoclingDocument`, and leverage its riched structured content for RAG applications.\n",
"\n",
"Data used in this example consist of patents from the [United States Patent and Trademark Office (USPTO)](https://www.uspto.gov/) and medical\n",
"articles from [PubMed Central® (PMC)](https://pmc.ncbi.nlm.nih.gov/).\n",
"\n",
"In this notebook, we accomplish the following:\n",
"- [Simple conversion](#simple-conversion) of supported XML files in a nutshell\n",
"- An [end-to-end application](#end-to-end-application) using public collections of XML files supported by Docling\n",
" - [Setup](#setup) the API access for generative AI\n",
" - [Fetch the data](#fetch-the-data) from USPTO and PubMed Central® sites, using Docling custom backends\n",
" - [Parse, chunk, and index](#parse-chunk-and-index) the documents in a vector database\n",
" - [Perform RAG](#question-answering-with-rag) using [LlamaIndex Docling extension](../../integrations/llamaindex/)\n",
"\n",
"For more details on document chunking with Docling, refer to the [Chunking](../../concepts/chunking/) documentation. For RAG with Docling and LlamaIndex, also check the example [RAG with LlamaIndex](../rag_llamaindex/)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simple conversion\n",
"\n",
"The XML file format defines and stores data in a format that is both human-readable and machine-readable.\n",
"Because of this flexibility, Docling requires custom backend processors to interpret XML definitions and convert them into `DoclingDocument` objects.\n",
"\n",
"Some public data collections in XML format are already supported by Docling (USTPO patents and PMC articles). In these cases, the document conversion is straightforward and the same as with any other supported format, such as PDF or HTML. The execution example in [Simple Conversion](../minimal/) is the recommended usage of Docling for a single file:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ConversionStatus.SUCCESS\n"
]
}
],
"source": [
"from docling.document_converter import DocumentConverter\n",
"\n",
"# a sample PMC article:\n",
"source = \"../../tests/data/jats/elife-56337.nxml\"\n",
"converter = DocumentConverter()\n",
"result = converter.convert(source)\n",
"print(result.status)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Once the document is converted, it can be exported to any format supported by Docling. For instance, to markdown (showing here the first lines only):"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"# KRAB-zinc finger protein gene expansion in response to active retrotransposons in the murine lineage\n",
"\n",
"Gernot Wolf, Alberto de Iaco, Ming-An Sun, Melania Bruno, Matthew Tinkham, Don Hoang, Apratim Mitra, Sherry Ralls, Didier Trono, Todd S Macfarlan\n",
"\n",
"The Eunice Kennedy Shriver National Institute of Child Health and Human Development, The National Institutes of Health, Bethesda, United States; School of Life Sciences, École Polytechnique Fédérale de Lausanne (EPFL), Lausanne, Switzerland\n",
"\n",
"## Abstract\n",
"\n"
]
}
],
"source": [
"md_doc = result.document.export_to_markdown()\n",
"\n",
"delim = \"\\n\"\n",
"print(delim.join(md_doc.split(delim)[:8]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If the XML file is not supported, a `ConversionError` message will be raised."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Input document docling_test.xml does not match any allowed format.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"File format not allowed: docling_test.xml\n"
]
}
],
"source": [
"from io import BytesIO\n",
"\n",
"from docling.datamodel.base_models import DocumentStream\n",
"from docling.exceptions import ConversionError\n",
"\n",
"xml_content = (\n",
" b'<?xml version=\"1.0\" encoding=\"UTF-8\"?><!DOCTYPE docling_test SYSTEM '\n",
" b'\"test.dtd\"><docling>Random content</docling>'\n",
")\n",
"stream = DocumentStream(name=\"docling_test.xml\", stream=BytesIO(xml_content))\n",
"try:\n",
" result = converter.convert(stream)\n",
"except ConversionError as ce:\n",
" print(ce)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can always refer to the [Usage](../../usage/#supported-formats) documentation page for a list of supported formats."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## End-to-end application\n",
"\n",
"This section describes a step-by-step application for processing XML files from supported public collections and use them for question-answering."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Requirements can be installed as shown below. The `--no-warn-conflicts` argument is meant for Colab's pre-populated Python environment, feel free to remove for stricter usage."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install -q --progress-bar off --no-warn-conflicts llama-index-core llama-index-readers-docling llama-index-node-parser-docling llama-index-embeddings-huggingface llama-index-llms-huggingface-api llama-index-vector-stores-milvus llama-index-readers-file python-dotenv"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook uses HuggingFace's Inference API. For an increased LLM quota, a token can be provided via the environment variable `HF_TOKEN`.\n",
"\n",
"If you're running this notebook in Google Colab, make sure you [add](https://medium.com/@parthdasawant/how-to-use-secrets-in-google-colab-450c38e3ec75) your API key as a secret."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from warnings import filterwarnings\n",
"\n",
"from dotenv import load_dotenv\n",
"\n",
"\n",
"def _get_env_from_colab_or_os(key):\n",
" try:\n",
" from google.colab import userdata\n",
"\n",
" try:\n",
" return userdata.get(key)\n",
" except userdata.SecretNotFoundError:\n",
" pass\n",
" except ImportError:\n",
" pass\n",
" return os.getenv(key)\n",
"\n",
"\n",
"load_dotenv()\n",
"\n",
"filterwarnings(action=\"ignore\", category=UserWarning, module=\"pydantic\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now define the main parameters:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from tempfile import mkdtemp\n",
"\n",
"from llama_index.embeddings.huggingface import HuggingFaceEmbedding\n",
"from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI\n",
"\n",
"EMBED_MODEL_ID = \"BAAI/bge-small-en-v1.5\"\n",
"EMBED_MODEL = HuggingFaceEmbedding(model_name=EMBED_MODEL_ID)\n",
"TEMP_DIR = Path(mkdtemp())\n",
"MILVUS_URI = str(TEMP_DIR / \"docling.db\")\n",
"GEN_MODEL = HuggingFaceInferenceAPI(\n",
" token=_get_env_from_colab_or_os(\"HF_TOKEN\"),\n",
" model_name=\"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
")\n",
"embed_dim = len(EMBED_MODEL.get_text_embedding(\"hi\"))\n",
"# https://github.com/huggingface/transformers/issues/5486:\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Fetch the data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook we will use XML data from collections supported by Docling:\n",
"- Medical articles from the [PubMed Central® (PMC)](https://pmc.ncbi.nlm.nih.gov/). They are available in an [FTP server](https://ftp.ncbi.nlm.nih.gov/pub/pmc/) as `.tar.gz` files. Each file contains the full article data in XML format, among other supplementary files like images or spreadsheets.\n",
"- Patents from the [United States Patent and Trademark Office](https://www.uspto.gov/). They are available in the [Bulk Data Storage System (BDSS)](https://bulkdata.uspto.gov/) as zip files. Each zip file may contain several patents in XML format.\n",
"\n",
"The raw files will be downloaded form the source and saved in a temporary directory."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### PMC articles\n",
"\n",
"The [OA file](https://ftp.ncbi.nlm.nih.gov/pub/pmc/oa_file_list.csv) is a manifest file of all the PMC articles, including the URL path to download the source files. In this notebook we will use as example the article [Pathogens spread by high-altitude windborne mosquitoes](https://pmc.ncbi.nlm.nih.gov/articles/PMC11703268/), which is available in the archive file [PMC11703268.tar.gz](https://ftp.ncbi.nlm.nih.gov/pub/pmc/oa_package/e3/6b/PMC11703268.tar.gz)."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading https://ftp.ncbi.nlm.nih.gov/pub/pmc/oa_package/e3/6b/PMC11703268.tar.gz...\n",
"Extracting and storing the XML file containing the article text...\n",
"Stored XML file nihpp-2024.12.26.630351v1.nxml\n"
]
}
],
"source": [
"import tarfile\n",
"from io import BytesIO\n",
"\n",
"import requests\n",
"\n",
"# PMC article PMC11703268\n",
"url: str = \"https://ftp.ncbi.nlm.nih.gov/pub/pmc/oa_package/e3/6b/PMC11703268.tar.gz\"\n",
"\n",
"print(f\"Downloading {url}...\")\n",
"buf = BytesIO(requests.get(url).content)\n",
"print(\"Extracting and storing the XML file containing the article text...\")\n",
"with tarfile.open(fileobj=buf, mode=\"r:gz\") as tar_file:\n",
" for tarinfo in tar_file:\n",
" if tarinfo.isreg():\n",
" file_path = Path(tarinfo.name)\n",
" if file_path.suffix == \".nxml\":\n",
" with open(TEMP_DIR / file_path.name, \"wb\") as file_obj:\n",
" file_obj.write(tar_file.extractfile(tarinfo).read())\n",
" print(f\"Stored XML file {file_path.name}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### USPTO patents\n",
"\n",
"Since each USPTO file is a concatenation of several patents, we need to split its content into valid XML pieces. The following code downloads a sample zip file, split its content in sections, and dumps each section as an XML file. For simplicity, this pipeline is shown here in a sequential manner, but it could be parallelized."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading https://bulkdata.uspto.gov/data/patent/grant/redbook/fulltext/2024/ipg241217.zip...\n",
"Parsing zip file, splitting into XML sections, and exporting to files...\n"
]
}
],
"source": [
"import zipfile\n",
"\n",
"# Patent grants from December 17-23, 2024\n",
"url: str = (\n",
" \"https://bulkdata.uspto.gov/data/patent/grant/redbook/fulltext/2024/ipg241217.zip\"\n",
")\n",
"XML_SPLITTER: str = '<?xml version=\"1.0\"'\n",
"doc_num: int = 0\n",
"\n",
"print(f\"Downloading {url}...\")\n",
"buf = BytesIO(requests.get(url).content)\n",
"print(f\"Parsing zip file, splitting into XML sections, and exporting to files...\")\n",
"with zipfile.ZipFile(buf) as zf:\n",
" res = zf.testzip()\n",
" if res:\n",
" print(\"Error validating zip file\")\n",
" else:\n",
" with zf.open(zf.namelist()[0]) as xf:\n",
" is_patent = False\n",
" patent_buffer = BytesIO()\n",
" for xf_line in xf:\n",
" decoded_line = xf_line.decode(errors=\"ignore\").rstrip()\n",
" xml_index = decoded_line.find(XML_SPLITTER)\n",
" if xml_index != -1:\n",
" if (\n",
" xml_index > 0\n",
" ): # cases like </sequence-cwu><?xml version=\"1.0\"...\n",
" patent_buffer.write(xf_line[:xml_index])\n",
" patent_buffer.write(b\"\\r\\n\")\n",
" xf_line = xf_line[xml_index:]\n",
" if patent_buffer.getbuffer().nbytes > 0 and is_patent:\n",
" doc_num += 1\n",
" patent_id = f\"ipg241217-{doc_num}\"\n",
" with open(TEMP_DIR / f\"{patent_id}.xml\", \"wb\") as file_obj:\n",
" file_obj.write(patent_buffer.getbuffer())\n",
" is_patent = False\n",
" patent_buffer = BytesIO()\n",
" elif decoded_line.startswith(\"<!DOCTYPE\"):\n",
" is_patent = True\n",
" patent_buffer.write(xf_line)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fetched and exported 4014 documents.\n"
]
}
],
"source": [
"print(f\"Fetched and exported {doc_num} documents.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using the backend converter (optional)\n",
"\n",
"- The custom backend converters `PubMedDocumentBackend` and `PatentUsptoDocumentBackend` aim at handling the parsing of PMC articles and USPTO patents, respectively.\n",
"- As any other backends, you can leverage the function `is_valid()` to check if the input document is supported by the this backend.\n",
"- Note that some XML sections in the original USPTO zip file may not represent patents, like sequence listings, and therefore they will show as invalid by the backend."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Document nihpp-2024.12.26.630351v1.nxml is a valid PMC article? True\n",
"Document ipg241217-1.xml is a valid patent? True\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "316241ca89a843bda3170f2a5c76c639",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/4014 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 3928 patents out of 4014 XML files.\n"
]
}
],
"source": [
"from tqdm.notebook import tqdm\n",
"\n",
"from docling.backend.xml.jats_backend import JatsDocumentBackend\n",
"from docling.backend.xml.uspto_backend import PatentUsptoDocumentBackend\n",
"from docling.datamodel.base_models import InputFormat\n",
"from docling.datamodel.document import InputDocument\n",
"\n",
"# check PMC\n",
"in_doc = InputDocument(\n",
" path_or_stream=TEMP_DIR / \"nihpp-2024.12.26.630351v1.nxml\",\n",
" format=InputFormat.XML_JATS,\n",
" backend=JatsDocumentBackend,\n",
")\n",
"backend = JatsDocumentBackend(\n",
" in_doc=in_doc, path_or_stream=TEMP_DIR / \"nihpp-2024.12.26.630351v1.nxml\"\n",
")\n",
"print(f\"Document {in_doc.file.name} is a valid PMC article? {backend.is_valid()}\")\n",
"\n",
"# check USPTO\n",
"in_doc = InputDocument(\n",
" path_or_stream=TEMP_DIR / \"ipg241217-1.xml\",\n",
" format=InputFormat.XML_USPTO,\n",
" backend=PatentUsptoDocumentBackend,\n",
")\n",
"backend = PatentUsptoDocumentBackend(\n",
" in_doc=in_doc, path_or_stream=TEMP_DIR / \"ipg241217-1.xml\"\n",
")\n",
"print(f\"Document {in_doc.file.name} is a valid patent? {backend.is_valid()}\")\n",
"\n",
"patent_valid = 0\n",
"pbar = tqdm(TEMP_DIR.glob(\"*.xml\"), total=doc_num)\n",
"for in_path in pbar:\n",
" in_doc = InputDocument(\n",
" path_or_stream=in_path,\n",
" format=InputFormat.XML_USPTO,\n",
" backend=PatentUsptoDocumentBackend,\n",
" )\n",
" backend = PatentUsptoDocumentBackend(in_doc=in_doc, path_or_stream=in_path)\n",
" patent_valid += int(backend.is_valid())\n",
"\n",
"print(f\"Found {patent_valid} patents out of {doc_num} XML files.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Calling the function `convert()` will convert the input document into a `DoclingDocument`"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Patent \"Semiconductor package\" has 19 claims\n"
]
}
],
"source": [
"doc = backend.convert()\n",
"\n",
"claims_sec = [item for item in doc.texts if item.text == \"CLAIMS\"][0]\n",
"print(f'Patent \"{doc.texts[0].text}\" has {len(claims_sec.children)} claims')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"✏️ **Tip**: in general, there is no need to use the backend converters to parse USPTO or JATS (PubMed) XML files. The generic `DocumentConverter` object tries to guess the input document format and applies the corresponding backend parser. The conversion shown in [Simple Conversion](#simple-conversion) is the recommended usage for the supported XML files."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Parse, chunk, and index"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `DoclingDocument` format of the converted patents has a rich hierarchical structure, inherited from the original XML document and preserved by the Docling custom backend.\n",
"In this notebook, we will leverage:\n",
"- The `SimpleDirectoryReader` pattern to iterate over the exported XML files created in section [Fetch the data](#fetch-the-data).\n",
"- The LlamaIndex extensions, `DoclingReader` and `DoclingNodeParser`, to ingest the patent chunks into a Milvus vectore store.\n",
"- The `HierarchicalChunker` implementation, which applies a document-based hierarchical chunking, to leverage the patent structures like sections and paragraphs within sections.\n",
"\n",
"Refer to other possible implementations and usage patterns in the [Chunking](../../concepts/chunking/) documentation and the [RAG with LlamaIndex](../rag_llamaindex/) notebook."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Set the Docling reader and the directory reader\n",
"\n",
"Note that `DoclingReader` uses Docling's `DocumentConverter` by default and therefore it will recognize the format of the XML files and leverage the `PatentUsptoDocumentBackend` automatically.\n",
"\n",
"For demonstration purposes, we limit the scope of the analysis to the first 100 patents."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"from llama_index.core import SimpleDirectoryReader\n",
"from llama_index.readers.docling import DoclingReader\n",
"\n",
"reader = DoclingReader(export_type=DoclingReader.ExportType.JSON)\n",
"dir_reader = SimpleDirectoryReader(\n",
" input_dir=TEMP_DIR,\n",
" exclude=[\"docling.db\", \"*.nxml\"],\n",
" file_extractor={\".xml\": reader},\n",
" filename_as_id=True,\n",
" num_files_limit=100,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Set the node parser\n",
"\n",
"Note that the `HierarchicalChunker` is the default chunking implementation of the `DoclingNodeParser`."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from llama_index.node_parser.docling import DoclingNodeParser\n",
"\n",
"node_parser = DoclingNodeParser()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Set a local Milvus database and run the ingestion"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from llama_index.core import StorageContext, VectorStoreIndex\n",
"from llama_index.vector_stores.milvus import MilvusVectorStore\n",
"\n",
"vector_store = MilvusVectorStore(\n",
" uri=MILVUS_URI,\n",
" dim=embed_dim,\n",
" overwrite=True,\n",
")\n",
"\n",
"index = VectorStoreIndex.from_documents(\n",
" documents=dir_reader.load_data(show_progress=True),\n",
" transformations=[node_parser],\n",
" storage_context=StorageContext.from_defaults(vector_store=vector_store),\n",
" embed_model=EMBED_MODEL,\n",
" show_progress=True,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, add the PMC article to the vector store directly from the reader."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<llama_index.core.indices.vector_store.base.VectorStoreIndex at 0x373a7f7d0>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"index.from_documents(\n",
" documents=reader.load_data(TEMP_DIR / \"nihpp-2024.12.26.630351v1.nxml\"),\n",
" transformations=[node_parser],\n",
" storage_context=StorageContext.from_defaults(vector_store=vector_store),\n",
" embed_model=EMBED_MODEL,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Question-answering with RAG"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The retriever can be used to identify highly relevant documents:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Node ID: 5afd36c0-a739-4a88-a51c-6d0f75358db5\n",
"Text: The portable fitness monitoring device 102 may be a device such\n",
"as, for example, a mobile phone, a personal digital assistant, a music\n",
"file player (e.g. and MP3 player), an intelligent article for wearing\n",
"(e.g. a fitness monitoring garment, wrist band, or watch), a dongle\n",
"(e.g. a small hardware device that protects software) that includes a\n",
"fitn...\n",
"Score: 0.772\n",
"\n",
"Node ID: f294b5fd-9089-43cb-8c4e-d1095a634ff1\n",
"Text: US Patent Application US 20120071306 entitled “Portable\n",
"Multipurpose Whole Body Exercise Device” discloses a portable\n",
"multipurpose whole body exercise device which can be used for general\n",
"fitness, Pilates-type, core strengthening, therapeutic, and\n",
"rehabilitative exercises as well as stretching and physical therapy\n",
"and which includes storable acc...\n",
"Score: 0.749\n",
"\n",
"Node ID: 8251c7ef-1165-42e1-8c91-c99c8a711bf7\n",
"Text: Program products, methods, and systems for providing fitness\n",
"monitoring services of the present invention can include any software\n",
"application executed by one or more computing devices. A computing\n",
"device can be any type of computing device having one or more\n",
"processors. For example, a computing device can be a workstation,\n",
"mobile device (e.g., ...\n",
"Score: 0.744\n",
"\n"
]
}
],
"source": [
"retriever = index.as_retriever(similarity_top_k=3)\n",
"results = retriever.retrieve(\"What patents are related to fitness devices?\")\n",
"\n",
"for item in results:\n",
" print(item)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With the query engine, we can run the question-answering with the RAG pattern on the set of indexed documents.\n",
"\n",
"First, we can prompt the LLM directly:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">╭──────────────────────────────────────────────────── Prompt ─────────────────────────────────────────────────────╮</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">│</span> Do mosquitoes in high altitude expand viruses over large distances? <span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;31m╭─\u001b[0m\u001b[1;31m───────────────────────────────────────────────────\u001b[0m\u001b[1;31m Prompt \u001b[0m\u001b[1;31m────────────────────────────────────────────────────\u001b[0m\u001b[1;31m─╮\u001b[0m\n",
"\u001b[1;31m│\u001b[0m Do mosquitoes in high altitude expand viruses over large distances? \u001b[1;31m│\u001b[0m\n",
"\u001b[1;31m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">╭─────────────────────────────────────────────── Generated Content ───────────────────────────────────────────────╮</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> Mosquitoes can be found at high altitudes, but their ability to transmit viruses over long distances is not <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> primarily dependent on altitude. Mosquitoes are vectors for various diseases, such as malaria, dengue fever, <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> and Zika virus, and their transmission range is more closely related to their movement, the presence of a host, <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> and environmental conditions that support their survival and reproduction. <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> At high altitudes, the environment can be less suitable for mosquitoes due to factors such as colder <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> temperatures, lower humidity, and stronger winds, which can limit their population size and distribution. <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> However, some species of mosquitoes have adapted to high-altitude environments and can still transmit diseases <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> in these areas. <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> It is possible for mosquitoes to be transported by wind or human activities to higher altitudes, but this is <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> not a significant factor in their ability to transmit viruses over long distances. Instead, long-distance <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> transmission of viruses is more often associated with human travel and transportation, which can rapidly spread <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> infected mosquitoes or humans to new areas, leading to the spread of disease. <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;32m╭─\u001b[0m\u001b[1;32m──────────────────────────────────────────────\u001b[0m\u001b[1;32m Generated Content \u001b[0m\u001b[1;32m──────────────────────────────────────────────\u001b[0m\u001b[1;32m─╮\u001b[0m\n",
"\u001b[1;32m│\u001b[0m Mosquitoes can be found at high altitudes, but their ability to transmit viruses over long distances is not \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m primarily dependent on altitude. Mosquitoes are vectors for various diseases, such as malaria, dengue fever, \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m and Zika virus, and their transmission range is more closely related to their movement, the presence of a host, \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m and environmental conditions that support their survival and reproduction. \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m At high altitudes, the environment can be less suitable for mosquitoes due to factors such as colder \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m temperatures, lower humidity, and stronger winds, which can limit their population size and distribution. \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m However, some species of mosquitoes have adapted to high-altitude environments and can still transmit diseases \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m in these areas. \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m It is possible for mosquitoes to be transported by wind or human activities to higher altitudes, but this is \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m not a significant factor in their ability to transmit viruses over long distances. Instead, long-distance \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m transmission of viruses is more often associated with human travel and transportation, which can rapidly spread \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m infected mosquitoes or humans to new areas, leading to the spread of disease. \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from llama_index.core.base.llms.types import ChatMessage, MessageRole\n",
"from rich.console import Console\n",
"from rich.panel import Panel\n",
"\n",
"console = Console()\n",
"query = \"Do mosquitoes in high altitude expand viruses over large distances?\"\n",
"\n",
"usr_msg = ChatMessage(role=MessageRole.USER, content=query)\n",
"response = GEN_MODEL.chat(messages=[usr_msg])\n",
"\n",
"console.print(Panel(query, title=\"Prompt\", border_style=\"bold red\"))\n",
"console.print(\n",
" Panel(\n",
" response.message.content.strip(),\n",
" title=\"Generated Content\",\n",
" border_style=\"bold green\",\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can compare the response when the model is prompted with the indexed PMC article as supporting context:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">╭────────────────────────────────────────── Generated Content with RAG ───────────────────────────────────────────╮</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> Yes, mosquitoes in high altitude can expand viruses over large distances. A study intercepted 1,017 female <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> mosquitoes at altitudes of 120-290 m above ground over Mali and Ghana and screened them for infection with <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> arboviruses, plasmodia, and filariae. The study found that 3.5% of the mosquitoes were infected with <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> flaviviruses, and 1.1% were infectious. Additionally, the study identified 19 mosquito-borne pathogens, <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> including three arboviruses that affect humans (dengue, West Nile, and MPoko viruses). The study provides <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span> compelling evidence that mosquito-borne pathogens are often spread by windborne mosquitoes at altitude. <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">│</span>\n",
"<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[1;32m╭─\u001b[0m\u001b[1;32m─────────────────────────────────────────\u001b[0m\u001b[1;32m Generated Content with RAG \u001b[0m\u001b[1;32m──────────────────────────────────────────\u001b[0m\u001b[1;32m─╮\u001b[0m\n",
"\u001b[1;32m│\u001b[0m Yes, mosquitoes in high altitude can expand viruses over large distances. A study intercepted 1,017 female \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m mosquitoes at altitudes of 120-290 m above ground over Mali and Ghana and screened them for infection with \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m arboviruses, plasmodia, and filariae. The study found that 3.5% of the mosquitoes were infected with \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m flaviviruses, and 1.1% were infectious. Additionally, the study identified 19 mosquito-borne pathogens, \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m including three arboviruses that affect humans (dengue, West Nile, and MPoko viruses). The study provides \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m│\u001b[0m compelling evidence that mosquito-borne pathogens are often spread by windborne mosquitoes at altitude. \u001b[1;32m│\u001b[0m\n",
"\u001b[1;32m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from llama_index.core.vector_stores import ExactMatchFilter, MetadataFilters\n",
"\n",
"filters = MetadataFilters(\n",
" filters=[\n",
" ExactMatchFilter(key=\"filename\", value=\"nihpp-2024.12.26.630351v1.nxml\"),\n",
" ]\n",
")\n",
"\n",
"query_engine = index.as_query_engine(llm=GEN_MODEL, filter=filters, similarity_top_k=3)\n",
"result = query_engine.query(query)\n",
"\n",
"console.print(\n",
" Panel(\n",
" result.response.strip(),\n",
" title=\"Generated Content with RAG\",\n",
" border_style=\"bold green\",\n",
" )\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -5,16 +5,19 @@ from pathlib import Path
from typing import Iterable
import yaml
from docling_core.types.doc import ImageRefMode
from docling.datamodel.base_models import ConversionStatus
from docling.backend.docling_parse_v4_backend import DoclingParseV4DocumentBackend
from docling.datamodel.base_models import ConversionStatus, InputFormat
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.datamodel.settings import settings
from docling.document_converter import DocumentConverter
from docling.document_converter import DocumentConverter, PdfFormatOption
_log = logging.getLogger(__name__)
USE_V2 = True
USE_LEGACY = True
USE_LEGACY = False
def export_documents(
@ -33,9 +36,26 @@ def export_documents(
doc_filename = conv_res.input.file.stem
if USE_V2:
# Export Docling document format to JSON:
with (output_dir / f"{doc_filename}.json").open("w") as fp:
fp.write(json.dumps(conv_res.document.export_to_dict()))
conv_res.document.save_as_json(
output_dir / f"{doc_filename}.json",
image_mode=ImageRefMode.PLACEHOLDER,
)
conv_res.document.save_as_html(
output_dir / f"{doc_filename}.html",
image_mode=ImageRefMode.EMBEDDED,
)
conv_res.document.save_as_document_tokens(
output_dir / f"{doc_filename}.doctags.txt"
)
conv_res.document.save_as_markdown(
output_dir / f"{doc_filename}.md",
image_mode=ImageRefMode.PLACEHOLDER,
)
conv_res.document.save_as_markdown(
output_dir / f"{doc_filename}.txt",
image_mode=ImageRefMode.PLACEHOLDER,
strict_text=True,
)
# Export Docling document format to YAML:
with (output_dir / f"{doc_filename}.yaml").open("w") as fp:
@ -103,10 +123,10 @@ def main():
logging.basicConfig(level=logging.INFO)
input_doc_paths = [
Path("./tests/data/2206.01062.pdf"),
Path("./tests/data/2203.01017v2.pdf"),
Path("./tests/data/2305.03393v1.pdf"),
Path("./tests/data/redp5110_sampled.pdf"),
Path("./tests/data/pdf/2206.01062.pdf"),
Path("./tests/data/pdf/2203.01017v2.pdf"),
Path("./tests/data/pdf/2305.03393v1.pdf"),
Path("./tests/data/pdf/redp5110_sampled.pdf"),
]
# buf = BytesIO(Path("./test/data/2206.01062.pdf").open("rb").read())
@ -119,7 +139,16 @@ def main():
# settings.debug.visualize_tables = True
# settings.debug.visualize_cells = True
doc_converter = DocumentConverter()
pipeline_options = PdfPipelineOptions()
pipeline_options.generate_page_images = True
doc_converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_options=pipeline_options, backend=DoclingParseV4DocumentBackend
)
}
)
start_time = time.time()

View File

@ -5,7 +5,11 @@ from pathlib import Path
from docling.backend.pypdfium2_backend import PyPdfiumDocumentBackend
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
PdfPipelineOptions,
)
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.models.ocr_mac_model import OcrMacOptions
from docling.models.tesseract_ocr_cli_model import TesseractCliOcrOptions
@ -17,7 +21,7 @@ _log = logging.getLogger(__name__)
def main():
logging.basicConfig(level=logging.INFO)
input_doc_path = Path("./tests/data/2206.01062.pdf")
input_doc_path = Path("./tests/data/pdf/2206.01062.pdf")
###########################################################################
@ -76,7 +80,7 @@ def main():
pipeline_options.table_structure_options.do_cell_matching = True
pipeline_options.ocr_options.lang = ["es"]
pipeline_options.accelerator_options = AcceleratorOptions(
num_threads=4, device=Device.AUTO
num_threads=4, device=AcceleratorDevice.AUTO
)
doc_converter = DocumentConverter(

View File

@ -0,0 +1,92 @@
# WARNING
# This example demonstrates only how to develop a new enrichment model.
# It does not run the actual formula understanding model.
import logging
from pathlib import Path
from typing import Iterable
from docling_core.types.doc import DocItemLabel, DoclingDocument, NodeItem, TextItem
from docling.datamodel.base_models import InputFormat, ItemAndImageEnrichmentElement
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.models.base_model import BaseItemAndImageEnrichmentModel
from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
class ExampleFormulaUnderstandingPipelineOptions(PdfPipelineOptions):
do_formula_understanding: bool = True
# A new enrichment model using both the document element and its image as input
class ExampleFormulaUnderstandingEnrichmentModel(BaseItemAndImageEnrichmentModel):
images_scale = 2.6
def __init__(self, enabled: bool):
self.enabled = enabled
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
return (
self.enabled
and isinstance(element, TextItem)
and element.label == DocItemLabel.FORMULA
)
def __call__(
self,
doc: DoclingDocument,
element_batch: Iterable[ItemAndImageEnrichmentElement],
) -> Iterable[NodeItem]:
if not self.enabled:
return
for enrich_element in element_batch:
enrich_element.image.show()
yield enrich_element.item
# How the pipeline can be extended.
class ExampleFormulaUnderstandingPipeline(StandardPdfPipeline):
def __init__(self, pipeline_options: ExampleFormulaUnderstandingPipelineOptions):
super().__init__(pipeline_options)
self.pipeline_options: ExampleFormulaUnderstandingPipelineOptions
self.enrichment_pipe = [
ExampleFormulaUnderstandingEnrichmentModel(
enabled=self.pipeline_options.do_formula_understanding
)
]
if self.pipeline_options.do_formula_understanding:
self.keep_backend = True
@classmethod
def get_default_options(cls) -> ExampleFormulaUnderstandingPipelineOptions:
return ExampleFormulaUnderstandingPipelineOptions()
# Example main. In the final version, we simply have to set do_formula_understanding to true.
def main():
logging.basicConfig(level=logging.INFO)
input_doc_path = Path("./tests/data/pdf/2203.01017v2.pdf")
pipeline_options = ExampleFormulaUnderstandingPipelineOptions()
pipeline_options.do_formula_understanding = True
doc_converter = DocumentConverter(
format_options={
InputFormat.PDF: PdfFormatOption(
pipeline_cls=ExampleFormulaUnderstandingPipeline,
pipeline_options=pipeline_options,
)
}
)
result = doc_converter.convert(input_doc_path)
if __name__ == "__main__":
main()

View File

@ -1,3 +1,7 @@
# WARNING
# This example demonstrates only how to develop a new enrichment model.
# It does not run the actual picture classifier model.
import logging
from pathlib import Path
from typing import Any, Iterable
@ -22,7 +26,6 @@ class ExamplePictureClassifierPipelineOptions(PdfPipelineOptions):
class ExamplePictureClassifierEnrichmentModel(BaseEnrichmentModel):
def __init__(self, enabled: bool):
self.enabled = enabled
@ -54,7 +57,6 @@ class ExamplePictureClassifierEnrichmentModel(BaseEnrichmentModel):
class ExamplePictureClassifierPipeline(StandardPdfPipeline):
def __init__(self, pipeline_options: ExamplePictureClassifierPipelineOptions):
super().__init__(pipeline_options)
self.pipeline_options: ExamplePictureClassifierPipeline
@ -73,7 +75,7 @@ class ExamplePictureClassifierPipeline(StandardPdfPipeline):
def main():
logging.basicConfig(level=logging.INFO)
input_doc_path = Path("./tests/data/2206.01062.pdf")
input_doc_path = Path("./tests/data/pdf/2206.01062.pdf")
pipeline_options = ExamplePictureClassifierPipelineOptions()
pipeline_options.images_scale = 2.0

View File

@ -16,7 +16,7 @@ IMAGE_RESOLUTION_SCALE = 2.0
def main():
logging.basicConfig(level=logging.INFO)
input_doc_path = Path("./tests/data/2206.01062.pdf")
input_doc_path = Path("./tests/data/pdf/2206.01062.pdf")
output_dir = Path("scratch")
# Important: For operating with page images, we must keep them, otherwise the DocumentConverter

View File

@ -19,7 +19,7 @@ IMAGE_RESOLUTION_SCALE = 2.0
def main():
logging.basicConfig(level=logging.INFO)
input_doc_path = Path("./tests/data/2206.01062.pdf")
input_doc_path = Path("./tests/data/pdf/2206.01062.pdf")
output_dir = Path("scratch")
# Important: For operating with page images, we must keep them, otherwise the DocumentConverter

View File

@ -12,7 +12,7 @@ _log = logging.getLogger(__name__)
def main():
logging.basicConfig(level=logging.INFO)
input_doc_path = Path("./tests/data/2206.01062.pdf")
input_doc_path = Path("./tests/data/pdf/2206.01062.pdf")
output_dir = Path("scratch")
doc_converter = DocumentConverter()

Some files were not shown because too many files have changed in this diff Show More