Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
levskaya@ mattjj@
When walking about the countryside of Italy, the people will not hesitate to tell you that JAX has "una anima di pura programmazione funzionale".
JAX is a language for expressing and composing transformations of numerical programs. As such it needs to control the unwanted proliferation of side-effects in its programs so that analysis and transformation of its computations remain tractable!
This requires us to write code in a functional style with explicit descriptions of how the state of a program changes, which results in several important differences to how you might be used to programming in Numpy, Tensorflow or Pytorch.
Herein we try to cover the most frequent points of trouble that users encounter when starting out in JAX.
!pip install --upgrade -q git+https://github.com/google/jax.git
!pip install --upgrade -q jaxlib
import numpy as onp
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams['image.interpolation'] = 'nearest'
rcParams['image.cmap'] = 'viridis'
rcParams['axes.grid'] = False
In Numpy you're used to doing this:
numpy_array = onp.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)
# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)
original array: [[0. 0. 0.] [0. 0. 0.] [0. 0. 0.]] updated array: [[0. 0. 0.] [1. 1. 1.] [0. 0. 0.]]
If we try to update a JAX device array in-place, however, we get an error! (☉_☉)
jax_array = np.zeros((3,3), dtype=np.float32)
# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-189-a717a200f584> in <module>() 2 3 # In place update of JAX's array will yield an error! ----> 4 jax_array[1, :] = 1.0 TypeError: '_FilledConstant' object does not support item assignment
What gives?!
Allowing mutation of variables in-place makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program.
Instead, JAX offers the functional update functions: index_update, index_add and the index helper.
NB: Fancy Indexing is not yet supported, but will likely be added to JAX soon.
️⚠️ inside jit
'd code and lax.while_loop
or lax.fori_loop
the size of slices can't be functions of argument values but only functions of argument shapes -- the slice start indices have no such restriction. See the below Control Flow Section for more information on this limitation.
from jax.ops import index, index_add, index_update
If the input values of index_update aren't reused, jit-compiled code will perform these operations in-place.
jax_array = np.zeros((3, 3))
print("original array:")
print(jax_array)
new_jax_array = index_update(jax_array, index[1, :], 1.)
print("old array unchanged:")
print(jax_array)
print("new array:")
print(new_jax_array)
original array: [[0. 0. 0.] [0. 0. 0.] [0. 0. 0.]] old array unchanged: [[0. 0. 0.] [0. 0. 0.] [0. 0. 0.]] new array: [[0. 0. 0.] [1. 1. 1.] [0. 0. 0.]]
If the input values of index_update aren't reused, jit-compiled code will perform these operations in-place.
print("original array:")
jax_array = np.ones((5, 6))
print(jax_array)
new_jax_array = index_add(jax_array, index[::2, 3:], 7.)
print("new array post-addition:")
print(new_jax_array)
original array: [[1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1.] [1. 1. 1. 1. 1. 1.]] new array post-addition: [[1. 1. 1. 8. 8. 8.] [1. 1. 1. 1. 1. 1.] [1. 1. 1. 8. 8. 8.] [1. 1. 1. 1. 1. 1.] [1. 1. 1. 8. 8. 8.]]
If all scientific papers whose results are in doubt because of bad
rand()
s were to disappear from library shelves, there would be a gap on each shelf about as big as your fist. - Numerical Recipes
You're used to stateful pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:
print(onp.random.random())
print(onp.random.random())
print(onp.random.random())
0.7117779558041075 0.014396253746679077 0.7717174868106601
Underneath the hood, numpy uses the Mersenne Twister PRNG to power its pseudorandom functions. The PRNG has a period of $2^{19937-1}$ and at any point can be described by 624 32bit unsigned ints and a position indicating how much of this "entropy" has been used up.
onp.random.seed(0)
rng_state = onp.random.get_state()
#print(rng_state)
# --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044,
# 2481403966, 4042607538, 337614300, ... 614 more numbers...,
# 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0)
This pseudorandom state vector is automagically updated behind the scenes every time a random number is needed, "consuming" 2 of the uint32s in the Mersenne twister state vector:
_ = onp.random.uniform()
rng_state = onp.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0)
# Let's exhaust the entropy in this PRNG statevector
for i in range(311):
_ = onp.random.uniform()
rng_state = onp.random.get_state()
#print(rng_state)
# --> ('MT19937', array([2443250962, 1093594115, 1878467924,
# ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0)
# Next call iterates the RNG state for a new batch of fake "entropy".
_ = onp.random.uniform()
rng_state = onp.random.get_state()
# print(rng_state)
# --> ('MT19937', array([1499117434, 2949980591, 2242547484,
# 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0)
The problem with magic PRNG state is that it's hard to reason about how it's being used and updated across different threads, processes, and devices, and it's very easy to screw up when the details of entropy production and consumption are hidden from the end user.
The Mersenne Twister PRNG is also known to have a number of problems, it has a large 2.5Kb state size, which leads to problematic initialization issues. It fails modern BigCrush tests, and is generally slow.
JAX instead implements an explicit PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern Three-fry counter-based PRNG that's splittable. That is, its design allows us to fork the PRNG state into new PRNGs for use with parallel stochastic generation.
The random state is described by two unsigned-int32s that we call a key:
from jax import random
key = random.PRNGKey(0)
key
array([0, 0], dtype=uint32)
JAX's random functions produce pseudorandom numbers from the PRNG state, but do not change the state!
Reusing the same state will cause sadness and monotony, depriving the enduser of lifegiving chaos:
print(random.normal(key, shape=(1,)))
print(key)
# No no no!
print(random.normal(key, shape=(1,)))
print(key)
[-0.20584233] [0 0] [-0.20584233] [0 0]
Instead, we split the PRNG to get usable subkeys every time we need a new pseudorandom number:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \---SPLIT --> new key ", key)
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key [0 0] \---SPLIT --> new key [4146024105 967050713] \--> new subkey [2718843009 1272950319] --> normal [-1.2515389]
We propagate the key and make new subkeys whenever we need a new random number:
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print(" \---SPLIT --> new key ", key)
print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
old key [4146024105 967050713] \---SPLIT --> new key [2384771982 3928867769] \--> new subkey [1278412471 2182328957] --> normal [-0.5866507]
We can generate more than one subkey at a time:
key, *subkeys = random.split(key, 4)
for subkey in subkeys:
print(random.normal(subkey, shape=(1,)))
[-0.37533447] [0.9864503] [0.1455319]
✔ python control_flow + autodiff ✔
If you just want to apply grad
to your python functions, you can use regular python control-flow constructs with no problems, as if you were using Autograd (or Pytorch or TF Eager).
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
print(grad(f)(2.)) # ok!
print(grad(f)(4.)) # ok!
12.0 -4.0
python control flow + JIT
Using control flow with jit
is more complicated, and by default it has more constraints.
This works:
@jit
def f(x):
for i in range(3):
x = 2 * x
return x
print(f(3))
24
So does this:
@jit
def g(x):
y = 0.
for i in range(x.shape[0]):
y = y + x[i]
return y
print(g(np.array([1., 2., 3.])))
6.0
But this doesn't, at least by default:
@jit
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
# This will fail!
try:
f(2)
except Exception as e:
print("ERROR:", e)
ERROR: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.
What gives!?
When we jit
-compile a function, we usually want to compile a version of the function that works for many different argument values, so that we can cache and reuse the compiled code. That way we don't have to re-compile on each function evaluation.
For example, if we evaluate an @jit
function on the array np.array([1., 2., 3.], np.float32)
, we might want to compile code that we can reuse to evaluate the function on np.array([4., 5., 6.], np.float32)
to save on compile time.
To get a view of your Python code that is valid for many different argument values, JAX traces it on abstract values that represent sets of possible inputs. There are multiple different levels of abstraction, and different transformations use different abstraction levels.
By default, jit
traces your code on the ShapedArray
abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value ShapedArray((3,), np.float32)
, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.
But there's a tradeoff here: if we trace a Python function on a ShapedArray((), np.float32)
that isn't committed to a specific concrete value, when we hit a line like if x < 3
, the expression x < 3
evaluates to an abstract ShapedArray((), np.bool_)
that represents the set {True, False}
. When Python attempts to coerce that to a concrete True
or False
, we get an error: we don't know which branch to take, and can't continue tracing! The tradeoff is that with higher levels of abstraction we gain a more general view of the Python code (and thus save on re-compilations), but we require more constraints on the Python code to complete the trace.
The good news is that you can control this tradeoff yourself. By having jit
trace on more refined abstract values, you can relax the traceability constraints. For example, using the static_argnums
argument to jit
, we can specify to trace on concrete values of some arguments. Here's that example function again:
def f(x):
if x < 3:
return 3. * x ** 2
else:
return -4 * x
f = jit(f, static_argnums=(0,))
print(f(2.))
12.0
Here's another example, this time involving a loop:
def f(x, n):
y = 0.
for i in range(n):
y = y + x[i]
return y
f = jit(f, static_argnums=(1,))
f(np.array([2., 3., 4.]), 2)
array(5., dtype=float32)
In effect, the loop gets statically unrolled. JAX can also trace at higher levels of abstraction, like Unshaped
, but that's not currently the default for any transformation
️⚠️ functions with argument-value dependent shapes
These control-flow issues also come up in a more subtle way: numerical functions we want to jit can't specialize the shapes of internal arrays on argument values (specializing on argument shapes is ok). As a trivial example, let's make a function whose output happens to depend on the input variable length
.
def example_fun(length, val):
return np.ones((length,)) * val
# un-jit'd works fine
print(example_fun(5, 4))
bad_example_jit = jit(example_fun)
# this will fail:
try:
print(bad_example_jit(10, 4))
except Exception as e:
print("error!", e)
# static_argnums tells JAX to recompile on changes at these argument positions:
good_example_jit = jit(example_fun, static_argnums=(0,))
# first compile
print(good_example_jit(10, 4))
# recompiles
print(good_example_jit(5, 4))
[4. 4. 4. 4. 4.] error! `full` requires shapes to be concrete. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead. [4. 4. 4. 4. 4. 4. 4. 4. 4. 4.] [4. 4. 4. 4. 4.]
static_argnums
can be handy if length
in our example rarely changes, but it would be disastrous if it changed a lot!
Lastly, if your function has global side-effects, JAX's tracer can cause weird things to happen. A common gotcha is trying to print arrays inside jit'd functions:
@jit
def f(x):
print(x)
y = 2 * x
print(y)
return y
f(2)
Traced<ShapedArray(int32[]):JaxprTrace(level=-1/1)> Traced<ShapedArray(int32[]):JaxprTrace(level=-1/1)>
array(4, dtype=int32)
There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. then you can use these 4 structured control flow primitives:
lax.cond
will be differentiable soonlax.while_loop
non-differentiable*lax.fori_loop
non-differentiable*lax.scan
will be differentiable soon*these can in principle be made to be forward-differentiable, but this isn't on the current roadmap.
python equivalent:
def cond(pred, true_operand, true_fun, false_operand, false_fun):
if pred:
return true_fun(true_operand)
else:
return false_fun(false_operand)
from jax import lax
operand = np.array([0.])
lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1)
# --> array([1.], dtype=float32)
lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1)
# --> array([-1.], dtype=float32)
array([-1.], dtype=float32)
python equivalent:
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
init_val = 0
cond_fun = lambda x: x<10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
array(10, dtype=int32)
python equivalent:
def fori_loop(start, stop, body_fun, init_val):
val = init_val
for i in range(start, stop):
val = body_fun(i, val)
return val
init_val = 0
start = 0
stop = 10
body_fun = lambda i,x: x+i
lax.fori_loop(start, stop, body_fun, init_val)
# --> array(45, dtype=int32)
array(45, dtype=int32)
JAX and XLA offer the very general N-dimensional conv_general_dilated function, but it's not very obvious how to use it. We'll give some examples of the common use-cases. There are also the convenience functions lax.conv
and lax.conv_general_padding
for the most common kinds of convolutions.
A survey of the family of convolutional operators, a guide to convolutional arithmetic is highly recommended reading!
Let's define a simple diagonal edge kernel:
# 2D kernel - HWIO layout
kernel = onp.zeros((3, 3, 3, 3), dtype=np.float32)
kernel += onp.array([[1, 1, 0],
[1, 0,-1],
[0,-1,-1]])[:, :, onp.newaxis, onp.newaxis]
print("Edge Conv kernel:")
plt.imshow(kernel[:, :, 0, 0]);
Edge Conv kernel:
And we'll make a simple synthetic image:
# NHWC layout
img = onp.zeros((1, 200, 198, 3), dtype=np.float32)
for k in range(3):
x = 30 + 60*k
y = 20 + 60*k
img[0, x:x+10, y:y+10, k] = 1.0
print("Original Image:")
plt.imshow(img[0]);
Original Image:
These are the simple convenience functions for convolutions
️⚠️ The convenience lax.conv
, lax.conv_with_general_padding
helper function assume NCHW images and IOHW kernels.
out = lax.conv(np.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor
np.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor
(1, 1), # window strides
'SAME') # padding mode
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,0,:,:]);
out shape: (1, 3, 200, 198) First output channel:
out = lax.conv_with_general_padding(
np.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor
np.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor
(1, 1), # window strides
((2,2),(2,2)), # general padding 2x2
(1,1), # lhs/image dilation
(1,1)) # rhs/kernel dilation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,0,:,:]);
out shape: (1, 3, 202, 200) First output channel:
The important argument is the 3-tuple of axis layout arguments: (Input Layout, Kernel Layout, Output Layout)
⚠️ To demonstrate the flexibility of dimension numbers we choose a NHWC image and HWIO kernel convention for lax.conv_general_dilated
below.
dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape
kernel.shape, # only ndim matters, not shape
('NHWC', 'HWIO', 'NHWC')) # the important bit
print(dn)
ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
'SAME', # padding mode
(1,1), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,:,:,0]);
out shape: (1, 200, 198, 3) First output channel:
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
'VALID', # padding mode
(1,1), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "DIFFERENT from above!")
print("First output channel:")
plt.figure(figsize=(10,10))
plt.imshow(onp.array(out)[0,:,:,0]);
out shape: (1, 198, 196, 3) DIFFERENT from above! First output channel:
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(2,2), # window strides
'SAME', # padding mode
(1,1), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, " <-- half the size of above")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);
out shape: (1, 100, 99, 3) <-- half the size of above First output channel:
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
'VALID', # padding mode
(1,1), # lhs/image dilation
(12,12), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);
out shape: (1, 176, 174, 3) First output channel:
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1), # window strides
'SAME', # padding mode
(2,2), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "<-- larger than original!")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);
out shape: (1, 399, 395, 3) <-- larger than original! First output channel:
We can use the last to, for instance, implement transposed convolutions:
# The following is equivalent to tensorflow:
# N,H,W,C = img.shape
# out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1))
# transposed conv = 180deg kernel roation plus LHS dilation
# rotate kernel 180deg:
kernel_rot = np.rot90(np.rot90(kernel, axes=(0,1)), axes=(0,1))
# need a custom output padding:
padding = ((2, 1), (2, 1))
out = lax.conv_general_dilated(img, # lhs = image tensor
kernel_rot, # rhs = conv kernel tensor
(1,1), # window strides
padding, # padding mode
(2,2), # lhs/image dilation
(1,1), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape, "<-- transposed_conv")
plt.figure(figsize=(10,10))
print("First output channel:")
plt.imshow(onp.array(out)[0,:,:,0]);
out shape: (1, 400, 396, 3) <-- transposed_conv First output channel:
You aren't limited to 2D convolutions, a simple 1D demo is below:
# 1D kernel - WIO layout
kernel = onp.array([[[1, 0, -1], [-1, 0, 1]],
[[1, 1, 1], [-1, -1, -1]]],
dtype=np.float32).transpose([2,1,0])
# 1D data - NWC layout
data = onp.zeros((1, 200, 2), dtype=np.float32)
for i in range(2):
for k in range(2):
x = 35*i + 30 + 60*k
data[0, x:x+30, k] = 1.0
print("in shapes:", data.shape, kernel.shape)
plt.figure(figsize=(10,5))
plt.plot(data[0]);
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
('NWC', 'WIO', 'NWC'))
print(dn)
out = lax.conv_general_dilated(data, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,), # window strides
'SAME', # padding mode
(1,), # lhs/image dilation
(1,), # rhs/kernel dilation
dn) # dimension_numbers = lhs, rhs, out dimension permutation
print("out shape: ", out.shape)
plt.figure(figsize=(10,5))
plt.plot(out[0]);
in shapes: (1, 200, 2) (3, 2, 2) ConvDimensionNumbers(lhs_spec=(0, 2, 1), rhs_spec=(2, 1, 0), out_spec=(0, 2, 1)) out shape: (1, 200, 2)
# Random 3D kernel - HWDIO layout
kernel = onp.array([
[[0, 0, 0], [0, 1, 0], [0, 0, 0]],
[[0, -1, 0], [-1, 0, -1], [0, -1, 0]],
[[0, 0, 0], [0, 1, 0], [0, 0, 0]]],
dtype=np.float32)[:, :, :, onp.newaxis, onp.newaxis]
# 3D data - NHWDC layout
data = onp.zeros((1, 30, 30, 30, 1), dtype=np.float32)
x, y, z = onp.mgrid[0:1:30j, 0:1:30j, 0:1:30j]
data += (onp.sin(2*x*np.pi)*onp.cos(2*y*np.pi)*onp.cos(2*z*np.pi))[None,:,:,:,None]
print("in shapes:", data.shape, kernel.shape)
dn = lax.conv_dimension_numbers(data.shape, kernel.shape,
('NHWDC', 'HWDIO', 'NHWDC'))
print(dn)
out = lax.conv_general_dilated(data, # lhs = image tensor
kernel, # rhs = conv kernel tensor
(1,1,1), # window strides
'SAME', # padding mode
(1,1,1), # lhs/image dilation
(1,1,1), # rhs/kernel dilation
dn) # dimension_numbers
print("out shape: ", out.shape)
# Make some simple 3d density plots:
from mpl_toolkits.mplot3d import Axes3D
def make_alpha(cmap):
my_cmap = cmap(np.arange(cmap.N))
my_cmap[:,-1] = np.linspace(0, 1, cmap.N)**3
return mpl.colors.ListedColormap(my_cmap)
my_cmap = make_alpha(plt.cm.viridis)
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=data.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('input')
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=out.ravel(), cmap=my_cmap)
ax.axis('off')
ax.set_title('3D conv output');
in shapes: (1, 30, 30, 30, 1) (3, 3, 3, 1, 1) ConvDimensionNumbers(lhs_spec=(0, 4, 1, 2, 3), rhs_spec=(4, 3, 0, 1, 2), out_spec=(0, 4, 1, 2, 3)) out shape: (1, 30, 30, 30, 1)
If you want to trace where NaNs are occurring in your functions or gradients, you can turn on the NaN-checker by:
JAX_DEBUG_NANS=True
environment variable.import config
and config.update("jax_debug_nans", True)
near the top of your main filefrom jax.config import config
and config.parse_flags_with_absl()
to your main file, then set the option using a command-line flag like --jax_debug_nans=True
.This will cause computations to error-out immediately on production of a NaN.
⚠️ You shouldn't have the NaN-checker on if you're not debugging, as it can introduce lots of device-host round-trips and performance regressions!
At the moment, XLA's CPU backend defaults to enabling fast math mode, which does not preserve nan/inf semantics. (The GPU backend does not use fast math by default!) If fast math mode is enabled, the semantics of inf and nan are not preserved by XLA/LLVM, and the behavior of inf/nan values is unpredictable.
To disable fast math mode on CPU, set the environment variable:
XLA_FLAGS=--xla_cpu_enable_fast_math=false
At the moment, JAX by default enforces single-precision numbers to mitigate the Numpy API's tendency to aggressively promote operands to double
. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!
x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)
x.dtype
dtype('float32')
To use double-precision numbers, you need to set the jax_enable_x64
configuration variable at startup.
There are a few ways to do this:
You can enable 64bit mode by setting the environment variable JAX_ENABLE_X64=True
.
You can manually set the jax_enable_x64
configuration flag at startup:
# again, this only works on startup!
from jax.config import config
config.update("jax_enable_x64", True)
absl.app.run(main)
from jax.config import config
config.config_with_absl()
absl.app.run(main)
, you can instead usefrom jax.config import config
if __name__ == '__main__':
# calls config.config_with_absl() *and* runs absl parsing
config.parse_flags_with_absl()
Note that #2-#4 work for any of JAX's configuration options.
We can then confirm that x64
mode is enabled:
from jax import numpy as np, random
x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64)
x.dtype # --> dtype('float64')
⚠️ XLA doesn't support 64-bit convolutions on all backends!
If something's not covered here that has caused you weeping and gnashing of teeth, please let us know and we'll extend these introductory advisos!