From 8fb9a1cc88e98ab4308f3d16071f0af6e6d84b7c Mon Sep 17 00:00:00 2001 From: Glyph Date: Sat, 1 Jun 2019 02:30:09 -0700 Subject: [PATCH 1/6] Revert "remove untested SQL/Alchimia backend implementation" This reverts commit b8efc966a499ad9237648a97cc5f88ac130b077d. --- setup.py | 9 +- src/klein/interfaces.py | 13 +- src/klein/storage/_istorage.py | 49 ++ src/klein/storage/_security.py | 89 ++++ src/klein/storage/_sql.py | 668 ++++++++++++++++++++++++ src/klein/storage/_sql_generic.py | 373 +++++++++++++ src/klein/storage/interfaces.py | 14 + src/klein/storage/sql.py | 19 + src/klein/storage/test/test_security.py | 13 + src/klein/storage/test/test_sql.py | 12 + 10 files changed, 1254 insertions(+), 5 deletions(-) create mode 100644 src/klein/storage/_istorage.py create mode 100644 src/klein/storage/_security.py create mode 100644 src/klein/storage/_sql.py create mode 100644 src/klein/storage/_sql_generic.py create mode 100644 src/klein/storage/interfaces.py create mode 100644 src/klein/storage/sql.py create mode 100644 src/klein/storage/test/test_security.py create mode 100644 src/klein/storage/test/test_sql.py diff --git a/setup.py b/setup.py index be7a28642..c74409160 100644 --- a/setup.py +++ b/setup.py @@ -38,11 +38,18 @@ "Werkzeug", "zope.interface", ], + extras_require={ + "sql": [ + "alchimia", + "passlib", + "bcrypt", + ] + }, keywords="twisted flask werkzeug web", license="MIT", name="klein", packages=["klein", "klein.storage", - "klein.test"], + "klein.test", "klein.storage.test"], package_dir={"": "src"}, package_data=dict( klein=[ diff --git a/src/klein/interfaces.py b/src/klein/interfaces.py index e40218bd7..16f2f596e 100644 --- a/src/klein/interfaces.py +++ b/src/klein/interfaces.py @@ -25,6 +25,8 @@ if TYPE_CHECKING: # pragma: no cover from ._storage.memory import MemorySessionStore, MemorySession + from ._storage.sql import (SessionStore, SQLAccount, IPTrackingProcurer, + AccountSessionBinding) from ._session import SessionProcurer, Authorization from ._form import Field, RenderableFormParam, FieldInjector from ._isession import IRequestLifecycleT as _IRequestLifecycleT @@ -32,11 +34,14 @@ from typing import Union - ISessionStore = Union[_ISessionStore, MemorySessionStore] - ISessionProcurer = Union[_ISessionProcurer, SessionProcurer] + ISessionStore = Union[_ISessionStore, MemorySessionStore, + SessionStore] + ISessionProcurer = Union[_ISessionProcurer, SessionProcurer, + IPTrackingProcurer] ISession = Union[_ISession, MemorySession] - ISimpleAccount = _ISimpleAccount - ISimpleAccountBinding = _ISimpleAccountBinding + ISimpleAccount = Union[_ISimpleAccount, SQLAccount] + ISimpleAccountBinding = Union[_ISimpleAccountBinding, + AccountSessionBinding] IDependencyInjector = Union[_IDependencyInjector, Authorization, RenderableFormParam, FieldInjector, RequestURL, RequestComponent] diff --git a/src/klein/storage/_istorage.py b/src/klein/storage/_istorage.py new file mode 100644 index 000000000..28f5b32d8 --- /dev/null +++ b/src/klein/storage/_istorage.py @@ -0,0 +1,49 @@ + +from typing import TYPE_CHECKING + +from zope.interface import Attribute, Interface + +from .._typing import ifmethod + +if TYPE_CHECKING: # pragma: no cover + from twisted.internet.defer import Deferred + from ..interfaces import ISessionStore, ISession + from ._sql_generic import Transaction + ISession, ISessionStore, Deferred, Transaction + + +class ISQLAuthorizer(Interface): + """ + An add-on for an L{AlchimiaDataStore} that can populate data on an Alchimia + session. + """ + + authorizationInterface = Attribute( + """ + The interface or class for which a session can be authorized by this + L{ISQLAuthorizer}. + """ + ) + + @ifmethod + def authorizationForSession(sessionStore, transaction, session): + # type: (ISessionStore, Transaction, ISession) -> Deferred + """ + Get a data object that the session has access to. + + If necessary, load related data first. + + @param sessionStore: the store where the session is stored. + @type sessionStore: L{ISessionStore} + + @param transaction: The transaction that loaded the session. + @type transaction: L{klein.storage.sql.Transaction} + + @param session: The session that said this data will be attached to. + @type session: L{ISession} + + @return: the object the session is authorized to access + @rtype: a providier of C{self.authorizationInterface}, or a L{Deferred} + firing the same. + """ + # session_store is a _sql.SessionStore but it's not documented as such. diff --git a/src/klein/storage/_security.py b/src/klein/storage/_security.py new file mode 100644 index 000000000..c3ff2fdef --- /dev/null +++ b/src/klein/storage/_security.py @@ -0,0 +1,89 @@ + +from functools import partial +from typing import Any, Callable, Optional, TYPE_CHECKING, Text, Tuple +from unicodedata import normalize + +from passlib.context import CryptContext + +from twisted.internet.defer import Deferred, inlineCallbacks, returnValue +from twisted.internet.threads import deferToThread + +if TYPE_CHECKING: # pragma: no cover + Text, Callable, Deferred, Optional, Tuple, Any + + +passlibContextWithGoodDefaults = partial(CryptContext, schemes=['bcrypt']) + +def _verifyAndUpdate(secret, hash, ctxFactory=passlibContextWithGoodDefaults): + # type: (Text, Text, Callable[[], CryptContext]) -> Deferred + """ + Asynchronous wrapper for L{CryptContext.verify_and_update}. + """ + @deferToThread + def theWork(): + # type: () -> Tuple[bool, Optional[str]] + return ctxFactory().verify_and_update(secret, hash) + return theWork + + +def _hashSecret(secret, ctxFactory=passlibContextWithGoodDefaults): + # type: (Text, Callable[[], CryptContext]) -> Deferred + """ + Asynchronous wrapper for L{CryptContext.hash}. + """ + @deferToThread + def theWork(): + # type: () -> str + return ctxFactory().hash(secret) + return theWork + + + +@inlineCallbacks +def checkAndReset(storedPasswordText, providedPasswordText, resetter): + # type: (Text, Text, Callable[[str], Any]) -> Any + """ + Check the given stored password text against the given provided password + text. + + @param storedPasswordText: opaque (text) from the account database. + @type storedPasswordText: L{unicode} + + @param providedPasswordText: the plain-text password provided by the + user. + @type providedPasswordText: L{unicode} + + @return: L{Deferred} firing with C{True} if the password matches and + C{False} if the password does not match. + """ + providedPasswordText = normalize('NFD', providedPasswordText) + valid, newHash = yield _verifyAndUpdate(providedPasswordText, + storedPasswordText) + if valid: + # Password migration! Does our passlib context have an awesome *new* + # hash it wants to give us? Store it. + if newHash is not None: + if isinstance(newHash, bytes): + newHash = newHash.decode("charmap") + yield resetter(newHash) + returnValue(True) + else: + returnValue(False) + + + +@inlineCallbacks +def computeKeyText(passwordText): + # type: (Text) -> Any + """ + Compute some text to store for a given plain-text password. + + @param passwordText: The text of a new password, as entered by a user. + + @return: a L{Deferred} firing with L{unicode}. + """ + normalized = normalize('NFD', passwordText) + hashed = yield _hashSecret(normalized) + if isinstance(hashed, bytes): + hashed = hashed.decode("charmap") + return hashed diff --git a/src/klein/storage/_sql.py b/src/klein/storage/_sql.py new file mode 100644 index 000000000..be79d4103 --- /dev/null +++ b/src/klein/storage/_sql.py @@ -0,0 +1,668 @@ +from binascii import hexlify +from datetime import datetime +from functools import reduce +from os import urandom +from typing import ( + Any, Callable, Dict, Iterable, List, Optional, TYPE_CHECKING, Text, + Type, TypeVar, cast +) +from uuid import uuid4 + +import attr +from attr import Factory +from attr.validators import instance_of as an + +from six import text_type + +from sqlalchemy import ( + Boolean, Column, DateTime, ForeignKey, MetaData, Table, + Unicode, UniqueConstraint, true +) +from sqlalchemy.exc import IntegrityError +from sqlalchemy.schema import CreateTable +from sqlalchemy.sql.expression import select + +from twisted.internet.defer import ( + gatherResults, inlineCallbacks, maybeDeferred, returnValue +) +from twisted.python.compat import unicode + +from zope.interface import implementer +from zope.interface.interfaces import IInterface + +from ._security import checkAndReset, computeKeyText +from ._sql_generic import Transaction, requestBoundTransaction +from .interfaces import ISQLAuthorizer +from .. import SessionProcurer +from ..interfaces import ( + ISession, ISessionProcurer, ISessionStore, ISimpleAccount, + ISimpleAccountBinding, NoSuchSession, SessionMechanism +) + +if TYPE_CHECKING: # pragma: no cover + import sqlalchemy + from twisted.internet.defer import Deferred + from twisted.internet.interfaces import IReactorThreads + from twisted.web.iweb import IRequest + from ._sql_generic import DataStore + (Any, Callable, Deferred, Type, Iterable, IReactorThreads, Text, + List, sqlalchemy, Dict, IRequest, IInterface, Optional, DataStore) + T = TypeVar('T') + +@implementer(ISession) +@attr.s +class SQLSession(object): + _sessionStore = attr.ib(type='SessionStore') + identifier = attr.ib(type=Text) + isConfidential = attr.ib(type=bool) + authenticatedBy = attr.ib(type=SessionMechanism) + + def authorize(self, interfaces): + # type: (Iterable[IInterface]) -> Any + interfaces = set(interfaces) + result = {} # type: Dict[IInterface, Deferred] + ds = [] # type: List[Deferred] + txn = self._sessionStore._transaction + for a in self._sessionStore._authorizers: + # This should probably do something smart with interface + # priority, checking isOrExtends or something similar. + if a.authorizationInterface in interfaces: + v = maybeDeferred(a.authorizationForSession, + self._sessionStore, txn, self) + ds.append(v) + result[a.authorizationInterface] = v + v.addCallback( + lambda value, ai: result.__setitem__(ai, value), + ai=a.authorizationInterface + ) + + def r(ignored): + # type: (T) -> Dict[str, Any] + return result + return (gatherResults(ds).addCallback(r)) + + + +@attr.s +class SessionIPInformation(object): + """ + Information about a session being used from a given IP address. + """ + id = attr.ib(validator=an(text_type), type=Text) + ip = attr.ib(validator=an(text_type), type=Text) + when = attr.ib(validator=an(datetime), type=datetime) + +@implementer(ISessionStore) +@attr.s() +class SessionStore(object): + """ + An implementation of L{ISessionStore} based on a L{DataStore}, that + stores sessions in a SQLAlchemy database. + """ + + _transaction = attr.ib(type=Transaction) + _authorizers = attr.ib(type=List[ISQLAuthorizer], default=Factory(list)) + + def sentInsecurely(self, tokens): + # type: (List[str]) -> Deferred + """ + Tokens have been sent insecurely; delete any tokens expected to be + confidential. + + @param tokens: L{list} of L{str} + + @return: a L{Deferred} that fires when the tokens have been + invalidated. + """ + s = sessionSchema.session + return gatherResults([ + self._transaction.execute( + s.delete().where((s.c.session_id == token) & + (s.c.confidential == true())) + ) for token in tokens + ]) + + + @inlineCallbacks + def newSession(self, isConfidential, authenticatedBy): + # type: (bool, SessionMechanism) -> Deferred + identifier = hexlify(urandom(32)).decode('ascii') + s = sessionSchema.session + yield self._transaction.execute(s.insert().values( + session_id=identifier, + confidential=isConfidential, + )) + returnValue(SQLSession(self, + identifier=identifier, + isConfidential=isConfidential, + authenticatedBy=authenticatedBy)) + + + @inlineCallbacks + def loadSession(self, identifier, isConfidential, authenticatedBy): + # type: (Text, bool, SessionMechanism) -> Deferred + s = sessionSchema.session + result = yield self._transaction.execute( + s.select((s.c.session_id == identifier) & + (s.c.confidential == isConfidential))) + results = yield result.fetchall() + if not results: + raise NoSuchSession(u"Session not present in SQL store.") + fetched_identifier = results[0][s.c.session_id] + returnValue(SQLSession(self, + identifier=fetched_identifier, + isConfidential=isConfidential, + authenticatedBy=authenticatedBy)) + + + +@implementer(ISimpleAccountBinding) +@attr.s +class AccountSessionBinding(object): + """ + (Stateless) binding between an account and a session, so that sessions can + attach to, detach from, . + """ + _session = attr.ib(type=ISession) + _transaction = attr.ib(type=Transaction) + + def _account(self, accountID, username, email): + # type: (Text, Text, Text) -> SQLAccount + """ + Construct an L{SQLAccount} bound to this plugin & dataStore. + """ + return SQLAccount(self._transaction, accountID, username, + email) + + + @inlineCallbacks + def createAccount(self, username, email, password): + # type: (Text, Text, Text) -> Any + """ + Create a new account with the given username, email and password. + + @return: an L{Account} if one could be created, L{None} if one could + not be. + """ + computedHash = yield computeKeyText(password) + newAccountID = unicode(uuid4()) + insert = (sessionSchema.account.insert() + .values(account_id=newAccountID, + username=username, email=email, + password_blob=computedHash)) + try: + yield self._transaction.execute(insert) + except IntegrityError: + returnValue(None) + else: + accountID = newAccountID + account = self._account(accountID, username, email) + returnValue(account) + + + @inlineCallbacks + def bindIfCredentialsMatch(self, username, password): + # type: (Text, Text) -> Any + """ + Associate this session with a given user account, if the password + matches. + + @param username: The username input by the user. + @type username: L{text_type} + + @param password: The plain-text password input by the user. + @type password: L{text_type} + + @rtype: L{Deferred} firing with L{IAccount} if we succeeded and L{None} + if we failed. + """ + acc = sessionSchema.account + + result = yield self._transaction.execute( + acc.select(acc.c.username == username) + ) + accountsInfo = (yield result.fetchall()) + if not accountsInfo: + # no account, bye + returnValue(None) + [row] = accountsInfo + stored_password_text = row[acc.c.password_blob] + accountID = row[acc.c.account_id] + + def reset_password(newPWText): + # type: (Text) -> Any + a = sessionSchema.account + return self._transaction.execute( + a.update(a.c.account_id == accountID) + .values(password_blob=newPWText) + ) + + if (yield checkAndReset(stored_password_text, + password, + reset_password)): + account = self._account(accountID, row[acc.c.username], + row[acc.c.email]) + yield account.bindSession(self._session) + returnValue(account) + + + @inlineCallbacks + def boundAccounts(self): + # type: () -> Deferred + """ + Retrieve the accounts currently associated with this session. + + @return: L{Deferred} firing with a L{list} of accounts. + """ + ast = sessionSchema.sessionAccount + acc = sessionSchema.account + result = (yield (yield self._transaction.execute( + ast.join(acc, ast.c.account_id == acc.c.account_id) + .select(ast.c.session_id == self._session.identifier, + use_labels=True) + )).fetchall()) + returnValue([ + self._account(it[ast.c.account_id], it[acc.c.username], + it[acc.c.email]) + for it in result + ]) + + + @inlineCallbacks + def boundSessionInformation(self): + # type: () -> Any + """ + Retrieve information about all sessions attached to the same account + that this session is. + + @return: L{Deferred} firing a L{list} of L{SessionIPInformation} + """ + acs = sessionSchema.sessionAccount + sipt = sessionSchema.sessionIP + + acs2 = acs.alias() + result = yield self._transaction.execute( + select([sipt], use_labels=True) + .where( + (acs.c.session_id == self._session.identifier) & + (acs.c.account_id == acs2.c.account_id) & + (acs2.c.session_id == sipt.c.session_id) + ) + ) + returnValue([ + SessionIPInformation( + id=row[sipt.c.session_id], + ip=row[sipt.c.ip_address], + when=row[sipt.c.last_used]) + for row in (yield result.fetchall()) + ]) + + + def unbindThisSession(self): + # type: () -> Any + """ + Disassociate this session from any accounts it's logged in to. + + @return: a L{Deferred} that fires when the account is logged out. + """ + ast = sessionSchema.sessionAccount + return self._transaction.execute(ast.delete( + ast.c.session_id == self._session.identifier + )) + + + +@implementer(ISimpleAccount) +@attr.s +class SQLAccount(object): + """ + An implementation of L{ISimpleAccount} backed by an Alchimia data store. + """ + + _transaction = attr.ib(type=Transaction) + accountID = attr.ib(type=Text) + username = attr.ib(type=Text) + email = attr.ib(type=Text) + + + def bindSession(self, session): + # type: (ISession) -> Deferred + """ + Add a session to the database. + """ + return self._transaction.execute( + sessionSchema.sessionAccount + .insert().values(account_id=self.accountID, + session_id=session.identifier) + ) + + + @inlineCallbacks + def changePassword(self, newPassword): + # type: (Text) -> Any + """ + @param newPassword: The text of the new password. + @type newPassword: L{unicode} + """ + computedHash = yield computeKeyText(newPassword) + result = yield self._transaction.execute( + sessionSchema.account.update() + .where(account_id=self.accountID) + .values(password_blob=computedHash) + ) + returnValue(result) + + + + +@inlineCallbacks +def upsert( + engine, # type: Transaction + table, # type: sqlalchemy.schema.Table + to_query, # type: Dict[str, Any] + to_change # type: Dict[str, Any] +): + # type: (...) -> Any + """ + Try inserting, if inserting fails, then update. + """ + try: + result = yield engine.execute( + table.insert().values(**dict(to_query, **to_change)) + ) + except IntegrityError: + from operator import and_ as And + update = table.update().where( + reduce(And, ( + (getattr(table.c, cname) == cvalue) + for (cname, cvalue) in to_query.items() + )) + ).values(**to_change) + result = yield engine.execute(update) + returnValue(result) + + +@attr.s +class SessionSchema(object): + """ + Schema for SQL session features. + + This is exposed as public API so that you can have tables which relate + against it in your own code, and integrate with your schema management + system. + + However, while Klein uses Alchimia itself, it does not want to be in the + business of managing your schema migrations or your database access. As + such, this class exposes the schema in several formats: + + - via SQLAlchemy metadata, if you want to use something like Alembic or + SQLAlchemy-Migrate + + - via a single SQL string, if you manage your SQL migrations manually + """ + session = attr.ib(type=Table) + account = attr.ib(type=Table) + sessionAccount = attr.ib(type=Table) + sessionIP = attr.ib(type=Table) + + @classmethod + def withMetadata(cls, metadata=None): + # type: (Optional[MetaData]) -> SessionSchema + """ + Create a new L{SQLSessionSchema} with the given metadata, defaulting to + new L{MetaData} if none is supplied. + """ + if metadata is None: + metadata = MetaData() + session = Table( + "session", metadata, + Column("session_id", Unicode(), primary_key=True, + nullable=False), + Column("confidential", Boolean(), nullable=False), + ) + account = Table( + "account", metadata, + Column("account_id", Unicode(), primary_key=True, + nullable=False), + Column("username", Unicode(), unique=True, nullable=False), + Column("email", Unicode(), nullable=False), + Column("password_blob", Unicode(), nullable=False), + ) + sessionAccount = Table( + "session_account", metadata, + Column("account_id", Unicode(), + ForeignKey(account.c.account_id, ondelete="CASCADE")), + Column("session_id", Unicode(), + ForeignKey(session.c.session_id, ondelete="CASCADE")), + UniqueConstraint("account_id", "session_id"), + ) + sessionIP = Table( + "session_ip", metadata, + Column("session_id", Unicode(), + ForeignKey(session.c.session_id, ondelete="CASCADE")), + Column("ip_address", Unicode(), nullable=False), + Column("address_family", Unicode(), nullable=False), + Column("last_used", DateTime(), nullable=False), + UniqueConstraint("session_id", "ip_address", "address_family"), + ) + return cls(session, account, sessionAccount, sessionIP) + + + def tables(self): + # type: () -> Iterable[Table] + """ + Yield all tables that need to be created in order for sessions to be + enabled in a SQLAlchemy database, in the order they need to be created. + """ + yield self.session + yield self.account + yield self.sessionAccount + yield self.sessionIP + + + @inlineCallbacks + def create(self, transaction): + # type: (Transaction) -> Deferred + """ + Given a L{Transaction}, create this schema in the database and return a + L{Deferred} that fires with C{None} when done. + + This method will handle any future migration concerns. + """ + for table in self.tables(): + yield transaction.execute(CreateTable(table)) + + + def migrationSQL(self): + # type: () -> Text + """ + Return some SQL to run in order to create the tables necessary for the + SQL session and account store. Currently there is only one version, + but in the future, sections of this will be clearly delineated by '-- + Klein Session Schema Version X' comments. + + This SQL will not attempt to discern whether the tables exist already + or whether the migrations should be run. + """ + return (u"\n-- Klein Session Schema Version 1\n" + + (u";".join(str(CreateTable(table)) + for table in self.tables()))) + + + +sessionSchema = SessionSchema.withMetadata(MetaData()) + +procurerFromTransactionT = Callable[[Transaction], ISessionProcurer] + +@implementer(ISessionProcurer) +class IPTrackingProcurer(object): + """ + An implementation of L{ISessionProcurer} that keeps track of the source IP + of the originating session. + """ + + def __init__( + self, + dataStore, # type: DataStore + procurerFromTransaction # type: procurerFromTransactionT + ): + # type: (...) -> None + """ + Create an L{IPTrackingProcurer} from SQLAlchemy metadata, an alchimia + data store, and an existing L{ISessionProcurer}. + """ + self._dataStore = dataStore + self._procurerFromTransaction = procurerFromTransaction + + + @inlineCallbacks + def procureSession(self, request, forceInsecure=False): + # type: (IRequest, bool) -> Deferred + """ + Procure a session from the underlying procurer, keeping track of the IP + of the request object. + """ + alreadyProcured = request.getComponent(ISession) + if alreadyProcured is not None: + returnValue(alreadyProcured) + # if getattr(request, 'requesting', False): + # raise RuntimeError("what are you doing!?") + # request.requesting = True + transaction = yield requestBoundTransaction(request, self._dataStore) + procurer = yield self._procurerFromTransaction(transaction) + session = yield procurer.procureSession(request, forceInsecure) + try: + ipAddress = (request.client.host or b"").decode("ascii") + except BaseException: + ipAddress = u"" + sip = sessionSchema.sessionIP + yield upsert( + transaction, sip, + dict(session_id=session.identifier, ip_address=ipAddress, + address_family=(u"AF_INET6" if u":" in ipAddress + else u"AF_INET")), + dict(last_used=datetime.utcnow()) + ) + # XXX This should set a savepoint because we don't want application + # logic to be able to roll back the IP access log. + returnValue(session) + + + +procurerFromStoreT = Callable[[ISessionStore], ISessionProcurer] + +def procurerFromDataStore( + dataStore, # type: DataStore + authorizers, # type: List[ISQLAuthorizer] + procurerFromStore=SessionProcurer # type: procurerFromStoreT +): + # type: (...) -> ISessionProcurer + """ + Open a session store, returning a procurer that can procure sessions from + it. + + @param databaseURL: an SQLAlchemy database URL. + + @param procurerFromStore: A callable that takes an L{ISessionStore} and + returns an L{ISessionProcurer}. + + @return: L{Deferred} firing with L{ISessionProcurer} + """ + allAuthorizers = [simpleAccountBinding.authorizer, + logMeIn.authorizer] + list(authorizers) + return IPTrackingProcurer( + dataStore, + lambda transaction: procurerFromStore(SessionStore( + transaction, allAuthorizers + )) + ) + + + +class _FunctionWithAuthorizer(object): + + authorizer = None # type: Any + + def __call__( + self, + sessionStore, # type: SessionStore + transaction, # type: Transaction + session # type: ISession + ): + # type: (...) -> Any + """ + Signature for a function that can have an authorizer attached to it. + """ + +_authorizerFunction = Callable[ + [SessionStore, Transaction, ISession], + Any +] + +@implementer(ISQLAuthorizer) +@attr.s +class SimpleSQLAuthorizer(object): + authorizationInterface = attr.ib(type=Type) + _decorated = attr.ib(type=_authorizerFunction) + + def authorizationForSession(self, sessionStore, transaction, session): + # type: (SessionStore, Transaction, ISession) -> Any + cb = cast(_authorizerFunction, self._decorated) # type: ignore + return cb(sessionStore, transaction, session) + + +def authorizerFor( + authorizationInterface, # type: IInterface +): + # type: (...) -> Callable[[Callable], _FunctionWithAuthorizer] + """ + Declare an SQL authorizer, implemented by a given function. Used like so:: + + @authorizerFor(Foo, tables(foo=[Column("bar", Unicode())])) + def authorizeFoo(dataStore, sessionStore, transaction, session): + return Foo(metadata, metadata.tables["foo"]) + + @param authorizationInterface: The type we are creating an authorizer for. + + @return: a decorator that can decorate a function with the signature + C{(metadata, dataStore, sessionStore, transaction, session)} + """ + def decorator(decorated): + # type: (_authorizerFunction) -> _FunctionWithAuthorizer + result = cast(_FunctionWithAuthorizer, decorated) + result.authorizer = SimpleSQLAuthorizer(authorizationInterface, + decorated) + return result + return decorator + + + +@authorizerFor(ISimpleAccountBinding) +def simpleAccountBinding( + sessionStore, # type: SessionStore + transaction, # type: Transaction + session # type: ISession +): + # type: (...) -> AccountSessionBinding + """ + All sessions are authorized for access to an L{ISimpleAccountBinding}. + """ + return AccountSessionBinding(session, transaction) + + + +@authorizerFor(ISimpleAccount) +@inlineCallbacks +def logMeIn( + sessionStore, # type: SessionStore + transaction, # type: Transaction + session # type: ISession +): + # type: (...) -> Deferred + """ + Retrieve an L{ISimpleAccount} authorization. + """ + binding = ((yield session.authorize([ISimpleAccountBinding])) + [ISimpleAccountBinding]) + returnValue(next(iter((yield binding.boundAccounts())), + None)) diff --git a/src/klein/storage/_sql_generic.py b/src/klein/storage/_sql_generic.py new file mode 100644 index 000000000..5f48ff01a --- /dev/null +++ b/src/klein/storage/_sql_generic.py @@ -0,0 +1,373 @@ +""" +Generic SQL data storage stuff; the substrate for session-storage stuff. +""" + +from collections import deque +from sys import exc_info +from typing import Any, Optional, TYPE_CHECKING, Text, TypeVar + +from alchimia import TWISTED_STRATEGY + +import attr +from attr import Factory + +from sqlalchemy import create_engine + +from twisted.internet.defer import (Deferred, gatherResults, inlineCallbacks, + returnValue, succeed) + +from zope.interface import Interface, implementer + +from ..interfaces import TransactionEnded + +_sqlAlchemyConnection = Any +_sqlAlchemyTransaction = Any + +COMMITTING = "committing" +COMMITTED = "committed" +COMMIT_FAILED = "commit failed" +ROLLING_BACK = "rolling back" +ROLLED_BACK = "rolled back" +ROLLBACK_FAILED = "rollback failed" + +if TYPE_CHECKING: # pragma: no cover + T = TypeVar('T') + from twisted.internet.interfaces import IReactorThreads + IReactorThreads + from typing import Iterable + Iterable + from typing import Callable + Callable + from twisted.web.iweb import IRequest + IRequest + + +@attr.s +class Transaction(object): + """ + Wrapper around a SQLAlchemy connection which is invalidated when the + transaction is committed or rolled back. + """ + _connection = attr.ib(type=_sqlAlchemyConnection) + _transaction = attr.ib(type=_sqlAlchemyTransaction) + _parent = attr.ib(type='Optional[Transaction]', default=None) + _stopped = attr.ib(type=Text, default=u"") + _completeDeferred = attr.ib(type=Deferred, default=Factory(Deferred)) + + def _checkStopped(self): + # type: () -> None + """ + Raise an exception if the transaction has been stopped for any reason. + """ + if self._stopped: + raise TransactionEnded(self._stopped) + if self._parent is not None: + self._parent._checkStopped() + + + def execute(self, statement, *multiparams, **params): + # type: (Any, *Any, **Any) -> Deferred + """ + Execute a statement unless this transaction has been stopped, otherwise + raise L{TransactionEnded}. + """ + self._checkStopped() + return self._connection.execute(statement, *multiparams, **params) + + + def commit(self): + # type: () -> Deferred + """ + Commit this transaction. + """ + self._checkStopped() + self._stopped = COMMITTING + return self._transaction.commit().addCallbacks( + (lambda commitResult: self._finishWith(COMMITTED)), + (lambda commitFailure: self._finishWith(COMMIT_FAILED)) + ) + + + def rollback(self): + # type: () -> Deferred + """ + Roll this transaction back. + """ + self._checkStopped() + self._stopped = ROLLING_BACK + return self._transaction.rollback().addCallbacks( + (lambda commitResult: self._finishWith(ROLLED_BACK)), + (lambda commitFailure: self._finishWith(ROLLBACK_FAILED)) + ) + + + def _finishWith(self, stopStatus): + # type: (Text) -> None + """ + Complete this transaction. + """ + self._stopped = stopStatus + self._completeDeferred.callback(stopStatus) + + + @inlineCallbacks + def savepoint(self): + # type: () -> Deferred + """ + Create a L{Savepoint} which can be treated as a sub-transaction. + + @note: As long as this L{Savepoint} has not been rolled back or + committed, this transaction's C{execute} method will execute within + the context of that savepoint. + """ + returnValue(Transaction( + self._connection, (yield self._connection.begin_nested()), + self + )) + + + def subtransact(self, logic): + # type: (Callable[[Transaction], Deferred]) -> Deferred + """ + Run the given C{logic} in a subtransaction. + """ + return Transactor(self.savepoint).transact(logic) + + + def maybeCommit(self): + # type: () -> Deferred + """ + Commit this transaction if it hasn't been finished (committed or rolled + back) yet; otherwise, do nothing. + """ + if self._stopped: + return succeed(None) + return self.commit() + + + def maybeRollback(self): + # type: () -> Deferred + """ + Roll this transaction back if it hasn't been finished (committed or + rolled back) yet; otherwise, do nothing. + """ + if self._stopped: + return succeed(None) + return self.rollback() + + +@attr.s +class Transactor(object): + """ + A context manager that represents the lifecycle of a transaction when + paired with application code. + """ + + _newTransaction = attr.ib(type='Callable[[], Deferred]') + _transaction = attr.ib(type=Optional[Transaction], default=None) + + @inlineCallbacks + def __aenter__(self): + # type: () -> Deferred + """ + Start a transaction. + """ + self._transaction = yield self._newTransaction() + # ^ https://github.com/python/mypy/issues/4688 + returnValue(self._transaction) + + @inlineCallbacks + def __aexit__(self, exc_type, exc_value, traceback): + # type: (type, Exception, Any) -> Deferred + """ + End a transaction. + """ + assert self._transaction is not None + if exc_type is None: + yield self._transaction.commit() + else: + yield self._transaction.rollback() + self._transaction = None + + @inlineCallbacks + def transact(self, logic): + # type: (Callable) -> Deferred + """ + Run the given logic within this L{TransactionContext}, starting and + stopping as usual. + """ + try: + transaction = yield self.__aenter__() + result = yield logic(transaction) + finally: + yield self.__aexit__(*exc_info()) + returnValue(result) + + + +@attr.s(hash=False) +class DataStore(object): + """ + L{DataStore} is a generic storage object that connect to an SQL + database, run transactions, and manage schema metadata. + """ + + _engine = attr.ib(type=_sqlAlchemyConnection) + _freeConnections = attr.ib(default=Factory(deque), type=deque) + + @inlineCallbacks + def newTransaction(self): + # type: () -> Deferred + """ + Create a new Klein transaction. + """ + alchimiaConnection = ( + self._freeConnections.popleft() if self._freeConnections + else (yield self._engine.connect()) + ) + alchimiaTransaction = yield alchimiaConnection.begin() + kleinTransaction = Transaction(alchimiaConnection, alchimiaTransaction) + + @kleinTransaction._completeDeferred.addBoth + def recycleTransaction(anything): + # type: (T) -> T + self._freeConnections.append(alchimiaConnection) + return anything + returnValue(kleinTransaction) + + + def transact(self, callable): + # type: (Callable[[Transaction], Any]) -> Any + """ + Run the given C{callable} within a transaction. + + @param callable: A callable object that encapsulates application logic + that needs to run in a transaction. + @type callable: callable taking a L{Transaction} and returning a + L{Deferred}. + + @return: a L{Deferred} firing with the result of C{callable} + @rtype: L{Deferred} that fires when the transaction is complete, or + fails when the transaction is rolled back. + """ + return Transactor(self.newTransaction).transact(callable) + + + @classmethod + def open(cls, reactor, dbURL): + # type: (IReactorThreads, Text) -> DataStore + """ + Open an L{DataStore}. + + @param reactor: the reactor that this store should be opened on. + @type reactor: L{IReactorThreads} + + @param dbURL: the SQLAlchemy database URI to connect to. + @type dbURL: L{str} + """ + return cls(create_engine(dbURL, reactor=reactor, + strategy=TWISTED_STRATEGY)) + + +class ITransactionRequestAssociator(Interface): + """ + Associates transactions with requests. + """ + +@implementer(ITransactionRequestAssociator) +@attr.s +class TransactionRequestAssociator(object): + """ + Does the thing the interface says. + """ + _map = attr.ib(type=dict, default=Factory(dict)) + committing = attr.ib(type=bool, default=False) + + @inlineCallbacks + def transactionForStore(self, dataStore): + # type: (DataStore) -> Deferred + """ + Get a transaction for the given datastore. + """ + if dataStore in self._map: + returnValue(self._map[dataStore]) + txn = yield dataStore.newTransaction() + self._map[dataStore] = txn + returnValue(txn) + + def commitAll(self): + # type: () -> Deferred + """ + Commit all associated transactions. + """ + self.committing = True + return gatherResults([value.maybeCommit() + for value in self._map.values()]) + +@inlineCallbacks +def requestBoundTransaction(request, dataStore): + # type: (IRequest, DataStore) -> Deferred + """ + Retrieve a transaction that is bound to the lifecycle of the given request. + + There are three use-cases for this lifecycle: + + 1. 'normal CRUD' - a request begins, a transaction is associated with + it, and the transaction completes when the request completes. The + appropriate time to commit the transaction is the moment before the + first byte goes out to the client. The appropriate moment to + interpose this commit is in `Request.write`, at the moment where + it's about to call channel.writeHeaders, since the HTTP status code + should be an indicator of whether the transaction succeeded or + failed. + + 2. 'just the session please' - a request begins, a transaction is + associated with it in order to discover the session, and the + application code in question isn't actually using the database. + (Ideally as expressed through "the dependency-declaration decorator, + such as @authorized, did not indicate that a transaction will be + required"). + + 3. 'fancy API stuff' - a request begins, a transaction is associated + with it in order to discover the session, the application code needs + to then do I{something} with that transaction in-line with the + session discovery, but then needs to commit in order to relinquish + all database locks while doing some potentially slow external API + calls, then start a I{new} transaction later in the request flow. + """ + assoc = request.getComponent(ITransactionRequestAssociator) + if assoc is None: + assoc = TransactionRequestAssociator() + request.setComponent(ITransactionRequestAssociator, assoc) + + def finishCommit(result): + # type: (Any) -> Deferred + return assoc.commitAll() + request.notifyFinish().addBoth(finishCommit) + + # originalWrite = request.write + # buffer = [] + # def committed(result): + # for buf in buffer: + # if buf is None: + # originalFinish() + # else: + # originalWrite(buf) + + # def maybeWrite(data): + # if request.startedWriting: + # return originalWrite(data) + # buffer.append(data) + # if assoc.committing: + # return + # assoc.commitAll().addBoth(committed) + # def maybeFinish(): + # if not request.startedWriting: + # buffer.append(None) + # else: + # originalFinish() + # originalFinish = request.finish + # request.write = maybeWrite + # request.finish = maybeFinish + txn = yield assoc.transactionForStore(dataStore) + return txn diff --git a/src/klein/storage/interfaces.py b/src/klein/storage/interfaces.py new file mode 100644 index 000000000..5e1420821 --- /dev/null +++ b/src/klein/storage/interfaces.py @@ -0,0 +1,14 @@ + +from typing import TYPE_CHECKING, Union + +from ._istorage import ISQLAuthorizer as _ISQLAuthorizer + +if TYPE_CHECKING: # pragma: no cover + from ._sql import SimpleSQLAuthorizer + ISQLAuthorizer = Union[_ISQLAuthorizer, SimpleSQLAuthorizer] +else: + ISQLAuthorizer = _ISQLAuthorizer + +__all__ = [ + 'ISQLAuthorizer' +] diff --git a/src/klein/storage/sql.py b/src/klein/storage/sql.py new file mode 100644 index 000000000..68c898b54 --- /dev/null +++ b/src/klein/storage/sql.py @@ -0,0 +1,19 @@ + +from ._sql import ( + SessionSchema, authorizerFor, procurerFromDataStore +) +from ._sql_generic import ( + DataStore, Transaction +) + +__all__ = [ + "procurerFromDataStore", + "authorizerFor", + "SessionSchema", + "DataStore", + "Transaction", +] + +if __name__ == '__main__': + import sys + sys.stdout.write(SessionSchema.withMetadata().migrationSQL()) diff --git a/src/klein/storage/test/test_security.py b/src/klein/storage/test/test_security.py new file mode 100644 index 000000000..59651fedf --- /dev/null +++ b/src/klein/storage/test/test_security.py @@ -0,0 +1,13 @@ + +from twisted.trial.unittest import TestCase + +from klein.storage import security + +class SQLTests(TestCase): + + def test_security(self): + # type: () -> None + """ + Add tests. + """ + security diff --git a/src/klein/storage/test/test_sql.py b/src/klein/storage/test/test_sql.py new file mode 100644 index 000000000..8a0f73d4b --- /dev/null +++ b/src/klein/storage/test/test_sql.py @@ -0,0 +1,12 @@ + +from twisted.trial.unittest import TestCase + +from klein.storage import sql + +class SQLTests(TestCase): + def test_sql(self): + # type: () -> None + """ + Add tests. + """ + sql From 4a7f7d291e89848854a3de66d5c146bb010f6792 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Aug 2021 10:53:02 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- setup.py | 5 +- src/klein/interfaces.py | 43 ++- src/klein/storage/_istorage.py | 4 +- src/klein/storage/_security.py | 21 +- src/klein/storage/_sql.py | 442 ++++++++++++++---------- src/klein/storage/_sql_generic.py | 74 ++-- src/klein/storage/interfaces.py | 8 +- src/klein/storage/sql.py | 12 +- src/klein/storage/test/test_security.py | 3 +- src/klein/storage/test/test_sql.py | 2 +- 10 files changed, 354 insertions(+), 260 deletions(-) diff --git a/setup.py b/setup.py index 75f7ab392..ba7882c5a 100644 --- a/setup.py +++ b/setup.py @@ -46,13 +46,12 @@ "docs": [ "Sphinx==3.5.1", "sphinx-rtd-theme==0.5.1", - ] + ], }, keywords="twisted flask werkzeug web", license="MIT", name="klein", - packages=["klein", "klein.storage", - "klein.test", "klein.storage.test"], + packages=["klein", "klein.storage", "klein.test", "klein.storage.test"], package_dir={"": "src"}, package_data=dict( klein=["py.typed"], diff --git a/src/klein/interfaces.py b/src/klein/interfaces.py index 40f3b7d8f..b7f72738d 100644 --- a/src/klein/interfaces.py +++ b/src/klein/interfaces.py @@ -19,10 +19,14 @@ TransactionEnded, ) -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from ._storage.memory import MemorySessionStore, MemorySession - from ._storage.sql import (SessionStore, SQLAccount, IPTrackingProcurer, - AccountSessionBinding) + from ._storage.sql import ( + SessionStore, + SQLAccount, + IPTrackingProcurer, + AccountSessionBinding, + ) from ._session import SessionProcurer, Authorization from ._form import Field, RenderableFormParam, FieldInjector from ._isession import IRequestLifecycleT as _IRequestLifecycleT @@ -30,20 +34,29 @@ from typing import Union - ISessionStore = Union[_ISessionStore, MemorySessionStore, - SessionStore] - ISessionProcurer = Union[_ISessionProcurer, SessionProcurer, - IPTrackingProcurer] + ISessionStore = Union[_ISessionStore, MemorySessionStore, SessionStore] + ISessionProcurer = Union[ + _ISessionProcurer, SessionProcurer, IPTrackingProcurer + ] ISession = Union[_ISession, MemorySession] ISimpleAccount = Union[_ISimpleAccount, SQLAccount] - ISimpleAccountBinding = Union[_ISimpleAccountBinding, - AccountSessionBinding] - IDependencyInjector = Union[_IDependencyInjector, Authorization, - RenderableFormParam, FieldInjector, RequestURL, - RequestComponent] - IRequiredParameter = Union[_IRequiredParameter, Authorization, Field, - RenderableFormParam, RequestURL, - RequestComponent] + ISimpleAccountBinding = Union[_ISimpleAccountBinding, AccountSessionBinding] + IDependencyInjector = Union[ + _IDependencyInjector, + Authorization, + RenderableFormParam, + FieldInjector, + RequestURL, + RequestComponent, + ] + IRequiredParameter = Union[ + _IRequiredParameter, + Authorization, + Field, + RenderableFormParam, + RequestURL, + RequestComponent, + ] IRequestLifecycle = _IRequestLifecycleT else: ISession = _ISession diff --git a/src/klein/storage/_istorage.py b/src/klein/storage/_istorage.py index 28f5b32d8..2416d85b0 100644 --- a/src/klein/storage/_istorage.py +++ b/src/klein/storage/_istorage.py @@ -1,14 +1,14 @@ - from typing import TYPE_CHECKING from zope.interface import Attribute, Interface from .._typing import ifmethod -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from twisted.internet.defer import Deferred from ..interfaces import ISessionStore, ISession from ._sql_generic import Transaction + ISession, ISessionStore, Deferred, Transaction diff --git a/src/klein/storage/_security.py b/src/klein/storage/_security.py index c3ff2fdef..7a19b0d00 100644 --- a/src/klein/storage/_security.py +++ b/src/klein/storage/_security.py @@ -1,4 +1,3 @@ - from functools import partial from typing import Any, Callable, Optional, TYPE_CHECKING, Text, Tuple from unicodedata import normalize @@ -8,21 +7,24 @@ from twisted.internet.defer import Deferred, inlineCallbacks, returnValue from twisted.internet.threads import deferToThread -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover Text, Callable, Deferred, Optional, Tuple, Any -passlibContextWithGoodDefaults = partial(CryptContext, schemes=['bcrypt']) +passlibContextWithGoodDefaults = partial(CryptContext, schemes=["bcrypt"]) + def _verifyAndUpdate(secret, hash, ctxFactory=passlibContextWithGoodDefaults): # type: (Text, Text, Callable[[], CryptContext]) -> Deferred """ Asynchronous wrapper for L{CryptContext.verify_and_update}. """ + @deferToThread def theWork(): # type: () -> Tuple[bool, Optional[str]] return ctxFactory().verify_and_update(secret, hash) + return theWork @@ -31,12 +33,13 @@ def _hashSecret(secret, ctxFactory=passlibContextWithGoodDefaults): """ Asynchronous wrapper for L{CryptContext.hash}. """ + @deferToThread def theWork(): # type: () -> str return ctxFactory().hash(secret) - return theWork + return theWork @inlineCallbacks @@ -56,9 +59,10 @@ def checkAndReset(storedPasswordText, providedPasswordText, resetter): @return: L{Deferred} firing with C{True} if the password matches and C{False} if the password does not match. """ - providedPasswordText = normalize('NFD', providedPasswordText) - valid, newHash = yield _verifyAndUpdate(providedPasswordText, - storedPasswordText) + providedPasswordText = normalize("NFD", providedPasswordText) + valid, newHash = yield _verifyAndUpdate( + providedPasswordText, storedPasswordText + ) if valid: # Password migration! Does our passlib context have an awesome *new* # hash it wants to give us? Store it. @@ -71,7 +75,6 @@ def checkAndReset(storedPasswordText, providedPasswordText, resetter): returnValue(False) - @inlineCallbacks def computeKeyText(passwordText): # type: (Text) -> Any @@ -82,7 +85,7 @@ def computeKeyText(passwordText): @return: a L{Deferred} firing with L{unicode}. """ - normalized = normalize('NFD', passwordText) + normalized = normalize("NFD", passwordText) hashed = yield _hashSecret(normalized) if isinstance(hashed, bytes): hashed = hashed.decode("charmap") diff --git a/src/klein/storage/_sql.py b/src/klein/storage/_sql.py index be79d4103..59ed7098b 100644 --- a/src/klein/storage/_sql.py +++ b/src/klein/storage/_sql.py @@ -3,8 +3,17 @@ from functools import reduce from os import urandom from typing import ( - Any, Callable, Dict, Iterable, List, Optional, TYPE_CHECKING, Text, - Type, TypeVar, cast + Any, + Callable, + Dict, + Iterable, + List, + Optional, + TYPE_CHECKING, + Text, + Type, + TypeVar, + cast, ) from uuid import uuid4 @@ -12,18 +21,27 @@ from attr import Factory from attr.validators import instance_of as an -from six import text_type from sqlalchemy import ( - Boolean, Column, DateTime, ForeignKey, MetaData, Table, - Unicode, UniqueConstraint, true + Boolean, + Column, + DateTime, + ForeignKey, + MetaData, + Table, + Unicode, + UniqueConstraint, + true, ) from sqlalchemy.exc import IntegrityError from sqlalchemy.schema import CreateTable from sqlalchemy.sql.expression import select from twisted.internet.defer import ( - gatherResults, inlineCallbacks, maybeDeferred, returnValue + gatherResults, + inlineCallbacks, + maybeDeferred, + returnValue, ) from twisted.python.compat import unicode @@ -35,24 +53,45 @@ from .interfaces import ISQLAuthorizer from .. import SessionProcurer from ..interfaces import ( - ISession, ISessionProcurer, ISessionStore, ISimpleAccount, - ISimpleAccountBinding, NoSuchSession, SessionMechanism + ISession, + ISessionProcurer, + ISessionStore, + ISimpleAccount, + ISimpleAccountBinding, + NoSuchSession, + SessionMechanism, ) -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover import sqlalchemy from twisted.internet.defer import Deferred from twisted.internet.interfaces import IReactorThreads from twisted.web.iweb import IRequest from ._sql_generic import DataStore - (Any, Callable, Deferred, Type, Iterable, IReactorThreads, Text, - List, sqlalchemy, Dict, IRequest, IInterface, Optional, DataStore) - T = TypeVar('T') + + ( + Any, + Callable, + Deferred, + Type, + Iterable, + IReactorThreads, + Text, + List, + sqlalchemy, + Dict, + IRequest, + IInterface, + Optional, + DataStore, + ) + T = TypeVar("T") + @implementer(ISession) @attr.s -class SQLSession(object): - _sessionStore = attr.ib(type='SessionStore') +class SQLSession: + _sessionStore = attr.ib(type="SessionStore") identifier = attr.ib(type=Text) isConfidential = attr.ib(type=bool) authenticatedBy = attr.ib(type=SessionMechanism) @@ -60,41 +99,44 @@ class SQLSession(object): def authorize(self, interfaces): # type: (Iterable[IInterface]) -> Any interfaces = set(interfaces) - result = {} # type: Dict[IInterface, Deferred] - ds = [] # type: List[Deferred] + result = {} # type: Dict[IInterface, Deferred] + ds = [] # type: List[Deferred] txn = self._sessionStore._transaction for a in self._sessionStore._authorizers: # This should probably do something smart with interface # priority, checking isOrExtends or something similar. if a.authorizationInterface in interfaces: - v = maybeDeferred(a.authorizationForSession, - self._sessionStore, txn, self) + v = maybeDeferred( + a.authorizationForSession, self._sessionStore, txn, self + ) ds.append(v) result[a.authorizationInterface] = v v.addCallback( lambda value, ai: result.__setitem__(ai, value), - ai=a.authorizationInterface + ai=a.authorizationInterface, ) def r(ignored): # type: (T) -> Dict[str, Any] return result - return (gatherResults(ds).addCallback(r)) + return gatherResults(ds).addCallback(r) @attr.s -class SessionIPInformation(object): +class SessionIPInformation: """ Information about a session being used from a given IP address. """ - id = attr.ib(validator=an(text_type), type=Text) - ip = attr.ib(validator=an(text_type), type=Text) + + id = attr.ib(validator=an(str), type=Text) + ip = attr.ib(validator=an(str), type=Text) when = attr.ib(validator=an(datetime), type=datetime) + @implementer(ISessionStore) @attr.s() -class SessionStore(object): +class SessionStore: """ An implementation of L{ISessionStore} based on a L{DataStore}, that stores sessions in a SQLAlchemy database. @@ -115,54 +157,69 @@ def sentInsecurely(self, tokens): invalidated. """ s = sessionSchema.session - return gatherResults([ - self._transaction.execute( - s.delete().where((s.c.session_id == token) & - (s.c.confidential == true())) - ) for token in tokens - ]) - + return gatherResults( + [ + self._transaction.execute( + s.delete().where( + (s.c.session_id == token) & (s.c.confidential == true()) + ) + ) + for token in tokens + ] + ) @inlineCallbacks def newSession(self, isConfidential, authenticatedBy): # type: (bool, SessionMechanism) -> Deferred - identifier = hexlify(urandom(32)).decode('ascii') + identifier = hexlify(urandom(32)).decode("ascii") s = sessionSchema.session - yield self._transaction.execute(s.insert().values( - session_id=identifier, - confidential=isConfidential, - )) - returnValue(SQLSession(self, - identifier=identifier, - isConfidential=isConfidential, - authenticatedBy=authenticatedBy)) - + yield self._transaction.execute( + s.insert().values( + session_id=identifier, + confidential=isConfidential, + ) + ) + returnValue( + SQLSession( + self, + identifier=identifier, + isConfidential=isConfidential, + authenticatedBy=authenticatedBy, + ) + ) @inlineCallbacks def loadSession(self, identifier, isConfidential, authenticatedBy): # type: (Text, bool, SessionMechanism) -> Deferred s = sessionSchema.session result = yield self._transaction.execute( - s.select((s.c.session_id == identifier) & - (s.c.confidential == isConfidential))) + s.select( + (s.c.session_id == identifier) + & (s.c.confidential == isConfidential) + ) + ) results = yield result.fetchall() if not results: - raise NoSuchSession(u"Session not present in SQL store.") + raise NoSuchSession("Session not present in SQL store.") fetched_identifier = results[0][s.c.session_id] - returnValue(SQLSession(self, - identifier=fetched_identifier, - isConfidential=isConfidential, - authenticatedBy=authenticatedBy)) - + returnValue( + SQLSession( + self, + identifier=fetched_identifier, + isConfidential=isConfidential, + authenticatedBy=authenticatedBy, + ) + ) @implementer(ISimpleAccountBinding) @attr.s -class AccountSessionBinding(object): +class AccountSessionBinding: """ (Stateless) binding between an account and a session, so that sessions can attach to, detach from, . """ + _session = attr.ib(type=ISession) _transaction = attr.ib(type=Transaction) @@ -171,9 +228,7 @@ def _account(self, accountID, username, email): """ Construct an L{SQLAccount} bound to this plugin & dataStore. """ - return SQLAccount(self._transaction, accountID, username, - email) - + return SQLAccount(self._transaction, accountID, username, email) @inlineCallbacks def createAccount(self, username, email, password): @@ -186,10 +241,12 @@ def createAccount(self, username, email, password): """ computedHash = yield computeKeyText(password) newAccountID = unicode(uuid4()) - insert = (sessionSchema.account.insert() - .values(account_id=newAccountID, - username=username, email=email, - password_blob=computedHash)) + insert = sessionSchema.account.insert().values( + account_id=newAccountID, + username=username, + email=email, + password_blob=computedHash, + ) try: yield self._transaction.execute(insert) except IntegrityError: @@ -199,7 +256,6 @@ def createAccount(self, username, email, password): account = self._account(accountID, username, email) returnValue(account) - @inlineCallbacks def bindIfCredentialsMatch(self, username, password): # type: (Text, Text) -> Any @@ -221,7 +277,7 @@ def bindIfCredentialsMatch(self, username, password): result = yield self._transaction.execute( acc.select(acc.c.username == username) ) - accountsInfo = (yield result.fetchall()) + accountsInfo = yield result.fetchall() if not accountsInfo: # no account, bye returnValue(None) @@ -233,19 +289,20 @@ def reset_password(newPWText): # type: (Text) -> Any a = sessionSchema.account return self._transaction.execute( - a.update(a.c.account_id == accountID) - .values(password_blob=newPWText) + a.update(a.c.account_id == accountID).values( + password_blob=newPWText + ) ) - if (yield checkAndReset(stored_password_text, - password, - reset_password)): - account = self._account(accountID, row[acc.c.username], - row[acc.c.email]) + if ( + yield checkAndReset(stored_password_text, password, reset_password) + ): + account = self._account( + accountID, row[acc.c.username], row[acc.c.email] + ) yield account.bindSession(self._session) returnValue(account) - @inlineCallbacks def boundAccounts(self): # type: () -> Deferred @@ -256,17 +313,22 @@ def boundAccounts(self): """ ast = sessionSchema.sessionAccount acc = sessionSchema.account - result = (yield (yield self._transaction.execute( - ast.join(acc, ast.c.account_id == acc.c.account_id) - .select(ast.c.session_id == self._session.identifier, - use_labels=True) - )).fetchall()) - returnValue([ - self._account(it[ast.c.account_id], it[acc.c.username], - it[acc.c.email]) - for it in result - ]) - + result = yield ( + yield self._transaction.execute( + ast.join(acc, ast.c.account_id == acc.c.account_id).select( + ast.c.session_id == self._session.identifier, + use_labels=True, + ) + ) + ).fetchall() + returnValue( + [ + self._account( + it[ast.c.account_id], it[acc.c.username], it[acc.c.email] + ) + for it in result + ] + ) @inlineCallbacks def boundSessionInformation(self): @@ -282,21 +344,22 @@ def boundSessionInformation(self): acs2 = acs.alias() result = yield self._transaction.execute( - select([sipt], use_labels=True) - .where( - (acs.c.session_id == self._session.identifier) & - (acs.c.account_id == acs2.c.account_id) & - (acs2.c.session_id == sipt.c.session_id) + select([sipt], use_labels=True).where( + (acs.c.session_id == self._session.identifier) + & (acs.c.account_id == acs2.c.account_id) + & (acs2.c.session_id == sipt.c.session_id) ) ) - returnValue([ - SessionIPInformation( - id=row[sipt.c.session_id], - ip=row[sipt.c.ip_address], - when=row[sipt.c.last_used]) - for row in (yield result.fetchall()) - ]) - + returnValue( + [ + SessionIPInformation( + id=row[sipt.c.session_id], + ip=row[sipt.c.ip_address], + when=row[sipt.c.last_used], + ) + for row in (yield result.fetchall()) + ] + ) def unbindThisSession(self): # type: () -> Any @@ -306,15 +369,14 @@ def unbindThisSession(self): @return: a L{Deferred} that fires when the account is logged out. """ ast = sessionSchema.sessionAccount - return self._transaction.execute(ast.delete( - ast.c.session_id == self._session.identifier - )) - + return self._transaction.execute( + ast.delete(ast.c.session_id == self._session.identifier) + ) @implementer(ISimpleAccount) @attr.s -class SQLAccount(object): +class SQLAccount: """ An implementation of L{ISimpleAccount} backed by an Alchimia data store. """ @@ -324,19 +386,17 @@ class SQLAccount(object): username = attr.ib(type=Text) email = attr.ib(type=Text) - def bindSession(self, session): # type: (ISession) -> Deferred """ Add a session to the database. """ return self._transaction.execute( - sessionSchema.sessionAccount - .insert().values(account_id=self.accountID, - session_id=session.identifier) + sessionSchema.sessionAccount.insert().values( + account_id=self.accountID, session_id=session.identifier + ) ) - @inlineCallbacks def changePassword(self, newPassword): # type: (Text) -> Any @@ -353,14 +413,12 @@ def changePassword(self, newPassword): returnValue(result) - - @inlineCallbacks def upsert( - engine, # type: Transaction - table, # type: sqlalchemy.schema.Table - to_query, # type: Dict[str, Any] - to_change # type: Dict[str, Any] + engine, # type: Transaction + table, # type: sqlalchemy.schema.Table + to_query, # type: Dict[str, Any] + to_change, # type: Dict[str, Any] ): # type: (...) -> Any """ @@ -372,18 +430,26 @@ def upsert( ) except IntegrityError: from operator import and_ as And - update = table.update().where( - reduce(And, ( - (getattr(table.c, cname) == cvalue) - for (cname, cvalue) in to_query.items() - )) - ).values(**to_change) + + update = ( + table.update() + .where( + reduce( + And, + ( + (getattr(table.c, cname) == cvalue) + for (cname, cvalue) in to_query.items() + ), + ) + ) + .values(**to_change) + ) result = yield engine.execute(update) returnValue(result) @attr.s -class SessionSchema(object): +class SessionSchema: """ Schema for SQL session features. @@ -400,6 +466,7 @@ class SessionSchema(object): - via a single SQL string, if you manage your SQL migrations manually """ + session = attr.ib(type=Table) account = attr.ib(type=Table) sessionAccount = attr.ib(type=Table) @@ -415,31 +482,42 @@ def withMetadata(cls, metadata=None): if metadata is None: metadata = MetaData() session = Table( - "session", metadata, - Column("session_id", Unicode(), primary_key=True, - nullable=False), + "session", + metadata, + Column("session_id", Unicode(), primary_key=True, nullable=False), Column("confidential", Boolean(), nullable=False), ) account = Table( - "account", metadata, - Column("account_id", Unicode(), primary_key=True, - nullable=False), + "account", + metadata, + Column("account_id", Unicode(), primary_key=True, nullable=False), Column("username", Unicode(), unique=True, nullable=False), Column("email", Unicode(), nullable=False), Column("password_blob", Unicode(), nullable=False), ) sessionAccount = Table( - "session_account", metadata, - Column("account_id", Unicode(), - ForeignKey(account.c.account_id, ondelete="CASCADE")), - Column("session_id", Unicode(), - ForeignKey(session.c.session_id, ondelete="CASCADE")), + "session_account", + metadata, + Column( + "account_id", + Unicode(), + ForeignKey(account.c.account_id, ondelete="CASCADE"), + ), + Column( + "session_id", + Unicode(), + ForeignKey(session.c.session_id, ondelete="CASCADE"), + ), UniqueConstraint("account_id", "session_id"), ) sessionIP = Table( - "session_ip", metadata, - Column("session_id", Unicode(), - ForeignKey(session.c.session_id, ondelete="CASCADE")), + "session_ip", + metadata, + Column( + "session_id", + Unicode(), + ForeignKey(session.c.session_id, ondelete="CASCADE"), + ), Column("ip_address", Unicode(), nullable=False), Column("address_family", Unicode(), nullable=False), Column("last_used", DateTime(), nullable=False), @@ -447,7 +525,6 @@ def withMetadata(cls, metadata=None): ) return cls(session, account, sessionAccount, sessionIP) - def tables(self): # type: () -> Iterable[Table] """ @@ -459,7 +536,6 @@ def tables(self): yield self.sessionAccount yield self.sessionIP - @inlineCallbacks def create(self, transaction): # type: (Transaction) -> Deferred @@ -472,7 +548,6 @@ def create(self, transaction): for table in self.tables(): yield transaction.execute(CreateTable(table)) - def migrationSQL(self): # type: () -> Text """ @@ -484,27 +559,27 @@ def migrationSQL(self): This SQL will not attempt to discern whether the tables exist already or whether the migrations should be run. """ - return (u"\n-- Klein Session Schema Version 1\n" + - (u";".join(str(CreateTable(table)) - for table in self.tables()))) - + return "\n-- Klein Session Schema Version 1\n" + ( + ";".join(str(CreateTable(table)) for table in self.tables()) + ) sessionSchema = SessionSchema.withMetadata(MetaData()) procurerFromTransactionT = Callable[[Transaction], ISessionProcurer] + @implementer(ISessionProcurer) -class IPTrackingProcurer(object): +class IPTrackingProcurer: """ An implementation of L{ISessionProcurer} that keeps track of the source IP of the originating session. """ def __init__( - self, - dataStore, # type: DataStore - procurerFromTransaction # type: procurerFromTransactionT + self, + dataStore, # type: DataStore + procurerFromTransaction, # type: procurerFromTransactionT ): # type: (...) -> None """ @@ -514,7 +589,6 @@ def __init__( self._dataStore = dataStore self._procurerFromTransaction = procurerFromTransaction - @inlineCallbacks def procureSession(self, request, forceInsecure=False): # type: (IRequest, bool) -> Deferred @@ -534,27 +608,30 @@ def procureSession(self, request, forceInsecure=False): try: ipAddress = (request.client.host or b"").decode("ascii") except BaseException: - ipAddress = u"" + ipAddress = "" sip = sessionSchema.sessionIP yield upsert( - transaction, sip, - dict(session_id=session.identifier, ip_address=ipAddress, - address_family=(u"AF_INET6" if u":" in ipAddress - else u"AF_INET")), - dict(last_used=datetime.utcnow()) + transaction, + sip, + dict( + session_id=session.identifier, + ip_address=ipAddress, + address_family=("AF_INET6" if ":" in ipAddress else "AF_INET"), + ), + dict(last_used=datetime.utcnow()), ) # XXX This should set a savepoint because we don't want application # logic to be able to roll back the IP access log. returnValue(session) - procurerFromStoreT = Callable[[ISessionStore], ISessionProcurer] + def procurerFromDataStore( - dataStore, # type: DataStore - authorizers, # type: List[ISQLAuthorizer] - procurerFromStore=SessionProcurer # type: procurerFromStoreT + dataStore, # type: DataStore + authorizers, # type: List[ISQLAuthorizer] + procurerFromStore=SessionProcurer, # type: procurerFromStoreT ): # type: (...) -> ISessionProcurer """ @@ -568,40 +645,40 @@ def procurerFromDataStore( @return: L{Deferred} firing with L{ISessionProcurer} """ - allAuthorizers = [simpleAccountBinding.authorizer, - logMeIn.authorizer] + list(authorizers) + allAuthorizers = [ + simpleAccountBinding.authorizer, + logMeIn.authorizer, + ] + list(authorizers) return IPTrackingProcurer( dataStore, - lambda transaction: procurerFromStore(SessionStore( - transaction, allAuthorizers - )) + lambda transaction: procurerFromStore( + SessionStore(transaction, allAuthorizers) + ), ) +class _FunctionWithAuthorizer: -class _FunctionWithAuthorizer(object): - - authorizer = None # type: Any + authorizer = None # type: Any def __call__( - self, - sessionStore, # type: SessionStore - transaction, # type: Transaction - session # type: ISession + self, + sessionStore, # type: SessionStore + transaction, # type: Transaction + session, # type: ISession ): # type: (...) -> Any """ Signature for a function that can have an authorizer attached to it. """ -_authorizerFunction = Callable[ - [SessionStore, Transaction, ISession], - Any -] + +_authorizerFunction = Callable[[SessionStore, Transaction, ISession], Any] + @implementer(ISQLAuthorizer) @attr.s -class SimpleSQLAuthorizer(object): +class SimpleSQLAuthorizer: authorizationInterface = attr.ib(type=Type) _decorated = attr.ib(type=_authorizerFunction) @@ -612,7 +689,7 @@ def authorizationForSession(self, sessionStore, transaction, session): def authorizerFor( - authorizationInterface, # type: IInterface + authorizationInterface, # type: IInterface ): # type: (...) -> Callable[[Callable], _FunctionWithAuthorizer] """ @@ -627,21 +704,23 @@ def authorizeFoo(dataStore, sessionStore, transaction, session): @return: a decorator that can decorate a function with the signature C{(metadata, dataStore, sessionStore, transaction, session)} """ + def decorator(decorated): # type: (_authorizerFunction) -> _FunctionWithAuthorizer result = cast(_FunctionWithAuthorizer, decorated) - result.authorizer = SimpleSQLAuthorizer(authorizationInterface, - decorated) + result.authorizer = SimpleSQLAuthorizer( + authorizationInterface, decorated + ) return result - return decorator + return decorator @authorizerFor(ISimpleAccountBinding) def simpleAccountBinding( - sessionStore, # type: SessionStore - transaction, # type: Transaction - session # type: ISession + sessionStore, # type: SessionStore + transaction, # type: Transaction + session, # type: ISession ): # type: (...) -> AccountSessionBinding """ @@ -650,19 +729,18 @@ def simpleAccountBinding( return AccountSessionBinding(session, transaction) - @authorizerFor(ISimpleAccount) @inlineCallbacks def logMeIn( - sessionStore, # type: SessionStore - transaction, # type: Transaction - session # type: ISession + sessionStore, # type: SessionStore + transaction, # type: Transaction + session, # type: ISession ): # type: (...) -> Deferred """ Retrieve an L{ISimpleAccount} authorization. """ - binding = ((yield session.authorize([ISimpleAccountBinding])) - [ISimpleAccountBinding]) - returnValue(next(iter((yield binding.boundAccounts())), - None)) + binding = (yield session.authorize([ISimpleAccountBinding]))[ + ISimpleAccountBinding + ] + returnValue(next(iter((yield binding.boundAccounts())), None)) diff --git a/src/klein/storage/_sql_generic.py b/src/klein/storage/_sql_generic.py index 5f48ff01a..b2b6c99b3 100644 --- a/src/klein/storage/_sql_generic.py +++ b/src/klein/storage/_sql_generic.py @@ -13,8 +13,13 @@ from sqlalchemy import create_engine -from twisted.internet.defer import (Deferred, gatherResults, inlineCallbacks, - returnValue, succeed) +from twisted.internet.defer import ( + Deferred, + gatherResults, + inlineCallbacks, + returnValue, + succeed, +) from zope.interface import Interface, implementer @@ -30,28 +35,33 @@ ROLLED_BACK = "rolled back" ROLLBACK_FAILED = "rollback failed" -if TYPE_CHECKING: # pragma: no cover - T = TypeVar('T') +if TYPE_CHECKING: # pragma: no cover + T = TypeVar("T") from twisted.internet.interfaces import IReactorThreads + IReactorThreads from typing import Iterable + Iterable from typing import Callable + Callable from twisted.web.iweb import IRequest + IRequest @attr.s -class Transaction(object): +class Transaction: """ Wrapper around a SQLAlchemy connection which is invalidated when the transaction is committed or rolled back. """ + _connection = attr.ib(type=_sqlAlchemyConnection) _transaction = attr.ib(type=_sqlAlchemyTransaction) - _parent = attr.ib(type='Optional[Transaction]', default=None) - _stopped = attr.ib(type=Text, default=u"") + _parent = attr.ib(type="Optional[Transaction]", default=None) + _stopped = attr.ib(type=Text, default="") _completeDeferred = attr.ib(type=Deferred, default=Factory(Deferred)) def _checkStopped(self): @@ -64,7 +74,6 @@ def _checkStopped(self): if self._parent is not None: self._parent._checkStopped() - def execute(self, statement, *multiparams, **params): # type: (Any, *Any, **Any) -> Deferred """ @@ -74,7 +83,6 @@ def execute(self, statement, *multiparams, **params): self._checkStopped() return self._connection.execute(statement, *multiparams, **params) - def commit(self): # type: () -> Deferred """ @@ -84,10 +92,9 @@ def commit(self): self._stopped = COMMITTING return self._transaction.commit().addCallbacks( (lambda commitResult: self._finishWith(COMMITTED)), - (lambda commitFailure: self._finishWith(COMMIT_FAILED)) + (lambda commitFailure: self._finishWith(COMMIT_FAILED)), ) - def rollback(self): # type: () -> Deferred """ @@ -97,10 +104,9 @@ def rollback(self): self._stopped = ROLLING_BACK return self._transaction.rollback().addCallbacks( (lambda commitResult: self._finishWith(ROLLED_BACK)), - (lambda commitFailure: self._finishWith(ROLLBACK_FAILED)) + (lambda commitFailure: self._finishWith(ROLLBACK_FAILED)), ) - def _finishWith(self, stopStatus): # type: (Text) -> None """ @@ -109,7 +115,6 @@ def _finishWith(self, stopStatus): self._stopped = stopStatus self._completeDeferred.callback(stopStatus) - @inlineCallbacks def savepoint(self): # type: () -> Deferred @@ -120,11 +125,11 @@ def savepoint(self): committed, this transaction's C{execute} method will execute within the context of that savepoint. """ - returnValue(Transaction( - self._connection, (yield self._connection.begin_nested()), - self - )) - + returnValue( + Transaction( + self._connection, (yield self._connection.begin_nested()), self + ) + ) def subtransact(self, logic): # type: (Callable[[Transaction], Deferred]) -> Deferred @@ -133,7 +138,6 @@ def subtransact(self, logic): """ return Transactor(self.savepoint).transact(logic) - def maybeCommit(self): # type: () -> Deferred """ @@ -144,7 +148,6 @@ def maybeCommit(self): return succeed(None) return self.commit() - def maybeRollback(self): # type: () -> Deferred """ @@ -157,13 +160,13 @@ def maybeRollback(self): @attr.s -class Transactor(object): +class Transactor: """ A context manager that represents the lifecycle of a transaction when paired with application code. """ - _newTransaction = attr.ib(type='Callable[[], Deferred]') + _newTransaction = attr.ib(type="Callable[[], Deferred]") _transaction = attr.ib(type=Optional[Transaction], default=None) @inlineCallbacks @@ -204,9 +207,8 @@ def transact(self, logic): returnValue(result) - @attr.s(hash=False) -class DataStore(object): +class DataStore: """ L{DataStore} is a generic storage object that connect to an SQL database, run transactions, and manage schema metadata. @@ -222,7 +224,8 @@ def newTransaction(self): Create a new Klein transaction. """ alchimiaConnection = ( - self._freeConnections.popleft() if self._freeConnections + self._freeConnections.popleft() + if self._freeConnections else (yield self._engine.connect()) ) alchimiaTransaction = yield alchimiaConnection.begin() @@ -233,8 +236,8 @@ def recycleTransaction(anything): # type: (T) -> T self._freeConnections.append(alchimiaConnection) return anything - returnValue(kleinTransaction) + returnValue(kleinTransaction) def transact(self, callable): # type: (Callable[[Transaction], Any]) -> Any @@ -252,7 +255,6 @@ def transact(self, callable): """ return Transactor(self.newTransaction).transact(callable) - @classmethod def open(cls, reactor, dbURL): # type: (IReactorThreads, Text) -> DataStore @@ -265,8 +267,9 @@ def open(cls, reactor, dbURL): @param dbURL: the SQLAlchemy database URI to connect to. @type dbURL: L{str} """ - return cls(create_engine(dbURL, reactor=reactor, - strategy=TWISTED_STRATEGY)) + return cls( + create_engine(dbURL, reactor=reactor, strategy=TWISTED_STRATEGY) + ) class ITransactionRequestAssociator(Interface): @@ -274,12 +277,14 @@ class ITransactionRequestAssociator(Interface): Associates transactions with requests. """ + @implementer(ITransactionRequestAssociator) @attr.s -class TransactionRequestAssociator(object): +class TransactionRequestAssociator: """ Does the thing the interface says. """ + _map = attr.ib(type=dict, default=Factory(dict)) committing = attr.ib(type=bool, default=False) @@ -301,8 +306,10 @@ def commitAll(self): Commit all associated transactions. """ self.committing = True - return gatherResults([value.maybeCommit() - for value in self._map.values()]) + return gatherResults( + [value.maybeCommit() for value in self._map.values()] + ) + @inlineCallbacks def requestBoundTransaction(request, dataStore): @@ -343,6 +350,7 @@ def requestBoundTransaction(request, dataStore): def finishCommit(result): # type: (Any) -> Deferred return assoc.commitAll() + request.notifyFinish().addBoth(finishCommit) # originalWrite = request.write diff --git a/src/klein/storage/interfaces.py b/src/klein/storage/interfaces.py index 5e1420821..8bbd733ca 100644 --- a/src/klein/storage/interfaces.py +++ b/src/klein/storage/interfaces.py @@ -1,14 +1,12 @@ - from typing import TYPE_CHECKING, Union from ._istorage import ISQLAuthorizer as _ISQLAuthorizer -if TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover from ._sql import SimpleSQLAuthorizer + ISQLAuthorizer = Union[_ISQLAuthorizer, SimpleSQLAuthorizer] else: ISQLAuthorizer = _ISQLAuthorizer -__all__ = [ - 'ISQLAuthorizer' -] +__all__ = ["ISQLAuthorizer"] diff --git a/src/klein/storage/sql.py b/src/klein/storage/sql.py index 68c898b54..8e214cfbf 100644 --- a/src/klein/storage/sql.py +++ b/src/klein/storage/sql.py @@ -1,10 +1,5 @@ - -from ._sql import ( - SessionSchema, authorizerFor, procurerFromDataStore -) -from ._sql_generic import ( - DataStore, Transaction -) +from ._sql import SessionSchema, authorizerFor, procurerFromDataStore +from ._sql_generic import DataStore, Transaction __all__ = [ "procurerFromDataStore", @@ -14,6 +9,7 @@ "Transaction", ] -if __name__ == '__main__': +if __name__ == "__main__": import sys + sys.stdout.write(SessionSchema.withMetadata().migrationSQL()) diff --git a/src/klein/storage/test/test_security.py b/src/klein/storage/test/test_security.py index 59651fedf..60c73641e 100644 --- a/src/klein/storage/test/test_security.py +++ b/src/klein/storage/test/test_security.py @@ -1,10 +1,9 @@ - from twisted.trial.unittest import TestCase from klein.storage import security -class SQLTests(TestCase): +class SQLTests(TestCase): def test_security(self): # type: () -> None """ diff --git a/src/klein/storage/test/test_sql.py b/src/klein/storage/test/test_sql.py index 8a0f73d4b..040b5960a 100644 --- a/src/klein/storage/test/test_sql.py +++ b/src/klein/storage/test/test_sql.py @@ -1,8 +1,8 @@ - from twisted.trial.unittest import TestCase from klein.storage import sql + class SQLTests(TestCase): def test_sql(self): # type: () -> None From fa14353b4d0848d0263990f484e198bc1327b750 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 30 May 2022 22:40:21 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/klein/interfaces.py | 19 +++++++++--------- src/klein/storage/_istorage.py | 4 +++- src/klein/storage/_security.py | 5 +++-- src/klein/storage/_sql.py | 32 +++++++++++++++---------------- src/klein/storage/_sql_generic.py | 12 +++++------- src/klein/storage/interfaces.py | 1 + src/klein/storage/sql.py | 1 + 7 files changed, 39 insertions(+), 35 deletions(-) diff --git a/src/klein/interfaces.py b/src/klein/interfaces.py index 968a29dba..e634c8dba 100644 --- a/src/klein/interfaces.py +++ b/src/klein/interfaces.py @@ -16,20 +16,21 @@ TransactionEnded, ) + if TYPE_CHECKING: # pragma: no cover - from ._storage.memory import MemorySessionStore, MemorySession + from typing import Union + + from ._dihttp import RequestComponent, RequestURL + from ._form import Field, FieldInjector, RenderableFormParam + from ._isession import IRequestLifecycleT as _IRequestLifecycleT + from ._session import Authorization, SessionProcurer + from ._storage.memory import MemorySession, MemorySessionStore from ._storage.sql import ( + AccountSessionBinding, + IPTrackingProcurer, SessionStore, SQLAccount, - IPTrackingProcurer, - AccountSessionBinding, ) - from ._session import SessionProcurer, Authorization - from ._form import Field, RenderableFormParam, FieldInjector - from ._isession import IRequestLifecycleT as _IRequestLifecycleT - from ._dihttp import RequestURL, RequestComponent - - from typing import Union ISessionStore = Union[_ISessionStore, MemorySessionStore, SessionStore] ISessionProcurer = Union[ diff --git a/src/klein/storage/_istorage.py b/src/klein/storage/_istorage.py index 2416d85b0..2bfb1272f 100644 --- a/src/klein/storage/_istorage.py +++ b/src/klein/storage/_istorage.py @@ -4,9 +4,11 @@ from .._typing import ifmethod + if TYPE_CHECKING: # pragma: no cover from twisted.internet.defer import Deferred - from ..interfaces import ISessionStore, ISession + + from ..interfaces import ISession, ISessionStore from ._sql_generic import Transaction ISession, ISessionStore, Deferred, Transaction diff --git a/src/klein/storage/_security.py b/src/klein/storage/_security.py index 7a19b0d00..c36f53c1c 100644 --- a/src/klein/storage/_security.py +++ b/src/klein/storage/_security.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Callable, Optional, TYPE_CHECKING, Text, Tuple +from typing import TYPE_CHECKING, Any, Callable, Optional, Text, Tuple from unicodedata import normalize from passlib.context import CryptContext @@ -7,8 +7,9 @@ from twisted.internet.defer import Deferred, inlineCallbacks, returnValue from twisted.internet.threads import deferToThread + if TYPE_CHECKING: # pragma: no cover - Text, Callable, Deferred, Optional, Tuple, Any + str, Callable, Deferred, Optional, Tuple, Any passlibContextWithGoodDefaults = partial(CryptContext, schemes=["bcrypt"]) diff --git a/src/klein/storage/_sql.py b/src/klein/storage/_sql.py index 59ed7098b..ba3249b01 100644 --- a/src/klein/storage/_sql.py +++ b/src/klein/storage/_sql.py @@ -3,13 +3,13 @@ from functools import reduce from os import urandom from typing import ( + TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, - TYPE_CHECKING, Text, Type, TypeVar, @@ -20,8 +20,6 @@ import attr from attr import Factory from attr.validators import instance_of as an - - from sqlalchemy import ( Boolean, Column, @@ -36,6 +34,8 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.schema import CreateTable from sqlalchemy.sql.expression import select +from zope.interface import implementer +from zope.interface.interfaces import IInterface from twisted.internet.defer import ( gatherResults, @@ -45,12 +45,6 @@ ) from twisted.python.compat import unicode -from zope.interface import implementer -from zope.interface.interfaces import IInterface - -from ._security import checkAndReset, computeKeyText -from ._sql_generic import Transaction, requestBoundTransaction -from .interfaces import ISQLAuthorizer from .. import SessionProcurer from ..interfaces import ( ISession, @@ -61,12 +55,18 @@ NoSuchSession, SessionMechanism, ) +from ._security import checkAndReset, computeKeyText +from ._sql_generic import Transaction, requestBoundTransaction +from .interfaces import ISQLAuthorizer + if TYPE_CHECKING: # pragma: no cover import sqlalchemy + from twisted.internet.defer import Deferred from twisted.internet.interfaces import IReactorThreads from twisted.web.iweb import IRequest + from ._sql_generic import DataStore ( @@ -76,7 +76,7 @@ Type, Iterable, IReactorThreads, - Text, + str, List, sqlalchemy, Dict, @@ -92,7 +92,7 @@ @attr.s class SQLSession: _sessionStore = attr.ib(type="SessionStore") - identifier = attr.ib(type=Text) + identifier = attr.ib(type=str) isConfidential = attr.ib(type=bool) authenticatedBy = attr.ib(type=SessionMechanism) @@ -129,8 +129,8 @@ class SessionIPInformation: Information about a session being used from a given IP address. """ - id = attr.ib(validator=an(str), type=Text) - ip = attr.ib(validator=an(str), type=Text) + id = attr.ib(validator=an(str), type=str) + ip = attr.ib(validator=an(str), type=str) when = attr.ib(validator=an(datetime), type=datetime) @@ -382,9 +382,9 @@ class SQLAccount: """ _transaction = attr.ib(type=Transaction) - accountID = attr.ib(type=Text) - username = attr.ib(type=Text) - email = attr.ib(type=Text) + accountID = attr.ib(type=str) + username = attr.ib(type=str) + email = attr.ib(type=str) def bindSession(self, session): # type: (ISession) -> Deferred diff --git a/src/klein/storage/_sql_generic.py b/src/klein/storage/_sql_generic.py index b2b6c99b3..5c515c475 100644 --- a/src/klein/storage/_sql_generic.py +++ b/src/klein/storage/_sql_generic.py @@ -4,14 +4,13 @@ from collections import deque from sys import exc_info -from typing import Any, Optional, TYPE_CHECKING, Text, TypeVar - -from alchimia import TWISTED_STRATEGY +from typing import TYPE_CHECKING, Any, Optional, Text, TypeVar import attr +from alchimia import TWISTED_STRATEGY from attr import Factory - from sqlalchemy import create_engine +from zope.interface import Interface, implementer from twisted.internet.defer import ( Deferred, @@ -21,10 +20,9 @@ succeed, ) -from zope.interface import Interface, implementer - from ..interfaces import TransactionEnded + _sqlAlchemyConnection = Any _sqlAlchemyTransaction = Any @@ -61,7 +59,7 @@ class Transaction: _connection = attr.ib(type=_sqlAlchemyConnection) _transaction = attr.ib(type=_sqlAlchemyTransaction) _parent = attr.ib(type="Optional[Transaction]", default=None) - _stopped = attr.ib(type=Text, default="") + _stopped = attr.ib(type=str, default="") _completeDeferred = attr.ib(type=Deferred, default=Factory(Deferred)) def _checkStopped(self): diff --git a/src/klein/storage/interfaces.py b/src/klein/storage/interfaces.py index 8bbd733ca..fa083f677 100644 --- a/src/klein/storage/interfaces.py +++ b/src/klein/storage/interfaces.py @@ -2,6 +2,7 @@ from ._istorage import ISQLAuthorizer as _ISQLAuthorizer + if TYPE_CHECKING: # pragma: no cover from ._sql import SimpleSQLAuthorizer diff --git a/src/klein/storage/sql.py b/src/klein/storage/sql.py index 8e214cfbf..62288a6c8 100644 --- a/src/klein/storage/sql.py +++ b/src/klein/storage/sql.py @@ -1,6 +1,7 @@ from ._sql import SessionSchema, authorizerFor, procurerFromDataStore from ._sql_generic import DataStore, Transaction + __all__ = [ "procurerFromDataStore", "authorizerFor", From 40a72a3471919e72d13f5400ec9804b19feddeee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Apr 2023 20:56:45 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/klein/storage/_security.py | 2 +- src/klein/storage/_sql.py | 1 - src/klein/storage/_sql_generic.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/klein/storage/_security.py b/src/klein/storage/_security.py index c36f53c1c..45ceb6ce8 100644 --- a/src/klein/storage/_security.py +++ b/src/klein/storage/_security.py @@ -1,5 +1,5 @@ from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Optional, Text, Tuple +from typing import Any, Callable, Optional, TYPE_CHECKING, Tuple from unicodedata import normalize from passlib.context import CryptContext diff --git a/src/klein/storage/_sql.py b/src/klein/storage/_sql.py index ba3249b01..2c750eb53 100644 --- a/src/klein/storage/_sql.py +++ b/src/klein/storage/_sql.py @@ -10,7 +10,6 @@ Iterable, List, Optional, - Text, Type, TypeVar, cast, diff --git a/src/klein/storage/_sql_generic.py b/src/klein/storage/_sql_generic.py index 5c515c475..0750751ec 100644 --- a/src/klein/storage/_sql_generic.py +++ b/src/klein/storage/_sql_generic.py @@ -4,7 +4,7 @@ from collections import deque from sys import exc_info -from typing import TYPE_CHECKING, Any, Optional, Text, TypeVar +from typing import Any, Optional, TYPE_CHECKING, TypeVar import attr from alchimia import TWISTED_STRATEGY From dc4629f6097de671f3ae45b626f5efe4ecab8b53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Apr 2023 21:02:18 +0000 Subject: [PATCH 5/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/klein/storage/_security.py | 2 +- src/klein/storage/_sql_generic.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/klein/storage/_security.py b/src/klein/storage/_security.py index 45ceb6ce8..e35321e42 100644 --- a/src/klein/storage/_security.py +++ b/src/klein/storage/_security.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Callable, Optional, TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple from unicodedata import normalize from passlib.context import CryptContext diff --git a/src/klein/storage/_sql_generic.py b/src/klein/storage/_sql_generic.py index 0750751ec..44f9b1eea 100644 --- a/src/klein/storage/_sql_generic.py +++ b/src/klein/storage/_sql_generic.py @@ -4,7 +4,7 @@ from collections import deque from sys import exc_info -from typing import Any, Optional, TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, Optional, TypeVar import attr from alchimia import TWISTED_STRATEGY From fa8a069fe15963d48883c0cf1de38430bf0cbc3b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 May 2023 22:24:53 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/klein/storage/_sql.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/klein/storage/_sql.py b/src/klein/storage/_sql.py index 2c750eb53..a45a3c64d 100644 --- a/src/klein/storage/_sql.py +++ b/src/klein/storage/_sql.py @@ -657,7 +657,6 @@ def procurerFromDataStore( class _FunctionWithAuthorizer: - authorizer = None # type: Any def __call__(