Skip to content

Commit ce8bac4

Browse files
committed
fix cuda device mismatch
1 parent cb54822 commit ce8bac4

5 files changed

Lines changed: 81 additions & 69 deletions

File tree

src/graphnet/deployment/deployment_module.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ def _resolve_prediction_columns(
112112
)
113113
else:
114114
# Only Take First Label
115-
resolved_prediction_columns.append(model.prediction_labels[0])
115+
resolved_prediction_columns.append(
116+
model.prediction_labels[0]
117+
)
116118
return resolved_prediction_columns
117119

118120
def _inference(self, data: Union[Data, Batch]) -> List[np.ndarray]:
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
"""Deployment modules specific to IceCube."""
22

3-
from .inference_module import I3InferenceModule, I3ParticleInferenceModule, I3MultipleModelInferenceModule
3+
from .inference_module import (
4+
I3InferenceModule,
5+
I3ParticleInferenceModule,
6+
I3MultipleModelInferenceModule,
7+
)
48
from .cleaning_module import I3PulseCleanerModule
59
from .i3deployer import I3Deployer

src/graphnet/deployment/icecube/inference_module.py

Lines changed: 62 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -149,42 +149,10 @@ def __call__(self, frame: I3Frame) -> bool:
149149
return True
150150
# inference
151151
memory_watch = False
152-
if self._inference_speed_check is True:
153-
# create log file if it does not exist
154-
data_repr_start = time()
155152
try:
156-
if not self.multiple_models:
157-
data = self._create_data_representation(frame=frame).to(
158-
self._device
159-
)
160-
if self._inference_speed_check is True:
161-
data_repr_end = time()
162-
data_repr_time = data_repr_end - data_repr_start
163-
inference_start = time()
164-
predictions = self._apply_model(data=data)
165-
else:
166-
features = self._extract_feature_array_from_frame(frame=frame)
167-
if self._inference_speed_check is True:
168-
data_repr_end = time()
169-
data_repr_time = data_repr_end - data_repr_start
170-
inference_start = time()
171-
model_input_data = []
172-
for _, graph_definition in enumerate(self._graph_definitions):
173-
data = graph_definition(
174-
input_features=features[_],
175-
input_feature_names=self.features_list[_],
176-
)
177-
model_input_data.append(Batch.from_data_list([data]))
178-
179-
predictions = self._apply_model(data=model_input_data)
180-
181-
if self._inference_speed_check is True:
182-
inference_end = time()
183-
inference_time = inference_end - inference_start
184-
self._logger.info(
185-
f"Data representation time: {data_repr_time:.4f} s\n"
186-
f"Inference time: {inference_time:.4f} s\n"
187-
)
153+
predictions, data_repr_time, inference_time = (
154+
self._create_data_and_apply(frame=frame)
155+
)
188156

189157
except OutOfMemoryError:
190158
self.error(
@@ -195,22 +163,12 @@ def __call__(self, frame: I3Frame) -> bool:
195163
save_device = self._device
196164
self._device = "cpu"
197165
self.model.to(self._device)
198-
data = self._create_data_representation(frame=frame)
199-
if self._inference_speed_check is True:
200-
data_repr_end = time()
201-
data_repr_time = data_repr_end - data_repr_start
202-
inference_start = time()
203-
predictions = self._apply_model(data=data)
204-
if self._inference_speed_check is True:
205-
inference_end = time()
206-
inference_time = inference_end - inference_start
207-
self._logger.info(
208-
f"Data representation time: {data_repr_time:.4f} s\n"
209-
f"Inference time: {inference_time:.4f} s\n"
210-
)
211-
self._device = save_device
166+
167+
predictions, data_repr_time, inference_time = (
168+
self._create_data_and_apply(frame=frame)
169+
)
212170
memory_watch = True
213-
del data
171+
self._device = save_device
214172

215173
if self._inference_speed_check is True:
216174
write_start = time()
@@ -238,6 +196,49 @@ def __call__(self, frame: I3Frame) -> bool:
238196
self.model.to(self._device)
239197
return True
240198

199+
def _create_data_and_apply(self, frame: I3Frame) -> tuple:
200+
data_repr_time = -1
201+
inference_time = -1
202+
if self._inference_speed_check is True:
203+
# create log file if it does not exist
204+
data_repr_start = time()
205+
if not self.multiple_models:
206+
data = self._create_data_representation(frame=frame).to(
207+
self._device
208+
)
209+
if self._inference_speed_check is True:
210+
data_repr_end = time()
211+
data_repr_time = data_repr_end - data_repr_start
212+
inference_start = time()
213+
predictions = self._apply_model(data=data)
214+
else:
215+
features = self._extract_feature_array_from_frame(frame=frame)
216+
if self._inference_speed_check is True:
217+
data_repr_end = time()
218+
data_repr_time = data_repr_end - data_repr_start
219+
inference_start = time()
220+
model_input_data = []
221+
for _, graph_definition in enumerate(self._graph_definitions):
222+
data = graph_definition(
223+
input_features=features[_],
224+
input_feature_names=self.features_list[_],
225+
)
226+
model_input_data.append(
227+
Batch.from_data_list([data.to(self._device)])
228+
)
229+
230+
predictions = self._apply_model(data=model_input_data)
231+
232+
if self._inference_speed_check is True:
233+
inference_end = time()
234+
inference_time = inference_end - inference_start
235+
self._logger.info(
236+
f"Data representation time: {data_repr_time:.4f} s\n"
237+
f"Inference time: {inference_time:.4f} s\n"
238+
)
239+
del data
240+
return predictions, data_repr_time, inference_time
241+
241242
def _check_dimensions(self, predictions: np.ndarray) -> int:
242243
if len(predictions.shape) > 1:
243244
dim = predictions.shape[1]
@@ -284,7 +285,7 @@ def _apply_model(self, data: Data) -> np.ndarray:
284285
"""Apply model to `Data` and case-handling."""
285286
if data is not None:
286287
predictions = self._inference(data)
287-
#print(predictions, type(predictions), type(predictions[0]))
288+
# print(predictions, type(predictions), type(predictions[0]))
288289
if isinstance(predictions, list):
289290
predictions = np.concatenate(
290291
[pred.flatten() for pred in predictions]
@@ -316,8 +317,8 @@ def _create_data_representation(self, frame: I3Frame) -> Data:
316317
data = self._graph_definition(
317318
input_features=input_features,
318319
input_feature_names=self._features,
319-
)
320-
return Batch.from_data_list([data.to(self._device)])
320+
).to(self._device)
321+
return Batch.from_data_list([data])
321322
else:
322323
return None
323324

@@ -421,10 +422,10 @@ def __init__(
421422
len(self._positions) == 3
422423
), "positions must be a list of 3 elements"
423424

424-
def _get_min_time(self, frame: I3Frame) -> float:
425+
def _get_min_time(self, frame: I3Frame, pulsemap: str) -> float:
425426
"""Get the minimum time of the first pulse in the frame."""
426427
min_time = np.inf
427-
doms = frame[self._pulsemap].apply(frame).values()
428+
doms = frame[pulsemap].apply(frame).values()
428429
# seach for the minimum time
429430
for dom in doms:
430431
if dom[0].time < min_time:
@@ -451,10 +452,11 @@ def _add_to_frame(self, frame, data):
451452

452453
if self._shift_time:
453454
# Shift time to be relative to the first pulse
454-
shift_time = (
455-
self._get_min_time(frame)
456-
- frame["CVStatistics"].min_pulse_time
457-
)
455+
shift_time = self._get_min_time(frame, self._pulsemap)
456+
if "CVStatistics" in frame:
457+
shift_time -= frame["CVStatistics"].min_pulse_time
458+
else:
459+
shift_time -= self._get_min_time(frame, "InIcePulses")
458460
particle.time = data[self._time].value + shift_time
459461
else:
460462
particle.time = data[self._time].value
@@ -479,6 +481,7 @@ def _add_to_frame(self, frame, data):
479481
super()._add_to_frame(frame=frame, data=data)
480482
return
481483

484+
482485
class I3MultipleModelInferenceModule(I3InferenceModule):
483486
"""I3InferenceModule for I3Particle data."""
484487

@@ -499,4 +502,3 @@ def _add_to_frame(self, frame, data):
499502

500503
i3_score_container = dataclasses.I3MapStringDouble(data)
501504
frame.Put(self._key_name, i3_score_container)
502-

src/graphnet/models/components/layers.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,12 @@ def forward(
6767
x = super().forward(x, edge_index)
6868

6969
# Recompute adjacency
70-
edge_index = knn_graph(
71-
x=x[:, self.features_subset],
72-
k=self.nb_neighbors,
73-
batch=batch,
74-
).to(self.device)
70+
with torch.cuda.device(x.device):
71+
edge_index = knn_graph(
72+
x=x[:, self.features_subset],
73+
k=self.nb_neighbors,
74+
batch=batch,
75+
).to(x.device)
7576

7677
return x, edge_index
7778

src/graphnet/models/gnn/dynedge.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,11 @@ def _global_pooling(self, x: Tensor, batch: LongTensor) -> Tensor:
254254
assert self._global_pooling_schemes
255255
pooled = []
256256
for pooling_scheme in self._global_pooling_schemes:
257-
pooling_fn = GLOBAL_POOLINGS[pooling_scheme]
258-
pooled_x = pooling_fn(x, index=batch, dim=0)
257+
with torch.cuda.device(
258+
x.device
259+
): # Ensure pooling is performed on the same device as x
260+
pooling_fn = GLOBAL_POOLINGS[pooling_scheme]
261+
pooled_x = pooling_fn(x, index=batch, dim=0)
259262
if isinstance(pooled_x, tuple) and len(pooled_x) == 2:
260263
# `scatter_{min,max}`, which return also an argument, vs.
261264
# `scatter_{mean,sum}`

0 commit comments

Comments
 (0)