Source code for tooluniverse.cache.result_cache_manager

"""
Result cache manager that coordinates in-memory and persistent storage.
"""

from __future__ import annotations

import logging
import os
import queue
import threading
import time
from dataclasses import dataclass
from typing import Any, Dict, Iterator, Optional, Sequence

from .memory_cache import LRUCache, SingleFlight
from .sqlite_backend import CacheEntry, PersistentCache

logger = logging.getLogger(__name__)


[docs] @dataclass class CacheRecord: value: Any expires_at: Optional[float] namespace: str version: str
[docs] class ResultCacheManager: """Facade around memory + persistent cache layers."""
[docs] def __init__( self, *, memory_size: int = 256, persistent_path: Optional[str] = None, enabled: bool = True, persistence_enabled: bool = True, singleflight: bool = True, default_ttl: Optional[int] = None, async_persist: Optional[bool] = None, async_queue_size: int = 10000, ): self.enabled = enabled self.default_ttl = default_ttl self.memory = LRUCache(max_size=memory_size) persistence_path = persistent_path if persistence_path is None: cache_dir = os.environ.get("TOOLUNIVERSE_CACHE_DIR") if cache_dir: persistence_path = os.path.join(cache_dir, "tooluniverse_cache.sqlite") self.persistent = None if persistence_enabled and persistence_path: try: self.persistent = PersistentCache(persistence_path, enable=True) except Exception as exc: logger.warning("Failed to initialize persistent cache: %s", exc) self.persistent = None self.singleflight = SingleFlight() if singleflight else None self._init_async_persistence(async_persist, async_queue_size)
# ------------------------------------------------------------------ # Helper methods # ------------------------------------------------------------------
[docs] @staticmethod def compose_key(namespace: str, version: str, cache_key: str) -> str: return f"{namespace}::{version}::{cache_key}"
def _now(self) -> float: return time.time() def _ttl_or_default(self, ttl: Optional[int]) -> Optional[int]: return ttl if ttl is not None else self.default_ttl def _init_async_persistence( self, async_persist: Optional[bool], async_queue_size: int ) -> None: if async_persist is None: async_persist = os.getenv( "TOOLUNIVERSE_CACHE_ASYNC_PERSIST", "true" ).lower() in ("true", "1", "yes") self.async_persist = ( async_persist and self.persistent is not None and self.enabled ) self._persist_queue: Optional["queue.Queue[tuple[str, Dict[str, Any]]]"] = None self._worker_thread: Optional[threading.Thread] = None if not self.async_persist: return queue_size = max(1, async_queue_size) self._persist_queue = queue.Queue(maxsize=queue_size) self._worker_thread = threading.Thread( target=self._async_worker, name="ResultCacheWriter", daemon=True, ) self._worker_thread.start() # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def get(self, *, namespace: str, version: str, cache_key: str) -> Optional[Any]: if not self.enabled: return None composed = self.compose_key(namespace, version, cache_key) record = self.memory.get(composed) if record: if record.expires_at and record.expires_at <= self._now(): self.memory.delete(composed) else: return record.value entry = self._get_from_persistent(composed) if entry: expires_at = entry.created_at + entry.ttl if entry.ttl else None self.memory.set( composed, CacheRecord( value=entry.value, expires_at=expires_at, namespace=namespace, version=version, ), ) return entry.value return None
[docs] def set( self, *, namespace: str, version: str, cache_key: str, value: Any, ttl: Optional[int] = None, ): if not self.enabled: return effective_ttl = self._ttl_or_default(ttl) expires_at = self._now() + effective_ttl if effective_ttl else None composed = self.compose_key(namespace, version, cache_key) self.memory.set( composed, CacheRecord( value=value, expires_at=expires_at, namespace=namespace, version=version, ), ) if self.persistent: payload = { "composed": composed, "value": value, "namespace": namespace, "version": version, "ttl": effective_ttl, } if not self._schedule_persist("set", payload): self._perform_persist_set(**payload)
[docs] def delete(self, *, namespace: str, version: str, cache_key: str): composed = self.compose_key(namespace, version, cache_key) self.memory.delete(composed) if self.persistent: try: self.persistent.delete(composed) except Exception as exc: logger.warning("Persistent cache delete failed: %s", exc)
[docs] def clear(self, namespace: Optional[str] = None): if namespace: # Clear matching namespace in memory keys_to_remove = [ key for key, record in self.memory.items() if hasattr(record, "namespace") and record.namespace == namespace ] for key in keys_to_remove: self.memory.delete(key) else: self.memory.clear() if self.persistent: try: self.flush() self.persistent.clear(namespace=namespace) except Exception as exc: logger.warning("Persistent cache clear failed: %s", exc)
[docs] def bulk_get(self, requests: Sequence[Dict[str, str]]) -> Dict[str, Any]: """Fetch multiple cache entries at once. Args: requests: Iterable of dicts containing ``namespace``, ``version`` and ``cache_key``. Returns: Mapping of composed cache keys to cached values. """ if not self.enabled: return {} hits: Dict[str, Any] = {} for request in requests: namespace = request["namespace"] version = request["version"] cache_key = request["cache_key"] value = self.get( namespace=namespace, version=version, cache_key=cache_key, ) if value is not None: composed = self.compose_key(namespace, version, cache_key) hits[composed] = value return hits
[docs] def stats(self) -> Dict[str, Any]: return { "enabled": self.enabled, "memory": self.memory.stats(), "persistent": ( self.persistent.stats() if self.persistent else {"enabled": False} ), "async_persist": self.async_persist, "pending_writes": ( self._persist_queue.qsize() if self.async_persist and self._persist_queue is not None else 0 ), }
[docs] def dump(self, namespace: Optional[str] = None) -> Iterator[Dict[str, Any]]: if not self.persistent: return iter([]) self.flush() return ( { "cache_key": entry.key, "namespace": entry.namespace, "version": entry.version, "ttl": entry.ttl, "created_at": entry.created_at, "last_accessed": entry.last_accessed, "hit_count": entry.hit_count, "value": entry.value, } for entry in self._iter_persistent(namespace=namespace) )
def _get_from_persistent(self, composed_key: str) -> Optional[CacheEntry]: if not self.persistent: return None try: return self.persistent.get(composed_key) except Exception as exc: logger.warning("Persistent cache read failed: %s", exc) self.persistent = None return None def _iter_persistent(self, namespace: Optional[str]): if not self.persistent: return iter([]) try: return self.persistent.iter_entries(namespace=namespace) except Exception as exc: logger.warning("Persistent cache iterator failed: %s", exc) return iter([]) # ------------------------------------------------------------------ # Context manager for singleflight # ------------------------------------------------------------------
[docs] def singleflight_guard(self, composed_key: str): if self.singleflight: return self.singleflight.acquire(composed_key) return _DummyContext()
[docs] def close(self): self.flush() self._shutdown_async_worker() if self.persistent: try: self.persistent.close() except Exception as exc: logger.warning("Persistent cache close failed: %s", exc)
# ------------------------------------------------------------------ # Async persistence helpers # ------------------------------------------------------------------
[docs] def flush(self): if self.async_persist and self._persist_queue is not None: self._persist_queue.join()
def _schedule_persist(self, op: str, payload: Dict[str, Any]) -> bool: if not self.async_persist or self._persist_queue is None: return False try: self._persist_queue.put_nowait((op, payload)) return True except queue.Full: logger.warning( "Async cache queue full; falling back to synchronous persistence" ) return False def _async_worker(self): queue_ref = self._persist_queue if queue_ref is None: return while True: try: op, payload = queue_ref.get() except Exception: continue if op == "__STOP__": queue_ref.task_done() break try: if op == "set": self._perform_persist_set(**payload) else: logger.warning("Unknown async cache operation: %s", op) except Exception as exc: logger.warning("Async cache write failed: %s", exc) # Disable async persistence to avoid repeated failures self.async_persist = False finally: queue_ref.task_done() def _perform_persist_set( self, *, composed: str, value: Any, namespace: str, version: str, ttl: Optional[int], ): if not self.persistent: return try: self.persistent.set( composed, value, namespace=namespace, version=version, ttl=ttl, ) except Exception as exc: logger.warning("Persistent cache write failed: %s", exc) self.persistent = None raise def _shutdown_async_worker(self) -> None: if not self.async_persist or self._persist_queue is None: return try: self._persist_queue.put_nowait(("__STOP__", {})) except queue.Full: self._persist_queue.put(("__STOP__", {})) if self._worker_thread is not None: self._worker_thread.join(timeout=5) self._worker_thread = None self._persist_queue = None
class _DummyContext: def __enter__(self): return None def __exit__(self, exc_type, exc_val, exc_tb): return False