Skip to content
This repository was archived by the owner on Sep 20, 2023. It is now read-only.

Commit d693c59

Browse files
Merge pull request #102 from laughingman7743/fix_decimal_type_format
Fix parameter format of Decimal type
2 parents 0924f7c + dc40360 commit d693c59

4 files changed

Lines changed: 17 additions & 6 deletions

File tree

pyathenajdbc/formatter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ def _format_seq(formatter, escaper, val):
7272
return "({0})".format(", ".join(results))
7373

7474

75+
def _format_decimal(formatter, escaper, val):
76+
return "DECIMAL {0}".format(escaper("{0:f}".format(val)))
77+
78+
7579
class ParameterFormatter(object):
7680
def __init__(self):
7781
self.mappings = _DEFAULT_FORMATTERS
@@ -119,7 +123,7 @@ def register_formatter(self, type_, formatter):
119123
int: _format_default,
120124
float: _format_default,
121125
long: _format_default,
122-
Decimal: _format_default,
126+
Decimal: _format_decimal,
123127
bool: _format_bool,
124128
str: _format_str,
125129
unicode: _format_str,

pyathenajdbc/util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
from datetime import datetime
77

88

9-
def as_pandas(cursor):
9+
def as_pandas(cursor, coerce_float=False):
1010
from pandas import DataFrame
1111

1212
names = [metadata[0] for metadata in cursor.description]
13-
return DataFrame.from_records(cursor.fetchall(), columns=names)
13+
return DataFrame.from_records(
14+
cursor.fetchall(), columns=names, coerce_float=coerce_float
15+
)
1416

1517

1618
def synchronized(wrapped):

tests/test_cursor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,11 @@ def test_unicode(self, cursor):
216216
cursor.execute("SELECT %(param)s FROM one_row", {"param": unicode_str})
217217
self.assertEqual(cursor.fetchall(), [(unicode_str,)])
218218

219+
@with_cursor
220+
def test_decimal(self, cursor):
221+
cursor.execute("SELECT %(decimal)s", {"decimal": Decimal("0.00000000001")})
222+
self.assertEqual(cursor.fetchall(), [(Decimal("0.00000000001"),)])
223+
219224
@with_cursor
220225
def test_null(self, cursor):
221226
cursor.execute("SELECT null FROM many_rows")

tests/test_formatter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def test_format_decimal(self):
172172
"""
173173
SELECT *
174174
FROM test_table
175-
WHERE col_decimal <= 0.0000000001
175+
WHERE col_decimal <= DECIMAL '0.0000000001'
176176
"""
177177
).strip()
178178

@@ -181,7 +181,7 @@ def test_format_decimal(self):
181181
"""
182182
SELECT *
183183
FROM test_table
184-
WHERE col_decimal <= %(param).10f
184+
WHERE col_decimal <= %(param)s
185185
"""
186186
).strip(),
187187
{"param": Decimal("0.0000000001")},
@@ -364,7 +364,7 @@ def test_format_decimal_list(self):
364364
"""
365365
SELECT *
366366
FROM test_table
367-
WHERE col_decimal IN (0.0000000001, 99.9999999999)
367+
WHERE col_decimal IN (DECIMAL '0.0000000001', DECIMAL '99.9999999999')
368368
"""
369369
).strip()
370370

0 commit comments

Comments
 (0)