from abc import ABC, abstractmethod
from dataclasses import dataclass
from numbers import Number
from typing import Any, Callable
from . import base
from .. import digest
[docs]
class Digested(ABC):
@abstractmethod
[docs]
def underlying(self):
"""Return plain underlying value, ie. list/dict/etc of nested values or their partial digests"""
# mess with our hash to ensure that we are referentially transparent with respect to the underlying list.
# For the replacement of the 'real' list with the 'digested' list to be invisible to caches, they must hash to the
# same values.
[docs]
def __digest__(self):
return digest.digest(self.underlying())
@abstractmethod
[docs]
def mend(self, storage: 'DestructuringMixin'): ...
@classmethod
@abstractmethod
[docs]
def sunder(cls, intern: Callable[[Any], tuple[Any, int | float]], value: Any): ...
@staticmethod
[docs]
def get(storage, key):
if isinstance(key, digest.Digest):
return storage.get(key)
else:
return key
@dataclass
[docs]
class DigestedIterable(Digested):
[docs]
def underlying(self):
return self.items
[docs]
def mend(self, storage: 'DestructuringMixin') -> list | tuple:
return type(self.items)(map(lambda v: self.get(storage, v), self.items))
@classmethod
[docs]
def sunder(cls, intern: Callable[[Any], tuple[Any, int | float]], value: list | tuple):
children, depths = zip(*(intern(v) for v in value))
depth = 1 + max(depths)
items = type(value)(children)
if all(not isinstance(r, digest.Digest) for r in items):
# intern did do anything to our children because we're out of depth, return them verbatim
return items, depth
return cls(items), depth
@dataclass
[docs]
class DigestedDict(Digested):
[docs]
def underlying(self):
return self.items
[docs]
def mend(self, storage: 'DestructuringMixin') -> dict:
return {self.get(storage, k): self.get(storage, v)
for k, v in self.items.items()}
@classmethod
[docs]
def sunder(cls, intern: Callable[[Any], tuple[Any, int | float]], value: dict):
kk, k_depths = zip(*(intern(k) for k in value))
vv, v_depths = zip(*(intern(v) for v in value.values()))
depth = 1 + max(max(k_depths), max(v_depths))
items = dict(zip(kk, vv))
if all(not isinstance(r, digest.Digest) for r in (*items.keys(), *items.values())):
# intern did not do anything to our children because we're out of depth, return them verbatim
return items, depth
return cls(items), depth
@dataclass(frozen=True)
[docs]
class DestructuringMixin(base.StorageBackend):
"""Mixin that recursively destructures collections on save/load.
Place before a concrete :class:`StorageBackend` in the MRO to add
destructuring behavior. Lists, tuples, and dicts are broken apart so each
element is stored independently; on load the original structure is
reassembled.
Example::
@dataclass(frozen=True)
class ValueMemory(ValueMixin, DestructuringMixin, MemoryBackend): ...
vm = ValueMemory(storage={})
key = vm.save([1, [2, 3]])
assert vm.load(key) == [1, [2, 3]]
"""
[docs]
remaining_depth: int = 0
@staticmethod
[docs]
def _is_trojan_tuple(value):
return hasattr(value, "_fields") and hasattr(value, "_field_defaults")
[docs]
def _intern_rec(self, value: Any, key: digest.Digest | None = None) -> tuple[Any, int | float]:
"""Post-order traversal: recurse to leaves, decide inline-vs-store on the way back up.
Returns ``(result, depth)`` where *result* is the plain value when ``depth < remaining_depth``
(the element is inlined in its parent's :class:`Digested` wrapper) or a :class:`Digest` when
the element was written to storage separately. Every node in the structure is visited exactly
once (O(n)), unlike a separate depth-counting pass.
"""
if isinstance(value, digest.Digest):
return value, -1
depth = float("inf")
match value:
case Number() | str() | bytes():
depth = 0
# treat exactly like scalars, but guard syntax gets a bit weird if we try to join
# technically this violates our recursion of 1 + max children depth, but I just can't see a use for
# destructuring the empty container
case dict() | list() | tuple() if not value:
depth = 0
# because nothing is ever simple, namedtuple break LSP by not accepting a single iterable
# this is considered highly annoying, because e.g. a lot of numpy functions we'd like to wrap return
# NamedTuple
case tuple() if self._is_trojan_tuple(value):
pass
case list() | tuple():
value, depth = DigestedIterable.sunder(self._intern_rec, value)
case dict():
value, depth = DigestedDict.sunder(self._intern_rec, value)
if depth < self.remaining_depth:
return value, depth
return super().put(value, key or digest.digest(value)), depth
[docs]
def put(self, value: Any, key: digest.Digest) -> digest.Digest:
match value:
case list() | tuple() | dict() if not value:
return super().put(value, key)
case list() | tuple() | dict():
value_or_digest, depth = self._intern_rec(value, key)
# if given value is nominally not deep enough to be destructured/saved during recursion
# we do it here manually as the recursion base case
if depth < self.remaining_depth:
return super().put(value_or_digest, key)
else:
return value_or_digest
case _:
return super().put(value, key)
[docs]
def get(self, key: digest.Digest | Any) -> Any:
value = super().get(key)
match value:
case Digested():
return value.mend(self)
case _:
return value