Source code for renn.metaopt.task_lib.quadratic

# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines quadratic loss functions."""

import jax
import jax.numpy as jnp

HIGHEST = jax.lax.Precision.HIGHEST


[docs]def quadform(hess, x, precision): """Computes a quadratic form (x^T @ H @ x).""" u = jnp.dot(hess, x, precision=precision) # u = Hx return jnp.inner(x, u, precision=precision)
[docs]def loguniform(n, lambda_min, lambda_max, precision=HIGHEST): """Quadratic loss function with loguniform eigenvalues. The loss is: f(x) = (1/2) x^T H x + x^T v + b. The eigenvalues of the Hessian (H) are sampled uniformly on a logarithmic grid from lambda_min to lambda_max. Args: n: int, Problem dimension (number of parameters). lambda_min: float, Minimum eigenvalue of the Hessian. lambda_max: float, Maximum eigenvalue of the Hessian. precision: Which lax precision to use (default: HIGHEST). Returns: problem_fun: Function that takes a jax PRNGkey and a precision argument and returns an (initial_params, loss_fun) tuple. """ def problem_fun(key): """Builds a quadratic loss problem.""" pkey, ekey, qkey, vkey = jax.random.split(key, 4) # Sample eigenvalues. log_eigenvalues = jax.random.uniform(ekey, shape=(n,), minval=lambda_min, maxval=lambda_max) eigenvalues = 10**log_eigenvalues # Build orthonormal basis. basis = jax.nn.initializers.orthogonal()(qkey, shape=(n, n)) # Define hessian. hess = jnp.dot(jnp.dot(basis, jnp.diag(eigenvalues), precision=precision), basis.T, precision=precision) # Random vector for the linear term in the loss. v = jax.random.normal(vkey, shape=(n,)) # Compute an offset such that the global minimum has a loss of zero. xstar = jnp.linalg.solve(hess, -v) offset = -0.5 * quadform(hess, xstar, precision=precision) - jnp.inner(v, xstar, precision=precision) # pylint: disable=line-too-long def loss_fun(x, _): return 0.5 * quadform(hess, x, precision=precision) + jnp.inner(v, x, precision=precision) + offset # pylint: disable=line-too-long x_init = jax.random.normal(pkey, shape=(n,)) return x_init, loss_fun return problem_fun