Skip to content

Custom/gradients dispatch#632

Open
nfeybesse wants to merge 12 commits intotensorflow:masterfrom
nfeybesse:custom/gradients-dispatch
Open

Custom/gradients dispatch#632
nfeybesse wants to merge 12 commits intotensorflow:masterfrom
nfeybesse:custom/gradients-dispatch

Conversation

@nfeybesse
Copy link
Contributor

  1. Initial Problem

You wanted to register multiple custom gradients in Java using
TensorFlow.registerCustomGradient(...).

Observed symptom:

After registering a few gradients (≈ 5–10),

TFJ_RegisterCustomGradient(opType, adapter) received adapter_ptr = 0 on the C++ side,

which resulted in:

either a refusal to register the gradient,

or a SIGSEGV later during backpropagation.

Key observation:

If the “important” gradient was registered first, it worked.

Subsequent ones failed → this was a cumulative issue, not related to the specific op.

  1. Actual Root Cause

It was not:

a JNI signature bug,

an InfoMap issue,

nor a casting or ABI problem.

👉 The real cause was a limitation in JavaCPP FunctionPointer callbacks:

each TFJ_GradFuncAdapter allocates a native thunk,

after a certain number of such allocations, JavaCPP silently passes a null pointer (0),

the TensorFlow C++ runtime then receives an invalid callback pointer.

👉 Conclusion:
Creating one native callback per gradient is not scalable.

  1. Principle of the Definitive Fix

Instead of:

1 gradient = 1 native callback

We switched to:

1 single native callback

with dispatching in Java based on opType

This is exactly how TensorFlow does it in Python on the C++ side.

  1. Final Architecture
    A. A Single Native Callback (Singleton)

A single TFJ_GradFuncAdapter instance

Registered with TensorFlow C++ for all ops

As a result:

no more adapter_ptr = 0

no practical limit on the number of custom gradients

B. Java-side Dispatch by opType

A Java dispatcher selects the correct gradient during backpropagation:

TensorFlow C++

CustomGradFunc (C++)

TFJ_GradFuncAdapter.call(...)

DispatchingGradientAdapter.apply(...)

CustomGradient / RawCustomGradient for the corresponding op

  1. Proper Handling of Visibility Constraints
    Problem

NativeScope and Ops have package-private constructors

They are only accessible from org.tensorflow.op

Solution

DispatchingGradientAdapter is package-private and lives in org.tensorflow.op

A public GradientDispatch class acts as a bridge

TensorFlow.java only sees the public TFJ_GradFuncAdapter type

➡️ This strictly respects TensorFlow Java’s internal design, with no hacks.

  1. Correct Support for “NoGradient”
    Problem

Returning null on the Java side caused a NullPointerException

The native code did not correctly support TF_Output.oper == nullptr

Fixes

Java side (AbstractGradientAdapter):

null is now translated into:

TF_Output { oper = nullptr, index = 0 }

C++ side (CustomGradFunc):

out.oper == nullptr is interpreted as NoGradient

No dangerous dereference

No crashes / no SIGSEGV

  1. Cleanup of the C++ Bridge (CustomGradFunc)

Applied corrections:

Removed a double loop that was adding gradients twice

Consistent handling of NoGradient

Single, safe memory deallocation (free(outputs))

Preserved defensive hardening:

checks on num_outputs

outputs == nullptr

etc.

  1. Final State
    What now works

✔ Registering dozens (or hundreds) of custom gradients

✔ Registration order no longer matters

✔ No more adapter_ptr = 0

✔ No JNI crashes / no SIGSEGV

✔ Proper support for partial gradients (NoGradient)

✔ Architecture aligned with native TensorFlow

What was avoided

❌ Fragile JavaCPP patches

❌ Dependency on internal allocation details

❌ Workarounds based on registration order

  1. In One Sentence

We replaced a non-scalable architecture (“N gradients = N native callbacks”) with a scalable one (“1 native callback + Java dispatch”), while properly fixing NoGradient handling and strictly respecting TensorFlow Java’s internal constraints.

unordered_map<string, TFJ_GradFuncAdapter> g_grad_func_adapters;

// Cast helper (inspired by TF C-API)
template <typename T, typename U>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fix this diff to remove all the formatting changes so we can see just the functional changes to CustomGradFunc?

return false;
}

bool registered = GradOpRegistry::Global()->Register(op_type, CustomGradFunc);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the formatting to reduce the diff.

@Craigacp
Copy link
Collaborator

This looks like a fairly complicated fix to work around a bug in JavaCPP? Is it not better to fix it there?

@nfeybesse
Copy link
Contributor Author

Thanks for the question — it’s a fair concern.

This change is indeed a workaround for a limitation in JavaCPP (bytedeco/javacpp#1205), where multiple native callbacks of the same kind cannot be reliably registered and invoked. In practice, only the last registered gradient adapter survives, which makes it impossible to support more than one Java custom gradient per process.

Fixing this directly in JavaCPP would be ideal in theory, but in practice it is not a viable short- or medium-term option for TensorFlow Java:

The issue is deep in JavaCPP’s native callback and lifetime management.

TensorFlow Java depends on JavaCPP as an external project, and cannot reasonably block feature development or correctness fixes on changes there.

Even with a JavaCPP fix, TensorFlow Java would still need a stable, deterministic way to manage gradient dispatch per op type.

For these reasons, this PR follows the same architectural pattern already used by TensorFlow itself.

TensorFlow Python does not register one native callback per op.
Instead, it registers a single C++ gradient hook and performs runtime dispatch based on the op type (via the gradient registry). In other words, Python also uses a centralized dispatcher rather than relying on multiple independent native callbacks.

This PR mirrors that design on the Java side:

A single native CustomGradFunc is registered with TensorFlow.

That function dispatches to the appropriate Java gradient implementation based on op_type.

This avoids the JavaCPP limitation entirely, while matching TensorFlow’s own gradient architecture.

As a result, the solution is:

robust and deterministic,

consistent with TensorFlow’s Python design,

backward-compatible,

and does not require changes to JavaCPP or TensorFlow C++.

In short: while the root cause is a JavaCPP limitation, centralizing gradient dispatch is not a hack — it is the same model TensorFlow already uses, adapted to the Java runtime constraints.

@nfeybesse
Copy link
Contributor Author

bytedeco/javacpp#648

@nfeybesse nfeybesse force-pushed the custom/gradients-dispatch branch from d7bc382 to 8d80312 Compare February 11, 2026 10:30

final String opType = operation.type();

RawCustomGradient rg = raw.get(opType);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic prefers raw gradients over typed ones, but there isn't anything documented about why it prefers them or if it makes sense to add both raw and typed gradients for the same op. It would be good to clarify this, and if it doesn't make sense to have both kinds of gradients the adapter should reject them in the puts.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add javadoc to the top of this class noting the overall purpose of it (to provide Java side dispatching for gradients mirroring TF-Python), that it only accepts either raw or typed gradients for a given op, and that it rejects duplicate assignments.

…, remove inline ifs, add license headers and imports

- Prevent dual registration of raw and typed gradients for the same op type
- Use putIfAbsent and explicit exceptions to avoid silent overwrites
- Replace inline if statements in tfj_gradients_impl.cc with brace blocks
- Add Apache 2.0 headers to new files
- Replace fully-qualified GradientDispatch reference with import
@nfeybesse
Copy link
Contributor Author

Thanks for the review!

I’ve pushed an update that:

  • Enforces mutual exclusion between raw and typed gradient registrations
  • Prevents silent overwrites via putIfAbsent
  • Replaces inline if statements with brace blocks in tfj_gradients_impl.cc
  • Adds the standard Apache 2.0 headers
  • Uses an import for GradientDispatch

Let me know if anything else should be adjusted.

Document purpose as Java-side gradient dispatcher mirroring TF-Python,
clarify raw vs typed gradient registration contract, and note duplicate
registration rejection.
@Craigacp
Copy link
Collaborator

I'm trying to figure out why it needs to make a fresh anonymous subclass of TFJ_GradFuncAdapter rather than returning the subclass it already has. While I was doing that I noticed that there isn't actually a test that the original bug is fixed, by adding more than 10 gradients to a session and checking they all exist. Can you add such a test?

@nfeybesse
Copy link
Contributor Author

Good question.

The reason we return a fresh anonymous subclass of TFJ_GradFuncAdapter (instead of reusing the original subclass instance) is related to how FunctionPointer instances are managed on the JVM side and how their native address is captured.

TFJ_GradFuncAdapter (see

TFJ_GradFuncAdapter

) extends FunctionPointer. Each instance is associated with a native trampoline pointer allocated by JavaCPP. That native pointer is what TensorFlow C actually stores internally when we register the gradient.

If we were to reuse the same adapter instance:

We would share a single native function pointer across multiple gradient registrations.

The lifetime of that pointer would be tied to the original Java object.

In some scenarios (especially with multiple graphs or repeated registrations), this can lead to subtle issues:

pointer reuse across registrations,

unexpected deallocation / GC interaction,

or native-side bookkeeping assuming distinct callbacks per registration.

By returning a fresh anonymous subclass each time, we guarantee:

A distinct FunctionPointer instance.

A distinct native trampoline pointer.

No accidental sharing of callback state across gradient registrations.

In other words, the anonymous subclass is not about Java polymorphism — it's about forcing allocation of a fresh native callback binding, so the C side never sees a reused function pointer instance.

That said, if you think it's clearer, I can refactor the code to make that intent explicit (e.g. by adding a short comment explaining that we deliberately allocate a new FunctionPointer instance per registration to avoid native pointer reuse).

nfeybesse added a commit to nfeybesse/tensorflow that referenced this pull request Feb 24, 2026
…ence

This adds a regression test for PR tensorflow#632.

The test dynamically discovers op types with no registered gradient
(using TF_GetAllOpList + TensorFlow.hasGradient), registers 11 custom
gradients, and verifies that all are present in the native gradient
registry.

This directly validates that registering more than 10 gradients works
and that all entries are correctly stored in the native registry,
without relying on Graph.addGradients() execution.

Addresses reviewer comment about missing test for >10 gradients.
opType,
(tf, op, gradInputs) -> {
int n = op.numInputs();
java.util.ArrayList<org.tensorflow.Operand<?>> grads = new java.util.ArrayList<>(n);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayList is imported, don't fully qualify it.

@Craigacp
Copy link
Collaborator

I thought the point of this change was that there was only ever a single native function pointer, which points to the DispatchingGradientAdapter which then dispatches on the Java side to the gradient op?

@nfeybesse
Copy link
Contributor Author

Yes, that’s correct — the intent of this change is that there is a single native function pointer registered, which points to the DispatchingGradientAdapter, and all gradient dispatching happens on the Java side based on opType.

The anonymous subclass is not meant to introduce multiple native callbacks. It is only used to ensure we get a properly allocated and strongly reachable FunctionPointer instance on the JavaCPP side, avoiding any lifecycle or pointer ownership issues. Conceptually, there is still just one native entry point, and all routing is done in Java.

In parallel, I’m also working on stabilizing and refining the gradients for branching operations (e.g., If, Switch, etc.), which rely heavily on this dispatching mechanism.

…ence

This adds a regression test for PR tensorflow#632.

The test dynamically discovers op types with no registered gradient
(using TF_GetAllOpList + TensorFlow.hasGradient), registers 11 custom
gradients, and verifies that all are present in the native gradient
registry.

This directly validates that registering more than 10 gradients works
and that all entries are correctly stored in the native registry,
without relying on Graph.addGradients() execution.

Addresses reviewer comment about missing test for >10 gradients.
@nfeybesse nfeybesse force-pushed the custom/gradients-dispatch branch from b0dbea2 to 56539b9 Compare February 24, 2026 16:26
@Craigacp
Copy link
Collaborator

Yes, that’s correct — the intent of this change is that there is a single native function pointer registered, which points to the DispatchingGradientAdapter, and all gradient dispatching happens on the Java side based on opType.

The anonymous subclass is not meant to introduce multiple native callbacks. It is only used to ensure we get a properly allocated and strongly reachable FunctionPointer instance on the JavaCPP side, avoiding any lifecycle or pointer ownership issues. Conceptually, there is still just one native entry point, and all routing is done in Java.

In parallel, I’m also working on stabilizing and refining the gradients for branching operations (e.g., If, Switch, etc.), which rely heavily on this dispatching mechanism.

That doesn't make any sense to me. If the logic is all Java side, then we don't need to worry about JavaCPP. It'll be reachable because it's stored in the ConcurrentHashMap inside DispatchingGradientAdapter. The new gradient adapters don't interact with the native code at all as the native code calls up into them.

Consequently the native code in tfj_gradients_impl can be simplified so it doesn't contain a map for the gradient adapters because that logic lives in Java now, it only needs a reference to the DispatchingGradientAdapter which can be set on the first call.

Replace the native per-op unordered_map of TFJ_GradFuncAdapter with a
single global dispatch adapter.

The native layer now registers CustomGradFunc per op type in the
GradOpRegistry, but always calls the same TFJ_GradFuncAdapter instance.
All opType-based routing is handled on the Java side by
DispatchingGradientAdapter.

This aligns the native implementation with the intended design:
there is only one native function pointer registered, and dispatch
logic lives entirely in Java.

Also fixes unsafe casting of Scope* to TFJ_Scope* by constructing a
temporary TFJ_Scope wrapper instead.
@Craigacp
Copy link
Collaborator

Craigacp commented Mar 6, 2026

I think we can simplify this fix even further. Removing the adapter methods on CustomGradient and RawCustomGradient as those are no longer called allows us to remove TypedGradientAdapter and RawGradientAdapter completely. I tested this out in this branch - https://github.com/Craigacp/tensorflow-java/tree/minimal-fix, and the tests pass. Can you work this into your PR and we'll merge it?

@nfeybesse
Copy link
Contributor Author

nfeybesse commented Mar 7, 2026

I’ve just tested this change and it seems to work well with my programs, including the custom gradient tests involving branching.

Programs that use my smartIf still require a small improvement in Graph.java. In particular, the concrete functions need to be cached before the gradient of the StatefulIf is executed. During gradient execution, we must avoid calling TFJ APIs that manipulate functions, since those parts of the API are locked at that stage (for example function registration, lookup, or any operation that may trigger function creation).

This approach more or less mirrors what is done on the Python side, where the concrete functions are prepared ahead of time before the gradient logic runs.

With this small adjustment in Graph.java, the test passes correctly. I’ve included the test and the corresponding diff below.

In my framework, smartIf is implemented in a fairly elegant way using lambdas, rather than requiring users to manually define concrete functions. Nested conditionals also work correctly, and I have tests covering that case as well. In addition, tests involving side effects behave as expected.

It would be great if this change could be included in the PR, as I would prefer not to lose support for StatefulIf.

Diff:

index 488434c56..4d3fa73c9 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java
@@ -43,6 +43,7 @@ import java.util.NoSuchElementException;
 import java.util.Queue;
 import java.util.Set;
 import java.util.WeakHashMap;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.stream.Collectors;
 import org.bytedeco.javacpp.BytePointer;
 import org.bytedeco.javacpp.Pointer;
@@ -398,6 +399,10 @@ public final class Graph implements ExecutionEnvironment, AutoCloseable {
 
   @Override
   public void attachFunction(ConcreteFunction function) {
+    String name = function.getDefinedName();
+    if (functionCache.putIfAbsent(name, function) != null) {
+      return;
+    }
     try (Reference ref = ref();
         PointerScope scope = new PointerScope()) {
       TF_Status status = TF_Status.newStatus();
@@ -455,6 +460,10 @@ public final class Graph implements ExecutionEnvironment, AutoCloseable {
    *     name
    */
   public ConcreteFunction getFunction(String key) {
+    ConcreteFunction cached = functionCache.get(key);
+    if (cached != null) {
+      return cached;
+    }
     try (Reference ref = ref();
         PointerScope scope = new PointerScope()) {
       List<NativeFunction> funcs = getNativeFunctions(scope);
@@ -881,6 +890,15 @@ public final class Graph implements ExecutionEnvironment, AutoCloseable {
   private final Set<Operation> initializers = Collections.synchronizedSet(new LinkedHashSet<>());
   private int newInitializersMarker = -1;
 
+  private final ConcurrentHashMap<String, ConcreteFunction> functionCache = new ConcurrentHashMap<>();
+
+  public ConcreteFunction getFunctionCached(String prefix) {
+         for(String name : functionCache.keySet())
+                 if(name.startsWith(prefix)) 
+                         return functionCache.get(name);
+    return null;
+  }
+
   /**
    * Use builders without locking. This should only be used during custom gradient building.
    *

/*
 Copyright 2026 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
 */
package org.tensorflow;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;

import org.junit.jupiter.api.Test;
import org.tensorflow.ConcreteFunction;
import org.tensorflow.Graph;
import org.tensorflow.GraphOperation;
import org.tensorflow.Operand;
import org.tensorflow.Result;
import org.tensorflow.Signature;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Gradients;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.StatefulIf;
import org.tensorflow.op.core.StatefulPartitionedCall;
import org.tensorflow.op.core.StatelessIf;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.family.TType;

public class IfGradientTest {

	private static ConcreteFunction thenFn() {
		return ConcreteFunction.create((Ops tf) -> {
			Placeholder<TFloat32> x = tf.placeholder(TFloat32.class);
			Operand<TFloat32> y = tf.math.mul(x, tf.constant(3.0f));
			return Signature.builder("thenBranch").input("x", x).output("y", y).build();
		});
	}

	private static ConcreteFunction elseFn() {
		return ConcreteFunction.create((Ops tf) -> {
			Placeholder<TFloat32> x = tf.placeholder(TFloat32.class);
			Operand<TFloat32> y = tf.math.mul(x, tf.constant(5.0f));
			return Signature.builder("elseBranch").input("x", x).output("y", y).build();
		});
	}

	private static void assertClose(float got, float expected, float eps, String msg) {
		if (Math.abs(got - expected) > eps) {
			throw new AssertionError(msg + " (got=" + got + ", expected=" + expected + ")");
		}
	}
	
private static void primeIfGradFunctions(Graph g) {	
		
		Iterator<GraphOperation> operations = g.operations();
		while (operations.hasNext()) {
			GraphOperation op = operations.next();
			String type = op.type();
			if (!StatefulIf.OP_NAME.equals(type) && !StatelessIf.OP_NAME.equals(type))
				continue;

			ConcreteFunction thenFwd = op.attributes().getAttrFunction("then_branch");
			ConcreteFunction elseFwd = op.attributes().getAttrFunction("else_branch");

			int nInputs = op.inputListLength("input");
			int nOut = op.numOutputs();

			List<Class<? extends TType>> tin = new ArrayList<>(nInputs);
			for (int i = 0; i < nInputs; i++) {
				Class<? extends TType> c = op.input(1 + i).asOutput().type();
				tin.add(c);
			}

			List<Class<? extends TType>> tout = new ArrayList<>(nOut);
			for (int i = 0; i < nOut; i++) {
				Class<? extends TType> c = op.output(i).type();
				tout.add(c);
			}

			ConcreteFunction thenGrad = buildBranchGradFn(op.name() + "/then_grad",thenFwd, tin, tout);
			ConcreteFunction elseGrad = buildBranchGradFn(op.name() + "/else_grad",elseFwd, tin, tout);

			g.attachFunction(thenGrad);
			g.attachFunction(elseGrad);
		}
	}

	@SuppressWarnings({ "rawtypes", "unchecked" })
	private static ConcreteFunction buildBranchGradFn(String prefix,ConcreteFunction branchFn, List<Class<? extends TType>> tin, List<Class<? extends TType>> toutForward) {

		return ConcreteFunction.create((Ops tf) -> {
			Signature.Builder sig = Signature.builder(prefix);

			List<Operand<?>> x = new ArrayList<>(tin.size());
			for (int i = 0; i < tin.size(); i++) {
				Placeholder<? extends TType> ph = tf.placeholder((Class) tin.get(i));
				x.add(ph);
				sig.input("x" + i, ph);
			}

			List<Operand<?>> dy = new ArrayList<>(toutForward.size());
			for (int i = 0; i < toutForward.size(); i++) {
				Placeholder<? extends TType> ph = tf.placeholder((Class) toutForward.get(i));
				dy.add(ph);
				sig.input("dy" + i, ph);
			}

			StatefulPartitionedCall yCall = StatefulPartitionedCall.create(tf.scope(), x, toutForward, branchFn);

			Operand<?> L = tf.constant(0.0f);
			for (int i = 0; i < toutForward.size(); i++) {
				Operand<?> prod = tf.math.mul((Operand) yCall.output().get(i), (Operand) dy.get(i));
				L = tf.math.add((Operand) L, (Operand) sumAll(tf, prod));
			}

			Gradients g = tf.gradients((Iterable) List.of((Operand) L), x);

			for (int i = 0; i < tin.size(); i++) {
				Operand<?> dx = g.dy(i);
				sig.output("dx" + i, dx);
			}

			return sig.build();
		});
	}

	@SuppressWarnings({ "rawtypes", "unchecked" })
	private static Operand<?> sumAll(Ops tf, Operand<?> v) {
		Operand<TInt32> r = tf.rank(v);
		Operand<TInt32> axes = tf.range(tf.constant(0), r, tf.constant(1));
		return tf.reduceSum((Operand) v, axes);
	}

	@Test
	public void testStatefullIfGradient() {
		TensorFlow.registerCustomGradient(StatefulIf.OP_NAME, (tf, op, gradOutputs) -> {
			
			OperationAttributeInspector attrs = op.attributes();
			ConcreteFunction thenBranch = attrs.getAttrFunction("then_branch");
			ConcreteFunction elseBranch = attrs.getAttrFunction("else_branch");

			if (thenBranch == null || elseBranch == null) {
				int n = 1 + op.inputListLength("input");
				List<Operand<?>> no = new ArrayList<>(n);
				for (int i = 0; i < n; i++) {
					no.add(null);
				}
				return no;
			}

			Operand<? extends TType> cond = op.input(0);
			int nInputs = op.inputListLength("input");
			List<Operand<?>> inputs = new ArrayList<>(nInputs);
			for (int i = 0; i < nInputs; i++) {
				inputs.add(op.input(1 + i));
			}

			int nOut = op.numOutputs();
			List<Class<? extends TType>> toutForward = new ArrayList<>(nOut);
			for (int i = 0; i < nOut; i++) {
				toutForward.add(op.output(i).type());
			}

			List<Class<? extends TType>> tin = inputs.stream().map(input->input.asOutput().type()).collect(Collectors.toList());
			List<Operand<?>> dys = new ArrayList<>(nOut);
			for (int i = 0; i < nOut; i++) {
				Operand<?> dy = null;
				if (gradOutputs != null && i < gradOutputs.size()) {
					dy = gradOutputs.get(i);
				}
				if (dy == null) {
					dy = gradOutputs == null || gradOutputs.isEmpty() ? tf.onesLike((Operand) op.output(i)) : tf.zerosLike((Operand) op.output(i));
				}
				dys.add(dy);
			}

			List<Operand<?>> input = new ArrayList<>(nInputs + nOut);
			input.addAll(inputs);
			input.addAll(dys);

			final String thenPrefix = op.name() + "/then_grad"; //op has unique name
			final String elsePrefix = op.name() + "/else_grad";

			ConcreteFunction thenGrad = op.env().getFunctionCached(thenPrefix);
			ConcreteFunction elseGrad = op.env().getFunctionCached(elsePrefix);

			if (thenGrad == null || elseGrad == null) {
				throw new IllegalStateException("If grad functions not primed for op=" + op.name());
			}
			StatefulIf dInputsIf = StatefulIf.create(tf.scope(), cond, input, tin, thenGrad, elseGrad);
			List<Operand<?>> result = new ArrayList<>(1 + nInputs);
			result.add(null); // no gradient for condition
			result.addAll( dInputsIf.output());
			return result;
		});

		Graph g = new Graph();
		Ops tf = Ops.create(g);

		var x = tf.placeholder(TFloat32.class); // scalar
		var cond = tf.placeholder(TBool.class); // scalar

		try (ConcreteFunction thenBranch = thenFn(); ConcreteFunction elseBranch = elseFn()) {

			StatefulIf ifOp = StatefulIf.create(tf.scope(), cond, List.of((Operand) x), List.of(TFloat32.class), thenBranch, elseBranch);

			var y = ifOp.output().get(0);

			primeIfGradFunctions(g);

			var dy_dx = g.addGradients(y, new Output[] {x.asOutput()})[0];
			
			try(Session session = new Session(g)){

				try (Result r =session.runner().feed(x, TFloat32.scalarOf(2.0f)).feed(cond, TBool.scalarOf(true)).fetch(y).fetch(dy_dx).run()) {

					float yVal = ((TFloat32) r.get(0)).getFloat();
					float gVal = ((TFloat32) r.get(1)).getFloat();

					assertClose(yVal, 6.0f, 1e-6f, "y mismatch for cond=true");
					assertClose(gVal, 3.0f, 1e-6f, "grad mismatch for cond=true");
				}

				// ---- cond=false
				try (Result r = session.runner().feed(x, TFloat32.scalarOf(2.0f)).feed(cond, TBool.scalarOf(false)).fetch(y).fetch(dy_dx).run()) {

					float yVal = ((TFloat32) r.get(0)).getFloat();
					float gVal = ((TFloat32) r.get(1)).getFloat();
					assertClose(yVal, 10.0f, 1e-6f, "y mismatch for cond=false");
					assertClose(gVal, 5.0f, 1e-6f, "grad mismatch for cond=false");
				}
			};

		}
	};
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants