Source code for fleche.storage.sql

from typing import Iterable, Any, List
from pathlib import Path
from dataclasses import dataclass, field
from .base import CallStorage, AmbiguousDigestError
from ..call import Call, QueryCall
from ..digest import Digest, DIGEST_LENGTH, digest

from pyiron_snippets.import_alarm import ImportAlarm

with ImportAlarm(
    "Sql requires 'sqlalchemy' to be installed. "
    "Install it with `pip install fleche[sqlalchemy]`.",
    raise_exception=True,
) as sqlalchemy_alarm:
    from sqlalchemy import (
        create_engine,
        String,
        Integer,
        ForeignKey,
        UniqueConstraint,
        select,
        and_,
    )
    from sqlalchemy import event
    from sqlalchemy.orm import (
        declarative_base,
        sessionmaker,
        relationship,
        aliased,
        Mapped,
        mapped_column,
    )
    from sqlalchemy.types import JSON

[docs] Base = declarative_base()
class CallModel(Base): __tablename__ = "calls" key = mapped_column(String(DIGEST_LENGTH), primary_key=True) name: Mapped[str] = mapped_column(String, nullable=False) module: Mapped[str] = mapped_column(String, nullable=True) version: Mapped[int] = mapped_column(Integer, nullable=True) code_digest: Mapped[str] = mapped_column(String(DIGEST_LENGTH), nullable=True) result: Mapped[str] = mapped_column(String(DIGEST_LENGTH), nullable=True) arguments = relationship( "ArgumentModel", back_populates="call", cascade="all, delete-orphan", order_by="ArgumentModel.position", ) class ArgumentModel(Base): __tablename__ = "arguments" id = mapped_column(Integer, primary_key=True, autoincrement=True) call_key = mapped_column( String(DIGEST_LENGTH), ForeignKey("calls.key", ondelete="CASCADE"), nullable=False, ) position = mapped_column(Integer, nullable=False) name = mapped_column(String, nullable=False) value = mapped_column(String(DIGEST_LENGTH), nullable=False) __table_args__ = ( UniqueConstraint("call_key", "name", name="uq_arguments_call_name"), ) call = relationship("CallModel", back_populates="arguments") class MetaModel(Base): __tablename__ = "metadata" id = mapped_column(Integer, primary_key=True, autoincrement=True) call_key = mapped_column( String(DIGEST_LENGTH), ForeignKey("calls.key", ondelete="CASCADE"), nullable=False, index=True, ) name: Mapped[str] = mapped_column(String, nullable=False, index=True) data: Mapped[dict] = mapped_column(JSON, nullable=False) __table_args__ = ( UniqueConstraint("call_key", "name", name="uq_metadata_call_name"), )
[docs] def _coerce_sqlite_url(path_or_url: str | None) -> str: if path_or_url is None: return "sqlite:///:memory:" if isinstance(path_or_url, str) and path_or_url.startswith("sqlite:"): url = path_or_url else: abs_path = Path(str(path_or_url)).absolute() url = f"sqlite:///{abs_path}" if url.startswith("sqlite:///"): db_path = url[len("sqlite:///") :] if db_path and db_path != ":memory:": Path(db_path).parent.mkdir(parents=True, exist_ok=True) return url
# Use a constant for the PRAGMA execution to avoid raw string injection risks.
[docs] SQLITE_FOREIGN_KEYS_ON = "PRAGMA foreign_keys=ON"
[docs] def _enable_sqlite_foreign_keys(engine) -> None: @event.listens_for(engine, "connect") def _set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() try: # We use a static constant string here. We cannot use sqlalchemy.text() # because we are operating on a raw DBAPI cursor at the 'connect' event. cursor.execute(SQLITE_FOREIGN_KEYS_ON) finally: cursor.close()
@dataclass
[docs] class Sql(CallStorage): """SQLAlchemy-backed CallStorage with JSON metadata and DB-backed expand()."""
[docs] url: str | None = None
[docs] echo: bool = field(default=False, compare=False)
[docs] engine: Any = field(init=False, repr=False, compare=False)
[docs] session: Any = field(init=False, repr=False, compare=False)
@sqlalchemy_alarm
[docs] def __post_init__(self) -> None: self.url = _coerce_sqlite_url(self.url) assert self.url is not None self.engine = create_engine(self.url, echo=self.echo, future=True) _enable_sqlite_foreign_keys(self.engine) Base.metadata.create_all(self.engine) self.session = sessionmaker( bind=self.engine, expire_on_commit=False, future=True )
[docs] def __getstate__(self): state = self.__dict__.copy() state.pop("engine", None) state.pop("session", None) return state
[docs] def __setstate__(self, state): self.__dict__.update(state) # Re-initialize unpickleable fields self.__post_init__()
[docs] def _save(self, call: Call) -> Digest: key = call.to_lookup_key() session = self.session() try: existing = session.get(CallModel, str(key)) if existing is not None: if self.load(key) == call: return key session.delete(existing) session.flush() call_model = CallModel( key=str(key), name=call.name, module=call.module, version=call.version, code_digest=call.code_digest, result=call.result if call.result is None else str(call.result), ) session.add(call_model) for i, (k, v) in enumerate(call.arguments.items()): session.add( ArgumentModel( call_key=str(key), position=i, name=str(k), value=str(v) ) ) if call.metadata: session.add_all( [ MetaModel(call_key=str(key), name=name, data=data) for name, data in call.metadata.items() ] ) session.commit() # Always return a Digest instance, not a plain str return key except Exception: session.rollback() raise finally: session.close()
[docs] def _load(self, key: Digest) -> Call: session = self.session() try: call_model = session.execute( select(CallModel).where(CallModel.key == str(key)) ).scalar_one_or_none() if call_model is None: raise KeyError(key) arguments = {arg.name: Digest(arg.value) for arg in call_model.arguments} meta_rows = ( session.execute(select(MetaModel).where(MetaModel.call_key == str(key))) .scalars() .all() ) call = Call( name=call_model.name, arguments=arguments, metadata={row.name: (row.data or {}) for row in meta_rows}, module=call_model.module, version=call_model.version, code_digest=call_model.code_digest, result=( Digest(call_model.result) if call_model.result is not None else None ), ) return call finally: session.close()
[docs] def _contains(self, key: Digest) -> bool: session = self.session() try: return ( session.execute( select(CallModel.key).where(CallModel.key == str(key)) ).first() is not None ) finally: session.close()
[docs] def list(self) -> Iterable[Digest]: session = self.session() try: # Return Digest instances for keys return [Digest(row[0]) for row in session.execute(select(CallModel.key))] finally: session.close()
[docs] def expand(self, key: Digest | str) -> Digest: if len(key) >= DIGEST_LENGTH: return Digest(str(key)) if len(key) < 4: raise KeyError(key) prefix = str(key) session = self.session() try: rows = session.execute( select(CallModel.key) .where(CallModel.key.like(f"{prefix}%")) .order_by(CallModel.key) .limit(2) ).all() finally: session.close() if not rows: raise KeyError(key) if len(rows) == 1: return Digest(rows[0][0]) m1, m2 = rows[0][0], rows[1][0] for i, (c1, c2) in enumerate(zip(m1, m2)): if c1 != c2: break else: i = min(len(m1), len(m2)) raise AmbiguousDigestError( f"Short digest {key} is ambiguous; need at least {i+1} characters." )
[docs] def _evict(self, key: Digest) -> None: session = self.session() try: instance = session.get(CallModel, str(key)) if instance is None: return session.delete(instance) session.commit() except Exception: session.rollback() raise finally: session.close()
[docs] def _normalize_value(self, v: Any) -> str: """Return the stored form used in SQL for argument/result matching. We must match the generic CallStorage.query semantics which compare digest(template_value) == digest(stored_call_value). In this backend, stored argument/result values are hex-digest strings, and digest(Digest(x)) == x. Therefore we should always compare Arg.value/CallModel.result to str(digest(template_value)). """ return str(digest(v))
[docs] def _build_call_conditions(self, template: QueryCall) -> List[Any]: conditions = [] if template.name is not None: conditions.append(CallModel.name == template.name) if template.module is not None: conditions.append(CallModel.module == template.module) if template.version is not None: conditions.append(CallModel.version == template.version) if template.code_digest is not None: conditions.append(CallModel.code_digest == template.code_digest) if template.result is not None: conditions.append( CallModel.result == self._normalize_value(template.result) ) return conditions
[docs] def _apply_argument_filters(self, stmt: Any, arguments: dict[str, Any] | None) -> Any: if not arguments: return stmt for k, v in arguments.items(): Arg = aliased(ArgumentModel) on_clause = and_( Arg.call_key == CallModel.key, Arg.name == str(k), ) if v is not None: on_clause = and_(on_clause, Arg.value == self._normalize_value(v)) stmt = stmt.join(Arg, on_clause) return stmt
[docs] def _apply_metadata_filters( self, stmt: Any, meta_specs: dict[str, dict[str, Any]] | None ) -> Any: if not meta_specs: return stmt # Determine if all filters are server-side compatible server_side_meta = True for _mname, filters in meta_specs.items(): for _k, _v in (filters or {}).items(): if _v is None or not isinstance(_v, (str, bool, int, float)): server_side_meta = False break if not server_side_meta: break if server_side_meta: # Build joins and conditions per metadata name for mname, filters in meta_specs.items(): M = aliased(MetaModel) stmt = stmt.join( M, and_( M.call_key == CallModel.key, M.name == mname, ), ) for k, v in (filters or {}).items(): if isinstance(v, bool): stmt = stmt.where(M.data[k].as_boolean() == v) elif isinstance(v, int): stmt = stmt.where(M.data[k].as_integer() == v) elif isinstance(v, float): stmt = stmt.where(M.data[k].as_float() == v) else: stmt = stmt.where(M.data[k].as_string() == v) return stmt
[docs] def query(self, template: QueryCall) -> Iterable[Call]: """Find cached calls matching a template using SQL-side filtering. Semantics match CallStorage.query: - Fields set to None are wildcards. - Arguments and result are compared by digest(template_value) == digest(stored_value). - Metadata can be filtered by providing template.metadata as a mapping of metadata name -> dict of key/value filters. An empty dict for a given name means "presence of that metadata name". Filters with simple types (str, bool, int, float) are pushed down to SQL via JSON-extract expressions; other types (e.g., lists) or None values fall back to client-side checks after loading. This method builds a SELECT over calls, joining the arguments table and metadata table as needed to reduce candidate rows, then loads the resulting calls and performs any remaining client-side validation. Args: template: A Call used as a template. None-valued fields are wildcards. Yields: Call: Matching calls including their decoded metadata. """ session = self.session() try: stmt = select(CallModel.key).select_from(CallModel) conditions = self._build_call_conditions(template) if conditions: stmt = stmt.where(and_(*conditions)) stmt = self._apply_argument_filters(stmt, template.arguments) stmt = self._apply_metadata_filters(stmt, template.metadata) # Distinct to avoid duplicate keys if multiple argument joins could overlap stmt = stmt.distinct() keys = [Digest(k) for (k,) in session.execute(stmt).all()] finally: session.close() # Yield loaded calls using existing loader (ensures metadata returned too) def meta_matches(call: Call) -> bool: specs = template.metadata if not specs: return True for mname, filters in specs.items(): data = (call.metadata or {}).get(mname) if data is None: return False for kk, vv in (filters or {}).items(): if vv is None: if kk not in data: return False else: if data.get(kk) != vv: return False return True for k in keys: c = self.load(k) if meta_matches(c): yield c