from abc import ABC, abstractmethod
import logging
import random
import threading
from dataclasses import dataclass, replace, field
from typing import Iterable, Any, Callable, Literal, overload
import pandas as pd
from . import digest as _digest
from .digest import Digest # type hint convenience
from . import storage
from .storage.base import _longest_common_prefix_length
from .storage.destructuring import HasChildDigests
from .call import Call, DigestedCall, LazyCall, QueryCall
from . import call
from . import query
[docs]
logger = logging.getLogger("fleche.cache")
[docs]
class Rejected(Exception):
"""Cache refused to cache the call for some reason or other."""
pass
# backwards compat imports
# from breaking introduced in 0.4.0
[docs]
DigestedIterable = storage.destructuring.DigestedIterable
[docs]
DigestedDict = storage.destructuring.DigestedDict
[docs]
class BaseCache(ABC):
@abstractmethod
[docs]
def save(self, call: Call) -> str:
...
@abstractmethod
[docs]
def load(self, key: str) -> LazyCall:
...
@abstractmethod
[docs]
def load_value(self, key: str) -> Any: ...
@abstractmethod
[docs]
def evict(self, key: str | Digest) -> None:
...
[docs]
def contains(self, key: str) -> bool:
try:
self.load(key)
return True
except KeyError:
return False
[docs]
def transfer(self, other: "BaseCache", pop: bool = False, overwrite: bool = False) -> None:
"""Transfer all calls from this cache to another cache.
Args:
other: The destination cache.
pop: If True, evict transferred keys from the source cache after moving.
overwrite: If True, overwrite existing entries in the target cache.
If False (default), skip entries that already exist in the target.
"""
self.query().transfer(other, pop=pop, overwrite=overwrite)
[docs]
def readonly(self) -> "ReadOnlyCache":
"""Return a read-only view of this cache."""
return ReadOnlyCache(self)
[docs]
def push(self, cache: "BaseCache") -> "CacheStack":
return CacheStack((cache, self))
@abstractmethod
[docs]
def expand(self, key: Digest | str) -> Digest:
"""
Expand a short digest prefix to its full-length digest.
Args:
key (str or :class:`Digest`): the short digest prefix to expand
Returns:
:class:`Digest`: the full-length digest
Raises:
KeyError: if the key is not found
:class:`AmbiguousDigestError`: if the prefix matches more than one entry
"""
...
@overload
[docs]
def shrink(self, key: Digest | str, /) -> Digest: ...
@overload
def shrink(self, key: Digest | str, /, *keys: Digest | str) -> "tuple[Digest, ...]": ...
def shrink(self, *keys: Digest | str) -> "Digest | tuple[Digest, ...]":
"""
Find the shortest substring(s) that unambiguously reference each call.
With a single key, returns one :class:`Digest`. With multiple keys,
returns a tuple of :class:`Digest` in the same order as the inputs;
the batched form lets sub-storages list their keys once instead of
per-key, which matters on backends where listing is expensive (e.g.
SQL, filesystem).
Each input key must belong to *one* of the sub-storages (call or
value). Mixing call keys and value keys in a single call is
undefined behaviour — the result depends on internal partitioning
order and may change without notice.
.. warning::
This is a property of how many values there are in your storage!
A key returned from this function may become ambigious in the future when more values are added.
Do not rely on this function in your programs, it is provided as a convenience for users only!
Args:
*keys (str or :class:`Digest`): one or more keys to shorten
Returns:
:class:`Digest` (single key) or tuple of :class:`Digest` (multiple)
Raises:
:class:`AmbiguousDigestError`: if no shorter key is possible for any input
"""
if not keys:
raise TypeError("shrink() requires at least one key")
out = self._shrink(*keys)
return out[0] if len(keys) == 1 else out
@abstractmethod
[docs]
def _shrink(self, *keys: Digest | str) -> "tuple[Digest, ...]":
"""Partition and shrink all keys; always returns a same-length tuple of short digests."""
...
@abstractmethod
[docs]
def _query(self, call: call.QueryCall) -> Iterable[LazyCall]: ...
[docs]
def query(self, template: "call.QueryCall | None" = None, **kwargs) -> query.QueryIterator:
"""Query the cache for matching calls.
Accepts either a :class:`~fleche.call.QueryCall` as the first positional argument,
or the same keyword arguments that :class:`~fleche.call.QueryCall` accepts.
Omitted fields default to ``None`` (wildcard). Passing both a template and
keyword arguments raises :class:`TypeError`.
Examples::
cache.query(name="my_func")
cache.query(name="my_func", arguments={"x": 1})
cache.query(QueryCall(name="my_func")) # existing form still works
cache.query() # all calls
Returns:
:class:`~fleche.query.QueryIterator`
"""
if template is None:
template = call.QueryCall(**kwargs)
elif kwargs:
raise TypeError("Cannot pass keyword arguments when a QueryCall template is provided")
def _safe_iter():
try:
yield from self._query(template)
except _digest.Indigestible as e:
logger.warning("No hash for query argument: %s", e.args[0])
return query.QueryIterator(_safe_iter, cache=self)
[docs]
def table(
self,
arguments: Iterable[str] | str | Literal[True] = (),
results=False,
shrink_keys: bool = True,
) -> pd.DataFrame:
"""Return a pandas DataFrame summarizing cached calls via query().
This implementation uses a fully-wildcard Call template to retrieve
all calls through ``self.query`` and then flattens metadata keys into
top-level columns for convenience.
By default, arguments and results are elided.
The DataFrame index will be the lookup key (digest) of each call.
Columns are:
- `name`: the function name
- `module`: the module name
- 'result`: if `results` argument is `True`
- metadata fields are flattened and added as columns directly
If given argument names collide with any of the above columns, they are prefixed by 'a_'.
Only requested arguments are loaded from cache.
Args:
arguments: add the given arguments (of the queried calls) as columns to the table.
Pass ``True`` to add all arguments, or a single string as a shortcut for a
one-element tuple.
results (bool): if True, add results of queried calls to table
shrink_keys (bool): if True (default), shrink each index entry to
its shortest unambiguous prefix. Set to ``False`` to keep
full-length digests.
Returns:
:class:`pandas.DataFrame`: table of all calls on cache
"""
tpl = call.QueryCall(
name=None,
arguments=None,
metadata=None,
module=None,
version=None,
result=None,
)
return self.query(tpl).table(arguments=arguments, results=results, shrink_keys=shrink_keys)
[docs]
def filter(self, predicate: Callable[[Call | LazyCall], bool] | QueryCall) -> 'FilteredCache':
"""Create a read-only view of this cache that only exposes calls matching the predicate.
Args:
predicate: A function that takes a Call or LazyCall and returns True
if it should be included in the new cache, or a QueryCall object to
use as a template.
Returns:
FilteredCache: A read-only view of the cache.
"""
if isinstance(predicate, QueryCall):
predicate = predicate.matches
return FilteredCache(self, predicate)
[docs]
def _combine_expand(key: "Digest | str", results: "Iterable[Digest]") -> "Digest":
"""Reduce sub-storage expand results to a single resolved digest.
Raises:
KeyError: if no results were found.
AmbiguousDigestError: if results disagree on the full digest.
"""
unique = sorted(set(results))
if not unique:
raise KeyError(key)
if len(unique) == 1:
return unique[0]
lcp = _longest_common_prefix_length(unique[0], unique[1])
raise storage.AmbiguousDigestError(
f"Short digest {key} is ambiguous; expands to: {unique}; need at least {lcp + 1} characters."
)
[docs]
def _combine_shrink(key: "Digest | str", results: "Iterable[Digest]") -> "Digest":
"""Reduce sub-storage shrink results to the longest (safest) prefix.
Raises:
KeyError: if no results were found.
"""
all_results = list(results)
if not all_results:
raise KeyError(key)
return max(all_results, key=len)
@dataclass(frozen=True)
[docs]
class Cache(BaseCache):
[docs]
values: storage.ValueStorage
[docs]
calls: storage.CallStorage
[docs]
def load_value(self, key):
return self.values.load(key)
[docs]
def save(self, call: Call) -> str:
try:
digested = call.stash(self.values)
except storage.SaveError as e:
raise Rejected(e)
return self.calls.save(digested)
[docs]
def load(self, key: str) -> LazyCall:
return self.calls.load(key).fetch(self)
[docs]
def contains(self, key: str) -> bool:
return self.calls.contains(key)
[docs]
def expand(self, key: Digest | str) -> Digest:
results = []
for sub in (self.calls, self.values):
try:
results.append(sub.expand(key))
except KeyError:
pass
return _combine_expand(key, results)
[docs]
def _shrink(self, *keys: Digest | str) -> "tuple[Digest, ...]":
call_keys: list = []
value_keys: list = []
for k in keys:
if self.calls.contains(k):
call_keys.append(k)
elif self.values.contains(k):
value_keys.append(k)
else:
raise KeyError(k)
results: dict = {}
for sub, ks in ((self.calls, call_keys), (self.values, value_keys)):
if not ks:
continue
r = sub.shrink(*ks)
if len(ks) == 1:
r = (r,)
for k, s in zip(ks, r):
results[k] = s
return tuple(results[k] for k in keys)
[docs]
def _query(self, call: call.QueryCall) -> Iterable[LazyCall]:
"""Query for cached calls that match a template and return decoded results.
This delegates to the underlying :meth:`CallStorage.query` using the provided template ``call``. Any digested
argument values and the result are decoded via this cache's value storage before yielding.
Args:
call: A ``Call`` instance used as a template; fields set to ``None``
act as wildcards. For arguments and result, comparisons follow
digest semantics (i.e., values are matched by their digest).
Yields:
Call | LazyCall: Matching calls with arguments and result decoded from digests
where possible.
"""
# Delegate to underlying call storage, but first expand possible value digests and decode any digested
# arguments/results before yielding to the caller (same semantics as load()).
def maybe_expand(value):
if isinstance(value, Digest):
return self.values.expand(value)
else:
return value
call = replace(
call,
arguments={
k: maybe_expand(v)
for k, v in call.arguments.items()
} if call.arguments is not None else {},
result=maybe_expand(call.result),
)
for c in self.calls.query(call):
try:
yield c.fetch(self)
except Exception as err:
logger.error(
f"Failed to load matching call {c.to_lookup_key()} with {err}! Indicates corrupt cache."
)
[docs]
def evict(self, key: str | Digest) -> None:
self.calls.evict(key)
[docs]
def redigest(self) -> None:
"""Ensures consistent cache keys in case digest function changed.
This may take time depending on cache size."""
for key in self.calls.list():
call = self.load(key).fetch()
if call.to_lookup_key() != key:
# instantiate values too
self.save(call)
self.calls.evict(key)
[docs]
def gc(self) -> set[Digest]:
"""Evict value entries not reachable from any stored call.
Brute-force mark-and-sweep: walks every call record to build the set
of directly-referenced value digests, then transitively follows
destructured sub-references (via :meth:`DestructuringMixin.child_digests`
on storages that satisfy :class:`HasChildDigests`), and evicts every
``values`` key outside the reachable
set. Call records are left untouched.
Returns:
The set of digests that were evicted from value storage.
"""
reachable: set[Digest] = set()
for key in self.calls.list():
try:
dc = self.calls.load(key)
except KeyError:
continue
if isinstance(dc.result, Digest):
reachable.add(dc.result)
for v in dc.arguments.values():
if isinstance(v, Digest):
reachable.add(v)
if isinstance(self.values, HasChildDigests):
frontier = set(reachable)
while frontier:
key = frontier.pop()
try:
children = self.values.child_digests(key)
except KeyError:
continue
new = children - reachable
reachable |= new
frontier |= new
evicted: set[Digest] = set()
for key in list(self.values.list()):
if key not in reachable:
try:
self.values.evict(key)
evicted.add(key)
except KeyError:
continue
return evicted
@dataclass(frozen=True)
[docs]
class CacheWrapper(BaseCache):
"""Forwarding base class: all BaseCache methods delegate to ``self.cache``.
Combine with behaviour mixins (ReadOnlyMixin, FilteringMixin) to build
concrete wrapper classes without redeclaring ``cache``.
"""
[docs]
def save(self, call: Call) -> str:
return self.cache.save(call)
[docs]
def load(self, key: str) -> LazyCall:
return self.cache.load(key)
[docs]
def load_value(self, key: str) -> Any:
return self.cache.load_value(key)
[docs]
def contains(self, key: str) -> bool:
return self.cache.contains(key)
[docs]
def evict(self, key: str | Digest) -> None:
self.cache.evict(key)
[docs]
def expand(self, key: Digest | str) -> Digest:
return self.cache.expand(key)
[docs]
def _shrink(self, *keys: Digest | str) -> "tuple[Digest, ...]":
if len(keys) == 1:
return (self.cache.shrink(keys[0]),)
return self.cache.shrink(*keys)
[docs]
def _query(self, call: call.QueryCall) -> Iterable[LazyCall]:
return self.cache.query(call)
[docs]
class ReadOnlyMixin(CacheWrapper):
"""Raises :class:`Rejected` for ``save`` and ``evict``."""
[docs]
def save(self, call: Call):
raise Rejected(self, call)
[docs]
def evict(self, key: str | Digest) -> None:
raise Rejected("Cannot evict from a ReadOnlyCache", self, key)
@dataclass(frozen=True)
[docs]
class ReadOnlyCache(ReadOnlyMixin):
"""A cache that can only be read from."""
@dataclass(frozen=True)
[docs]
class FilteringMixin(CacheWrapper):
"""Filters ``load`` and ``_query`` results by a predicate."""
[docs]
predicate: Callable[[Call | LazyCall], bool]
[docs]
def load(self, key: str) -> LazyCall:
lc = self.cache.load(key)
if not self.predicate(lc):
raise KeyError(key)
return lc
[docs]
def _query(self, call: call.QueryCall) -> Iterable[LazyCall]:
for c in self.cache.query(call):
if self.predicate(c):
yield c
@dataclass(frozen=True)
[docs]
class FilteredCache(FilteringMixin, ReadOnlyMixin):
"""A read-only view of a cache that only exposes calls matching a predicate."""
@dataclass(frozen=True)
[docs]
class RefreshingCache(CacheWrapper):
"""A cache that forces re-execution by always missing on load.
It forwards saves and value loads to an underlying cache, allowing
new results to be stored while ensuring that existing ones are
ignored for the duration of its use.
This is necessary to handle nested fleche calls during a rerun,
otherwise forcing them to re-execute would be awkward.
"""
[docs]
def load(self, key: str) -> LazyCall:
raise KeyError(key)
[docs]
def contains(self, key: str) -> bool:
return False
@dataclass(frozen=True)
[docs]
class CacheStack(BaseCache):
"""A combination of caches with a shared traversal policy.
Saving always targets the lowest level (``stack[0]``); loading traverses
from ``stack[0]`` upward and back-fills any hit into ``stack[0]``.
All multi-cache fan-out is handled by three private traversal helpers,
each implementing one of the recurring patterns across the stack:
- :meth:`_first_hit` — return on the first success; raise if all miss.
- :meth:`_collect` — gather every success; caller combines the results.
- :meth:`_foreach` — apply to every cache; swallow expected refusals.
"""
[docs]
stack: tuple[BaseCache, ...]
[docs]
def __post_init__(self):
for c in self.stack:
if isinstance(c, CacheStack):
raise ValueError("CacheStack cannot be nested inside another CacheStack")
[docs]
def save(self, call: Call):
self.stack[0].save(call)
[docs]
def load(self, key) -> LazyCall:
for i, cache in enumerate(self.stack):
try:
lc = cache.load(key)
if i > 0:
try:
self.save(lc.fetch())
logger.info("Transferred hit for %s from higher cache to base cache", key)
except Rejected as e:
logger.warning("Failed to transfer hit for %s to base cache: %s", key, e)
return lc
except KeyError:
continue
raise KeyError(key)
[docs]
def load_value(self, key):
return self._first_hit(lambda c: c.load_value(key))
[docs]
def contains(self, key: str) -> bool:
return any(cache.contains(key) for cache in self.stack)
[docs]
def push(self, cache: BaseCache) -> "CacheStack":
return CacheStack((cache, *self.stack))
[docs]
def evict(self, key: str | Digest) -> None:
self._foreach(lambda c: c.evict(key))
[docs]
def expand(self, key: Digest | str) -> Digest:
return _combine_expand(key, self._collect(lambda c: c.expand(key)))
[docs]
def _shrink(self, *keys: Digest | str) -> "tuple[Digest, ...]":
per_key: dict = {k: [] for k in keys}
for cache in self.stack:
present = [k for k in keys if cache.contains(k)]
if not present:
continue
r = cache.shrink(*present)
if len(present) == 1:
r = (r,)
for k, s in zip(present, r):
per_key[k].append(s)
out_list = []
for k in keys:
if not per_key[k]:
raise KeyError(k)
out_list.append(_combine_shrink(k, per_key[k]))
return tuple(out_list)
[docs]
def _query(self, call: call.QueryCall) -> Iterable[LazyCall]:
"""Aggregate query results across the stack, avoiding duplicates.
The caches are queried from bottom to top. Results are deduplicated by
their lookup key (via ``Call.to_lookup_key()``) and yielded in the
order they are first seen.
Args:
call: A template ``Call`` where ``None`` fields act as wildcards.
Yields:
Call | LazyCall: Matching calls from any cache in the stack, without duplicates.
"""
seen = set()
for cache in self.stack:
for c in cache.query(call):
k = c.to_lookup_key()
if k in seen:
continue
seen.add(k)
yield c
# ------------------------------------------------------------------
# Private traversal helpers — three patterns that recur across the
# public fan-out methods. New CacheStack operations should be
# expressed as a one-liner over whichever helper fits.
# ------------------------------------------------------------------
[docs]
def _first_hit(self, op: Callable[["BaseCache"], Any], *, exc: type[BaseException] = KeyError) -> Any:
"""Return the first successful result from iterating the stack.
Invokes ``op(cache)`` on each cache in ``self.stack`` in order and
returns immediately when a call does not raise *exc*. If every cache
raises *exc* the exception is re-raised.
This is the **first-hit-wins** pattern: used when any single cache can
satisfy the request and caches earlier in the stack are preferred (e.g.
:meth:`load_value`). The caller supplies the per-cache operation as a
lambda so the key (or other closure state) is always available in the
traceback without adding an extra helper argument.
Args:
op: Callable that accepts a single :class:`BaseCache` and returns
the desired result. Called at most once per cache.
exc: Exception *class* treated as a cache miss. Defaults to
:class:`KeyError`. Must be a single type (not a tuple)
because it is also used in the ``raise`` at the end.
Raises:
exc: If every cache in the stack raises *exc*.
"""
for cache in self.stack:
try:
return op(cache)
except exc:
continue
raise exc
[docs]
def _collect(self, op: Callable[["BaseCache"], Any], *, exc: type[BaseException] = KeyError) -> list:
"""Collect one result per cache, skipping misses.
Invokes ``op(cache)`` on every cache in ``self.stack`` and appends
each non-raising result to a list. Caches that raise *exc* are
silently skipped; all other exceptions propagate normally.
This is the **collect-and-combine** pattern: used when all caches may
hold relevant data and the caller needs to aggregate results before
returning (e.g. :meth:`expand` and :meth:`shrink`, which pass the
collected list to ``_combine_expand``/``_combine_shrink``).
Args:
op: Callable that accepts a single :class:`BaseCache` and returns
a result to collect. Called exactly once per cache.
exc: Exception *class* to treat as a miss and skip. Defaults to
:class:`KeyError`.
Returns:
A list of all non-raising results in stack order. May be empty
when every cache misses; the caller is responsible for handling
that case (typically by raising :class:`KeyError`).
"""
out = []
for cache in self.stack:
try:
out.append(op(cache))
except exc:
pass
return out
[docs]
def _foreach(
self,
op: Callable[["BaseCache"], None],
*,
exc: type[BaseException] | tuple[type[BaseException], ...] = (Rejected, KeyError),
) -> None:
"""Apply an operation to every cache in the stack, swallowing refusals.
Invokes ``op(cache)`` on every cache in ``self.stack`` unconditionally.
Exceptions of type *exc* are caught and discarded; any other exception
propagates normally.
This is the **apply-everywhere** pattern: used when an operation should
be attempted on all caches regardless of whether individual caches
support it (e.g. :meth:`evict`, where read-only caches raise
:class:`Rejected` and empty caches raise :class:`KeyError`, and both
are expected non-fatal outcomes).
Args:
op: Callable that accepts a single :class:`BaseCache`. Its return
value is ignored. Called exactly once per cache.
exc: Exception type(s) to swallow. Defaults to
``(Rejected, KeyError)`` — the two standard refusal signals
used across the cache hierarchy. Pass a tuple to swallow
multiple types.
"""
for cache in self.stack:
try:
op(cache)
except exc:
pass
[docs]
class SizeLimitedMixin(BaseCache):
"""Mixin that enforces a maximum number of cached calls with random eviction.
Combine this with :class:`Cache` (mixin first in MRO) to get a size-limited
cache::
@dataclass
class SizeLimitedCache(SizeLimitedMixin, Cache):
max_size: int
When a new call is saved and the number of cached calls exceeds ``max_size``,
a call record is selected for eviction via :meth:`_pick_eviction_target`.
Value storage is intentionally left untouched.
The concrete class must provide a ``max_size`` integer, which is provided
automatically when mixed with :class:`Cache`.
"""
[docs]
_lock: threading.RLock = field(init=False, repr=False, compare=False)
[docs]
_keys: set[str] = field(init=False, repr=False, compare=False)
[docs]
def __post_init__(self, *args, **kwargs):
if hasattr(super(), '__post_init__'):
super().__post_init__(*args, **kwargs) # ty: ignore
object.__setattr__(self, '_lock', threading.RLock())
object.__setattr__(self, '_keys', {c.to_lookup_key() for c in self.query(call.QueryCall())})
# ------------------------------------------------------------------
# Eviction policy – override this to generalise to other strategies
# (e.g. LRU, LFU, …).
# ------------------------------------------------------------------
[docs]
def _pick_eviction_target(self, keys: list[str]) -> str:
"""Select the call to evict from a sample of cached call keys.
The default implementation chooses uniformly at random. Override this
method to implement a different eviction policy without touching any
other part of the class.
Args:
keys: A non-empty list of all tracked call keys.
Returns:
The key of the call that should be evicted.
"""
return random.choice(keys)
[docs]
def _enforce_size_limit(self) -> None:
"""Evict call records until the cache is within ``max_size``."""
with self._lock:
while len(self._keys) > self.max_size:
target = self._pick_eviction_target(list(self._keys))
self.evict(target)
[docs]
def save(self, call: call.Call) -> str:
with self._lock:
key = super().save(call)
self._keys.add(key)
self._enforce_size_limit()
return key
[docs]
def evict(self, key: str | _digest.Digest) -> None:
with self._lock:
super().evict(key)
self._keys.discard(str(key))
@dataclass(frozen=True)
[docs]
class SizeLimitedCache(SizeLimitedMixin, Cache):
"""A :class:`Cache` that enforces a maximum number of cached calls.
When a new call is saved and the number of cached calls exceeds ``max_size``,
a call record is selected for eviction via :meth:`_pick_eviction_target`.
The default policy evicts uniformly at random; override
:meth:`_pick_eviction_target` to change this.
Args:
values: Value storage (forwarded to :class:`Cache`).
_calls: Call storage (forwarded to :class:`Cache`).
max_size: Maximum number of calls to keep.
"""