ScalingMakeItPossible / FTP_update.py
Yaning1001's picture
Add files using upload-large-folder tool
69168b6 verified
import torch
from torch.optim.optimizer import Optimizer, required
import copy
import math
class FTP(object):
def __init__(self, k=1.0, exclude_set={}):
self.exclude_set = exclude_set
self.threshold = torch.nn.Hardtanh(0,1)
self.k = k # Gradient annealing factor
self.j = 0 # Buffer counter
# AdamUtil parameters
self.mu = 1e-2
self.beta1 = 0.9
self.beta2 = 0.999
self.t = 1
# Buffers
self.gamma_buff = []
self.first_m_gamma = []
self.second_m_gamma = []
self.prev_c = []
self.prev_scale = []
@torch.no_grad()
def step(self, name, curr, pre, d_p):
if curr.requires_grad and name not in self.exclude_set: # Exclude set includes those params that not be updated
c_t = (curr - d_p) - pre # Compute potential new param
norms = self._mars_norm(c_t)
# New: Apply spectral normalization to c_t
c_t = self._apply_spectral_norm(c_t)
if self.t == 1:
gamma = torch.tensor(1e-8, device=norms.device)
self._update_buffers(gamma)
else:
# Get previous values
gamma_prev = self.gamma_buff[self.j]
c_prev = self.prev_c[self.j]
scale_prev = self.prev_scale[self.j]
# Calculate gradient for gamma
gamma_grad = torch.sum(self._dot(curr.grad, c_prev, scale=scale_prev))
# Anneal positive gradient
if gamma_grad > 0:
gamma_grad = gamma_grad * self.k
gamma = self._adam_util(gamma_prev, gamma_grad)
# Clip gamma
gamma = self._clip(gamma, norms)
# Update
denom = 1/norms
ratio = gamma * denom
new_p = pre + self.threshold(ratio) * c_t
# Save updated values
self._update_buffers(gamma, c_t, denom)
self.j += 1
return new_p
else:
return None
def _apply_spectral_norm(self, c_t):
u, s, vh = torch.linalg.svd(c_t, full_matrices=False)
spectral_norm = s.max()
return torch.matmul(u, vh) / spectral_norm
def incre_counters(self):
self.t += 1
self.j = 0
@torch.no_grad()
def _mars_norm(self, tensor):
return torch.sum(torch.abs(tensor), dim=tuple(range(1, tensor.dim())), keepdim=True) + 1e-8
@torch.no_grad()
def _clip(self, constraint, norms):
return torch.nn.functional.hardtanh(constraint, 1e-8, norms.max())
@torch.no_grad()
def _dot(self, tensor1, tensor2, scale=1):
return torch.sum(torch.mul(tensor1, tensor2), dim=tuple(range(1, tensor1.dim())), keepdim=True) * scale
@torch.no_grad()
def _adam_util(self, prev, grad):
first_moment = self.beta1 * self.first_m_gamma[self.j] + (1 - self.beta1) * grad
second_moment = self.beta2 * self.second_m_gamma[self.j] + (1 - self.beta2) * grad**2
self.first_m_gamma[self.j] = first_moment
self.second_m_gamma[self.j] = second_moment
first_moment = first_moment / (1 - self.beta1**self.t)
second_moment = second_moment / (1 - self.beta2**self.t)
return prev - self.mu * first_moment / (torch.sqrt(second_moment) + 1e-8)
def _update_buffers(self, gamma, c_t=None, denom=None):
if c_t is None:
self.first_m_gamma.append(0.0)
self.second_m_gamma.append(0.0)
self.gamma_buff.append(gamma)
self.prev_c.append(0.0)
self.prev_scale.append(0.0)
else:
self.gamma_buff[self.j] = gamma
self.prev_c[self.j] = c_t
self.prev_scale[self.j] = denom
class SGDP(Optimizer):
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False, k=1.0, exclude_set = {}):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGDP, self).__init__(params, defaults)
self.first_iter_flag = False
# initialize FTP
self.ftp = FTP(k, exclude_set=exclude_set)
def __setstate__(self, state):
super(SGDP, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
for p, name, pre in zip(group['params'],group['name'],group['pre']):
if p.grad is None:
continue
d_p = p.grad
if weight_decay != 0:
d_p = d_p.add(p, alpha=weight_decay)
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(buf, alpha=momentum)
else:
d_p = buf
# FTP step
d_p = group['lr']*d_p
new_p = self.ftp.step(name,p,pre,d_p)
if new_p is not None :
p.copy_(new_p)
else:
p.add_(d_p, alpha=-1)
# FTP increment internal counters
self.ftp.incre_counters()
return loss
class AdamP(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, k=1.0, exclude_set={}):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(AdamP, self).__init__(params, defaults)
# initialize FTP
self.ftp = FTP(k, exclude_set=exclude_set)
def __setstate__(self, state):
super(AdamP, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
max_exp_avg_sqs = []
state_steps = []
for p in group['params']:
if p.grad is not None:
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
grads.append(p.grad)
state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if group['amsgrad']:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
if group['amsgrad']:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'])
beta1, beta2 = group['betas']
self.adam(group,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
group['amsgrad'],
beta1,
beta2,
group['lr'],
group['weight_decay'],
group['eps']
)
# FTP increment internal counters
self.ftp.incre_counters()
return loss
def adam(self, group,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float):
i = 0
for param, name, pre in zip(group['params'],group['name'],group['pre']):
if param.grad is None:
continue
grad = param.grad
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
if amsgrad:
max_exp_avg_sq = max_exp_avg_sqs[i]
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
i += 1
# FTP step
d_p = step_size * exp_avg/denom + lr * weight_decay * param
new_p = self.ftp.step(name,param,pre,d_p)
if new_p is None :
new_p = param - d_p
param.copy_(new_p)