Source code for renn.utils

# Copyright 2020 Google LLC
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for optimization."""

import jax
from jax import flatten_util
import jax.numpy as jnp

import numpy as np
import tqdm

__all__ = [
    'batch_mean', 'norm', 'identity', 'fst', 'snd', 'optimize', 'one_hot',

[docs]def batch_mean(fun, in_axes): """Converts a function to a batched version (maps over multiple inputs). This takes a function that returns a scalar (such as a loss function) and returns a new function that maps the function over multiple arguments (such as over multiple random seeds) and returns the average of the results. It is useful for generating a batched version of a loss function, where the loss function has stochasticity that depends on a random seed argument. Args: fun: function, Function to batch. in_axes: tuple, Specifies the arguments to fun to batch over. For example, in_axes=(None, 0) would batch over the second argument. Returns: batch_fun: function, computes the average over a batch. """ mapped_fun = jax.vmap(fun, in_axes=in_axes) def batch_fun(*args): return jnp.mean(mapped_fun(*args)) return batch_fun
[docs]def norm(params, order=2): """Computes the (flattened) norm of a pytree.""" return jnp.linalg.norm(flatten_util.ravel_pytree(params)[0], ord=order)
[docs]def identity(x): """Identity function.""" return x
[docs]def fst(xs): """Returns the first element from a list.""" return xs[0]
[docs]def snd(xs): """Returns the second element from a list.""" return xs[1]
[docs]def compose(*funcs): """Returns a function that is the composition of multiple functions.""" def wrapper(x): for func in reversed(funcs): x = func(x) return x return wrapper
[docs]def optimize(loss_fun, x0, optimizer, steps, stop_tol=-np.inf): """Run an optimizer on a given loss function. Args: loss_fun: Scalar loss function to optimize. x0: Initial parameters. optimizer: An tuple of optimizer functions (init_opt, update_opt, get_params) from a jax.experimental.optimizers instance. steps: Iterator over steps. stop_tol: Stop if the loss is below this value (Default: -np.inf). Returns: loss_hist: Array of losses during training. final_params: Optimized parameters. """ # Initialize optimizer. init_opt, update_opt, get_params = optimizer opt_state = init_opt(x0) # Loss and gradient. value_and_grad = jax.value_and_grad(loss_fun) @jax.jit def step(k, state): params = get_params(state) loss, grads = value_and_grad(params) return loss, update_opt(k, grads, state) # Store loss history. loss_hist = [] for k in steps: f, opt_state = step(k, opt_state) loss_hist.append(f) if f <= stop_tol: break # Extract final parameters. final_params = get_params(opt_state) return np.array(loss_hist), final_params
[docs]def one_hot(labels, num_classes, dtype=jnp.float32): """Creates a one-hot encoding of an array of labels. Args: labels: array of integers with shape (num_examples,). num_classes: int, Total number of classes. dtype: optional, jax datatype for the return array (Default: float32). Returns: one_hot_labels: array with shape (num_examples, num_classes). """ return jnp.array(jnp.array(labels)[:, None] == jnp.arange(num_classes), dtype)
def select(sequences, indices): """Given an array of shape (number_of_sequences, sequence_length, element_dimension), and a 1D array specifying which indices of each sequence to select, return a (number_of_sequences, element_dimension)-shaped array with the selected elements. Args: sequences: array with shape (number_of_sequences, sequence_length, element_dimension) indices: 1D array with length number_of_sequence Returns: selected_elements: array with shape (number_of_sequences, element_dimension) """ assert len(indices) == sequences.shape[0] # shape indices properly indices_shaped = indices[:, jnp.newaxis, jnp.newaxis] # select element selected_elements = jnp.take_along_axis(sequences, indices_shaped, axis=1) # remove sequence dimension selected_elements = jnp.squeeze(selected_elements, axis=1) return selected_elements def make_loss_function(network_apply_fun, basic_loss_fun, regularization_fun): """ Given the network-function, the basic loss function, and a regularization function, return a loss function which maps a tuple of network parameters and a training batch to a loss value. Arguments: network_apply_fun - maps (network_params, batched_inputs) -> network_logits basic_loss_fun - maps (logits, batched_labels) -> scalar loss value regularization_fun - maps network_params -> scalar loss value Returns: total_loss_fun - maps (network_params, batch) -> scalar loss value """ def total_loss_fun(params, batch): """ Maps network parameters and training batch to a loss value. Args: batch: a dictionary with keys ['inputs', 'index', 'labels'] 'inputs': sequence of inputs with shape (batch_size, max_sequence_length) 'index' : 1d-array storing length of the corresponding input sequence 'labels': 1d-array storing label of corresponding input sequence Returns: loss: scalar loss averaged over batch """ all_time_logits = network_apply_fun(params, batch['inputs']) end_logits = select(all_time_logits, batch['index'] - 1) return basic_loss_fun(end_logits, batch['labels']) + regularization_fun(params) return total_loss_fun def make_acc_fun(network_apply_fun, num_outputs=1): """ Given a network function and number of outputs, returns an accuracy function """ if num_outputs == 1: prediction_function = lambda x: (x >= 0.).astype(jnp.int32) else: prediction_function = lambda x: x.argmax(axis=-1).astype(jnp.int32) @jax.jit def accuracy_fun(params, batch): all_time_logits = network_apply_fun(params, batch['inputs']) end_logits = select(all_time_logits, batch['index'] - 1) predictions = jnp.squeeze(prediction_function(end_logits)) accuracies = (batch['labels'] == predictions).astype(jnp.int32) return jnp.mean(accuracies) return accuracy_fun