Skip to content
10 changes: 10 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,13 @@ Project page: https://github.com/SalesforceAIResearch/uni2ts
License: https://github.com/SalesforceAIResearch/uni2ts/blob/main/LICENSE.txt

--------------------------------------------------------------------------------

The following files include code modified from PatchTST project.

./iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/*

PatchTST is open source software licensed under the Apache License 2.0
Project page: https://github.com/ibm-research/patchtst
License: https://github.com/ibm-research/patchtst/blob/main/LICENSE

--------------------------------------------------------------------------------
22 changes: 20 additions & 2 deletions iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
model_type: str = "",
pipeline_cls: str = "",
repo_id: str = "",
download_weights: bool = True,
auto_map: Optional[Dict] = None,
hub_mixin_cls: Optional[str] = None,
transformers_registered: bool = False,
Expand All @@ -40,9 +41,12 @@ def __init__(
self.state = state
self.pipeline_cls = pipeline_cls
self.repo_id = repo_id
self.auto_map = auto_map
self.download_weights = download_weights
self.auto_map = auto_map # If exists, indicates it's a Transformers model
self.hub_mixin_cls = hub_mixin_cls
self.transformers_registered = transformers_registered
self.transformers_registered = (
transformers_registered # Internal flag: whether registered to Transformers
)

def __repr__(self):
return (
Expand Down Expand Up @@ -173,4 +177,18 @@ def __repr__(self):
},
transformers_registered=True,
),
"patchtst_fm": ModelInfo(
model_id="patchtst_fm",
category=ModelCategory.BUILTIN,
state=ModelStates.INACTIVE,
model_type="patchtst_fm",
pipeline_cls="pipeline_patchtst_fm.PatchTSTFMPipeline",
repo_id="ibm-research/patchtst-fm-r1",
download_weights=False,
auto_map={
"AutoConfig": "configuration_patchtst_fm.PatchTSTFMConfig",
"AutoModelForCausalLM": "modeling_patchtst_fm.PatchTSTFMForPrediction",
},
transformers_registered=True,
),
}
33 changes: 20 additions & 13 deletions iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,28 @@ def _process_builtin_model_directory(self, model_dir: str, model_id: str):

def _download_model_if_necessary() -> bool:
"""Returns: True if the model is existed or downloaded successfully, False otherwise."""
repo_id = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id].repo_id
model_info = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id]
repo_id = model_info.repo_id
weights_path = os.path.join(model_dir, MODEL_SAFETENSORS)
config_path = os.path.join(model_dir, CONFIG_JSON)
if not os.path.exists(weights_path):
try:
hf_hub_download(
repo_id=repo_id,
filename=MODEL_SAFETENSORS,
local_dir=model_dir,
)
except Exception as e:
logger.error(
f"Failed to download model weights from HuggingFace: {e}"
)
return False

if getattr(model_info, "download_weights", True):
if not os.path.exists(weights_path):
try:
hf_hub_download(
repo_id=repo_id,
filename=MODEL_SAFETENSORS,
local_dir=model_dir,
)
except Exception as e:
logger.error(
f"Failed to download model weights from HuggingFace: {e}"
)
return False
else:
logger.info(
f"Skipping weight download for {model_id} due to configuration."
)
if not os.path.exists(config_path):
try:
hf_hub_download(
Expand Down
16 changes: 16 additions & 0 deletions iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading
Loading