renn package

Submodules

renn.analysis_utils module

Utilities for analysis.

renn.analysis_utils.pseudogrid(coordinates, dimension)[source]

Constructs a pseudogrid (‘pseudo’ in that it is not necessarily evenly-spaced) of points in ‘dimension’-dimension space from the specified coordinates.

Arguments: coordinates: a mapping between dimensions and

coordinates in those dimensions

dimension: number of dimensions

For all dimensions that are not specified, the coordinate is taken to be 0.

Example

if coordinates = {0: [0, 1, 2],
2: [1]},

and dimension = 4, the coordinates in dimensions 1 and 3 will be taken as [0], yielding the effective coordinate-dictionary

coordinates = {0: [0,1,2],
1: [0], 2: [1], 3: [0]}
Then the resulting pseudogrid will be constructed as:
[[0,0,1,0], [1,0,1,0], [2,0,1,0]]

renn.losses module

Functions for computing loss.

renn.losses.binary_xent(logits, labels)[source]

Cross-entropy loss in in a two-class classification problem, where the model output is a single logit

Parameters:
  • logits – array of shape (batch_size, 1) or just (batch_size)
  • labels – array of length batch_size, whose elements are either 0 or 1
Returns:

scalar cross entropy loss

Return type:

loss

renn.losses.multiclass_xent(logits, labels)[source]

renn.serialize module

Serialization of pytrees.

renn.serialize.dump(pytree, file)[source]
renn.serialize.load(file)[source]
renn.serialize.dumps(pytree)[source]
renn.serialize.loads(bytes)[source]

renn.utils module

Utilities for optimization.

renn.utils.batch_mean(fun, in_axes)[source]

Converts a function to a batched version (maps over multiple inputs).

This takes a function that returns a scalar (such as a loss function) and returns a new function that maps the function over multiple arguments (such as over multiple random seeds) and returns the average of the results.

It is useful for generating a batched version of a loss function, where the loss function has stochasticity that depends on a random seed argument.

Parameters:
  • fun – function, Function to batch.
  • in_axes – tuple, Specifies the arguments to fun to batch over. For example, in_axes=(None, 0) would batch over the second argument.
Returns:

function, computes the average over a batch.

Return type:

batch_fun

renn.utils.norm(params, order=2)[source]

Computes the (flattened) norm of a pytree.

renn.utils.identity(x)[source]

Identity function.

renn.utils.fst(xs)[source]

Returns the first element from a list.

renn.utils.snd(xs)[source]

Returns the second element from a list.

renn.utils.optimize(loss_fun, x0, optimizer, steps, stop_tol=-inf)[source]

Run an optimizer on a given loss function.

Parameters:
  • loss_fun – Scalar loss function to optimize.
  • x0 – Initial parameters.
  • optimizer – An tuple of optimizer functions (init_opt, update_opt, get_params) from a jax.experimental.optimizers instance.
  • steps – Iterator over steps.
  • stop_tol – Stop if the loss is below this value (Default: -np.inf).
Returns:

Array of losses during training. final_params: Optimized parameters.

Return type:

loss_hist

renn.utils.one_hot(labels, num_classes, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)[source]

Creates a one-hot encoding of an array of labels.

Parameters:
  • labels – array of integers with shape (num_examples,).
  • num_classes – int, Total number of classes.
  • dtype – optional, jax datatype for the return array (Default: float32).
Returns:

array with shape (num_examples, num_classes).

Return type:

one_hot_labels

renn.utils.compose(*funcs)[source]

Returns a function that is the composition of multiple functions.

renn.version module

Module contents

RENN core.