renn package¶
Subpackages¶
Submodules¶
renn.analysis_utils module¶
Utilities for analysis.
-
renn.analysis_utils.
pseudogrid
(coordinates, dimension)[source]¶ Constructs a pseudogrid (‘pseudo’ in that it is not necessarily evenly-spaced) of points in ‘dimension’-dimension space from the specified coordinates.
Arguments: coordinates: a mapping between dimensions and
coordinates in those dimensionsdimension: number of dimensions
For all dimensions that are not specified, the coordinate is taken to be 0.
Example
- if coordinates = {0: [0, 1, 2],
- 2: [1]},
and dimension = 4, the coordinates in dimensions 1 and 3 will be taken as [0], yielding the effective coordinate-dictionary
- coordinates = {0: [0,1,2],
- 1: [0], 2: [1], 3: [0]}
- Then the resulting pseudogrid will be constructed as:
- [[0,0,1,0], [1,0,1,0], [2,0,1,0]]
renn.losses module¶
Functions for computing loss.
-
renn.losses.
binary_xent
(logits, labels)[source]¶ Cross-entropy loss in in a two-class classification problem, where the model output is a single logit
Parameters: - logits – array of shape (batch_size, 1) or just (batch_size)
- labels – array of length batch_size, whose elements are either 0 or 1
Returns: scalar cross entropy loss
Return type: loss
renn.serialize module¶
Serialization of pytrees.
renn.utils module¶
Utilities for optimization.
-
renn.utils.
batch_mean
(fun, in_axes)[source]¶ Converts a function to a batched version (maps over multiple inputs).
This takes a function that returns a scalar (such as a loss function) and returns a new function that maps the function over multiple arguments (such as over multiple random seeds) and returns the average of the results.
It is useful for generating a batched version of a loss function, where the loss function has stochasticity that depends on a random seed argument.
Parameters: - fun – function, Function to batch.
- in_axes – tuple, Specifies the arguments to fun to batch over. For example, in_axes=(None, 0) would batch over the second argument.
Returns: function, computes the average over a batch.
Return type: batch_fun
-
renn.utils.
optimize
(loss_fun, x0, optimizer, steps, stop_tol=-inf)[source]¶ Run an optimizer on a given loss function.
Parameters: - loss_fun – Scalar loss function to optimize.
- x0 – Initial parameters.
- optimizer – An tuple of optimizer functions (init_opt, update_opt, get_params) from a jax.experimental.optimizers instance.
- steps – Iterator over steps.
- stop_tol – Stop if the loss is below this value (Default: -np.inf).
Returns: Array of losses during training. final_params: Optimized parameters.
Return type: loss_hist
-
renn.utils.
one_hot
(labels, num_classes, dtype=<class 'jax._src.numpy.lax_numpy.float32'>)[source]¶ Creates a one-hot encoding of an array of labels.
Parameters: - labels – array of integers with shape (num_examples,).
- num_classes – int, Total number of classes.
- dtype – optional, jax datatype for the return array (Default: float32).
Returns: array with shape (num_examples, num_classes).
Return type: one_hot_labels
renn.version module¶
Module contents¶
RENN core.