Skip to content

Commit bb75be5

Browse files
committed
tests: Update per AI feedback.
1 parent 9e9b5c6 commit bb75be5

1 file changed

Lines changed: 45 additions & 28 deletions

File tree

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package com.google.showcase.v1beta1.it;
22

3-
import static com.google.common.truth.Truth.assertThat;
4-
53
import org.junit.jupiter.api.Test;
64
import org.tensorflow.Graph;
75
import org.tensorflow.Session;
@@ -11,39 +9,58 @@
119
import org.tensorflow.proto.GraphDef;
1210
import org.tensorflow.types.TInt32;
1311

14-
// Tensorflow depends on protobuf 3.x gen code and runtime, we test it in showcase module to prove
15-
// that it works with
16-
// protobuf 4.33+ gen code and runtime that comes with client libraries.
12+
import static com.google.common.truth.Truth.assertThat;
13+
14+
/**
15+
* Tensorflow depends on protobuf 3.x gen code and runtime, we test it in showcase module to prove that it works with
16+
* protobuf 4.33+ gen code and runtime that comes with client libraries.
17+
*/
1718
public class ITProtobuf3Compatibility {
1819

19-
@Test
20-
void testTensorflow_helloWorldExample() {
21-
try (Graph graph = new Graph()) {
22-
// Hello world example for "10 + 32" operation.
23-
Ops tf = Ops.create(graph);
20+
@Test
21+
void testTensorflow_helloWorldExample() {
22+
try (Graph graph = new Graph()) {
23+
// Hello world example for "10 + 32" operation.
24+
Ops tf = Ops.create(graph);
2425

25-
Constant<TInt32> expectedValue1 = tf.constant(10);
26-
Constant<TInt32> expectedValue2 = tf.constant(32);
26+
int expectedValue1 = 10;
27+
int expectedValue2 = 32;
28+
int expectedSum = 42;
2729

28-
Add<TInt32> sum = tf.math.add(expectedValue1, expectedValue2);
30+
String name1 = "constant1";
31+
String name2 = "constant2";
2932

30-
try (Session s = new Session(graph)) {
31-
try (TInt32 result = (TInt32) s.runner().fetch(sum).run().get(0)) {
32-
System.out.println("10 + 32 = " + result.getInt());
33-
}
34-
}
33+
Constant<TInt32> constant1 = tf.withName(name1).constant(expectedValue1);
34+
Constant<TInt32> constant2 = tf.withName(name2).constant(expectedValue2);
3535

36-
// GraphDef is a protobuf gen code.
37-
GraphDef graphDef = graph.toGraphDef();
36+
Add<TInt32> sum = tf.math.add(constant1, constant2);
3837

39-
// Inspect the protobuf gen code
40-
Integer actual1 =
41-
graphDef.getNode(0).getAttrOrThrow("value").getTensor().getIntValList().get(0);
42-
Integer actual2 =
43-
graphDef.getNode(1).getAttrOrThrow("value").getTensor().getIntValList().get(0);
38+
try (Session s = new Session(graph)) {
39+
try (TInt32 result = (TInt32) s.runner().fetch(sum).run().get(0)) {
40+
int actualResult = result.getInt();
41+
assertThat(actualResult).isEqualTo(expectedSum);
42+
}
43+
}
44+
45+
//GraphDef is a protobuf gen code.
46+
GraphDef graphDef = graph.toGraphDef();
47+
48+
//Inspect the protobuf gen code
49+
Integer actual1 = getValueFromGraphDefByName(graphDef, name1);
50+
Integer actual2 = getValueFromGraphDefByName(graphDef, name2);
51+
52+
assertThat(actual1).isEqualTo(expectedValue1);
53+
assertThat(actual2).isEqualTo(expectedValue2);
54+
}
55+
}
4456

45-
assertThat(actual1).isEqualTo(expectedValue1);
46-
assertThat(actual2).isEqualTo(expectedValue2);
57+
private static Integer getValueFromGraphDefByName(GraphDef graphDef, String name1) {
58+
return graphDef.getNodeList()
59+
.stream()
60+
.filter(nodeDef -> nodeDef.getName().equals(name1))
61+
.findFirst()
62+
.get()
63+
.getAttrOrThrow("value")
64+
.getTensor().getIntValList().get(0);
4765
}
48-
}
4966
}

0 commit comments

Comments
 (0)