|
1 | 1 | package com.google.showcase.v1beta1.it; |
2 | 2 |
|
3 | | -import static com.google.common.truth.Truth.assertThat; |
4 | | - |
5 | 3 | import org.junit.jupiter.api.Test; |
6 | 4 | import org.tensorflow.Graph; |
7 | 5 | import org.tensorflow.Session; |
|
11 | 9 | import org.tensorflow.proto.GraphDef; |
12 | 10 | import org.tensorflow.types.TInt32; |
13 | 11 |
|
14 | | -// Tensorflow depends on protobuf 3.x gen code and runtime, we test it in showcase module to prove |
15 | | -// that it works with |
16 | | -// protobuf 4.33+ gen code and runtime that comes with client libraries. |
| 12 | +import static com.google.common.truth.Truth.assertThat; |
| 13 | + |
| 14 | +/** |
| 15 | + * Tensorflow depends on protobuf 3.x gen code and runtime, we test it in showcase module to prove that it works with |
| 16 | + * protobuf 4.33+ gen code and runtime that comes with client libraries. |
| 17 | + */ |
17 | 18 | public class ITProtobuf3Compatibility { |
18 | 19 |
|
19 | | - @Test |
20 | | - void testTensorflow_helloWorldExample() { |
21 | | - try (Graph graph = new Graph()) { |
22 | | - // Hello world example for "10 + 32" operation. |
23 | | - Ops tf = Ops.create(graph); |
| 20 | + @Test |
| 21 | + void testTensorflow_helloWorldExample() { |
| 22 | + try (Graph graph = new Graph()) { |
| 23 | + // Hello world example for "10 + 32" operation. |
| 24 | + Ops tf = Ops.create(graph); |
24 | 25 |
|
25 | | - Constant<TInt32> expectedValue1 = tf.constant(10); |
26 | | - Constant<TInt32> expectedValue2 = tf.constant(32); |
| 26 | + int expectedValue1 = 10; |
| 27 | + int expectedValue2 = 32; |
| 28 | + int expectedSum = 42; |
27 | 29 |
|
28 | | - Add<TInt32> sum = tf.math.add(expectedValue1, expectedValue2); |
| 30 | + String name1 = "constant1"; |
| 31 | + String name2 = "constant2"; |
29 | 32 |
|
30 | | - try (Session s = new Session(graph)) { |
31 | | - try (TInt32 result = (TInt32) s.runner().fetch(sum).run().get(0)) { |
32 | | - System.out.println("10 + 32 = " + result.getInt()); |
33 | | - } |
34 | | - } |
| 33 | + Constant<TInt32> constant1 = tf.withName(name1).constant(expectedValue1); |
| 34 | + Constant<TInt32> constant2 = tf.withName(name2).constant(expectedValue2); |
35 | 35 |
|
36 | | - // GraphDef is a protobuf gen code. |
37 | | - GraphDef graphDef = graph.toGraphDef(); |
| 36 | + Add<TInt32> sum = tf.math.add(constant1, constant2); |
38 | 37 |
|
39 | | - // Inspect the protobuf gen code |
40 | | - Integer actual1 = |
41 | | - graphDef.getNode(0).getAttrOrThrow("value").getTensor().getIntValList().get(0); |
42 | | - Integer actual2 = |
43 | | - graphDef.getNode(1).getAttrOrThrow("value").getTensor().getIntValList().get(0); |
| 38 | + try (Session s = new Session(graph)) { |
| 39 | + try (TInt32 result = (TInt32) s.runner().fetch(sum).run().get(0)) { |
| 40 | + int actualResult = result.getInt(); |
| 41 | + assertThat(actualResult).isEqualTo(expectedSum); |
| 42 | + } |
| 43 | + } |
| 44 | + |
| 45 | + //GraphDef is a protobuf gen code. |
| 46 | + GraphDef graphDef = graph.toGraphDef(); |
| 47 | + |
| 48 | + //Inspect the protobuf gen code |
| 49 | + Integer actual1 = getValueFromGraphDefByName(graphDef, name1); |
| 50 | + Integer actual2 = getValueFromGraphDefByName(graphDef, name2); |
| 51 | + |
| 52 | + assertThat(actual1).isEqualTo(expectedValue1); |
| 53 | + assertThat(actual2).isEqualTo(expectedValue2); |
| 54 | + } |
| 55 | + } |
44 | 56 |
|
45 | | - assertThat(actual1).isEqualTo(expectedValue1); |
46 | | - assertThat(actual2).isEqualTo(expectedValue2); |
| 57 | + private static Integer getValueFromGraphDefByName(GraphDef graphDef, String name1) { |
| 58 | + return graphDef.getNodeList() |
| 59 | + .stream() |
| 60 | + .filter(nodeDef -> nodeDef.getName().equals(name1)) |
| 61 | + .findFirst() |
| 62 | + .get() |
| 63 | + .getAttrOrThrow("value") |
| 64 | + .getTensor().getIntValList().get(0); |
47 | 65 | } |
48 | | - } |
49 | 66 | } |
0 commit comments