Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
Forked from neural_network_and_data_loading.ipynb
Dougal Maclaurin, Peter Hawkins, Matthew Johnson, Roy Frostig, Alex Wiltschko, Chris Leary
Let's combine everything we showed in the quickstart notebook to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use tensorflow/datasets
data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for builidng our model.
!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\.//' -e 's/\..*//')/jaxlib-0.1.13-cp36-none-linux_x86_64.whl
!pip install --upgrade -q jax
Collecting jaxlib Using cached https://files.pythonhosted.org/packages/06/af/c0d5f539820e97e8ec27f05a0ee50327fe34a35369e4e02ea45ce2a45c01/jaxlib-0.1.8-cp36-none-manylinux1_x86_64.whl Collecting scipy (from jaxlib) Using cached https://files.pythonhosted.org/packages/67/e6/6d4edaceee6a110ecf6f318482f5229792f143e468b34a631f5a0899f56d/scipy-1.2.0-cp36-cp36m-manylinux1_x86_64.whl Requirement already satisfied: protobuf>=3.6.0 in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jaxlib) (3.6.1) Requirement already satisfied: six in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jaxlib) (1.12.0) Requirement already satisfied: numpy>=1.12 in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jaxlib) (1.16.1) Requirement already satisfied: absl-py in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jaxlib) (0.7.0) Requirement already satisfied: setuptools in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from protobuf>=3.6.0->jaxlib) (40.8.0) Installing collected packages: scipy, jaxlib Successfully installed jaxlib-0.1.6 scipy-1.2.0 Requirement already up-to-date: jax in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (0.1.16) Requirement already satisfied, skipping upgrade: six in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jax) (1.12.0) Requirement already satisfied, skipping upgrade: opt-einsum in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jax) (2.3.2) Requirement already satisfied, skipping upgrade: protobuf>=3.6.0 in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jax) (3.6.1) Requirement already satisfied, skipping upgrade: numpy>=1.12 in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jax) (1.16.1) Requirement already satisfied, skipping upgrade: absl-py in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from jax) (0.7.0) Requirement already satisfied, skipping upgrade: setuptools in /usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages (from protobuf>=3.6.0->jax) (40.8.0)
from __future__ import print_function, division, absolute_import
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
Let's get a few bookkeeping items out of the way.
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
keys = random.split(key, len(sizes))
return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
layer_sizes = [784, 512, 512, 10]
param_scale = 0.1
step_size = 0.0001
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))
/usr/local/google/home/rsepassi/python/fresh/lib/python3.6/site-packages/jax/lib/xla_bridge.py:146: UserWarning: No GPU found, falling back to CPU. warnings.warn('No GPU found, falling back to CPU.')
Let us first define our prediction function. Note that we're defining this for a single image example. We're going to use JAX's vmap
function to automatically handle mini-batches, with no performance penalty.
from jax.scipy.misc import logsumexp
def relu(x):
return np.maximum(0, x)
def predict(params, image):
# per-example predictions
activations = image
for w, b in params[:-1]:
outputs = np.dot(w, activations) + b
activations = relu(outputs)
final_w, final_b = params[-1]
logits = np.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
Let's check that our prediction function only works on single images.
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)
(10,)
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
preds = predict(params, random_flattened_images)
except TypeError:
print('Invalid shapes!')
Invalid shapes!
# Let's upgrade it to handle batches using `vmap`
# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))
# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)
(10, 10)
At this point, we have all the ingredients we need to define our neural network and train it. We've built an auto-batched version of predict
, which we should be able to use in a loss function. We should be able to use grad
to take the derivative of the loss with respect to the neural network parameters. Last, we should be able to use jit
to speed up everything.
def one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
def accuracy(params, images, targets):
target_class = np.argmax(targets, axis=1)
predicted_class = np.argmax(batched_predict(params, images), axis=1)
return np.mean(predicted_class == target_class)
def loss(params, images, targets):
preds = batched_predict(params, images)
return -np.sum(preds * targets)
@jit
def update(params, x, y):
grads = grad(loss)(params, x, y)
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
tensorflow/datasets
¶JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the tensorflow/datasets
data loader.
# Install tensorflow-datasets
# TODO(rsepassi): Switch to stable version on release
!pip install -q --upgrade tfds-nightly tf-nightly
import tensorflow_datasets as tfds
data_dir = '/tmp/tfds'
# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c
# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = np.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)
# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = np.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)
Train: (60000, 784) (60000, 10) Test: (10000, 784) (10000, 10)
import time
def get_train_batches():
# as_supervised=True gives us the (image, label) as a tuple instead of a dict
ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
# You can build up an arbitrary tf.data input pipeline
ds = ds.batch(128).prefetch(1)
# tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
return tfds.as_numpy(ds)
for epoch in range(num_epochs):
start_time = time.time()
for x, y in get_train_batches():
x = np.reshape(x, (len(x), num_pixels))
y = one_hot(y, num_labels)
params = update(params, x, y)
epoch_time = time.time() - start_time
train_acc = accuracy(params, train_images, train_labels)
test_acc = accuracy(params, test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Epoch 0 in 4.93 sec Training set accuracy 0.9690666794776917 Test set accuracy 0.9631999731063843 Epoch 1 in 3.91 sec Training set accuracy 0.9807999730110168 Test set accuracy 0.97079998254776 Epoch 2 in 4.02 sec Training set accuracy 0.9878833293914795 Test set accuracy 0.9763000011444092 Epoch 3 in 4.03 sec Training set accuracy 0.992733359336853 Test set accuracy 0.9787999987602234 Epoch 4 in 3.95 sec Training set accuracy 0.9907500147819519 Test set accuracy 0.9745000004768372 Epoch 5 in 4.01 sec Training set accuracy 0.9953666925430298 Test set accuracy 0.9782000184059143 Epoch 6 in 3.90 sec Training set accuracy 0.9984833598136902 Test set accuracy 0.9815000295639038 Epoch 7 in 3.93 sec Training set accuracy 0.9991166591644287 Test set accuracy 0.9824000000953674 Epoch 8 in 4.16 sec Training set accuracy 0.999833345413208 Test set accuracy 0.982200026512146 Epoch 9 in 4.03 sec Training set accuracy 0.999916672706604 Test set accuracy 0.9829999804496765
We've now used the whole of the JAX API: grad
for derivatives, jit
for speedups and vmap
for auto-vectorization.
We used NumPy to specify all of our computation, and borrowed the great data loaders from tensorflow/datasets
, and ran the whole thing on the GPU.