Source code for fleche.storage.pickle_file

import pickle
import logging
import gzip
from dataclasses import dataclass, field
from typing import Any

from .file import FileStorage
from ..digest import Digest
from ..security import get_secret_key, SignedBytes, SignatureError

from pyiron_snippets.import_alarm import ImportAlarm

[docs] logger = logging.getLogger("fleche.storage.pickle_file")
with ImportAlarm( "PickleFile.with_cloudpickle requires 'cloudpickle' to be installed. " "Install it with `pip install fleche[cloudpickle]`.", raise_exception=True, ) as cloudpickle_alarm: import cloudpickle with ImportAlarm( "PickleFile.with_dill requires 'dill' to be installed. " "Install it with `pip install fleche[dill]`.", raise_exception=True, ) as dill_alarm: import dill @dataclass(kw_only=True)
[docs] class PickleFile(FileStorage): """ Store values as files on the filesystem using a serialization module. """
[docs] secret_key: list[bytes] = field(default_factory=list)
[docs] serializer: Any = field(repr=False)
[docs] compress: bool = False
[docs] def __post_init__(self): super().__post_init__() if not self.secret_key: self.secret_key = get_secret_key()
@classmethod
[docs] def with_pickle(cls, *args, **kwargs): """Construct a PickleFile using the standard pickle module.""" return cls(*args, serializer=pickle, **kwargs)
@classmethod @cloudpickle_alarm
[docs] def with_cloudpickle(cls, *args, **kwargs): """Construct a PickleFile using the cloudpickle module.""" return cls(*args, serializer=cloudpickle, **kwargs)
@classmethod @dill_alarm
[docs] def with_dill(cls, *args, **kwargs): """Construct a PickleFile using the dill module.""" return cls(*args, serializer=dill, **kwargs)
[docs] def _save(self, value: Any, key: Digest) -> Digest: signer = SignedBytes(self.secret_key) data = signer.dumps(self.serializer.dumps(value)) if self.compress: data = gzip.compress(data) (self._path(key)).write_bytes(data) return key
[docs] def _load(self, key: Digest) -> Any: try: content = (self._path(key)).read_bytes() if self.compress: content = gzip.decompress(content) signer = SignedBytes(self.secret_key) data = signer.loads(content) return self.serializer.loads(data) except FileNotFoundError: raise KeyError(key) from None except SignatureError: raise KeyError(key, "Value present but failed signature check.")