renn.rnn package

Submodules

renn.rnn.cells module

Recurrent neural network (RNN) cells.

class renn.rnn.cells.LinearRNN(A: jax._src.numpy.lax_numpy.array, W: jax._src.numpy.lax_numpy.array, b: jax._src.numpy.lax_numpy.array)[source]

Bases: object

Dataclass for storing parameters of a Linear RNN.

apply(x, h) → jax._src.numpy.lax_numpy.array[source]

Linear RNN Update.

flatten()[source]
class renn.rnn.cells.RNNCell(num_units, h_init=<function zeros>)[source]

Bases: object

Base class for all RNN Cells.

An RNNCell must implement the following methods:
init(PRNGKey, input_shape) -> output_shape, rnn_params apply(params, inputs, state) -> next_state
apply(params, inputs, state)[source]
get_initial_state(params, batch_size=None)[source]

Gets initial RNN states.

Parameters:
  • params – rnn_parameters
  • batch_size – batch size of initial states to create.
Returns:

An ndarray with shape (batch size, num_units).

init(key, input_shape)[source]
init_initial_state(key)[source]
class renn.rnn.cells.StackedCell(layers)[source]

Bases: renn.rnn.cells.RNNCell

Stacks multiple RNN cells together.

A stacked RNN cell is specified by a list of RNN cells and (optional) stax.Dense layers in between them.

Note that the full hidden state for this cell is the concatenation of hidden states from all of the cells in the stack.

apply(params, inputs, state)[source]

Applies a single step of a Stacked RNN.

init(key, input_shape)[source]

Initializes parameters of a Stacked RNN.

class renn.rnn.cells.GRU(num_units, gate_bias=0.0, w_init=<function variance_scaling.<locals>.init>, b_init=<function zeros>, h_init=<function zeros>)[source]

Bases: renn.rnn.cells.RNNCell

Gated recurrent unit.

apply(params, inputs, state)[source]
init(key, input_shape)[source]
class renn.rnn.cells.LSTM(num_units, forget_bias=1.0, w_init=<function variance_scaling.<locals>.init>, b_init=<function zeros>, h_init=<function zeros>)[source]

Bases: renn.rnn.cells.RNNCell

Long-short term memory (LSTM).

apply(params, inputs, full_state)[source]
init(key, input_shape)[source]
class renn.rnn.cells.VanillaRNN(num_units, w_init=<function variance_scaling.<locals>.init>, b_init=<function zeros>, h_init=<function zeros>)[source]

Bases: renn.rnn.cells.RNNCell

Vanilla RNN Cell.

apply(params, inputs, state)[source]

Applies a single step of a Vanilla RNN.

init(key, input_shape)[source]

Initializes the parameters of a Vanilla RNN.

class renn.rnn.cells.UGRNN(num_units, gate_bias=0.0, w_init=<function variance_scaling.<locals>.init>, b_init=<function zeros>, h_init=<function zeros>)[source]

Bases: renn.rnn.cells.RNNCell

Update-gate RNN Cell.

apply(params, inputs, state)[source]
init(key, input_shape)[source]
renn.rnn.cells.embedding(vocab_size, embedding_size, initializer=<function orthogonal.<locals>.init>)[source]

Builds a token embedding.

Parameters:
  • vocab_size – int, Size of the vocabulary (number of tokens).
  • embedding_size – int, Dimensionality of the embedding.
  • initializer – Initializer for the embedding (Default: orthogonal).
Returns:

callable, Initializes the embedding given a key and input_shape. apply_fun: callable, Converts a set of tokens to embedded vectors.

Return type:

init_fun

renn.rnn.fixed_points module

Fixed point finding routines.

renn.rnn.fixed_points.build_fixed_point_loss(rnn_cell, cell_params)[source]

Builds function to compute speed of hidden states.

Parameters:
  • rnn_cell – an RNNCell instance.
  • cell_params – RNN parameters to use when applying the RNN.
Returns:

function that takes a batch of hidden states

and inputs and computes the speed of the corresponding hidden states.

Return type:

fixed_point_loss_fun

renn.rnn.fixed_points.find_fixed_points(fp_loss_fun, initial_states, x_star, optimizer, tolerance, steps=range(0, 1000))[source]

Run fixed point optimization.

Parameters:
  • fp_loss_fun – Function that computes fixed point speeds.
  • initial_states – Initial state seeds.
  • x_star – Input at which to compute fixed points.
  • optimizer – A jax.experimental.optimizers tuple.
  • tolerance – Stopping tolerance threshold.
  • steps – Iterator over steps.
Returns:

Array of fixed points for each tolerance. loss_hist: Array containing fixed point loss curve. squared_speeds: Array containing the squared speed of each fixed point.

Return type:

fixed_points

renn.rnn.fixed_points.exclude_outliers(points, threshold=inf, verbose=False)[source]

Remove points that are not within some threshold of another point.

renn.rnn.network module

Recurrent neural network (RNN) helper functions.

renn.rnn.network.build_rnn(num_tokens, emb_size, cell, num_outputs=1)[source]

Builds an end-to-end recurrent neural network (RNN) model.

Parameters:
  • num_tokens – int, Number of different input tokens.
  • emb_size – int, Dimensionality of the embedding vectors.
  • cell – RNNCell to use as the core update function (see cells.py).
  • num_outputs – int, Number of outputs from the readout (Default: 1).
Returns:

function that takes a PRNGkey and input_shape and returns

expected shapes and initialized embedding, RNN, and readout parameters.

apply_fun: function that takes a tuple of network parameters and batch of

input tokens and applies the RNN to each sequence in the batch.

emb_apply: function to just apply the embedding. readout_apply: function to just apply the readout.

Return type:

init_fun

renn.rnn.network.mse(y, yhat)[source]

Mean squared error loss.

renn.rnn.network.eigsorted(jac)[source]

Computes sorted eigenvalues and corresponding eigenvectors of a matrix.

Notes

The eigenvectors are stored in the columns of the returned matrices. The right and left eigenvectors are returned, such that: J=REL^T

Parameters:jac – numpy array used to compute the eigendecomposition (must be square).
Returns:right eigenvectors, as columns in the returned array. eigvals: numpy array of eigenvalues. lefts: left eigenvectors, as columns in the returned array.
Return type:rights
renn.rnn.network.timescale(eigenvalues)[source]

Converts eigenvalues into approximate time constants.

renn.rnn.unroll module

Recurrent neural network (RNN) cells.

renn.rnn.unroll.unroll_rnn(initial_states, input_sequences, rnn_update, readout=<function identity>)[source]

Unrolls an RNN on a batch of input sequences.

Given a batch of initial RNN states, and a batch of input sequences, this function unrolls application of the RNN along the sequence. The RNN state is updated using the rnn_update function, and the readout is used to convert the RNN state to outputs (defaults to the identity function).

B: batch size. N: number of RNN units. T: sequence length.

Parameters:
  • initial_states – batch of initial states, with shape (B, N).
  • input_sequences – batch of inputs, with shape (B, T, N).
  • rnn_update – updates the RNN hidden state, given (inputs, current_states).
  • readout – applies the readout, given current states. If this is the identity function, then no readout is applied (returns the hidden states).
Returns:

batch of outputs (batch_size, sequence_length, num_outputs).

Return type:

outputs

Module contents