import importlib
import pickle
import logging
import gzip
import types
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from .file import FileStorage
from .base import ValueMixin, CallMixin, DestructuringMixin
from ..security import get_secret_key, SignedBytes, SignatureError
from pyiron_snippets.import_alarm import ImportAlarm
[docs]
logger = logging.getLogger("fleche.storage.pickle_file")
[docs]
_HMAC_MIN_KEY_BYTES = 32
[docs]
def _normalize_secret_key(key) -> list[bytes]:
"""
Normalize a secret key value to ``list[bytes]``.
Accepts:
- ``bytes``: wrapped in a list
- ``str``: split on ``":"`` delimiter, each part encoded to UTF-8
- ``list[bytes]``: each element validated for minimum length
- ``list[str]``: each element (or colon-delimited parts) encoded to UTF-8
Each resulting key must be at least ``_HMAC_MIN_KEY_BYTES`` bytes long.
Raises:
TypeError: if ``key`` or any element is not ``bytes`` or ``str``.
ValueError: if any resulting key is shorter than ``_HMAC_MIN_KEY_BYTES``.
"""
if isinstance(key, (bytes, str)):
key = [key]
if not isinstance(key, list):
raise TypeError(
f"secret_key must be bytes, str, or list, got {type(key).__name__}"
)
result = []
for k in key:
if isinstance(k, str):
for part in k.split(":"):
encoded = part.encode("utf-8")
if len(encoded) < _HMAC_MIN_KEY_BYTES:
raise ValueError(
f"Each secret key must be at least {_HMAC_MIN_KEY_BYTES} bytes, "
f"got {len(encoded)}"
)
result.append(encoded)
elif isinstance(k, bytes):
if len(k) < _HMAC_MIN_KEY_BYTES:
raise ValueError(
f"Each secret key must be at least {_HMAC_MIN_KEY_BYTES} bytes, "
f"got {len(k)}"
)
result.append(k)
else:
raise TypeError(
f"Each element of secret_key must be bytes or str, "
f"got {type(k).__name__}"
)
return result
with ImportAlarm(
"PickleFile.with_cloudpickle requires 'cloudpickle' to be installed. "
"Install it with `pip install fleche[cloudpickle]`.",
raise_exception=True,
) as cloudpickle_alarm:
import cloudpickle
with ImportAlarm(
"PickleFile.with_dill requires 'dill' to be installed. "
"Install it with `pip install fleche[dill]`.",
raise_exception=True,
) as dill_alarm:
import dill
@dataclass(frozen=True, kw_only=True)
[docs]
class PickleFileBackend(FileStorage):
"""
Store values as files on the filesystem using a serialization module.
"""
[docs]
secret_key: list[bytes] = field(default_factory=list)
[docs]
serializer: Any = field(repr=False)
[docs]
def __post_init__(self):
super().__post_init__()
normalized_key = get_secret_key() if not self.secret_key else _normalize_secret_key(self.secret_key)
object.__setattr__(self, "secret_key", normalized_key)
@classmethod
[docs]
def with_pickle(cls, *args, **kwargs):
"""Construct a PickleFileBackend using the standard pickle module."""
return cls(*args, serializer=pickle, **kwargs)
@classmethod
@cloudpickle_alarm
[docs]
def with_cloudpickle(cls, *args, **kwargs):
"""Construct a PickleFileBackend using the cloudpickle module."""
return cls(*args, serializer=cloudpickle, **kwargs)
@classmethod
@dill_alarm
[docs]
def with_dill(cls, *args, **kwargs):
"""Construct a PickleFileBackend using the dill module."""
return cls(*args, serializer=dill, **kwargs)
[docs]
def __getstate__(self):
state = self.__dict__.copy()
serializer = state.get("serializer")
if isinstance(serializer, types.ModuleType):
state["serializer"] = serializer.__name__
return state
[docs]
def __setstate__(self, state):
serializer_name = state.get("serializer")
if isinstance(serializer_name, str):
state = dict(state)
state["serializer"] = importlib.import_module(serializer_name)
self.__dict__.update(state)
[docs]
def _to_file(self, value: Any, path: Path) -> None:
signer = SignedBytes(self.secret_key)
data = signer.dumps(self.serializer.dumps(value))
if self.compress:
data = gzip.compress(data)
path.write_bytes(data)
[docs]
def _from_file(self, path: Path) -> Any:
try:
content = path.read_bytes()
if self.compress:
content = gzip.decompress(content)
signer = SignedBytes(self.secret_key)
data = signer.loads(content)
return self.serializer.loads(data)
except FileNotFoundError:
raise KeyError(path) from None
except SignatureError:
raise KeyError(path, "Value present but failed signature check.")
[docs]
class ValuePickleFile(ValueMixin, DestructuringMixin, PickleFileBackend): ...
[docs]
class CallPickleFile(CallMixin, PickleFileBackend): ...