@@ -157,7 +157,7 @@ class ModelBase:
157157 _cache : dict [tuple , Any ]
158158 table_name : str
159159 primary_key_names : tuple [str , ...]
160- array_safe_insert : bool
160+ insert_multiple_mode : str
161161
162162 def __init_subclass__ (
163163 cls ,
@@ -169,12 +169,9 @@ def __init_subclass__(
169169 ):
170170 cls ._cache = {}
171171 cls .table_name = table_name
172- if insert_multiple_mode == "array_safe" :
173- cls .array_safe_insert = True
174- elif insert_multiple_mode == "unnest" :
175- cls .array_safe_insert = False
176- else :
172+ if insert_multiple_mode not in ("array_safe" , "unnest" , "executemany" ):
177173 raise ValueError ("Unknown `insert_multiple_mode`" )
174+ cls .insert_multiple_mode = insert_multiple_mode
178175 if isinstance (primary_key , str ):
179176 cls .primary_key_names = (primary_key ,)
180177 else :
@@ -357,19 +354,37 @@ def select_sql(
357354 return query
358355
359356 @classmethod
360- async def select_cursor (
357+ async def cursor_from (
358+ cls : type [T ],
359+ connection : Connection ,
360+ query : Fragment ,
361+ prefetch : int = 1000 ,
362+ ) -> AsyncGenerator [T , None ]:
363+ async for row in connection .cursor (* query , prefetch = prefetch ):
364+ yield cls .from_mapping (row )
365+
366+ @classmethod
367+ def select_cursor (
361368 cls : type [T ],
362369 connection : Connection ,
363370 order_by : Union [FieldNames , str ] = (),
364371 for_update : bool = False ,
365372 where : Where = (),
366373 prefetch : int = 1000 ,
367374 ) -> AsyncGenerator [T , None ]:
368- async for row in connection .cursor (
369- * cls .select_sql (order_by = order_by , for_update = for_update , where = where ),
375+ return cls .cursor_from (
376+ connection ,
377+ cls .select_sql (order_by = order_by , for_update = for_update , where = where ),
370378 prefetch = prefetch ,
371- ):
372- yield cls .from_mapping (row )
379+ )
380+
381+ @classmethod
382+ async def fetch_from (
383+ cls : type [T ],
384+ connection_or_pool : Union [Connection , Pool ],
385+ query : Fragment ,
386+ ) -> list [T ]:
387+ return [cls .from_mapping (row ) for row in await connection_or_pool .fetch (* query )]
373388
374389 @classmethod
375390 async def select (
@@ -379,12 +394,10 @@ async def select(
379394 for_update : bool = False ,
380395 where : Where = (),
381396 ) -> list [T ]:
382- return [
383- cls .from_mapping (row )
384- for row in await connection_or_pool .fetch (
385- * cls .select_sql (order_by = order_by , for_update = for_update , where = where )
386- )
387- ]
397+ return await cls .fetch_from (
398+ connection_or_pool ,
399+ cls .select_sql (order_by = order_by , for_update = for_update , where = where ),
400+ )
388401
389402 @classmethod
390403 def create_sql (cls : type [T ], ** kwargs : Any ) -> Fragment :
@@ -506,6 +519,37 @@ def insert_multiple_array_safe_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
506519 ),
507520 )
508521
522+ @classmethod
523+ def insert_multiple_executemany_chunk_sql (
524+ cls : type [T ], chunk_size : int
525+ ) -> Fragment :
526+ def generate () -> Fragment :
527+ columns = len (cls .column_info ())
528+ values = ", " .join (
529+ f"({ ', ' .join (f'${ i } ' for i in chunk )} )"
530+ for chunk in chunked (range (1 , columns * chunk_size + 1 ), columns )
531+ )
532+ return sql (
533+ "INSERT INTO {table} ({fields}) VALUES {values}" ,
534+ table = cls .table_name_sql (),
535+ fields = sql .list (cls .field_names_sql ()),
536+ values = sql .literal (values ),
537+ ).flatten ()
538+
539+ return cls ._cached (
540+ ("insert_multiple_executemany_chunk" , chunk_size ),
541+ generate ,
542+ )
543+
544+ @classmethod
545+ async def insert_multiple_executemany (
546+ cls : type [T ], connection_or_pool : Union [Connection , Pool ], rows : Iterable [T ]
547+ ) -> None :
548+ args = [r .field_values () for r in rows ]
549+ query = cls .insert_multiple_executemany_chunk_sql (1 ).query ()[0 ]
550+ if args :
551+ await connection_or_pool .executemany (query , args )
552+
509553 @classmethod
510554 async def insert_multiple_unnest (
511555 cls : type [T ], connection_or_pool : Union [Connection , Pool ], rows : Iterable [T ]
@@ -527,11 +571,28 @@ async def insert_multiple_array_safe(
527571 async def insert_multiple (
528572 cls : type [T ], connection_or_pool : Union [Connection , Pool ], rows : Iterable [T ]
529573 ) -> str :
530- if cls .array_safe_insert :
574+ if cls .insert_multiple_mode == "executemany" :
575+ await cls .insert_multiple_executemany (connection_or_pool , rows )
576+ return "INSERT"
577+ elif cls .insert_multiple_mode == "array_safe" :
531578 return await cls .insert_multiple_array_safe (connection_or_pool , rows )
532579 else :
533580 return await cls .insert_multiple_unnest (connection_or_pool , rows )
534581
582+ @classmethod
583+ async def upsert_multiple_executemany (
584+ cls : type [T ],
585+ connection_or_pool : Union [Connection , Pool ],
586+ rows : Iterable [T ],
587+ insert_only : FieldNamesSet = (),
588+ ) -> None :
589+ args = [r .field_values () for r in rows ]
590+ query = cls .upsert_sql (
591+ cls .insert_multiple_executemany_chunk_sql (1 ), exclude = insert_only
592+ ).query ()[0 ]
593+ if args :
594+ await connection_or_pool .executemany (query , args )
595+
535596 @classmethod
536597 async def upsert_multiple_unnest (
537598 cls : type [T ],
@@ -566,7 +627,12 @@ async def upsert_multiple(
566627 rows : Iterable [T ],
567628 insert_only : FieldNamesSet = (),
568629 ) -> str :
569- if cls .array_safe_insert :
630+ if cls .insert_multiple_mode == "executemany" :
631+ await cls .upsert_multiple_executemany (
632+ connection_or_pool , rows , insert_only = insert_only
633+ )
634+ return "INSERT"
635+ elif cls .insert_multiple_mode == "array_safe" :
570636 return await cls .upsert_multiple_array_safe (
571637 connection_or_pool , rows , insert_only = insert_only
572638 )
0 commit comments