Skip to content

Commit 38b1d74

Browse files
authored
minor fixes for ibims orm creation (#99)
* relative imports +snakecasing * fixed linting * fixed uuid refs * many refs fixed uuid ref * remove unused function
1 parent d810b17 commit 38b1d74

9 files changed

Lines changed: 83 additions & 57 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ classifiers = [
1919
dependencies = [
2020
"datamodel-code-generator",
2121
"click",
22+
"autoflake"
2223
] # add all the dependencies from requirements.in here, too
2324
dynamic = ["readme", "version"]
2425

requirements.in

Lines changed: 0 additions & 2 deletions
This file was deleted.

requirements.txt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,24 @@
22
# This file is autogenerated by pip-compile with Python 3.12
33
# by the following command:
44
#
5-
# pip-compile requirements.in
5+
# pip-compile '.\pyproject.toml'
66
#
77
annotated-types==0.6.0
88
# via pydantic
99
argcomplete==3.1.2
1010
# via datamodel-code-generator
11+
autoflake==2.3.1
12+
# via BO4E-Python-Generator (pyproject.toml)
1113
black==24.8.0
1214
# via datamodel-code-generator
1315
click==8.1.7
1416
# via
15-
# -r requirements.in
17+
# BO4E-Python-Generator (pyproject.toml)
1618
# black
1719
datamodel-code-generator==0.25.9
1820
# via -r requirements.in
21+
colorama==0.4.6
22+
# via click
1923
dnspython==2.4.2
2024
# via email-validator
2125
email-validator==2.0.0.post2
@@ -46,6 +50,8 @@ pydantic[email]==2.4.2
4650
# via datamodel-code-generator
4751
pydantic-core==2.10.1
4852
# via pydantic
53+
pyflakes==3.2.0
54+
# via autoflake
4955
pyyaml==6.0.1
5056
# via datamodel-code-generator
5157
typing-extensions==4.8.0

src/bo4e_generator/__main__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
parse_bo4e_schemas,
1717
)
1818
from bo4e_generator.schema import get_namespace, get_version
19+
from bo4e_generator.sqlparser import remove_unused_imports
1920

2021

2122
def resolve_paths(input_directory: Path, output_directory: Path) -> tuple[Path, Path]:
@@ -52,6 +53,11 @@ def generate_bo4e_schemas(
5253
for relative_file_path, file_content in file_contents.items():
5354
file_path = output_directory / relative_file_path
5455
file_path.parent.mkdir(parents=True, exist_ok=True)
56+
if (
57+
relative_file_path.name not in ["__init__.py", "__version__.py"]
58+
and OutputType[output_type] == OutputType.SQL_MODEL
59+
):
60+
file_content = remove_unused_imports(file_content)
5561
file_content = formatter.format_code(file_content)
5662
file_path.write_text(file_content, encoding="utf-8")
5763
print(f"Created {file_path}")

src/bo4e_generator/custom_templates/BaseModel.jinja2

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,30 @@
11
{%- if SQL and SQL['imports']%}
2-
{%- for class_name, module_path in SQL['imports'].items() %}
2+
{%- for import_class_name, module_path in SQL['imports'].items() %}
33
{%- if module_path[:4] == 'enum'%}
4-
from borm.models.{{ module_path }} import {{ class_name }}
4+
from ..{{ module_path }} import {{ import_class_name }}
55
{%- elif module_path == 'Link'%}
6-
from borm.models.many import {{ class_name }}
6+
{% if class_name == 'ZusatzAttribut'%}
7+
from .many import {{ import_class_name }}
8+
{% else %}
9+
from ..many import {{ import_class_name }}
10+
{% endif %}
711
{%- else %}
8-
from {{ module_path }} import {{ class_name }}
12+
from {{ module_path }} import {{ import_class_name }}
913
{%- endif %}
1014
{%- endfor -%}
1115
{%- endif %}
1216
{%- if SQL and SQL['relationimports']%}
1317
from typing import TYPE_CHECKING
1418
if TYPE_CHECKING:
15-
{%- for class_name, module_path in SQL['relationimports'].items() %}
16-
from borm.models.{{ module_path }} import {{ class_name }}
19+
{% if class_name == 'ZusatzAttribut'%}
20+
{%- for import_class_name, module_path in SQL['relationimports'].items() %}
21+
from .{{ module_path }} import {{ import_class_name }}
1722
{%- endfor -%}
23+
{% else %}
24+
{%- for import_class_name, module_path in SQL['relationimports'].items() %}
25+
from ..{{ module_path }} import {{ import_class_name }}
26+
{%- endfor -%}
27+
{%- endif %}
1828
{%- endif %}
1929
{% for decorator in decorators -%}
2030
{{ decorator }}

src/bo4e_generator/custom_templates/ManyLinks.jinja2

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ class {{class1}}{{class2[1]}}Link(SQLModel, table=True):
1414
"""
1515
class linking m-n relation of tables {{class1}} and {{class2[0]}} for field {{ class2[1]}}.
1616
"""
17-
{{class1.lower()}}_id: Optional[uuid_pkg.UUID] = Field(sa_column=Column(UUID(as_uuid=True), ForeignKey("{{class1.lower()}}.{{class1.lower()}}_sqlid", ondelete="CASCADE"), primary_key=True))
18-
{{class2[0].lower()}}_id: Optional[uuid_pkg.UUID] = Field(sa_column=Column(UUID(as_uuid=True), ForeignKey("{{class2[0].lower()}}.{{class2[0].lower()}}_sqlid", ondelete="CASCADE"), primary_key=True))
17+
{{class1.lower()}}_id: Optional[uuid_pkg.UUID] = Field(sa_column=Column(UUID(as_uuid=True), ForeignKey("{{class1.lower()}}.id", ondelete="CASCADE"), primary_key=True))
18+
{{class2[0].lower()}}_id: Optional[uuid_pkg.UUID] = Field(sa_column=Column(UUID(as_uuid=True), ForeignKey("{{class2[0].lower()}}.id", ondelete="CASCADE"), primary_key=True))
1919

2020
{%- endfor -%}
2121
{%- endfor -%}

src/bo4e_generator/sqlparser.py

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
"""
55

66
import json
7+
import os
78
import re
9+
import subprocess
10+
import tempfile
811
from collections import defaultdict
912
from pathlib import Path
1013
from typing import Any, DefaultDict, Union
1114

12-
import black
13-
import isort
1415
from jinja2 import Environment, FileSystemLoader
1516

16-
from bo4e_generator.schema import SchemaMetadata
17+
from bo4e_generator.schema import SchemaMetadata, camel_to_snake
1718

1819

1920
def remove_pydantic_field_import(python_code: str) -> str:
@@ -42,29 +43,32 @@ def adapt_parse_for_sql(
4243
for schema_metadata in namespace.values():
4344
if schema_metadata.module_path[0] != "enum":
4445
# list of fields which will be replaced by modified versions
45-
del_fields = []
46+
del_fields = set()
4647
for field, val in schema_metadata.schema_parsed["properties"].items():
4748
# type Any field
4849
if "type" not in str(val):
4950
add_relation, relation_imports = create_sql_any(
5051
field, schema_metadata.class_name, namespace, add_relation, relation_imports
5152
)
52-
del_fields.append(field)
53+
del_fields.add(field)
5354
# modify decimal fields
5455
if "number" in str(val) and "string" in str(val):
5556
relation_imports[schema_metadata.class_name + "ADD"]["Decimal"] = "decimal"
5657
if "array" in str(val) and "$ref" not in str(val):
5758
add_relation, relation_imports = create_sql_list(
5859
field, schema_metadata.class_name, namespace, add_relation, relation_imports
5960
)
60-
del_fields.append(field)
61+
del_fields.add(field)
6162
if "$ref" in str(val): # or "array" in str(val):
6263
add_relation, relation_imports = create_sql_field(
6364
field, schema_metadata.class_name, namespace, add_relation, relation_imports
6465
)
65-
del_fields.append(field)
66+
del_fields.add(field)
6667
for field in del_fields:
6768
del schema_metadata.schema_parsed["properties"][field]
69+
# delete id field as it is replaced below
70+
if schema_metadata.schema_parsed["properties"].get("_id"):
71+
del schema_metadata.schema_parsed["properties"]["_id"]
6872
# store the reduced version. The modified fields will be added in the BaseModel.jinja2 schema
6973
schema_metadata.schema_text = json.dumps(schema_metadata.schema_parsed, indent=2, ensure_ascii=False)
7074

@@ -104,9 +108,8 @@ def additional_sql_arguments(
104108
if schema_metadata.module_path[0] != "enum":
105109
# add primary key
106110
additional_sql_data[schema_metadata.class_name]["SQL"] = {
107-
"primary": schema_metadata.class_name.lower()
108-
+ "_sqlid: uuid_pkg.UUID = Field( default_factory=uuid_pkg.uuid4, primary_key=True, index=True, "
109-
"nullable=False )"
111+
"primary": "id: uuid_pkg.UUID = Field( default_factory=uuid_pkg.uuid4, primary_key=True, index=True, "
112+
'nullable=False, alias="_id", title=" Id" )'
110113
}
111114
if schema_metadata.class_name in add_relation:
112115
additional_sql_data[schema_metadata.class_name]["SQL"]["relations"] = add_relation[
@@ -184,7 +187,7 @@ def create_sql_list(
184187
add_imports[class_name + "ADD"]["Column, ARRAY"] = "sqlalchemy"
185188
add_imports[class_name + "ADD"][sa_type] = "sqlalchemy"
186189

187-
add_fields[class_name][f"{field_name}"] = (
190+
add_fields[class_name][f"{camel_to_snake(field_name)}"] = (
188191
f"List[{type_hint}] "
189192
+ is_optional
190193
+ f' = Field({default}, title="{field_name}", sa_column=Column( ARRAY( {sa_type} )))'
@@ -209,12 +212,14 @@ def sql_reference_enum(
209212
returns field which references enums.
210213
"""
211214
if is_list:
212-
add_fields[class_name][f"{field_name}"] = (
215+
add_fields[class_name][f"{camel_to_snake(field_name)}"] = (
213216
f"List[{reference_name}]" + is_optional + f" = Field({default},"
214217
f' sa_column=Column( ARRAY( Enum( {reference_name}, name="{reference_name.lower()}"))))'
215218
)
216219
else:
217-
add_fields[class_name][f"{field_name}"] = f"{reference_name}" + is_optional + f"= Field({default})"
220+
add_fields[class_name][f"{camel_to_snake(field_name)}"] = (
221+
f"{reference_name}" + is_optional + f"= Field({default})"
222+
)
218223

219224
# import enums
220225
if is_list:
@@ -265,23 +270,23 @@ def create_sql_field(
265270
add_fields["MANY"][class_name] = [[reference_name, field_name]]
266271
elif reference_name not in add_fields["MANY"][class_name]:
267272
add_fields["MANY"][class_name].append([reference_name, field_name])
268-
add_fields[class_name][f"{field_name}"] = (
273+
add_fields[class_name][f"{camel_to_snake(field_name)}"] = (
269274
f'List["{reference_name}"] ='
270275
f' Relationship(back_populates="{class_name.lower()}_{field_name.lower()}_link", '
271276
f"link_model={class_name}{field_name}Link)"
272277
)
273278
add_fields[reference_name][f"{class_name.lower()}_{field_name.lower()}_link"] = (
274279
f'List["{class_name}"] ='
275-
f' Relationship(back_populates="{field_name}", '
280+
f' Relationship(back_populates="{camel_to_snake(field_name)}", '
276281
f"link_model={class_name}{field_name}Link)"
277282
)
278283
add_imports[class_name + "ADD"][f"{class_name}{field_name}Link)"] = "Link"
279284
add_imports[reference_name + "ADD"][f"{class_name}{field_name}Link)"] = "Link"
280285
else:
281286
# cf. https://github.com/tiangolo/sqlmodel/pull/610
282-
add_fields[class_name][f"{field_name}_id"] = (
287+
add_fields[class_name][f"{camel_to_snake(field_name)}_id"] = (
283288
"uuid_pkg.UUID " + is_optional + f" = Field(sa_column=Column(UUID(as_uuid=True),"
284-
f' ForeignKey("{reference_name.lower()}.{reference_name.lower()}_sqlid"'
289+
f' ForeignKey("{reference_name.lower()}.id"'
285290
f', ondelete="SET NULL")))'
286291
)
287292
add_imports[class_name + "ADD"]["Column"] = "sqlalchemy"
@@ -291,20 +296,20 @@ def create_sql_field(
291296
# pylint: disable= fixme
292297
# todo: check default
293298

294-
add_fields[class_name][f"{field_name}"] = (
299+
add_fields[class_name][f"{camel_to_snake(field_name)}"] = (
295300
f'"{reference_name}" ='
296-
f' Relationship(back_populates="{class_name.lower()}_{field_name}",'
297-
f' sa_relationship_kwargs= {{ "foreign_keys":"[{class_name}.{field_name}_id]" }})'
301+
f' Relationship(back_populates="{class_name.lower()}_{camel_to_snake(field_name)}",'
302+
f' sa_relationship_kwargs= {{ "foreign_keys":"[{class_name}.{camel_to_snake(field_name)}_id]" }})'
298303
)
299304

300305
# cf. https://github.com/tiangolo/sqlmodel/issues/10
301306
# https://github.com/tiangolo/sqlmodel/issues/213
302307
# https://dev.to/whchi/disable-sqlmodel-foreign-key-constraint-55kp
303-
add_fields[reference_name][f"{class_name.lower()}_{field_name}"] = (
304-
f'List["{class_name}"] = Relationship(back_populates="{field_name}",'
308+
add_fields[reference_name][f"{class_name.lower()}_{camel_to_snake(field_name)}"] = (
309+
f'List["{class_name}"] = Relationship(back_populates="{camel_to_snake(field_name)}",'
305310
f"sa_relationship_kwargs="
306311
f'{{"primaryjoin":'
307-
f' "{class_name}.{field_name}_id=={reference_name}.{reference_name.lower()}_sqlid",'
312+
f' "{class_name}.{camel_to_snake(field_name)}_id=={reference_name}.id",'
308313
f' "lazy": "joined"}})'
309314
)
310315
# add_relation_import
@@ -346,11 +351,11 @@ def create_sql_any(
346351
if is_list:
347352
add_imports[class_name + "ADD"]["List"] = "typing"
348353
add_imports[class_name + "ADD"]["ARRAY"] = "sqlalchemy"
349-
add_fields[class_name][f"{field_name}"] = (
354+
add_fields[class_name][f"{camel_to_snake(field_name)}"] = (
350355
"List[Any]" + is_optional + f" = Field({default}," f" sa_column=Column( ARRAY( PickleType)))"
351356
)
352357
else:
353-
add_fields[class_name][f"{field_name}"] = (
358+
add_fields[class_name][f"{camel_to_snake(field_name)}"] = (
354359
"Any" + is_optional + f" = Field({default}," f" sa_column=Column( PickleType))"
355360
)
356361

@@ -365,13 +370,27 @@ def write_many_many_links(links: dict[str, str]) -> str:
365370
environment = Environment(loader=FileSystemLoader(template_path))
366371
template = environment.get_template("ManyLinks.jinja2")
367372
python_code = template.render({"class": links})
368-
python_code = format_code(python_code)
373+
# python_code = format_code(python_code)
369374
return python_code
370375

371376

372-
def format_code(code: str) -> str:
377+
def remove_unused_imports(code):
373378
"""
374-
perform isort and black on code
379+
Removes unused imports from the given code using autoflake.
375380
"""
376-
code = black.format_str(code, mode=black.Mode())
377-
return isort.code(code, known_local_folder=["borm"])
381+
# Create a temporary file
382+
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as tmp_file:
383+
tmp_file_name = tmp_file.name
384+
tmp_file.write(code.encode("utf-8"))
385+
386+
# Run autoflake to remove unused imports
387+
subprocess.run(["autoflake", "--remove-all-unused-imports", "--in-place", tmp_file_name], check=True)
388+
389+
# Read the cleaned code from the temporary file
390+
with open(tmp_file_name, "r", encoding="utf-8") as tmp_file:
391+
cleaned_code = tmp_file.read()
392+
393+
# Clean up the temporary file
394+
os.remove(tmp_file_name)
395+
396+
return cleaned_code

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ deps =
7878
pre-commit
7979
commands =
8080
python -m pip install --upgrade pip
81-
pip-compile requirements.in
81+
pip-compile .\pyproject.toml
8282
pip install -r requirements.txt
8383
pre-commit install
8484

unittests/test_sqlparser.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from bo4e_generator.sqlparser import (
99
adapt_parse_for_sql,
1010
create_sql_field,
11-
format_code,
1211
remove_pydantic_field_import,
1312
return_ref,
1413
write_many_many_links,
@@ -76,16 +75,3 @@ def test_write_many_many_links(self) -> None:
7675
file_contents = write_many_many_links(links)
7776
keywords = ["AngebotzusatzAttributeLink", "angebot_id", "zusatzattribut_id"]
7877
assert all(substring in file_contents for substring in keywords)
79-
80-
def test_format_code(self) -> None:
81-
unsorted = (
82-
"from sqlmodel import Field, Relationship, SQLModel\n"
83-
"from typing import TYPE_CHECKING, List\n"
84-
"from borm.models.enum.anrede import Anrede"
85-
)
86-
resorted = (
87-
"from typing import TYPE_CHECKING, List\n\n"
88-
"from sqlmodel import Field, Relationship, SQLModel\n\n"
89-
"from borm.models.enum.anrede import Anrede\n"
90-
)
91-
assert resorted == format_code(unsorted)

0 commit comments

Comments
 (0)