!pip install -q --upgrade jax jaxlib from __future__ import print_function, division import jax.numpy as np from jax import grad, jit, vmap from jax import random key = random.PRNGKey(0) grad_tanh = grad(np.tanh) print(grad_tanh(2.0)) print(grad(grad(np.tanh))(2.0)) print(grad(grad(grad(np.tanh)))(2.0)) def sigmoid(x): return 0.5 * (np.tanh(x / 2) + 1) # Outputs probability of a label being true. def predict(W, b, inputs): return sigmoid(np.dot(inputs, W) + b) # Build a toy dataset. inputs = np.array([[0.52, 1.12, 0.77], [0.88, -1.08, 0.15], [0.52, 0.06, -1.30], [0.74, -2.49, 1.39]]) targets = np.array([True, True, False, True]) # Training loss is the negative log-likelihood of the training examples. def loss(W, b): preds = predict(W, b, inputs) label_probs = preds * targets + (1 - preds) * (1 - targets) return -np.sum(np.log(label_probs)) # Initialize random model coefficients key, W_key, b_key = random.split(key, 3) W = random.normal(W_key, (3,)) b = random.normal(b_key, ()) # Differentiate `loss` with respect to the first positional argument: W_grad = grad(loss, argnums=0)(W, b) print('W_grad', W_grad) # Since argnums=0 is the default, this does the same thing: W_grad = grad(loss)(W, b) print('W_grad', W_grad) # But we can choose different values too, and drop the keyword: b_grad = grad(loss, 1)(W, b) print('b_grad', b_grad) # Including tuple values W_grad, b_grad = grad(loss, (0, 1))(W, b) print('W_grad', W_grad) print('b_grad', b_grad) def loss2(params_dict): preds = predict(params_dict['W'], params_dict['b'], inputs) label_probs = preds * targets + (1 - preds) * (1 - targets) return -np.sum(np.log(label_probs)) print(grad(loss2)({'W': W, 'b': b})) from jax import value_and_grad loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b) print('loss value', loss_value) print('loss value', loss(W, b)) # Set a step size for finite differences calculations eps = 1e-4 # Check b_grad with scalar finite differences b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps print('b_grad_numerical', b_grad_numerical) print('b_grad_autodiff', grad(loss, 1)(W, b)) # Check W_grad with finite differences in a random direction key, subkey = random.split(key) vec = random.normal(subkey, W.shape) unitvec = vec / np.sqrt(np.vdot(vec, vec)) W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps print('W_dirderiv_numerical', W_grad_numerical) print('W_dirderiv_autodiff', np.vdot(grad(loss)(W, b), unitvec)) from jax.test_util import check_grads check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives def hvp(f, x, v): return grad(lambda x: np.vdot(grad(f)(x), v)) from jax import jacfwd, jacrev # Isolate the function from the weight matrix to the predictions f = lambda W: predict(W, b, inputs) J = jacfwd(f)(W) print("jacfwd result, with shape", J.shape) print(J) J = jacrev(f)(W) print("jacrev result, with shape", J.shape) print(J) def predict_dict(params, inputs): return predict(params['W'], params['b'], inputs) J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs) for k, v in J_dict.items(): print("Jacobian from {} to logits is".format(k)) print(v) def hessian(f): return jacfwd(jacrev(f)) H = hessian(f)(W) print("hessian, with shape", H.shape) print(H) from jax import jvp # Isolate the function from the weight matrix to the predictions f = lambda W: predict(W, b, inputs) key, subkey = random.split(key) v = random.normal(subkey, W.shape) # Push forward the vector `v` along `f` evaluated at `W` y, u = jvp(f, (W,), (v,)) from jax import vjp # Isolate the function from the weight matrix to the predictions f = lambda W: predict(W, b, inputs) y, vjp_fun = vjp(f, W) key, subkey = random.split(key) u = random.normal(subkey, y.shape) # Pull back the covector `u` along `f` evaluated at `W` v = vjp_fun(u) def hvp(f, x, v): return grad(lambda x: np.vdot(grad(f)(x), v)) from jax import jvp, grad # forward-over-reverse def hvp(f, primals, tangents): return jvp(grad(f), primals, tangents)[1] def f(X): return np.sum(np.tanh(X)**2) key, subkey1, subkey2 = random.split(key, 3) X = random.normal(subkey1, (30, 40)) V = random.normal(subkey2, (30, 40)) ans1 = hvp(f, (X,), (V,)) ans2 = np.tensordot(hessian(f)(X), V, 2) print(np.allclose(ans1, ans2, 1e-4, 1e-4)) # reverse-over-forward def hvp_revfwd(f, primals, tangents): g = lambda primals: jvp(f, primals, tangents)[1] return grad(g)(primals) # reverse-over-reverse, only works for single arguments def hvp_revrev(f, primals, tangents): x, = primals v, = tangents return grad(lambda x: np.vdot(grad(f)(x), v))(x) print("Forward over reverse") %timeit -n10 -r3 hvp(f, (X,), (V,)) print("Reverse over forward") %timeit -n10 -r3 hvp_revfwd(f, (X,), (V,)) print("Reverse over reverse") %timeit -n10 -r3 hvp_revrev(f, (X,), (V,)) print("Naive full Hessian materialization") %timeit -n10 -r3 np.tensordot(hessian(f)(X), V, 2) # Isolate the function from the weight matrix to the predictions f = lambda W: predict(W, b, inputs) # Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`. # First, use a list comprehension to loop over rows in the matrix M. def loop_mjp(f, x, M): y, vjp_fun = vjp(f, x) return np.vstack([vjp_fun(mi) for mi in M]) # Now, use vmap to build a computation that does a single fast matrix-matrix # multiply, rather than an outer loop over vector-matrix multiplies. def vmap_mjp(f, x, M): y, vjp_fun = vjp(f, x) return vmap(vjp_fun)(M) key = random.PRNGKey(0) num_covecs = 128 U = random.normal(key, (num_covecs,) + y.shape) loop_vs = loop_mjp(f, W, M=U) print('Non-vmapped Matrix-Jacobian product') %timeit -n10 -r3 loop_mjp(f, W, M=U) print('\nVmapped Matrix-Jacobian product') vmap_vs = vmap_mjp(f, W, M=U) %timeit -n10 -r3 vmap_mjp(f, W, M=U) assert np.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical' def loop_jmp(f, x, M): # jvp immediately returns the primal and tangent values as a tuple, # so we'll compute and select the tangents in a list comprehension return np.vstack([jvp(f, (W,), (si,))[1] for si in S]) def vmap_jmp(f, x, M): _jvp = lambda s: jvp(f, (W,), (s,))[1] return vmap(_jvp)(M) num_vecs = 128 S = random.normal(key, (num_vecs,) + W.shape) loop_vs = loop_jmp(f, W, M=S) print('Non-vmapped Jacobian-Matrix product') %timeit -n10 -r3 loop_jmp(f, W, M=S) vmap_vs = vmap_jmp(f, W, M=S) print('\nVmapped Jacobian-Matrix product') %timeit -n10 -r3 vmap_jmp(f, W, M=S) assert np.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical' from jax import jacrev as builtin_jacrev def our_jacrev(f): def jacfun(x): y, vjp_fun = vjp(f, x) # Use vmap to do a matrix-Jacobian product. # Here, the matrix is the Euclidean basis, so we get all # entries in the Jacobian at once. J, = vmap(vjp_fun, in_axes=0)(np.eye(len(y))) return J return jacfun assert np.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!' from jax import jacfwd as builtin_jacfwd def our_jacfwd(f): def jacfun(x): _jvp = lambda s: jvp(f, (x,), (s,))[1] Jt =vmap(_jvp, in_axes=1)(np.eye(len(x))) return np.transpose(Jt) return jacfun assert np.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!' def f(x): try: if x < 3: return 2 * x ** 3 else: raise ValueError except ValueError: return np.pi * x y, f_vjp = vjp(f, 4.) print(jit(f_vjp)(1.)) def f(z): x, y = real(z), imag(z) return u(x, y), v(x, y) * 1j def grad_f(z): x, y = real(z), imag(z) return grad(u, 0)(x, y) + grad(u, 1)(x, y) * 1j A = np.array([[5., 2.+3j, 5j], [2.-3j, 7., 1.+7j], [-5j, 1.-7j, 12.]]) def f(X): L = np.linalg.cholesky(X) return np.sum((L - np.sin(L))**2) grad(f)(A)