11import json
22import re
3- from string import Formatter
3+ import traceback
44
55from IPython .core .magic import (
66 Magics ,
1010 needs_local_scope ,
1111)
1212from IPython .core .magic_arguments import argument , magic_arguments , parse_argstring
13- from IPython .display import display_javascript
1413from sqlalchemy .exc import OperationalError , ProgrammingError , DatabaseError
1514
1615import sql .connection
@@ -46,7 +45,8 @@ class SqlMagic(Magics, Configurable):
4645 style = Unicode (
4746 "DEFAULT" ,
4847 config = True ,
49- help = "Set the table printing style to any of prettytable's defined styles (currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)" ,
48+ help = "Set the table printing style to any of prettytable's defined styles "
49+ "(currently DEFAULT, MSWORD_FRIENDLY, PLAIN_COLUMNS, RANDOM)" ,
5050 )
5151 short_errors = Bool (
5252 True ,
@@ -72,17 +72,17 @@ class SqlMagic(Magics, Configurable):
7272 "odbc.ini" ,
7373 config = True ,
7474 help = "Path to DSN file. "
75- "When the first argument is of the form [section], "
76- "a sqlalchemy connection string is formed from the "
77- "matching section in the DSN file." ,
75+ "When the first argument is of the form [section], "
76+ "a sqlalchemy connection string is formed from the "
77+ "matching section in the DSN file." ,
7878 )
7979 autocommit = Bool (True , config = True , help = "Set autocommit mode" )
8080
8181 def __init__ (self , shell ):
8282 Configurable .__init__ (self , config = shell .config )
8383 Magics .__init__ (self , shell = shell )
8484
85- # Add ourself to the list of module configurable via %config
85+ # Add ourselves to the list of module configurable via %config
8686 self .shell .configurables .append (self )
8787
8888 @needs_local_scope
@@ -121,7 +121,7 @@ def __init__(self, shell):
121121 help = "specify dictionary of connection arguments to pass to SQL driver" ,
122122 )
123123 @argument ("-f" , "--file" , type = str , help = "Run SQL from file at this path" )
124- def execute (self , line = "" , cell = "" , local_ns = {} ):
124+ def execute (self , line = "" , cell = "" , local_ns = None ):
125125 """Runs SQL statement against a database, specified by SQLAlchemy connect string.
126126
127127 If no database connection has been established, first word
@@ -147,15 +147,17 @@ def execute(self, line="", cell="", local_ns={}):
147147
148148 """
149149 # Parse variables (words wrapped in {}) for %%sql magic (for %sql this is done automatically)
150+ if local_ns is None :
151+ local_ns = {}
150152 cell = self .shell .var_expand (cell )
151153 line = sql .parse .without_sql_comment (parser = self .execute .parser , line = line )
152154 args = parse_argstring (self .execute , line )
153155 if args .connections :
154156 return sql .connection .Connection .connections
155157 elif args .close :
156- return sql .connection .Connection ._close (args .close )
158+ return sql .connection .Connection .close (args .close )
157159
158- # save globals and locals so they can be referenced in bind vars
160+ # save globals and locals, so they can be referenced in bind vars
159161 user_ns = self .shell .user_ns .copy ()
160162 user_ns .update (local_ns )
161163
@@ -173,7 +175,7 @@ def execute(self, line="", cell="", local_ns={}):
173175
174176 if args .connection_arguments :
175177 try :
176- # check for string deliniators , we need to strip them for json parse
178+ # check for string delineators , we need to strip them for json parse
177179 raw_args = args .connection_arguments
178180 if len (raw_args ) > 1 :
179181 targets = ['"' , "'" ]
@@ -183,7 +185,7 @@ def execute(self, line="", cell="", local_ns={}):
183185 raw_args = raw_args [1 :- 1 ]
184186 args .connection_arguments = json .loads (raw_args )
185187 except Exception as e :
186- print (e )
188+ print (traceback . format_exc () )
187189 raise e
188190 else :
189191 args .connection_arguments = {}
@@ -197,8 +199,10 @@ def execute(self, line="", cell="", local_ns={}):
197199 connect_args = args .connection_arguments ,
198200 creator = args .creator ,
199201 )
200- except Exception as e :
201- print (e )
202+ # Rollback just in case there was an error in previous statement
203+ conn .internal_connection .rollback ()
204+ except Exception :
205+ print (traceback .format_exc ())
202206 print (sql .connection .Connection .tell_format ())
203207 return None
204208
@@ -220,7 +224,7 @@ def execute(self, line="", cell="", local_ns={}):
220224 and self .column_local_vars
221225 ):
222226 # Instead of returning values, set variables directly in the
223- # users namespace. Variable names given by column names
227+ # user's namespace. Variable names given by column names
224228
225229 if self .autopandas :
226230 keys = result .keys ()
@@ -253,7 +257,8 @@ def execute(self, line="", cell="", local_ns={}):
253257 if self .short_errors :
254258 print (e )
255259 else :
256- raise
260+ print (traceback .format_exc ())
261+ raise e
257262
258263 legal_sql_identifier = re .compile (r"^[A-Za-z0-9#_$]+" )
259264
@@ -279,7 +284,7 @@ def _persist_dataframe(self, raw, conn, user_ns, append=False):
279284 table_name = self .legal_sql_identifier .search (table_name ).group (0 )
280285
281286 if_exists = "append" if append else "fail"
282- frame .to_sql (table_name , conn .session .engine , if_exists = if_exists )
287+ frame .to_sql (table_name , conn .internal_connection .engine , if_exists = if_exists )
283288 return "Persisted %s" % table_name
284289
285290
0 commit comments