Source code for

# Copyright 2020 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""Recurrent neural network (RNN) helper functions."""

import functools

import jax
from jax.experimental import stax
import jax.numpy as jnp
import numpy as np

from . import cells
from . import unroll

__all__ = ['build_rnn', 'mse', 'eigsorted', 'timescale']

[docs]def build_rnn(num_tokens, emb_size, cell, num_outputs=1): """Builds an end-to-end recurrent neural network (RNN) model. Args: num_tokens: int, Number of different input tokens. emb_size: int, Dimensionality of the embedding vectors. cell: RNNCell to use as the core update function (see num_outputs: int, Number of outputs from the readout (Default: 1). Returns: init_fun: function that takes a PRNGkey and input_shape and returns expected shapes and initialized embedding, RNN, and readout parameters. apply_fun: function that takes a tuple of network parameters and batch of input tokens and applies the RNN to each sequence in the batch. emb_apply: function to just apply the embedding. readout_apply: function to just apply the readout. """ emb_init, emb_apply = cells.embedding(num_tokens, emb_size) readout_init, readout_apply = stax.Dense(num_outputs) def init_fun(key, input_shape): """Initialize the components of the RNN. Args: key: Jax PRNGkey used to initialize the parameters. input_shape: tuple representing the input shape, should be (batch_size, sequence_length). Returns: shapes: set of tuples representing the shapes after applying the Embedding, RNN Cell, and Readout layers. network_params: tuple of network parameters, containing the embedding, RNN cell, and readout parameters. """ emb_key, cell_key, readout_key = jax.random.split(key, 3) # Initialize the Embedding for the input tokens. emb_shape, emb_params = emb_init(emb_key, input_shape) # The cell is defined for a single update step, which is why we ignore # the sequence dimension (emb_shape[1]) here. rnn_shape, rnn_params = cell.init(cell_key, (emb_shape[0], emb_shape[2])) output_shape, readout_params = readout_init(readout_key, rnn_shape) shapes = (emb_shape, rnn_shape, output_shape) network_params = (emb_params, rnn_params, readout_params) return shapes, network_params def apply_fun(network_params, tokens): """Applies the RNN on a batch of input sequences. Args: network_params: tuple of network parameters (see init_fun). tokens: batch of inputs, with shape (batch_size, sequence_length). Returns: outputs: network outputs, at every step along the sequence. """ emb_params, rnn_params, readout_params = network_params # Apply the embedding. inputs = emb_apply(emb_params, tokens) # Run the RNN. initial_states = cell.get_initial_state(rnn_params, batch_size=tokens.shape[0]) return unroll.unroll_rnn(initial_states, inputs, functools.partial(cell.batch_apply, rnn_params), functools.partial(readout_apply, readout_params)) return init_fun, apply_fun, emb_apply, readout_apply
[docs]def mse(y, yhat): """Mean squared error loss.""" return 0.5 * jnp.mean((y - yhat)**2)
[docs]def eigsorted(jac): """Computes sorted eigenvalues and corresponding eigenvectors of a matrix. Notes: The eigenvectors are stored in the columns of the returned matrices. The right and left eigenvectors are returned, such that: J=REL^T Args: jac: numpy array used to compute the eigendecomposition (must be square). Returns: rights: right eigenvectors, as columns in the returned array. eigvals: numpy array of eigenvalues. lefts: left eigenvectors, as columns in the returned array. """ unsorted_eigvals, unsorted_rights = np.linalg.eig(jac) sorted_indices = np.flipud(np.argsort(np.abs(unsorted_eigvals))) eigenvalues = unsorted_eigvals[sorted_indices] rights = unsorted_rights[:, sorted_indices] lefts = np.linalg.pinv(rights).T return rights, eigenvalues, lefts
[docs]def timescale(eigenvalues): """Converts eigenvalues into approximate time constants.""" return -1. / np.log(np.abs(eigenvalues))