"""Update functions for common optimizers."""
import jax.numpy as jnp
[docs]def momentum(alpha, beta):
def update(g, v):
"""Momentum update.
Args:
g: gradient
v: velocity
"""
v = beta * v + g
return -alpha * v
return update
[docs]def nesterov(alpha, beta):
def update(g, v):
"""Nesterov momentum update.
Args:
g: gradient
v: velocity
"""
v = beta * v + g
return -alpha * (beta * v + g)
return update
[docs]def adagrad(alpha, beta):
def update(g, g_sq, v):
"""Adagrad update.
Args:
g: gradient
g_sq: cumulative squared gradient
v: velocity
"""
g_sq += jnp.square(g)
g_norm = jnp.where(g_sq > 0, g / jnp.sqrt(g_sq), 0.)
v = (1. - beta) * g_norm + beta * v
return -alpha * v
return update
[docs]def rmsprop(alpha, beta=0.9, eps=1e-5):
def update(g, m):
"""RMSProp update.
Args:
g: gradient
m: running average of the second moment
"""
m = beta * m + jnp.square(g) * (1. - beta)
g_norm = g / jnp.sqrt(m + eps)
return -alpha * g_norm
return update
[docs]def adam(alpha, beta1=0.9, beta2=0.999, eps=1e-5):
def update(g, m, v):
"""Adam update.
Note: this is uncorrected.
Args:
g: gradient
v: running average of the first moment (momentum)
m: running average of the second moment (normalization)
"""
v = (1 - beta1) * g + beta1 * v # First moment.
m = (1 - beta2) * jnp.square(g) + beta2 * m # Second moment.
return -alpha * v / (jnp.sqrt(m) + eps)
return update
[docs]def cwrnn(cell_apply, readout_apply):
def update(g, h):
"""Component-wise RNN Optimizer update.
Args:
g: gradient
h: RNN state
"""
h_next = cell_apply(g, h)
return readout_apply(h_next)
return update