1use async_trait::async_trait;
12use mas_data_model::{Clock, User};
13use mas_storage::user::{UserFilter, UserRepository};
14use rand::RngCore;
15use sea_query::{Expr, PostgresQueryBuilder, Query, extension::postgres::PgExpr as _};
16use sea_query_binder::SqlxBinder;
17use sqlx::PgConnection;
18use ulid::Ulid;
19use uuid::Uuid;
20
21use crate::{
22 DatabaseError,
23 filter::{Filter, StatementExt},
24 iden::Users,
25 pagination::QueryBuilderExt,
26 tracing::ExecuteExt,
27};
28
29mod email;
30mod password;
31mod recovery;
32mod registration;
33mod registration_token;
34mod session;
35mod terms;
36
37#[cfg(test)]
38mod tests;
39
40pub use self::{
41 email::PgUserEmailRepository, password::PgUserPasswordRepository,
42 recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
43 registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
44 terms::PgUserTermsRepository,
45};
46
47pub struct PgUserRepository<'c> {
49 conn: &'c mut PgConnection,
50}
51
52impl<'c> PgUserRepository<'c> {
53 pub fn new(conn: &'c mut PgConnection) -> Self {
55 Self { conn }
56 }
57}
58
59mod priv_ {
60 #![allow(missing_docs)]
63
64 use chrono::{DateTime, Utc};
65 use mas_storage::pagination::Node;
66 use sea_query::enum_def;
67 use ulid::Ulid;
68 use uuid::Uuid;
69
70 #[derive(Debug, Clone, sqlx::FromRow)]
71 #[enum_def]
72 pub(super) struct UserLookup {
73 pub(super) user_id: Uuid,
74 pub(super) username: String,
75 pub(super) created_at: DateTime<Utc>,
76 pub(super) locked_at: Option<DateTime<Utc>>,
77 pub(super) deactivated_at: Option<DateTime<Utc>>,
78 pub(super) can_request_admin: bool,
79 pub(super) is_guest: bool,
80 }
81
82 impl Node<Ulid> for UserLookup {
83 fn cursor(&self) -> Ulid {
84 self.user_id.into()
85 }
86 }
87}
88
89use priv_::{UserLookup, UserLookupIden};
90
91impl From<UserLookup> for User {
92 fn from(value: UserLookup) -> Self {
93 let id = value.user_id.into();
94 Self {
95 id,
96 username: value.username,
97 sub: id.to_string(),
98 created_at: value.created_at,
99 locked_at: value.locked_at,
100 deactivated_at: value.deactivated_at,
101 can_request_admin: value.can_request_admin,
102 is_guest: value.is_guest,
103 }
104 }
105}
106
107impl Filter for UserFilter<'_> {
108 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
109 sea_query::Condition::all()
110 .add_option(self.state().map(|state| {
111 match state {
112 mas_storage::user::UserState::Deactivated => {
113 Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
114 }
115 mas_storage::user::UserState::Locked => {
116 Expr::col((Users::Table, Users::LockedAt)).is_not_null()
117 }
118 mas_storage::user::UserState::Active => {
119 Expr::col((Users::Table, Users::LockedAt))
120 .is_null()
121 .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
122 }
123 }
124 }))
125 .add_option(self.can_request_admin().map(|can_request_admin| {
126 Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
127 }))
128 .add_option(
129 self.is_guest()
130 .map(|is_guest| Expr::col((Users::Table, Users::IsGuest)).eq(is_guest)),
131 )
132 .add_option(self.search().map(|search| {
133 Expr::col((Users::Table, Users::Username)).ilike(format!("%{search}%"))
134 }))
135 }
136}
137
138#[async_trait]
139impl UserRepository for PgUserRepository<'_> {
140 type Error = DatabaseError;
141
142 #[tracing::instrument(
143 name = "db.user.lookup",
144 skip_all,
145 fields(
146 db.query.text,
147 user.id = %id,
148 ),
149 err,
150 )]
151 async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
152 let res = sqlx::query_as!(
153 UserLookup,
154 r#"
155 SELECT user_id
156 , username
157 , created_at
158 , locked_at
159 , deactivated_at
160 , can_request_admin
161 , is_guest
162 FROM users
163 WHERE user_id = $1
164 "#,
165 Uuid::from(id),
166 )
167 .traced()
168 .fetch_optional(&mut *self.conn)
169 .await?;
170
171 let Some(res) = res else { return Ok(None) };
172
173 Ok(Some(res.into()))
174 }
175
176 #[tracing::instrument(
177 name = "db.user.find_by_username",
178 skip_all,
179 fields(
180 db.query.text,
181 user.username = username,
182 ),
183 err,
184 )]
185 async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
186 let res = sqlx::query_as!(
190 UserLookup,
191 r#"
192 SELECT user_id
193 , username
194 , created_at
195 , locked_at
196 , deactivated_at
197 , can_request_admin
198 , is_guest
199 FROM users
200 WHERE LOWER(username) = LOWER($1)
201 "#,
202 username,
203 )
204 .traced()
205 .fetch_all(&mut *self.conn)
206 .await?;
207
208 match &res[..] {
209 [user] => Ok(Some(user.clone().into())),
211 [] => Ok(None),
213 list => {
214 if let Some(user) = list.iter().find(|user| user.username == username) {
217 Ok(Some(user.clone().into()))
218 } else {
219 Ok(None)
221 }
222 }
223 }
224 }
225
226 #[tracing::instrument(
227 name = "db.user.add",
228 skip_all,
229 fields(
230 db.query.text,
231 user.username = username,
232 user.id,
233 ),
234 err,
235 )]
236 async fn add(
237 &mut self,
238 rng: &mut (dyn RngCore + Send),
239 clock: &dyn Clock,
240 username: String,
241 ) -> Result<User, Self::Error> {
242 let created_at = clock.now();
243 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
244 tracing::Span::current().record("user.id", tracing::field::display(id));
245
246 let res = sqlx::query!(
247 r#"
248 INSERT INTO users (user_id, username, created_at)
249 VALUES ($1, $2, $3)
250 ON CONFLICT (username) DO NOTHING
251 "#,
252 Uuid::from(id),
253 username,
254 created_at,
255 )
256 .traced()
257 .execute(&mut *self.conn)
258 .await?;
259
260 DatabaseError::ensure_affected_rows(&res, 1)?;
263
264 Ok(User {
265 id,
266 username,
267 sub: id.to_string(),
268 created_at,
269 locked_at: None,
270 deactivated_at: None,
271 can_request_admin: false,
272 is_guest: false,
273 })
274 }
275
276 #[tracing::instrument(
277 name = "db.user.exists",
278 skip_all,
279 fields(
280 db.query.text,
281 user.username = username,
282 ),
283 err,
284 )]
285 async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
286 let exists = sqlx::query_scalar!(
287 r#"
288 SELECT EXISTS(
289 SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
290 ) AS "exists!"
291 "#,
292 username
293 )
294 .traced()
295 .fetch_one(&mut *self.conn)
296 .await?;
297
298 Ok(exists)
299 }
300
301 #[tracing::instrument(
302 name = "db.user.lock",
303 skip_all,
304 fields(
305 db.query.text,
306 %user.id,
307 ),
308 err,
309 )]
310 async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
311 if user.locked_at.is_some() {
312 return Ok(user);
313 }
314
315 let locked_at = clock.now();
316 let res = sqlx::query!(
317 r#"
318 UPDATE users
319 SET locked_at = $1
320 WHERE user_id = $2
321 "#,
322 locked_at,
323 Uuid::from(user.id),
324 )
325 .traced()
326 .execute(&mut *self.conn)
327 .await?;
328
329 DatabaseError::ensure_affected_rows(&res, 1)?;
330
331 user.locked_at = Some(locked_at);
332
333 Ok(user)
334 }
335
336 #[tracing::instrument(
337 name = "db.user.unlock",
338 skip_all,
339 fields(
340 db.query.text,
341 %user.id,
342 ),
343 err,
344 )]
345 async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
346 if user.locked_at.is_none() {
347 return Ok(user);
348 }
349
350 let res = sqlx::query!(
351 r#"
352 UPDATE users
353 SET locked_at = NULL
354 WHERE user_id = $1
355 "#,
356 Uuid::from(user.id),
357 )
358 .traced()
359 .execute(&mut *self.conn)
360 .await?;
361
362 DatabaseError::ensure_affected_rows(&res, 1)?;
363
364 user.locked_at = None;
365
366 Ok(user)
367 }
368
369 #[tracing::instrument(
370 name = "db.user.deactivate",
371 skip_all,
372 fields(
373 db.query.text,
374 %user.id,
375 ),
376 err,
377 )]
378 async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
379 if user.deactivated_at.is_some() {
380 return Ok(user);
381 }
382
383 let deactivated_at = clock.now();
384 let res = sqlx::query!(
385 r#"
386 UPDATE users
387 SET deactivated_at = $2
388 WHERE user_id = $1
389 AND deactivated_at IS NULL
390 "#,
391 Uuid::from(user.id),
392 deactivated_at,
393 )
394 .traced()
395 .execute(&mut *self.conn)
396 .await?;
397
398 DatabaseError::ensure_affected_rows(&res, 1)?;
399
400 user.deactivated_at = Some(deactivated_at);
401
402 Ok(user)
403 }
404
405 #[tracing::instrument(
406 name = "db.user.reactivate",
407 skip_all,
408 fields(
409 db.query.text,
410 %user.id,
411 ),
412 err,
413 )]
414 async fn reactivate(&mut self, mut user: User) -> Result<User, Self::Error> {
415 if user.deactivated_at.is_none() {
416 return Ok(user);
417 }
418
419 let res = sqlx::query!(
420 r#"
421 UPDATE users
422 SET deactivated_at = NULL
423 WHERE user_id = $1
424 "#,
425 Uuid::from(user.id),
426 )
427 .traced()
428 .execute(&mut *self.conn)
429 .await?;
430
431 DatabaseError::ensure_affected_rows(&res, 1)?;
432
433 user.deactivated_at = None;
434
435 Ok(user)
436 }
437
438 #[tracing::instrument(
439 name = "db.user.delete_unsupported_threepids",
440 skip_all,
441 fields(
442 db.query.text,
443 %user.id,
444 ),
445 err,
446 )]
447 async fn delete_unsupported_threepids(&mut self, user: &User) -> Result<usize, Self::Error> {
448 let res = sqlx::query!(
449 r#"
450 DELETE FROM user_unsupported_third_party_ids
451 WHERE user_id = $1
452 "#,
453 Uuid::from(user.id),
454 )
455 .traced()
456 .execute(&mut *self.conn)
457 .await?;
458
459 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
460 }
461
462 #[tracing::instrument(
463 name = "db.user.set_can_request_admin",
464 skip_all,
465 fields(
466 db.query.text,
467 %user.id,
468 user.can_request_admin = can_request_admin,
469 ),
470 err,
471 )]
472 async fn set_can_request_admin(
473 &mut self,
474 mut user: User,
475 can_request_admin: bool,
476 ) -> Result<User, Self::Error> {
477 let res = sqlx::query!(
478 r#"
479 UPDATE users
480 SET can_request_admin = $2
481 WHERE user_id = $1
482 "#,
483 Uuid::from(user.id),
484 can_request_admin,
485 )
486 .traced()
487 .execute(&mut *self.conn)
488 .await?;
489
490 DatabaseError::ensure_affected_rows(&res, 1)?;
491
492 user.can_request_admin = can_request_admin;
493
494 Ok(user)
495 }
496
497 #[tracing::instrument(
498 name = "db.user.list",
499 skip_all,
500 fields(
501 db.query.text,
502 ),
503 err,
504 )]
505 async fn list(
506 &mut self,
507 filter: UserFilter<'_>,
508 pagination: mas_storage::Pagination,
509 ) -> Result<mas_storage::Page<User>, Self::Error> {
510 let (sql, arguments) = Query::select()
511 .expr_as(
512 Expr::col((Users::Table, Users::UserId)),
513 UserLookupIden::UserId,
514 )
515 .expr_as(
516 Expr::col((Users::Table, Users::Username)),
517 UserLookupIden::Username,
518 )
519 .expr_as(
520 Expr::col((Users::Table, Users::CreatedAt)),
521 UserLookupIden::CreatedAt,
522 )
523 .expr_as(
524 Expr::col((Users::Table, Users::LockedAt)),
525 UserLookupIden::LockedAt,
526 )
527 .expr_as(
528 Expr::col((Users::Table, Users::DeactivatedAt)),
529 UserLookupIden::DeactivatedAt,
530 )
531 .expr_as(
532 Expr::col((Users::Table, Users::CanRequestAdmin)),
533 UserLookupIden::CanRequestAdmin,
534 )
535 .expr_as(
536 Expr::col((Users::Table, Users::IsGuest)),
537 UserLookupIden::IsGuest,
538 )
539 .from(Users::Table)
540 .apply_filter(filter)
541 .generate_pagination((Users::Table, Users::UserId), pagination)
542 .build_sqlx(PostgresQueryBuilder);
543
544 let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
545 .traced()
546 .fetch_all(&mut *self.conn)
547 .await?;
548
549 let page = pagination.process(edges).map(User::from);
550
551 Ok(page)
552 }
553
554 #[tracing::instrument(
555 name = "db.user.count",
556 skip_all,
557 fields(
558 db.query.text,
559 ),
560 err,
561 )]
562 async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
563 let (sql, arguments) = Query::select()
564 .expr(Expr::col((Users::Table, Users::UserId)).count())
565 .from(Users::Table)
566 .apply_filter(filter)
567 .build_sqlx(PostgresQueryBuilder);
568
569 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
570 .traced()
571 .fetch_one(&mut *self.conn)
572 .await?;
573
574 count
575 .try_into()
576 .map_err(DatabaseError::to_invalid_operation)
577 }
578
579 #[tracing::instrument(
580 name = "db.user.acquire_lock_for_sync",
581 skip_all,
582 fields(
583 db.query.text,
584 user.id = %user.id,
585 ),
586 err,
587 )]
588 async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
589 let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
597
598 sqlx::query!(
601 r#"
602 SELECT pg_advisory_xact_lock($1)
603 "#,
604 lock_id,
605 )
606 .traced()
607 .execute(&mut *self.conn)
608 .await?;
609
610 Ok(())
611 }
612}