Source code for renn.rnn.fixed_points

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fixed point finding routines."""

import jax
import jax.numpy as jnp
import numpy as np
from scipy.spatial import distance

from .. import utils

__all__ = ['build_fixed_point_loss', 'find_fixed_points', 'exclude_outliers']


[docs]def build_fixed_point_loss(rnn_cell, cell_params): """Builds function to compute speed of hidden states. Args: rnn_cell: an RNNCell instance. cell_params: RNN parameters to use when applying the RNN. Returns: fixed_point_loss_fun: function that takes a batch of hidden states and inputs and computes the speed of the corresponding hidden states. """ def fixed_point_loss_fun(h, x): """Computes the speed of hidden states. The speed is defined as the squared l2 distance between the current state and the next state, in response to a given input: Q = (1/2) || h - F(h, x) ||_2^2 Args: h: The current state as a vector. x: The current input as a vector. Returns: fixed_point_loss_fun: A function that computes the fixed point speeds for a list or array of states. """ h_next = rnn_cell.batch_apply(cell_params, x, h) return 0.5 * jnp.sum((h - h_next)**2, axis=1) return fixed_point_loss_fun
[docs]def find_fixed_points(fp_loss_fun, initial_states, x_star, optimizer, tolerance, steps=range(1000)): """Run fixed point optimization. Args: fp_loss_fun: Function that computes fixed point speeds. initial_states: Initial state seeds. x_star: Input at which to compute fixed points. optimizer: A jax.experimental.optimizers tuple. tolerance: Stopping tolerance threshold. steps: Iterator over steps. Returns: fixed_points: Array of fixed points for each tolerance. loss_hist: Array containing fixed point loss curve. squared_speeds: Array containing the squared speed of each fixed point. """ loss_hist, fps = utils.optimize(lambda h: jnp.mean(fp_loss_fun(h, x_star)), initial_states, optimizer, steps, stop_tol=tolerance) fixed_points = jax.device_get(fps) squared_speeds = jax.device_get(fp_loss_fun(fps, x_star)) return fixed_points, loss_hist, squared_speeds
[docs]def exclude_outliers(points, threshold=np.inf, verbose=False): """Remove points that are not within some threshold of another point.""" # Return all fixed points if tolerance is <= 0 if np.isinf(threshold): return points # Return if there are less than two fixed points. if points.shape[0] <= 1: return points # Compute pairwise distances between all fixed points. distances = distance.squareform(distance.pdist(points)) # Find distance to each nearest neighbor. nn_distance = np.partition(distances, 1, axis=0)[1] # Keep points whose nearest neighbor is within some distance threshold. keep_indices = np.where(nn_distance <= threshold)[0] # Log how many points were kept. if verbose: print(f'Keeping {len(keep_indices)} out of {len(points)} points.') return points[keep_indices]