1- import sys
1+ import logging
22from pathlib import Path
3+
4+ import lightning as L
35import torch
46import torch .nn as nn
57import torch .nn .functional as F
6- import lightning as L
7- import logging
8+ from scipy .stats import pearsonr
89
910try :
1011 import wandb
1920
2021logger = logging .getLogger (__name__ )
2122
23+ MAE = nn .L1Loss ()
24+
2225
2326class LogLowestMAE (L .Callback ):
2427 def __init__ (self , config ):
25- super (LogLowestMAE , self ).__init__ ()
28+ super ().__init__ ()
2629 self .bestMAE = float ("inf" )
2730 self .config = config
2831
@@ -40,7 +43,7 @@ def on_validation_end(self, trainer, pl_module):
4043
4144class LRelu_with_saturation (nn .Module ):
4245 def __init__ (self , negative_slope , saturation ):
43- super (LRelu_with_saturation , self ).__init__ ()
46+ super ().__init__ ()
4447 self .negative_slope = negative_slope
4548 self .saturation = saturation
4649 self .leaky_relu = nn .LeakyReLU (self .negative_slope )
@@ -61,7 +64,7 @@ def __init__(
6164 negative_slope ,
6265 saturation ,
6366 ):
64- super (Conv1dActivation , self ).__init__ ()
67+ super ().__init__ ()
6568 self .conv = nn .Conv1d (in_channels , out_channels , kernel_size , padding = padding )
6669 self .initializer = initializer
6770 self .activation = LRelu_with_saturation (
@@ -76,7 +79,7 @@ def forward(self, x):
7679
7780class DenseActivation (nn .Module ):
7881 def __init__ (self , in_features , out_features , initializer , negative_slope , saturation ):
79- super (DenseActivation , self ).__init__ ()
82+ super ().__init__ ()
8083 self .linear = nn .Linear (in_features , out_features )
8184 self .initializer = initializer
8285 self .activation = LRelu_with_saturation (
@@ -91,7 +94,7 @@ def forward(self, x):
9194
9295class SelfAttention (nn .Module ):
9396 def __init__ (self , feature_dim , heads = 1 ):
94- super (SelfAttention , self ).__init__ ()
97+ super ().__init__ ()
9598 self .feature_dim = feature_dim
9699 self .heads = heads
97100 # self.padded_dim = self.feature_dim + (self.feature_dim % self.heads)
@@ -152,7 +155,7 @@ def forward(self, x):
152155
153156class Branch (nn .Module ):
154157 def __init__ (self , input_size , output_size , add_layer = 1 , dropout_rate = 0.0 ):
155- super (Branch , self ).__init__ ()
158+ super ().__init__ ()
156159 self .add_layer = add_layer
157160 if self .add_layer :
158161 self .fc1 = nn .Linear (input_size , output_size )
@@ -172,7 +175,7 @@ def forward(self, x):
172175
173176class IM2Deep (L .LightningModule ):
174177 def __init__ (self , config , criterion ):
175- super (IM2Deep , self ).__init__ ()
178+ super ().__init__ ()
176179 self .config = config
177180 self .criterion = criterion
178181 self .mae = nn .L1Loss ()
@@ -628,7 +631,7 @@ def configure_init(self):
628631
629632class IM2DeepMulti (L .LightningModule ):
630633 def __init__ (self , config , criterion ):
631- super (IM2DeepMulti , self ).__init__ ()
634+ super ().__init__ ()
632635 self .config = config
633636 self .criterion = criterion
634637
@@ -1106,7 +1109,7 @@ def configure_init(self):
11061109
11071110class IM2DeepMultiTransfer (L .LightningModule ):
11081111 def __init__ (self , config , criterion ):
1109- super (IM2DeepMultiTransfer , self ).__init__ ()
1112+ super ().__init__ ()
11101113 # TODO: config should be adapted in config file
11111114 self .config = config
11121115 self .criterion = criterion
@@ -1123,7 +1126,7 @@ def __init__(self, config, criterion):
11231126 self .ConvGlobal = self .backbone .ConvGlobal
11241127 self .OneHot = self .backbone .OneHot
11251128
1126- if self .config .get ("add_X_mol" , False ) == True :
1129+ if self .config .get ("add_X_mol" , False ):
11271130 self .MolDesc = self .backbone .MolDesc
11281131
11291132 self .concat = list (self .backbone .Concat .children ())[:- 1 ]
@@ -1294,7 +1297,7 @@ def configure_optimizers(self):
12941297
12951298class IM2DeepTransfer (L .LightningModule ):
12961299 def __init__ (self , config , criterion ):
1297- super (IM2DeepTransfer , self ).__init__ ()
1300+ super ().__init__ ()
12981301
12991302 self .config = config
13001303 self .criterion = criterion
@@ -1312,7 +1315,7 @@ def __init__(self, config, criterion):
13121315 self .ConvGlobal = self .backbone .ConvGlobal
13131316 self .OneHot = self .backbone .OneHot
13141317
1315- if self .config .get ("add_X_mol" , False ) == True :
1318+ if self .config .get ("add_X_mol" , False ):
13161319 self .MolDesc = self .backbone .MolDesc
13171320
13181321 self .concat = self .backbone .Concat
@@ -1439,7 +1442,7 @@ def configure_optimizers(self):
14391442
14401443class FlexibleLossSorted (nn .Module ):
14411444 def __init__ (self , diversity_weight = 0.1 ):
1442- super (FlexibleLossSorted , self ).__init__ ()
1445+ super ().__init__ ()
14431446 self .diversity_weight = diversity_weight
14441447
14451448 def forward (self , y1 , y2 , y_hat1 , y_hat2 ):
@@ -1479,7 +1482,7 @@ def forward(self, y1, y2, y_hat1, y_hat2):
14791482
14801483class FlexibleLoss (nn .Module ):
14811484 def __init__ (self , diversity_weight = 0.1 ):
1482- super (FlexibleLoss , self ).__init__ ()
1485+ super ().__init__ ()
14831486 self .diversity_weight = diversity_weight
14841487
14851488 def forward (self , y1 , y2 , y_hat1 , y_hat2 ):
0 commit comments