mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-26 20:14:47 +00:00
Create gpu_utils.py
Signed-off-by: jane-temcious <jane.temcious@iterate.ai>
This commit is contained in:
parent
46700e9f29
commit
5ab8792697
52
docling/utils/gpu_utils.py
Normal file
52
docling/utils/gpu_utils.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
"""Add commentMore actions
|
||||||
|
Utilities for managing GPU resources.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def clear_gpu_memory():
|
||||||
|
"""
|
||||||
|
Clear GPU memory cache for supported frameworks.
|
||||||
|
Currently supports PyTorch and TensorFlow.
|
||||||
|
"""
|
||||||
|
# Try to clear PyTorch CUDA cache
|
||||||
|
try:
|
||||||
|
if importlib.util.find_spec("torch") is not None:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
_log.info("Clearing PyTorch CUDA memory cache")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Handle Apple MPS (Metal Performance Shaders)
|
||||||
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
_log.info("MPS device detected - note that MPS may not support explicit memory clearing")
|
||||||
|
# MPS doesn't have an explicit memory clearing mechanism like CUDA
|
||||||
|
# But we can try to force garbage collection
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
except Exception as e:
|
||||||
|
_log.warning(f"Failed to clear PyTorch GPU memory: {e}")
|
||||||
|
|
||||||
|
# Try to clear TensorFlow GPU memory
|
||||||
|
try:
|
||||||
|
if importlib.util.find_spec("tensorflow") is not None:
|
||||||
|
import tensorflow as tf
|
||||||
|
if len(tf.config.list_physical_devices('GPU')) > 0:
|
||||||
|
_log.info("Clearing TensorFlow GPU memory")
|
||||||
|
for device in tf.config.list_physical_devices('GPU'):
|
||||||
|
try:
|
||||||
|
tf.config.experimental.reset_memory_stats(device)
|
||||||
|
except Exception as e:
|
||||||
|
_log.warning(f"Failed to reset memory stats for device {device}: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
_log.warning(f"Failed to clear TensorFlow GPU memory: {e}")
|
||||||
|
|
||||||
|
# Force Python garbage collection
|
||||||
|
try:
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
except Exception as e:
|
||||||
|
_log.warning(f"Failed to run garbage collection: {e}")
|
Loading…
Reference in New Issue
Block a user