renn.rnn package¶
Submodules¶
renn.rnn.cells module¶
Recurrent neural network (RNN) cells.
-
class
renn.rnn.cells.
LinearRNN
(A: jax._src.numpy.lax_numpy.array, W: jax._src.numpy.lax_numpy.array, b: jax._src.numpy.lax_numpy.array)[source]¶ Bases:
object
Dataclass for storing parameters of a Linear RNN.
-
class
renn.rnn.cells.
RNNCell
(num_units, h_init=<function zeros>)[source]¶ Bases:
object
Base class for all RNN Cells.
- An RNNCell must implement the following methods:
- init(PRNGKey, input_shape) -> output_shape, rnn_params apply(params, inputs, state) -> next_state
-
class
renn.rnn.cells.
StackedCell
(layers)[source]¶ Bases:
renn.rnn.cells.RNNCell
Stacks multiple RNN cells together.
A stacked RNN cell is specified by a list of RNN cells and (optional) stax.Dense layers in between them.
Note that the full hidden state for this cell is the concatenation of hidden states from all of the cells in the stack.
-
class
renn.rnn.cells.
GRU
(num_units, gate_bias=0.0, w_init=<function variance_scaling.<locals>.init>, b_init=<function zeros>, h_init=<function zeros>)[source]¶ Bases:
renn.rnn.cells.RNNCell
Gated recurrent unit.
-
class
renn.rnn.cells.
LSTM
(num_units, forget_bias=1.0, w_init=<function variance_scaling.<locals>.init>, b_init=<function zeros>, h_init=<function zeros>)[source]¶ Bases:
renn.rnn.cells.RNNCell
Long-short term memory (LSTM).
-
class
renn.rnn.cells.
VanillaRNN
(num_units, w_init=<function variance_scaling.<locals>.init>, b_init=<function zeros>, h_init=<function zeros>)[source]¶ Bases:
renn.rnn.cells.RNNCell
Vanilla RNN Cell.
-
class
renn.rnn.cells.
UGRNN
(num_units, gate_bias=0.0, w_init=<function variance_scaling.<locals>.init>, b_init=<function zeros>, h_init=<function zeros>)[source]¶ Bases:
renn.rnn.cells.RNNCell
Update-gate RNN Cell.
-
renn.rnn.cells.
embedding
(vocab_size, embedding_size, initializer=<function orthogonal.<locals>.init>)[source]¶ Builds a token embedding.
Parameters: - vocab_size – int, Size of the vocabulary (number of tokens).
- embedding_size – int, Dimensionality of the embedding.
- initializer – Initializer for the embedding (Default: orthogonal).
Returns: callable, Initializes the embedding given a key and input_shape. apply_fun: callable, Converts a set of tokens to embedded vectors.
Return type: init_fun
renn.rnn.fixed_points module¶
Fixed point finding routines.
-
renn.rnn.fixed_points.
build_fixed_point_loss
(rnn_cell, cell_params)[source]¶ Builds function to compute speed of hidden states.
Parameters: - rnn_cell – an RNNCell instance.
- cell_params – RNN parameters to use when applying the RNN.
Returns: - function that takes a batch of hidden states
and inputs and computes the speed of the corresponding hidden states.
Return type: fixed_point_loss_fun
-
renn.rnn.fixed_points.
find_fixed_points
(fp_loss_fun, initial_states, x_star, optimizer, tolerance, steps=range(0, 1000))[source]¶ Run fixed point optimization.
Parameters: - fp_loss_fun – Function that computes fixed point speeds.
- initial_states – Initial state seeds.
- x_star – Input at which to compute fixed points.
- optimizer – A jax.experimental.optimizers tuple.
- tolerance – Stopping tolerance threshold.
- steps – Iterator over steps.
Returns: Array of fixed points for each tolerance. loss_hist: Array containing fixed point loss curve. squared_speeds: Array containing the squared speed of each fixed point.
Return type: fixed_points
renn.rnn.network module¶
Recurrent neural network (RNN) helper functions.
-
renn.rnn.network.
build_rnn
(num_tokens, emb_size, cell, num_outputs=1)[source]¶ Builds an end-to-end recurrent neural network (RNN) model.
Parameters: - num_tokens – int, Number of different input tokens.
- emb_size – int, Dimensionality of the embedding vectors.
- cell – RNNCell to use as the core update function (see cells.py).
- num_outputs – int, Number of outputs from the readout (Default: 1).
Returns: - function that takes a PRNGkey and input_shape and returns
expected shapes and initialized embedding, RNN, and readout parameters.
- apply_fun: function that takes a tuple of network parameters and batch of
input tokens and applies the RNN to each sequence in the batch.
emb_apply: function to just apply the embedding. readout_apply: function to just apply the readout.
Return type: init_fun
-
renn.rnn.network.
eigsorted
(jac)[source]¶ Computes sorted eigenvalues and corresponding eigenvectors of a matrix.
Notes
The eigenvectors are stored in the columns of the returned matrices. The right and left eigenvectors are returned, such that: J=REL^T
Parameters: jac – numpy array used to compute the eigendecomposition (must be square). Returns: right eigenvectors, as columns in the returned array. eigvals: numpy array of eigenvalues. lefts: left eigenvectors, as columns in the returned array. Return type: rights
renn.rnn.unroll module¶
Recurrent neural network (RNN) cells.
-
renn.rnn.unroll.
unroll_rnn
(initial_states, input_sequences, rnn_update, readout=<function identity>)[source]¶ Unrolls an RNN on a batch of input sequences.
Given a batch of initial RNN states, and a batch of input sequences, this function unrolls application of the RNN along the sequence. The RNN state is updated using the rnn_update function, and the readout is used to convert the RNN state to outputs (defaults to the identity function).
B: batch size. N: number of RNN units. T: sequence length.
Parameters: - initial_states – batch of initial states, with shape (B, N).
- input_sequences – batch of inputs, with shape (B, T, N).
- rnn_update – updates the RNN hidden state, given (inputs, current_states).
- readout – applies the readout, given current states. If this is the identity function, then no readout is applied (returns the hidden states).
Returns: batch of outputs (batch_size, sequence_length, num_outputs).
Return type: outputs