1515"""Benchmark JetStream online serving.
1616
1717On the server side, run one of the following commands:
18- * For real server, you need to pass correct server config (include the model config that
19- being passed into your engine impl) to the command below. Refer to config_lib.py and
20- implementations/mock/config.py for config impl detail.
18+ * For real server, you need to pass correct server config (include the
19+ model config that being passed into your engine impl) to the command
20+ below. Refer to config_lib.py and implementations/mock/config.py for
21+ config impl detail.
2122
2223 (run with real server)
2324 python -m jetstream.core.implementations.<your_impl>.server \
2728 python -m jetstream.core.implementations.mock.server
2829
2930On the client side, run:
30- * For real server and shareGPT dataset, you need to pass the tokenizer, server config, and
31- dataset flags to the command below, and make some changes to the tokenizer logic in the
32- benchmark script (get_tokenizer and sample_requests func) to use your tokenizer correctly.
33- * Add `--save-result` flag to save the benchmark result to a json file in current folder.
34- * Add `--threads` flag to set the maximum number of threads used for request dispatching.
31+ * For real server and shareGPT dataset, you need to pass the tokenizer,
32+ server config, and dataset flags to the command below, and make some
33+ changes to the tokenizer logic in the benchmark script (get_tokenizer
34+ and sample_requests func) to use your tokenizer correctly.
35+ * Add `--save-result` flag to save the benchmark result to a json file in
36+ current folder.
3537
3638 (run with real model and engines)
3739 python -m benchmarks.benchmark_serving \
7476
7577@dataclass
7678class BenchmarkMetrics :
79+ """Data class to store benchmark metrics."""
80+
7781 completed : int
7882 total_input : int
7983 total_output : int
@@ -136,7 +140,7 @@ def load_sharegpt_dataset(
136140 conversation_starter : str ,
137141) -> List [tuple [str ]]:
138142 # Load the dataset.
139- with open (dataset_path ) as f :
143+ with open (dataset_path , "r" , encoding = "utf-8" ) as f :
140144 dataset = json .load (f )
141145 # Filter out the conversations with less than 2 turns.
142146 dataset = [data for data in dataset if len (data ["conversations" ]) >= 2 ]
@@ -159,7 +163,7 @@ def load_sharegpt_dataset(
159163
160164def load_openorca_dataset (dataset_path : str ) -> List [tuple [str ]]:
161165 # Load the dataset.
162- with open (dataset_path ) as f :
166+ with open (dataset_path , "r" , encoding = "utf-8" ) as f :
163167 dataset = json .load (f )
164168
165169 # Tokenize the prompts and completions.
@@ -211,7 +215,7 @@ def filter_dataset(
211215 filtered_dataset : List [InputRequest ] = []
212216 for (
213217 prompt ,
214- prompt_token_ids ,
218+ _ ,
215219 output ,
216220 prompt_len ,
217221 output_len ,
@@ -255,7 +259,7 @@ def sample_requests(
255259 print (
256260 f"Number of requests { num_requests } is larger than size of dataset"
257261 f" { n } .\n " ,
258- f "Repeating data to meet number of requests.\n " ,
262+ "Repeating data to meet number of requests.\n " ,
259263 )
260264 sampled_indices = sampled_indices * int (
261265 np .ceil (num_requests / len (sampled_indices ))
@@ -361,7 +365,6 @@ async def send_request(
361365 pbar : tqdm ,
362366 session_cache : str ,
363367 priority : int ,
364- threads : int ,
365368) -> RequestFuncOutput :
366369 """Send the request to JetStream server."""
367370 request = jetstream_pb2 .DecodeRequest (
@@ -394,7 +397,6 @@ async def benchmark(
394397 disable_tqdm : bool ,
395398 session_cache : str ,
396399 priority : int ,
397- threads : int ,
398400):
399401 """Benchmark the online serving performance."""
400402 pbar = None if disable_tqdm else tqdm (total = len (input_requests ))
@@ -412,7 +414,6 @@ async def benchmark(
412414 pbar = pbar ,
413415 session_cache = session_cache ,
414416 priority = priority ,
415- threads = threads ,
416417 )
417418 )
418419 )
@@ -519,8 +520,8 @@ def main(args: argparse.Namespace):
519520 )
520521
521522 # A given args.max_output_length value is the max generation step,
522- # when the args.max_output_length is default to None, the sample's golden output length
523- # will be used to decide the generation step
523+ # when the args.max_output_length is default to None, the sample's golden
524+ # output length will be used to decide the generation step.
524525 input_requests = sample_requests (
525526 dataset = dataset ,
526527 tokenizer = tokenizer ,
@@ -540,7 +541,6 @@ def main(args: argparse.Namespace):
540541 disable_tqdm = args .disable_tqdm ,
541542 session_cache = args .session_cache ,
542543 priority = args .priority ,
543- threads = args .threads ,
544544 )
545545 )
546546 print ("Warm up done" )
@@ -554,7 +554,6 @@ def main(args: argparse.Namespace):
554554 disable_tqdm = args .disable_tqdm ,
555555 session_cache = args .session_cache ,
556556 priority = args .priority ,
557- threads = args .threads ,
558557 )
559558 )
560559
@@ -582,12 +581,12 @@ def main(args: argparse.Namespace):
582581 file_name = (
583582 f"JetStream-{ args .request_rate } qps-{ base_model_id } -{ current_dt } .json"
584583 )
585- with open (file_name , "w" ) as outfile :
584+ with open (file_name , "w" , encoding = "utf-8" ) as outfile :
586585 json .dump (result_json , outfile )
587586
588587 if args .save_request_outputs :
589588 file_path = args .request_outputs_file_path
590- with open (file_path , "w" ) as output_file :
589+ with open (file_path , "w" , encoding = "utf-8" ) as output_file :
591590 json .dump (
592591 [output .to_dict () for output in request_outputs ],
593592 output_file ,
@@ -653,12 +652,6 @@ def main(args: argparse.Namespace):
653652 "the request arrival times."
654653 ),
655654 )
656- parser .add_argument (
657- "--threads" ,
658- type = int ,
659- default = 110 ,
660- help = "The maximum number of threads used for request dispatching." ,
661- )
662655 parser .add_argument (
663656 "--total-mock-requests" ,
664657 type = int ,
@@ -736,5 +729,5 @@ def main(args: argparse.Namespace):
736729 help = "What entity should be the one starting the conversations." ,
737730 )
738731
739- args = parser .parse_args ()
740- main (args )
732+ parsed_args = parser .parse_args ()
733+ main (parsed_args )
0 commit comments