Skip to content

Commit e513343

Browse files
committed
feat(orm)!: transactions support
1 parent 5556616 commit e513343

13 files changed

Lines changed: 790 additions & 198 deletions

File tree

cot-macros/src/model.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ impl ModelBuilder {
176176
let fields_as_get_values = &self.fields_as_get_values;
177177

178178
quote! {
179-
#[#crate_ident::__private::async_trait]
180179
#[automatically_derived]
180+
#[#orm_ident::async_trait]
181181
impl #orm_ident::Model for #name {
182182
type Fields = #fields_struct_name;
183183
type PrimaryKey = #pk_type;
@@ -225,11 +225,11 @@ impl ModelBuilder {
225225
}
226226

227227
async fn get_by_primary_key<DB: #orm_ident::DatabaseBackend>(
228-
db: &DB,
228+
mut db: DB,
229229
pk: Self::PrimaryKey,
230230
) -> #orm_ident::Result<Option<Self>> {
231231
#orm_ident::query!(Self, $#pk_field_name == pk)
232-
.get(db)
232+
.get(&mut db)
233233
.await
234234
}
235235
}

cot/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ subtle = { workspace = true, features = ["std"] }
5959
swagger-ui-redist = { workspace = true, optional = true }
6060
thiserror.workspace = true
6161
time.workspace = true
62-
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal", "fs", "io-util"] }
62+
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal", "fs", "io-util", "sync"] }
6363
toml = { workspace = true, features = ["parse", "serde"] }
6464
tower = { workspace = true, features = ["util"] }
6565
tower-livereload = { workspace = true, optional = true }

cot/src/auth/db.rs

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ impl DatabaseUser {
9797
/// # }
9898
/// ```
9999
pub async fn create_user<DB: DatabaseBackend, T: Into<String>, U: Into<Password>>(
100-
db: &DB,
100+
mut db: DB,
101101
username: T,
102102
password: U,
103103
) -> Result<Self> {
@@ -108,7 +108,9 @@ impl DatabaseUser {
108108
})?;
109109

110110
let mut user = Self::new(Auto::auto(), username, &password.into());
111-
user.insert(db).await.map_err(AuthError::backend_error)?;
111+
user.insert(&mut db)
112+
.await
113+
.map_err(AuthError::backend_error)?;
112114

113115
Ok(user)
114116
}
@@ -153,9 +155,9 @@ impl DatabaseUser {
153155
/// # Ok(())
154156
/// # }
155157
/// ```
156-
pub async fn get_by_id<DB: DatabaseBackend>(db: &DB, id: i64) -> Result<Option<Self>> {
158+
pub async fn get_by_id<DB: DatabaseBackend>(mut db: DB, id: i64) -> Result<Option<Self>> {
157159
let db_user = query!(DatabaseUser, $id == id)
158-
.get(db)
160+
.get(&mut db)
159161
.await
160162
.map_err(AuthError::backend_error)?;
161163

@@ -199,14 +201,14 @@ impl DatabaseUser {
199201
/// # }
200202
/// ```
201203
pub async fn get_by_username<DB: DatabaseBackend>(
202-
db: &DB,
204+
mut db: DB,
203205
username: &str,
204206
) -> Result<Option<Self>> {
205207
let username = LimitedString::<MAX_USERNAME_LENGTH>::new(username).map_err(|_| {
206208
AuthError::backend_error(CreateUserError::UsernameTooLong(username.len()))
207209
})?;
208210
let db_user = query!(DatabaseUser, $username == username)
209-
.get(db)
211+
.get(&mut db)
210212
.await
211213
.map_err(AuthError::backend_error)?;
212214

@@ -219,7 +221,7 @@ impl DatabaseUser {
219221
///
220222
/// Returns an error if there was an error querying the database.
221223
pub async fn authenticate<DB: DatabaseBackend>(
222-
db: &DB,
224+
mut db: DB,
223225
credentials: &DatabaseUserCredentials,
224226
) -> Result<Option<Self>> {
225227
let username = credentials.username();
@@ -228,7 +230,7 @@ impl DatabaseUser {
228230
AuthError::backend_error(CreateUserError::UsernameTooLong(username.len()))
229231
})?;
230232
let user = query!(DatabaseUser, $username == username_limited)
231-
.get(db)
233+
.get(&mut db)
232234
.await
233235
.map_err(AuthError::backend_error)?;
234236

@@ -238,7 +240,7 @@ impl DatabaseUser {
238240
PasswordVerificationResult::Ok => Ok(Some(user)),
239241
PasswordVerificationResult::OkObsolete(new_hash) => {
240242
user.password = new_hash;
241-
user.save(db).await.map_err(AuthError::backend_error)?;
243+
user.save(&mut db).await.map_err(AuthError::backend_error)?;
242244
Ok(Some(user))
243245
}
244246
PasswordVerificationResult::Invalid => Ok(None),
@@ -620,7 +622,7 @@ mod tests {
620622
let username = "testuser".to_string();
621623
let password = Password::new("password123");
622624

623-
let user = DatabaseUser::create_user(&mock_db, username.clone(), &password)
625+
let user = DatabaseUser::create_user(&mut mock_db, username.clone(), &password)
624626
.await
625627
.unwrap();
626628
assert_eq!(user.username(), username);
@@ -640,7 +642,7 @@ mod tests {
640642
.expect_get::<DatabaseUser>()
641643
.returning(move |_| Ok(Some(user.clone())));
642644

643-
let result = DatabaseUser::get_by_id(&mock_db, 1).await.unwrap();
645+
let result = DatabaseUser::get_by_id(&mut mock_db, 1).await.unwrap();
644646
assert!(result.is_some());
645647
assert_eq!(result.unwrap().username(), "testuser");
646648
}
@@ -661,7 +663,7 @@ mod tests {
661663

662664
let credentials =
663665
DatabaseUserCredentials::new("testuser".to_string(), Password::new("password123"));
664-
let result = DatabaseUser::authenticate(&mock_db, &credentials)
666+
let result = DatabaseUser::authenticate(&mut mock_db, &credentials)
665667
.await
666668
.unwrap();
667669
assert!(result.is_some());
@@ -679,7 +681,7 @@ mod tests {
679681

680682
let credentials =
681683
DatabaseUserCredentials::new("testuser".to_string(), Password::new("password123"));
682-
let result = DatabaseUser::authenticate(&mock_db, &credentials)
684+
let result = DatabaseUser::authenticate(&mut mock_db, &credentials)
683685
.await
684686
.unwrap();
685687
assert!(result.is_none());
@@ -701,7 +703,7 @@ mod tests {
701703

702704
let credentials =
703705
DatabaseUserCredentials::new("testuser".to_string(), Password::new("invalid"));
704-
let result = DatabaseUser::authenticate(&mock_db, &credentials)
706+
let result = DatabaseUser::authenticate(&mut mock_db, &credentials)
705707
.await
706708
.unwrap();
707709
assert!(result.is_none());

0 commit comments

Comments
 (0)