Skip to content

Commit 596a82e

Browse files
authored
Merge pull request #64 from jcal-15/field-name
Adding option to support keeping underscores in argument names
2 parents d902568 + 3455ed7 commit 596a82e

2 files changed

Lines changed: 18 additions & 6 deletions

File tree

argparse_dataclass.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def _add_dataclass_options(
334334
raise TypeError("cls must be a dataclass")
335335

336336
for field in fields(options_class):
337-
args = field.metadata.get("args", [f"--{field.name.replace('_', '-')}"])
337+
args = field.metadata.get("args", [f"--{_get_arg_name(field)}"])
338338
positional = not args[0].startswith("-")
339339
kwargs = {
340340
"type": field.metadata.get("type", field.type),
@@ -448,7 +448,7 @@ def _handle_bool_type(field: Field, args: list, kwargs: dict):
448448
if field.default is True:
449449
kwargs["action"] = "store_false"
450450
if "args" not in field.metadata:
451-
args[0] = f"--no-{field.name.replace('_', '-')}"
451+
args[0] = f"--no-{_get_arg_name(field)}"
452452
kwargs["dest"] = field.name
453453
elif field.metadata.get("required") is True:
454454
kwargs["action"] = BooleanOptionalAction
@@ -479,6 +479,12 @@ def _handle_argument_group(
479479
group.add_argument(*args, **kwargs)
480480

481481

482+
def _get_arg_name(field: Field):
483+
if field.metadata.get("keep_underscores", False):
484+
return field.name
485+
return field.name.replace("_", "-")
486+
487+
482488
class ArgumentParser(argparse.ArgumentParser, Generic[OptionsType]):
483489
"""Command line argument parser that derives its options from a dataclass.
484490

tests/test_argumentparser.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,6 @@ class Options:
142142
params = ArgumentParser(Options).parse_args(["--name", "john doe"])
143143
self.assertEqual(params.name, "John Doe")
144144

145-
@unittest.skipIf(
146-
sys.version_info[:2] == (3, 6),
147-
"Python 3.6 does not have datetime.fromisoformat()",
148-
)
149145
def test_default_factory(self):
150146
@dataclass
151147
class Parameters:
@@ -302,6 +298,16 @@ class Options:
302298

303299
self.assertRaises(ValueError, lambda: ArgumentParser(Options))
304300

301+
def test_keep_underscores(self):
302+
@dataclass
303+
class Args:
304+
num_of_foo: int = field(metadata={"keep_underscores": True})
305+
is_fun: bool = field(default=True, metadata={"keep_underscores": True})
306+
307+
params = ArgumentParser(Args).parse_args(["--num_of_foo=10", "--no-is_fun"])
308+
self.assertEqual(10, params.num_of_foo)
309+
self.assertFalse(params.is_fun)
310+
305311

306312
if __name__ == "__main__":
307313
unittest.main()

0 commit comments

Comments
 (0)