"""
Result cache manager that coordinates in-memory and persistent storage.
"""
from __future__ import annotations
import atexit
import logging
import os
import queue
import threading
import time
import weakref
from dataclasses import dataclass
from typing import Any, Dict, Iterator, Optional, Sequence
from .memory_cache import LRUCache, SingleFlight
from .sqlite_backend import CacheEntry, PersistentCache
logger = logging.getLogger(__name__)
# Global registry for cleanup on exit
_active_cache_managers: weakref.WeakSet[
"ResultCacheManager"
] = weakref.WeakSet()
def _cleanup_all_cache_managers():
"""Cleanup all active cache managers on Python exit."""
for manager in list(_active_cache_managers):
try:
manager.close()
except Exception:
pass
# Register cleanup function to run on Python exit
atexit.register(_cleanup_all_cache_managers)
[docs]
@dataclass
class CacheRecord:
value: Any
expires_at: Optional[float]
namespace: str
version: str
[docs]
class ResultCacheManager:
"""Facade around memory + persistent cache layers."""
[docs]
def __init__(
self,
*,
memory_size: int = 256,
persistent_path: Optional[str] = None,
enabled: bool = True,
persistence_enabled: bool = True,
singleflight: bool = True,
default_ttl: Optional[int] = None,
async_persist: Optional[bool] = None,
async_queue_size: int = 10000,
):
self.enabled = enabled
self.default_ttl = default_ttl
self.memory = LRUCache(max_size=memory_size)
persistence_path = persistent_path
if persistence_path is None:
cache_dir = os.environ.get("TOOLUNIVERSE_CACHE_DIR")
if cache_dir:
persistence_path = os.path.join(cache_dir, "tooluniverse_cache.sqlite")
self.persistent = None
if persistence_enabled and persistence_path:
try:
self.persistent = PersistentCache(persistence_path, enable=True)
except Exception as exc:
logger.warning("Failed to initialize persistent cache: %s", exc)
self.persistent = None
self.singleflight = SingleFlight() if singleflight else None
self._init_async_persistence(async_persist, async_queue_size)
# Register this instance for cleanup on exit
_active_cache_managers.add(self)
# ------------------------------------------------------------------
# Helper methods
# ------------------------------------------------------------------
[docs]
@staticmethod
def compose_key(namespace: str, version: str, cache_key: str) -> str:
return f"{namespace}::{version}::{cache_key}"
def _now(self) -> float:
return time.time()
def _ttl_or_default(self, ttl: Optional[int]) -> Optional[int]:
return ttl if ttl is not None else self.default_ttl
def _init_async_persistence(
self, async_persist: Optional[bool], async_queue_size: int
) -> None:
if async_persist is None:
async_persist = os.getenv(
"TOOLUNIVERSE_CACHE_ASYNC_PERSIST", "true"
).lower() in ("true", "1", "yes")
self.async_persist = (
async_persist and self.persistent is not None and self.enabled
)
self._persist_queue: Optional["queue.Queue[tuple[str, Dict[str, Any]]]"] = None
self._worker_thread: Optional[threading.Thread] = None
# Always initialize shutdown event for safe cleanup
self._shutdown_event = threading.Event()
if not self.async_persist:
return
queue_size = max(1, async_queue_size)
self._persist_queue = queue.Queue(maxsize=queue_size)
self._worker_thread = threading.Thread(
target=self._async_worker,
name="ResultCacheWriter",
daemon=True,
)
self._worker_thread.start()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
def get(self, *, namespace: str, version: str, cache_key: str) -> Optional[Any]:
if not self.enabled:
return None
composed = self.compose_key(namespace, version, cache_key)
record = self.memory.get(composed)
if record:
if record.expires_at and record.expires_at <= self._now():
self.memory.delete(composed)
else:
return record.value
entry = self._get_from_persistent(composed)
if entry:
# Use expires_at from entry (stored in database) or calculate from ttl
expires_at = entry.expires_at
if expires_at is None and entry.ttl:
expires_at = entry.created_at + entry.ttl
# Check if entry has expired before returning
if expires_at and expires_at <= self._now():
# Entry has expired, delete from persistent cache and return None
if self.persistent:
try:
self.persistent.delete(composed)
except Exception:
pass
return None
# Entry is still valid, restore to memory cache and return
self.memory.set(
composed,
CacheRecord(
value=entry.value,
expires_at=expires_at,
namespace=namespace,
version=version,
),
)
return entry.value
return None
[docs]
def set(
self,
*,
namespace: str,
version: str,
cache_key: str,
value: Any,
ttl: Optional[int] = None,
):
if not self.enabled:
return
effective_ttl = self._ttl_or_default(ttl)
expires_at = self._now() + effective_ttl if effective_ttl else None
composed = self.compose_key(namespace, version, cache_key)
self.memory.set(
composed,
CacheRecord(
value=value,
expires_at=expires_at,
namespace=namespace,
version=version,
),
)
if self.persistent:
# Calculate expires_at and created_at here to ensure consistency
# between memory and persistent cache
now = self._now()
payload = {
"composed": composed,
"value": value,
"namespace": namespace,
"version": version,
"ttl": effective_ttl,
"created_at": now,
"expires_at": expires_at,
}
if not self._schedule_persist("set", payload):
self._perform_persist_set(**payload)
[docs]
def delete(self, *, namespace: str, version: str, cache_key: str):
composed = self.compose_key(namespace, version, cache_key)
self.memory.delete(composed)
if self.persistent:
try:
self.persistent.delete(composed)
except Exception as exc:
logger.warning("Persistent cache delete failed: %s", exc)
[docs]
def clear(self, namespace: Optional[str] = None):
if namespace:
# Clear matching namespace in memory
keys_to_remove = [
key
for key, record in self.memory.items()
if hasattr(record, "namespace") and record.namespace == namespace
]
for key in keys_to_remove:
self.memory.delete(key)
else:
self.memory.clear()
if self.persistent:
try:
self.flush()
self.persistent.clear(namespace=namespace)
except Exception as exc:
logger.warning("Persistent cache clear failed: %s", exc)
[docs]
def bulk_get(self, requests: Sequence[Dict[str, str]]) -> Dict[str, Any]:
"""Fetch multiple cache entries at once.
Args:
requests: Iterable of dicts containing ``namespace``, ``version`` and ``cache_key``.
Returns
Mapping of composed cache keys to cached values.
"""
if not self.enabled:
return {}
hits: Dict[str, Any] = {}
for request in requests:
namespace = request["namespace"]
version = request["version"]
cache_key = request["cache_key"]
value = self.get(
namespace=namespace,
version=version,
cache_key=cache_key,
)
if value is not None:
composed = self.compose_key(namespace, version, cache_key)
hits[composed] = value
return hits
[docs]
def stats(self) -> Dict[str, Any]:
return {
"enabled": self.enabled,
"memory": self.memory.stats(),
"persistent": (
self.persistent.stats() if self.persistent else {"enabled": False}
),
"async_persist": self.async_persist,
"pending_writes": (
self._persist_queue.qsize()
if self.async_persist and self._persist_queue is not None
else 0
),
}
[docs]
def dump(self, namespace: Optional[str] = None) -> Iterator[Dict[str, Any]]:
if not self.persistent:
return iter([])
self.flush()
return (
{
"cache_key": entry.key,
"namespace": entry.namespace,
"version": entry.version,
"ttl": entry.ttl,
"created_at": entry.created_at,
"last_accessed": entry.last_accessed,
"hit_count": entry.hit_count,
"value": entry.value,
}
for entry in self._iter_persistent(namespace=namespace)
)
def _get_from_persistent(self, composed_key: str) -> Optional[CacheEntry]:
if not self.persistent:
return None
try:
return self.persistent.get(composed_key)
except Exception as exc:
logger.warning("Persistent cache read failed: %s", exc)
self.persistent = None
return None
def _iter_persistent(self, namespace: Optional[str]):
if not self.persistent:
return iter([])
try:
return self.persistent.iter_entries(namespace=namespace)
except Exception as exc:
logger.warning("Persistent cache iterator failed: %s", exc)
return iter([])
# ------------------------------------------------------------------
# Context manager for singleflight
# ------------------------------------------------------------------
[docs]
def singleflight_guard(self, composed_key: str):
if self.singleflight:
return self.singleflight.acquire(composed_key)
return _DummyContext()
[docs]
def close(self):
"""Close the cache manager and cleanup resources."""
self.flush()
self._shutdown_async_worker()
if self.persistent:
try:
self.persistent.close()
except Exception as exc:
logger.warning("Persistent cache close failed: %s", exc)
# Remove from global registry
_active_cache_managers.discard(self)
[docs]
def __del__(self):
"""Ensure cleanup happens even if close() is not called explicitly."""
try:
# Only shutdown if attributes exist (object not partially constructed)
if hasattr(self, '_shutdown_event'):
self._shutdown_event.set()
if hasattr(self, '_worker_thread') and self._worker_thread is not None:
if self._worker_thread.is_alive():
# Signal shutdown and wait briefly
if hasattr(self, '_persist_queue') and self._persist_queue is not None:
try:
self._persist_queue.put_nowait(("__STOP__", {}))
except Exception:
pass
self._worker_thread.join(timeout=0.5)
except Exception:
# Ignore errors during destruction - Python is shutting down anyway
pass
# ------------------------------------------------------------------
# Async persistence helpers
# ------------------------------------------------------------------
[docs]
def flush(self):
if self.async_persist and self._persist_queue is not None:
self._persist_queue.join()
def _schedule_persist(self, op: str, payload: Dict[str, Any]) -> bool:
if not self.async_persist or self._persist_queue is None:
return False
try:
self._persist_queue.put_nowait((op, payload))
return True
except queue.Full:
logger.warning(
"Async cache queue full; falling back to synchronous persistence"
)
return False
def _async_worker(self):
queue_ref = self._persist_queue
if queue_ref is None:
return
# Use longer timeout to reduce CPU wakeups, but Event can wake immediately
# Event.wait() can be interrupted immediately by setting the event
TIMEOUT = 1.0 # Check every second, but Event can wake immediately
while True:
# Wait for shutdown event or timeout
# If shutdown is set, wait() returns immediately (True)
# Otherwise, wait up to TIMEOUT seconds
if self._shutdown_event.wait(timeout=TIMEOUT):
# Shutdown was signaled
break
# Timeout occurred - check queue for work
# Use non-blocking get to avoid blocking when shutdown is signaled
try:
op, payload = queue_ref.get_nowait()
except queue.Empty:
# No work available, continue loop to check shutdown again
continue
if op == "__STOP__":
queue_ref.task_done()
break
try:
if op == "set":
self._perform_persist_set(**payload)
else:
logger.warning("Unknown async cache operation: %s", op)
except Exception as exc:
logger.warning("Async cache write failed: %s", exc)
# Disable async persistence to avoid repeated failures
self.async_persist = False
finally:
queue_ref.task_done()
def _perform_persist_set(
self,
*,
composed: str,
value: Any,
namespace: str,
version: str,
ttl: Optional[int],
created_at: Optional[float] = None,
expires_at: Optional[float] = None,
):
if not self.persistent:
return
try:
# Pass created_at and expires_at to ensure consistency
# between memory and persistent cache
self.persistent.set(
composed,
value,
namespace=namespace,
version=version,
ttl=ttl,
created_at=created_at,
expires_at=expires_at,
)
except Exception as exc:
logger.warning("Persistent cache write failed: %s", exc)
self.persistent = None
raise
def _shutdown_async_worker(self) -> None:
if not hasattr(self, '_worker_thread') or self._worker_thread is None:
return
if not hasattr(self, '_persist_queue') or self._persist_queue is None:
return
# Signal shutdown first - this will cause worker to exit on next timeout check
if hasattr(self, '_shutdown_event'):
self._shutdown_event.set()
# Try to send stop message to worker thread (non-blocking)
try:
self._persist_queue.put_nowait(("__STOP__", {}))
except queue.Full:
# Queue is full - worker will exit due to shutdown_event being set
pass
except Exception:
# Queue might be closed or in invalid state - worker will exit due to shutdown_event
pass
# Wait for thread to finish
if self._worker_thread.is_alive():
self._worker_thread.join(timeout=2.0)
if self._worker_thread.is_alive():
logger.warning("Cache worker thread did not terminate within timeout, but shutdown was signaled")
# Clean up
self._worker_thread = None
self._persist_queue = None
if hasattr(self, '_shutdown_event'):
self._shutdown_event.clear()
class _DummyContext:
def __enter__(self):
return None
def __exit__(self, exc_type, exc_val, exc_tb):
return False