diff --git a/changelog.d/3-bug-fixes/pg-reconnect b/changelog.d/3-bug-fixes/pg-reconnect new file mode 100644 index 00000000000..4aa05fe5fc3 --- /dev/null +++ b/changelog.d/3-bug-fixes/pg-reconnect @@ -0,0 +1 @@ +Reconnect and retry queries when the PostgreSQL server restarts \ No newline at end of file diff --git a/flake.lock b/flake.lock index 87862f14021..50a36f10c4e 100644 --- a/flake.lock +++ b/flake.lock @@ -186,16 +186,16 @@ "hasql-migration": { "flake": false, "locked": { - "lastModified": 1777384964, - "narHash": "sha256-NRFZUDR4cW6jRihO311glqtlIlcRKSKeBXJErHuXf+k=", + "lastModified": 1777986637, + "narHash": "sha256-NdrqeecEdokSCqBm6BqZ9mCDnqgDzAiN5BMwjMcvv1Q=", "owner": "wireapp", "repo": "hasql-migration", - "rev": "ef03ac6410c94444bf1807fc4eda1db6b0974984", + "rev": "6fe20bfb145dde56254089902734e2fcb079fc19", "type": "github" }, "original": { "owner": "wireapp", - "ref": "allow-no-transaction", + "ref": "upgrade-hasql", "repo": "hasql-migration", "type": "github" } @@ -314,6 +314,23 @@ "type": "github" } }, + "postgresql-connection-string": { + "flake": false, + "locked": { + "lastModified": 1778144330, + "narHash": "sha256-w/TFQX7PIsLQTG+Es2yla584Y3T7Q6H1iIlqFap8qUk=", + "owner": "wireapp", + "repo": "postgresql-connection-string", + "rev": "ff98790cb3058545ea978ea143cd57b04360c48c", + "type": "github" + }, + "original": { + "owner": "wireapp", + "ref": "expose-from-key-value-params", + "repo": "postgresql-connection-string", + "type": "github" + } + }, "postie": { "flake": false, "locked": { @@ -345,6 +362,7 @@ "nixpkgs": "nixpkgs", "nixpkgs-unstable": "nixpkgs-unstable", "nixpkgs_24_11": "nixpkgs_24_11", + "postgresql-connection-string": "postgresql-connection-string", "postie": "postie", "sbomnix": "sbomnix", "servant-openapi3": "servant-openapi3", diff --git a/flake.nix b/flake.nix index b5c9a151a50..5a9b40ea81a 100644 --- a/flake.nix +++ b/flake.nix @@ -87,7 +87,12 @@ }; hasql-migration = { - url = "github:wireapp/hasql-migration?ref=allow-no-transaction"; + url = "github:wireapp/hasql-migration?ref=upgrade-hasql"; + flake = false; + }; + + postgresql-connection-string = { + url = "github:wireapp/postgresql-connection-string?ref=expose-from-key-value-params"; flake = false; }; }; diff --git a/libs/extended/default.nix b/libs/extended/default.nix index 3c544f93cd3..476c68fdfd4 100644 --- a/libs/extended/default.nix +++ b/libs/extended/default.nix @@ -32,6 +32,7 @@ , memory , metrics-wai , monad-control +, postgresql-connection-string , prometheus-client , retry , servant @@ -80,6 +81,7 @@ mkDerivation { memory metrics-wai monad-control + postgresql-connection-string prometheus-client retry servant diff --git a/libs/extended/extended.cabal b/libs/extended/extended.cabal index 006a74c703d..58d587bed59 100644 --- a/libs/extended/extended.cabal +++ b/libs/extended/extended.cabal @@ -112,6 +112,7 @@ library , memory , metrics-wai , monad-control + , postgresql-connection-string , prometheus-client , retry , servant diff --git a/libs/extended/src/Hasql/Pool/Extended.hs b/libs/extended/src/Hasql/Pool/Extended.hs index 8f86a5a153b..cf6e6d0f4f1 100644 --- a/libs/extended/src/Hasql/Pool/Extended.hs +++ b/libs/extended/src/Hasql/Pool/Extended.hs @@ -22,13 +22,12 @@ import Data.Map as Map import Data.Misc import Data.Set qualified as Set import Data.UUID -import Hasql.Connection.Setting qualified as HasqlSetting -import Hasql.Connection.Setting.Connection qualified as HasqlConn -import Hasql.Connection.Setting.Connection.Param qualified as HasqlConfig +import Hasql.Connection.Settings qualified as HasqlConnSettings import Hasql.Pool as HasqlPool import Hasql.Pool.Config qualified as HasqlPool import Hasql.Pool.Observation import Imports +import PostgresqlConnectionString qualified import Prometheus import Util.Options @@ -50,23 +49,21 @@ instance FromJSON PoolConfig where -- | Creates a pool from postgres config params -- --- HasqlConn.params translates pgParams into connection (which just holds the connection string and is not a real connection) --- HasqlSetting.connection unwraps the connection string out of connection --- HasqlPool.staticConnectionSettings translates the connection string to the pool settings +-- HasqlPool.staticConnectionSettings translates the connection settings to the pool settings -- HasqlPool.settings translates the pool settings into pool config -- HasqlPool.acquire creates the pool. -- ezpz. initPostgresPool :: PoolConfig -> Map Text Text -> Maybe FilePathSecrets -> IO HasqlPool.Pool initPostgresPool config pgConfig mFpSecrets = do mPw <- for mFpSecrets initCredentials - let pgConfigWithPw = maybe pgConfig (\pw -> Map.insert "password" pw pgConfig) mPw - pgParams = Map.foldMapWithKey (\k v -> [HasqlConfig.other k v]) pgConfigWithPw + let pgSettings = + HasqlConnSettings.connectionString (PostgresqlConnectionString.toUrl $ PostgresqlConnectionString.fromKeyValueParams pgConfig) + <> foldMap HasqlConnSettings.password mPw metrics <- initHasqlPoolMetrics connsRef <- newIORef $ Connections mempty mempty mempty HasqlPool.acquire $ HasqlPool.settings - [ HasqlPool.staticConnectionSettings $ - [HasqlSetting.connection $ HasqlConn.params pgParams], + [ HasqlPool.staticConnectionSettings pgSettings, HasqlPool.size config.size, HasqlPool.acquisitionTimeout config.acquisitionTimeout.duration, HasqlPool.agingTimeout config.agingTimeout.duration, diff --git a/libs/wire-subsystems/src/Wire/ConversationStore/Migration.hs b/libs/wire-subsystems/src/Wire/ConversationStore/Migration.hs index 4b32b9d0ccc..881ccd38ea3 100644 --- a/libs/wire-subsystems/src/Wire/ConversationStore/Migration.hs +++ b/libs/wire-subsystems/src/Wire/ConversationStore/Migration.hs @@ -263,7 +263,7 @@ saveConvToPostgres allConvData = do meta.cnvmCellsState, meta.cnvmParent ) - runTransaction ReadCommitted Write $ do + runTransactionWithRetry ReadCommitted Write $ do Transaction.statement convRow insertConv Transaction.statement localMemberColumns insertLocalMembers Transaction.statement remoteMemberColumns insertRemoteMembers @@ -468,7 +468,7 @@ getRemoteMemberStatusFromCassandra uid = withCassandra $ do saveRemoteMemberStatusToPostgres :: (PGConstraints r) => UserId -> Map (Remote ConvId) MemberStatus -> Sem r () saveRemoteMemberStatusToPostgres uid statusses = - runTransaction ReadCommitted Write $ do + runTransactionWithRetry ReadCommitted Write $ do Transaction.statement statusColumns insertStatuses Transaction.statement (DeleteUser, uid) markDeletionPendingStmt where diff --git a/libs/wire-subsystems/src/Wire/ConversationStore/Postgres.hs b/libs/wire-subsystems/src/Wire/ConversationStore/Postgres.hs index 7102b65d4e7..2db307aa8b1 100644 --- a/libs/wire-subsystems/src/Wire/ConversationStore/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/ConversationStore/Postgres.hs @@ -156,7 +156,7 @@ upsertConversationImpl lcnv nc = do meta.cnvmParent, fmap (.depth) hconfig ) - runTransaction ReadCommitted Write $ do + runTransactionWithRetry ReadCommitted Write $ do Transaction.statement convRow insertConvStatement upsertMembersTransaction storedConv.id_ $ UserList localUsers remoteUsers pure storedConv @@ -206,7 +206,7 @@ deleteConversationImpl cid = getConversationImpl :: (PGConstraints r) => ConvId -> Sem r (Maybe StoredConversation) getConversationImpl cid = - runTransaction ReadCommitted Read $ do + runTransactionWithRetry ReadCommitted Read $ do mConvRow <- Transaction.statement cid selectConvMetadata case mConvRow of Nothing -> pure Nothing @@ -624,7 +624,7 @@ deleteTeamConversationsImpl tid = -- MEMBER OPERATIONS upsertMembersImpl :: (PGConstraints r) => ConvId -> UserList (UserId, RoleName) -> Sem r ([LocalMember], [RemoteMember]) upsertMembersImpl convId users@(UserList lusers rusers) = do - runTransaction ReadCommitted Write $ upsertMembersTransaction convId users + runTransactionWithRetry ReadCommitted Write $ upsertMembersTransaction convId users pure (map newMemberWithRole lusers, map newRemoteMemberWithRole rusers) upsertMembersTransaction :: ConvId -> UserList (UserId, RoleName) -> Transaction () @@ -680,7 +680,7 @@ createBotMemberImpl serviceRef botId convId = do getLocalMemberImpl :: (PGConstraints r) => ConvId -> UserId -> Sem r (Maybe LocalMember) getLocalMemberImpl convId userId = do mRow <- - runSession $ do + runSessionWithRetry $ do mDirectMember <- HasqlSession.statement (convId, userId) selectMember case mDirectMember of Nothing -> HasqlSession.statement (convId, userId) selectParentMember @@ -769,7 +769,7 @@ type RemoteMemberRow = (ConvId, Domain, UserId, RoleName) getRemoteMemberImpl :: (PGConstraints r) => ConvId -> Remote UserId -> Sem r (Maybe RemoteMember) getRemoteMemberImpl convId (tUntagged -> Qualified uid domain) = do mRow <- - runSession $ do + runSessionWithRetry $ do mDirectMember <- HasqlSession.statement (convId, domain, uid) selectMember case mDirectMember of Nothing -> HasqlSession.statement (convId, domain, uid) selectParentMember @@ -958,7 +958,7 @@ setOtherRemoteMember cid (tUntagged -> Qualified uid domain) upd = deleteMembersImpl :: (PGConstraints r) => ConvId -> UserList UserId -> Sem r () deleteMembersImpl cid users = - runTransaction ReadCommitted Write $ do + runTransactionWithRetry ReadCommitted Write $ do Transaction.statement (cid, users.ulLocals) deleteLocalsStmt for_ (bucketRemote users.ulRemotes) $ \(tUntagged -> Qualified remotes domain) -> Transaction.statement (cid, domain, remotes) deleteRemotesStmt diff --git a/libs/wire-subsystems/src/Wire/MeetingsStore/Postgres.hs b/libs/wire-subsystems/src/Wire/MeetingsStore/Postgres.hs index 65f8f890626..6e83d61ff5d 100644 --- a/libs/wire-subsystems/src/Wire/MeetingsStore/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/MeetingsStore/Postgres.hs @@ -31,18 +31,17 @@ import Data.Time.Clock import Data.UUID (UUID, nil) import Data.Vector qualified as V import Hasql.Pool -import Hasql.Session import Hasql.Statement import Hasql.TH import Imports import Polysemy -import Polysemy.Error (Error, throw) +import Polysemy.Error (Error) import Polysemy.Input import Wire.API.Meeting (Recurrence) import Wire.API.PostgresMarshall (PostgresMarshall (..), PostgresUnmarshall (..), dimapPG) import Wire.API.User.Identity (EmailAddress, fromEmail) import Wire.MeetingsStore -import Wire.Postgres (PGConstraints) +import Wire.Postgres interpretMeetingsStoreToPostgres :: (PGConstraints r) => @@ -80,7 +79,6 @@ createMeetingImpl :: Bool -> Sem r StoredMeeting createMeetingImpl title creator startTime endTime recurrence convId emails trial = do - pool <- input now <- liftIO getCurrentTime let sm = StoredMeeting @@ -96,8 +94,7 @@ createMeetingImpl title creator startTime endTime recurrence convId emails trial createdAt = now, updatedAt = now } - result <- liftIO $ use pool $ statement sm insertStatement - either throw pure result + runStatement sm insertStatement insertStatement :: Statement StoredMeeting StoredMeeting insertStatement = @@ -187,18 +184,12 @@ updateMeetingImpl :: Maybe (Maybe Recurrence) -> Sem r (Maybe StoredMeeting) updateMeetingImpl meetingId mTitle mStartDate mEndDate mRecurrence = do - pool <- input - result <- liftIO $ use pool session - either throw pure result + case mRecurrence of + Nothing -> + runStatement (mTitle, mStartDate, mEndDate, meetingId) updateWithoutRecurrenceStatement + Just recurrence -> + runStatement (mTitle, mStartDate, mEndDate, recurrence, meetingId) updateWithRecurrenceStatement where - session :: Session (Maybe StoredMeeting) - session = - case mRecurrence of - Nothing -> - statement (mTitle, mStartDate, mEndDate, meetingId) updateWithoutRecurrenceStatement - Just recurrence -> - statement (mTitle, mStartDate, mEndDate, recurrence, meetingId) updateWithRecurrenceStatement - updateWithRecurrenceStatement :: Statement UpdateMeetingWithRecurrenceTuple (Maybe StoredMeeting) updateWithRecurrenceStatement = dimapPG @@ -256,12 +247,8 @@ deleteMeetingImpl :: MeetingId -> Sem r () deleteMeetingImpl meetingId = do - pool <- input - result <- liftIO $ use pool session - either throw pure result + runStatement (toUUID meetingId) deleteStatement where - session :: Session () - session = statement (toUUID meetingId) deleteStatement deleteStatement :: Statement UUID () deleteStatement = [resultlessStatement| @@ -276,9 +263,7 @@ getMeetingImpl :: MeetingId -> Sem r (Maybe StoredMeeting) getMeetingImpl meetingId = do - pool <- input - result <- liftIO $ use pool $ statement (toUUID meetingId) getMeetingStatement - either throw pure result + runStatement (toUUID meetingId) getMeetingStatement getMeetingStatement :: Statement UUID (Maybe StoredMeeting) getMeetingStatement = @@ -303,12 +288,8 @@ listMeetingsByUserImpl :: UTCTime -> Sem r [StoredMeeting] listMeetingsByUserImpl userId cutoffTime = do - pool <- input - result <- liftIO $ use pool session - either throw pure result + runStatement (toUUID userId, cutoffTime) $ V.toList <$> listStatement where - session :: Session [StoredMeeting] - session = statement (toUUID userId, cutoffTime) $ V.toList <$> listStatement listStatement :: Statement (UUID, UTCTime) (V.Vector StoredMeeting) listStatement = refineResult @@ -331,12 +312,8 @@ listMeetingsByConversationImpl :: UTCTime -> Sem r [StoredMeeting] listMeetingsByConversationImpl convId cutoffTime = do - pool <- input - result <- liftIO $ use pool session - either throw pure result + runStatement (toUUID convId, cutoffTime) $ V.toList <$> listStatement where - session :: Session [StoredMeeting] - session = statement (toUUID convId, cutoffTime) $ V.toList <$> listStatement listStatement :: Statement (UUID, UTCTime) (V.Vector StoredMeeting) listStatement = refineResult @@ -359,13 +336,8 @@ addInvitedEmailsImpl :: [EmailAddress] -> Sem r () addInvitedEmailsImpl meetingId emails = do - pool <- input - result <- liftIO $ use pool session - either throw pure result + runStatement (V.fromList (fromEmail <$> emails), toUUID meetingId) addEmailStatement where - session :: Session () - session = statement (V.fromList (fromEmail <$> emails), toUUID meetingId) addEmailStatement - addEmailStatement :: Statement (V.Vector Text, UUID) () addEmailStatement = [resultlessStatement| @@ -381,12 +353,8 @@ removeInvitedEmailsImpl :: [EmailAddress] -> Sem r () removeInvitedEmailsImpl meetingId emails = do - pool <- input - result <- liftIO $ use pool session - either throw pure result + runStatement (V.fromList (fromEmail <$> emails), toUUID meetingId) removeEmailStatement where - session :: Session () - session = statement (V.fromList (fromEmail <$> emails), toUUID meetingId) removeEmailStatement removeEmailStatement :: Statement (V.Vector Text, UUID) () removeEmailStatement = [resultlessStatement| diff --git a/libs/wire-subsystems/src/Wire/Postgres.hs b/libs/wire-subsystems/src/Wire/Postgres.hs index 84cca14b183..d18de5f10a8 100644 --- a/libs/wire-subsystems/src/Wire/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/Postgres.hs @@ -39,7 +39,9 @@ module Wire.Postgres -- * Runners runStatement, runSession, + runSessionWithRetry, runTransaction, + runTransactionWithRetry, runPipeline, parseCount, PGConstraints, @@ -68,11 +70,12 @@ where import Control.Monad.Trans.State import Data.Functor.Contravariant import Data.Id -import Data.Text qualified as T -import Data.Text.Encoding qualified as T +import Data.Text qualified as Text +import Data.Text.Encoding qualified as Text import Data.Time.Clock import Hasql.Decoders qualified as Dec import Hasql.Encoders qualified as Enc +import Hasql.Errors import Hasql.Pipeline (Pipeline) import Hasql.Pool import Hasql.Pool qualified as Hasql @@ -85,6 +88,7 @@ import Imports import Polysemy import Polysemy.Error (Error, throw) import Polysemy.Input +import PostgreSQL.ErrorCodes qualified as PostgreSQL import Wire.API.Pagination type PGConstraints r = @@ -93,22 +97,69 @@ type PGConstraints r = Member (Error Hasql.UsageError) r ) -runSession :: +-- | Resets the pool if it detects server errors due to admin intervention. +-- Things like server restart. Then retries the session. +-- +-- Inspired by https://github.com/nikita-volkov/hasql-pool/issues/27 +useWithResetAndRetry :: forall a. Pool -> Session a -> IO (Either UsageError a) +useWithResetAndRetry pool sess = go maxRetries + where + maxRetries :: Int + maxRetries = 5 + + resettableErrors :: [ByteString] + resettableErrors = [PostgreSQL.admin_shutdown, PostgreSQL.crash_shutdown, PostgreSQL.cannot_connect_now, PostgreSQL.database_dropped] + + go :: Int -> IO (Either UsageError a) + go 0 = use pool sess + go n = do + eithRes <- use pool sess + case eithRes of + Left (SessionUsageError (StatementSessionError _ _ _ _ _ (ServerStatementError (ServerError errCode _ _ _ _)))) -> do + if (Text.encodeUtf8 errCode `elem` resettableErrors) + then do + release pool + go (n - 1) + else pure eithRes + _ -> pure eithRes + +-- | Runs a 'Session' using the 'Hasql.Pool'. Retries on server errors due to +-- admin intervention. Things like server restart. +runSessionWithRetry :: (PGConstraints r) => Session a -> Sem r a +runSessionWithRetry sess = do + pool <- input + liftIO (useWithResetAndRetry pool sess) >>= either throw pure + +-- | Runs a 'Session' using the 'Hasql.Pool'. Unlike 'runSessionWithRetry' it +-- doesn't retry on server errors due to admin intervention. Things like server +-- restart. +-- +-- This should only be used if a session cannot be retried due to some other +-- 'IO' happening within the session, which cannot be repeated. +runSession :: (PGConstraints r) => Session a -> Sem r a runSession sess = do pool <- input liftIO (use pool sess) >>= either throw pure +-- | Runs a 'Statement' using the 'Hasql.Pool'. Always retries on server errors +-- due to admin intervention. Things like server restart. runStatement :: (PGConstraints r) => a -> Statement a b -> Sem r b runStatement a stmt = - runSession $ statement a stmt + runSessionWithRetry $ statement a stmt +-- | Runs a 'Transaction' using the 'Hasql.Pool'. Unlike +-- 'runTransactionWithRetry' it doesn't retry on server errors due to admin +-- intervention. Things like server restart. +-- +-- This should only be used if a transaction cannot be retried due to some other +-- 'IO' happening within the transaction, which cannot be repeated. runTransaction :: (PGConstraints r) => IsolationLevel -> @@ -118,6 +169,19 @@ runTransaction :: runTransaction isolationLevel mode t = runSession $ Transaction.transaction isolationLevel mode t +-- | Runs a 'Transaction' using the 'Hasql.Pool'. Retries on server errors due +-- to admin intervention. Things like server restart. +runTransactionWithRetry :: + (PGConstraints r) => + IsolationLevel -> + Mode -> + Transaction a -> + Sem r a +runTransactionWithRetry isolationLevel mode t = + runSessionWithRetry $ Transaction.transaction isolationLevel mode t + +-- | Runs a 'Pipeline' using the 'Hasql.Pool'. Always retries on server errors +-- due to admin intervention. Things like server restart. runPipeline :: (PGConstraints r) => Pipeline a -> @@ -200,7 +264,7 @@ paramLiteral encoder q = } argPattern0 :: Text -> Int -> Text -argPattern0 t i = "$" <> T.pack (show i) <> " :: " <> t +argPattern0 t i = "$" <> Text.pack (show i) <> " :: " <> t argPattern :: Text -> Int -> Text argPattern t i = "(" <> argPattern0 t i <> ")" @@ -266,7 +330,7 @@ clause op cl = } where wrap :: [Text] -> Text - wrap xs = "(" <> T.intercalate ", " xs <> ")" + wrap xs = "(" <> Text.intercalate ", " xs <> ")" -- | Fragment for a clause with a single value. clause1 :: forall a. (PostgresValue a) => Text -> Text -> a -> QueryFragment @@ -276,7 +340,7 @@ orderBy :: [(Text, SortOrder)] -> QueryFragment orderBy os = literal $ "order by " - <> T.intercalate ", " (map (\(field, o) -> field <> " " <> sortOrderClause o) os) + <> Text.intercalate ", " (map (\(field, o) -> field <> " " <> sortOrderClause o) os) limit :: forall a. (PostgresValue a) => a -> QueryFragment limit n = paramLiteral (valueEncoder n) $ \i -> @@ -288,11 +352,10 @@ offset n = paramLiteral (valueEncoder n) $ \i -> buildStatement :: QueryFragment -> Dec.Result b -> Statement () b buildStatement frag dec = - Statement - (T.encodeUtf8 (evalState frag.query 1)) + preparable + (evalState frag.query 1) frag.encoder dec - True nextIndex :: State Int Int nextIndex = get <* modify succ diff --git a/libs/wire-subsystems/src/Wire/PostgresMigrations.hs b/libs/wire-subsystems/src/Wire/PostgresMigrations.hs index 9681b98bd30..d8b022b1ffd 100644 --- a/libs/wire-subsystems/src/Wire/PostgresMigrations.hs +++ b/libs/wire-subsystems/src/Wire/PostgresMigrations.hs @@ -24,6 +24,7 @@ import Control.Exception import Data.FileEmbed import Data.Hashable qualified as Hashable import Data.Set qualified as Set +import Data.Text.Encoding qualified as Text import Hasql.Migration import Hasql.Pool import Hasql.Session @@ -37,7 +38,7 @@ import System.Logger qualified as Log import UnliftIO.Retry allMigrations :: [MigrationCommand] -allMigrations = map (uncurry MigrationScript) $(makeRelativeToProject "postgres-migrations" >>= embedDir) +allMigrations = map (\(name, contentBS) -> MigrationScript name (Text.decodeUtf8 contentBS)) $(makeRelativeToProject "postgres-migrations" >>= embedDir) -- | Scripts which cannot be run in a transaction nonTransactionMigrations :: Set ScriptName @@ -117,6 +118,6 @@ resetSchema :: Pool -> Logger -> IO () resetSchema pool logger = do Log.warn logger $ Log.msg (Log.val "resetting postgres schema") let session = do - sql "DROP SCHEMA IF EXISTS public CASCADE" - sql "CREATE SCHEMA IF NOT EXISTS public" + script "DROP SCHEMA IF EXISTS public CASCADE" + script "CREATE SCHEMA IF NOT EXISTS public" either throwIO pure =<< use pool session diff --git a/libs/wire-subsystems/src/Wire/TeamCollaboratorsStore/Postgres.hs b/libs/wire-subsystems/src/Wire/TeamCollaboratorsStore/Postgres.hs index 9942fede7a0..b454f1380d1 100644 --- a/libs/wire-subsystems/src/Wire/TeamCollaboratorsStore/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/TeamCollaboratorsStore/Postgres.hs @@ -30,15 +30,14 @@ import Data.Set qualified as Set import Data.UUID import Data.Vector hiding (mapM) import Hasql.Pool -import Hasql.Session import Hasql.Statement import Hasql.TH import Imports import Polysemy import Polysemy.Error (Error, throw) import Polysemy.Input -import PostgreSQL.ErrorCodes import Wire.API.Team.Collaborator +import Wire.Postgres import Wire.TeamCollaboratorsStore interpretTeamCollaboratorsStoreToPostgres :: @@ -59,21 +58,13 @@ interpretTeamCollaboratorsStoreToPostgres = RemoveTeamCollaborator userId teamId -> removeTeamCollaboratorImpl userId teamId getTeamCollaboratorImpl :: - ( Member (Input Pool) r, - Member (Embed IO) r, - Member (Error UsageError) r - ) => + (PGConstraints r) => TeamId -> UserId -> Sem r (Maybe TeamCollaborator) getTeamCollaboratorImpl teamId userId = do - pool <- input - eitherTeamCollaborator <- liftIO $ use pool session - either throw pure eitherTeamCollaborator + runStatement (teamId, userId) getTeamCollaboratorStatement where - session :: Session (Maybe TeamCollaborator) - session = statement (teamId, userId) getTeamCollaboratorStatement - getTeamCollaboratorStatement :: Statement (TeamId, UserId) (Maybe TeamCollaborator) getTeamCollaboratorStatement = dimap @@ -84,9 +75,7 @@ getTeamCollaboratorImpl teamId userId = do |] createTeamCollaboratorImpl :: - ( Member (Input Pool) r, - Member (Embed IO) r, - Member (Error UsageError) r, + ( PGConstraints r, Member (Error TeamCollaboratorsError) r ) => UserId -> @@ -94,33 +83,23 @@ createTeamCollaboratorImpl :: Set CollaboratorPermission -> Sem r () createTeamCollaboratorImpl userId teamId permissions = do - pool <- input - eitherErrorOrUnit <- liftIO $ use pool session - either errHandler pure eitherErrorOrUnit + mReturn <- runStatement (userId, teamId, permissions) insertStatement + case mReturn of + Just _ -> pure () + Nothing -> throw AlreadyExists where - session :: Session () - session = statement (userId, teamId, permissions) insertStatement - - insertStatement :: Statement (UserId, TeamId, Set CollaboratorPermission) () + insertStatement :: Statement (UserId, TeamId, Set CollaboratorPermission) (Maybe Int32) insertStatement = lmap ( \(uid, tid, pms) -> (toUUID uid, toUUID tid, collaboratorPermissionToPostgreslRep <$> (Data.Vector.fromList . toAscList) pms) ) - $ [resultlessStatement| + $ [maybeStatement| insert into collaborators (user_id, team_id, permissions) values ($1 :: uuid, $2 :: uuid, $3 :: smallint[]) + on conflict do nothing + returning (1 :: integer) |] - errHandler :: - ( Member (Error UsageError) r', - Member (Error TeamCollaboratorsError) r' - ) => - UsageError -> - Sem r' () - errHandler (SessionUsageError (QueryError _ _ (ResultError (ServerError code _ _ _ _)))) - | code == unique_violation = throw AlreadyExists - errHandler e = throw e - getAllTeamCollaboratorsImpl :: ( Member (Input Pool) r, Member (Embed IO) r, @@ -129,13 +108,8 @@ getAllTeamCollaboratorsImpl :: TeamId -> Sem r [TeamCollaborator] getAllTeamCollaboratorsImpl teamId = do - pool <- input - eitherTeamCollaborators <- liftIO $ use pool session - either throw pure eitherTeamCollaborators + runStatement teamId getAllTeamCollaboratorsStatement where - session :: Session [TeamCollaborator] - session = statement teamId getAllTeamCollaboratorsStatement - getAllTeamCollaboratorsStatement :: Statement TeamId [TeamCollaborator] getAllTeamCollaboratorsStatement = dimap toUUID (Data.Vector.toList . (toTeamCollaborator <$>)) $ @@ -153,13 +127,8 @@ updateTeamCollaboratorImpl :: Set CollaboratorPermission -> Sem r () updateTeamCollaboratorImpl userId teamId permissions = do - pool <- input - eitherErrorOrUnit <- liftIO $ use pool session - either throw pure eitherErrorOrUnit + runStatement (userId, teamId, permissions) updateStatement where - session :: Session () - session = statement (userId, teamId, permissions) updateStatement - updateStatement :: Statement (UserId, TeamId, Set CollaboratorPermission) () updateStatement = lmap @@ -179,13 +148,8 @@ removeTeamCollaboratorImpl :: TeamId -> Sem r () removeTeamCollaboratorImpl userId teamId = do - pool <- input - eitherErrorOrUnit <- liftIO $ use pool session - either throw pure eitherErrorOrUnit + runStatement (userId, teamId) deleteStatement where - session :: Session () - session = statement (userId, teamId) deleteStatement - deleteStatement :: Statement (UserId, TeamId) () deleteStatement = lmap @@ -224,13 +188,8 @@ getTeamCollaborationsImpl :: UserId -> Sem r [TeamCollaborator] getTeamCollaborationsImpl teamId = do - pool <- input - eitherTeamCollaborators <- liftIO $ use pool session - either throw pure eitherTeamCollaborators + runStatement teamId getAllCollaborationsByUserStatement where - session :: Session [TeamCollaborator] - session = statement teamId getAllCollaborationsByUserStatement - getAllCollaborationsByUserStatement :: Statement UserId [TeamCollaborator] getAllCollaborationsByUserStatement = dimap toUUID (Data.Vector.toList . (toTeamCollaborator <$>)) $ @@ -247,13 +206,8 @@ getTeamCollaboratorsWithIdsImpl :: Set UserId -> Sem r [TeamCollaborator] getTeamCollaboratorsWithIdsImpl teamIds userIds = do - pool <- input - eitherTeamCollaborators <- liftIO $ use pool session - either throw pure eitherTeamCollaborators + runStatement (Data.Set.toList teamIds, Data.Set.toList userIds) getTeamCollaboratorStatement where - session :: Session [TeamCollaborator] - session = statement (Data.Set.toList teamIds, Data.Set.toList userIds) getTeamCollaboratorStatement - getTeamCollaboratorStatement :: Statement ([TeamId], [UserId]) [TeamCollaborator] getTeamCollaboratorStatement = dimap diff --git a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs index ce4e3ae515e..6702f5a41f4 100644 --- a/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs +++ b/libs/wire-subsystems/src/Wire/UserGroupStore/Postgres.hs @@ -40,9 +40,10 @@ import Hasql.Transaction qualified as Tx import Hasql.Transaction.Sessions qualified as TxSessions import Imports import Polysemy -import Polysemy.Error (Error, throw) +import Polysemy.Error (Error) import Polysemy.Input import Wire.API.Pagination +import Wire.API.PostgresMarshall import Wire.API.User.Profile import Wire.API.UserGroup hiding (UpdateUserGroupChannels) import Wire.API.UserGroup.Pagination @@ -79,13 +80,8 @@ interpretUserGroupStoreToPostgres = getUserGroupsForConv :: (UserGroupStorePostgresEffectConstraints r) => ConvId -> Sem r (Vector UserGroup) getUserGroupsForConv convId = do - pool <- input - result <- liftIO $ use pool session - either throw pure result + runStatement convId stmt where - session :: Session (Vector UserGroup) - session = statement convId stmt - decodeRow :: (UUID, Text, Int32, UTCTime, Vector UUID) -> Either Text UserGroup decodeRow (uuid, nameTxt, managedByInt, createdAtUtc, memberIds) = do name <- userGroupNameFromText nameTxt @@ -112,15 +108,9 @@ getUserGroupsForConv convId = do getUserGroupIdsForUsers :: (UserGroupStorePostgresEffectConstraints r) => [UserId] -> Sem r (Map UserId [UserGroupId]) getUserGroupIdsForUsers uidsList = do - pool <- input - result <- liftIO $ use pool session - either throw pure result + rows <- runStatement (V.fromList (toUUID <$> uidsList)) stmt + pure $ Map.fromListWith (<>) [(Id u, [Id g]) | (u, g) <- V.toList rows] where - session :: Session (Map UserId [UserGroupId]) - session = do - rows <- statement (V.fromList (toUUID <$> uidsList)) stmt - pure $ Map.fromListWith (<>) [(Id u, [Id g]) | (u, g) <- V.toList rows] - stmt :: Statement (Vector UUID) (Vector (UUID, UUID)) stmt = [vectorStatement| @@ -131,15 +121,10 @@ getUserGroupIdsForUsers uidsList = do updateUsers :: (UserGroupStorePostgresEffectConstraints r) => UserGroupId -> Vector UserId -> Sem r () updateUsers gid uids = do - pool <- input - result <- liftIO $ use pool session - either throw pure result + runTransactionWithRetry TxSessions.Serializable TxSessions.Write do + Tx.statement gid deleteAllUsersStatement + Tx.statement (toUUID gid, uids) insertGroupMembersStatement where - session :: Session () - session = TxSessions.transaction TxSessions.Serializable TxSessions.Write do - Tx.statement gid deleteAllUsersStatement - Tx.statement (toUUID gid, uids) insertGroupMembersStatement - deleteAllUsersStatement :: Statement UserGroupId () deleteAllUsersStatement = dimap (.toUUID) (const ()) $ @@ -155,10 +140,11 @@ getUserGroup :: Bool -> Sem r (Maybe UserGroup) getUserGroup team id_ includeChannels = do - pool <- input loc <- inputQualifyLocal () - eitherUserGroup <- liftIO $ use pool (if includeChannels then sessionWithChannels loc else session) - either throw pure eitherUserGroup + runSessionWithRetry $ + if includeChannels + then sessionWithChannels loc + else session where session :: Session (Maybe UserGroup) session = runMaybeT do @@ -228,7 +214,7 @@ getUserGroupsWithMembers :: UserGroupPageRequest -> Sem r UserGroupPageWithMembers getUserGroupsWithMembers tid req = - runTransaction TxSessions.ReadCommitted TxSessions.Read $ + runTransactionWithRetry TxSessions.ReadCommitted TxSessions.Read $ UserGroupPage <$> Tx.statement () (refineResult (mapM toUserGroup) $ buildStatement query rows) <*> getUserGroupCount tid req @@ -340,7 +326,7 @@ getUserGroups :: Sem r UserGroupPage getUserGroups tid req@(UserGroupPageRequest {..}) = do loc <- inputQualifyLocal () - runTransaction TxSessions.ReadCommitted TxSessions.Read $ + runTransactionWithRetry TxSessions.ReadCommitted TxSessions.Read $ UserGroupPage <$> getUserGroupsSession loc <*> getUserGroupCount tid req where getUserGroupsSession :: Local () -> Tx.Transaction [UserGroupMeta] @@ -358,7 +344,7 @@ getUserGroups tid req@(UserGroupPageRequest {..}) = do ) decodeRow - decodeRow :: HD.Result [(UUID, Text, Int32, UTCTime, Maybe Int32, Int32, Maybe (Vector UUID))] + decodeRow :: HD.Result [(UUID, Text, Int32, UTCTime, Maybe Int64, Int64, Maybe (Vector UUID))] decodeRow = HD.rowList ( (,,,,,,) @@ -366,15 +352,15 @@ getUserGroups tid req@(UserGroupPageRequest {..}) = do <*> HD.column (HD.nonNullable HD.text) <*> HD.column (HD.nonNullable HD.int4) <*> HD.column (HD.nonNullable HD.timestamptz) - <*> (if req.includeMemberCount then Just <$> HD.column (HD.nonNullable HD.int4) else pure Nothing) - <*> HD.column (HD.nonNullable HD.int4) + <*> (if req.includeMemberCount then Just <$> HD.column (HD.nonNullable HD.int8) else pure Nothing) + <*> HD.column (HD.nonNullable HD.int8) <*> ( if req.includeChannels then Just <$> decodeUuidVector else pure Nothing ) ) - parseRow :: Local a -> (UUID, Text, Int32, UTCTime, Maybe Int32, Int32, Maybe (Vector UUID)) -> Either Text UserGroupMeta + parseRow :: Local a -> (UUID, Text, Int32, UTCTime, Maybe Int64, Int64, Maybe (Vector UUID)) -> Either Text UserGroupMeta parseRow loc (Id -> id_, namePre, managedByPre, toUTCTimeMillis -> createdAt, membersCountRaw, channelsCountRaw, maybeChannels) = do managedBy <- parseManagedBy managedByPre name <- userGroupNameFromText namePre @@ -400,26 +386,21 @@ createUserGroup :: ManagedBy -> Sem r UserGroup createUserGroup team newUserGroup managedBy = do - pool <- input - eitherUuid <- liftIO $ use pool session - either throw pure eitherUuid + runTransaction TxSessions.Serializable TxSessions.Write do + (id_, name, managedBy_, createdAt) <- Tx.statement (newUserGroup.name, team, managedBy) insertGroupStatement + Tx.statement (toUUID id_, newUserGroup.members) insertGroupMembersStatement + pure + UserGroup_ + { membersCount = Nothing, + members = Identity newUserGroup.members, + channels = mempty, + managedBy = managedBy_, + channelsCount = Nothing, + id_, + name, + createdAt + } where - session :: Session UserGroup - session = TxSessions.transaction TxSessions.Serializable TxSessions.Write do - (id_, name, managedBy_, createdAt) <- Tx.statement (newUserGroup.name, team, managedBy) insertGroupStatement - Tx.statement (toUUID id_, newUserGroup.members) insertGroupMembersStatement - pure - UserGroup_ - { membersCount = Nothing, - members = Identity newUserGroup.members, - channels = mempty, - managedBy = managedBy_, - channelsCount = Nothing, - id_, - name, - createdAt - } - decodeMetadataRow :: (UUID, Text, Int32, UTCTime) -> Either Text (UserGroupId, UserGroupName, ManagedBy, UTCTimeMillis) decodeMetadataRow (groupId, name, managedByInt, utcTime) = (Id groupId,,,toUTCTimeMillis utcTime) @@ -452,15 +433,9 @@ updateGroup :: UserGroupUpdate -> Sem r (Maybe ()) updateGroup tid gid gup = do - pool <- input - result <- liftIO $ use pool session - either throw pure result + found <- isJust <$> runStatement (tid, gid, gup.name) updateGroupStatement + pure $ if found then Just () else Nothing where - session :: Session (Maybe ()) - session = TxSessions.transaction TxSessions.Serializable TxSessions.Write do - found <- isJust <$> Tx.statement (tid, gid, gup.name) updateGroupStatement - pure $ if found then Just () else Nothing - updateGroupStatement :: Statement (TeamId, UserGroupId, UserGroupName) (Maybe Bool) updateGroupStatement = lmap (\(t, g, n) -> (t.toUUID, g.toUUID, userGroupNameToText n)) $ @@ -477,15 +452,9 @@ deleteGroup :: UserGroupId -> Sem r (Maybe ()) deleteGroup tid gid = do - pool <- input - result <- liftIO $ use pool session - either throw pure result + found <- isJust <$> runStatement (tid, gid) deleteGroupStatement + pure $ if found then Just () else Nothing where - session :: Session (Maybe ()) - session = TxSessions.transaction TxSessions.Serializable TxSessions.Write do - found <- isJust <$> Tx.statement (tid, gid) deleteGroupStatement - pure $ if found then Just () else Nothing - deleteGroupStatement :: Statement (TeamId, UserGroupId) (Maybe Bool) deleteGroupStatement = lmap (\(t, g) -> (t.toUUID, g.toUUID)) $ @@ -501,11 +470,15 @@ addUser :: UserGroupId -> UserId -> Sem r () -addUser = - crudUser - [resultlessStatement| - insert into user_group_member (user_group_id, user_id) values (($1 :: uuid), ($2 :: uuid)) - |] +addUser ugid uid = + runStatement (ugid, uid) insertStatement + where + insertStatement :: Statement (UserGroupId, UserId) () + insertStatement = + lmapPG + [resultlessStatement| + insert into user_group_member (user_group_id, user_id) values (($1 :: uuid), ($2 :: uuid)) + |] removeUser :: forall r. @@ -513,11 +486,15 @@ removeUser :: UserGroupId -> UserId -> Sem r () -removeUser = - crudUser - [resultlessStatement| - delete from user_group_member where user_group_id = ($1 :: uuid) and user_id = ($2 :: uuid) - |] +removeUser ugid uid = + runStatement (ugid, uid) deleteStatement + where + deleteStatement :: Statement (UserGroupId, UserId) () + deleteStatement = + lmapPG + [resultlessStatement| + delete from user_group_member where user_group_id = ($1 :: uuid) and user_id = ($2 :: uuid) + |] updateUserGroupChannels :: forall r. @@ -527,16 +504,11 @@ updateUserGroupChannels :: Vector ConvId -> Sem r () updateUserGroupChannels appendOnly gid convIds = do - pool <- input - eitherErrorOrUnit <- liftIO $ use pool session - either throw pure eitherErrorOrUnit + runTransaction TxSessions.Serializable TxSessions.Write $ do + unless appendOnly $ + Tx.statement (gid, convIds) deleteStatement + Tx.statement (gid, convIds) insertStatement where - session :: Session () - session = TxSessions.transaction TxSessions.Serializable TxSessions.Write $ do - unless appendOnly $ - Tx.statement (gid, convIds) deleteStatement - Tx.statement (gid, convIds) insertStatement - deleteStatement :: Statement (UserGroupId, Vector ConvId) () deleteStatement = lmap @@ -560,15 +532,9 @@ getUserGroupChannels :: UserGroupId -> Sem r (Maybe (Vector ConvId)) getUserGroupChannels tid gid = do - pool <- input - result <- liftIO $ use pool session - either throw pure result + mbUuids <- runStatement (gid, tid) getChannelsStatement + pure (fmap (fmap Id) mbUuids) where - session :: Session (Maybe (Vector ConvId)) - session = do - mbUuids <- statement (gid, tid) getChannelsStatement - pure (fmap (fmap Id) mbUuids) - getChannelsStatement :: Statement (UserGroupId, TeamId) (Maybe (Vector UUID)) getChannelsStatement = lmap (\(g, t) -> (g.toUUID, t.toUUID)) $ @@ -582,23 +548,5 @@ getUserGroupChannels tid gid = do where ug.id = ($1 :: uuid) and ug.team_id = ($2 :: uuid) |] -crudUser :: - forall r. - (UserGroupStorePostgresEffectConstraints r) => - Statement (UUID, UUID) () -> - UserGroupId -> - UserId -> - Sem r () -crudUser op gid uid = do - pool <- input - result <- liftIO $ use pool session - either throw pure result - where - session :: Session () - session = TxSessions.transaction TxSessions.Serializable TxSessions.Write do - Tx.statement - (gid, uid) - (lmap (\(gid_, uid_) -> (gid_.toUUID, uid_.toUUID)) op) - toRelationTable :: a -> Vector b -> (Vector a, Vector b) toRelationTable a bs = (a <$ bs, bs) diff --git a/nix/haskell-pins.nix b/nix/haskell-pins.nix index c7fced9fddb..cf7490f1b68 100644 --- a/nix/haskell-pins.nix +++ b/nix/haskell-pins.nix @@ -164,9 +164,15 @@ let }; # PR: https://github.com/tvh/hasql-migration/pull/19 + # and hasql-upgrade, no PR yet. hasql-migration = { src = inputs.hasql-migration; }; + + # PR: https://github.com/nikita-volkov/postgresql-connection-string/pull/4 + postgresql-connection-string = { + src = inputs.postgresql-connection-string; + }; }; hackagePins = { @@ -194,6 +200,31 @@ let version = "0.2.0"; sha256 = "sha256-kEalrs79uI8CMaVa7suYEzeer/YqFoJOqkV+LhiUwY4="; }; + + postgresql-binary = { + version = "0.15.0.1"; + sha256 = "sha256-q5t2OgiDxyt8WU+zHVxpyVhFF9PtDu2BlQRfuPpBkgk="; + }; + + hasql = { + version = "1.10.3"; + sha256 = "sha256-aJg6+oSWGkXm9pYLVv15d7M7HcnHhZpkw5c7ezxh2Yc="; + }; + + hasql-th = { + version = "0.5"; + sha256 = "sha256-qD9RljGDwMpPZ2epCxzL3Sbbn2Ce1472Vf2AGFroIW8="; + }; + + hasql-transaction = { + version = "1.2.2"; + sha256 = "sha256-o53h6ly2Kukhw9dcyAOvywzwlZDdgb+b/jqbw72lLHg="; + }; + + hasql-pool = { + version = "1.4.2"; + sha256 = "sha256-iQB2TD9hsPnqoVh5mR3Y2K8Cv67rWqBR0WHxOWZeiD8="; + }; }; # Name -> Source -> Maybe Subpath -> Drv mkGitDrv = name: src: subpath: