import logging
from typing import Iterable, Any, List
from pathlib import Path
from dataclasses import dataclass, field
from .base import KeyManagement, CallStorage, AmbiguousDigestError
from ..call import Call, QueryCall
from ..digest import Digest, DIGEST_LENGTH, digest
from pyiron_snippets.import_alarm import ImportAlarm
[docs]
logger = logging.getLogger("fleche.storage")
with ImportAlarm(
"Sql requires 'sqlalchemy' to be installed. "
"Install it with `pip install fleche[sqlalchemy]`.",
raise_exception=True,
) as sqlalchemy_alarm:
from sqlalchemy import (
create_engine,
String,
Integer,
ForeignKey,
UniqueConstraint,
select,
and_,
)
from sqlalchemy import event
from sqlalchemy.orm import (
declarative_base,
sessionmaker,
relationship,
aliased,
Mapped,
mapped_column,
)
from sqlalchemy.types import JSON
[docs]
Base = declarative_base()
class CallModel(Base):
__tablename__ = "calls"
key = mapped_column(String(DIGEST_LENGTH), primary_key=True)
name: Mapped[str] = mapped_column(String, nullable=False)
module: Mapped[str] = mapped_column(String, nullable=True)
version: Mapped[int] = mapped_column(Integer, nullable=True)
code_digest: Mapped[str] = mapped_column(String(DIGEST_LENGTH), nullable=True)
result: Mapped[str] = mapped_column(String(DIGEST_LENGTH), nullable=True)
arguments = relationship(
"ArgumentModel",
back_populates="call",
cascade="all, delete-orphan",
order_by="ArgumentModel.position",
)
class ArgumentModel(Base):
__tablename__ = "arguments"
id = mapped_column(Integer, primary_key=True, autoincrement=True)
call_key = mapped_column(
String(DIGEST_LENGTH),
ForeignKey("calls.key", ondelete="CASCADE"),
nullable=False,
)
position = mapped_column(Integer, nullable=False)
name = mapped_column(String, nullable=False)
value = mapped_column(String(DIGEST_LENGTH), nullable=False)
__table_args__ = (
UniqueConstraint("call_key", "name", name="uq_arguments_call_name"),
)
call = relationship("CallModel", back_populates="arguments")
class MetaModel(Base):
__tablename__ = "metadata"
id = mapped_column(Integer, primary_key=True, autoincrement=True)
call_key = mapped_column(
String(DIGEST_LENGTH),
ForeignKey("calls.key", ondelete="CASCADE"),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(String, nullable=False, index=True)
data: Mapped[dict] = mapped_column(JSON, nullable=False)
__table_args__ = (
UniqueConstraint("call_key", "name", name="uq_metadata_call_name"),
)
[docs]
def _coerce_sqlite_url(path_or_url: str | None) -> str:
if path_or_url is None:
return "sqlite:///:memory:"
if isinstance(path_or_url, str) and path_or_url.startswith("sqlite:"):
url = path_or_url
else:
abs_path = Path(str(path_or_url)).absolute()
url = f"sqlite:///{abs_path}"
if url.startswith("sqlite:///"):
db_path = url[len("sqlite:///") :]
if db_path and db_path != ":memory:":
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
return url
# Use a constant for the PRAGMA execution to avoid raw string injection risks.
[docs]
SQLITE_FOREIGN_KEYS_ON = "PRAGMA foreign_keys=ON"
[docs]
def _enable_sqlite_foreign_keys(engine) -> None:
@event.listens_for(engine, "connect")
def _set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
try:
# We use a static constant string here. We cannot use sqlalchemy.text()
# because we are operating on a raw DBAPI cursor at the 'connect' event.
cursor.execute(SQLITE_FOREIGN_KEYS_ON)
finally:
cursor.close()
@dataclass(frozen=True)
[docs]
class Sql(CallStorage):
"""SQLAlchemy-backed CallStorage with JSON metadata and DB-backed expand()."""
[docs]
echo: bool = field(default=False, compare=False)
[docs]
engine: Any = field(init=False, repr=False, compare=False)
[docs]
session: Any = field(init=False, repr=False, compare=False)
@sqlalchemy_alarm
[docs]
def __post_init__(self) -> None:
coerced_url = _coerce_sqlite_url(self.url)
assert coerced_url is not None
object.__setattr__(self, "url", coerced_url)
engine = create_engine(coerced_url, echo=self.echo, future=True)
_enable_sqlite_foreign_keys(engine)
Base.metadata.create_all(engine)
object.__setattr__(self, "engine", engine)
object.__setattr__(self, "session", sessionmaker(
bind=engine, expire_on_commit=False, future=True
))
[docs]
def __getstate__(self):
state = self.__dict__.copy()
state.pop("engine", None)
state.pop("session", None)
return state
[docs]
def __setstate__(self, state):
self.__dict__.update(state)
# Re-initialize unpickleable fields
self.__post_init__()
[docs]
def put(self, call: Any, key: Digest) -> Digest:
session = self.session()
try:
existing = session.get(CallModel, str(key))
if existing is not None:
if self.get(key) == call:
return key
session.delete(existing)
session.flush()
call_model = CallModel(
key=str(key),
name=call.name,
module=call.module,
version=call.version,
code_digest=call.code_digest,
result=call.result if call.result is None else str(call.result),
)
session.add(call_model)
for i, (k, v) in enumerate(call.arguments.items()):
session.add(
ArgumentModel(
call_key=str(key), position=i, name=str(k), value=str(v)
)
)
if call.metadata:
session.add_all(
[
MetaModel(call_key=str(key), name=name, data=data)
for name, data in call.metadata.items()
]
)
session.commit()
# Always return a Digest instance, not a plain str
return key
except Exception:
session.rollback()
raise
finally:
session.close()
[docs]
def get(self, key: Digest) -> Call:
session = self.session()
try:
call_model = session.execute(
select(CallModel).where(CallModel.key == str(key))
).scalar_one_or_none()
if call_model is None:
raise KeyError(key)
arguments = {arg.name: Digest(arg.value) for arg in call_model.arguments}
meta_rows = (
session.execute(select(MetaModel).where(MetaModel.call_key == str(key)))
.scalars()
.all()
)
call = Call(
name=call_model.name,
arguments=arguments,
metadata={row.name: (row.data or {}) for row in meta_rows},
module=call_model.module,
version=call_model.version,
code_digest=call_model.code_digest,
result=(
Digest(call_model.result) if call_model.result is not None else None
),
)
return call
finally:
session.close()
[docs]
def _contains(self, key: Digest) -> bool:
session = self.session()
try:
return (
session.execute(
select(CallModel.key).where(CallModel.key == str(key))
).first()
is not None
)
finally:
session.close()
[docs]
def list(self) -> Iterable[Digest]:
session = self.session()
try:
# Return Digest instances for keys
return [Digest(row[0]) for row in session.execute(select(CallModel.key))]
finally:
session.close()
[docs]
def expand(self, key: Digest | str) -> Digest:
if len(key) >= DIGEST_LENGTH:
return Digest(str(key))
if len(key) < 4:
raise KeyError(key)
prefix = str(key)
session = self.session()
try:
rows = session.execute(
select(CallModel.key)
.where(CallModel.key.like(f"{prefix}%"))
.order_by(CallModel.key)
.limit(2)
).all()
finally:
session.close()
if not rows:
raise KeyError(key)
if len(rows) == 1:
return Digest(rows[0][0])
m1, m2 = rows[0][0], rows[1][0]
for i, (c1, c2) in enumerate(zip(m1, m2)):
if c1 != c2:
break
else:
i = min(len(m1), len(m2))
raise AmbiguousDigestError(
f"Short digest {key} is ambiguous; need at least {i+1} characters."
)
[docs]
def _evict(self, key: Digest) -> None:
session = self.session()
try:
instance = session.get(CallModel, str(key))
if instance is None:
return
session.delete(instance)
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
[docs]
def save(self, call: Call) -> Digest:
key = call.to_lookup_key()
logger.debug("Saving call %s", key)
if self.contains(str(key)):
self.evict(str(key))
return self.put(call, key)
[docs]
def load(self, key: Digest | str) -> Call:
if len(key) < DIGEST_LENGTH:
key = self.expand(key)
else:
key = Digest(key)
logger.debug("Loading call with key %s", key)
return self.get(key)
[docs]
def _normalize_value(self, v: Any) -> str:
"""Return the stored form used in SQL for argument/result matching.
We must match the generic CallStorage.query semantics which compare
digest(template_value) == digest(stored_call_value).
In this backend, stored argument/result values are hex-digest strings,
and digest(Digest(x)) == x. Therefore we should always compare
Arg.value/CallModel.result to str(digest(template_value)).
"""
return str(digest(v))
[docs]
def _build_call_conditions(self, template: QueryCall) -> List[Any]:
conditions = []
if template.name is not None:
conditions.append(CallModel.name == template.name)
if template.module is not None:
conditions.append(CallModel.module == template.module)
if template.version is not None:
conditions.append(CallModel.version == template.version)
if template.code_digest is not None:
conditions.append(CallModel.code_digest == template.code_digest)
if template.result is not None:
conditions.append(
CallModel.result == self._normalize_value(template.result)
)
return conditions
[docs]
def _apply_argument_filters(self, stmt: Any, arguments: dict[str, Any] | None) -> Any:
if not arguments:
return stmt
for k, v in arguments.items():
Arg = aliased(ArgumentModel)
on_clause = and_(
Arg.call_key == CallModel.key,
Arg.name == str(k),
)
if v is not None:
on_clause = and_(on_clause, Arg.value == self._normalize_value(v))
stmt = stmt.join(Arg, on_clause)
return stmt
[docs]
def query(self, template: QueryCall) -> Iterable[Call]:
"""Find cached calls matching a template using SQL-side filtering.
Semantics match CallStorage.query:
- Fields set to None are wildcards.
- Arguments and result are compared by digest(template_value) == digest(stored_value).
- Metadata can be filtered by providing template.metadata as a mapping of
metadata name -> dict of key/value filters. An empty dict for a given
name means "presence of that metadata name". Filters with simple types
(str, bool, int, float) are pushed down to SQL via JSON-extract
expressions; other types (e.g., lists) or None values fall back to
client-side checks after loading.
This method builds a SELECT over calls, joining the arguments table and
metadata table as needed to reduce candidate rows, then loads the
resulting calls and performs any remaining client-side validation.
Args:
template: A Call used as a template. None-valued fields are wildcards.
Yields:
Call: Matching calls including their decoded metadata.
"""
session = self.session()
try:
stmt = select(CallModel.key).select_from(CallModel)
conditions = self._build_call_conditions(template)
if conditions:
stmt = stmt.where(and_(*conditions))
stmt = self._apply_argument_filters(stmt, template.arguments)
stmt = self._apply_metadata_filters(stmt, template.metadata)
# Distinct to avoid duplicate keys if multiple argument joins could overlap
stmt = stmt.distinct()
keys = [Digest(k) for (k,) in session.execute(stmt).all()]
finally:
session.close()
# Yield loaded calls using existing loader (ensures metadata returned too)
def meta_matches(call: Call) -> bool:
specs = template.metadata
if not specs:
return True
for mname, filters in specs.items():
data = (call.metadata or {}).get(mname)
if data is None:
return False
for kk, vv in (filters or {}).items():
if vv is None:
if kk not in data:
return False
else:
if data.get(kk) != vv:
return False
return True
for k in keys:
c = self.get(k)
if meta_matches(c):
yield c