This repository was archived by the owner on Mar 17, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 436
Expand file tree
/
Copy pathkeras_multiple_inputs_saved_model.rs
More file actions
54 lines (47 loc) · 2.01 KB
/
keras_multiple_inputs_saved_model.rs
File metadata and controls
54 lines (47 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
use tensorflow::{Graph, SavedModelBundle, SessionOptions, SessionRunArgs, Tensor};
fn main() {
// In this file test_in_input is being used while in the python script,
// that generates the saved model from Keras model it has a name "test_in".
// For multiple inputs _input is not being appended to signature input parameter name.
let signature_input_1_parameter_name = "test_in1";
let signature_input_2_parameter_name = "test_in2";
let signature_output_parameter_name = "test_out";
let save_dir = "examples/keras_multiple_inputs_saved_model";
let tensor1: Tensor<f32> = Tensor::from(&[0.1, 0.2, 0.3, 0.4, 0.5][..]);
let tensor2: Tensor<f32> = Tensor::from(&[0.6, 0.7, 0.8, 0.9, 0.1][..]);
let mut graph = Graph::new();
let bundle = SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, save_dir)
.expect("Can't load saved model");
let session = &bundle.session;
let signature = bundle
.meta_graph_def()
.get_signature("serving_default")
.unwrap();
let input_info1 = signature
.get_input(signature_input_1_parameter_name)
.unwrap();
let input_info2 = signature
.get_input(signature_input_2_parameter_name)
.unwrap();
let output_info = signature
.get_output(signature_output_parameter_name)
.unwrap();
let input_op1 = graph
.operation_by_name_required(&input_info1.name().name)
.unwrap();
let input_op2 = graph
.operation_by_name_required(&input_info2.name().name)
.unwrap();
let output_op = graph
.operation_by_name_required(&output_info.name().name)
.unwrap();
let mut args = SessionRunArgs::new();
args.add_feed(&input_op1, 0, &tensor1);
args.add_feed(&input_op2, 0, &tensor2);
let out = args.request_fetch(&output_op, 0);
session
.run(&mut args)
.expect("Error occured during calculations: {:?}");
let out_res: f32 = args.fetch(out).unwrap()[0];
println!("Results: {:?}", out_res);
}