renn.metaopt.task_lib package

Submodules

renn.metaopt.task_lib.quadratic module

Defines quadratic loss functions.

renn.metaopt.task_lib.quadratic.loguniform(n, lambda_min, lambda_max, precision=<PrecisionConfig_Precision.HIGHEST: 2>)[source]

Quadratic loss function with loguniform eigenvalues.

The loss is: f(x) = (1/2) x^T H x + x^T v + b.

The eigenvalues of the Hessian (H) are sampled uniformly on a logarithmic grid from lambda_min to lambda_max.

Parameters:
  • n – int, Problem dimension (number of parameters).
  • lambda_min – float, Minimum eigenvalue of the Hessian.
  • lambda_max – float, Maximum eigenvalue of the Hessian.
  • precision – Which lax precision to use (default: HIGHEST).
Returns:

Function that takes a jax PRNGkey and a precision argument

and returns an (initial_params, loss_fun) tuple.

Return type:

problem_fun

renn.metaopt.task_lib.quadratic.quadform(hess, x, precision)[source]

Computes a quadratic form (x^T @ H @ x).

Module contents