mas_storage_pg/
repository.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::ops::{Deref, DerefMut};
8
9use async_trait::async_trait;
10use futures_util::{FutureExt, TryFutureExt, future::BoxFuture};
11use mas_storage::{
12    BoxRepository, BoxRepositoryFactory, MapErr, Repository, RepositoryAccess, RepositoryError,
13    RepositoryFactory, RepositoryTransaction,
14    app_session::AppSessionRepository,
15    compat::{
16        CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
17        CompatSsoLoginRepository,
18    },
19    oauth2::{
20        OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2ClientRepository,
21        OAuth2DeviceCodeGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository,
22    },
23    personal::PersonalSessionRepository,
24    policy_data::PolicyDataRepository,
25    queue::{QueueJobRepository, QueueScheduleRepository, QueueWorkerRepository},
26    upstream_oauth2::{
27        UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository,
28        UpstreamOAuthSessionRepository,
29    },
30    user::{
31        BrowserSessionRepository, UserEmailRepository, UserPasswordRepository,
32        UserRecoveryRepository, UserRegistrationRepository, UserRegistrationTokenRepository,
33        UserRepository, UserTermsRepository,
34    },
35};
36use sqlx::{PgConnection, PgPool, Postgres, Transaction};
37use tracing::Instrument;
38
39use crate::{
40    DatabaseError,
41    app_session::PgAppSessionRepository,
42    compat::{
43        PgCompatAccessTokenRepository, PgCompatRefreshTokenRepository, PgCompatSessionRepository,
44        PgCompatSsoLoginRepository,
45    },
46    oauth2::{
47        PgOAuth2AccessTokenRepository, PgOAuth2AuthorizationGrantRepository,
48        PgOAuth2ClientRepository, PgOAuth2DeviceCodeGrantRepository,
49        PgOAuth2RefreshTokenRepository, PgOAuth2SessionRepository,
50    },
51    personal::{PgPersonalAccessTokenRepository, PgPersonalSessionRepository},
52    policy_data::PgPolicyDataRepository,
53    queue::{
54        job::PgQueueJobRepository, schedule::PgQueueScheduleRepository,
55        worker::PgQueueWorkerRepository,
56    },
57    telemetry::DB_CLIENT_CONNECTIONS_CREATE_TIME_HISTOGRAM,
58    upstream_oauth2::{
59        PgUpstreamOAuthLinkRepository, PgUpstreamOAuthProviderRepository,
60        PgUpstreamOAuthSessionRepository,
61    },
62    user::{
63        PgBrowserSessionRepository, PgUserEmailRepository, PgUserPasswordRepository,
64        PgUserRecoveryRepository, PgUserRegistrationRepository, PgUserRegistrationTokenRepository,
65        PgUserRepository, PgUserTermsRepository,
66    },
67};
68
69/// An implementation of the [`RepositoryFactory`] trait backed by a PostgreSQL
70/// connection pool.
71#[derive(Clone)]
72pub struct PgRepositoryFactory {
73    pool: PgPool,
74}
75
76impl PgRepositoryFactory {
77    /// Create a new [`PgRepositoryFactory`] from a PostgreSQL connection pool.
78    #[must_use]
79    pub fn new(pool: PgPool) -> Self {
80        Self { pool }
81    }
82
83    /// Box the factory
84    #[must_use]
85    pub fn boxed(self) -> BoxRepositoryFactory {
86        Box::new(self)
87    }
88
89    /// Get the underlying PostgreSQL connection pool
90    #[must_use]
91    pub fn pool(&self) -> PgPool {
92        self.pool.clone()
93    }
94}
95
96#[async_trait]
97impl RepositoryFactory for PgRepositoryFactory {
98    async fn create(&self) -> Result<BoxRepository, RepositoryError> {
99        let start = std::time::Instant::now();
100        let repo = PgRepository::from_pool(&self.pool)
101            .await
102            .map_err(RepositoryError::from_error)?
103            .boxed();
104
105        // Measure the time it took to create the connection
106        let duration = start.elapsed();
107        let duration_ms = duration.as_millis().try_into().unwrap_or(u64::MAX);
108        DB_CLIENT_CONNECTIONS_CREATE_TIME_HISTOGRAM.record(duration_ms, &[]);
109
110        Ok(repo)
111    }
112}
113
114/// An implementation of the [`Repository`] trait backed by a PostgreSQL
115/// transaction.
116pub struct PgRepository<C = Transaction<'static, Postgres>> {
117    conn: C,
118}
119
120impl PgRepository {
121    /// Create a new [`PgRepository`] from a PostgreSQL connection pool,
122    /// starting a transaction.
123    ///
124    /// # Errors
125    ///
126    /// Returns a [`DatabaseError`] if the transaction could not be started.
127    pub async fn from_pool(pool: &PgPool) -> Result<Self, DatabaseError> {
128        let txn = pool.begin().await?;
129        Ok(Self::from_conn(txn))
130    }
131
132    /// Transform the repository into a type-erased [`BoxRepository`]
133    pub fn boxed(self) -> BoxRepository {
134        Box::new(MapErr::new(self, RepositoryError::from_error))
135    }
136}
137
138impl<C> PgRepository<C> {
139    /// Create a new [`PgRepository`] from an existing PostgreSQL connection
140    /// with a transaction
141    pub fn from_conn(conn: C) -> Self {
142        PgRepository { conn }
143    }
144
145    /// Consume this [`PgRepository`], returning the underlying connection.
146    pub fn into_inner(self) -> C {
147        self.conn
148    }
149}
150
151impl<C> AsRef<C> for PgRepository<C> {
152    fn as_ref(&self) -> &C {
153        &self.conn
154    }
155}
156
157impl<C> AsMut<C> for PgRepository<C> {
158    fn as_mut(&mut self) -> &mut C {
159        &mut self.conn
160    }
161}
162
163impl<C> Deref for PgRepository<C> {
164    type Target = C;
165
166    fn deref(&self) -> &Self::Target {
167        &self.conn
168    }
169}
170
171impl<C> DerefMut for PgRepository<C> {
172    fn deref_mut(&mut self) -> &mut Self::Target {
173        &mut self.conn
174    }
175}
176
177impl Repository<DatabaseError> for PgRepository {}
178
179impl RepositoryTransaction for PgRepository {
180    type Error = DatabaseError;
181
182    fn save(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
183        let span = tracing::info_span!("db.save");
184        self.conn
185            .commit()
186            .map_err(DatabaseError::from)
187            .instrument(span)
188            .boxed()
189    }
190
191    fn cancel(self: Box<Self>) -> BoxFuture<'static, Result<(), Self::Error>> {
192        let span = tracing::info_span!("db.cancel");
193        self.conn
194            .rollback()
195            .map_err(DatabaseError::from)
196            .instrument(span)
197            .boxed()
198    }
199}
200
201impl<C> RepositoryAccess for PgRepository<C>
202where
203    C: AsMut<PgConnection> + Send,
204{
205    type Error = DatabaseError;
206
207    fn upstream_oauth_link<'c>(
208        &'c mut self,
209    ) -> Box<dyn UpstreamOAuthLinkRepository<Error = Self::Error> + 'c> {
210        Box::new(PgUpstreamOAuthLinkRepository::new(self.conn.as_mut()))
211    }
212
213    fn upstream_oauth_provider<'c>(
214        &'c mut self,
215    ) -> Box<dyn UpstreamOAuthProviderRepository<Error = Self::Error> + 'c> {
216        Box::new(PgUpstreamOAuthProviderRepository::new(self.conn.as_mut()))
217    }
218
219    fn upstream_oauth_session<'c>(
220        &'c mut self,
221    ) -> Box<dyn UpstreamOAuthSessionRepository<Error = Self::Error> + 'c> {
222        Box::new(PgUpstreamOAuthSessionRepository::new(self.conn.as_mut()))
223    }
224
225    fn user<'c>(&'c mut self) -> Box<dyn UserRepository<Error = Self::Error> + 'c> {
226        Box::new(PgUserRepository::new(self.conn.as_mut()))
227    }
228
229    fn user_email<'c>(&'c mut self) -> Box<dyn UserEmailRepository<Error = Self::Error> + 'c> {
230        Box::new(PgUserEmailRepository::new(self.conn.as_mut()))
231    }
232
233    fn user_password<'c>(
234        &'c mut self,
235    ) -> Box<dyn UserPasswordRepository<Error = Self::Error> + 'c> {
236        Box::new(PgUserPasswordRepository::new(self.conn.as_mut()))
237    }
238
239    fn user_recovery<'c>(
240        &'c mut self,
241    ) -> Box<dyn UserRecoveryRepository<Error = Self::Error> + 'c> {
242        Box::new(PgUserRecoveryRepository::new(self.conn.as_mut()))
243    }
244
245    fn user_terms<'c>(&'c mut self) -> Box<dyn UserTermsRepository<Error = Self::Error> + 'c> {
246        Box::new(PgUserTermsRepository::new(self.conn.as_mut()))
247    }
248
249    fn user_registration<'c>(
250        &'c mut self,
251    ) -> Box<dyn UserRegistrationRepository<Error = Self::Error> + 'c> {
252        Box::new(PgUserRegistrationRepository::new(self.conn.as_mut()))
253    }
254
255    fn user_registration_token<'c>(
256        &'c mut self,
257    ) -> Box<dyn UserRegistrationTokenRepository<Error = Self::Error> + 'c> {
258        Box::new(PgUserRegistrationTokenRepository::new(self.conn.as_mut()))
259    }
260
261    fn browser_session<'c>(
262        &'c mut self,
263    ) -> Box<dyn BrowserSessionRepository<Error = Self::Error> + 'c> {
264        Box::new(PgBrowserSessionRepository::new(self.conn.as_mut()))
265    }
266
267    fn app_session<'c>(&'c mut self) -> Box<dyn AppSessionRepository<Error = Self::Error> + 'c> {
268        Box::new(PgAppSessionRepository::new(self.conn.as_mut()))
269    }
270
271    fn oauth2_client<'c>(
272        &'c mut self,
273    ) -> Box<dyn OAuth2ClientRepository<Error = Self::Error> + 'c> {
274        Box::new(PgOAuth2ClientRepository::new(self.conn.as_mut()))
275    }
276
277    fn oauth2_authorization_grant<'c>(
278        &'c mut self,
279    ) -> Box<dyn OAuth2AuthorizationGrantRepository<Error = Self::Error> + 'c> {
280        Box::new(PgOAuth2AuthorizationGrantRepository::new(
281            self.conn.as_mut(),
282        ))
283    }
284
285    fn oauth2_session<'c>(
286        &'c mut self,
287    ) -> Box<dyn OAuth2SessionRepository<Error = Self::Error> + 'c> {
288        Box::new(PgOAuth2SessionRepository::new(self.conn.as_mut()))
289    }
290
291    fn oauth2_access_token<'c>(
292        &'c mut self,
293    ) -> Box<dyn OAuth2AccessTokenRepository<Error = Self::Error> + 'c> {
294        Box::new(PgOAuth2AccessTokenRepository::new(self.conn.as_mut()))
295    }
296
297    fn oauth2_refresh_token<'c>(
298        &'c mut self,
299    ) -> Box<dyn OAuth2RefreshTokenRepository<Error = Self::Error> + 'c> {
300        Box::new(PgOAuth2RefreshTokenRepository::new(self.conn.as_mut()))
301    }
302
303    fn oauth2_device_code_grant<'c>(
304        &'c mut self,
305    ) -> Box<dyn OAuth2DeviceCodeGrantRepository<Error = Self::Error> + 'c> {
306        Box::new(PgOAuth2DeviceCodeGrantRepository::new(self.conn.as_mut()))
307    }
308
309    fn compat_session<'c>(
310        &'c mut self,
311    ) -> Box<dyn CompatSessionRepository<Error = Self::Error> + 'c> {
312        Box::new(PgCompatSessionRepository::new(self.conn.as_mut()))
313    }
314
315    fn compat_sso_login<'c>(
316        &'c mut self,
317    ) -> Box<dyn CompatSsoLoginRepository<Error = Self::Error> + 'c> {
318        Box::new(PgCompatSsoLoginRepository::new(self.conn.as_mut()))
319    }
320
321    fn compat_access_token<'c>(
322        &'c mut self,
323    ) -> Box<dyn CompatAccessTokenRepository<Error = Self::Error> + 'c> {
324        Box::new(PgCompatAccessTokenRepository::new(self.conn.as_mut()))
325    }
326
327    fn compat_refresh_token<'c>(
328        &'c mut self,
329    ) -> Box<dyn CompatRefreshTokenRepository<Error = Self::Error> + 'c> {
330        Box::new(PgCompatRefreshTokenRepository::new(self.conn.as_mut()))
331    }
332
333    fn personal_access_token<'c>(
334        &'c mut self,
335    ) -> Box<dyn mas_storage::personal::PersonalAccessTokenRepository<Error = Self::Error> + 'c>
336    {
337        Box::new(PgPersonalAccessTokenRepository::new(self.conn.as_mut()))
338    }
339
340    fn personal_session<'c>(
341        &'c mut self,
342    ) -> Box<dyn PersonalSessionRepository<Error = Self::Error> + 'c> {
343        Box::new(PgPersonalSessionRepository::new(self.conn.as_mut()))
344    }
345
346    fn queue_worker<'c>(&'c mut self) -> Box<dyn QueueWorkerRepository<Error = Self::Error> + 'c> {
347        Box::new(PgQueueWorkerRepository::new(self.conn.as_mut()))
348    }
349
350    fn queue_job<'c>(&'c mut self) -> Box<dyn QueueJobRepository<Error = Self::Error> + 'c> {
351        Box::new(PgQueueJobRepository::new(self.conn.as_mut()))
352    }
353
354    fn queue_schedule<'c>(
355        &'c mut self,
356    ) -> Box<dyn QueueScheduleRepository<Error = Self::Error> + 'c> {
357        Box::new(PgQueueScheduleRepository::new(self.conn.as_mut()))
358    }
359
360    fn policy_data<'c>(&'c mut self) -> Box<dyn PolicyDataRepository<Error = Self::Error> + 'c> {
361        Box::new(PgPolicyDataRepository::new(self.conn.as_mut()))
362    }
363}