diff --git a/setup.py b/setup.py index 5348d1382..1c08ae914 100644 --- a/setup.py +++ b/setup.py @@ -38,10 +38,17 @@ "Werkzeug", "zope.interface", ], + extra_requires={ + "sql": [ + "alchimia", + "passlib", + "bcrypt", + ], + }, keywords="twisted flask werkzeug web", license="MIT", name="klein", - packages=["klein", "klein.storage", "klein.test"], + packages=["klein", "klein.storage", "klein.test", "klein.storage.test"], package_dir={"": "src"}, package_data=dict( klein=["py.typed"], diff --git a/src/klein/interfaces.py b/src/klein/interfaces.py index 481799e8d..e634c8dba 100644 --- a/src/klein/interfaces.py +++ b/src/klein/interfaces.py @@ -17,6 +17,55 @@ ) +if TYPE_CHECKING: # pragma: no cover + from typing import Union + + from ._dihttp import RequestComponent, RequestURL + from ._form import Field, FieldInjector, RenderableFormParam + from ._isession import IRequestLifecycleT as _IRequestLifecycleT + from ._session import Authorization, SessionProcurer + from ._storage.memory import MemorySession, MemorySessionStore + from ._storage.sql import ( + AccountSessionBinding, + IPTrackingProcurer, + SessionStore, + SQLAccount, + ) + + ISessionStore = Union[_ISessionStore, MemorySessionStore, SessionStore] + ISessionProcurer = Union[ + _ISessionProcurer, SessionProcurer, IPTrackingProcurer + ] + ISession = Union[_ISession, MemorySession] + ISimpleAccount = Union[_ISimpleAccount, SQLAccount] + ISimpleAccountBinding = Union[_ISimpleAccountBinding, AccountSessionBinding] + IDependencyInjector = Union[ + _IDependencyInjector, + Authorization, + RenderableFormParam, + FieldInjector, + RequestURL, + RequestComponent, + ] + IRequiredParameter = Union[ + _IRequiredParameter, + Authorization, + Field, + RenderableFormParam, + RequestURL, + RequestComponent, + ] + IRequestLifecycle = _IRequestLifecycleT +else: + ISession = _ISession + ISessionStore = _ISessionStore + ISimpleAccount = _ISimpleAccount + ISessionProcurer = _ISessionProcurer + ISimpleAccountBinding = _ISimpleAccountBinding + IDependencyInjector = _IDependencyInjector + IRequiredParameter = _IRequiredParameter + IRequestLifecycle = _IRequestLifecycle + __all__ = ( "EarlyExit", "IDependencyInjector", diff --git a/src/klein/storage/_istorage.py b/src/klein/storage/_istorage.py new file mode 100644 index 000000000..2bfb1272f --- /dev/null +++ b/src/klein/storage/_istorage.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING + +from zope.interface import Attribute, Interface + +from .._typing import ifmethod + + +if TYPE_CHECKING: # pragma: no cover + from twisted.internet.defer import Deferred + + from ..interfaces import ISession, ISessionStore + from ._sql_generic import Transaction + + ISession, ISessionStore, Deferred, Transaction + + +class ISQLAuthorizer(Interface): + """ + An add-on for an L{AlchimiaDataStore} that can populate data on an Alchimia + session. + """ + + authorizationInterface = Attribute( + """ + The interface or class for which a session can be authorized by this + L{ISQLAuthorizer}. + """ + ) + + @ifmethod + def authorizationForSession(sessionStore, transaction, session): + # type: (ISessionStore, Transaction, ISession) -> Deferred + """ + Get a data object that the session has access to. + + If necessary, load related data first. + + @param sessionStore: the store where the session is stored. + @type sessionStore: L{ISessionStore} + + @param transaction: The transaction that loaded the session. + @type transaction: L{klein.storage.sql.Transaction} + + @param session: The session that said this data will be attached to. + @type session: L{ISession} + + @return: the object the session is authorized to access + @rtype: a providier of C{self.authorizationInterface}, or a L{Deferred} + firing the same. + """ + # session_store is a _sql.SessionStore but it's not documented as such. diff --git a/src/klein/storage/_security.py b/src/klein/storage/_security.py new file mode 100644 index 000000000..e35321e42 --- /dev/null +++ b/src/klein/storage/_security.py @@ -0,0 +1,93 @@ +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple +from unicodedata import normalize + +from passlib.context import CryptContext + +from twisted.internet.defer import Deferred, inlineCallbacks, returnValue +from twisted.internet.threads import deferToThread + + +if TYPE_CHECKING: # pragma: no cover + str, Callable, Deferred, Optional, Tuple, Any + + +passlibContextWithGoodDefaults = partial(CryptContext, schemes=["bcrypt"]) + + +def _verifyAndUpdate(secret, hash, ctxFactory=passlibContextWithGoodDefaults): + # type: (Text, Text, Callable[[], CryptContext]) -> Deferred + """ + Asynchronous wrapper for L{CryptContext.verify_and_update}. + """ + + @deferToThread + def theWork(): + # type: () -> Tuple[bool, Optional[str]] + return ctxFactory().verify_and_update(secret, hash) + + return theWork + + +def _hashSecret(secret, ctxFactory=passlibContextWithGoodDefaults): + # type: (Text, Callable[[], CryptContext]) -> Deferred + """ + Asynchronous wrapper for L{CryptContext.hash}. + """ + + @deferToThread + def theWork(): + # type: () -> str + return ctxFactory().hash(secret) + + return theWork + + +@inlineCallbacks +def checkAndReset(storedPasswordText, providedPasswordText, resetter): + # type: (Text, Text, Callable[[str], Any]) -> Any + """ + Check the given stored password text against the given provided password + text. + + @param storedPasswordText: opaque (text) from the account database. + @type storedPasswordText: L{unicode} + + @param providedPasswordText: the plain-text password provided by the + user. + @type providedPasswordText: L{unicode} + + @return: L{Deferred} firing with C{True} if the password matches and + C{False} if the password does not match. + """ + providedPasswordText = normalize("NFD", providedPasswordText) + valid, newHash = yield _verifyAndUpdate( + providedPasswordText, storedPasswordText + ) + if valid: + # Password migration! Does our passlib context have an awesome *new* + # hash it wants to give us? Store it. + if newHash is not None: + if isinstance(newHash, bytes): + newHash = newHash.decode("charmap") + yield resetter(newHash) + returnValue(True) + else: + returnValue(False) + + +@inlineCallbacks +def computeKeyText(passwordText): + # type: (Text) -> Any + """ + Compute some text to store for a given plain-text password. + + @param passwordText: The text of a new password, as entered by a user. + + @return: a L{Deferred} firing with L{unicode}. + """ + normalized = normalize("NFD", passwordText) + hashed = yield _hashSecret(normalized) + if isinstance(hashed, bytes): + hashed = hashed.decode("charmap") + return hashed diff --git a/src/klein/storage/_sql.py b/src/klein/storage/_sql.py new file mode 100644 index 000000000..a45a3c64d --- /dev/null +++ b/src/klein/storage/_sql.py @@ -0,0 +1,744 @@ +from binascii import hexlify +from datetime import datetime +from functools import reduce +from os import urandom +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Type, + TypeVar, + cast, +) +from uuid import uuid4 + +import attr +from attr import Factory +from attr.validators import instance_of as an +from sqlalchemy import ( + Boolean, + Column, + DateTime, + ForeignKey, + MetaData, + Table, + Unicode, + UniqueConstraint, + true, +) +from sqlalchemy.exc import IntegrityError +from sqlalchemy.schema import CreateTable +from sqlalchemy.sql.expression import select +from zope.interface import implementer +from zope.interface.interfaces import IInterface + +from twisted.internet.defer import ( + gatherResults, + inlineCallbacks, + maybeDeferred, + returnValue, +) +from twisted.python.compat import unicode + +from .. import SessionProcurer +from ..interfaces import ( + ISession, + ISessionProcurer, + ISessionStore, + ISimpleAccount, + ISimpleAccountBinding, + NoSuchSession, + SessionMechanism, +) +from ._security import checkAndReset, computeKeyText +from ._sql_generic import Transaction, requestBoundTransaction +from .interfaces import ISQLAuthorizer + + +if TYPE_CHECKING: # pragma: no cover + import sqlalchemy + + from twisted.internet.defer import Deferred + from twisted.internet.interfaces import IReactorThreads + from twisted.web.iweb import IRequest + + from ._sql_generic import DataStore + + ( + Any, + Callable, + Deferred, + Type, + Iterable, + IReactorThreads, + str, + List, + sqlalchemy, + Dict, + IRequest, + IInterface, + Optional, + DataStore, + ) + T = TypeVar("T") + + +@implementer(ISession) +@attr.s +class SQLSession: + _sessionStore = attr.ib(type="SessionStore") + identifier = attr.ib(type=str) + isConfidential = attr.ib(type=bool) + authenticatedBy = attr.ib(type=SessionMechanism) + + def authorize(self, interfaces): + # type: (Iterable[IInterface]) -> Any + interfaces = set(interfaces) + result = {} # type: Dict[IInterface, Deferred] + ds = [] # type: List[Deferred] + txn = self._sessionStore._transaction + for a in self._sessionStore._authorizers: + # This should probably do something smart with interface + # priority, checking isOrExtends or something similar. + if a.authorizationInterface in interfaces: + v = maybeDeferred( + a.authorizationForSession, self._sessionStore, txn, self + ) + ds.append(v) + result[a.authorizationInterface] = v + v.addCallback( + lambda value, ai: result.__setitem__(ai, value), + ai=a.authorizationInterface, + ) + + def r(ignored): + # type: (T) -> Dict[str, Any] + return result + + return gatherResults(ds).addCallback(r) + + +@attr.s +class SessionIPInformation: + """ + Information about a session being used from a given IP address. + """ + + id = attr.ib(validator=an(str), type=str) + ip = attr.ib(validator=an(str), type=str) + when = attr.ib(validator=an(datetime), type=datetime) + + +@implementer(ISessionStore) +@attr.s() +class SessionStore: + """ + An implementation of L{ISessionStore} based on a L{DataStore}, that + stores sessions in a SQLAlchemy database. + """ + + _transaction = attr.ib(type=Transaction) + _authorizers = attr.ib(type=List[ISQLAuthorizer], default=Factory(list)) + + def sentInsecurely(self, tokens): + # type: (List[str]) -> Deferred + """ + Tokens have been sent insecurely; delete any tokens expected to be + confidential. + + @param tokens: L{list} of L{str} + + @return: a L{Deferred} that fires when the tokens have been + invalidated. + """ + s = sessionSchema.session + return gatherResults( + [ + self._transaction.execute( + s.delete().where( + (s.c.session_id == token) & (s.c.confidential == true()) + ) + ) + for token in tokens + ] + ) + + @inlineCallbacks + def newSession(self, isConfidential, authenticatedBy): + # type: (bool, SessionMechanism) -> Deferred + identifier = hexlify(urandom(32)).decode("ascii") + s = sessionSchema.session + yield self._transaction.execute( + s.insert().values( + session_id=identifier, + confidential=isConfidential, + ) + ) + returnValue( + SQLSession( + self, + identifier=identifier, + isConfidential=isConfidential, + authenticatedBy=authenticatedBy, + ) + ) + + @inlineCallbacks + def loadSession(self, identifier, isConfidential, authenticatedBy): + # type: (Text, bool, SessionMechanism) -> Deferred + s = sessionSchema.session + result = yield self._transaction.execute( + s.select( + (s.c.session_id == identifier) + & (s.c.confidential == isConfidential) + ) + ) + results = yield result.fetchall() + if not results: + raise NoSuchSession("Session not present in SQL store.") + fetched_identifier = results[0][s.c.session_id] + returnValue( + SQLSession( + self, + identifier=fetched_identifier, + isConfidential=isConfidential, + authenticatedBy=authenticatedBy, + ) + ) + + +@implementer(ISimpleAccountBinding) +@attr.s +class AccountSessionBinding: + """ + (Stateless) binding between an account and a session, so that sessions can + attach to, detach from, . + """ + + _session = attr.ib(type=ISession) + _transaction = attr.ib(type=Transaction) + + def _account(self, accountID, username, email): + # type: (Text, Text, Text) -> SQLAccount + """ + Construct an L{SQLAccount} bound to this plugin & dataStore. + """ + return SQLAccount(self._transaction, accountID, username, email) + + @inlineCallbacks + def createAccount(self, username, email, password): + # type: (Text, Text, Text) -> Any + """ + Create a new account with the given username, email and password. + + @return: an L{Account} if one could be created, L{None} if one could + not be. + """ + computedHash = yield computeKeyText(password) + newAccountID = unicode(uuid4()) + insert = sessionSchema.account.insert().values( + account_id=newAccountID, + username=username, + email=email, + password_blob=computedHash, + ) + try: + yield self._transaction.execute(insert) + except IntegrityError: + returnValue(None) + else: + accountID = newAccountID + account = self._account(accountID, username, email) + returnValue(account) + + @inlineCallbacks + def bindIfCredentialsMatch(self, username, password): + # type: (Text, Text) -> Any + """ + Associate this session with a given user account, if the password + matches. + + @param username: The username input by the user. + @type username: L{text_type} + + @param password: The plain-text password input by the user. + @type password: L{text_type} + + @rtype: L{Deferred} firing with L{IAccount} if we succeeded and L{None} + if we failed. + """ + acc = sessionSchema.account + + result = yield self._transaction.execute( + acc.select(acc.c.username == username) + ) + accountsInfo = yield result.fetchall() + if not accountsInfo: + # no account, bye + returnValue(None) + [row] = accountsInfo + stored_password_text = row[acc.c.password_blob] + accountID = row[acc.c.account_id] + + def reset_password(newPWText): + # type: (Text) -> Any + a = sessionSchema.account + return self._transaction.execute( + a.update(a.c.account_id == accountID).values( + password_blob=newPWText + ) + ) + + if ( + yield checkAndReset(stored_password_text, password, reset_password) + ): + account = self._account( + accountID, row[acc.c.username], row[acc.c.email] + ) + yield account.bindSession(self._session) + returnValue(account) + + @inlineCallbacks + def boundAccounts(self): + # type: () -> Deferred + """ + Retrieve the accounts currently associated with this session. + + @return: L{Deferred} firing with a L{list} of accounts. + """ + ast = sessionSchema.sessionAccount + acc = sessionSchema.account + result = yield ( + yield self._transaction.execute( + ast.join(acc, ast.c.account_id == acc.c.account_id).select( + ast.c.session_id == self._session.identifier, + use_labels=True, + ) + ) + ).fetchall() + returnValue( + [ + self._account( + it[ast.c.account_id], it[acc.c.username], it[acc.c.email] + ) + for it in result + ] + ) + + @inlineCallbacks + def boundSessionInformation(self): + # type: () -> Any + """ + Retrieve information about all sessions attached to the same account + that this session is. + + @return: L{Deferred} firing a L{list} of L{SessionIPInformation} + """ + acs = sessionSchema.sessionAccount + sipt = sessionSchema.sessionIP + + acs2 = acs.alias() + result = yield self._transaction.execute( + select([sipt], use_labels=True).where( + (acs.c.session_id == self._session.identifier) + & (acs.c.account_id == acs2.c.account_id) + & (acs2.c.session_id == sipt.c.session_id) + ) + ) + returnValue( + [ + SessionIPInformation( + id=row[sipt.c.session_id], + ip=row[sipt.c.ip_address], + when=row[sipt.c.last_used], + ) + for row in (yield result.fetchall()) + ] + ) + + def unbindThisSession(self): + # type: () -> Any + """ + Disassociate this session from any accounts it's logged in to. + + @return: a L{Deferred} that fires when the account is logged out. + """ + ast = sessionSchema.sessionAccount + return self._transaction.execute( + ast.delete(ast.c.session_id == self._session.identifier) + ) + + +@implementer(ISimpleAccount) +@attr.s +class SQLAccount: + """ + An implementation of L{ISimpleAccount} backed by an Alchimia data store. + """ + + _transaction = attr.ib(type=Transaction) + accountID = attr.ib(type=str) + username = attr.ib(type=str) + email = attr.ib(type=str) + + def bindSession(self, session): + # type: (ISession) -> Deferred + """ + Add a session to the database. + """ + return self._transaction.execute( + sessionSchema.sessionAccount.insert().values( + account_id=self.accountID, session_id=session.identifier + ) + ) + + @inlineCallbacks + def changePassword(self, newPassword): + # type: (Text) -> Any + """ + @param newPassword: The text of the new password. + @type newPassword: L{unicode} + """ + computedHash = yield computeKeyText(newPassword) + result = yield self._transaction.execute( + sessionSchema.account.update() + .where(account_id=self.accountID) + .values(password_blob=computedHash) + ) + returnValue(result) + + +@inlineCallbacks +def upsert( + engine, # type: Transaction + table, # type: sqlalchemy.schema.Table + to_query, # type: Dict[str, Any] + to_change, # type: Dict[str, Any] +): + # type: (...) -> Any + """ + Try inserting, if inserting fails, then update. + """ + try: + result = yield engine.execute( + table.insert().values(**dict(to_query, **to_change)) + ) + except IntegrityError: + from operator import and_ as And + + update = ( + table.update() + .where( + reduce( + And, + ( + (getattr(table.c, cname) == cvalue) + for (cname, cvalue) in to_query.items() + ), + ) + ) + .values(**to_change) + ) + result = yield engine.execute(update) + returnValue(result) + + +@attr.s +class SessionSchema: + """ + Schema for SQL session features. + + This is exposed as public API so that you can have tables which relate + against it in your own code, and integrate with your schema management + system. + + However, while Klein uses Alchimia itself, it does not want to be in the + business of managing your schema migrations or your database access. As + such, this class exposes the schema in several formats: + + - via SQLAlchemy metadata, if you want to use something like Alembic or + SQLAlchemy-Migrate + + - via a single SQL string, if you manage your SQL migrations manually + """ + + session = attr.ib(type=Table) + account = attr.ib(type=Table) + sessionAccount = attr.ib(type=Table) + sessionIP = attr.ib(type=Table) + + @classmethod + def withMetadata(cls, metadata=None): + # type: (Optional[MetaData]) -> SessionSchema + """ + Create a new L{SQLSessionSchema} with the given metadata, defaulting to + new L{MetaData} if none is supplied. + """ + if metadata is None: + metadata = MetaData() + session = Table( + "session", + metadata, + Column("session_id", Unicode(), primary_key=True, nullable=False), + Column("confidential", Boolean(), nullable=False), + ) + account = Table( + "account", + metadata, + Column("account_id", Unicode(), primary_key=True, nullable=False), + Column("username", Unicode(), unique=True, nullable=False), + Column("email", Unicode(), nullable=False), + Column("password_blob", Unicode(), nullable=False), + ) + sessionAccount = Table( + "session_account", + metadata, + Column( + "account_id", + Unicode(), + ForeignKey(account.c.account_id, ondelete="CASCADE"), + ), + Column( + "session_id", + Unicode(), + ForeignKey(session.c.session_id, ondelete="CASCADE"), + ), + UniqueConstraint("account_id", "session_id"), + ) + sessionIP = Table( + "session_ip", + metadata, + Column( + "session_id", + Unicode(), + ForeignKey(session.c.session_id, ondelete="CASCADE"), + ), + Column("ip_address", Unicode(), nullable=False), + Column("address_family", Unicode(), nullable=False), + Column("last_used", DateTime(), nullable=False), + UniqueConstraint("session_id", "ip_address", "address_family"), + ) + return cls(session, account, sessionAccount, sessionIP) + + def tables(self): + # type: () -> Iterable[Table] + """ + Yield all tables that need to be created in order for sessions to be + enabled in a SQLAlchemy database, in the order they need to be created. + """ + yield self.session + yield self.account + yield self.sessionAccount + yield self.sessionIP + + @inlineCallbacks + def create(self, transaction): + # type: (Transaction) -> Deferred + """ + Given a L{Transaction}, create this schema in the database and return a + L{Deferred} that fires with C{None} when done. + + This method will handle any future migration concerns. + """ + for table in self.tables(): + yield transaction.execute(CreateTable(table)) + + def migrationSQL(self): + # type: () -> Text + """ + Return some SQL to run in order to create the tables necessary for the + SQL session and account store. Currently there is only one version, + but in the future, sections of this will be clearly delineated by '-- + Klein Session Schema Version X' comments. + + This SQL will not attempt to discern whether the tables exist already + or whether the migrations should be run. + """ + return "\n-- Klein Session Schema Version 1\n" + ( + ";".join(str(CreateTable(table)) for table in self.tables()) + ) + + +sessionSchema = SessionSchema.withMetadata(MetaData()) + +procurerFromTransactionT = Callable[[Transaction], ISessionProcurer] + + +@implementer(ISessionProcurer) +class IPTrackingProcurer: + """ + An implementation of L{ISessionProcurer} that keeps track of the source IP + of the originating session. + """ + + def __init__( + self, + dataStore, # type: DataStore + procurerFromTransaction, # type: procurerFromTransactionT + ): + # type: (...) -> None + """ + Create an L{IPTrackingProcurer} from SQLAlchemy metadata, an alchimia + data store, and an existing L{ISessionProcurer}. + """ + self._dataStore = dataStore + self._procurerFromTransaction = procurerFromTransaction + + @inlineCallbacks + def procureSession(self, request, forceInsecure=False): + # type: (IRequest, bool) -> Deferred + """ + Procure a session from the underlying procurer, keeping track of the IP + of the request object. + """ + alreadyProcured = request.getComponent(ISession) + if alreadyProcured is not None: + returnValue(alreadyProcured) + # if getattr(request, 'requesting', False): + # raise RuntimeError("what are you doing!?") + # request.requesting = True + transaction = yield requestBoundTransaction(request, self._dataStore) + procurer = yield self._procurerFromTransaction(transaction) + session = yield procurer.procureSession(request, forceInsecure) + try: + ipAddress = (request.client.host or b"").decode("ascii") + except BaseException: + ipAddress = "" + sip = sessionSchema.sessionIP + yield upsert( + transaction, + sip, + dict( + session_id=session.identifier, + ip_address=ipAddress, + address_family=("AF_INET6" if ":" in ipAddress else "AF_INET"), + ), + dict(last_used=datetime.utcnow()), + ) + # XXX This should set a savepoint because we don't want application + # logic to be able to roll back the IP access log. + returnValue(session) + + +procurerFromStoreT = Callable[[ISessionStore], ISessionProcurer] + + +def procurerFromDataStore( + dataStore, # type: DataStore + authorizers, # type: List[ISQLAuthorizer] + procurerFromStore=SessionProcurer, # type: procurerFromStoreT +): + # type: (...) -> ISessionProcurer + """ + Open a session store, returning a procurer that can procure sessions from + it. + + @param databaseURL: an SQLAlchemy database URL. + + @param procurerFromStore: A callable that takes an L{ISessionStore} and + returns an L{ISessionProcurer}. + + @return: L{Deferred} firing with L{ISessionProcurer} + """ + allAuthorizers = [ + simpleAccountBinding.authorizer, + logMeIn.authorizer, + ] + list(authorizers) + return IPTrackingProcurer( + dataStore, + lambda transaction: procurerFromStore( + SessionStore(transaction, allAuthorizers) + ), + ) + + +class _FunctionWithAuthorizer: + authorizer = None # type: Any + + def __call__( + self, + sessionStore, # type: SessionStore + transaction, # type: Transaction + session, # type: ISession + ): + # type: (...) -> Any + """ + Signature for a function that can have an authorizer attached to it. + """ + + +_authorizerFunction = Callable[[SessionStore, Transaction, ISession], Any] + + +@implementer(ISQLAuthorizer) +@attr.s +class SimpleSQLAuthorizer: + authorizationInterface = attr.ib(type=Type) + _decorated = attr.ib(type=_authorizerFunction) + + def authorizationForSession(self, sessionStore, transaction, session): + # type: (SessionStore, Transaction, ISession) -> Any + cb = cast(_authorizerFunction, self._decorated) # type: ignore + return cb(sessionStore, transaction, session) + + +def authorizerFor( + authorizationInterface, # type: IInterface +): + # type: (...) -> Callable[[Callable], _FunctionWithAuthorizer] + """ + Declare an SQL authorizer, implemented by a given function. Used like so:: + + @authorizerFor(Foo, tables(foo=[Column("bar", Unicode())])) + def authorizeFoo(dataStore, sessionStore, transaction, session): + return Foo(metadata, metadata.tables["foo"]) + + @param authorizationInterface: The type we are creating an authorizer for. + + @return: a decorator that can decorate a function with the signature + C{(metadata, dataStore, sessionStore, transaction, session)} + """ + + def decorator(decorated): + # type: (_authorizerFunction) -> _FunctionWithAuthorizer + result = cast(_FunctionWithAuthorizer, decorated) + result.authorizer = SimpleSQLAuthorizer( + authorizationInterface, decorated + ) + return result + + return decorator + + +@authorizerFor(ISimpleAccountBinding) +def simpleAccountBinding( + sessionStore, # type: SessionStore + transaction, # type: Transaction + session, # type: ISession +): + # type: (...) -> AccountSessionBinding + """ + All sessions are authorized for access to an L{ISimpleAccountBinding}. + """ + return AccountSessionBinding(session, transaction) + + +@authorizerFor(ISimpleAccount) +@inlineCallbacks +def logMeIn( + sessionStore, # type: SessionStore + transaction, # type: Transaction + session, # type: ISession +): + # type: (...) -> Deferred + """ + Retrieve an L{ISimpleAccount} authorization. + """ + binding = (yield session.authorize([ISimpleAccountBinding]))[ + ISimpleAccountBinding + ] + returnValue(next(iter((yield binding.boundAccounts())), None)) diff --git a/src/klein/storage/_sql_generic.py b/src/klein/storage/_sql_generic.py new file mode 100644 index 000000000..44f9b1eea --- /dev/null +++ b/src/klein/storage/_sql_generic.py @@ -0,0 +1,379 @@ +""" +Generic SQL data storage stuff; the substrate for session-storage stuff. +""" + +from collections import deque +from sys import exc_info +from typing import TYPE_CHECKING, Any, Optional, TypeVar + +import attr +from alchimia import TWISTED_STRATEGY +from attr import Factory +from sqlalchemy import create_engine +from zope.interface import Interface, implementer + +from twisted.internet.defer import ( + Deferred, + gatherResults, + inlineCallbacks, + returnValue, + succeed, +) + +from ..interfaces import TransactionEnded + + +_sqlAlchemyConnection = Any +_sqlAlchemyTransaction = Any + +COMMITTING = "committing" +COMMITTED = "committed" +COMMIT_FAILED = "commit failed" +ROLLING_BACK = "rolling back" +ROLLED_BACK = "rolled back" +ROLLBACK_FAILED = "rollback failed" + +if TYPE_CHECKING: # pragma: no cover + T = TypeVar("T") + from twisted.internet.interfaces import IReactorThreads + + IReactorThreads + from typing import Iterable + + Iterable + from typing import Callable + + Callable + from twisted.web.iweb import IRequest + + IRequest + + +@attr.s +class Transaction: + """ + Wrapper around a SQLAlchemy connection which is invalidated when the + transaction is committed or rolled back. + """ + + _connection = attr.ib(type=_sqlAlchemyConnection) + _transaction = attr.ib(type=_sqlAlchemyTransaction) + _parent = attr.ib(type="Optional[Transaction]", default=None) + _stopped = attr.ib(type=str, default="") + _completeDeferred = attr.ib(type=Deferred, default=Factory(Deferred)) + + def _checkStopped(self): + # type: () -> None + """ + Raise an exception if the transaction has been stopped for any reason. + """ + if self._stopped: + raise TransactionEnded(self._stopped) + if self._parent is not None: + self._parent._checkStopped() + + def execute(self, statement, *multiparams, **params): + # type: (Any, *Any, **Any) -> Deferred + """ + Execute a statement unless this transaction has been stopped, otherwise + raise L{TransactionEnded}. + """ + self._checkStopped() + return self._connection.execute(statement, *multiparams, **params) + + def commit(self): + # type: () -> Deferred + """ + Commit this transaction. + """ + self._checkStopped() + self._stopped = COMMITTING + return self._transaction.commit().addCallbacks( + (lambda commitResult: self._finishWith(COMMITTED)), + (lambda commitFailure: self._finishWith(COMMIT_FAILED)), + ) + + def rollback(self): + # type: () -> Deferred + """ + Roll this transaction back. + """ + self._checkStopped() + self._stopped = ROLLING_BACK + return self._transaction.rollback().addCallbacks( + (lambda commitResult: self._finishWith(ROLLED_BACK)), + (lambda commitFailure: self._finishWith(ROLLBACK_FAILED)), + ) + + def _finishWith(self, stopStatus): + # type: (Text) -> None + """ + Complete this transaction. + """ + self._stopped = stopStatus + self._completeDeferred.callback(stopStatus) + + @inlineCallbacks + def savepoint(self): + # type: () -> Deferred + """ + Create a L{Savepoint} which can be treated as a sub-transaction. + + @note: As long as this L{Savepoint} has not been rolled back or + committed, this transaction's C{execute} method will execute within + the context of that savepoint. + """ + returnValue( + Transaction( + self._connection, (yield self._connection.begin_nested()), self + ) + ) + + def subtransact(self, logic): + # type: (Callable[[Transaction], Deferred]) -> Deferred + """ + Run the given C{logic} in a subtransaction. + """ + return Transactor(self.savepoint).transact(logic) + + def maybeCommit(self): + # type: () -> Deferred + """ + Commit this transaction if it hasn't been finished (committed or rolled + back) yet; otherwise, do nothing. + """ + if self._stopped: + return succeed(None) + return self.commit() + + def maybeRollback(self): + # type: () -> Deferred + """ + Roll this transaction back if it hasn't been finished (committed or + rolled back) yet; otherwise, do nothing. + """ + if self._stopped: + return succeed(None) + return self.rollback() + + +@attr.s +class Transactor: + """ + A context manager that represents the lifecycle of a transaction when + paired with application code. + """ + + _newTransaction = attr.ib(type="Callable[[], Deferred]") + _transaction = attr.ib(type=Optional[Transaction], default=None) + + @inlineCallbacks + def __aenter__(self): + # type: () -> Deferred + """ + Start a transaction. + """ + self._transaction = yield self._newTransaction() + # ^ https://github.com/python/mypy/issues/4688 + returnValue(self._transaction) + + @inlineCallbacks + def __aexit__(self, exc_type, exc_value, traceback): + # type: (type, Exception, Any) -> Deferred + """ + End a transaction. + """ + assert self._transaction is not None + if exc_type is None: + yield self._transaction.commit() + else: + yield self._transaction.rollback() + self._transaction = None + + @inlineCallbacks + def transact(self, logic): + # type: (Callable) -> Deferred + """ + Run the given logic within this L{TransactionContext}, starting and + stopping as usual. + """ + try: + transaction = yield self.__aenter__() + result = yield logic(transaction) + finally: + yield self.__aexit__(*exc_info()) + returnValue(result) + + +@attr.s(hash=False) +class DataStore: + """ + L{DataStore} is a generic storage object that connect to an SQL + database, run transactions, and manage schema metadata. + """ + + _engine = attr.ib(type=_sqlAlchemyConnection) + _freeConnections = attr.ib(default=Factory(deque), type=deque) + + @inlineCallbacks + def newTransaction(self): + # type: () -> Deferred + """ + Create a new Klein transaction. + """ + alchimiaConnection = ( + self._freeConnections.popleft() + if self._freeConnections + else (yield self._engine.connect()) + ) + alchimiaTransaction = yield alchimiaConnection.begin() + kleinTransaction = Transaction(alchimiaConnection, alchimiaTransaction) + + @kleinTransaction._completeDeferred.addBoth + def recycleTransaction(anything): + # type: (T) -> T + self._freeConnections.append(alchimiaConnection) + return anything + + returnValue(kleinTransaction) + + def transact(self, callable): + # type: (Callable[[Transaction], Any]) -> Any + """ + Run the given C{callable} within a transaction. + + @param callable: A callable object that encapsulates application logic + that needs to run in a transaction. + @type callable: callable taking a L{Transaction} and returning a + L{Deferred}. + + @return: a L{Deferred} firing with the result of C{callable} + @rtype: L{Deferred} that fires when the transaction is complete, or + fails when the transaction is rolled back. + """ + return Transactor(self.newTransaction).transact(callable) + + @classmethod + def open(cls, reactor, dbURL): + # type: (IReactorThreads, Text) -> DataStore + """ + Open an L{DataStore}. + + @param reactor: the reactor that this store should be opened on. + @type reactor: L{IReactorThreads} + + @param dbURL: the SQLAlchemy database URI to connect to. + @type dbURL: L{str} + """ + return cls( + create_engine(dbURL, reactor=reactor, strategy=TWISTED_STRATEGY) + ) + + +class ITransactionRequestAssociator(Interface): + """ + Associates transactions with requests. + """ + + +@implementer(ITransactionRequestAssociator) +@attr.s +class TransactionRequestAssociator: + """ + Does the thing the interface says. + """ + + _map = attr.ib(type=dict, default=Factory(dict)) + committing = attr.ib(type=bool, default=False) + + @inlineCallbacks + def transactionForStore(self, dataStore): + # type: (DataStore) -> Deferred + """ + Get a transaction for the given datastore. + """ + if dataStore in self._map: + returnValue(self._map[dataStore]) + txn = yield dataStore.newTransaction() + self._map[dataStore] = txn + returnValue(txn) + + def commitAll(self): + # type: () -> Deferred + """ + Commit all associated transactions. + """ + self.committing = True + return gatherResults( + [value.maybeCommit() for value in self._map.values()] + ) + + +@inlineCallbacks +def requestBoundTransaction(request, dataStore): + # type: (IRequest, DataStore) -> Deferred + """ + Retrieve a transaction that is bound to the lifecycle of the given request. + + There are three use-cases for this lifecycle: + + 1. 'normal CRUD' - a request begins, a transaction is associated with + it, and the transaction completes when the request completes. The + appropriate time to commit the transaction is the moment before the + first byte goes out to the client. The appropriate moment to + interpose this commit is in `Request.write`, at the moment where + it's about to call channel.writeHeaders, since the HTTP status code + should be an indicator of whether the transaction succeeded or + failed. + + 2. 'just the session please' - a request begins, a transaction is + associated with it in order to discover the session, and the + application code in question isn't actually using the database. + (Ideally as expressed through "the dependency-declaration decorator, + such as @authorized, did not indicate that a transaction will be + required"). + + 3. 'fancy API stuff' - a request begins, a transaction is associated + with it in order to discover the session, the application code needs + to then do I{something} with that transaction in-line with the + session discovery, but then needs to commit in order to relinquish + all database locks while doing some potentially slow external API + calls, then start a I{new} transaction later in the request flow. + """ + assoc = request.getComponent(ITransactionRequestAssociator) + if assoc is None: + assoc = TransactionRequestAssociator() + request.setComponent(ITransactionRequestAssociator, assoc) + + def finishCommit(result): + # type: (Any) -> Deferred + return assoc.commitAll() + + request.notifyFinish().addBoth(finishCommit) + + # originalWrite = request.write + # buffer = [] + # def committed(result): + # for buf in buffer: + # if buf is None: + # originalFinish() + # else: + # originalWrite(buf) + + # def maybeWrite(data): + # if request.startedWriting: + # return originalWrite(data) + # buffer.append(data) + # if assoc.committing: + # return + # assoc.commitAll().addBoth(committed) + # def maybeFinish(): + # if not request.startedWriting: + # buffer.append(None) + # else: + # originalFinish() + # originalFinish = request.finish + # request.write = maybeWrite + # request.finish = maybeFinish + txn = yield assoc.transactionForStore(dataStore) + return txn diff --git a/src/klein/storage/interfaces.py b/src/klein/storage/interfaces.py new file mode 100644 index 000000000..fa083f677 --- /dev/null +++ b/src/klein/storage/interfaces.py @@ -0,0 +1,13 @@ +from typing import TYPE_CHECKING, Union + +from ._istorage import ISQLAuthorizer as _ISQLAuthorizer + + +if TYPE_CHECKING: # pragma: no cover + from ._sql import SimpleSQLAuthorizer + + ISQLAuthorizer = Union[_ISQLAuthorizer, SimpleSQLAuthorizer] +else: + ISQLAuthorizer = _ISQLAuthorizer + +__all__ = ["ISQLAuthorizer"] diff --git a/src/klein/storage/sql.py b/src/klein/storage/sql.py new file mode 100644 index 000000000..62288a6c8 --- /dev/null +++ b/src/klein/storage/sql.py @@ -0,0 +1,16 @@ +from ._sql import SessionSchema, authorizerFor, procurerFromDataStore +from ._sql_generic import DataStore, Transaction + + +__all__ = [ + "procurerFromDataStore", + "authorizerFor", + "SessionSchema", + "DataStore", + "Transaction", +] + +if __name__ == "__main__": + import sys + + sys.stdout.write(SessionSchema.withMetadata().migrationSQL()) diff --git a/src/klein/storage/test/test_security.py b/src/klein/storage/test/test_security.py new file mode 100644 index 000000000..60c73641e --- /dev/null +++ b/src/klein/storage/test/test_security.py @@ -0,0 +1,12 @@ +from twisted.trial.unittest import TestCase + +from klein.storage import security + + +class SQLTests(TestCase): + def test_security(self): + # type: () -> None + """ + Add tests. + """ + security diff --git a/src/klein/storage/test/test_sql.py b/src/klein/storage/test/test_sql.py new file mode 100644 index 000000000..040b5960a --- /dev/null +++ b/src/klein/storage/test/test_sql.py @@ -0,0 +1,12 @@ +from twisted.trial.unittest import TestCase + +from klein.storage import sql + + +class SQLTests(TestCase): + def test_sql(self): + # type: () -> None + """ + Add tests. + """ + sql