131 lines
4.3 KiB
Python
131 lines
4.3 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""
|
|
Used for EMA tracking a given pytorch module. The user is responsible for calling step()
|
|
and setting the appropriate decay
|
|
"""
|
|
|
|
import copy
|
|
import logging
|
|
|
|
import torch
|
|
|
|
|
|
class EMAModule:
|
|
"""Exponential Moving Average of Fairseq Models"""
|
|
|
|
def __init__(self, model, ema_decay=0.9999, ema_fp32=False, device=None, skip_keys=None):
|
|
"""
|
|
@param model model to initialize the EMA with
|
|
@param config EMAConfig object with configuration like
|
|
ema_decay, ema_update_freq, ema_fp32
|
|
@param device If provided, copy EMA to this device (e.g. gpu).
|
|
Otherwise EMA is in the same device as the model.
|
|
"""
|
|
|
|
self.decay = ema_decay
|
|
self.ema_fp32 = ema_fp32
|
|
self.model = copy.deepcopy(model)
|
|
self.model.requires_grad_(False)
|
|
self.skip_keys = skip_keys or set()
|
|
self.fp32_params = {}
|
|
|
|
if device is not None:
|
|
logging.info(f"Copying EMA model to device {device}")
|
|
self.model = self.model.to(device=device)
|
|
|
|
if self.ema_fp32:
|
|
self.build_fp32_params()
|
|
|
|
self.update_freq_counter = 0
|
|
|
|
def build_fp32_params(self, state_dict=None):
|
|
"""
|
|
Store a copy of the EMA params in fp32.
|
|
If state dict is passed, the EMA params is copied from
|
|
the provided state dict. Otherwise, it is copied from the
|
|
current EMA model parameters.
|
|
"""
|
|
if not self.ema_fp32:
|
|
raise RuntimeError(
|
|
"build_fp32_params should not be called if ema_fp32=False. "
|
|
"Use ema_fp32=True if this is really intended."
|
|
)
|
|
|
|
if state_dict is None:
|
|
state_dict = self.model.state_dict()
|
|
|
|
def _to_float(t):
|
|
return t.float() if torch.is_floating_point(t) else t
|
|
|
|
for param_key in state_dict:
|
|
if param_key in self.fp32_params:
|
|
self.fp32_params[param_key].copy_(state_dict[param_key])
|
|
else:
|
|
self.fp32_params[param_key] = _to_float(state_dict[param_key])
|
|
|
|
def restore(self, state_dict, build_fp32_params=False):
|
|
"""Load data from a model spec into EMA model"""
|
|
self.model.load_state_dict(state_dict, strict=False)
|
|
if build_fp32_params:
|
|
self.build_fp32_params(state_dict)
|
|
|
|
def set_decay(self, decay):
|
|
self.decay = decay
|
|
|
|
def get_decay(self):
|
|
return self.decay
|
|
|
|
def _step_internal(self, new_model):
|
|
"""One update of the EMA model based on new model weights"""
|
|
decay = self.decay
|
|
|
|
ema_state_dict = {}
|
|
ema_params = self.fp32_params if self.ema_fp32 else self.model.state_dict()
|
|
for key, param in new_model.state_dict().items():
|
|
if isinstance(param, dict):
|
|
continue
|
|
try:
|
|
ema_param = ema_params[key]
|
|
except KeyError:
|
|
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
|
|
|
|
if param.shape != ema_param.shape:
|
|
raise ValueError(
|
|
"incompatible tensor shapes between model param and ema param"
|
|
+ "{} vs. {}".format(param.shape, ema_param.shape)
|
|
)
|
|
|
|
if "version" in key:
|
|
# Do not decay a model.version pytorch param
|
|
continue
|
|
|
|
if key in self.skip_keys or (
|
|
"num_batches_tracked" in key and ema_param.dtype == torch.int64
|
|
):
|
|
ema_param = param.to(dtype=ema_param.dtype).clone()
|
|
ema_params[key].copy_(ema_param)
|
|
else:
|
|
ema_param.mul_(decay)
|
|
ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay)
|
|
ema_state_dict[key] = ema_param
|
|
self.restore(ema_state_dict, build_fp32_params=False)
|
|
|
|
def step(self, new_model):
|
|
self._step_internal(new_model)
|
|
|
|
def reverse(self, model):
|
|
"""
|
|
Load the model parameters from EMA model.
|
|
Useful for inference or fine-tuning from the EMA model.
|
|
"""
|
|
d = self.model.state_dict()
|
|
if "_ema" in d:
|
|
del d["_ema"]
|
|
|
|
model.load_state_dict(d, strict=False)
|
|
return model
|