Skip to content

Commit 46c152f

Browse files
authored
Cleanup orchestrator proto (#112)
* Cleanup orchestrator proto * Update JetStream based on proto cleanup
1 parent 196beda commit 46c152f

8 files changed

Lines changed: 23 additions & 76 deletions

File tree

benchmarks/benchmark_serving.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -426,18 +426,14 @@ async def send_request(
426426
tokenizer: Any,
427427
input_request: InputRequest,
428428
pbar: tqdm,
429-
session_cache: str,
430-
priority: int,
431429
) -> RequestFuncOutput:
432430
"""Send the request to JetStream server."""
433431
# Tokenization on client side following MLPerf standard.
434432
token_ids = tokenizer.encode(input_request.prompt)
435433
request = jetstream_pb2.DecodeRequest(
436-
session_cache=session_cache,
437434
token_content=jetstream_pb2.DecodeRequest.TokenContent(
438435
token_ids=token_ids
439436
),
440-
priority=priority,
441437
max_tokens=input_request.output_len,
442438
)
443439
output = RequestFuncOutput()
@@ -463,8 +459,6 @@ async def benchmark(
463459
input_requests: list[InputRequest],
464460
request_rate: float,
465461
disable_tqdm: bool,
466-
session_cache: str,
467-
priority: int,
468462
):
469463
"""Benchmark the online serving performance."""
470464
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
@@ -481,8 +475,6 @@ async def benchmark(
481475
tokenizer=tokenizer,
482476
input_request=request,
483477
pbar=pbar,
484-
session_cache=session_cache,
485-
priority=priority,
486478
)
487479
)
488480
)
@@ -614,8 +606,6 @@ def main(args: argparse.Namespace):
614606
input_requests=warmup_requests,
615607
request_rate=args.request_rate,
616608
disable_tqdm=args.disable_tqdm,
617-
session_cache=args.session_cache,
618-
priority=args.priority,
619609
)
620610
)
621611
print(f"{args.warmup_mode} warmup completed.")
@@ -631,8 +621,6 @@ def main(args: argparse.Namespace):
631621
input_requests=input_requests,
632622
request_rate=args.request_rate,
633623
disable_tqdm=args.disable_tqdm,
634-
session_cache=args.session_cache,
635-
priority=args.priority,
636624
)
637625
)
638626

@@ -790,24 +778,6 @@ def main(args: argparse.Namespace):
790778
" the form of a string."
791779
),
792780
)
793-
parser.add_argument(
794-
"--priority",
795-
type=int,
796-
default=0,
797-
help=(
798-
"Message priority. (currently no business logic implemented, use"
799-
" default 0)"
800-
),
801-
)
802-
parser.add_argument(
803-
"--session-cache",
804-
type=str,
805-
default="",
806-
help=(
807-
"Location of any pre-cached results. (currently _load_cache_history"
808-
" not implemented, use default empty str)"
809-
),
810-
)
811781
parser.add_argument(
812782
"--save-request-outputs",
813783
action="store_true",

jetstream/core/orchestrator.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ class ActiveRequest:
133133
complete: Optional[np.ndarray] = None
134134
prefill_result: Any = None
135135
#################### Information relevant for prefill ########################
136-
history_path: Optional[str] = None
137136
prefill_content: Optional[str | list[int]] = None
138137
padded_token_length: Optional[int] = None
139138
################## Information relevant for detokenization ###################
@@ -491,14 +490,13 @@ def _prefill_thread(self, idx: int):
491490

492491
if request is None:
493492
break
494-
is_bos = not bool(request.history_path)
493+
is_bos = True
495494
logging.info(
496495
"Prefilling on prefill engine %d : prefill queue size, %d,"
497-
" is_bos: %s, history: %s",
496+
" is_bos: %s",
498497
idx,
499498
self._prefill_backlog.qsize(),
500499
is_bos,
501-
request.history_path,
502500
)
503501
# Tokenize and padding the text or token input.
504502
padded_tokens, true_length = self._process_prefill_content(
@@ -895,7 +893,6 @@ async def Decode( # pylint: disable=invalid-overridden-method
895893
# Wrap request as an ActiveRequest.
896894
active_request = ActiveRequest(
897895
max_tokens=request.max_tokens,
898-
history_path=request.session_cache,
899896
prefill_content=prefill_content,
900897
is_client_side_tokenization=is_client_side_tokenization,
901898
return_channel=return_channel,

jetstream/core/proto/jetstream.proto

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ service Orchestrator {
2626
}
2727

2828
message DecodeRequest {
29-
// Where to load any pre-existing kv cache from.
30-
string session_cache = 1;
31-
int32 priority = 3;
3229
// The maximum output length of a sequence. It's used in JetStream to control
3330
// the output/decode length of a sequence. It would not be used in the engine.
3431
// We should always set max_tokens <= (max_target_length -
@@ -51,7 +48,7 @@ message DecodeRequest {
5148
TextContent text_content = 5;
5249
TokenContent token_content = 6;
5350
}
54-
reserved 2;
51+
reserved 1, 2, 3;
5552
// Next ID: 7
5653
}
5754

jetstream/core/proto/jetstream_pb2.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
31-
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xa7\x02\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x02\x10\x03"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3'
31+
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\x8a\x02\n\rDecodeRequest\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03J\x04\x08\x03\x10\x04"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3'
3232
)
3333

3434
_globals = globals()
@@ -39,23 +39,23 @@
3939
if _descriptor._USE_C_DESCRIPTORS == False:
4040
DESCRIPTOR._options = None
4141
_globals["_DECODEREQUEST"]._serialized_start = 58
42-
_globals["_DECODEREQUEST"]._serialized_end = 353
43-
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 274
44-
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 301
45-
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 303
46-
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 336
47-
_globals["_DECODERESPONSE"]._serialized_start = 356
48-
_globals["_DECODERESPONSE"]._serialized_end = 687
49-
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 522
50-
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 538
51-
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 541
52-
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 670
53-
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 629
54-
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 670
55-
_globals["_HEALTHCHECKREQUEST"]._serialized_start = 689
56-
_globals["_HEALTHCHECKREQUEST"]._serialized_end = 709
57-
_globals["_HEALTHCHECKRESPONSE"]._serialized_start = 711
58-
_globals["_HEALTHCHECKRESPONSE"]._serialized_end = 749
59-
_globals["_ORCHESTRATOR"]._serialized_start = 752
60-
_globals["_ORCHESTRATOR"]._serialized_end = 937
42+
_globals["_DECODEREQUEST"]._serialized_end = 324
43+
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_start = 233
44+
_globals["_DECODEREQUEST_TEXTCONTENT"]._serialized_end = 260
45+
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_start = 262
46+
_globals["_DECODEREQUEST_TOKENCONTENT"]._serialized_end = 295
47+
_globals["_DECODERESPONSE"]._serialized_start = 327
48+
_globals["_DECODERESPONSE"]._serialized_end = 658
49+
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_start = 493
50+
_globals["_DECODERESPONSE_INITIALCONTENT"]._serialized_end = 509
51+
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_start = 512
52+
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 641
53+
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 600
54+
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 641
55+
_globals["_HEALTHCHECKREQUEST"]._serialized_start = 660
56+
_globals["_HEALTHCHECKREQUEST"]._serialized_end = 680
57+
_globals["_HEALTHCHECKRESPONSE"]._serialized_start = 682
58+
_globals["_HEALTHCHECKRESPONSE"]._serialized_end = 720
59+
_globals["_ORCHESTRATOR"]._serialized_start = 723
60+
_globals["_ORCHESTRATOR"]._serialized_end = 908
6161
# @@protoc_insertion_point(module_scope)

jetstream/tests/core/test_orchestrator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,7 @@ async def test_orchestrator_interleaved_mode(self):
7878
text = "AB"
7979

8080
request = jetstream_pb2.DecodeRequest(
81-
session_cache="",
8281
text_content=jetstream_pb2.DecodeRequest.TextContent(text=text),
83-
priority=1,
8482
max_tokens=3,
8583
)
8684
iterator = client.Decode(request)
@@ -109,11 +107,9 @@ async def test_orchestrator_interleaved_mode_client_tokenization(self):
109107
token_ids = [65, 66]
110108

111109
request = jetstream_pb2.DecodeRequest(
112-
session_cache="",
113110
token_content=jetstream_pb2.DecodeRequest.TokenContent(
114111
token_ids=token_ids
115112
),
116-
priority=1,
117113
max_tokens=3,
118114
)
119115
iterator = client.Decode(request)

jetstream/tests/core/test_server.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,7 @@ async def test_server(
9393
# as BOS
9494
text = "AB"
9595
request = jetstream_pb2.DecodeRequest(
96-
session_cache="",
9796
text_content=jetstream_pb2.DecodeRequest.TextContent(text=text),
98-
priority=1,
9997
max_tokens=3,
10098
)
10199
iterator = stub.Decode(request)

jetstream/tools/load_tester.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,11 @@ def api_call(
5050
stub: jetstream_pb2_grpc.OrchestratorStub,
5151
text: str,
5252
max_tokens: int,
53-
session_cache: str = "",
5453
print_interim: bool = True,
5554
) -> str:
5655
"""Sends a request to server and returns text."""
5756
request = jetstream_pb2.DecodeRequest(
58-
session_cache=session_cache,
5957
text_content=jetstream_pb2.DecodeRequest.TextContent(text=text),
60-
priority=1,
6158
max_tokens=max_tokens,
6259
)
6360
response = stub.Decode(request)

jetstream/tools/requester.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,7 @@
2626

2727
_SERVER = flags.DEFINE_string("server", "0.0.0.0", "server address")
2828
_PORT = flags.DEFINE_string("port", "9000", "port to ping")
29-
_SESSION_CACHE = flags.DEFINE_string(
30-
"session_cache", "", "Location of any pre-cached results"
31-
)
3229
_TEXT = flags.DEFINE_string("text", "Today is a good day", "The message")
33-
_PRIORITY = flags.DEFINE_integer("priority", 0, "Message priority")
3430
_MAX_TOKENS = flags.DEFINE_integer(
3531
"max_tokens", 3, "Maximum number of output/decode tokens of a sequence"
3632
)
@@ -82,20 +78,16 @@ def main(argv: Sequence[str]) -> None:
8278
vocab = load_vocab(_TOKENIZER.value)
8379
token_ids = vocab.tokenizer.encode(_TEXT.value)
8480
request = jetstream_pb2.DecodeRequest(
85-
session_cache=_SESSION_CACHE.value,
8681
token_content=jetstream_pb2.DecodeRequest.TokenContent(
8782
token_ids=token_ids
8883
),
89-
priority=_PRIORITY.value,
9084
max_tokens=_MAX_TOKENS.value,
9185
)
9286
else:
9387
request = jetstream_pb2.DecodeRequest(
94-
session_cache=_SESSION_CACHE.value,
9588
text_content=jetstream_pb2.DecodeRequest.TextContent(
9689
text=_TEXT.value
9790
),
98-
priority=_PRIORITY.value,
9991
max_tokens=_MAX_TOKENS.value,
10092
)
10193
return _GetResponseAsync(stub, request)

0 commit comments

Comments
 (0)