Source code for renn.metaopt.models

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Define simple learned optimizer models."""

import numbers

import jax
from jax.experimental import optimizers
from jax.experimental import stax
import jax.numpy as jnp

# Aliases for standard initializers and nonlinearities.
fan_in = jax.nn.initializers.variance_scaling(1., 'fan_in', 'normal')
zeros = jax.nn.initializers.zeros


[docs]def append_to_sequence(sequence, element): """Appends an element to a rolling sequence buffer. Args: sequence: a sequence of ndarrays, concatenated along the first dimension. element: an ndarray to add to the sequence. Returns: sequence: the updated sequence, with the first element removed, the rest of the elements shifted over, and the new element added. """ return jnp.vstack((sequence[1:], element[jnp.newaxis, ...]))
[docs]def cwrnn(key, cell, input_scale='raw', output_scale=1e-3): """Component-wise RNN Optimizer. This optimizer applies an RNN to update the parameters of each problem variable independently (hence the name, component-wise). It follows the same approach as in previous work (Andrychowicz et al 2016, Wichrowska et al 2017) that distribute the parameters along the batch dimension of the RNN. This allows us to easily update each parameter in parallel. Args: key: Jax PRNG key to use for initializing parameters. cell: An RNNCell to use (see renn/rnn/cells.py) input_scale: str, Specifies how to scale gradient inputs to the RNN. If 'raw', then the gradients are not scaled. If 'log1p', then the scale and the sign of the inputs are split into a length 2 vector, [log1p(abs(g)), sign(g)]. output_scale: float, Constant used to multiply (rescale) the RNN output. Returns: meta_parameters: A tuple containing the RNN parameters and the readout parameters. The RNN parameters themselves are a namedtuple. The readout parameters are also a tuple containing weights and a bias. optimizer_fun: A function that takes a set of meta_parameters and initializes an optimizer tuple containing functions to initialize the optimizer state, update the optimizer state, and get parameters from the optimizer state. """ # Input and output shapes. n_in = 2 if input_scale == 'log1p' else 1 n_out = 1 # Initialize the readout readout_init, readout_apply = stax.Dense(n_out, W_init=zeros, b_init=zeros) # Initialize parameters. rnn_key, readout_key = jax.random.split(key) rnn_shape, rnn_params = cell.init(rnn_key, (None, n_in)) _, readout_params = readout_init(readout_key, rnn_shape) initial_theta = (rnn_params, readout_params) @jax.experimental.optimizers.optimizer def optimizer_fun(theta): """Builds a component-wise RNN optimizer.""" rnn_params, readout_params = theta def init_state(x): n = jnp.ravel(x).size return (x, cell.get_initial_state(rnn_params, batch_size=n)) def update_opt(_, grads, state): x, h = state grad_vec = jnp.reshape(grads, (-1, 1)) # Inputs are scaled by a constant factor. if isinstance(input_scale, numbers.Number): inputs = input_scale * grad_vec # Inputs are raw (unmodified) gradients. elif input_scale == 'raw': inputs = grad_vec # Inputs are the log-scale and sign of the gradient. elif input_scale == 'log1p': scale = jnp.log1p(jnp.abs(grad_vec)) sign = jnp.sign(grad_vec) inputs = jnp.hstack((scale, sign)) else: raise ValueError(f'Invalid input scale {input_scale}.') h_next = cell.batch_apply(rnn_params, inputs, h) outputs = readout_apply(readout_params, h_next) x_next = x + output_scale * jnp.reshape(outputs, x.shape) return (x_next, h_next) def get_params(state): return state[0] return (init_state, update_opt, get_params) return initial_theta, optimizer_fun
[docs]def lds(key, num_units, h0_init=zeros, w_init=fan_in): """Linear dynamical system (LDS) optimizer.""" hstar_key, rec_key, inp_key, readout_key = jax.random.split(key, 4) # Initialize linear dynamical system h0 = h0_init(hstar_key, (num_units,)) rec_jac = w_init(rec_key, (num_units, num_units)) inp_jac = w_init(inp_key, (num_units, num_units)) # Initialize the readout readout_init, readout_apply = stax.Dense(1, W_init=zeros, b_init=zeros) _, readout_params = readout_init(readout_key, (None, num_units)) initial_meta_params = (h0, rec_jac, inp_jac, readout_params) @optimizers.optimizer def optimizer_fun(h_star, g_star, h_init, rec_jac, inp_jac, readout_params): """Linear dynamical system optimizer. Args: h_star: The state around which to linearize. g_star: The input around which to linearize. h_init: The initial state. rec_jac: Defines the recurrent dynamics. inp_jac: Multiplies the input gradients. readout_params: Tuple of (weights, biases). Returns: init_state: Initialize the optimizer state. update_opt: Updates the optimizer state variables given the current step, gradients, and current state. get_params: Gets parameters from the optimizer state. """ def init_state(x): batch_size = jnp.ravel(x).size h = jnp.ones((batch_size, 1)) * jnp.reshape(h_init, (1, -1)) return (x, h) def update_opt(_, grads, state): x, h = state g = jnp.reshape(grads, (-1, 1)) h_next = h_star + jnp.dot(h - h_star, rec_jac.T) + jnp.dot(g - g_star, inp_jac.T) # pylint: disable=line-too-long outputs = readout_apply(readout_params, h_next) x_next = x + jnp.reshape(outputs, x.shape) return (x_next, h_next) def get_params(state): return state[0] return (init_state, update_opt, get_params) return initial_meta_params, optimizer_fun
[docs]def linear(key, tau, scale, base=0): """Optimizer that is a linear function of gradient history.""" initial_meta_params = base + scale * jax.random.uniform(key, (tau,)) @optimizers.optimizer def optimizer_fun(meta_params): """Builds a linear optimizer with the given meta_params.""" def init_fun(params): """Initialize optimizer state.""" grad_seq = jnp.zeros((tau,) + params.shape) return (params, grad_seq) def update_fun(step, grads, state): """Apply a step of the optimizer.""" del step # Unused. params, grad_seq = state grad_seq = append_to_sequence(grad_seq, grads) params -= jnp.tensordot(meta_params, grad_seq, axes=1) return (params, grad_seq) def get_params(state): """Get parameters from the optimizer.""" return state[0] return (init_fun, update_fun, get_params) return initial_meta_params, optimizer_fun
[docs]def linear_dx(key, tau, scale_grad, scale_dx, base_grad=0, base_gram=0): """Optimizer that is a linear function of gradient and parameter history.""" key0, key1 = jax.random.split(key, 2) initial_meta_params = (base_grad + scale_grad * jax.random.uniform(key0, (tau,)), base_gram + scale_dx * jax.random.uniform(key1, (tau - 1,))) @optimizers.optimizer def optimizer_fun(meta_params): """Builds a linear_dx optimizer.""" theta_grad, theta_dx = meta_params def init_fun(params): """Initialize optimizer state.""" grad_seq = jnp.zeros((tau,) + params.shape) param_seq = jnp.zeros((tau,) + params.shape) return (params, grad_seq, param_seq) def update_fun(step, grads, state): """Apply a step of the optimizer.""" del step # Unused. params, grad_seq, param_seq = state grad_seq = append_to_sequence(grad_seq, grads) param_seq = append_to_sequence(param_seq, params) # Differences in parameters. # TODO(nirum): This recomputes differences at every iteration. Should # time this to ensure that the repeated jnp.diff call is not too slow. delta_params = jnp.diff(param_seq, axis=0) grad_term = jnp.tensordot(theta_grad, grad_seq, axes=1) dx_term = jnp.tensordot(theta_dx, delta_params, axes=1) params -= (grad_term + dx_term) return (params, grad_seq, param_seq) def get_params(state): return state[0] return init_fun, update_fun, get_params return initial_meta_params, optimizer_fun
[docs]def gradgram(key, tau, scale_grad, scale_gram, base_grad=0, base_gram=0): """Optimizer that is a function of gradient history and inner products.""" # Initialize meta-parameters. key0, key1 = jax.random.split(key, 2) initial_meta_params = (base_grad + scale_grad * jax.random.uniform(key0, (tau,)), base_gram + scale_gram * jax.random.uniform(key1, (tau,))) # Generalized inner product. innerprod = jax.jit( jax.vmap(jax.vmap(lambda x, y: -jnp.sum(x * y), in_axes=(0, None)), in_axes=(None, 0))) # Batched norm. norms = jax.jit(jax.vmap(jnp.linalg.norm, in_axes=0)) @optimizers.optimizer def optimizer_fun(meta_params): """An optimizer that uses gradient-gradient correlations.""" theta_grad, theta_gram = meta_params def init_fun(params): """Initialize the optimizer state.""" grad_seq = jnp.zeros((tau,) + params.shape) return (params, grad_seq) def update_fun(step, grads, state): """Apply a step of the optimzier.""" del step # Unused. params, grad_seq = state # Update gradient history. grad_seq = append_to_sequence(grad_seq, grads) # Compute normalized gram matrix. gram = innerprod(grad_seq, grad_seq) grad_norm = norms(grad_seq) gram /= (jnp.outer(grad_norm, grad_norm) + 1e-6) # Compute update terms. attn_weights = jnp.dot(stax.softmax(gram, axis=0), theta_gram) attn_term = jnp.tensordot(attn_weights, grad_seq, axes=1) grad_term = jnp.tensordot(theta_grad, grad_seq, axes=1) params -= (grad_term + attn_term) return (params, grad_seq) def get_params(state): return state[0] return init_fun, update_fun, get_params return initial_meta_params, optimizer_fun
[docs]def momentum(key): """Wrapper for the momentum optimizer.""" del key # Unused. initial_learning_rate = 1e-3 initial_mass = 0.8 def optimizer_fun(optimizer_params): return optimizers.momentum(*optimizer_params) return (initial_learning_rate, initial_mass), optimizer_fun
[docs]def aggmo(key, num_terms): """Aggregated momentum (aggmo).""" initial_learning_rate = 0.0 initial_masses = zeros(key, (num_terms,)) initial_meta_params = (initial_learning_rate, initial_masses) @optimizers.optimizer def optimizer_fun(v0, alphas, betas): """Aggregated momentum optimizer. Defines an aggregated momentum optimizer (momentum with multiple timescales). Instead of a single learning rate and momentum mass, this optimizer includes `n` of them. Args: v0: Initial velocity, with shape (1,) or (n,). If it is a single number, this will be broadcast along each of the n modes. alphas: Learning rate hyperparameters with shape (n,). betas: Momentum hyperparameters with shape (n,). Returns: init_state: Initialize the optimizer state. update_opt: Updates the optimizer state variables given the current step, gradients, and current state. get_params: Gets parameters from the optimizer state. """ alphas = jnp.reshape(alphas, (1, -1)) betas = jnp.reshape(betas, (1, -1)) def init_state(x): n = jnp.ravel(x).size v = jnp.ones((n, 1)) * jnp.reshape(v0, (1, -1)) return (x, v) def update_opt(_, grads, state): x, v = state inputs = jnp.reshape(grads, (-1, 1)) v_next = betas * v - alphas * inputs x_next = x + jnp.real(jnp.sum(v_next, axis=1)) return (x_next, v_next) def get_params(state): return state[0] return (init_state, update_opt, get_params) return initial_meta_params, optimizer_fun