!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-latest-cp36-none-linux_x86_64.whl !pip install --upgrade -q jax from __future__ import print_function, division import jax.numpy as np from jax import grad, jit, vmap from jax import random key = random.PRNGKey(0) x = random.normal(key, (10,)) print(x) size = 3000 x = random.normal(key, (size, size), dtype=np.float32) %timeit np.dot(x, x.T) # runs on the GPU import numpy as onp # original CPU-backed NumPy x = onp.random.normal(size=(size, size)).astype(onp.float32) %timeit np.dot(x, x.T) from jax import device_put x = onp.random.normal(size=(size, size)).astype(onp.float32) x = device_put(x) %timeit np.dot(x, x.T) x = onp.random.normal(size=(size, size)).astype(onp.float32) %timeit onp.dot(x, x.T) def selu(x, alpha=1.67, lmbda=1.05): return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha) x = random.normal(key, (1000000,)) %timeit selu(x) selu_jit = jit(selu) %timeit selu_jit(x) def sum_logistic(x): return np.sum(1.0 / (1.0 + np.exp(-x))) x_small = np.arange(3.) derivative_fn = grad(sum_logistic) print(derivative_fn(x_small)) def first_finite_differences(f, x): eps = 1e-3 return np.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps) for v in onp.eye(len(x))]) print(first_finite_differences(sum_logistic, x_small)) print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0)) from jax import jacfwd, jacrev def hessian(fun): return jit(jacfwd(jacrev(fun))) mat = random.normal(key, (150, 100)) batched_x = random.normal(key, (10, 100)) def apply_matrix(v): return np.dot(mat, v) def naively_batched_apply_matrix(v_batched): return np.stack([apply_matrix(v) for v in v_batched]) print('Naively batched') %timeit naively_batched_apply_matrix(batched_x) @jit def batched_apply_matrix(v_batched): return np.dot(v_batched, mat.T) print('Manually batched') %timeit batched_apply_matrix(batched_x) @jit def vmap_batched_apply_matrix(batched_x): return vmap(apply_matrix)(batched_x) print('Auto-vectorized with vmap') %timeit vmap_batched_apply_matrix(batched_x)