Skip to content

Commit 57599d2

Browse files
committed
Satisfy newer mypy
1 parent b1dbbd3 commit 57599d2

1 file changed

Lines changed: 22 additions & 15 deletions

File tree

sql_athame/dataclasses.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def model_field_metadata(
5454
def model_field(
5555
*, type: str, constraints: Union[str, Iterable[str]] = (), **kwargs: Any
5656
) -> Any:
57-
return field(**kwargs, metadata=model_field_metadata(type, constraints)) # type: ignore
57+
return field(**kwargs, metadata=model_field_metadata(type, constraints))
5858

5959

6060
sql_create_type_map = {
@@ -111,6 +111,13 @@ def __init_subclass__(
111111
else:
112112
cls.primary_key_names = tuple(primary_key)
113113

114+
@classmethod
115+
def _fields(cls):
116+
# wrapper to ignore typing weirdness: 'Argument 1 to "fields"
117+
# has incompatible type "..."; expected "DataclassInstance |
118+
# type[DataclassInstance]"'
119+
return fields(cls) # type: ignore
120+
114121
@classmethod
115122
def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
116123
try:
@@ -139,7 +146,7 @@ def column_info(cls, column: str) -> ColumnInfo:
139146
try:
140147
return cls._column_info[column] # type: ignore
141148
except AttributeError:
142-
cls._column_info = {f.name: column_info_for_field(f) for f in fields(cls)}
149+
cls._column_info = {f.name: column_info_for_field(f) for f in cls._fields()}
143150
return cls._column_info[column]
144151

145152
@classmethod
@@ -152,7 +159,7 @@ def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> List[Fragment
152159

153160
@classmethod
154161
def field_names(cls, *, exclude: FieldNamesSet = ()) -> List[str]:
155-
return [f.name for f in fields(cls) if f.name not in exclude]
162+
return [f.name for f in cls._fields() if f.name not in exclude]
156163

157164
@classmethod
158165
def field_names_sql(
@@ -171,7 +178,7 @@ def _get_field_values_fn(
171178
) -> Callable[[T], List[Any]]:
172179
env: Dict[str, Any] = dict()
173180
func = ["def get_field_values(self): return ["]
174-
for f in fields(cls):
181+
for f in cls._fields():
175182
if f.name not in exclude:
176183
func.append(f"self.{f.name},")
177184
func += ["]"]
@@ -200,23 +207,23 @@ def field_values_sql(
200207
def from_tuple(
201208
cls: Type[T], tup: tuple, *, offset: int = 0, exclude: FieldNamesSet = ()
202209
) -> T:
203-
names = (f.name for f in fields(cls) if f.name not in exclude)
210+
names = (f.name for f in cls._fields() if f.name not in exclude)
204211
kwargs = {name: tup[offset] for offset, name in enumerate(names, start=offset)}
205-
return cls(**kwargs) # type: ignore
212+
return cls(**kwargs)
206213

207214
@classmethod
208215
def from_dict(
209216
cls: Type[T], dct: Dict[str, Any], *, exclude: FieldNamesSet = ()
210217
) -> T:
211-
names = {f.name for f in fields(cls) if f.name not in exclude}
218+
names = {f.name for f in cls._fields() if f.name not in exclude}
212219
kwargs = {k: v for k, v in dct.items() if k in names}
213-
return cls(**kwargs) # type: ignore
220+
return cls(**kwargs)
214221

215222
@classmethod
216223
def ensure_model(cls: Type[T], row: Union[T, Mapping[str, Any]]) -> T:
217224
if isinstance(row, cls):
218225
return row
219-
return cls(**row) # type: ignore
226+
return cls(**row)
220227

221228
@classmethod
222229
def create_table_sql(cls) -> Fragment:
@@ -226,7 +233,7 @@ def create_table_sql(cls) -> Fragment:
226233
sql.identifier(f.name),
227234
sql.literal(cls.column_info(f.name).create_table_string()),
228235
)
229-
for f in fields(cls)
236+
for f in cls._fields()
230237
]
231238
if cls.primary_key_names:
232239
entries += [sql("PRIMARY KEY ({})", sql.list(cls.primary_key_names_sql()))]
@@ -278,7 +285,7 @@ async def select_cursor(
278285
*cls.select_sql(order_by=order_by, for_update=for_update, where=where),
279286
prefetch=prefetch,
280287
):
281-
yield cls(**row) # type: ignore
288+
yield cls(**row)
282289

283290
@classmethod
284291
async def select(
@@ -289,7 +296,7 @@ async def select(
289296
where: Where = (),
290297
) -> List[T]:
291298
return [
292-
cls(**row) # type: ignore
299+
cls(**row)
293300
for row in await connection_or_pool.fetch(
294301
*cls.select_sql(order_by=order_by, for_update=for_update, where=where)
295302
)
@@ -310,7 +317,7 @@ async def create(
310317
cls: Type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
311318
) -> T:
312319
row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
313-
return cls(**row) # type: ignore
320+
return cls(**row)
314321

315322
def insert_sql(self, exclude: FieldNamesSet = ()) -> Fragment:
316323
cached = self._cached(
@@ -419,7 +426,7 @@ def _get_equal_ignoring_fn(
419426
) -> Callable[[T, T], bool]:
420427
env: Dict[str, Any] = dict()
421428
func = ["def equal_ignoring(a, b):"]
422-
for f in fields(cls):
429+
for f in cls._fields():
423430
if f.name not in ignore:
424431
func.append(f" if a.{f.name} != b.{f.name}: return False")
425432
func += [" return True"]
@@ -473,7 +480,7 @@ def _get_differences_ignoring_fn(
473480
"def differences_ignoring(a, b):",
474481
" diffs = []",
475482
]
476-
for f in fields(cls):
483+
for f in cls._fields():
477484
if f.name not in ignore:
478485
func.append(
479486
f" if a.{f.name} != b.{f.name}: diffs.append({repr(f.name)})"

0 commit comments

Comments
 (0)