Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Dougal Maclaurin, Peter Hawkins, Matthew Johnson, Roy Frostig, Alex Wiltschko, Chris Leary
With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy code. It can differentiate through a large subset of Python’s features, including loops, ifs, recursion, and closures, and it can even take derivatives of derivatives of derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.
What’s new is that JAX uses XLA to compile and run your NumPy code on accelerators, like GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed. But JAX even lets you just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API. Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without having to leave Python.
!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-0.1.21-cp36-none-linux_x86_64.whl
!pip install --upgrade -q jax
from __future__ import print_function, division
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
We'll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers. For more details, see the readme.
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
Let's dive right in and multiply two big matrices.
size = 3000
x = random.normal(key, (size, size), dtype=np.float32)
%timeit np.dot(x, x.T).block_until_ready() # runs on the GPU
JAX NumPy functions work on regular NumPy arrays.
import numpy as onp # original CPU-backed NumPy
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit np.dot(x, x.T).block_until_ready()
That's slower because it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using device_put
.
from jax import device_put
x = onp.random.normal(size=(size, size)).astype(onp.float32)
x = device_put(x)
%timeit np.dot(x, x.T).block_until_ready()
The output of device_put
still acts like an NDArray. By the way, the implementation of device_put
is just device_put = jit(lambda x: x)
.
If you have a GPU (or TPU!) these calls run on the accelerator and have the potential to be much faster than on CPU.
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit onp.dot(x, x.T)
JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there's three main ones:
jit
, for speeding up your codegrad
, for taking derivativesvmap
, for automatic vectorization or batching.Let's go over these, one-by-one. We'll also end up composing these in interesting ways.
jit
to speed up functions¶JAX runs transparently on the GPU (or CPU, if you don't have one, and TPU coming soon!). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, we can use the @jit
decorator to compile multiple operations together using XLA. Let's try that.
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
We can speed it up with @jit
, which will jit-compile the first time selu
is called and will be cached thereafter.
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
grad
¶In addition to evaluating numerical functions, we also want to transform them. One transformation is automatic differentiation. In JAX, just like in Autograd, you can compute gradients with the grad
function.
def sum_logistic(x):
return np.sum(1.0 / (1.0 + np.exp(-x)))
x_small = np.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
Let's verify with finite differences that our result is correct.
def first_finite_differences(f, x):
eps = 1e-3
return np.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
for v in onp.eye(len(x))])
print(first_finite_differences(sum_logistic, x_small))
Taking derivatives is as easy as calling grad
. grad
and jit
compose and can be mixed arbitrarily. In the above example we jitted sum_logistic
and then took its derivative. We can go further:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
For more advanced autodiff, you can use jax.vjp
for reverse-mode vector-Jacobian products and jax.jvp
for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here's one way to compose them to make a function that efficiently computes full Hessian matrices:
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
vmap
¶JAX has one more transformation in its API that you might find useful: vmap
, the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. When composed with jit
, it can be just as fast as adding the batch dimensions by hand.
We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using vmap
. Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))
def apply_matrix(v):
return np.dot(mat, v)
Given a function such as apply_matrix
, we can loop over a batch dimension in Python, but usually the performance of doing so is poor.
def naively_batched_apply_matrix(v_batched):
return np.stack([apply_matrix(v) for v in v_batched])
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
We know how to batch this operation manually. In this case, np.dot
handles extra batch dimensions transparently.
@jit
def batched_apply_matrix(v_batched):
return np.dot(v_batched, mat.T)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
However, suppose we had a more complicated function without batching support. We can use vmap
to add batching support automatically.
@jit
def vmap_batched_apply_matrix(v_batched):
return vmap(apply_matrix)(v_batched)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Of course, vmap
can be arbitrarily composed with jit
, grad
, and any other JAX transformation.
This is just a taste of what JAX can do. We're really excited to see what you do with it!