This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.
Inspired by a notebook by @davmre.
!pip install --upgrade -q https://storage.googleapis.com/jax-releases/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-0.1.23-cp36-none-linux_x86_64.whl
!pip install --upgrade -q jax
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import itertools
import re
import sys
import time
from matplotlib.pyplot import *
import jax
from jax import lax
from jax import numpy as np
from jax import scipy
from jax import random
import numpy as onp
import scipy as oscipy
onp.random.seed(10009)
num_features = 10
num_points = 100
true_beta = onp.random.randn(num_features).astype(np.float32)
all_x = onp.random.randn(num_points, num_features).astype(np.float32)
y = (onp.random.rand(num_points) < oscipy.special.expit(all_x.dot(true_beta))).astype(np.int32)
y
array([0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0], dtype=int32)
We'll write a non-batched version, a manually batched version, and an autobatched version.
def log_joint(beta):
result = 0.
# Note that no `axis` parameter is provided to `np.sum`.
result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.))
result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))
return result
log_joint(onp.random.randn(num_features))
/Users/mhoffman/mypython/lib/python2.7/site-packages/jax/lib/xla_bridge.py:146: UserWarning: No GPU found, falling back to CPU. warnings.warn('No GPU found, falling back to CPU.')
array(-213.23558, dtype=float32)
# This doesn't work, because we didn't write `log_prob()` to handle batching.
batch_size = 10
batched_test_beta = onp.random.randn(batch_size, num_features)
log_joint(onp.random.randn(batch_size, num_features))
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-6-c7ddbb18b4cb> in <module>() 3 batched_test_beta = onp.random.randn(batch_size, num_features) 4 ----> 5 log_joint(onp.random.randn(batch_size, num_features)) <ipython-input-4-fff01ffe382a> in log_joint(beta) 3 # Note that no `axis` parameter is provided to `np.sum`. 4 result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.)) ----> 5 result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta)))) 6 return result /Users/mhoffman/mypython/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc in <lambda>(x, y) 240 fn = lambda x, y: lax_fn(*_promote_args_like(numpy_fn, x, y)) 241 else: --> 242 fn = lambda x, y: lax_fn(*_promote_args(numpy_fn.__name__, x, y)) 243 return _wraps(numpy_fn)(fn) 244 /Users/mhoffman/mypython/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc in _promote_args(fun_name, *args) 177 """Convenience function to apply Numpy argument shape and dtype promotion.""" 178 _check_arraylike(fun_name, *args) --> 179 return _promote_shapes(*_promote_dtypes(*args)) 180 181 /Users/mhoffman/mypython/lib/python2.7/site-packages/jax/numpy/lax_numpy.pyc in _promote_shapes(*args) 137 else: 138 shapes = [shape(arg) for arg in args] --> 139 nd = len(lax.broadcast_shapes(*shapes)) 140 return [lax.reshape(arg, (1,) * (nd - len(shp)) + shp) 141 if len(shp) != nd else arg for arg, shp in zip(args, shapes)] /Users/mhoffman/mypython/lib/python2.7/site-packages/jax/util.pyc in memoized_fun(*args, **kwargs) 159 cache.popitem(last=False) 160 --> 161 ans = cache[key] = fun(*args, **kwargs) 162 return ans 163 return memoized_fun /Users/mhoffman/mypython/lib/python2.7/site-packages/jax/lax.pyc in broadcast_shapes(*shapes) 67 if not onp.all((shapes == result_shape) | (shapes == 1)): 68 raise ValueError("Incompatible shapes for broadcasting: {}" ---> 69 .format(tuple(map(tuple, shapes)))) 70 return tuple(result_shape) 71 ValueError: Incompatible shapes for broadcasting: ((100, 10), (1, 100))
def batched_log_joint(beta):
result = 0.
# Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis
# or setting it incorrectly yields an error; at worst, it silently changes the
# semantics of the model.
result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.),
axis=-1)
# Note the multiple transposes. Getting this right is not rocket science,
# but it's also not totally mindless. (I didn't get it right on the first
# try.)
result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta.T).T)),
axis=-1)
return result
batch_size = 10
batched_test_beta = onp.random.randn(batch_size, num_features)
batched_log_joint(batched_test_beta)
array([-147.84033 , -207.02205 , -109.26075 , -243.8083 , -163.02911 , -143.84848 , -160.28772 , -113.77169 , -126.605446, -190.81989 ], dtype=float32)
It just works.
vmap_batched_log_joint = jax.vmap(log_joint)
vmap_batched_log_joint(batched_test_beta)
array([-147.84033 , -207.02205 , -109.26075 , -243.8083 , -163.02911 , -143.84848 , -160.28772 , -113.77169 , -126.605446, -190.81989 ], dtype=float32)
A little code is copied from above.
@jax.jit
def log_joint(beta):
result = 0.
# Note that no `axis` parameter is provided to `np.sum`.
result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=10.))
result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta))))
return result
batched_log_joint = jax.jit(jax.vmap(log_joint))
def elbo(beta_loc, beta_log_scale, epsilon):
beta_sample = beta_loc + np.exp(beta_log_scale) * epsilon
return np.mean(batched_log_joint(beta_sample), 0) + np.sum(beta_log_scale - 0.5 * onp.log(2*onp.pi))
elbo = jax.jit(elbo, static_argnums=(2, 3))
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))
def normal_sample(key, shape):
"""Convenience function for quasi-stateful RNG."""
new_key, sub_key = random.split(key)
return new_key, random.normal(sub_key, shape)
normal_sample = jax.jit(normal_sample, static_argnums=(1,))
key = random.PRNGKey(10003)
beta_loc = np.zeros(num_features, np.float32)
beta_log_scale = np.zeros(num_features, np.float32)
step_size = 0.01
batch_size = 128
epsilon_shape = (batch_size, num_features)
for i in range(1000):
key, epsilon = normal_sample(key, epsilon_shape)
elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad(
beta_loc, beta_log_scale, epsilon)
beta_loc += step_size * beta_loc_grad
beta_log_scale += step_size * beta_log_scale_grad
if i % 10 == 0:
print('{}\t{}'.format(i, elbo_val))
0 -180.853881836 10 -113.060455322 20 -102.737258911 30 -99.7873535156 40 -98.9089889526 50 -98.297454834 60 -98.1863174438 70 -97.5797195435 80 -97.2860031128 90 -97.4699630737 100 -97.4771728516 110 -97.5806732178 120 -97.494354248 130 -97.5027313232 140 -96.8639526367 150 -97.4419784546 160 -97.0694046021 170 -96.8402862549 180 -97.2133789062 190 -97.5650253296 200 -97.2639770508 210 -97.1197967529 220 -97.395942688 230 -97.1683197021 240 -97.1184082031 250 -97.2434539795 260 -97.2978668213 270 -96.692855835 280 -96.9643859863 290 -97.3005523682 300 -96.6359176636 310 -97.0351867676 320 -97.529083252 330 -97.2881164551 340 -97.0732192993 350 -97.1561889648 360 -97.2588195801 370 -97.1951446533 380 -97.1309204102 390 -97.1172637939 400 -96.9387359619 410 -97.2667694092 420 -97.353225708 430 -97.2100753784 440 -97.2843475342 450 -97.1630859375 460 -97.2612457275 470 -97.2134399414 480 -97.2399749756 490 -97.1491317749 500 -97.2352828979 510 -96.9342041016 520 -97.212097168 530 -96.8257751465 540 -97.0128479004 550 -96.9417648315 560 -97.1652069092 570 -97.2916564941 580 -97.429397583 590 -97.2437133789 600 -97.1521911621 610 -97.4984436035 620 -96.9906997681 630 -96.8895645142 640 -96.8996887207 650 -97.1379394531 660 -97.4370574951 670 -96.9923629761 680 -97.1562423706 690 -97.1869049072 700 -97.1116027832 710 -97.7810516357 720 -97.2322616577 730 -97.1620635986 740 -96.9958190918 750 -96.6672210693 760 -97.1679534912 770 -97.5143508911 780 -97.2890090942 790 -96.9122619629 800 -97.1709976196 810 -97.290473938 820 -97.1624298096 830 -97.1910629272 840 -97.5638198853 850 -97.0019378662 860 -96.8655548096 870 -96.7633743286 880 -96.8366088867 890 -97.1217956543 900 -97.0955505371 910 -97.0682373047 920 -97.1194763184 930 -96.8792953491 940 -97.4562530518 950 -96.6928024292 960 -97.293762207 970 -97.3353042603 980 -97.349609375 990 -97.0967559814
Coverage isn't quite as good as we might like, but it's not bad, and nobody said variational inference was exact.
figure(figsize=(7, 7))
plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
plot(true_beta, beta_loc + 2*np.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
plot(true_beta, beta_loc - 2*np.exp(beta_log_scale), 'r.')
plot_scale = 3
plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
xlabel('True beta')
ylabel('Estimated beta')
legend(loc='best')
<matplotlib.legend.Legend at 0x125ac6490>