Skip to content

Commit 8488a64

Browse files
authored
[AINode] Integrate moirai2 as builtin forecasting model (#17056)
1 parent 8455b8c commit 8488a64

28 files changed

Lines changed: 3862 additions & 8 deletions

LICENSE

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,4 +349,14 @@ The chronos-forecasting is open source software licensed under the Apache Licens
349349
Project page: https://github.com/amazon-science/chronos-forecasting
350350
License: https://github.com/amazon-science/chronos-forecasting/blob/main/LICENSE
351351

352-
--------------------------------------------------------------------------------
352+
--------------------------------------------------------------------------------
353+
354+
The following files include code modified from uni2ts project.
355+
356+
./iotdb-core/ainode/iotdb/ainode/core/model/moirai2/*
357+
358+
The uni2ts is open source software licensed under the Apache License 2.0
359+
Project page: https://github.com/SalesforceAIResearch/uni2ts
360+
License: https://github.com/SalesforceAIResearch/uni2ts/blob/main/LICENSE.txt
361+
362+
--------------------------------------------------------------------------------

integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ public class AINodeTestUtils {
5656
new AbstractMap.SimpleEntry<>(
5757
"timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")),
5858
new AbstractMap.SimpleEntry<>(
59-
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")))
59+
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")),
60+
new AbstractMap.SimpleEntry<>(
61+
"moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")))
6062
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
6163

6264
public static final Map<String, FakeModelInfo> BUILTIN_MODEL_MAP;

iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ def _estimate_shared_pool_size_by_total_mem(
6868
# Seize memory usage for each model
6969
mem_usages: Dict[str, float] = {}
7070
for model_info in all_models:
71+
if model_info.model_id not in MODEL_MEM_USAGE_MAP:
72+
logger.error(
73+
f"[Inference] Model '{model_info.model_id}' not found in MODEL_MEM_USAGE_MAP. "
74+
f"Available types: {list(MODEL_MEM_USAGE_MAP.keys())}"
75+
)
76+
raise KeyError(
77+
f"Model '{model_info.model_id}' not found in MODEL_MEM_USAGE_MAP. "
78+
f"Please add memory usage configuration for '{model_info.model_id}' in constant.py"
79+
)
7180
mem_usages[model_info.model_id] = (
7281
MODEL_MEM_USAGE_MAP[model_info.model_id] * INFERENCE_EXTRA_MEMORY_RATIO
7382
)

iotdb-core/ainode/iotdb/ainode/core/model/chronos2/chronos_bolt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,7 @@ def predict(
596596
context_tensor = torch.cat([context_tensor, prediction], dim=-1)[
597597
..., -self.model_context_length :
598598
]
599-
(batch_size, n_quantiles, context_length) = context_tensor.shape
599+
batch_size, n_quantiles, context_length = context_tensor.shape
600600

601601
with torch.no_grad():
602602
# Reshape (batch, n_quantiles, context_length) -> (batch * n_quantiles, context_length)

iotdb-core/ainode/iotdb/ainode/core/model/model_info.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,4 +145,17 @@ def __repr__(self):
145145
"AutoModelForCausalLM": "model.Chronos2Model",
146146
},
147147
),
148+
"moirai2": ModelInfo(
149+
model_id="moirai2",
150+
category=ModelCategory.BUILTIN,
151+
state=ModelStates.INACTIVE,
152+
model_type="moirai",
153+
pipeline_cls="pipeline_moirai2.Moirai2Pipeline",
154+
repo_id="Salesforce/moirai-2.0-R-small",
155+
auto_map={
156+
"AutoConfig": "configuration_moirai2.Moirai2Config",
157+
"AutoModelForCausalLM": "modeling_moirai2.Moirai2ForPrediction",
158+
},
159+
transformers_registered=True,
160+
),
148161
}

iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def load_model(model_info: ModelInfo, **model_kwargs) -> Any:
6161

6262
def load_model_from_transformers(model_info: ModelInfo, **model_kwargs):
6363
device_map = model_kwargs.get("device_map", "cpu")
64-
trust_remote_code = model_kwargs.get("trust_remote_code", True)
6564
train_from_scratch = model_kwargs.get("train_from_scratch", False)
6665

6766
model_path = os.path.join(
@@ -107,11 +106,9 @@ def load_model_from_transformers(model_info: ModelInfo, **model_kwargs):
107106
model_cls = AutoModelForCausalLM
108107

109108
if train_from_scratch:
110-
model = model_cls.from_config(config_cls, trust_remote_code=trust_remote_code)
109+
model = model_cls.from_config(config_cls)
111110
else:
112-
model = model_cls.from_pretrained(
113-
model_path, trust_remote_code=trust_remote_code
114-
)
111+
model = model_cls.from_pretrained(model_path)
115112

116113
return BACKEND.move_model(model, device_map)
117114

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
# This file is part of the Apache IoTDB project.
19+
#
20+
# This file includes code modified from the uni2ts project (https://github.com/salesforce/uni2ts).
21+
# The original code is licensed under the Apache License 2.0.
22+
#
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
19+
from typing import Optional
20+
21+
import numpy as np
22+
import torch
23+
from jaxtyping import Bool, Float, Int
24+
25+
numpy_to_torch_dtype_dict = {
26+
bool: torch.bool,
27+
np.uint8: torch.uint8,
28+
np.int8: torch.int8,
29+
np.int16: torch.int16,
30+
np.int32: torch.int32,
31+
np.int64: torch.int64,
32+
np.float16: torch.float16,
33+
np.float32: torch.float32,
34+
np.float64: torch.float64,
35+
np.complex64: torch.complex64,
36+
np.complex128: torch.complex128,
37+
}
38+
39+
40+
def packed_attention_mask(
41+
sample_id: Int[torch.Tensor, "*batch seq_len"],
42+
) -> Bool[torch.Tensor, "*batch seq_len seq_len"]:
43+
sample_id = sample_id.unsqueeze(-1)
44+
attention_mask = sample_id.eq(sample_id.mT)
45+
return attention_mask
46+
47+
48+
def packed_causal_attention_mask(
49+
sample_id: Int[torch.Tensor, "*batch seq_len"],
50+
time_id: Int[torch.Tensor, "*batch seq_len"],
51+
) -> Bool[torch.Tensor, "*batch seq_len seq_len"]:
52+
attention_mask = packed_attention_mask(sample_id)
53+
expanded_id1 = time_id.unsqueeze(-2)
54+
expanded_id2 = time_id.unsqueeze(-1)
55+
compare_res = expanded_id1 <= expanded_id2
56+
attention_mask = attention_mask * compare_res
57+
return attention_mask
58+
59+
60+
def mask_fill(
61+
tensor: Float[torch.Tensor, "*batch dim"],
62+
mask: Bool[torch.Tensor, "*batch"],
63+
value: Float[torch.Tensor, "dim"],
64+
) -> Float[torch.Tensor, "*batch dim"]:
65+
mask = mask.unsqueeze(-1)
66+
return tensor * ~mask + value * mask
67+
68+
69+
def safe_div(
70+
numer: torch.Tensor,
71+
denom: torch.Tensor,
72+
) -> torch.Tensor:
73+
return numer / torch.where(
74+
denom == 0,
75+
1.0,
76+
denom,
77+
)
78+
79+
80+
def size_to_mask(
81+
max_size: int,
82+
sizes: Int[torch.Tensor, "*batch"],
83+
) -> Bool[torch.Tensor, "*batch max_size"]:
84+
mask = torch.arange(max_size, device=sizes.device)
85+
return torch.lt(mask, sizes.unsqueeze(-1))
86+
87+
88+
def fixed_size(
89+
value: Float[torch.Tensor, "*batch max_size"],
90+
) -> Int[torch.Tensor, "*batch"]:
91+
sizes = torch.ones_like(value[..., 0], dtype=torch.long) * value.shape[-1]
92+
return sizes
93+
94+
95+
def sized_mean(
96+
value: Float[torch.Tensor, "*batch max_size"],
97+
sizes: Optional[Int[torch.Tensor, "*batch"]],
98+
dim: Optional[int | tuple[int, ...]] = None,
99+
keepdim: bool = False,
100+
size_keepdim: bool = False,
101+
correction: int = 0,
102+
) -> Float[torch.Tensor, "..."]:
103+
value = value * size_to_mask(value.shape[-1], sizes)
104+
div_val = safe_div(
105+
value.sum(dim=-1).sum(dim, keepdim=keepdim),
106+
torch.clamp(sizes.sum(dim, keepdim=keepdim) - correction, min=0),
107+
)
108+
if size_keepdim:
109+
div_val = div_val.unsqueeze(-1)
110+
return div_val
111+
112+
113+
def masked_mean(
114+
value: Float[torch.Tensor, "..."],
115+
mask: Bool[torch.Tensor, "..."],
116+
dim: Optional[int | tuple[int, ...]] = None,
117+
keepdim: bool = False,
118+
correction: int = 0,
119+
) -> Float[torch.Tensor, "..."]:
120+
return safe_div(
121+
(value * mask).sum(dim=dim, keepdim=keepdim),
122+
torch.clamp(mask.float().sum(dim, keepdim=keepdim) - correction, min=0),
123+
)
124+
125+
126+
def unsqueeze_trailing_dims(x: torch.Tensor, shape: torch.Size) -> torch.Tensor:
127+
if x.ndim > len(shape) or x.shape != shape[: x.ndim]:
128+
raise ValueError
129+
dim = (...,) + (None,) * (len(shape) - x.ndim)
130+
return x[dim]
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
19+
from typing import List, Tuple
20+
21+
from transformers import PretrainedConfig
22+
23+
24+
class Moirai2Config(PretrainedConfig):
25+
model_type = "moirai2"
26+
27+
def __init__(
28+
self,
29+
d_model: int = 384,
30+
d_ff: int = 1024,
31+
num_layers: int = 6,
32+
patch_size: int = 16,
33+
max_seq_len: int = 512,
34+
attn_dropout_p: float = 0.0,
35+
dropout_p: float = 0.0,
36+
scaling: bool = True,
37+
num_predict_token: int = 4,
38+
quantile_levels: Tuple[float, ...] = (
39+
0.1,
40+
0.2,
41+
0.3,
42+
0.4,
43+
0.5,
44+
0.6,
45+
0.7,
46+
0.8,
47+
0.9,
48+
),
49+
**kwargs,
50+
):
51+
self.d_model = d_model
52+
self.d_ff = d_ff
53+
self.num_layers = num_layers
54+
self.patch_size = patch_size
55+
self.max_seq_len = max_seq_len
56+
self.attn_dropout_p = attn_dropout_p
57+
self.dropout_p = dropout_p
58+
self.scaling = scaling
59+
self.num_predict_token = num_predict_token
60+
self.quantile_levels = quantile_levels
61+
super().__init__(**kwargs)

0 commit comments

Comments
 (0)