Skip to content

Commit 014383b

Browse files
committed
improve gpu memroy efficiency
1 parent eb5fad9 commit 014383b

2 files changed

Lines changed: 12 additions & 1 deletion

File tree

test/test_cav_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def test_all(self):
279279
cav_trainer.set_control(builder.control_concepts[0], num_samples=100)
280280

281281
cav_trainer.train_concepts(
282-
builder.concepts, 100, output_dir="data/cavs/", num_processes=1, backend='torch'
282+
builder.concepts, 100, output_dir="data/cavs/", num_processes=2, backend='torch'
283283
)
284284
cav_trainer.train_concepts(
285285
builder.concepts, 100, output_dir="data/cavs/", num_processes=2

tpcav/cavs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import logging
77
import os
8+
import gc
89
from pathlib import Path
910
from typing import Iterable, List, Optional, Tuple, Union
1011
import time
@@ -134,6 +135,10 @@ def fit(self, train_val_avs: np.ndarray, train_val_ls: np.ndarray):
134135
if (best_loss is None) or (loss < best_loss):
135136
best_loss = loss
136137
best_state_dict = state_dict
138+
139+
del model
140+
gc.collect()
141+
torch.cuda.empty_cache()
137142

138143
self.best_model = _TorchLinear(self.input_dim, self.num_class)
139144
self.best_model.load_state_dict(best_state_dict)
@@ -277,6 +282,12 @@ def _eval(avs, l, name: str):
277282
assert len(weights.shape) == 2 and weights.shape[0] == 2
278283
torch.save(weights, output_dir / "classifier_weights.pt")
279284

285+
if backend == 'torch':
286+
# release gpu memroy
287+
del clf.best_model
288+
gc.collect()
289+
torch.cuda.empty_cache()
290+
280291
return test_fscore, weights[0]
281292

282293

0 commit comments

Comments
 (0)