import pickle
import logging
import gzip
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable
from .file import FileStorage, file_write_lock
from .base import ValueMixin, CallMixin
from .thread_safe import PerKeyLockMixin
from .destructuring import DestructuringMixin
from ..security import get_secret_key, normalize_secret_key, SignedBytes, SignatureError
from pyiron_snippets.import_alarm import ImportAlarm
[docs]
logger = logging.getLogger("fleche.storage.pickle_file")
with ImportAlarm(
"PickleFile.with_cloudpickle requires 'cloudpickle' to be installed. "
"Install it with `pip install fleche[cloudpickle]`.",
raise_exception=True,
) as cloudpickle_alarm:
import cloudpickle
with ImportAlarm(
"PickleFile.with_dill requires 'dill' to be installed. "
"Install it with `pip install fleche[dill]`.",
raise_exception=True,
) as dill_alarm:
import dill
@dataclass(frozen=True, kw_only=True)
[docs]
class PickleFileBackend(FileStorage):
"""
Store values as files on the filesystem using a serialization module.
"""
[docs]
secret_key: tuple[bytes, ...] = field(default_factory=tuple)
[docs]
dumps: Callable = field(repr=False)
[docs]
loads: Callable = field(repr=False)
[docs]
def __post_init__(self):
super().__post_init__()
raw = get_secret_key() if not self.secret_key else normalize_secret_key(self.secret_key)
object.__setattr__(self, "secret_key", tuple(raw))
@classmethod
[docs]
def with_pickle(cls, *args, **kwargs):
"""Construct a PickleFileBackend using the standard pickle module."""
return cls(*args, dumps=pickle.dumps, loads=pickle.loads, **kwargs)
@classmethod
@cloudpickle_alarm
[docs]
def with_cloudpickle(cls, *args, **kwargs):
"""Construct a PickleFileBackend using the cloudpickle module."""
return cls(*args, dumps=cloudpickle.dumps, loads=cloudpickle.loads, **kwargs)
@classmethod
@dill_alarm
[docs]
def with_dill(cls, *args, **kwargs):
"""Construct a PickleFileBackend using the dill module."""
return cls(*args, dumps=dill.dumps, loads=dill.loads, **kwargs)
[docs]
def _to_file(self, value: Any, path: Path) -> None:
signer = SignedBytes(self.secret_key)
data = signer.dumps(self.dumps(value))
if self.compress:
data = gzip.compress(data)
path.write_bytes(data)
[docs]
def _from_file(self, path: Path) -> Any:
try:
content = path.read_bytes()
if content[:2] == b"\x1f\x8b":
content = gzip.decompress(content)
signer = SignedBytes(self.secret_key)
data = signer.loads(content)
return self.loads(data)
except FileNotFoundError:
raise KeyError(path) from None
except SignatureError:
raise KeyError(path, "Value present but failed signature check.")
[docs]
def compress_all(self) -> None:
"""Rewrite all stored files in gzip-compressed form."""
for key in list(self.list()):
path = self._path(key)
lock_path = self._path(f"{key}.lock")
with file_write_lock(lock_path):
try:
content = path.read_bytes()
except FileNotFoundError:
continue
if content[:2] != b"\x1f\x8b":
path.write_bytes(gzip.compress(content))
[docs]
def decompress_all(self) -> None:
"""Rewrite all stored files in uncompressed form."""
for key in list(self.list()):
path = self._path(key)
lock_path = self._path(f"{key}.lock")
with file_write_lock(lock_path):
try:
content = path.read_bytes()
except FileNotFoundError:
continue
if content[:2] == b"\x1f\x8b":
path.write_bytes(gzip.decompress(content))
@dataclass(frozen=True)
[docs]
class ValuePickleFile(PerKeyLockMixin, DestructuringMixin, ValueMixin, PickleFileBackend): ...
@dataclass(frozen=True)
[docs]
class CallPickleFile(PerKeyLockMixin, CallMixin, PickleFileBackend): ...