!pip install -q --upgrade jax jaxlib
100% |████████████████████████████████| 215kB 6.8MB/s 100% |████████████████████████████████| 21.1MB 1.1MB/s 100% |████████████████████████████████| 61kB 18.0MB/s Building wheel for opt-einsum (setup.py) ... done
from __future__ import print_function, division
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
/usr/local/lib/python2.7/dist-packages/jax/lib/xla_bridge.py:167: UserWarning: No GPU found, falling back to CPU. warnings.warn('No GPU found, falling back to CPU.')
alexbw@, mattjj@
JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics.
grad
¶You can differentiate a function with grad
:
grad_tanh = grad(np.tanh)
print(grad_tanh(2.0))
0.070650816
grad
takes a function and returns a function. If you have a Python function f
that evaluates the mathematical function $f$, then grad(f)
is a Python function that evaluates the mathematical function $\nabla f$. That means grad(f)(x)
represents the value $\nabla f(x)$.
Since grad
operates on functions, you can apply it to its own output to differentiate as many times as you like:
print(grad(grad(np.tanh))(2.0))
print(grad(grad(grad(np.tanh)))(2.0))
-0.13621867 0.25265405
Let's look at computing gradients with grad
in a linear logistic regression model. First, the setup:
def sigmoid(x):
return 0.5 * (np.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(np.dot(inputs, W) + b)
# Build a toy dataset.
inputs = np.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = np.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -np.sum(np.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
Use the grad
function with its argnums
argument to differentiate a function with respect to positional arguments.
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)
# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
W_grad [-0.16965586 -0.8774649 -1.4901347 ] W_grad [-0.16965586 -0.8774649 -1.4901347 ] b_grad -0.2922725 W_grad [-0.16965586 -0.8774649 -1.4901347 ] b_grad -0.2922725
This grad
API has a direct correspondence to the excellent notation in Spivak's classic Calculus on Manifolds (1965), also used in Sussman and Wisdom's Structure and Interpretation of Classical Mechanics (2015) and their Functional Differential Geometry (2013). Both books are open-access. See in particular the "Prologue" section of Functional Differential Geometry for a defense of this notation.
Essentially, when using the argnums
argument, if f
is a Python function for evaluating the mathematical function $f$, then the Python expression grad(f, i)
evaluates to a Python function for evaluating $\partial_i f$.
Differentiating with respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -np.sum(np.log(label_probs))
print(grad(loss2)({'W': W, 'b': b}))
{'b': array(-0.2922725, dtype=float32), 'W': array([-0.16965586, -0.8774649 , -1.4901347 ], dtype=float32)}
You can register your own container types to work with not just grad
but all the JAX transformations (jit
, vmap
, etc.).
value_and_grad
¶Another convenient function is value_and_grad
for efficiently computing both a function's value as well as its gradient's value:
from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 3.0519395 loss value 3.0519395
A great thing about derivatives is that they're straightforward to check with finite differences:
# Set a step size for finite differences calculations
eps = 1e-4
# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))
# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / np.sqrt(np.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', np.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.29325485 b_grad_autodiff -0.2922725 W_dirderiv_numerical -0.19550323 W_dirderiv_autodiff -0.19909078
JAX provides a simple convenience function that does essentially the same thing, but checks up to any order of differentiation that you like:
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives
grad
-of-grad
¶One thing we can do with higher-order grad
is build a Hessian-vector product function. (Later on we'll write an even more efficient implementation that mixes both forward- and reverse-mode, but this one will use pure reverse-mode.)
A Hessian-vector product function can be useful in a truncated Newton Conjugate-Gradient algorithm for minimizing smooth convex functions, or for studying the curvature of neural network training objectives (e.g. 1, 2, 3, 4).
For a scalar-valued function $f : \mathbb{R}^n \to \mathbb{R}$, the Hessian at a point $x \in \mathbb{R}^n$ is written as $\partial^2 f(x)$. A Hessian-vector product function is then able to evaluate
$\qquad v \mapsto \partial^2 f(x) \cdot v$
for any $v \in \mathbb{R}^n$.
The trick is not to instantiate the full Hessian matrix: if $n$ is large, perhaps in the millions or billions in the context of neural networks, then that might be impossible to store.
Luckily, grad
already gives us a way to write an efficient Hessian-vector product function. We just have to use the identity
$\qquad \partial^2 f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)$,
where $g(x) = \partial f(x) \cdot v$ is a new scalar-valued function that dots the gradient of $f$ at $x$ with the vector $v$. Nottice that we're only ever differentiating scalar-valued functions of vector-valued arguments, which is exactly where we know grad
is efficient.
In JAX code, we can just write this:
def hvp(f, x, v):
return grad(lambda x: np.vdot(grad(f)(x), v))
This example shows that you can freely use lexical closure, and JAX will never get perturbed or confused.
We'll check this implementation a few cells down, once we see how to compute dense Hessian matrices.
jacfwd
and jacrev
¶You can compute full Jacobian matrices using the jacfwd
and jacrev
functions:
from jax import jacfwd, jacrev
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)
J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3) [[ 0.05981753 0.12883775 0.08857596] [ 0.04015912 -0.0492862 0.0068453 ] [ 0.1218829 0.01406341 -0.30470726] [ 0.00140427 -0.00472519 0.00263776]] jacrev result, with shape (4, 3) [[ 0.05981753 0.12883775 0.08857595] [ 0.04015912 -0.0492862 0.00684531] [ 0.1218829 0.01406341 -0.30470726] [ 0.00140427 -0.00472519 0.00263776]]
These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd
uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices, while jacrev
uses reverse-mode, which is more efficient for "wide" Jacobian matrices. For matrices that are near-square, jacfwd
probably has an edge over jacrev
.
You can also use jacfwd
and jacrev
with container types:
def predict_dict(params, inputs):
return predict(params['W'], params['b'], inputs)
J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
print("Jacobian from {} to logits is".format(k))
print(v)
Jacobian from b to logits is [0.11503371 0.04563536 0.2343902 0.00189767] Jacobian from W to logits is [[ 0.05981753 0.12883775 0.08857595] [ 0.04015912 -0.0492862 0.00684531] [ 0.1218829 0.01406341 -0.30470726] [ 0.00140427 -0.00472519 0.00263776]]
For more details on forward- and reverse-mode, as well as how to implement jacfwd
and jacrev
as efficiently as possible, read on!
Using a composition of two of these functions gives us a way to compute dense Hessian matrices:
def hessian(f):
return jacfwd(jacrev(f))
H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3) [[[ 0.02285464 0.04922539 0.03384245] [ 0.04922538 0.10602392 0.07289143] [ 0.03384245 0.07289144 0.05011286]] [[-0.03195212 0.03921397 -0.00544638] [ 0.03921397 -0.04812624 0.0066842 ] [-0.00544638 0.0066842 -0.00092836]] [[-0.01583708 -0.00182736 0.03959271] [-0.00182736 -0.00021085 0.00456839] [ 0.03959271 0.00456839 -0.09898178]] [[-0.00103521 0.00348334 -0.00194452] [ 0.00348334 -0.01172098 0.00654304] [-0.00194452 0.00654304 -0.00365254]]]
This shape makes sense: if we start with a function $f : \mathbb{R}^n \to \mathbb{R}^m$, then at a point $x \in \mathbb{R}^n$ we expect to get the shapes
and so on.
To implement hessian
, we could have used jacrev(jacrev(f))
or jacrev(jacfwd(f))
or any other composition of the two. But forward-over-reverse is typically the most efficient. That's because in the inner Jacobian computation we're often differentiating a function wide Jacobian (maybe like a loss function $f : \mathbb{R}^n \to \mathbb{R}$), while in the outer Jacobian computation we're differentiating a function with a square Jacobian (since $\nabla f : \mathbb{R}^n \to \mathbb{R}^n$), which is where forward-mode wins out.
JAX includes efficient and general implementations of both forward- and reverse-mode automatic differentiation. The familiar grad
function is built on reverse-mode, but to explain the difference in the two modes, and when each can be useful, we need a bit of math background.
Mathematically, given a function $f : \mathbb{R}^n \to \mathbb{R}^m$, the Jacobian matrix of $f$ evaluated at an input point $x \in \mathbb{R}^n$, denoted $\partial f(x)$, is often thought of as a matrix in $\mathbb{R}^m \times \mathbb{R}^n$:
$\qquad \partial f(x) \in \mathbb{R}^{m \times n}$.
But we can also think of $\partial f(x)$ as a linear map, which maps the tangent space of the domain of $f$ at the point $x$ (which is just another copy of $\mathbb{R}^n$) to the tangent space of the codomain of $f$ at the point $f(x)$ (a copy of $\mathbb{R}^m$):
$\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$.
This map is called the pushforward map of $f$ at $x$. The Jacobian matrix is just the matrix for this linear map in a standard basis.
If we don't commit to one specific input point $x$, then we can think of the function $\partial f$ as first taking an input point and returning the Jacobian linear map at that input point:
$\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m$.
In particular, we can uncurry things so that given input point $x \in \mathbb{R}^n$ and a tangent vector $v \in \mathbb{R}^n$, we get back an output tangent vector in $\mathbb{R}^m$. We call that mapping, from $(x, v)$ pairs to output tangent vectors, the Jacobian-vector product, and write it as
$\qquad (x, v) \mapsto \partial f(x) v$
Back in Python code, JAX's jvp
function models this transformation. Given a Python function that evaluates $f$, JAX's jvp
is a way to get a Python function for evaluating $(x, v) \mapsto (f(x), \partial f(x) v)$.
from jax import jvp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))
In terms of Haskell-like type signatures, we could write
jvp :: (a -> b) -> a -> T a -> (b, T b)
where we use T a
to denote the type of the tangent space for a
. In words, jvp
takes as arguments a function of type a -> b
, a value of type a
, and a tangent vector value of type T a
. It gives back a pair consisting of a value of type b
and an output tangent vector of type T b
.
The jvp
-transformed function is evaluated much like the original function, but paired up with each primal value of type a
it pushes along tangent values of type T a
. For each primitive numerical operation that the original function would have applied, the jvp
-transformed function executes a "JVP rule" for that primitive that both evaluates the primitive on the primals and applies the primitive's JVP at those primal values.
That evaluation strategy has some immediate implications about computational complexity: since we evaluate JVPs as we go, we don't need to store anything for later, and so the memory cost is independent of the depth of the computation. In addition, the FLOP cost of the jvp
-transformed function is about 2x the cost of just evaluating the function. Put another way, for a fixed primal point $x$, we can evaluate $v \mapsto \partial f(x) \cdot v$ for about the same cost as evaluating $f$.
That memory complexity sounds pretty compelling! So why don't we see forward-mode very often in machine learning?
To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with "tall" Jacobians, but inefficient for "wide" Jacobians.
If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\mathbb{R}^n$ to a scalar loss value in $\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\partial f(x) \in \mathbb{R}^{1 \times n}$, which we often identify with the Gradient vector $\nabla f(x) \in \mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluating the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale.
To do better for functions like this, we just need to use reverse-mode.
Where forward-mode gives us back a function for evaluating Jacobian-vector products, which we can then use to build Jacobian matrices one column at a time, reverse-mode is a way to get back a function for evaluating vector-Jacobian products (equivalently Jacobian-transpose-vector products), which we can use to build Jacobian matrices one row at a time.
Let's again consider a function $f : \mathbb{R}^n \to \mathbb{R}^m$. Starting from our notation for JVPs, the notation for VJPs is pretty simple:
$\qquad (x, v) \mapsto v \partial f(x)$,
where $v$ is an element of the cotangent space of $f$ at $x$ (isomorphic to another copy of $\mathbb{R}^m$). When being rigorous, we should think of $v$ as a linear map $v : \mathbb{R}^m \to \mathbb{R}$, and when we write $v \partial f(x)$ we mean function composition $(v \circ \partial f)(x)$. But in the common case we can identify it with a vector in $\mathbb{R}^m$ and use the two almost interchageably, just like we might sometimes flip between "column vectors" and "row vectors" without much comment.
With that identification, we can alternatively think of the linear part of a VJP as the transpose (or adjoint conjugate) of the linear part of a JVP:
$\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v$.
For a given point $x$, we can write the signature as
$\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n$.
The corresponding map on cotangent spaces is often called the pullback of $f$ at $x$. The key for our purposes is that it goes from something that looks like the output of $f$ to something that looks like the input of $f$, just like we might expect from a transposed linear function.
Switching from math back to Python, the JAX function vjp
can take a Python function for evaluating $f$ and give us back a Python function for evaluating the VJP $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$.
from jax import vjp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
y, vjp_fun = vjp(f, W)
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)
In terms of Haskell-like type signatures, we could write
vjp :: (a -> b) -> a -> (b, CT b -> CT a)
where we use CT a
to denote the type for the cotangent space for a
. In words, vjp
takes as arguments a function of type a -> b
and a point of type a
, and gives back a pair consisting of a value of type b
and a linear map of type CT b -> CT a
.
This is great because it lets us build Jacobian matrices one row at a time, and the FLOP cost for evaluating $(x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))$ is only about twice the cost of evaluating $f$. In particular, if we want the gradient of a function $f : \mathbb{R}^n \to \mathbb{R}$, we can do it in just one call. That's how grad
is efficient for gradient-based optimization, even for objectives like neural network training loss functions on millions or billions of parameters.
There's a cost, though: though the FLOPs are friendly, memory scales with the depth of the computation. Also, the implementation is traditionally more complex than that of forward-mode, though JAX has some tricks up its sleeve (that's a story for a future notebook!).
For more on how reverse-mode works, see this tutorial video from the Deep Learning Summer School in 2017.
vmap
¶Now that we have jvp
and vjp
transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's vmap
transformation to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products.
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
return np.vstack([vjp_fun(mi) for mi in M])
# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
return vmap(vjp_fun)(M)
key = random.PRNGKey(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)
print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)
assert np.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product 10 loops, best of 3: 146 ms per loop Vmapped Matrix-Jacobian product 10 loops, best of 3: 6.85 ms per loop
def loop_jmp(f, x, M):
# jvp immediately returns the primal and tangent values as a tuple,
# so we'll compute and select the tangents in a list comprehension
return np.vstack([jvp(f, (W,), (si,))[1] for si in S])
def vmap_jmp(f, x, M):
_jvp = lambda s: jvp(f, (W,), (s,))[1]
return vmap(_jvp)(M)
num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)
loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)
assert np.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product 10 loops, best of 3: 525 ms per loop Vmapped Jacobian-Matrix product 10 loops, best of 3: 5.6 ms per loop
jacfwd
and jacrev
¶Now that we've seen fast Jacobian-matrix and matrix-Jacobian products, it's not hard to guess how to write jacfwd
and jacrev
. We just use the same technique to push-forward or pull-back an entire standard basis (isomorphic to an identity matrix) at once.
from jax import jacrev as builtin_jacrev
def our_jacrev(f):
def jacfun(x):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrix-Jacobian product.
# Here, the matrix is the Euclidean basis, so we get all
# entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(np.eye(len(y)))
return J
return jacfun
assert np.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
Jt =vmap(_jvp, in_axes=1)(np.eye(len(x)))
return np.transpose(Jt)
return jacfun
assert np.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'
Interestingly, Autograd couldn't do this. Our implementation of reverse-mode jacobian
in Autograd had to pull back one vector at a time with an outer-loop map
. Pushing one vector at a time through the computation is much less efficient than batching it all together with vmap
.
Another thing that Autograd couldn't do is jit
. Interestingly, no matter how much Python dynamism you use in your function to be differentiated, we could always use jit
on the linear part of the computation. For example:
def f(x):
try:
if x < 3:
return 2 * x ** 3
else:
raise ValueError
except ValueError:
return np.pi * x
y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(array(3.1415927, dtype=float32),)
JAX is great at complex numbers and differentiation. To support both holomorphic and non-holomorphic differentiation, JAX follows Autograd's convention for encoding complex derivatives.
Consider a complex-to-complex function $f: \mathbb{C} \to \mathbb{C}$ that we break down into its component real-to-real functions:
def f(z):
x, y = real(z), imag(z)
return u(x, y), v(x, y) * 1j
That is, we've decomposed $f(z) = u(x, y) + v(x, y) i$ where $z = x + y i$. We define grad(f)
to correspond to
def grad_f(z):
x, y = real(z), imag(z)
return grad(u, 0)(x, y) + grad(u, 1)(x, y) * 1j
In math symbols, that means we define $\partial f(z) \triangleq \partial_0 u(x, y) + \partial_1 u(x, y)$. So we throw out $v$, ignoring the complex component function of $f$ entirely!
This convention covers three important cases:
f
evaluates a holomorphic function, then we get the usual complex derivative, since $\partial_0 u = \partial_1 v$ and $\partial_1 u = - \partial_0 v$.f
is evaluates the real-valued loss function of a complex parameter x
, then we get a result that we can use in gradient-based optimization by taking steps in the direction of the conjugate of grad(f)(x)
.f
evaluates a real-to-real function, but its implementation uses complex primitives internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then we get the same result that an implementation that only used real primitives would have given.By throwing away v
entirely, this convention does not handle the case where f
evaluates a non-holomorphic function and you want to evaluate all of $\partial_0 u$, $\partial_1 u$, $\partial_0 v$, and $\partial_1 v$ at once. But in that case the answer would have to contain four real values, and so there's no way to express it as a single complex number.
You should expect complex numbers to work everywhere in JAX. Here's differentiating through a Cholesky decomposition of a complex matrix:
A = np.array([[5., 2.+3j, 5j],
[2.-3j, 7., 1.+7j],
[-5j, 1.-7j, 12.]])
def f(X):
L = np.linalg.cholesky(X)
return np.sum((L - np.sin(L))**2)
grad(f)(A)
array([[ 1.6623291e+01+0.j , -1.3631370e+00-5.6038527j, -1.8995690e+00+9.700885j ], [-1.3631370e+00+5.6038527j, -8.9385948e+00+0.j , -5.1351528e+00-6.5743794j], [-1.8995690e+00-9.700885j , -5.1351528e+00+6.5743794j, 1.3204219e-02+0.j ]], dtype=complex64)
For primitives' JVP rules, writing the primals as $z = a + bi$ and the tangents as $t = c + di$, we define the Jacobian-vector product $t \mapsto \partial f(z) \cdot t$ as
$t \mapsto \begin{matrix} \begin{bmatrix} 1 & 1 \end{bmatrix} \\ ~ \end{matrix} \begin{bmatrix} \partial_0 u(a, b) & -\partial_0 v(a, b) \\ - \partial_1 u(a, b) i & \partial_1 v(a, b) i \end{bmatrix} \begin{bmatrix} c \\ d \end{bmatrix}$.
See Chapter 4 of Dougal's PhD thesis for more details.
In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful.
There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in a "Advanced Autodiff Cookbook" include: