Conversation
There was a problem hiding this comment.
Code Review
This pull request corrects the batch size calculation logic in cuda_graph.py by ensuring the mtp_step multiplier is applied correctly within the range. A review comment suggests using a more idiomatic range with a step to improve readability and reuse existing variables.
| graph_grow_step_size = self.args.graph_grow_step_size * (self.mtp_step + 1) | ||
|
|
||
| batch_sizes = [i * (self.mtp_step + 1) for i in range(1, graph_split_batch_size + 1)] | ||
| batch_sizes = [i * (self.mtp_step + 1) for i in range(1, self.args.graph_split_batch_size + 1)] |
There was a problem hiding this comment.
While this line correctly fixes the bug, its logic can be expressed more directly and readably. You are generating a sequence of batch sizes that are multiples of (self.mtp_step + 1), up to graph_split_batch_size. This can be achieved more concisely using range() with a step. This approach also has the benefit of reusing the graph_split_batch_size variable defined on line 33, improving code clarity.
| batch_sizes = [i * (self.mtp_step + 1) for i in range(1, self.args.graph_split_batch_size + 1)] | |
| batch_sizes = list(range(self.mtp_step + 1, graph_split_batch_size + 1, self.mtp_step + 1)) |
No description provided.