Skip to content

Commit 93bfb05

Browse files
authored
Cache graph functions and add If gradient test (#637)
* Cache graph functions and add If gradient test * Expose cached graph function names
1 parent 1bbd7d9 commit 93bfb05

3 files changed

Lines changed: 336 additions & 0 deletions

File tree

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import java.util.Queue;
4444
import java.util.Set;
4545
import java.util.WeakHashMap;
46+
import java.util.concurrent.ConcurrentHashMap;
4647
import java.util.stream.Collectors;
4748
import org.bytedeco.javacpp.BytePointer;
4849
import org.bytedeco.javacpp.Pointer;
@@ -396,8 +397,21 @@ public GraphOperationBuilder opBuilder(String type, String name, Scope scope) {
396397
return new GraphOperationBuilder(this, type, name, scope, dangerousGradientBuilder);
397398
}
398399

400+
/**
401+
* Attaches a {@link ConcreteFunction} to this graph.
402+
*
403+
* <p>If a function with the same defined name has already been attached, this method returns
404+
* immediately without re-registering it.
405+
*
406+
* <p>The function is also stored in an internal cache to speed up subsequent lookups performed by
407+
* {@link #getFunction(String)}.
408+
*/
399409
@Override
400410
public void attachFunction(ConcreteFunction function) {
411+
String name = function.getDefinedName();
412+
if (functionCache.putIfAbsent(name, function) != null) {
413+
return;
414+
}
401415
try (Reference ref = ref();
402416
PointerScope scope = new PointerScope()) {
403417
TF_Status status = TF_Status.newStatus();
@@ -455,6 +469,10 @@ List<NativeFunction> getNativeFunctions(PointerScope outerScope) {
455469
* name
456470
*/
457471
public ConcreteFunction getFunction(String key) {
472+
ConcreteFunction cached = functionCache.get(key);
473+
if (cached != null) {
474+
return cached;
475+
}
458476
try (Reference ref = ref();
459477
PointerScope scope = new PointerScope()) {
460478
List<NativeFunction> funcs = getNativeFunctions(scope);
@@ -881,6 +899,33 @@ Set<Operation> initializers() {
881899
private final Set<Operation> initializers = Collections.synchronizedSet(new LinkedHashSet<>());
882900
private int newInitializersMarker = -1;
883901

902+
/**
903+
* Cache of {@link ConcreteFunction}s attached to this graph, indexed by their defined name.
904+
*
905+
* <p>This cache avoids repeatedly scanning the native function library when resolving functions
906+
* during gradient construction or control-flow expansion.
907+
*
908+
* <p>The cache is populated lazily when {@link #attachFunction(ConcreteFunction)} is called and
909+
* consulted first by {@link #getFunction(String)}.
910+
*
911+
* <p>A {@link ConcurrentHashMap} is used to allow concurrent reads during graph building without
912+
* additional synchronization.
913+
*/
914+
private final ConcurrentHashMap<String, ConcreteFunction> functionCache =
915+
new ConcurrentHashMap<>();
916+
917+
/**
918+
* Returns a read-only view of the function names cached by this graph.
919+
*
920+
* <p>This exposes only the function names so callers can resolve ambiguous matches themselves
921+
* before calling {@link #getFunction(String)} with an exact name.
922+
*
923+
* @return a read-only view of cached function names
924+
*/
925+
public Set<String> functionNames() {
926+
return Collections.unmodifiableSet(functionCache.keySet());
927+
}
928+
884929
/**
885930
* Use builders without locking. This should only be used during custom gradient building.
886931
*

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/AttributeMetadata.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public class AttributeMetadata {
3030

3131
/** The size of the list if this attribute is a list, undefined otherwise. */
3232
public final long listSize;
33+
3334
/**
3435
* The type of this attribute, or the type of the list values if it is a list.
3536
*
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
/*
2+
Copyright 2026 The TensorFlow Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
*/
17+
package org.tensorflow;
18+
19+
import java.util.ArrayList;
20+
import java.util.Iterator;
21+
import java.util.List;
22+
import java.util.stream.Collectors;
23+
import org.junit.jupiter.api.Test;
24+
import org.tensorflow.op.Ops;
25+
import org.tensorflow.op.core.Gradients;
26+
import org.tensorflow.op.core.Placeholder;
27+
import org.tensorflow.op.core.StatefulIf;
28+
import org.tensorflow.op.core.StatefulPartitionedCall;
29+
import org.tensorflow.op.core.StatelessIf;
30+
import org.tensorflow.types.TBool;
31+
import org.tensorflow.types.TFloat32;
32+
import org.tensorflow.types.TInt32;
33+
import org.tensorflow.types.family.TType;
34+
35+
public class IfGradientTest {
36+
37+
private static ConcreteFunction thenFn() {
38+
return ConcreteFunction.create(
39+
(Ops tf) -> {
40+
Placeholder<TFloat32> x = tf.placeholder(TFloat32.class);
41+
Operand<TFloat32> y = tf.math.mul(x, tf.constant(3.0f));
42+
return Signature.builder("thenBranch").input("x", x).output("y", y).build();
43+
});
44+
}
45+
46+
private static ConcreteFunction elseFn() {
47+
return ConcreteFunction.create(
48+
(Ops tf) -> {
49+
Placeholder<TFloat32> x = tf.placeholder(TFloat32.class);
50+
Operand<TFloat32> y = tf.math.mul(x, tf.constant(5.0f));
51+
return Signature.builder("elseBranch").input("x", x).output("y", y).build();
52+
});
53+
}
54+
55+
private static void assertClose(float got, float expected, float eps, String msg) {
56+
if (Math.abs(got - expected) > eps) {
57+
throw new AssertionError(msg + " (got=" + got + ", expected=" + expected + ")");
58+
}
59+
}
60+
61+
private static void primeIfGradFunctions(Graph g) {
62+
63+
Iterator<GraphOperation> operations = g.operations();
64+
while (operations.hasNext()) {
65+
GraphOperation op = operations.next();
66+
String type = op.type();
67+
if (!StatefulIf.OP_NAME.equals(type) && !StatelessIf.OP_NAME.equals(type)) continue;
68+
69+
ConcreteFunction thenFwd = op.attributes().getAttrFunction("then_branch");
70+
ConcreteFunction elseFwd = op.attributes().getAttrFunction("else_branch");
71+
72+
int nInputs = op.inputListLength("input");
73+
int nOut = op.numOutputs();
74+
75+
List<Class<? extends TType>> tin = new ArrayList<>(nInputs);
76+
for (int i = 0; i < nInputs; i++) {
77+
Class<? extends TType> c = op.input(1 + i).asOutput().type();
78+
tin.add(c);
79+
}
80+
81+
List<Class<? extends TType>> tout = new ArrayList<>(nOut);
82+
for (int i = 0; i < nOut; i++) {
83+
Class<? extends TType> c = op.output(i).type();
84+
tout.add(c);
85+
}
86+
87+
ConcreteFunction thenGrad = buildBranchGradFn(op.name() + "/then_grad", thenFwd, tin, tout);
88+
ConcreteFunction elseGrad = buildBranchGradFn(op.name() + "/else_grad", elseFwd, tin, tout);
89+
90+
g.attachFunction(thenGrad);
91+
g.attachFunction(elseGrad);
92+
}
93+
}
94+
95+
@SuppressWarnings({"rawtypes", "unchecked"})
96+
private static ConcreteFunction buildBranchGradFn(
97+
String prefix,
98+
ConcreteFunction branchFn,
99+
List<Class<? extends TType>> tin,
100+
List<Class<? extends TType>> toutForward) {
101+
102+
return ConcreteFunction.create(
103+
(Ops tf) -> {
104+
Signature.Builder sig = Signature.builder(prefix);
105+
106+
List<Operand<?>> x = new ArrayList<>(tin.size());
107+
for (int i = 0; i < tin.size(); i++) {
108+
Placeholder<? extends TType> ph = tf.placeholder((Class) tin.get(i));
109+
x.add(ph);
110+
sig.input("x" + i, ph);
111+
}
112+
113+
List<Operand<?>> dy = new ArrayList<>(toutForward.size());
114+
for (int i = 0; i < toutForward.size(); i++) {
115+
Placeholder<? extends TType> ph = tf.placeholder((Class) toutForward.get(i));
116+
dy.add(ph);
117+
sig.input("dy" + i, ph);
118+
}
119+
120+
StatefulPartitionedCall yCall =
121+
StatefulPartitionedCall.create(tf.scope(), x, toutForward, branchFn);
122+
123+
Operand<?> L = tf.constant(0.0f);
124+
for (int i = 0; i < toutForward.size(); i++) {
125+
Operand<?> prod = tf.math.mul((Operand) yCall.output().get(i), (Operand) dy.get(i));
126+
L = tf.math.add((Operand) L, (Operand) sumAll(tf, prod));
127+
}
128+
129+
Gradients g = tf.gradients((Iterable) List.of((Operand) L), x);
130+
131+
for (int i = 0; i < tin.size(); i++) {
132+
Operand<?> dx = g.dy(i);
133+
sig.output("dx" + i, dx);
134+
}
135+
136+
return sig.build();
137+
});
138+
}
139+
140+
@SuppressWarnings({"rawtypes", "unchecked"})
141+
private static Operand<?> sumAll(Ops tf, Operand<?> v) {
142+
Operand<TInt32> r = tf.rank(v);
143+
Operand<TInt32> axes = tf.range(tf.constant(0), r, tf.constant(1));
144+
return tf.reduceSum((Operand) v, axes);
145+
}
146+
147+
private static ConcreteFunction getSingleFunctionByPrefix(Graph graph, String prefix) {
148+
List<String> matches =
149+
graph.functionNames().stream()
150+
.filter(name -> name.startsWith(prefix))
151+
.collect(Collectors.toList());
152+
if (matches.size() != 1) {
153+
throw new IllegalStateException(
154+
"Expected one cached function for prefix=" + prefix + ", found=" + matches);
155+
}
156+
return graph.getFunction(matches.get(0));
157+
}
158+
159+
@Test
160+
public void testStatefullIfGradient() {
161+
TensorFlow.registerCustomGradient(
162+
StatefulIf.OP_NAME,
163+
(tf, op, gradOutputs) -> {
164+
OperationAttributeInspector attrs = op.attributes();
165+
ConcreteFunction thenBranch = attrs.getAttrFunction("then_branch");
166+
ConcreteFunction elseBranch = attrs.getAttrFunction("else_branch");
167+
168+
if (thenBranch == null || elseBranch == null) {
169+
int n = 1 + op.inputListLength("input");
170+
List<Operand<?>> no = new ArrayList<>(n);
171+
for (int i = 0; i < n; i++) {
172+
no.add(null);
173+
}
174+
return no;
175+
}
176+
177+
Operand<? extends TType> cond = op.input(0);
178+
int nInputs = op.inputListLength("input");
179+
List<Operand<?>> inputs = new ArrayList<>(nInputs);
180+
for (int i = 0; i < nInputs; i++) {
181+
inputs.add(op.input(1 + i));
182+
}
183+
184+
int nOut = op.numOutputs();
185+
List<Class<? extends TType>> toutForward = new ArrayList<>(nOut);
186+
for (int i = 0; i < nOut; i++) {
187+
toutForward.add(op.output(i).type());
188+
}
189+
190+
List<Class<? extends TType>> tin =
191+
inputs.stream().map(input -> input.asOutput().type()).collect(Collectors.toList());
192+
List<Operand<?>> dys = new ArrayList<>(nOut);
193+
for (int i = 0; i < nOut; i++) {
194+
Operand<?> dy = null;
195+
if (gradOutputs != null && i < gradOutputs.size()) {
196+
dy = gradOutputs.get(i);
197+
}
198+
if (dy == null) {
199+
dy =
200+
gradOutputs == null || gradOutputs.isEmpty()
201+
? tf.onesLike((Operand) op.output(i))
202+
: tf.zerosLike((Operand) op.output(i));
203+
}
204+
dys.add(dy);
205+
}
206+
207+
List<Operand<?>> input = new ArrayList<>(nInputs + nOut);
208+
input.addAll(inputs);
209+
input.addAll(dys);
210+
211+
final String thenPrefix = op.name() + "/then_grad"; // op has unique name
212+
final String elsePrefix = op.name() + "/else_grad";
213+
214+
ConcreteFunction thenGrad = getSingleFunctionByPrefix(op.env(), thenPrefix);
215+
ConcreteFunction elseGrad = getSingleFunctionByPrefix(op.env(), elsePrefix);
216+
217+
if (thenGrad == null || elseGrad == null) {
218+
throw new IllegalStateException("If grad functions not primed for op=" + op.name());
219+
}
220+
StatefulIf dInputsIf =
221+
StatefulIf.create(tf.scope(), cond, input, tin, thenGrad, elseGrad);
222+
List<Operand<?>> result = new ArrayList<>(1 + nInputs);
223+
result.add(null); // no gradient for condition
224+
result.addAll(dInputsIf.output());
225+
return result;
226+
});
227+
228+
Graph g = new Graph();
229+
Ops tf = Ops.create(g);
230+
231+
var x = tf.placeholder(TFloat32.class); // scalar
232+
var cond = tf.placeholder(TBool.class); // scalar
233+
234+
try (ConcreteFunction thenBranch = thenFn();
235+
ConcreteFunction elseBranch = elseFn()) {
236+
237+
StatefulIf ifOp =
238+
StatefulIf.create(
239+
tf.scope(),
240+
cond,
241+
List.of((Operand) x),
242+
List.of(TFloat32.class),
243+
thenBranch,
244+
elseBranch);
245+
246+
var y = ifOp.output().get(0);
247+
248+
primeIfGradFunctions(g);
249+
250+
var dy_dx = g.addGradients(y, new Output[] {x.asOutput()})[0];
251+
252+
try (Session session = new Session(g)) {
253+
254+
try (Result r =
255+
session
256+
.runner()
257+
.feed(x, TFloat32.scalarOf(2.0f))
258+
.feed(cond, TBool.scalarOf(true))
259+
.fetch(y)
260+
.fetch(dy_dx)
261+
.run()) {
262+
263+
float yVal = ((TFloat32) r.get(0)).getFloat();
264+
float gVal = ((TFloat32) r.get(1)).getFloat();
265+
266+
assertClose(yVal, 6.0f, 1e-6f, "y mismatch for cond=true");
267+
assertClose(gVal, 3.0f, 1e-6f, "grad mismatch for cond=true");
268+
}
269+
270+
// ---- cond=false
271+
try (Result r =
272+
session
273+
.runner()
274+
.feed(x, TFloat32.scalarOf(2.0f))
275+
.feed(cond, TBool.scalarOf(false))
276+
.fetch(y)
277+
.fetch(dy_dx)
278+
.run()) {
279+
280+
float yVal = ((TFloat32) r.get(0)).getFloat();
281+
float gVal = ((TFloat32) r.get(1)).getFloat();
282+
assertClose(yVal, 10.0f, 1e-6f, "y mismatch for cond=false");
283+
assertClose(gVal, 5.0f, 1e-6f, "grad mismatch for cond=false");
284+
}
285+
}
286+
;
287+
}
288+
}
289+
;
290+
}

0 commit comments

Comments
 (0)