mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-25 19:44:34 +00:00
Create gpu_utils.py
Signed-off-by: jane-temcious <jane.temcious@iterate.ai>
This commit is contained in:
parent
46700e9f29
commit
5ab8792697
52
docling/utils/gpu_utils.py
Normal file
52
docling/utils/gpu_utils.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""Add commentMore actions
|
||||
Utilities for managing GPU resources.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import importlib.util
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
def clear_gpu_memory():
|
||||
"""
|
||||
Clear GPU memory cache for supported frameworks.
|
||||
Currently supports PyTorch and TensorFlow.
|
||||
"""
|
||||
# Try to clear PyTorch CUDA cache
|
||||
try:
|
||||
if importlib.util.find_spec("torch") is not None:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
_log.info("Clearing PyTorch CUDA memory cache")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Handle Apple MPS (Metal Performance Shaders)
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
_log.info("MPS device detected - note that MPS may not support explicit memory clearing")
|
||||
# MPS doesn't have an explicit memory clearing mechanism like CUDA
|
||||
# But we can try to force garbage collection
|
||||
import gc
|
||||
gc.collect()
|
||||
except Exception as e:
|
||||
_log.warning(f"Failed to clear PyTorch GPU memory: {e}")
|
||||
|
||||
# Try to clear TensorFlow GPU memory
|
||||
try:
|
||||
if importlib.util.find_spec("tensorflow") is not None:
|
||||
import tensorflow as tf
|
||||
if len(tf.config.list_physical_devices('GPU')) > 0:
|
||||
_log.info("Clearing TensorFlow GPU memory")
|
||||
for device in tf.config.list_physical_devices('GPU'):
|
||||
try:
|
||||
tf.config.experimental.reset_memory_stats(device)
|
||||
except Exception as e:
|
||||
_log.warning(f"Failed to reset memory stats for device {device}: {e}")
|
||||
except Exception as e:
|
||||
_log.warning(f"Failed to clear TensorFlow GPU memory: {e}")
|
||||
|
||||
# Force Python garbage collection
|
||||
try:
|
||||
import gc
|
||||
gc.collect()
|
||||
except Exception as e:
|
||||
_log.warning(f"Failed to run garbage collection: {e}")
|
Loading…
Reference in New Issue
Block a user