renn.metaopt.task_lib package¶
Submodules¶
renn.metaopt.task_lib.quadratic module¶
Defines quadratic loss functions.
-
renn.metaopt.task_lib.quadratic.
loguniform
(n, lambda_min, lambda_max, precision=<PrecisionConfig_Precision.HIGHEST: 2>)[source]¶ Quadratic loss function with loguniform eigenvalues.
The loss is: f(x) = (1/2) x^T H x + x^T v + b.
The eigenvalues of the Hessian (H) are sampled uniformly on a logarithmic grid from lambda_min to lambda_max.
Parameters: - n – int, Problem dimension (number of parameters).
- lambda_min – float, Minimum eigenvalue of the Hessian.
- lambda_max – float, Maximum eigenvalue of the Hessian.
- precision – Which lax precision to use (default: HIGHEST).
Returns: - Function that takes a jax PRNGkey and a precision argument
and returns an (initial_params, loss_fun) tuple.
Return type: problem_fun