33#![ allow( clippy:: useless_conversion) ]
44
55use pyo3:: prelude:: * ;
6+ use pyo3:: create_exception;
7+ use pyo3:: exceptions:: PyValueError ;
68use pyo3:: types:: { PyBytes , PyDict , PyList } ;
79use std:: io:: Cursor ;
810
@@ -12,6 +14,28 @@ use ggsql::validate::{validate as rust_validate, ValidationWarning};
1214use ggsql:: writer:: { VegaLiteWriter as RustVegaLiteWriter , Writer as RustWriter } ;
1315use ggsql:: GgsqlError ;
1416
17+ // ============================================================================
18+ // Custom Exception Classes
19+ // ============================================================================
20+
21+ // All subclass ValueError for backwards compatibility
22+ create_exception ! ( ggsql, ParseError , PyValueError , "Raised on query syntax errors." ) ;
23+ create_exception ! ( ggsql, ValidationError , PyValueError , "Raised on semantic validation errors." ) ;
24+ create_exception ! ( ggsql, ReaderError , PyValueError , "Raised on data source errors." ) ;
25+ create_exception ! ( ggsql, WriterError , PyValueError , "Raised on output generation errors." ) ;
26+
27+ /// Convert a GgsqlError to the appropriate typed Python exception.
28+ fn ggsql_err_to_py ( e : GgsqlError ) -> PyErr {
29+ let msg = e. to_string ( ) ;
30+ match e {
31+ GgsqlError :: ParseError ( _) => PyErr :: new :: < ParseError , _ > ( msg) ,
32+ GgsqlError :: ValidationError ( _) => PyErr :: new :: < ValidationError , _ > ( msg) ,
33+ GgsqlError :: ReaderError ( _) => PyErr :: new :: < ReaderError , _ > ( msg) ,
34+ GgsqlError :: WriterError ( _) => PyErr :: new :: < WriterError , _ > ( msg) ,
35+ GgsqlError :: InternalError ( _) => PyErr :: new :: < PyValueError , _ > ( msg) ,
36+ }
37+ }
38+
1539use polars:: prelude:: { DataFrame , IpcReader , IpcWriter , SerReader , SerWriter } ;
1640
1741// ============================================================================
@@ -175,7 +199,7 @@ macro_rules! try_native_readers {
175199 if let Ok ( native) = $reader. downcast:: <$native_type>( ) {
176200 return native. borrow( ) . inner. execute( $query)
177201 . map( |s| PySpec { inner: s } )
178- . map_err( |e| PyErr :: new :: <pyo3 :: exceptions :: PyValueError , _> ( e . to_string ( ) ) ) ;
202+ . map_err( ggsql_err_to_py ) ;
179203 }
180204 ) *
181205 } } ;
@@ -225,7 +249,7 @@ impl PyDuckDBReader {
225249 #[ new]
226250 fn new ( connection : & str ) -> PyResult < Self > {
227251 let inner = RustDuckDBReader :: from_connection_string ( connection)
228- . map_err ( |e| PyErr :: new :: < pyo3 :: exceptions :: PyValueError , _ > ( e . to_string ( ) ) ) ?;
252+ . map_err ( ggsql_err_to_py ) ?;
229253 Ok ( Self { inner } )
230254 }
231255
@@ -255,7 +279,7 @@ impl PyDuckDBReader {
255279 let rust_df = py_to_polars ( py, df) ?;
256280 self . inner
257281 . register ( name, rust_df, replace)
258- . map_err ( |e| PyErr :: new :: < pyo3 :: exceptions :: PyValueError , _ > ( e . to_string ( ) ) )
282+ . map_err ( ggsql_err_to_py )
259283 }
260284
261285 /// Unregister a previously registered table.
@@ -272,7 +296,7 @@ impl PyDuckDBReader {
272296 fn unregister ( & self , name : & str ) -> PyResult < ( ) > {
273297 self . inner
274298 . unregister ( name)
275- . map_err ( |e| PyErr :: new :: < pyo3 :: exceptions :: PyValueError , _ > ( e . to_string ( ) ) )
299+ . map_err ( ggsql_err_to_py )
276300 }
277301
278302 /// Execute a SQL query and return the result as a DataFrame.
@@ -295,7 +319,7 @@ impl PyDuckDBReader {
295319 let df = self
296320 . inner
297321 . execute_sql ( sql)
298- . map_err ( |e| PyErr :: new :: < pyo3 :: exceptions :: PyValueError , _ > ( e . to_string ( ) ) ) ?;
322+ . map_err ( ggsql_err_to_py ) ?;
299323 polars_to_py ( py, & df)
300324 }
301325
@@ -330,7 +354,7 @@ impl PyDuckDBReader {
330354 self . inner
331355 . execute ( query)
332356 . map ( |s| PySpec { inner : s } )
333- . map_err ( |e| PyErr :: new :: < pyo3 :: exceptions :: PyValueError , _ > ( e . to_string ( ) ) )
357+ . map_err ( ggsql_err_to_py )
334358 }
335359}
336360
@@ -393,7 +417,7 @@ impl PyVegaLiteWriter {
393417 fn render ( & self , spec : & PySpec ) -> PyResult < String > {
394418 self . inner
395419 . render ( & spec. inner )
396- . map_err ( |e| PyErr :: new :: < pyo3 :: exceptions :: PyValueError , _ > ( e . to_string ( ) ) )
420+ . map_err ( ggsql_err_to_py )
397421 }
398422}
399423
@@ -658,7 +682,7 @@ impl PySpec {
658682#[ pyfunction]
659683fn validate ( query : & str ) -> PyResult < PyValidated > {
660684 let v = rust_validate ( query)
661- . map_err ( |e| PyErr :: new :: < pyo3 :: exceptions :: PyValueError , _ > ( e . to_string ( ) ) ) ?;
685+ . map_err ( ggsql_err_to_py ) ?;
662686
663687 Ok ( PyValidated {
664688 sql : v. sql ( ) . to_string ( ) ,
@@ -739,7 +763,7 @@ fn execute(query: &str, reader: &Bound<'_, PyAny>) -> PyResult<PySpec> {
739763 bridge
740764 . execute ( query)
741765 . map ( |s| PySpec { inner : s } )
742- . map_err ( |e| PyErr :: new :: < pyo3 :: exceptions :: PyValueError , _ > ( e . to_string ( ) ) )
766+ . map_err ( ggsql_err_to_py )
743767}
744768
745769// ============================================================================
@@ -748,6 +772,12 @@ fn execute(query: &str, reader: &Bound<'_, PyAny>) -> PyResult<PySpec> {
748772
749773#[ pymodule]
750774fn _ggsql ( m : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
775+ // Exceptions
776+ m. add ( "ParseError" , m. py ( ) . get_type :: < ParseError > ( ) ) ?;
777+ m. add ( "ValidationError" , m. py ( ) . get_type :: < ValidationError > ( ) ) ?;
778+ m. add ( "ReaderError" , m. py ( ) . get_type :: < ReaderError > ( ) ) ?;
779+ m. add ( "WriterError" , m. py ( ) . get_type :: < WriterError > ( ) ) ?;
780+
751781 // Classes
752782 m. add_class :: < PyDuckDBReader > ( ) ?;
753783 m. add_class :: < PyVegaLiteWriter > ( ) ?;
0 commit comments