This implements the fixed-beta version which performs similarly to the hand-tuned-schedule variant.
The full implementation is:
_raw_sf = C.update_by_schedule_free.fn.fn
def _amuse_beta(group):
step = group.get('_group_step')
if step is None: return group['amuse_beta1']
step, w = max(int(step), 1), group['warmup_steps']
if step <= w or w <= 1: return group['amuse_beta1']
return 1 - ((w - 1) / (step - 1)) ** group['amuse_rho'] * (1 - group['amuse_beta1'])
@C.zero_guard('momentum')
@C.no_state
def muon_ema(group, update, grad, param, momentum):
return utils.nesterov_ema(momentum, update, group['muon_mu'])
@C.copy_guard(2, 'z')
@C.no_state
def amuse_sf(group, update, grad, param, z):
group['beta'] = _amuse_beta(group)
return _raw_sf(group, update, grad, param, z)
class AMUSE(C.ScheduleFree):
def __init__(self, params, **kw):
d = dict(lr=.02, beta=.6, muon_mu=.95, weight_decay=0, warmup_steps=0,
weight_lr_power=2., r=0., amuse_beta1=.6, amuse_rho=.8)
d.update(kw)
super().__init__(params, d, fns=(muon_ema, C.orthogonalize_update, amuse_sf))