#!/usr/bin/env python # coding: utf-8 # # Autobatching log-densities example # # This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs. # # Inspired by a notebook by @davmre. # In[ ]: get_ipython().system("pip install --upgrade -q https://storage.googleapis.com/jax-releases/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.23-cp36-none-linux_x86_64.whl") get_ipython().system('pip install --upgrade -q jax') # In[1]: from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools import itertools import re import sys import time from matplotlib.pyplot import * import jax from jax import lax from jax import numpy as np from jax import scipy from jax import random import numpy as onp import scipy as oscipy # # Generate a fake binary classification dataset # In[2]: onp.random.seed(10009) num_features = 10 num_points = 100 true_beta = onp.random.randn(num_features).astype(np.float32) all_x = onp.random.randn(num_points, num_features).astype(np.float32) y = (onp.random.rand(num_points) < oscipy.special.expit(all_x.dot(true_beta))).astype(np.int32) # In[3]: y # # Write the log-joint function for the model # # We'll write a non-batched version, a manually batched version, and an autobatched version. # ## Non-batched # In[4]: def log_joint(beta): result = 0. # Note that no `axis` parameter is provided to `np.sum`. result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.)) result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta)))) return result # In[5]: log_joint(onp.random.randn(num_features)) # In[6]: # This doesn't work, because we didn't write `log_prob()` to handle batching. batch_size = 10 batched_test_beta = onp.random.randn(batch_size, num_features) log_joint(onp.random.randn(batch_size, num_features)) # ## Manually batched # In[7]: def batched_log_joint(beta): result = 0. # Here (and below) `sum` needs an `axis` parameter. At best, forgetting to set axis # or setting it incorrectly yields an error; at worst, it silently changes the # semantics of the model. result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=1.), axis=-1) # Note the multiple transposes. Getting this right is not rocket science, # but it's also not totally mindless. (I didn't get it right on the first # try.) result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta.T).T)), axis=-1) return result # In[8]: batch_size = 10 batched_test_beta = onp.random.randn(batch_size, num_features) batched_log_joint(batched_test_beta) # ## Autobatched with vmap # # It just works. # In[9]: vmap_batched_log_joint = jax.vmap(log_joint) vmap_batched_log_joint(batched_test_beta) # # Self-contained variational inference example # # A little code is copied from above. # ## Set up the (batched) log-joint function # In[10]: @jax.jit def log_joint(beta): result = 0. # Note that no `axis` parameter is provided to `np.sum`. result = result + np.sum(scipy.stats.norm.logpdf(beta, loc=0., scale=10.)) result = result + np.sum(-np.log(1 + np.exp(-(2*y-1) * np.dot(all_x, beta)))) return result batched_log_joint = jax.jit(jax.vmap(log_joint)) # ## Define the ELBO and its gradient # In[11]: def elbo(beta_loc, beta_log_scale, epsilon): beta_sample = beta_loc + np.exp(beta_log_scale) * epsilon return np.mean(batched_log_joint(beta_sample), 0) + np.sum(beta_log_scale - 0.5 * onp.log(2*onp.pi)) elbo = jax.jit(elbo, static_argnums=(2, 3)) elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1))) # ## Optimize the ELBO using SGD # In[12]: def normal_sample(key, shape): """Convenience function for quasi-stateful RNG.""" new_key, sub_key = random.split(key) return new_key, random.normal(sub_key, shape) normal_sample = jax.jit(normal_sample, static_argnums=(1,)) key = random.PRNGKey(10003) beta_loc = np.zeros(num_features, np.float32) beta_log_scale = np.zeros(num_features, np.float32) step_size = 0.01 batch_size = 128 epsilon_shape = (batch_size, num_features) for i in range(1000): key, epsilon = normal_sample(key, epsilon_shape) elbo_val, (beta_loc_grad, beta_log_scale_grad) = elbo_val_and_grad( beta_loc, beta_log_scale, epsilon) beta_loc += step_size * beta_loc_grad beta_log_scale += step_size * beta_log_scale_grad if i % 10 == 0: print('{}\t{}'.format(i, elbo_val)) # ## Display the results # # Coverage isn't quite as good as we might like, but it's not bad, and nobody said variational inference was exact. # In[13]: figure(figsize=(7, 7)) plot(true_beta, beta_loc, '.', label='Approximated Posterior Means') plot(true_beta, beta_loc + 2*np.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars') plot(true_beta, beta_loc - 2*np.exp(beta_log_scale), 'r.') plot_scale = 3 plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k') xlabel('True beta') ylabel('Estimated beta') legend(loc='best') # In[ ]: