Skip to content

Commit b842c70

Browse files
committed
Adding framework for trainer, preprocessing data, and running train context from config
1 parent d2205c7 commit b842c70

2,040 files changed

Lines changed: 36375 additions & 0 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
pip-wheel-metadata/
24+
share/python-wheels/
25+
*.egg-info/
26+
.installed.cfg
27+
*.egg
28+
MANIFEST
29+
30+
# PyInstaller
31+
# Usually these files are written by a python script from a template
32+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
33+
*.manifest
34+
*.spec
35+
36+
# Installer logs
37+
pip-log.txt
38+
pip-delete-this-directory.txt
39+
40+
# Unit test / coverage reports
41+
htmlcov/
42+
.tox/
43+
.nox/
44+
.coverage
45+
.coverage.*
46+
.cache
47+
nosetests.xml
48+
coverage.xml
49+
*.cover
50+
*.py,cover
51+
.hypothesis/
52+
.pytest_cache/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
target/
76+
77+
# Jupyter Notebook
78+
.ipynb_checkpoints
79+
80+
# IPython
81+
profile_default/
82+
ipython_config.py
83+
84+
# pyenv
85+
.python-version
86+
87+
# pipenv
88+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
90+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
91+
# install all needed dependencies.
92+
#Pipfile.lock
93+
94+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
95+
__pypackages__/
96+
97+
# Celery stuff
98+
celerybeat-schedule
99+
celerybeat.pid
100+
101+
# SageMath parsed files
102+
*.sage.py
103+
104+
# Environments
105+
.env
106+
.venv
107+
env/
108+
venv/
109+
ENV/
110+
env.bak/
111+
venv.bak/
112+
113+
# Spyder project settings
114+
.spyderproject
115+
.spyproject
116+
117+
# Rope project settings
118+
.ropeproject
119+
120+
# mkdocs documentation
121+
/site
122+
123+
# mypy
124+
.mypy_cache/
125+
.dmypy.json
126+
dmypy.json
127+
128+
# Pyre type checker
129+
.pyre/
130+
131+
# mac
132+
.DS_Store
133+
134+
# self-defined
135+
questions.md

configs/config.yml

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
2+
trainer: property
3+
4+
task:
5+
# run_mode: train
6+
name: "my_train_job"
7+
8+
reprocess: "False"
9+
10+
11+
parallel: "True"
12+
seed: 0
13+
#seed=0 means random initalization
14+
15+
16+
write_output: "True"
17+
parallel: "True"
18+
#Training print out frequency (print per n number of epochs)
19+
verbosity: 5
20+
21+
#Ratios for train/val/test split out of a total of 1
22+
train_ratio: 0.8
23+
val_ratio: 0.05
24+
test_ratio: 0.15
25+
26+
27+
28+
model:
29+
name: CGCNN
30+
load_model: "False"
31+
save_model: "True"
32+
model_path: "my_model.pth"
33+
#model attributes
34+
dim1: 100
35+
dim2: 150
36+
pre_fc_count: 1
37+
gc_count: 4
38+
post_fc_count: 3
39+
pool: "global_mean_pool"
40+
pool_order: "early"
41+
42+
43+
optim:
44+
batch_norm: "True"
45+
batch_track_stats: "True"
46+
act: "relu"
47+
dropout_rate: 0.0
48+
epochs: 250
49+
lr: 0.002
50+
#Loss functions (from pytorch) examples: l1_loss, mse_loss, binary_cross_entropy
51+
loss: "l1_loss"
52+
batch_size: 100
53+
optimizer: "AdamW"
54+
optimizer_args: {}
55+
scheduler: "ReduceLROnPlateau"
56+
scheduler_args: {"mode":"min", "factor":0.8, "patience":10, "min_lr":0.00001, "threshold":0.0002}
57+
58+
dataset:
59+
processed: False # if False, need to process data and generate .pt file
60+
# Whether to use "inmemory" or "large" format for pytorch-geometric dataset. Reccomend inmemory unless the dataset is too large
61+
# dataset_type: "inmemory"
62+
#Path to data files
63+
src: "/Users/shuyijia/Documents/GitHub/Fung-Lab/MatDeepLearn/data/dup_test_data/"
64+
#Path to target file within data_path
65+
target_path: "/Users/shuyijia/Documents/GitHub/Fung-Lab/MatDeepLearn/data/dup_test_data/targets.csv"
66+
#Format of data files (limit to those supported by ASE)
67+
data_format: "json"
68+
#Method of obtaining atom idctionary: available:(onehot)
69+
node_representation: "onehot"
70+
#Print out processing info
71+
verbose: "True"
72+
73+
#Loading dataset params
74+
#Index of target column in targets.csv
75+
target_index: 0
76+
77+
#graph specific settings
78+
cutoff_radius : 8.0
79+
n_neighbors : 12
80+
edge_steps : 50
81+
26.8 MB
Binary file not shown.

data/dup_test_data/raw/14816.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{"1": {
2+
"cell": {"array": {"__ndarray__": [[3, 3], "float64", [20.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 20.0]]}, "__ase_objtype__": "cell"},
3+
"ctime": 21.058740541934128,
4+
"mtime": 21.058740541934128,
5+
"numbers": {"__ndarray__": [[10], "int64", [78, 78, 78, 78, 78, 78, 78, 78, 78, 78]]},
6+
"pbc": {"__ndarray__": [[3], "bool", [false, false, false]]},
7+
"positions": {"__ndarray__": [[10, 3], "float64", [10.84811456205555, 11.13622580219429, 7.83612602879444, 8.89212837067458, 9.54658153686299, 7.21241514873573, 7.65634346100067, 11.06984873647254, 10.79899682518365, 6.83723125191032, 9.76361614088538, 8.8238894254297, 6.94521070137696, 7.92049302211125, 7.03157270638096, 10.65936710090347, 7.83630971766916, 9.02232011850644, 8.04508065127254, 7.57294548770338, 9.27638766368227, 9.58923786887345, 6.47050530236025, 10.83219447852401, 11.36764325142721, 8.9041126777983, 6.8505801585817, 9.61821695155211, 10.05997145673351, 9.73862657964421]]},
8+
"unique_id": "44cedb81e4b6f2905e53196a5dbc7149",
9+
"user": "vfung"},
10+
"ids": [1],
11+
"nextid": 2}

data/dup_test_data/raw/14817.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{"1": {
2+
"cell": {"array": {"__ndarray__": [[3, 3], "float64", [20.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 20.0]]}, "__ase_objtype__": "cell"},
3+
"ctime": 21.058740541993963,
4+
"mtime": 21.058740541993963,
5+
"numbers": {"__ndarray__": [[10], "int64", [78, 78, 78, 78, 78, 78, 78, 78, 78, 78]]},
6+
"pbc": {"__ndarray__": [[3], "bool", [false, false, false]]},
7+
"positions": {"__ndarray__": [[10, 3], "float64", [10.55209998008288, 11.12273451997437, 7.93469442813388, 8.88659050059301, 9.2077519640926, 7.43621294868226, 7.58157615497966, 11.31660688313961, 10.9051896993328, 6.87629035267318, 9.66768574333755, 9.18193276928948, 6.66738124240678, 7.95117558592068, 7.35907190445034, 10.64199363796908, 7.94142248773163, 9.45351590630791, 8.05832143804086, 7.36922560478509, 9.37331615074526, 9.59049918387847, 6.32329282872876, 11.00601946660024, 11.59106791318951, 8.92567725354054, 7.36744265111009, 9.55417941618642, 10.17442694874904, 9.98260443534766]]},
8+
"unique_id": "97217fe2fc33b3eed3108ccff9ae688f",
9+
"user": "vfung"},
10+
"ids": [1],
11+
"nextid": 2}

data/dup_test_data/raw/14818.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{"1": {
2+
"cell": {"array": {"__ndarray__": [[3, 3], "float64", [20.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 20.0]]}, "__ase_objtype__": "cell"},
3+
"ctime": 21.058740542031256,
4+
"mtime": 21.058740542031256,
5+
"numbers": {"__ndarray__": [[10], "int64", [78, 78, 78, 78, 78, 78, 78, 78, 78, 78]]},
6+
"pbc": {"__ndarray__": [[3], "bool", [false, false, false]]},
7+
"positions": {"__ndarray__": [[10, 3], "float64", [11.29046010021573, 10.83951109080352, 8.78356895780009, 8.76461637715562, 8.68463836793248, 7.02523909732477, 7.15423020732798, 10.92714064976973, 11.27338161281467, 7.47245844617973, 9.71444322041805, 9.07972226919126, 6.33363470913941, 8.16272475690982, 7.38169763150429, 10.56661092703278, 8.42524297663398, 9.33109837260677, 7.98561880411683, 7.24711330506961, 9.11678345406328, 9.74478841096819, 6.45650133463814, 10.61287025098539, 11.25979612757008, 9.06964509742759, 6.95151107676987, 9.42778607029368, 10.47303920039719, 10.44412727693962]]},
8+
"unique_id": "8f29e3614cae5dac3cc876f507882c81",
9+
"user": "vfung"},
10+
"ids": [1],
11+
"nextid": 2}

data/dup_test_data/raw/14819.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{"1": {
2+
"cell": {"array": {"__ndarray__": [[3, 3], "float64", [20.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 20.0]]}, "__ase_objtype__": "cell"},
3+
"ctime": 21.05874054209235,
4+
"mtime": 21.05874054209235,
5+
"numbers": {"__ndarray__": [[10], "int64", [78, 78, 78, 78, 78, 78, 78, 78, 78, 78]]},
6+
"pbc": {"__ndarray__": [[3], "bool", [false, false, false]]},
7+
"positions": {"__ndarray__": [[10, 3], "float64", [10.58697034870614, 11.32570080861459, 7.8531655433036, 8.88203235512252, 9.47449623311041, 7.39016445707521, 7.37131874305553, 11.00853492036741, 11.26965535190918, 6.97478546544625, 9.37641285837044, 9.42611312613394, 6.93475606023258, 7.86658725388662, 7.38917256500832, 10.20026338308245, 8.1565296806003, 9.22041881442984, 8.01438597117835, 7.00698945977101, 9.52948310023797, 10.05070336340693, 6.12059294992494, 10.64617112200399, 11.63602029599972, 9.11149171272604, 7.39018251905607, 9.34876419376952, 10.55266412262823, 9.88547322084188]]},
8+
"unique_id": "189385f6a1f34ea43e6a0b24eb304b5a",
9+
"user": "vfung"},
10+
"ids": [1],
11+
"nextid": 2}

data/dup_test_data/raw/14820.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{"1": {
2+
"cell": {"array": {"__ndarray__": [[3, 3], "float64", [20.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 20.0]]}, "__ase_objtype__": "cell"},
3+
"ctime": 21.058740542131783,
4+
"mtime": 21.058740542131783,
5+
"numbers": {"__ndarray__": [[10], "int64", [78, 78, 78, 78, 78, 78, 78, 78, 78, 78]]},
6+
"pbc": {"__ndarray__": [[3], "bool", [false, false, false]]},
7+
"positions": {"__ndarray__": [[10, 3], "float64", [10.32881754423678, 11.35058580204961, 7.76994277162286, 8.80950616004137, 9.26816986079739, 7.44596393314179, 7.20814690601561, 11.03755905261314, 10.84806410171449, 6.74551470577747, 9.34242420623981, 9.06899993072359, 7.10097357339577, 7.38274218737191, 7.46993075109567, 10.7303096153175, 8.26584646391771, 9.50682140111885, 8.25197879832853, 7.3695585855024, 9.71398630514723, 10.08427757040348, 6.44215070986084, 11.07247994935158, 11.40641766064006, 9.15284805377567, 7.26477737472839, 9.3340574658433, 10.38811489787136, 9.83903348135529]]},
8+
"unique_id": "518f6c562e6bb890f9b1b92833a5662b",
9+
"user": "vfung"},
10+
"ids": [1],
11+
"nextid": 2}

data/dup_test_data/raw/14821.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{"1": {
2+
"cell": {"array": {"__ndarray__": [[3, 3], "float64", [20.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 20.0]]}, "__ase_objtype__": "cell"},
3+
"ctime": 21.058740542168554,
4+
"mtime": 21.058740542168554,
5+
"numbers": {"__ndarray__": [[10], "int64", [78, 78, 78, 78, 78, 78, 78, 78, 78, 78]]},
6+
"pbc": {"__ndarray__": [[3], "bool", [false, false, false]]},
7+
"positions": {"__ndarray__": [[10, 3], "float64", [10.10803710234159, 11.50743629820141, 7.84927342342025, 8.77458847596841, 9.30628333122176, 7.5339626875287, 7.09528877205127, 10.67175797622389, 11.13083121823173, 6.80746529952517, 9.04573633660961, 9.27479696534825, 7.03447773446927, 7.46044993960281, 7.28311692550707, 10.77896394653345, 8.4897768554759, 9.52149221766082, 8.52332189289585, 7.12359732139326, 9.29854393134033, 10.29380085622592, 6.49264153190104, 10.88833491052535, 11.35174514271483, 9.40236930363699, 7.25950635843704, 9.23231059727422, 10.49995092573332, 9.96014136200034]]},
8+
"unique_id": "61863842e04794083c736b368e778a4a",
9+
"user": "vfung"},
10+
"ids": [1],
11+
"nextid": 2}

data/dup_test_data/raw/14822.json

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{"1": {
2+
"cell": {"array": {"__ndarray__": [[3, 3], "float64", [20.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 20.0]]}, "__ase_objtype__": "cell"},
3+
"ctime": 21.05874054220892,
4+
"mtime": 21.05874054220892,
5+
"numbers": {"__ndarray__": [[10], "int64", [78, 78, 78, 78, 78, 78, 78, 78, 78, 78]]},
6+
"pbc": {"__ndarray__": [[3], "bool", [false, false, false]]},
7+
"positions": {"__ndarray__": [[10, 3], "float64", [10.24319485445007, 11.54787920895129, 8.05438714696011, 8.79514374618761, 9.53458956129494, 7.39237251198056, 7.17066023823562, 10.66400252920393, 11.317993642804, 7.00235080611994, 9.17105503580481, 9.32627655563444, 7.12734370656666, 7.64755046468221, 7.28643577802192, 10.33275230999235, 8.27529187605358, 9.16645282584009, 8.2827409631189, 6.90621935834306, 9.44243947135026, 10.34155967299444, 6.28325712963245, 10.69661086680073, 11.451585259049, 9.51161833211171, 7.26687658440996, 9.25266862328552, 10.45853668392191, 10.05015479619768]]},
8+
"unique_id": "e8212baf15a4ef28903ef7fce2c1db38",
9+
"user": "vfung"},
10+
"ids": [1],
11+
"nextid": 2}

0 commit comments

Comments
 (0)