renn.metaopt package

Submodules

renn.metaopt.api module

Meta-optimization framework.

renn.metaopt.api.build_metaobj(problem_fun, optimizer_fun, num_inner_steps, meta_loss=<function mean>, l2_penalty=0.0, decorator=<function checkpoint>)[source]

Builds a meta-objective function.

Parameters:
  • problem_fun – callable, Takes a PRNGKey argument and returns initial parameters and a loss function.
  • optimizer_fun – callable, Takes a PRNGKey argument and returns an optimizer tuple (as in jax.experimental.optimizers).
  • num_inner_steps – int, Number of optimization steps.
  • meta_loss – callable, Function to use to compute a scalar meta-loss.
  • l2_penalty – float, L2 penalty to apply to the meta-parameters.
  • decorator – callable, Optional function to wrap the apply_fun argument to lax.scan. By default, this is jax.remat, which will rematerialize the forward computation when computing the gradient, trading off computation for memory. Using the identity function will turn off remat.
Returns:

callable, Function that takes meta-parameters and a

PRNGKey and returns a scalar meta-objective and the inner loss history.

Return type:

meta_objective

renn.metaopt.api.clip(x, value=inf)[source]

Clips elements of x to have magnitude less than or equal to value.

renn.metaopt.api.evaluate(opt, problem_fun, num_steps, eval_key, num_repeats=64)[source]

Evaluates an optimizer on a given problem.

Parameters:
  • opt – An optimizer tuple of functions (init_opt, update_opt, get_params) to evaluate.
  • problem_fun – A function that returns an (initial_params, loss_fun, fetch_data) tuple given a PRNGKey.
  • num_steps – Number of steps to run the optimizer for.
  • eval_key – Base PRNGKey used for evaluation.
  • num_repeats – Number of different evaluation seeds to use.
Returns:

Array of loss values with shape (num_repeats, num_steps)

containing the training loss curve for each random seed.

Return type:

losses

renn.metaopt.api.outer_loop(key, initial_meta_params, meta_objective, meta_optimizer, steps, batch_size=1, save_every=None, clip_value=inf)[source]

Meta-trains an optimizer.

Parameters:
  • key – Jax PRNG key, used for initializing the inner problem.
  • initial_meta_params – pytree, Initial meta-parameters.
  • meta_objective – function, Computes a (scalar) loss given meta-parameters and an array (batch) of random seeds.
  • meta_optimizer – tuple of functions, Defines the meta-optimizer to use (for example, a jax.experimental.optimizers Optimizer tuple).
  • steps – A generator that yields integers from (0, num_steps).
  • batch_size – int, Number of problems to train per batch.
  • save_every – int, Specifies how often to store auxiliary information. If None, then information is never stored (Default: None).
  • clip_value – float, Specifies the gradient clipping value (maximum gradient norm) (Default: np.inf).
Returns:

Final optimized parameters. store: Dict containing saved auxiliary information during optimization.

Return type:

final_params

renn.metaopt.api.unroll_for(initial_params, loss_fun, optimizer, extract_state, steps)[source]

Runs an optimizer on a given problem, using a for loop.

Note: this is slower to compile than unroll_scan, but can be used to store intermediate computations (such as the optimizer state or problem parameters) at every iteration, for further analysis.

Parameters:
  • initial_params – Initial parameters.
  • loss_fun – A function that takes (params, step) and returns a loss.
  • optimizer – A tuple containing an optimizer init function, an update function, and a get_params function.
  • extract_state – A function that given some optimizer state, returns what from that optimizer state to store. Note that each optimizer state is different, so this function depends on a particular optimizer.
  • steps – A generator that yields integers from (0, num_steps).
Returns:

Dictionary containing results to save.

Return type:

results

renn.metaopt.api.unroll_scan(initial_params, loss_fun, optimizer, num_steps, decorator)[source]

Runs an optimizer on a given problem, using lax.scan.

Note: this will cache parameters during the unrolled loop, and thus uses a lot of device memory, therefore it is not good for simply evaluating (testing) an optimizer. Instead, it is useful for when we need to compute a _derivative_ of some final loss with respect to the optimizer parameters.

Parameters:
  • initial_params – Initial parameters.
  • loss_fun – A function that takes (params, step) and returns a loss.
  • optimizer – A tuple containing an optimizer init function, an update function, and a get_params function.
  • num_steps – int, number of steps to run the optimizer.
  • decorator – callable, Optional decorator function used to wrap the apply_fun argument to lax.scan.
Returns:

Problem parameters after running the optimizer. fs: Loss at every step of the loop.

Return type:

final_params

renn.metaopt.common module

Update functions for common optimizers.

renn.metaopt.common.adagrad(alpha, beta)[source]
renn.metaopt.common.adam(alpha, beta1=0.9, beta2=0.999, eps=1e-05)[source]
renn.metaopt.common.cwrnn(cell_apply, readout_apply)[source]
renn.metaopt.common.momentum(alpha, beta)[source]
renn.metaopt.common.nesterov(alpha, beta)[source]
renn.metaopt.common.rmsprop(alpha, beta=0.9, eps=1e-05)[source]

renn.metaopt.losses module

Functions for computing a scalar objective from a loss curve.

renn.metaopt.losses.final(fs)[source]

Returns the final loss value.

renn.metaopt.losses.mean(fs)[source]

Returns the average over the loss values.

renn.metaopt.losses.nanmin(fs)[source]

Computes the NaN-aware minimum over the loss curve.

renn.metaopt.models module

Define simple learned optimizer models.

renn.metaopt.models.aggmo(key, num_terms)[source]

Aggregated momentum (aggmo).

renn.metaopt.models.append_to_sequence(sequence, element)[source]

Appends an element to a rolling sequence buffer.

Parameters:
  • sequence – a sequence of ndarrays, concatenated along the first dimension.
  • element – an ndarray to add to the sequence.
Returns:

the updated sequence, with the first element removed, the rest

of the elements shifted over, and the new element added.

Return type:

sequence

renn.metaopt.models.cwrnn(key, cell, input_scale='raw', output_scale=0.001)[source]

Component-wise RNN Optimizer.

This optimizer applies an RNN to update the parameters of each problem variable independently (hence the name, component-wise). It follows the same approach as in previous work (Andrychowicz et al 2016, Wichrowska et al 2017) that distribute the parameters along the batch dimension of the RNN. This allows us to easily update each parameter in parallel.

Parameters:
  • key – Jax PRNG key to use for initializing parameters.
  • cell – An RNNCell to use (see renn/rnn/cells.py)
  • input_scale – str, Specifies how to scale gradient inputs to the RNN. If ‘raw’, then the gradients are not scaled. If ‘log1p’, then the scale and the sign of the inputs are split into a length 2 vector, [log1p(abs(g)), sign(g)].
  • output_scale – float, Constant used to multiply (rescale) the RNN output.
Returns:

A tuple containing the RNN parameters and the readout

parameters. The RNN parameters themselves are a namedtuple. The readout parameters are also a tuple containing weights and a bias.

optimizer_fun: A function that takes a set of meta_parameters and

initializes an optimizer tuple containing functions to initialize the optimizer state, update the optimizer state, and get parameters from the optimizer state.

Return type:

meta_parameters

renn.metaopt.models.gradgram(key, tau, scale_grad, scale_gram, base_grad=0, base_gram=0)[source]

Optimizer that is a function of gradient history and inner products.

renn.metaopt.models.lds(key, num_units, h0_init=<function zeros>, w_init=<function variance_scaling.<locals>.init>)[source]

Linear dynamical system (LDS) optimizer.

renn.metaopt.models.linear(key, tau, scale, base=0)[source]

Optimizer that is a linear function of gradient history.

renn.metaopt.models.linear_dx(key, tau, scale_grad, scale_dx, base_grad=0, base_gram=0)[source]

Optimizer that is a linear function of gradient and parameter history.

renn.metaopt.models.momentum(key)[source]

Wrapper for the momentum optimizer.

renn.metaopt.tasks module

Load tasks from the library.

renn.metaopt.tasks.quad(n, lambda_min, lambda_max, precision=<PrecisionConfig_Precision.HIGHEST: 2>)

Quadratic loss function with loguniform eigenvalues.

The loss is: f(x) = (1/2) x^T H x + x^T v + b.

The eigenvalues of the Hessian (H) are sampled uniformly on a logarithmic grid from lambda_min to lambda_max.

Parameters:
  • n – int, Problem dimension (number of parameters).
  • lambda_min – float, Minimum eigenvalue of the Hessian.
  • lambda_max – float, Maximum eigenvalue of the Hessian.
  • precision – Which lax precision to use (default: HIGHEST).
Returns:

Function that takes a jax PRNGkey and a precision argument

and returns an (initial_params, loss_fun) tuple.

Return type:

problem_fun

renn.metaopt.tasks.two_moons(model, num_samples=1024, l2_pen=0.005, seed=0)[source]
renn.metaopt.tasks.logistic_regression(model, features, targets, l2_pen=0.0)[source]

Helper function for logistic regression.

renn.metaopt.tasks.softmax_regression(model, features, targets, num_classes, l2_pen=0.0)[source]

Helper function for softmax regression.

Module contents

Meta-optimization framework.