mas_storage_pg/personal/
session.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11    Clock, User,
12    personal::session::{PersonalSession, PersonalSessionOwner, SessionState},
13};
14use mas_storage::{
15    Page, Pagination,
16    pagination::Node,
17    personal::{PersonalSessionFilter, PersonalSessionRepository, PersonalSessionState},
18};
19use oauth2_types::scope::Scope;
20use rand::RngCore;
21use sea_query::{
22    Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
23    extension::postgres::PgExpr as _,
24};
25use sea_query_binder::SqlxBinder as _;
26use sqlx::PgConnection;
27use ulid::Ulid;
28use uuid::Uuid;
29
30use crate::{
31    DatabaseError,
32    errors::DatabaseInconsistencyError,
33    filter::{Filter, StatementExt as _},
34    iden::PersonalSessions,
35    pagination::QueryBuilderExt as _,
36    tracing::ExecuteExt as _,
37};
38
39/// An implementation of [`PersonalSessionRepository`] for a PostgreSQL
40/// connection
41pub struct PgPersonalSessionRepository<'c> {
42    conn: &'c mut PgConnection,
43}
44
45impl<'c> PgPersonalSessionRepository<'c> {
46    /// Create a new [`PgPersonalSessionRepository`] from an active PostgreSQL
47    /// connection
48    pub fn new(conn: &'c mut PgConnection) -> Self {
49        Self { conn }
50    }
51}
52
53#[derive(sqlx::FromRow)]
54#[enum_def]
55struct PersonalSessionLookup {
56    personal_session_id: Uuid,
57    owner_user_id: Option<Uuid>,
58    owner_oauth2_client_id: Option<Uuid>,
59    actor_user_id: Uuid,
60    human_name: String,
61    scope_list: Vec<String>,
62    created_at: DateTime<Utc>,
63    revoked_at: Option<DateTime<Utc>>,
64    last_active_at: Option<DateTime<Utc>>,
65    last_active_ip: Option<IpAddr>,
66}
67
68impl Node<Ulid> for PersonalSessionLookup {
69    fn cursor(&self) -> Ulid {
70        self.personal_session_id.into()
71    }
72}
73
74impl TryFrom<PersonalSessionLookup> for PersonalSession {
75    type Error = DatabaseInconsistencyError;
76
77    fn try_from(value: PersonalSessionLookup) -> Result<Self, Self::Error> {
78        let id = Ulid::from(value.personal_session_id);
79        let scope: Result<Scope, _> = value.scope_list.iter().map(|s| s.parse()).collect();
80        let scope = scope.map_err(|e| {
81            DatabaseInconsistencyError::on("personal_sessions")
82                .column("scope")
83                .row(id)
84                .source(e)
85        })?;
86
87        let state = match value.revoked_at {
88            None => SessionState::Valid,
89            Some(revoked_at) => SessionState::Revoked { revoked_at },
90        };
91
92        let owner = match (value.owner_user_id, value.owner_oauth2_client_id) {
93            (Some(owner_user_id), None) => PersonalSessionOwner::User(Ulid::from(owner_user_id)),
94            (None, Some(owner_oauth2_client_id)) => {
95                PersonalSessionOwner::OAuth2Client(Ulid::from(owner_oauth2_client_id))
96            }
97            _ => {
98                // should be impossible (CHECK constraint in Postgres prevents it)
99                return Err(DatabaseInconsistencyError::on("personal_sessions")
100                    .column("owner_user_id, owner_oauth2_client_id")
101                    .row(id));
102            }
103        };
104
105        Ok(PersonalSession {
106            id,
107            state,
108            owner,
109            actor_user_id: Ulid::from(value.actor_user_id),
110            human_name: value.human_name,
111            scope,
112            created_at: value.created_at,
113            last_active_at: value.last_active_at,
114            last_active_ip: value.last_active_ip,
115        })
116    }
117}
118
119#[async_trait]
120impl PersonalSessionRepository for PgPersonalSessionRepository<'_> {
121    type Error = DatabaseError;
122
123    #[tracing::instrument(
124        name = "db.personal_session.lookup",
125        skip_all,
126        fields(
127            db.query.text,
128            session.id = %id,
129        ),
130        err,
131    )]
132    async fn lookup(&mut self, id: Ulid) -> Result<Option<PersonalSession>, Self::Error> {
133        let res = sqlx::query_as!(
134            PersonalSessionLookup,
135            r#"
136                SELECT personal_session_id
137                     , owner_user_id
138                     , owner_oauth2_client_id
139                     , actor_user_id
140                     , scope_list
141                     , created_at
142                     , revoked_at
143                     , human_name
144                     , last_active_at
145                     , last_active_ip as "last_active_ip: IpAddr"
146                FROM personal_sessions
147
148                WHERE personal_session_id = $1
149            "#,
150            Uuid::from(id),
151        )
152        .traced()
153        .fetch_optional(&mut *self.conn)
154        .await?;
155
156        let Some(session) = res else { return Ok(None) };
157
158        Ok(Some(session.try_into()?))
159    }
160
161    #[tracing::instrument(
162        name = "db.personal_session.add",
163        skip_all,
164        fields(
165            db.query.text,
166            session.id,
167            session.scope = %scope,
168        ),
169        err,
170    )]
171    async fn add(
172        &mut self,
173        rng: &mut (dyn RngCore + Send),
174        clock: &dyn Clock,
175        owner: PersonalSessionOwner,
176        actor_user: &User,
177        human_name: String,
178        scope: Scope,
179    ) -> Result<PersonalSession, Self::Error> {
180        let created_at = clock.now();
181        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
182        tracing::Span::current().record("session.id", tracing::field::display(id));
183
184        let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
185
186        let (owner_user_id, owner_oauth2_client_id) = match owner {
187            PersonalSessionOwner::User(ulid) => (Some(Uuid::from(ulid)), None),
188            PersonalSessionOwner::OAuth2Client(ulid) => (None, Some(Uuid::from(ulid))),
189        };
190
191        sqlx::query!(
192            r#"
193                INSERT INTO personal_sessions
194                    ( personal_session_id
195                    , owner_user_id
196                    , owner_oauth2_client_id
197                    , actor_user_id
198                    , human_name
199                    , scope_list
200                    , created_at
201                    )
202                VALUES ($1, $2, $3, $4, $5, $6, $7)
203            "#,
204            Uuid::from(id),
205            owner_user_id,
206            owner_oauth2_client_id,
207            Uuid::from(actor_user.id),
208            &human_name,
209            &scope_list,
210            created_at,
211        )
212        .traced()
213        .execute(&mut *self.conn)
214        .await?;
215
216        Ok(PersonalSession {
217            id,
218            state: SessionState::Valid,
219            owner,
220            actor_user_id: actor_user.id,
221            human_name,
222            scope,
223            created_at,
224            last_active_at: None,
225            last_active_ip: None,
226        })
227    }
228
229    #[tracing::instrument(
230        name = "db.personal_session.revoke",
231        skip_all,
232        fields(
233            db.query.text,
234            %session.id,
235            %session.scope,
236        ),
237        err,
238    )]
239    async fn revoke(
240        &mut self,
241        clock: &dyn Clock,
242        session: PersonalSession,
243    ) -> Result<PersonalSession, Self::Error> {
244        let finished_at = clock.now();
245        let res = sqlx::query!(
246            r#"
247                UPDATE personal_sessions
248                SET revoked_at = $2
249                WHERE personal_session_id = $1
250            "#,
251            Uuid::from(session.id),
252            finished_at,
253        )
254        .traced()
255        .execute(&mut *self.conn)
256        .await?;
257
258        DatabaseError::ensure_affected_rows(&res, 1)?;
259
260        session
261            .finish(finished_at)
262            .map_err(DatabaseError::to_invalid_operation)
263    }
264
265    #[tracing::instrument(
266        name = "db.personal_session.list",
267        skip_all,
268        fields(
269            db.query.text,
270        ),
271        err,
272    )]
273    async fn list(
274        &mut self,
275        filter: PersonalSessionFilter<'_>,
276        pagination: Pagination,
277    ) -> Result<Page<PersonalSession>, Self::Error> {
278        let (sql, arguments) = Query::select()
279            .expr_as(
280                Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)),
281                PersonalSessionLookupIden::PersonalSessionId,
282            )
283            .expr_as(
284                Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId)),
285                PersonalSessionLookupIden::OwnerUserId,
286            )
287            .expr_as(
288                Expr::col((
289                    PersonalSessions::Table,
290                    PersonalSessions::OwnerOAuth2ClientId,
291                )),
292                PersonalSessionLookupIden::OwnerOauth2ClientId,
293            )
294            .expr_as(
295                Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId)),
296                PersonalSessionLookupIden::ActorUserId,
297            )
298            .expr_as(
299                Expr::col((PersonalSessions::Table, PersonalSessions::HumanName)),
300                PersonalSessionLookupIden::HumanName,
301            )
302            .expr_as(
303                Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
304                PersonalSessionLookupIden::ScopeList,
305            )
306            .expr_as(
307                Expr::col((PersonalSessions::Table, PersonalSessions::CreatedAt)),
308                PersonalSessionLookupIden::CreatedAt,
309            )
310            .expr_as(
311                Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)),
312                PersonalSessionLookupIden::RevokedAt,
313            )
314            .expr_as(
315                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt)),
316                PersonalSessionLookupIden::LastActiveAt,
317            )
318            .expr_as(
319                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveIp)),
320                PersonalSessionLookupIden::LastActiveIp,
321            )
322            .from(PersonalSessions::Table)
323            .apply_filter(filter)
324            .generate_pagination(
325                (PersonalSessions::Table, PersonalSessions::PersonalSessionId),
326                pagination,
327            )
328            .build_sqlx(PostgresQueryBuilder);
329
330        let edges: Vec<PersonalSessionLookup> = sqlx::query_as_with(&sql, arguments)
331            .traced()
332            .fetch_all(&mut *self.conn)
333            .await?;
334
335        let page = pagination.process(edges).try_map(TryFrom::try_from)?;
336
337        Ok(page)
338    }
339
340    #[tracing::instrument(
341        name = "db.personal_session.count",
342        skip_all,
343        fields(
344            db.query.text,
345        ),
346        err,
347    )]
348    async fn count(&mut self, filter: PersonalSessionFilter<'_>) -> Result<usize, Self::Error> {
349        let (sql, arguments) = Query::select()
350            .expr(Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)).count())
351            .from(PersonalSessions::Table)
352            .apply_filter(filter)
353            .build_sqlx(PostgresQueryBuilder);
354
355        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
356            .traced()
357            .fetch_one(&mut *self.conn)
358            .await?;
359
360        count
361            .try_into()
362            .map_err(DatabaseError::to_invalid_operation)
363    }
364}
365
366impl Filter for PersonalSessionFilter<'_> {
367    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
368        sea_query::Condition::all()
369            .add_option(self.owner_user().map(|user| {
370                Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId))
371                    .eq(Uuid::from(user.id))
372            }))
373            .add_option(self.owner_oauth2_client().map(|client| {
374                Expr::col((
375                    PersonalSessions::Table,
376                    PersonalSessions::OwnerOAuth2ClientId,
377                ))
378                .eq(Uuid::from(client.id))
379            }))
380            .add_option(self.actor_user().map(|user| {
381                Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId))
382                    .eq(Uuid::from(user.id))
383            }))
384            .add_option(self.device().map(|device| -> SimpleExpr {
385                if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
386                    Condition::any()
387                        .add(
388                            Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
389                                PersonalSessions::Table,
390                                PersonalSessions::ScopeList,
391                            )))),
392                        )
393                        .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
394                            Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
395                        )))
396                        .into()
397                } else {
398                    // If the device ID can't be encoded as a scope token, match no rows
399                    Expr::val(false).into()
400                }
401            }))
402            .add_option(self.state().map(|state| match state {
403                PersonalSessionState::Active => {
404                    Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_null()
405                }
406                PersonalSessionState::Revoked => {
407                    Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_not_null()
408                }
409            }))
410            .add_option(self.scope().map(|scope| {
411                let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
412                Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)).contains(scope)
413            }))
414            .add_option(self.last_active_before().map(|last_active_before| {
415                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
416                    .lt(last_active_before)
417            }))
418            .add_option(self.last_active_after().map(|last_active_after| {
419                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
420                    .gt(last_active_after)
421            }))
422    }
423}