Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions TraceLens/EventReplay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,21 @@
#
# See LICENSE for license information.
###############################################################################

from .event_replay import EventReplayer
from .custom_inits import (
CustomInit,
PagedAttentionInit,
MoeRoutingInit,
extract_batch_context,
)
from .utils import benchmark_func

__all__ = [
"EventReplayer",
"CustomInit",
"PagedAttentionInit",
"MoeRoutingInit",
"extract_batch_context",
"benchmark_func",
]
21 changes: 15 additions & 6 deletions TraceLens/EventReplay/batched_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,17 @@ def _get_args_kwargs_from_ir(
replayed_count = 0
errors = 0

for i, repro_info in enumerate(repro_data_list):
ops_to_replay = repro_data_list
if args.op_filter:
ops_to_replay = [r for r in ops_to_replay if args.op_filter in r["op_name"]]
if args.op_limit:
ops_to_replay = ops_to_replay[: args.op_limit]

for i, repro_info in enumerate(ops_to_replay):

op_name = repro_info["op_name"]
replay_ir = repro_info["replay_ir"]
print(f"\n[{replayed_count + 1}/{len(repro_data_list)}] Replaying: {op_name}")
print(f"\n[{replayed_count + 1}/{len(ops_to_replay)}] Replaying: {op_name}")

# Get the PyTorch operation function
try:
Expand Down Expand Up @@ -151,15 +157,16 @@ def _get_args_kwargs_from_ir(
errors += 1
continue
# --- Benchmark the function ---
mean_time_us = benchmark_func(
metrics = benchmark_func(
lambda: func(*pos_args, **kwargs), args.device, warmup=50, avg_steps=100
)
print(f" Average time taken: {mean_time_us:.2f} microseconds")
mean_time_us = metrics["mean_us"]
print(f" Average time taken: {mean_time_us:.2f} us (median: {metrics['median_us']:.2f} us)")
if "count" in repro_info:
count_workload = repro_info["count"]
total_time_us = mean_time_us * count_workload
print(f" Count in workload: {count_workload}")
print(f" Est time in workload: {total_time_us:.2f} microseconds")
print(f" Est time in workload: {total_time_us:.2f} us")
# --- Optionally sync again ---
if args.device == "cuda":
torch.cuda.synchronize()
Expand Down Expand Up @@ -190,7 +197,9 @@ def _get_args_kwargs_from_ir(
print("\n--- Replay Summary ---")
print(f"Total operations in file: {len(repro_data_list)}")
if args.op_filter:
print(f"Filter applied: '{args.op_filter}'")
print(f"Filter applied: '{args.op_filter}' ({len(ops_to_replay)} matched)")
if args.op_limit:
print(f"Limit applied: {args.op_limit}")
print(f"Attempted replays: {replayed_count}")
print(f"Successful replays: {replayed_count - errors}")
print(f"Errors encountered: {errors}")
Expand Down
Loading
Loading