-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild.rs
More file actions
136 lines (111 loc) · 4 KB
/
build.rs
File metadata and controls
136 lines (111 loc) · 4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
// Copyright (c) 2025 Zixiao Han
// SPDX-License-Identifier: MIT
use std::collections::HashMap;
use std::fs;
use std::io::Write;
const FIXED_INPUT_LAYER_SIZE: usize = 768;
const FIXED_OUTPUT_LAYER_SIZE: usize = 1;
fn main() {
let config_content =
fs::read_to_string("config/network.cfg").expect("Failed to read config file");
let config: HashMap<String, usize> = config_content
.lines()
.filter_map(|line| {
let mut parts = line.split('=');
let key = parts.next()?.trim().to_string();
let value = parts.next()?.trim().parse().ok()?;
Some((key, value))
})
.collect();
let hidden_layer_size = config
.get("hidden_layer_size")
.expect("hidden_layer_size not found in config");
load_weights(*hidden_layer_size);
}
fn load_weights(hidden_layer_size: usize) {
let weights_file = &format!(
"resources/quantized_weights/{}.quantized_weights",
hidden_layer_size
);
let out_file = "src/generated/network_weights.rs";
let file_content = fs::read_to_string(weights_file).expect("Failed to read model weights file");
let values: Vec<f32> = file_content
.split(',')
.map(|s| s.trim().parse::<f32>().unwrap())
.collect();
let input_to_hidden_size = FIXED_INPUT_LAYER_SIZE * hidden_layer_size;
let hidden_biases_size = hidden_layer_size;
let hidden_to_output_size = hidden_layer_size * FIXED_OUTPUT_LAYER_SIZE;
let output_biases_size = FIXED_OUTPUT_LAYER_SIZE;
let expected_total_size =
input_to_hidden_size + hidden_biases_size + hidden_to_output_size + output_biases_size + 1;
assert_eq!(
values.len(),
expected_total_size,
"Unmatched weights size, expected {}, got {}",
expected_total_size,
values.len()
);
let mut start = 0;
let input_layer_to_hidden_layer_weights = &values[start..start + input_to_hidden_size];
start += input_to_hidden_size;
let hidden_layer_biases = &values[start..start + hidden_biases_size];
start += hidden_biases_size;
let hidden_layer_to_output_layer_weights = &values[start..start + hidden_to_output_size];
start += hidden_to_output_size;
let output_bias = values[start];
start += output_biases_size;
let scaling_factor = values[start];
let mut code = String::new();
code.push_str(&format!(
"pub const INPUT_LAYER_SIZE: usize = {};\n",
FIXED_INPUT_LAYER_SIZE
));
code.push_str(&format!(
"pub const HIDDEN_LAYER_SIZE: usize = {};\n",
hidden_layer_size
));
code.push_str(&format!(
"pub const INPUT_LAYER_TO_HIDDEN_LAYER_WEIGHTS: [i16; {}] = [\n",
input_to_hidden_size
));
for &value in input_layer_to_hidden_layer_weights {
code.push_str(&format!("{}, ", value));
}
code.push_str("];\n\n");
code.push_str(&format!(
"pub const HIDDEN_LAYER_BIASES: [f32; {}] = [\n",
hidden_biases_size
));
for &value in hidden_layer_biases {
code.push_str(&format!("{}, ", string_float(value)));
}
code.push_str("];\n\n");
code.push_str(&format!(
"pub const HIDDEN_LAYER_TO_OUTPUT_LAYER_WEIGHTS: [f32; {}] = [\n",
hidden_layer_size
));
for hidden_idx in 0..hidden_layer_size {
let value = hidden_layer_to_output_layer_weights[hidden_idx];
code.push_str(&format!("{}, ", string_float(value)));
}
code.push_str("];\n\n");
code.push_str(&format!(
"pub const OUTPUT_BIAS: f32 = {};\n\n",
string_float(output_bias)
));
code.push_str(&format!(
"pub const SCALING_FACTOR: f32 = {};\n",
string_float(scaling_factor)
));
let mut file = fs::File::create(out_file).expect("Failed to create weights source file");
file.write_all(code.as_bytes())
.expect("Failed to write weights source file");
}
fn string_float(value: f32) -> String {
if value == 0. {
format!("{}.0", value)
} else {
format!("{}", value)
}
}