"""Load tasks from the library."""

import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp

from renn import utils
from sklearn.datasets import make_moons

from .task_lib import quadratic

__all__ = [

# Each task below is a function that takes problem parameters as
# arguments and returns a `problem_fun` function. This function takes a
# single argument, a PRNGKey, and returns a (x_init, loss_fun, data) tuple.
# x_init is a pytree initial problem parameters. loss_fun is a function
# that returns a scalar loss given parameters and a batch of data. Finally,
# data is an iterable pytree. All leaves of the tree must have the same first
# dimension, which is the number of steps to optimize for. These slices
# of data will be passed to the loss_fun during optimization.
quad = quadratic.loguniform

[docs]def two_moons(model, num_samples=1024, l2_pen=5e-3, seed=0): num_classes = 2 x, y = make_moons(n_samples=num_samples, shuffle=True, noise=0.1, random_state=seed) features = jnp.array(x) targets = jnp.array(y) return logistic_regression(model, features, targets, l2_pen=l2_pen)
[docs]def logistic_regression(model, features, targets, l2_pen=0.): """Helper function for logistic regression.""" m, n = features.shape def problem_fun(prng_key): keys = jax.random.split(prng_key) input_shape = (-1, n) init_fun, predict_fun = model output_shape, x0 = init_fun(keys[0], input_shape) def loss_fun(x, step): del step logits = jnp.squeeze(predict_fun(x, features)) data_loss = jnp.mean(jnp.log1p(jnp.exp(logits)) - targets * logits) reg_loss = l2_pen * utils.norm(x) return data_loss + reg_loss return x0, loss_fun return problem_fun
[docs]def softmax_regression(model, features, targets, num_classes, l2_pen=0.): """Helper function for softmax regression.""" m, n = features.shape def problem_fun(prng_key): keys = jax.random.split(prng_key) input_shape = (-1, n) init_fun, predict_fun = model output_shape, x0 = init_fun(keys[0], input_shape) def loss_fun(x, step): del step logits = jnp.squeeze(predict_fun(x, features)) onehot_targets = utils.one_hot(targets, num_classes) data_loss = -jnp.mean(jnp.sum(logits * onehot_targets, axis=1)) reg_loss = l2_pen * utils.norm(x) return data_loss + reg_loss return x0, loss_fun return problem_fun