Skip to content

Commit a8146a9

Browse files
committed
refactor: integrate GraphClass in Settings
- Integrate GraphClass into settings This change hopefully helps change Labels in the future if needed.
1 parent ce2e80d commit a8146a9

3 files changed

Lines changed: 5 additions & 8 deletions

File tree

CustomDataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from DataformatUtils import convert_edge_dim, convert_list_to_float_tensor, convert_list_to_long_tensor, \
1010
convert_hashed_names_to_float
1111
from Encoder import multi_hot_encoding
12-
from GraphClasses import defined_labels
12+
from settings import CONFIG
1313

1414

1515
class RepositoryDataset(Dataset):
@@ -30,7 +30,8 @@ def __init__(self, directory, label_list=None):
3030
print(e)
3131
# nodes have 11 features, their one hot encoded node type, hashed name, and one hot encoded library flag
3232
self.num_node_features = 11
33-
self.num_classes = len(defined_labels)
33+
self.defined_labels = CONFIG['graph']['defined_labels']
34+
self.num_classes = len(self.defined_labels)
3435
self.directory = directory
3536
self.graph_names = []
3637
self.graph_dir = os.listdir(directory)
@@ -162,7 +163,7 @@ def convert_labeled_graphs(self, labels):
162163
graph_labels) # count how many repos are in each class
163164

164165
# encode labels
165-
encoded_nodes = multi_hot_encoding(defined_labels, graph_labels)
166+
encoded_nodes = multi_hot_encoding(self.defined_labels, graph_labels)
166167
file = zip(graph_names, encoded_nodes)
167168
return file
168169

GraphClasses.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from CustomDataset import RepositoryDataset
1313
from GCN import GCN
14-
from GraphClasses import defined_labels
1514
from settings import CONFIG
1615

1716
'''please prepare the dataset you want to train the tool with by using prepareDataset.py,
@@ -27,6 +26,7 @@
2726
threshold = CONFIG['training']['threshold']
2827
save_classification_reports = CONFIG['training']['save_classification_reports']
2928
experiment_name = CONFIG['training']['experiment_name']
29+
defined_labels = CONFIG['graph']['defined_labels']
3030

3131

3232
def train():

0 commit comments

Comments
 (0)