@@ -4,18 +4,18 @@ use std::fmt::{Display, Formatter};
44use std:: str;
55
66use column:: Column ;
7- use common:: FieldDefinitionExpression ;
87use common:: {
98 as_alias, field_definition_expr, field_list, statement_terminator, table_list, table_reference,
109 unsigned_number,
1110} ;
11+ use common:: { sql_identifier, FieldDefinitionExpression } ;
1212use compound_select:: nested_compound_selection;
1313use condition:: { condition_expr, ConditionExpression } ;
1414use join:: { join_operator, JoinConstraint , JoinOperator , JoinRightSide } ;
1515use nom:: branch:: alt;
1616use nom:: bytes:: complete:: { tag, tag_no_case} ;
1717use nom:: combinator:: { map, opt} ;
18- use nom:: multi:: many0;
18+ use nom:: multi:: { many0, separated_list1 } ;
1919use nom:: sequence:: { delimited, preceded, terminated, tuple} ;
2020use nom:: IResult ;
2121use order:: { order_clause, OrderClause } ;
@@ -106,8 +106,66 @@ impl From<CompoundSelectStatement> for Selection {
106106 }
107107}
108108
109+ #[ derive( Clone , Debug , Eq , Hash , PartialEq , Serialize , Deserialize ) ]
110+ pub struct WithClause {
111+ pub recursive : bool ,
112+ pub subclauses : Vec < WithSubclause > ,
113+ }
114+
115+ impl fmt:: Display for WithClause {
116+ fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
117+ write ! ( f, "WITH " ) ?;
118+
119+ if self . recursive {
120+ write ! ( f, "RECURSIVE " ) ?;
121+ }
122+
123+ write ! (
124+ f,
125+ "{}" ,
126+ self . subclauses
127+ . iter( )
128+ . map( |c| format!( "{}" , c) )
129+ . collect:: <Vec <_>>( )
130+ . join( ", " )
131+ ) ?;
132+
133+ Ok ( ( ) )
134+ }
135+ }
136+
137+ #[ derive( Clone , Debug , Eq , Hash , PartialEq , Serialize , Deserialize ) ]
138+ pub struct WithSubclause {
139+ pub name : String ,
140+ pub columns : Vec < Column > ,
141+ pub selection : Box < Selection > ,
142+ }
143+
144+ impl fmt:: Display for WithSubclause {
145+ fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
146+ write ! ( f, "{} " , self . name) ?;
147+
148+ if self . columns . len ( ) > 0 {
149+ write ! (
150+ f,
151+ "({}) " ,
152+ self . columns
153+ . iter( )
154+ . map( |c| format!( "{}" , c) )
155+ . collect:: <Vec <_>>( )
156+ . join( ", " )
157+ ) ?;
158+ }
159+
160+ write ! ( f, "AS ({})" , self . selection) ?;
161+
162+ Ok ( ( ) )
163+ }
164+ }
165+
109166#[ derive( Clone , Debug , Default , Eq , Hash , PartialEq , Serialize , Deserialize ) ]
110167pub struct SelectStatement {
168+ pub with : Option < WithClause > ,
111169 pub tables : Vec < Table > ,
112170 pub distinct : bool ,
113171 pub fields : Vec < FieldDefinitionExpression > ,
@@ -120,6 +178,10 @@ pub struct SelectStatement {
120178
121179impl fmt:: Display for SelectStatement {
122180 fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
181+ if let Some ( ref with_clause) = self . with {
182+ write ! ( f, "{}" , with_clause) ?;
183+ }
184+
123185 write ! ( f, "SELECT " ) ?;
124186 if self . distinct {
125187 write ! ( f, "DISTINCT " ) ?;
@@ -318,8 +380,10 @@ pub fn simple_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
318380pub fn nested_simple_selection ( i : & [ u8 ] ) -> IResult < & [ u8 ] , SelectStatement > {
319381 let (
320382 remaining_input,
321- ( _, _, distinct, _, fields, _, tables, join, where_clause, group_by, order, limit) ,
383+ ( with , _ , _, _, distinct, _, fields, _, tables, join, where_clause, group_by, order, limit) ,
322384 ) = tuple ( (
385+ opt ( with_clause) ,
386+ multispace0,
323387 tag_no_case ( "select" ) ,
324388 multispace1,
325389 opt ( tag_no_case ( "distinct" ) ) ,
@@ -336,6 +400,7 @@ pub fn nested_simple_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
336400 Ok ( (
337401 remaining_input,
338402 SelectStatement {
403+ with,
339404 tables,
340405 distinct : distinct. is_some ( ) ,
341406 fields,
@@ -348,6 +413,60 @@ pub fn nested_simple_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
348413 ) )
349414}
350415
416+ pub fn with_clause ( i : & [ u8 ] ) -> IResult < & [ u8 ] , WithClause > {
417+ map (
418+ tuple ( (
419+ tag_no_case ( "with" ) ,
420+ multispace1,
421+ opt ( tag_no_case ( "recursive" ) ) ,
422+ multispace0,
423+ separated_list1 ( tuple ( ( multispace0, tag ( "," ) , multispace0) ) , with_subclause) ,
424+ ) ) ,
425+ |( _, _, recursive, _, subclauses) | WithClause {
426+ recursive : recursive. is_some ( ) ,
427+ subclauses,
428+ } ,
429+ ) ( i)
430+ }
431+
432+ pub fn with_subclause ( i : & [ u8 ] ) -> IResult < & [ u8 ] , WithSubclause > {
433+ map (
434+ tuple ( (
435+ sql_identifier,
436+ multispace1,
437+ opt ( with_clause_column_list) ,
438+ multispace0,
439+ tag_no_case ( "as" ) ,
440+ multispace1,
441+ tag ( "(" ) ,
442+ multispace0,
443+ nested_selection,
444+ multispace0,
445+ tag ( ")" ) ,
446+ ) ) ,
447+ |( name, _, columns, _, _, _, _, _, selection, _, _) | WithSubclause {
448+ name : str:: from_utf8 ( name) . unwrap ( ) . to_string ( ) ,
449+ columns : columns. unwrap_or ( vec ! [ ] ) ,
450+ selection : Box :: new ( selection) ,
451+ } ,
452+ ) ( i)
453+ }
454+
455+ pub fn with_clause_column_list ( i : & [ u8 ] ) -> IResult < & [ u8 ] , Vec < Column > > {
456+ let ( i, ( _, _, columns, _, _) ) = tuple ( (
457+ tag ( "(" ) ,
458+ multispace0,
459+ separated_list1 (
460+ tuple ( ( multispace0, tag ( "," ) , multispace0) ) ,
461+ map ( sql_identifier, |si| str:: from_utf8 ( si) . unwrap ( ) . into ( ) ) ,
462+ ) ,
463+ multispace0,
464+ tag ( ")" ) ,
465+ ) ) ( i) ?;
466+
467+ Ok ( ( i, columns) )
468+ }
469+
351470#[ cfg( test) ]
352471mod tests {
353472 use super :: * ;
@@ -1454,4 +1573,113 @@ mod tests {
14541573
14551574 assert_eq ! ( res. unwrap( ) . 1 , expected. into( ) ) ;
14561575 }
1576+
1577+ #[ test]
1578+ fn with ( ) {
1579+ let qstr0 = "WITH cte1 AS (SELECT a, b FROM table1)" ;
1580+ let qstr1 = "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, d FROM table2)" ;
1581+ let qstr2 =
1582+ "WITH cte1 (e, f) AS (SELECT a, b FROM table1), cte2 AS (SELECT c, d FROM table2)" ;
1583+ let qstr3 = "WITH RECURSIVE cte1 AS (SELECT a, b FROM table1)" ;
1584+ let res0 = with_clause ( qstr0. as_bytes ( ) ) ;
1585+ let res1 = with_clause ( qstr1. as_bytes ( ) ) ;
1586+ let res2 = with_clause ( qstr2. as_bytes ( ) ) ;
1587+ let res3 = with_clause ( qstr3. as_bytes ( ) ) ;
1588+
1589+ let expected_ss0 = Box :: new ( Selection :: Statement ( SelectStatement {
1590+ with : None ,
1591+ tables : vec ! [ Table {
1592+ name: "table1" . to_string( ) ,
1593+ alias: None ,
1594+ schema: None ,
1595+ } ] ,
1596+ distinct : false ,
1597+ fields : vec ! [
1598+ FieldDefinitionExpression :: Col ( Column :: from( "a" ) ) ,
1599+ FieldDefinitionExpression :: Col ( Column :: from( "b" ) ) ,
1600+ ] ,
1601+ join : vec ! [ ] ,
1602+ where_clause : None ,
1603+ group_by : None ,
1604+ order : None ,
1605+ limit : None ,
1606+ } ) ) ;
1607+ let expected_ss1 = Box :: new ( Selection :: Statement ( SelectStatement {
1608+ tables : vec ! [ Table {
1609+ name: "table2" . to_string( ) ,
1610+ alias: None ,
1611+ schema: None ,
1612+ } ] ,
1613+ fields : vec ! [
1614+ FieldDefinitionExpression :: Col ( Column :: from( "c" ) ) ,
1615+ FieldDefinitionExpression :: Col ( Column :: from( "d" ) ) ,
1616+ ] ,
1617+ ..Default :: default ( )
1618+ } ) ) ;
1619+
1620+ let expected0 = WithClause {
1621+ recursive : false ,
1622+ subclauses : vec ! [ WithSubclause {
1623+ name: "cte1" . to_string( ) ,
1624+ columns: vec![ ] ,
1625+ selection: expected_ss0. clone( ) ,
1626+ } ] ,
1627+ } ;
1628+ let expected1 = WithClause {
1629+ recursive : false ,
1630+ subclauses : vec ! [
1631+ WithSubclause {
1632+ name: "cte1" . to_string( ) ,
1633+ columns: vec![ ] ,
1634+ selection: expected_ss0. clone( ) ,
1635+ } ,
1636+ WithSubclause {
1637+ name: "cte2" . to_string( ) ,
1638+ columns: vec![ ] ,
1639+ selection: expected_ss1. clone( ) ,
1640+ } ,
1641+ ] ,
1642+ } ;
1643+ let expected2 = WithClause {
1644+ recursive : false ,
1645+ subclauses : vec ! [
1646+ WithSubclause {
1647+ name: "cte1" . to_string( ) ,
1648+ columns: vec![
1649+ Column {
1650+ name: "e" . to_string( ) ,
1651+ alias: None ,
1652+ table: None ,
1653+ function: None ,
1654+ } ,
1655+ Column {
1656+ name: "f" . to_string( ) ,
1657+ alias: None ,
1658+ table: None ,
1659+ function: None ,
1660+ } ,
1661+ ] ,
1662+ selection: expected_ss0. clone( ) ,
1663+ } ,
1664+ WithSubclause {
1665+ name: "cte2" . to_string( ) ,
1666+ columns: vec![ ] ,
1667+ selection: expected_ss1. clone( ) ,
1668+ } ,
1669+ ] ,
1670+ } ;
1671+ let expected3 = WithClause {
1672+ recursive : true ,
1673+ subclauses : vec ! [ WithSubclause {
1674+ name: "cte1" . to_string( ) ,
1675+ columns: vec![ ] ,
1676+ selection: expected_ss0. clone( ) ,
1677+ } ] ,
1678+ } ;
1679+
1680+ assert_eq ! ( res0. unwrap( ) . 1 , expected0) ;
1681+ assert_eq ! ( res1. unwrap( ) . 1 , expected1) ;
1682+ assert_eq ! ( res2. unwrap( ) . 1 , expected2) ;
1683+ assert_eq ! ( res3. unwrap( ) . 1 , expected3) ;
1684+ }
14571685}
0 commit comments