from typing import Iterable, Any, List
from pathlib import Path
from dataclasses import dataclass, field
from .base import CallStorage, AmbiguousDigestError
from ..call import Call, QueryCall
from ..digest import Digest, DIGEST_LENGTH, digest
from pyiron_snippets.import_alarm import ImportAlarm
with ImportAlarm(
"Sql requires 'sqlalchemy' to be installed. "
"Install it with `pip install fleche[sqlalchemy]`.",
raise_exception=True,
) as sqlalchemy_alarm:
from sqlalchemy import (
create_engine,
String,
Integer,
ForeignKey,
UniqueConstraint,
select,
and_,
)
from sqlalchemy import event
from sqlalchemy.orm import (
declarative_base,
sessionmaker,
relationship,
aliased,
Mapped,
mapped_column,
)
from sqlalchemy.types import JSON
[docs]
Base = declarative_base()
class CallModel(Base):
__tablename__ = "calls"
key = mapped_column(String(DIGEST_LENGTH), primary_key=True)
name: Mapped[str] = mapped_column(String, nullable=False)
module: Mapped[str] = mapped_column(String, nullable=True)
version: Mapped[int] = mapped_column(Integer, nullable=True)
code_digest: Mapped[str] = mapped_column(String(DIGEST_LENGTH), nullable=True)
result: Mapped[str] = mapped_column(String(DIGEST_LENGTH), nullable=True)
arguments = relationship(
"ArgumentModel",
back_populates="call",
cascade="all, delete-orphan",
order_by="ArgumentModel.position",
)
class ArgumentModel(Base):
__tablename__ = "arguments"
id = mapped_column(Integer, primary_key=True, autoincrement=True)
call_key = mapped_column(
String(DIGEST_LENGTH),
ForeignKey("calls.key", ondelete="CASCADE"),
nullable=False,
)
position = mapped_column(Integer, nullable=False)
name = mapped_column(String, nullable=False)
value = mapped_column(String(DIGEST_LENGTH), nullable=False)
__table_args__ = (
UniqueConstraint("call_key", "name", name="uq_arguments_call_name"),
)
call = relationship("CallModel", back_populates="arguments")
class MetaModel(Base):
__tablename__ = "metadata"
id = mapped_column(Integer, primary_key=True, autoincrement=True)
call_key = mapped_column(
String(DIGEST_LENGTH),
ForeignKey("calls.key", ondelete="CASCADE"),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(String, nullable=False, index=True)
data: Mapped[dict] = mapped_column(JSON, nullable=False)
__table_args__ = (
UniqueConstraint("call_key", "name", name="uq_metadata_call_name"),
)
[docs]
def _coerce_sqlite_url(path_or_url: str | None) -> str:
if path_or_url is None:
return "sqlite:///:memory:"
if isinstance(path_or_url, str) and path_or_url.startswith("sqlite:"):
url = path_or_url
else:
abs_path = Path(str(path_or_url)).absolute()
url = f"sqlite:///{abs_path}"
if url.startswith("sqlite:///"):
db_path = url[len("sqlite:///") :]
if db_path and db_path != ":memory:":
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
return url
# Use a constant for the PRAGMA execution to avoid raw string injection risks.
[docs]
SQLITE_FOREIGN_KEYS_ON = "PRAGMA foreign_keys=ON"
[docs]
def _enable_sqlite_foreign_keys(engine) -> None:
@event.listens_for(engine, "connect")
def _set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
try:
# We use a static constant string here. We cannot use sqlalchemy.text()
# because we are operating on a raw DBAPI cursor at the 'connect' event.
cursor.execute(SQLITE_FOREIGN_KEYS_ON)
finally:
cursor.close()
@dataclass
[docs]
class Sql(CallStorage):
"""SQLAlchemy-backed CallStorage with JSON metadata and DB-backed expand()."""
[docs]
echo: bool = field(default=False, compare=False)
[docs]
engine: Any = field(init=False, repr=False, compare=False)
[docs]
session: Any = field(init=False, repr=False, compare=False)
@sqlalchemy_alarm
[docs]
def __post_init__(self) -> None:
self.url = _coerce_sqlite_url(self.url)
assert self.url is not None
self.engine = create_engine(self.url, echo=self.echo, future=True)
_enable_sqlite_foreign_keys(self.engine)
Base.metadata.create_all(self.engine)
self.session = sessionmaker(
bind=self.engine, expire_on_commit=False, future=True
)
[docs]
def __getstate__(self):
state = self.__dict__.copy()
state.pop("engine", None)
state.pop("session", None)
return state
[docs]
def __setstate__(self, state):
self.__dict__.update(state)
# Re-initialize unpickleable fields
self.__post_init__()
[docs]
def _save(self, call: Call) -> Digest:
key = call.to_lookup_key()
session = self.session()
try:
existing = session.get(CallModel, str(key))
if existing is not None:
if self.load(key) == call:
return key
session.delete(existing)
session.flush()
call_model = CallModel(
key=str(key),
name=call.name,
module=call.module,
version=call.version,
code_digest=call.code_digest,
result=call.result if call.result is None else str(call.result),
)
session.add(call_model)
for i, (k, v) in enumerate(call.arguments.items()):
session.add(
ArgumentModel(
call_key=str(key), position=i, name=str(k), value=str(v)
)
)
if call.metadata:
session.add_all(
[
MetaModel(call_key=str(key), name=name, data=data)
for name, data in call.metadata.items()
]
)
session.commit()
# Always return a Digest instance, not a plain str
return key
except Exception:
session.rollback()
raise
finally:
session.close()
[docs]
def _load(self, key: str) -> Call:
session = self.session()
try:
call_model = session.execute(
select(CallModel).where(CallModel.key == key)
).scalar_one_or_none()
if call_model is None:
raise KeyError(key)
arguments = {arg.name: Digest(arg.value) for arg in call_model.arguments}
meta_rows = (
session.execute(select(MetaModel).where(MetaModel.call_key == key))
.scalars()
.all()
)
call = Call(
name=call_model.name,
arguments=arguments,
metadata={row.name: (row.data or {}) for row in meta_rows},
module=call_model.module,
version=call_model.version,
code_digest=call_model.code_digest,
result=(
Digest(call_model.result) if call_model.result is not None else None
),
)
return call
finally:
session.close()
[docs]
def _contains(self, key: Digest) -> bool:
session = self.session()
try:
return (
session.execute(
select(CallModel.key).where(CallModel.key == str(key))
).first()
is not None
)
finally:
session.close()
[docs]
def list(self) -> Iterable[Digest]:
session = self.session()
try:
# Return Digest instances for keys
return [Digest(row[0]) for row in session.execute(select(CallModel.key))]
finally:
session.close()
[docs]
def expand(self, key: Digest | str) -> Digest:
if len(key) >= DIGEST_LENGTH:
return Digest(str(key))
if len(key) < 4:
raise KeyError(key)
prefix = str(key)
session = self.session()
try:
rows = session.execute(
select(CallModel.key)
.where(CallModel.key.like(f"{prefix}%"))
.order_by(CallModel.key)
.limit(2)
).all()
finally:
session.close()
if not rows:
raise KeyError(key)
if len(rows) == 1:
return Digest(rows[0][0])
m1, m2 = rows[0][0], rows[1][0]
for i, (c1, c2) in enumerate(zip(m1, m2)):
if c1 != c2:
break
else:
i = min(len(m1), len(m2))
raise AmbiguousDigestError(
f"Short digest {key} is ambiguous; need at least {i+1} characters."
)
[docs]
def _evict(self, key: str) -> None:
session = self.session()
try:
instance = session.get(CallModel, key)
if instance is None:
return
session.delete(instance)
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
[docs]
def _normalize_value(self, v: Any) -> str:
"""Return the stored form used in SQL for argument/result matching.
We must match the generic CallStorage.query semantics which compare
digest(template_value) == digest(stored_call_value).
In this backend, stored argument/result values are hex-digest strings,
and digest(Digest(x)) == x. Therefore we should always compare
Arg.value/CallModel.result to str(digest(template_value)).
"""
return str(digest(v))
[docs]
def _build_call_conditions(self, template: QueryCall) -> List[Any]:
conditions = []
if template.name is not None:
conditions.append(CallModel.name == template.name)
if template.module is not None:
conditions.append(CallModel.module == template.module)
if template.version is not None:
conditions.append(CallModel.version == template.version)
if template.code_digest is not None:
conditions.append(CallModel.code_digest == template.code_digest)
if template.result is not None:
conditions.append(
CallModel.result == self._normalize_value(template.result)
)
return conditions
[docs]
def _apply_argument_filters(self, stmt: Any, arguments: dict[str, Any] | None) -> Any:
if not arguments:
return stmt
for k, v in arguments.items():
Arg = aliased(ArgumentModel)
on_clause = and_(
Arg.call_key == CallModel.key,
Arg.name == str(k),
)
if v is not None:
on_clause = and_(on_clause, Arg.value == self._normalize_value(v))
stmt = stmt.join(Arg, on_clause)
return stmt
[docs]
def query(self, template: QueryCall) -> Iterable[Call]:
"""Find cached calls matching a template using SQL-side filtering.
Semantics match CallStorage.query:
- Fields set to None are wildcards.
- Arguments and result are compared by digest(template_value) == digest(stored_value).
- Metadata can be filtered by providing template.metadata as a mapping of
metadata name -> dict of key/value filters. An empty dict for a given
name means "presence of that metadata name". Filters with simple types
(str, bool, int, float) are pushed down to SQL via JSON-extract
expressions; other types (e.g., lists) or None values fall back to
client-side checks after loading.
This method builds a SELECT over calls, joining the arguments table and
metadata table as needed to reduce candidate rows, then loads the
resulting calls and performs any remaining client-side validation.
Args:
template: A Call used as a template. None-valued fields are wildcards.
Yields:
Call: Matching calls including their decoded metadata.
"""
session = self.session()
try:
stmt = select(CallModel.key).select_from(CallModel)
conditions = self._build_call_conditions(template)
if conditions:
stmt = stmt.where(and_(*conditions))
stmt = self._apply_argument_filters(stmt, template.arguments)
stmt = self._apply_metadata_filters(stmt, template.metadata)
# Distinct to avoid duplicate keys if multiple argument joins could overlap
stmt = stmt.distinct()
keys = [Digest(k) for (k,) in session.execute(stmt).all()]
finally:
session.close()
# Yield loaded calls using existing loader (ensures metadata returned too)
def meta_matches(call: Call) -> bool:
specs = template.metadata
if not specs:
return True
for mname, filters in specs.items():
data = (call.metadata or {}).get(mname)
if data is None:
return False
for kk, vv in (filters or {}).items():
if vv is None:
if kk not in data:
return False
else:
if data.get(kk) != vv:
return False
return True
for k in keys:
c = self.load(k)
if meta_matches(c):
yield c