Skip to content

Commit 385ecac

Browse files
authored
Enable pylint linter and pyink formatter (#29)
* Enable pylint linter and pyink formatter * Update github action job name
1 parent 90b2a9d commit 385ecac

33 files changed

Lines changed: 873 additions & 424 deletions

.github/workflows/UnitTests.yaml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ on:
2828

2929
jobs:
3030
py:
31-
name: "Python type check"
31+
name: "Python type/lint/format checks"
3232
strategy:
3333
matrix:
3434
os: [ubuntu-20.04]
@@ -43,10 +43,18 @@ jobs:
4343
- name: Install Dependencies
4444
run: |
4545
pip install pytype
46+
pip install pylint
47+
pip install pyink
4648
pip install -r requirements.txt
4749
- name: Typecheck the code with pytype
4850
run: |
4951
pytype --jobs auto --disable import-error --disable module-attr jetstream/
52+
- name: Analysing the code with pylint
53+
run: |
54+
pylint jetstream/ benchmarks/
55+
- name: Format check with pyink
56+
run: |
57+
pyink --pyink-indentation 2 --line-length 80 --check --verbose .
5058
5159
cpu:
5260
name: "JetStream unit tests"

benchmarks/benchmark_serving.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
"""Benchmark JetStream online serving.
1616
1717
On the server side, run one of the following commands:
18-
* For real server, you need to pass correct server config (include the model config that
19-
being passed into your engine impl) to the command below. Refer to config_lib.py and
20-
implementations/mock/config.py for config impl detail.
18+
* For real server, you need to pass correct server config (include the
19+
model config that being passed into your engine impl) to the command
20+
below. Refer to config_lib.py and implementations/mock/config.py for
21+
config impl detail.
2122
2223
(run with real server)
2324
python -m jetstream.core.implementations.<your_impl>.server \
@@ -27,11 +28,12 @@
2728
python -m jetstream.core.implementations.mock.server
2829
2930
On the client side, run:
30-
* For real server and shareGPT dataset, you need to pass the tokenizer, server config, and
31-
dataset flags to the command below, and make some changes to the tokenizer logic in the
32-
benchmark script (get_tokenizer and sample_requests func) to use your tokenizer correctly.
33-
* Add `--save-result` flag to save the benchmark result to a json file in current folder.
34-
* Add `--threads` flag to set the maximum number of threads used for request dispatching.
31+
* For real server and shareGPT dataset, you need to pass the tokenizer,
32+
server config, and dataset flags to the command below, and make some
33+
changes to the tokenizer logic in the benchmark script (get_tokenizer
34+
and sample_requests func) to use your tokenizer correctly.
35+
* Add `--save-result` flag to save the benchmark result to a json file in
36+
current folder.
3537
3638
(run with real model and engines)
3739
python -m benchmarks.benchmark_serving \
@@ -74,6 +76,8 @@
7476

7577
@dataclass
7678
class BenchmarkMetrics:
79+
"""Data class to store benchmark metrics."""
80+
7781
completed: int
7882
total_input: int
7983
total_output: int
@@ -136,7 +140,7 @@ def load_sharegpt_dataset(
136140
conversation_starter: str,
137141
) -> List[tuple[str]]:
138142
# Load the dataset.
139-
with open(dataset_path) as f:
143+
with open(dataset_path, "r", encoding="utf-8") as f:
140144
dataset = json.load(f)
141145
# Filter out the conversations with less than 2 turns.
142146
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
@@ -159,7 +163,7 @@ def load_sharegpt_dataset(
159163

160164
def load_openorca_dataset(dataset_path: str) -> List[tuple[str]]:
161165
# Load the dataset.
162-
with open(dataset_path) as f:
166+
with open(dataset_path, "r", encoding="utf-8") as f:
163167
dataset = json.load(f)
164168

165169
# Tokenize the prompts and completions.
@@ -211,7 +215,7 @@ def filter_dataset(
211215
filtered_dataset: List[InputRequest] = []
212216
for (
213217
prompt,
214-
prompt_token_ids,
218+
_,
215219
output,
216220
prompt_len,
217221
output_len,
@@ -255,7 +259,7 @@ def sample_requests(
255259
print(
256260
f"Number of requests {num_requests} is larger than size of dataset"
257261
f" {n}.\n",
258-
f"Repeating data to meet number of requests.\n",
262+
"Repeating data to meet number of requests.\n",
259263
)
260264
sampled_indices = sampled_indices * int(
261265
np.ceil(num_requests / len(sampled_indices))
@@ -361,7 +365,6 @@ async def send_request(
361365
pbar: tqdm,
362366
session_cache: str,
363367
priority: int,
364-
threads: int,
365368
) -> RequestFuncOutput:
366369
"""Send the request to JetStream server."""
367370
request = jetstream_pb2.DecodeRequest(
@@ -394,7 +397,6 @@ async def benchmark(
394397
disable_tqdm: bool,
395398
session_cache: str,
396399
priority: int,
397-
threads: int,
398400
):
399401
"""Benchmark the online serving performance."""
400402
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
@@ -412,7 +414,6 @@ async def benchmark(
412414
pbar=pbar,
413415
session_cache=session_cache,
414416
priority=priority,
415-
threads=threads,
416417
)
417418
)
418419
)
@@ -519,8 +520,8 @@ def main(args: argparse.Namespace):
519520
)
520521

521522
# A given args.max_output_length value is the max generation step,
522-
# when the args.max_output_length is default to None, the sample's golden output length
523-
# will be used to decide the generation step
523+
# when the args.max_output_length is default to None, the sample's golden
524+
# output length will be used to decide the generation step.
524525
input_requests = sample_requests(
525526
dataset=dataset,
526527
tokenizer=tokenizer,
@@ -540,7 +541,6 @@ def main(args: argparse.Namespace):
540541
disable_tqdm=args.disable_tqdm,
541542
session_cache=args.session_cache,
542543
priority=args.priority,
543-
threads=args.threads,
544544
)
545545
)
546546
print("Warm up done")
@@ -554,7 +554,6 @@ def main(args: argparse.Namespace):
554554
disable_tqdm=args.disable_tqdm,
555555
session_cache=args.session_cache,
556556
priority=args.priority,
557-
threads=args.threads,
558557
)
559558
)
560559

@@ -582,12 +581,12 @@ def main(args: argparse.Namespace):
582581
file_name = (
583582
f"JetStream-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
584583
)
585-
with open(file_name, "w") as outfile:
584+
with open(file_name, "w", encoding="utf-8") as outfile:
586585
json.dump(result_json, outfile)
587586

588587
if args.save_request_outputs:
589588
file_path = args.request_outputs_file_path
590-
with open(file_path, "w") as output_file:
589+
with open(file_path, "w", encoding="utf-8") as output_file:
591590
json.dump(
592591
[output.to_dict() for output in request_outputs],
593592
output_file,
@@ -653,12 +652,6 @@ def main(args: argparse.Namespace):
653652
"the request arrival times."
654653
),
655654
)
656-
parser.add_argument(
657-
"--threads",
658-
type=int,
659-
default=110,
660-
help="The maximum number of threads used for request dispatching.",
661-
)
662655
parser.add_argument(
663656
"--total-mock-requests",
664657
type=int,
@@ -736,5 +729,5 @@ def main(args: argparse.Namespace):
736729
help="What entity should be the one starting the conversations.",
737730
)
738731

739-
args = parser.parse_args()
740-
main(args)
732+
parsed_args = parser.parse_args()
733+
main(parsed_args)

benchmarks/eval_accuracy.py

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,58 +12,67 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
"""Evaluate accuracy of JetStream online serving."""
16+
1517
import argparse
1618
import nltk
1719
import evaluate
1820
import json
1921

2022
import numpy as np
2123

24+
2225
def postprocess_text(preds, targets):
23-
preds = [pred.strip() for pred in preds]
24-
targets = [target.strip() for target in targets]
26+
preds = [pred.strip() for pred in preds]
27+
targets = [target.strip() for target in targets]
2528

26-
# rougeLSum expects newline after each sentence
27-
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
28-
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets]
29+
# rougeLSum expects newline after each sentence
30+
preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
31+
targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets]
2932

30-
return preds, targets
33+
return preds, targets
3134

3235

3336
def eval_accuracy(request_outputs_dict):
34-
metric = evaluate.load("rouge")
35-
nltk.download('punkt')
36-
preds = []
37-
targets = []
38-
39-
for output in request_outputs_dict:
40-
preds.append(output["generated_text"])
41-
targets.append(output["original_output"])
42-
preds, targets = postprocess_text(preds, targets)
43-
result = metric.compute(
44-
predictions=preds, references=targets, use_stemmer=True, use_aggregator=False)
45-
result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()}
46-
prediction_lens = [len(pred) for pred in preds]
47-
result["gen_len"] = np.sum(prediction_lens)
48-
result["gen_num"] = len(preds)
49-
print("\nResults\n")
50-
print(result)
37+
metric = evaluate.load("rouge")
38+
nltk.download("punkt")
39+
preds = []
40+
targets = []
41+
42+
for output in request_outputs_dict:
43+
preds.append(output["generated_text"])
44+
targets.append(output["original_output"])
45+
preds, targets = postprocess_text(preds, targets)
46+
result = metric.compute(
47+
predictions=preds,
48+
references=targets,
49+
use_stemmer=True,
50+
use_aggregator=False,
51+
)
52+
result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()}
53+
prediction_lens = [len(pred) for pred in preds]
54+
result["gen_len"] = np.sum(prediction_lens)
55+
result["gen_num"] = len(preds)
56+
print("\nResults\n")
57+
print(result)
5158

5259

5360
def main(args):
54-
with open(args.output_path) as f:
55-
request_outputs_dict = json.load(f)
56-
57-
eval_accuracy(request_outputs_dict)
61+
with open(args.output_path, "r", encoding="utf-8") as f:
62+
request_outputs_dict = json.load(f)
5863

64+
eval_accuracy(request_outputs_dict)
5965

60-
if __name__ == '__main__':
61-
parser = argparse.ArgumentParser()
62-
parser.add_argument(
63-
"--output_path", type=str,
66+
67+
if __name__ == "__main__":
68+
parser = argparse.ArgumentParser()
69+
parser.add_argument(
70+
"--output_path",
71+
type=str,
6472
default="/tmp/request-outputs.json",
65-
help="File path which has original_output and inference generated_text.")
66-
67-
args = parser.parse_args()
68-
69-
main(args)
73+
help="File path which has original_output and inference generated_text.",
74+
)
75+
76+
parsed_args = parser.parse_args()
77+
78+
main(parsed_args)

jetstream/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

jetstream/core/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,3 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-

0 commit comments

Comments
 (0)