!pip install --upgrade -q git+https://github.com/google/jax.git !pip install --upgrade -q jaxlib import numpy as onp from jax import grad, jit from jax import lax from jax import random import jax import jax.numpy as np import matplotlib as mpl from matplotlib import pyplot as plt from matplotlib import rcParams rcParams['image.interpolation'] = 'nearest' rcParams['image.cmap'] = 'viridis' rcParams['axes.grid'] = False numpy_array = onp.zeros((3,3), dtype=np.float32) print("original array:") print(numpy_array) # In place, mutating update numpy_array[1, :] = 1.0 print("updated array:") print(numpy_array) jax_array = np.zeros((3,3), dtype=np.float32) # In place update of JAX's array will yield an error! jax_array[1, :] = 1.0 from jax.ops import index, index_add, index_update jax_array = np.zeros((3, 3)) print("original array:") print(jax_array) new_jax_array = index_update(jax_array, index[1, :], 1.) print("old array unchanged:") print(jax_array) print("new array:") print(new_jax_array) print("original array:") jax_array = np.ones((5, 6)) print(jax_array) new_jax_array = index_add(jax_array, index[::2, 3:], 7.) print("new array post-addition:") print(new_jax_array) print(onp.random.random()) print(onp.random.random()) print(onp.random.random()) onp.random.seed(0) rng_state = onp.random.get_state() #print(rng_state) # --> ('MT19937', array([0, 1, 1812433255, 1900727105, 1208447044, # 2481403966, 4042607538, 337614300, ... 614 more numbers..., # 3048484911, 1796872496], dtype=uint32), 624, 0, 0.0) _ = onp.random.uniform() rng_state = onp.random.get_state() #print(rng_state) # --> ('MT19937', array([2443250962, 1093594115, 1878467924, # ..., 2648828502, 1678096082], dtype=uint32), 2, 0, 0.0) # Let's exhaust the entropy in this PRNG statevector for i in range(311): _ = onp.random.uniform() rng_state = onp.random.get_state() #print(rng_state) # --> ('MT19937', array([2443250962, 1093594115, 1878467924, # ..., 2648828502, 1678096082], dtype=uint32), 624, 0, 0.0) # Next call iterates the RNG state for a new batch of fake "entropy". _ = onp.random.uniform() rng_state = onp.random.get_state() # print(rng_state) # --> ('MT19937', array([1499117434, 2949980591, 2242547484, # 4162027047, 3277342478], dtype=uint32), 2, 0, 0.0) from jax import random key = random.PRNGKey(0) key print(random.normal(key, shape=(1,))) print(key) # No no no! print(random.normal(key, shape=(1,))) print(key) print("old key", key) key, subkey = random.split(key) normal_pseudorandom = random.normal(subkey, shape=(1,)) print(" \---SPLIT --> new key ", key) print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom) print("old key", key) key, subkey = random.split(key) normal_pseudorandom = random.normal(subkey, shape=(1,)) print(" \---SPLIT --> new key ", key) print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom) key, *subkeys = random.split(key, 4) for subkey in subkeys: print(random.normal(subkey, shape=(1,))) def f(x): if x < 3: return 3. * x ** 2 else: return -4 * x print(grad(f)(2.)) # ok! print(grad(f)(4.)) # ok! @jit def f(x): for i in range(3): x = 2 * x return x print(f(3)) @jit def g(x): y = 0. for i in range(x.shape[0]): y = y + x[i] return y print(g(np.array([1., 2., 3.]))) @jit def f(x): if x < 3: return 3. * x ** 2 else: return -4 * x # This will fail! try: f(2) except Exception as e: print("ERROR:", e) def f(x): if x < 3: return 3. * x ** 2 else: return -4 * x f = jit(f, static_argnums=(0,)) print(f(2.)) def f(x, n): y = 0. for i in range(n): y = y + x[i] return y f = jit(f, static_argnums=(1,)) f(np.array([2., 3., 4.]), 2) def example_fun(length, val): return np.ones((length,)) * val # un-jit'd works fine print(example_fun(5, 4)) bad_example_jit = jit(example_fun) # this will fail: try: print(bad_example_jit(10, 4)) except Exception as e: print("error!", e) # static_argnums tells JAX to recompile on changes at these argument positions: good_example_jit = jit(example_fun, static_argnums=(0,)) # first compile print(good_example_jit(10, 4)) # recompiles print(good_example_jit(5, 4)) @jit def f(x): print(x) y = 2 * x print(y) return y f(2) from jax import lax operand = np.array([0.]) lax.cond(True, operand, lambda x: x+1, operand, lambda x: x-1) # --> array([1.], dtype=float32) lax.cond(False, operand, lambda x: x+1, operand, lambda x: x-1) # --> array([-1.], dtype=float32) init_val = 0 cond_fun = lambda x: x<10 body_fun = lambda x: x+1 lax.while_loop(cond_fun, body_fun, init_val) # --> array(10, dtype=int32) init_val = 0 start = 0 stop = 10 body_fun = lambda i,x: x+i lax.fori_loop(start, stop, body_fun, init_val) # --> array(45, dtype=int32) # 2D kernel - HWIO layout kernel = onp.zeros((3, 3, 3, 3), dtype=np.float32) kernel += onp.array([[1, 1, 0], [1, 0,-1], [0,-1,-1]])[:, :, onp.newaxis, onp.newaxis] print("Edge Conv kernel:") plt.imshow(kernel[:, :, 0, 0]); # NHWC layout img = onp.zeros((1, 200, 198, 3), dtype=np.float32) for k in range(3): x = 30 + 60*k y = 20 + 60*k img[0, x:x+10, y:y+10, k] = 1.0 print("Original Image:") plt.imshow(img[0]); out = lax.conv(np.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor np.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor (1, 1), # window strides 'SAME') # padding mode print("out shape: ", out.shape) print("First output channel:") plt.figure(figsize=(10,10)) plt.imshow(onp.array(out)[0,0,:,:]); out = lax.conv_with_general_padding( np.transpose(img,[0,3,1,2]), # lhs = NCHW image tensor np.transpose(kernel,[2,3,0,1]), # rhs = IOHW conv kernel tensor (1, 1), # window strides ((2,2),(2,2)), # general padding 2x2 (1,1), # lhs/image dilation (1,1)) # rhs/kernel dilation print("out shape: ", out.shape) print("First output channel:") plt.figure(figsize=(10,10)) plt.imshow(onp.array(out)[0,0,:,:]); dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape kernel.shape, # only ndim matters, not shape ('NHWC', 'HWIO', 'NHWC')) # the important bit print(dn) out = lax.conv_general_dilated(img, # lhs = image tensor kernel, # rhs = conv kernel tensor (1,1), # window strides 'SAME', # padding mode (1,1), # lhs/image dilation (1,1), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permutation print("out shape: ", out.shape) print("First output channel:") plt.figure(figsize=(10,10)) plt.imshow(onp.array(out)[0,:,:,0]); out = lax.conv_general_dilated(img, # lhs = image tensor kernel, # rhs = conv kernel tensor (1,1), # window strides 'VALID', # padding mode (1,1), # lhs/image dilation (1,1), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permutation print("out shape: ", out.shape, "DIFFERENT from above!") print("First output channel:") plt.figure(figsize=(10,10)) plt.imshow(onp.array(out)[0,:,:,0]); out = lax.conv_general_dilated(img, # lhs = image tensor kernel, # rhs = conv kernel tensor (2,2), # window strides 'SAME', # padding mode (1,1), # lhs/image dilation (1,1), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permutation print("out shape: ", out.shape, " <-- half the size of above") plt.figure(figsize=(10,10)) print("First output channel:") plt.imshow(onp.array(out)[0,:,:,0]); out = lax.conv_general_dilated(img, # lhs = image tensor kernel, # rhs = conv kernel tensor (1,1), # window strides 'VALID', # padding mode (1,1), # lhs/image dilation (12,12), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permutation print("out shape: ", out.shape) plt.figure(figsize=(10,10)) print("First output channel:") plt.imshow(onp.array(out)[0,:,:,0]); out = lax.conv_general_dilated(img, # lhs = image tensor kernel, # rhs = conv kernel tensor (1,1), # window strides 'SAME', # padding mode (2,2), # lhs/image dilation (1,1), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permutation print("out shape: ", out.shape, "<-- larger than original!") plt.figure(figsize=(10,10)) print("First output channel:") plt.imshow(onp.array(out)[0,:,:,0]); # The following is equivalent to tensorflow: # N,H,W,C = img.shape # out = tf.nn.conv2d_transpose(img, kernel, (N,2*H,2*W,C), (1,2,2,1)) # transposed conv = 180deg kernel roation plus LHS dilation # rotate kernel 180deg: kernel_rot = np.rot90(np.rot90(kernel, axes=(0,1)), axes=(0,1)) # need a custom output padding: padding = ((2, 1), (2, 1)) out = lax.conv_general_dilated(img, # lhs = image tensor kernel_rot, # rhs = conv kernel tensor (1,1), # window strides padding, # padding mode (2,2), # lhs/image dilation (1,1), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permutation print("out shape: ", out.shape, "<-- transposed_conv") plt.figure(figsize=(10,10)) print("First output channel:") plt.imshow(onp.array(out)[0,:,:,0]); # 1D kernel - WIO layout kernel = onp.array([[[1, 0, -1], [-1, 0, 1]], [[1, 1, 1], [-1, -1, -1]]], dtype=np.float32).transpose([2,1,0]) # 1D data - NWC layout data = onp.zeros((1, 200, 2), dtype=np.float32) for i in range(2): for k in range(2): x = 35*i + 30 + 60*k data[0, x:x+30, k] = 1.0 print("in shapes:", data.shape, kernel.shape) plt.figure(figsize=(10,5)) plt.plot(data[0]); dn = lax.conv_dimension_numbers(data.shape, kernel.shape, ('NWC', 'WIO', 'NWC')) print(dn) out = lax.conv_general_dilated(data, # lhs = image tensor kernel, # rhs = conv kernel tensor (1,), # window strides 'SAME', # padding mode (1,), # lhs/image dilation (1,), # rhs/kernel dilation dn) # dimension_numbers = lhs, rhs, out dimension permutation print("out shape: ", out.shape) plt.figure(figsize=(10,5)) plt.plot(out[0]); # Random 3D kernel - HWDIO layout kernel = onp.array([ [[0, 0, 0], [0, 1, 0], [0, 0, 0]], [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], dtype=np.float32)[:, :, :, onp.newaxis, onp.newaxis] # 3D data - NHWDC layout data = onp.zeros((1, 30, 30, 30, 1), dtype=np.float32) x, y, z = onp.mgrid[0:1:30j, 0:1:30j, 0:1:30j] data += (onp.sin(2*x*np.pi)*onp.cos(2*y*np.pi)*onp.cos(2*z*np.pi))[None,:,:,:,None] print("in shapes:", data.shape, kernel.shape) dn = lax.conv_dimension_numbers(data.shape, kernel.shape, ('NHWDC', 'HWDIO', 'NHWDC')) print(dn) out = lax.conv_general_dilated(data, # lhs = image tensor kernel, # rhs = conv kernel tensor (1,1,1), # window strides 'SAME', # padding mode (1,1,1), # lhs/image dilation (1,1,1), # rhs/kernel dilation dn) # dimension_numbers print("out shape: ", out.shape) # Make some simple 3d density plots: from mpl_toolkits.mplot3d import Axes3D def make_alpha(cmap): my_cmap = cmap(np.arange(cmap.N)) my_cmap[:,-1] = np.linspace(0, 1, cmap.N)**3 return mpl.colors.ListedColormap(my_cmap) my_cmap = make_alpha(plt.cm.viridis) fig = plt.figure() ax = fig.gca(projection='3d') ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=data.ravel(), cmap=my_cmap) ax.axis('off') ax.set_title('input') fig = plt.figure() ax = fig.gca(projection='3d') ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=out.ravel(), cmap=my_cmap) ax.axis('off') ax.set_title('3D conv output'); x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64) x.dtype from jax import numpy as np, random x = random.uniform(random.PRNGKey(0), (1000,), dtype=np.float64) x.dtype # --> dtype('float64')