Quickstart

This notebook walks through some of the basic functionality provided by the renn package.

[1]:
# Imports
from functools import partial

import jax
import jax.numpy as jnp

import renn

base_key = jax.random.PRNGKey(0)
/Users/nirum/anaconda3/lib/python3.8/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')

Build and train RNNs

First, we will use the provided RNN cell classes to build different RNN architectures.

[2]:
# Here, we build an RNN composed of a single GRU cell.
cell = renn.GRU(32)
print(f'Made a GRU cell with {cell.num_units} units.')
Made a GRU cell with 32 units.

We can initialize the hidden state for this cell as follows:

[3]:
key, base_key = jax.random.split(base_key)
current_state = cell.init_initial_state(key)
print(f'Initialized state with shape: {current_state.shape}')
Initialized state with shape: (32,)

We can initialize the cell’s trainable parameters using cell.init:

[4]:
num_timesteps = 100
input_dim = 2
input_shape = (num_timesteps, input_dim)

key, base_key = jax.random.split(base_key)
output_shape, params = cell.init(key, input_shape)

print(f'Outputs have shape: {output_shape}')
Outputs have shape: (100, 32)

The GRU cell is a subclass of RNNCell. All RNNCells have an apply method that computes a single RNN step.

[5]:
key, base_key = jax.random.split(base_key)
inputs = jax.random.normal(key, (input_dim,))

next_state = cell.apply(params, inputs, current_state)
print(f'Next state has shape: {next_state.shape}')
Next state has shape: (32,)

To apply the RNN across an entire batch of sequences, we use the renn.unroll_rnn function:

[6]:
batch_size = 8
key, base_key = jax.random.split(base_key)
batched_inputs = jax.random.normal(key, (batch_size,) + input_shape)
batch_initial_states = cell.get_initial_state(params, batch_size=batch_size)

states = renn.unroll_rnn(batch_initial_states, batched_inputs, partial(cell.batch_apply, params))

print(f'Applied RNN to a batch of sequences, got back states with shape: {states.shape}')
Applied RNN to a batch of sequences, got back states with shape: (8, 100, 32)

We can use these to train RNNs on different kinds of sequential data.

Analyzing RNNs

The RNN cells we have in renn are easily amenable for analysis. One useful tool is to linearize the RNN, meaning we compute a first-order (linear) Taylor approximation of the nonlinear RNN update.

Mathematically, we can approximate the RNN at a particular expansion point (\(h\), \(x\)) as follows:

\[F(h + \Delta h, x + \Delta x) \approx h + \frac{\partial F}{\partial h} \left(\Delta h\right) + \frac{\partial F}{\partial x} \left(\Delta x\right)\]

In the above equation, the term \(\frac{\partial F}{\partial h}\) is the recurrent Jacobian of the RNN, and the term \(\frac{\partial F}{\partial x}\) is the input Jacobian.

We can easily compute Jacobians of our GRU cell at a particular point. We can do this using the rec_jac and inp_jac methods on the cell class:

[7]:
Jacobian = cell.rec_jac(params, inputs, current_state)
print(f'Recurrent Jacobian has shape: {Jacobian.shape}')

Jacobian = cell.inp_jac(params, inputs, current_state)
print(f'Input Jacobian has shape: {Jacobian.shape}')
Recurrent Jacobian has shape: (32, 32)
Input Jacobian has shape: (32, 2)

renn also contains helper functions for numerically finding fixed points of the RNN, for building and training different RNN architectures, and for training and analyzing RNN optimizers.

In future tutorials, we will explore some of these additional use cases!