@@ -4,7 +4,9 @@ mod sorter;
44
55use std:: fmt;
66use std:: fmt:: { Debug , Formatter } ;
7+ use std:: future:: Future ;
78
9+ pub use cot_macros:: migration_op;
810use sea_query:: { ColumnDef , StringLen } ;
911use thiserror:: Error ;
1012use tracing:: { Level , info} ;
@@ -20,9 +22,13 @@ pub enum MigrationEngineError {
2022 /// An error occurred while determining the correct order of migrations.
2123 #[ error( "error while determining the correct order of migrations" ) ]
2224 MigrationSortError ( #[ from] MigrationSorterError ) ,
25+ /// A custom error occurred during a migration.
26+ #[ error( "error running migration: {0}" ) ]
27+ Custom ( String ) ,
2328}
2429
25- /// A migration engine that can run migrations.
30+ /// A migration engine responsible for managing and applying database
31+ /// migrations.
2632///
2733/// # Examples
2834///
@@ -133,7 +139,7 @@ impl MigrationEngine {
133139 ///
134140 /// # Errors
135141 ///
136- /// Throws an error if any of the migrations fail to apply, or if there is
142+ /// Returns an error if any of the migrations fail to apply, or if there is
137143 /// an error while interacting with the database, or if there is an
138144 /// error while marking a migration as applied.
139145 ///
@@ -424,11 +430,37 @@ impl Operation {
424430 RemoveModelBuilder :: new ( )
425431 }
426432
433+ // TODO: docs
434+ pub const fn alter_field ( ) -> AlterFieldBuilder {
435+ AlterFieldBuilder :: new ( )
436+ }
437+
438+ /// Returns a builder for a custom operation.
439+ ///
440+ /// # Examples
441+ ///
442+ /// ```
443+ /// use cot::db::Result;
444+ /// use cot::db::migrations::{MigrationContext, Operation, migration_op};
445+ ///
446+ /// #[migration_op]
447+ /// async fn forwards(ctx: MigrationContext<'_>) -> Result<()> {
448+ /// // do something
449+ /// Ok(())
450+ /// }
451+ ///
452+ /// const OPERATION: Operation = Operation::custom(forwards).build();
453+ /// ```
454+ #[ must_use]
455+ pub const fn custom ( forwards : CustomOperationFn ) -> CustomBuilder {
456+ CustomBuilder :: new ( forwards)
457+ }
458+
427459 /// Runs the operation forwards.
428460 ///
429461 /// # Errors
430462 ///
431- /// Throws an error if the operation fails to apply.
463+ /// Returns an error if the operation fails to apply.
432464 ///
433465 /// # Examples
434466 ///
@@ -512,6 +544,13 @@ impl Operation {
512544 let query = sea_query:: Table :: drop ( ) . table ( * table_name) . to_owned ( ) ;
513545 database. execute_schema ( query) . await ?;
514546 }
547+ OperationInner :: Custom {
548+ forwards,
549+ backwards : _,
550+ } => {
551+ let context = MigrationContext :: new ( database) ;
552+ forwards ( context) . await ?;
553+ }
515554 }
516555 Ok ( ( ) )
517556 }
@@ -521,7 +560,7 @@ impl Operation {
521560 ///
522561 /// # Errors
523562 ///
524- /// Throws an error if the operation fails to apply.
563+ /// Returns an error if the operation fails to apply.
525564 ///
526565 /// # Examples
527566 ///
@@ -601,11 +640,50 @@ impl Operation {
601640 }
602641 database. execute_schema ( query) . await ?;
603642 }
643+ OperationInner :: Custom {
644+ forwards : _,
645+ backwards,
646+ } => {
647+ if let Some ( backwards) = backwards {
648+ let context = MigrationContext :: new ( database) ;
649+ backwards ( context) . await ?;
650+ } else {
651+ return Err ( crate :: db:: DatabaseError :: MigrationError (
652+ MigrationEngineError :: Custom ( "Backwards migration not implemented" . into ( ) ) ,
653+ ) ) ;
654+ }
655+ }
604656 }
605657 Ok ( ( ) )
606658 }
607659}
608660
661+ /// A context for a custom migration operation.
662+ ///
663+ /// This structure provides access to the database and other information that
664+ /// might be needed during a migration.
665+ #[ derive( Debug ) ]
666+ #[ non_exhaustive]
667+ pub struct MigrationContext < ' a > {
668+ /// The database connection to run the migration against.
669+ pub db : & ' a Database ,
670+ }
671+
672+ impl < ' a > MigrationContext < ' a > {
673+ fn new ( db : & ' a Database ) -> Self {
674+ Self { db }
675+ }
676+ }
677+
678+ /// A type alias for a custom migration operation function.
679+ ///
680+ /// Typically, you should use the [`migration_op`] attribute macro to define
681+ /// functions of this type.
682+ pub type CustomOperationFn =
683+ for <' a > fn (
684+ MigrationContext < ' a > ,
685+ ) -> std:: pin:: Pin < Box < dyn Future < Output = Result < ( ) > > + Send + ' a > > ;
686+
609687#[ derive( Debug , Copy , Clone ) ]
610688enum OperationInner {
611689 /// Create a new model with the given fields.
@@ -634,6 +712,10 @@ enum OperationInner {
634712 old_field : Field ,
635713 new_field : Field ,
636714 } ,
715+ Custom {
716+ forwards : CustomOperationFn ,
717+ backwards : Option < CustomOperationFn > ,
718+ } ,
637719}
638720
639721/// A field in a model.
@@ -687,6 +769,12 @@ impl Field {
687769
688770 /// Marks the field as a foreign key to the given model and field.
689771 ///
772+ /// # Panics
773+ ///
774+ /// This function will panic if `on_delete` or `on_update` is set to
775+ /// [`SetNone`](ForeignKeyOnDeletePolicy::SetNone) and the field is not
776+ /// nullable.
777+ ///
690778 /// # Cot CLI Usage
691779 ///
692780 /// Typically, you shouldn't need to use this directly. Instead, in most
@@ -1558,13 +1646,61 @@ impl RemoveModelBuilder {
15581646 }
15591647}
15601648
1561- /// Returns a builder for an operation that alters a field in a model.
1562- pub const fn alter_field ( ) -> AlterFieldBuilder {
1563- AlterFieldBuilder :: new ( )
1649+ /// A builder for a custom operation.
1650+ ///
1651+ /// # Examples
1652+ ///
1653+ /// ```
1654+ /// use cot::db::Result;
1655+ /// use cot::db::migrations::{MigrationContext, Operation, migration_op};
1656+ ///
1657+ /// #[migration_op]
1658+ /// async fn forwards(ctx: MigrationContext<'_>) -> Result<()> {
1659+ /// // do something
1660+ /// Ok(())
1661+ /// }
1662+ ///
1663+ /// #[migration_op]
1664+ /// async fn backwards(ctx: MigrationContext<'_>) -> Result<()> {
1665+ /// // undo something
1666+ /// Ok(())
1667+ /// }
1668+ ///
1669+ /// const OPERATION: Operation = Operation::custom(forwards).backwards(backwards).build();
1670+ /// ```
1671+ #[ derive( Debug , Copy , Clone ) ]
1672+ pub struct CustomBuilder {
1673+ forwards : CustomOperationFn ,
1674+ backwards : Option < CustomOperationFn > ,
1675+ }
1676+
1677+ impl CustomBuilder {
1678+ #[ must_use]
1679+ const fn new ( forwards : CustomOperationFn ) -> Self {
1680+ Self {
1681+ forwards,
1682+ backwards : None ,
1683+ }
1684+ }
1685+
1686+ /// Sets the backwards operation.
1687+ #[ must_use]
1688+ pub const fn backwards ( mut self , backwards : CustomOperationFn ) -> Self {
1689+ self . backwards = Some ( backwards) ;
1690+ self
1691+ }
1692+
1693+ /// Builds the operation.
1694+ #[ must_use]
1695+ pub const fn build ( self ) -> Operation {
1696+ Operation :: new ( OperationInner :: Custom {
1697+ forwards : self . forwards ,
1698+ backwards : self . backwards ,
1699+ } )
1700+ }
15641701}
15651702
1566- /// A builder for altering a field in a model.
1567- #[ must_use]
1703+ // TODO: docs
15681704#[ derive( Debug , Copy , Clone ) ]
15691705pub struct AlterFieldBuilder {
15701706 table_name : Option < Identifier > ,
@@ -1600,7 +1736,6 @@ impl AlterFieldBuilder {
16001736 }
16011737
16021738 /// Builds the operation.
1603- #[ must_use]
16041739 pub const fn build ( self ) -> Operation {
16051740 Operation :: new ( OperationInner :: AlterField {
16061741 table_name : unwrap_builder_option ! ( self , table_name) ,
@@ -2088,6 +2223,86 @@ mod tests {
20882223 }
20892224 }
20902225
2226+ #[ cot:: test]
2227+ #[ cfg_attr(
2228+ miri,
2229+ ignore = "unsupported operation: can't call foreign function `sqlite3_open_v2`"
2230+ ) ]
2231+ async fn operation_custom ( ) {
2232+ // test only on SQLite because we are using raw SQL
2233+ let test_db = TestDatabase :: new_sqlite ( ) . await . unwrap ( ) ;
2234+
2235+ #[ migration_op]
2236+ async fn forwards ( ctx : MigrationContext < ' _ > ) -> Result < ( ) > {
2237+ ctx. db
2238+ . raw ( "CREATE TABLE custom_test (id INTEGER PRIMARY KEY)" )
2239+ . await ?;
2240+ Ok ( ( ) )
2241+ }
2242+
2243+ let operation = Operation :: custom ( forwards) . build ( ) ;
2244+ operation. forwards ( & test_db. database ( ) ) . await . unwrap ( ) ;
2245+
2246+ let result = test_db. database ( ) . raw ( "SELECT * FROM custom_test" ) . await ;
2247+ assert ! ( result. is_ok( ) ) ;
2248+ }
2249+
2250+ #[ cot:: test]
2251+ #[ cfg_attr(
2252+ miri,
2253+ ignore = "unsupported operation: can't call foreign function `sqlite3_open_v2`"
2254+ ) ]
2255+ async fn operation_custom_backwards ( ) {
2256+ // test only on SQLite because we are using raw SQL
2257+ let test_db = TestDatabase :: new_sqlite ( ) . await . unwrap ( ) ;
2258+
2259+ #[ migration_op]
2260+ async fn forwards ( _ctx : MigrationContext < ' _ > ) -> Result < ( ) > {
2261+ panic ! ( "this should not be called" ) ;
2262+ }
2263+
2264+ #[ migration_op]
2265+ async fn backwards ( ctx : MigrationContext < ' _ > ) -> Result < ( ) > {
2266+ ctx. db . raw ( "DROP TABLE custom_test_back" ) . await ?;
2267+ Ok ( ( ) )
2268+ }
2269+
2270+ test_db
2271+ . database ( )
2272+ . raw ( "CREATE TABLE custom_test_back (id INTEGER PRIMARY KEY)" )
2273+ . await
2274+ . unwrap ( ) ;
2275+
2276+ let operation = Operation :: custom ( forwards) . backwards ( backwards) . build ( ) ;
2277+ operation. backwards ( & test_db. database ( ) ) . await . unwrap ( ) ;
2278+
2279+ let result = test_db
2280+ . database ( )
2281+ . raw ( "SELECT * FROM custom_test_back" )
2282+ . await ;
2283+ assert ! ( result. is_err( ) ) ;
2284+ }
2285+
2286+ #[ cot:: test]
2287+ #[ cfg_attr(
2288+ miri,
2289+ ignore = "unsupported operation: can't call foreign function `sqlite3_open_v2`"
2290+ ) ]
2291+ async fn operation_custom_backwards_not_implemented ( ) {
2292+ // test only on SQLite because we are using raw SQL
2293+ let test_db = TestDatabase :: new_sqlite ( ) . await . unwrap ( ) ;
2294+
2295+ #[ migration_op]
2296+ async fn forwards ( _ctx : MigrationContext < ' _ > ) -> Result < ( ) > {
2297+ Ok ( ( ) )
2298+ }
2299+
2300+ let operation = Operation :: custom ( forwards) . build ( ) ;
2301+ let result = operation. backwards ( & test_db. database ( ) ) . await ;
2302+
2303+ assert ! ( result. is_err( ) ) ;
2304+ }
2305+
20912306 #[ test]
20922307 fn field_new ( ) {
20932308 let field = Field :: new ( Identifier :: new ( "id" ) , ColumnType :: Integer )
0 commit comments