Skip to content

Commit a70526b

Browse files
feature: pre-commit hook for linting (#5)
1 parent 8cd46d1 commit a70526b

34 files changed

Lines changed: 608 additions & 446 deletions

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,4 @@ server/
173173
main.py
174174

175175
test*.py
176-
test*.ipynb
176+
test*.ipynb

.isort.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[settings]
2+
known_third_party = ase,numpy,pandas,scipy,setuptools,torch,torch_geometric,torch_scatter,tqdm,yaml

.pre-commit-config.yaml

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v4.3.0
4+
hooks:
5+
# - id: check-yaml
6+
- id: end-of-file-fixer
7+
- id: trailing-whitespace
8+
# isort
9+
- repo: https://github.com/asottile/seed-isort-config
10+
rev: v2.2.0
11+
hooks:
12+
- id: seed-isort-config
13+
- repo: https://github.com/pre-commit/mirrors-isort
14+
rev: v5.10.1
15+
hooks:
16+
- id: isort
17+
args: ["--profile", "black"]
18+
# flake8
19+
- repo: https://github.com/pycqa/flake8
20+
rev: 5.0.4
21+
hooks:
22+
- id: flake8
23+
args: # arguments to configure flake8
24+
# making isort line length compatible with black
25+
- "--max-line-length=88"
26+
- "--max-complexity=18"
27+
- "--select=B,C,E,F,W,T4,B9"
28+
# these are errors that will be ignored by flake8
29+
# check out their meaning here
30+
# https://flake8.pycqa.org/en/latest/user/error-codes.html
31+
- "--ignore=E203,E266,E501,W503,F403,F401,E402"
32+
# black
33+
- repo: https://github.com/psf/black
34+
rev: 22.10.0
35+
hooks:
36+
- id: black
37+
args: # arguments to configure black
38+
- --line-length=88
39+
- --include='\.pyi?$'
40+
# these folders wont be formatted by black
41+
- --exclude="""\.git |
42+
\.__pycache__|
43+
\.hg|
44+
\.mypy_cache|
45+
\.tox|
46+
\.venv|
47+
_build|
48+
buck-out|
49+
build|
50+
dist"""
51+
language_version: python3.9

README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,15 @@
1-
pip install -e .
1+
## Development
2+
3+
Install matdeeplearn with `pip install -e .`
4+
5+
#### Code Quality
6+
This project uses flake8, black, and isort for linting.
7+
To install the pre-commit git hook, run:
8+
```
9+
pre-commit install
10+
```
11+
By default, the hooks will run every time you say:
12+
```
13+
git commit -m "Commit message"
14+
```
15+
However, for more information, please see: https://pre-commit.com/#usage

data/pt_data_forces_500/pt_data_forces_500.json

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

matdeeplearn/README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-

matdeeplearn/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1+
from matdeeplearn.common.data import *
2+
13
from .models import *
24
from .preprocessor import *
3-
4-
from matdeeplearn.common.data import *

matdeeplearn/common/config/build_config.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
21
import ast
32
import copy
43
import logging
54
import os
5+
from pathlib import Path
6+
67
import yaml
78

8-
from pathlib import Path
99

1010
def merge_dicts(dict1: dict, dict2: dict):
1111
"""Recursively merge two dictionaries.
@@ -48,6 +48,7 @@ def merge_dicts(dict1: dict, dict2: dict):
4848

4949
return return_dict, duplicates
5050

51+
5152
def dict_set_recursively(dictionary, key_sequence, val):
5253
top_key = key_sequence.pop(0)
5354
if len(key_sequence) == 0:
@@ -84,9 +85,8 @@ def create_dict_from_args(args: list, sep: str = "."):
8485
return return_dict
8586

8687

87-
8888
def build_config(args, args_override):
89-
##Open provided config file
89+
# Open provided config file
9090
assert os.path.exists(args.config_path), (
9191
"Config file not found in " + args.config_path
9292
)
@@ -106,7 +106,7 @@ def build_config(args, args_override):
106106
config["submit"] = args.submit
107107
# config["summit"] = args.summit
108108
# Distributed
109-
#TODO: add distributed flags
109+
# TODO: add distributed flags
110110

111111
# if run_mode != "Hyperparameter":
112112
#
@@ -119,5 +119,4 @@ def build_config(args, args_override):
119119
# config["Processing"],
120120
# )
121121

122-
123122
return config

matdeeplearn/common/config/flags.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
21
import argparse
32
from pathlib import Path
43

54

65
class Flags:
76
def __init__(self):
8-
self.parser = argparse.ArgumentParser(
9-
description="MatDeepLearn inputs"
10-
)
7+
self.parser = argparse.ArgumentParser(description="MatDeepLearn inputs")
118
self.add_core_args()
129

1310
def get_parser(self):
@@ -118,12 +115,12 @@ def add_core_args(self):
118115
# parser.add_argument("--batch_size", default=None, type=int, help="batch size")
119116
# parser.add_argument("--lr", default=None, type=float, help="learning rate")
120117

121-
#TODO: add cluster args
118+
# TODO: add cluster args
122119
self.parser.add_argument(
123120
"--submit", action="store_true", help="Submit job to cluster"
124121
)
125-
#TODO: add checkpoint arg
126-
#TODO: timestamp id arg?
122+
# TODO: add checkpoint arg
123+
# TODO: timestamp id arg?
127124

128125

129126
flags = Flags()

matdeeplearn/common/data.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,73 @@
1-
import torch
21
import warnings
32

3+
import torch
44
from torch.utils.data import random_split
55
from torch_geometric.loader import DataLoader
66

7-
from matdeeplearn.preprocessor.transforms import *
8-
from matdeeplearn.preprocessor.datasets import StructureDataset, LargeStructureDataset
7+
from matdeeplearn.preprocessor.datasets import LargeStructureDataset, StructureDataset
8+
from matdeeplearn.preprocessor.transforms import GetY
9+
910

1011
# train test split
1112
def dataset_split(
1213
dataset,
1314
train_size: float = 0.8,
1415
valid_size: float = 0.05,
1516
test_size: float = 0.15,
16-
seed: int = 1234
17+
seed: int = 1234,
1718
):
18-
'''
19+
"""
1920
Splits an input dataset into 3 subsets: train, validation, test.
2021
Requires train_size + valid_size + test_size = 1
2122
2223
Parameters
2324
----------
2425
dataset: matdeeplearn.preprocessor.datasets.StructureDataset
2526
a dataset object that contains the target data
26-
27+
2728
train_size: float
2829
a float between 0.0 and 1.0 that represents the proportion
2930
of the dataset to use as the training set
3031
3132
valid_size: float
3233
a float between 0.0 and 1.0 that represents the proportion
3334
of the dataset to use as the validation set
34-
35+
3536
test_size: float
3637
a float between 0.0 and 1.0 that represents the proportion
3738
of the dataset to use as the test set
38-
'''
39+
"""
3940
if train_size + valid_size + test_size != 1:
4041
warnings.warn("Invalid sizes detected. Using default split of 80/5/15.")
4142
train_size, valid_size, test_size = 0.8, 0.05, 0.15
4243

4344
dataset_size = len(dataset)
44-
45+
4546
train_len = int(train_size * dataset_size)
4647
valid_len = int(valid_size * dataset_size)
4748
test_len = int(test_size * dataset_size)
4849
unused_len = dataset_size - train_len - valid_len - test_len
4950

50-
(
51-
train_dataset,
52-
val_dataset,
53-
test_dataset,
54-
unused_dataset
55-
) = random_split(
51+
(train_dataset, val_dataset, test_dataset, unused_dataset) = random_split(
5652
dataset,
5753
[train_len, valid_len, test_len, unused_len],
58-
generator=torch.Generator().manual_seed(seed)
54+
generator=torch.Generator().manual_seed(seed),
5955
)
6056

6157
return train_dataset, val_dataset, test_dataset
6258

59+
6360
def get_dataset(
64-
data_path,
65-
target_index: int = 0,
66-
transform_type='GetY',
67-
large_dataset=False
61+
data_path, target_index: int = 0, transform_type="GetY", large_dataset=False
6862
):
69-
'''
63+
"""
7064
get dataset according to data_path
7165
this assumes that the data has already been processed and
7266
data.pt file exists in data_path/processed/ folder
7367
7468
Parameters
7569
----------
76-
70+
7771
data_path: str
7872
path to the folder containing data.pt file
7973
@@ -85,13 +79,13 @@ def get_dataset(
8579
the current run/experiment
8680
8781
transform_type: transformation function/class to be applied
88-
'''
89-
82+
"""
83+
9084
# set transform method
91-
if transform_type == 'GetY':
85+
if transform_type == "GetY":
9286
T = GetY
9387
else:
94-
raise ValueError('No such transform found for {transform}')
88+
raise ValueError("No such transform found for {transform}")
9589

9690
# check if large dataset is needed
9791
if large_dataset:
@@ -101,37 +95,38 @@ def get_dataset(
10195

10296
transform = T(index=target_index)
10397

104-
return Dataset(data_path, processed_data_path='', transform=transform)
98+
return Dataset(data_path, processed_data_path="", transform=transform)
99+
105100

106101
def get_dataloader(
107102
dataset,
108103
batch_size: int,
109104
num_workers: int = 0,
110-
sampler = None,
105+
sampler=None,
111106
):
112-
'''
107+
"""
113108
Returns a single dataloader for a given dataset
114109
115110
Parameters
116111
----------
117112
dataset: matdeeplearn.preprocessor.datasets.StructureDataset
118113
a dataset object that contains the target data
119-
114+
120115
batch_size: int
121116
size of each batch
122117
123118
num_workers: int
124119
how many subprocesses to use for data loading. 0 means that
125120
the data will be loaded in the main process.
126-
'''
121+
"""
127122

128123
# load data
129124
loader = DataLoader(
130125
dataset,
131126
batch_size=batch_size,
132127
shuffle=(sampler is None),
133128
num_workers=num_workers,
134-
sampler=sampler
129+
sampler=sampler,
135130
)
136131

137-
return loader
132+
return loader

0 commit comments

Comments
 (0)