Source code for fleche.storage.pickle_file

import pickle
import logging
import gzip
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable

from .file import FileStorage, file_write_lock
from .base import ValueMixin, CallMixin
from .thread_safe import PerKeyLockMixin
from .destructuring import DestructuringMixin
from ..security import get_secret_key, normalize_secret_key, SignedBytes, SignatureError

from pyiron_snippets.import_alarm import ImportAlarm

[docs] logger = logging.getLogger("fleche.storage.pickle_file")
with ImportAlarm( "PickleFile.with_cloudpickle requires 'cloudpickle' to be installed. " "Install it with `pip install fleche[cloudpickle]`.", raise_exception=True, ) as cloudpickle_alarm: import cloudpickle with ImportAlarm( "PickleFile.with_dill requires 'dill' to be installed. " "Install it with `pip install fleche[dill]`.", raise_exception=True, ) as dill_alarm: import dill @dataclass(frozen=True, kw_only=True)
[docs] class PickleFileBackend(FileStorage): """ Store values as files on the filesystem using a serialization module. """
[docs] secret_key: tuple[bytes, ...] = field(default_factory=tuple)
[docs] dumps: Callable = field(repr=False)
[docs] loads: Callable = field(repr=False)
[docs] compress: bool = False
[docs] def __post_init__(self): super().__post_init__() raw = get_secret_key() if not self.secret_key else normalize_secret_key(self.secret_key) object.__setattr__(self, "secret_key", tuple(raw))
@classmethod
[docs] def with_pickle(cls, *args, **kwargs): """Construct a PickleFileBackend using the standard pickle module.""" return cls(*args, dumps=pickle.dumps, loads=pickle.loads, **kwargs)
@classmethod @cloudpickle_alarm
[docs] def with_cloudpickle(cls, *args, **kwargs): """Construct a PickleFileBackend using the cloudpickle module.""" return cls(*args, dumps=cloudpickle.dumps, loads=cloudpickle.loads, **kwargs)
@classmethod @dill_alarm
[docs] def with_dill(cls, *args, **kwargs): """Construct a PickleFileBackend using the dill module.""" return cls(*args, dumps=dill.dumps, loads=dill.loads, **kwargs)
[docs] def _to_file(self, value: Any, path: Path) -> None: signer = SignedBytes(self.secret_key) data = signer.dumps(self.dumps(value)) if self.compress: data = gzip.compress(data) path.write_bytes(data)
[docs] def _from_file(self, path: Path) -> Any: try: content = path.read_bytes() if content[:2] == b"\x1f\x8b": content = gzip.decompress(content) signer = SignedBytes(self.secret_key) data = signer.loads(content) return self.loads(data) except FileNotFoundError: raise KeyError(path) from None except SignatureError: raise KeyError(path, "Value present but failed signature check.")
[docs] def compress_all(self) -> None: """Rewrite all stored files in gzip-compressed form.""" for key in list(self.list()): path = self._path(key) lock_path = self._path(f"{key}.lock") with file_write_lock(lock_path): try: content = path.read_bytes() except FileNotFoundError: continue if content[:2] != b"\x1f\x8b": path.write_bytes(gzip.compress(content))
[docs] def decompress_all(self) -> None: """Rewrite all stored files in uncompressed form.""" for key in list(self.list()): path = self._path(key) lock_path = self._path(f"{key}.lock") with file_write_lock(lock_path): try: content = path.read_bytes() except FileNotFoundError: continue if content[:2] == b"\x1f\x8b": path.write_bytes(gzip.decompress(content))
@dataclass(frozen=True)
[docs] class ValuePickleFile(PerKeyLockMixin, DestructuringMixin, ValueMixin, PickleFileBackend): ...
@dataclass(frozen=True)
[docs] class CallPickleFile(PerKeyLockMixin, CallMixin, PickleFileBackend): ...