import logging from pathlib import Path from typing import Optional from typing import Sequence from typing import Union import warnings import os from io import BytesIO import torch from typing import Collection import os import torch import re from collections import OrderedDict from functools import cmp_to_key def _get_checkpoint_paths(output_dir: str, last_n: int = 5): """ Get the paths of the last 'last_n' checkpoints by parsing filenames in the output directory. """ try: checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu") avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"] val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"] sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True) sorted_items = ( sorted_items[:last_n] if avg_keep_nbest_models_type == "acc" else sorted_items[-last_n:] ) checkpoint_paths = [os.path.join(output_dir, key) for key, value in sorted_items[:last_n]] except: print(f"{checkpoint} does not exist, avg the lastet checkpoint.") # List all files in the output directory files = os.listdir(output_dir) # Filter out checkpoint files and extract epoch numbers checkpoint_files = [f for f in files if f.startswith("model.pt.e")] # Sort files by epoch number in descending order checkpoint_files.sort(key=lambda x: int(re.search(r"(\d+)", x).group()), reverse=True) # Get the last 'last_n' checkpoint paths checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]] return checkpoint_paths @torch.no_grad() def average_checkpoints(output_dir: str, last_n: int = 5, **kwargs): """ Average the last 'last_n' checkpoints' model state_dicts. If a tensor is of type torch.int, perform sum instead of average. """ checkpoint_paths = _get_checkpoint_paths(output_dir, last_n) print(f"average_checkpoints: {checkpoint_paths}") state_dicts = [] # Load state_dicts from checkpoints for path in checkpoint_paths: if os.path.isfile(path): state_dicts.append(torch.load(path, map_location="cpu")["state_dict"]) else: print(f"Checkpoint file {path} not found.") # Check if we have any state_dicts to average if len(state_dicts) < 1: raise RuntimeError("No checkpoints found for averaging.") # Average or sum weights avg_state_dict = OrderedDict() for key in state_dicts[0].keys(): tensors = [state_dict[key].cpu() for state_dict in state_dicts] # Check the type of the tensor if str(tensors[0].dtype).startswith("torch.int"): # Perform sum for integer tensors summed_tensor = sum(tensors) avg_state_dict[key] = summed_tensor else: # Perform average for other types of tensors stacked_tensors = torch.stack(tensors) avg_state_dict[key] = torch.mean(stacked_tensors, dim=0) checkpoint_outpath = os.path.join(output_dir, f"model.pt.avg{last_n}") torch.save({"state_dict": avg_state_dict}, checkpoint_outpath) return checkpoint_outpath