@@ -100,12 +100,24 @@ class ModelBase(Mapping[str, Any]):
100100 _cache : Dict [tuple , Any ]
101101 table_name : str
102102 primary_key_names : Tuple [str , ...]
103+ array_safe_insert : bool
103104
104105 def __init_subclass__ (
105- cls , * , table_name : str , primary_key : Union [FieldNames , str ] = (), ** kwargs : Any
106+ cls ,
107+ * ,
108+ table_name : str ,
109+ primary_key : Union [FieldNames , str ] = (),
110+ insert_multiple_mode : str = "unnest" ,
111+ ** kwargs : Any ,
106112 ):
107113 cls ._cache = {}
108114 cls .table_name = table_name
115+ if insert_multiple_mode == "array_safe" :
116+ cls .array_safe_insert = True
117+ elif insert_multiple_mode == "unnest" :
118+ cls .array_safe_insert = False
119+ else :
120+ raise ValueError ("Unknown `insert_multiple_mode`" )
109121 if isinstance (primary_key , str ):
110122 cls .primary_key_names = (primary_key ,)
111123 else :
@@ -407,19 +419,69 @@ def insert_multiple_sql(cls: Type[T], rows: Iterable[T]) -> Fragment:
407419 )
408420
409421 @classmethod
410- async def insert_multiple (
422+ def insert_multiple_array_safe_sql (cls : Type [T ], rows : Iterable [T ]) -> Fragment :
423+ return sql (
424+ "INSERT INTO {table} ({fields}) VALUES {values}" ,
425+ table = cls .table_name_sql (),
426+ fields = sql .list (cls .field_names_sql ()),
427+ values = sql .list (
428+ sql ("({})" , sql .list (row .field_values_sql (default_none = True )))
429+ for row in rows
430+ ),
431+ )
432+
433+ @classmethod
434+ async def insert_multiple_unnest (
411435 cls : Type [T ], connection_or_pool : Union [Connection , Pool ], rows : Iterable [T ]
412436 ) -> str :
413437 return await connection_or_pool .execute (* cls .insert_multiple_sql (rows ))
414438
415439 @classmethod
416- async def upsert_multiple (
440+ async def insert_multiple_array_safe (
441+ cls : Type [T ], connection_or_pool : Union [Connection , Pool ], rows : Iterable [T ]
442+ ) -> str :
443+ for chunk in chunked (rows , 100 ):
444+ last = await connection_or_pool .execute (
445+ * cls .insert_multiple_array_safe_sql (chunk )
446+ )
447+ return last
448+
449+ @classmethod
450+ async def insert_multiple (
451+ cls : Type [T ], connection_or_pool : Union [Connection , Pool ], rows : Iterable [T ]
452+ ) -> str :
453+ if cls .array_safe_insert :
454+ return await cls .insert_multiple_array_safe (connection_or_pool , rows )
455+ else :
456+ return await cls .insert_multiple_unnest (connection_or_pool , rows )
457+
458+ @classmethod
459+ async def upsert_multiple_unnest (
417460 cls : Type [T ], connection_or_pool : Union [Connection , Pool ], rows : Iterable [T ]
418461 ) -> str :
419462 return await connection_or_pool .execute (
420463 * cls .upsert_sql (cls .insert_multiple_sql (rows ))
421464 )
422465
466+ @classmethod
467+ async def upsert_multiple_array_safe (
468+ cls : Type [T ], connection_or_pool : Union [Connection , Pool ], rows : Iterable [T ]
469+ ) -> str :
470+ for chunk in chunked (rows , 100 ):
471+ last = await connection_or_pool .execute (
472+ * cls .upsert_sql (cls .insert_multiple_array_safe_sql (chunk ))
473+ )
474+ return last
475+
476+ @classmethod
477+ async def upsert_multiple (
478+ cls : Type [T ], connection_or_pool : Union [Connection , Pool ], rows : Iterable [T ]
479+ ) -> str :
480+ if cls .array_safe_insert :
481+ return await cls .upsert_multiple_array_safe (connection_or_pool , rows )
482+ else :
483+ return await cls .upsert_multiple_unnest (connection_or_pool , rows )
484+
423485 @classmethod
424486 def _get_equal_ignoring_fn (
425487 cls : Type [T ], ignore : FieldNamesSet = ()
@@ -530,3 +592,8 @@ async def replace_multiple_reporting_differences(
530592 await cls .delete_multiple (connection , deleted )
531593
532594 return created , updated_triples , deleted
595+
596+
597+ def chunked (lst , n ):
598+ for i in range (0 , len (lst ), n ):
599+ yield lst [i : i + n ]
0 commit comments