-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathconfig.py
More file actions
66 lines (61 loc) · 2.33 KB
/
config.py
File metadata and controls
66 lines (61 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Config for JetStream Server (including engine init)."""
import functools
import os
from typing import Sequence, Type
import jax
from jetstream.core import config_lib
from jetstream.engine.implementations.maxtext.MaxText import maxengine_config, pyconfig
from jetstream_pt import config
def get_server_config(
config_str: str, argv: Sequence[str]
) -> config_lib.ServerConfig | Type[config_lib.ServerConfig]:
match config_str:
case "MaxtextInterleavedServer":
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(argv)
server_config = config_lib.ServerConfig(
prefill_slices=(),
generate_slices=(),
interleaved_slices=("tpu=" + str(jax.device_count()),),
prefill_engine_create_fns=(),
generate_engine_create_fns=(),
interleaved_engine_create_fns=(
functools.partial(
maxengine_config.create_maxengine, config=pyconfig.config
),
),
)
case "PyTorchInterleavedServer":
os.environ["XLA_FLAGS"] = (
"--xla_dump_to=/tmp/xla_logs --xla_dump_hlo_as_text"
)
engine = config.create_engine_from_config_flags()
server_config = config_lib.ServerConfig(
prefill_slices=(),
generate_slices=(),
interleaved_slices=("tpu=" + str(jax.device_count()),),
prefill_engine_create_fns=(),
generate_engine_create_fns=(),
interleaved_engine_create_fns=(lambda a: engine,),
)
case "InterleavedCPUTestServer":
server_config = config_lib.InterleavedCPUTestServer
case "CPUTestServer":
server_config = config_lib.CPUTestServer
case _:
raise NotImplementedError
return server_config