|
6 | 6 | # for details. |
7 | 7 | ########################################################################## |
8 | 8 |
|
| 9 | +import os |
| 10 | + |
9 | 11 | import numpy as np |
10 | 12 | import torch |
11 | 13 | import traits.api as traits |
@@ -46,6 +48,20 @@ def validate(self, objekt, name, value): |
46 | 48 | return value |
47 | 49 |
|
48 | 50 |
|
| 51 | +class File(traits.TraitType): |
| 52 | + def validate(self, objekt, name, value): |
| 53 | + if not (isinstance(value, str) and os.path.isfile(value)): |
| 54 | + self.error(objekt, name, value) |
| 55 | + return value |
| 56 | + |
| 57 | + |
| 58 | +class Directory(traits.TraitType): |
| 59 | + def validate(self, objekt, name, value): |
| 60 | + if not (isinstance(value, str) and os.path.isdir(value)): |
| 61 | + self.error(objekt, name, value) |
| 62 | + return value |
| 63 | + |
| 64 | + |
49 | 65 | class Sequence(traits.List): |
50 | 66 | def validate(self, objekt, name, value): |
51 | 67 | if not isinstance(value, (tuple, list)): |
@@ -81,12 +97,15 @@ def validate(self, objekt, name, value): |
81 | 97 | "int": "traits.Int", |
82 | 98 | "float": "traits.Float", |
83 | 99 | "bool": "traits.Bool", |
84 | | - "Tensor": "traits.Tensor", |
| 100 | + "torch.Tensor": "traits.Tensor", |
85 | 101 | "list": "traits.List", |
86 | 102 | "tuple": "traits.Tuple", |
87 | | - "Sequence": "traits.Sequence", |
88 | | - "array": "traits.Array", |
89 | | - "Union": "traits.Union", |
90 | | - "Optional": "traits.Either", |
| 103 | + "collections.abc.Sequence": "traits.Sequence", |
| 104 | + "typing.Sequence": "traits.Sequence", |
| 105 | + "numpy.array": "traits.Array", |
| 106 | + "typing.Union": "traits.Union", |
| 107 | + "typing.Optional": "traits.Either", |
91 | 108 | "NoneType": "traits.Undefined", |
| 109 | + "typex.typing_extensions.File": "traits.File", |
| 110 | + "typex.typing_extensions.Directory": "traits.Directory" |
92 | 111 | } |
0 commit comments