Skip to content

Commit fb23076

Browse files
committed
pytorch: remove manual install of transformers
We now install a new enough (in fact newer) version of transformers globally, so we no longer need this manual install.
1 parent 2b2046a commit fb23076

2 files changed

Lines changed: 4 additions & 56 deletions

File tree

ML-Frameworks/pytorch-aarch64/examples/gen_ai_utils/setup_local_packages.py

Lines changed: 0 additions & 44 deletions
This file was deleted.

ML-Frameworks/pytorch-aarch64/examples/transformers_llm_text_gen.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,6 @@
3737
from torchao.quantization.granularity import PerGroup, PerAxis
3838
from torchao.quantization.quant_primitives import MappingType
3939

40-
# This script requires a fairly recent version of transformers
41-
42-
gen_ai_utils = os.getcwd() + "/gen_ai_utils/"
43-
local_packages = gen_ai_utils + "/genai_local_packages"
44-
install_script = gen_ai_utils + "/setup_local_packages.py"
45-
subprocess.run([sys.executable, install_script, local_packages, "transformers==4.47.1"])
46-
sys.path.insert(0, local_packages)
47-
4840
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, TextStreamer
4941
from torch.profiler import profile, ProfilerActivity, tensorboard_trace_handler
5042

@@ -95,7 +87,7 @@ def load_model_components(model_folder_path):
9587

9688

9789
def get_quantized_model(args):
98-
model_name = f"{args.model.name}"
90+
model_name = f"{args.model}"
9991
print("Running model ", model_name)
10092
config, tokenizer, model = load_model_components(args.model)
10193
if model is None:
@@ -189,7 +181,7 @@ def eval_quantized_output(quantized_model, tokenizer, input_tensor, max_min_toke
189181

190182

191183
def main(args):
192-
name_string = f"{args.model.name}"
184+
name_string = f"{args.model}"
193185
quantized_model_, tokenizer_, config_ = get_quantized_model(args)
194186
input_tensor = tokenizer_.encode(args.prompt, return_tensors="pt")
195187
eval_quantized_output(
@@ -223,8 +215,8 @@ def main(args):
223215
)
224216
parser.add_argument(
225217
"--model",
226-
type=Path,
227-
default=Path("TinyLlama/TinyLlama-1.1B-Chat-v1.0"),
218+
type=str,
219+
default="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
228220
help="Hugging Face model ID or Cloned model repository with model files",
229221
)
230222
parser.add_argument(

0 commit comments

Comments
 (0)