#!/usr/bin/env python3 # -*- encoding: utf-8 -*- import os import sys import torch import torch.nn as nn import hydra import logging import time import argparse from io import BytesIO from contextlib import nullcontext import torch.distributed as dist from omegaconf import DictConfig, OmegaConf from torch.cuda.amp import autocast, GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.algorithms.join import Join from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from funasr.train_utils.average_nbest_models import average_checkpoints from funasr.register import tables from funasr.optimizers import optim_classes from funasr.train_utils.trainer import Trainer from funasr.schedulers import scheduler_classes from funasr.train_utils.initialize import initialize from funasr.download.download_from_hub import download_model from funasr.models.lora.utils import mark_only_lora_as_trainable from funasr.train_utils.set_all_random_seed import set_all_random_seed from funasr.train_utils.load_pretrained_model import load_pretrained_model from funasr.utils.misc import prepare_model_dir from funasr.train_utils.model_summary import model_summary from funasr import AutoModel @hydra.main(config_name=None, version_base=None) def main_hydra(kwargs: DictConfig): if kwargs.get("debug", False): import pdb pdb.set_trace() assert "model" in kwargs if "model_conf" not in kwargs: logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms"))) kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs) main(**kwargs) def main(**kwargs): # set random seed set_all_random_seed(kwargs.get("seed", 0)) torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) # open tf32 torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True) local_rank = int(os.environ.get("LOCAL_RANK", 0)) if local_rank == 0: tables.print() # Check if we are using DDP or FSDP use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1 use_fsdp = kwargs.get("use_fsdp", False) # use_ddp = False if use_fsdp else use_fsdp if use_ddp or use_fsdp: dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://") torch.cuda.set_device(local_rank) logging.info("Build model, frontend, tokenizer") device = kwargs.get("device", "cuda") kwargs["device"] = "cpu" model = AutoModel(**kwargs) # save config.yaml if ( (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0 ): prepare_model_dir(**kwargs) # parse kwargs kwargs = model.kwargs kwargs["device"] = device tokenizer = kwargs["tokenizer"] frontend = kwargs["frontend"] model = model.model del kwargs["model"] # freeze_param freeze_param = kwargs.get("freeze_param", None) if freeze_param is not None: if "," in freeze_param: freeze_param = eval(freeze_param) if not isinstance(freeze_param, (list, tuple)): freeze_param = (freeze_param,) logging.info("freeze_param is not None: %s", freeze_param) for t in freeze_param: for k, p in model.named_parameters(): if k.startswith(t + ".") or k == t: logging.info(f"Setting {k}.requires_grad = False") p.requires_grad = False if local_rank == 0: logging.info(f"{model_summary(model)}") if use_ddp: model = model.cuda(local_rank) model = DDP( model, device_ids=[local_rank], find_unused_parameters=kwargs.get("train_conf", {}).get( "find_unused_parameters", False ), ) elif use_fsdp: # model = FSDP(model).cuda(local_rank) def custom_auto_wrap_policy( module: nn.Module, recurse: bool, nonwrapped_numel: int, # Additional custom arguments min_num_params: int = int(1e8), ) -> bool: # 根据自定义逻辑决定是否包装模块 is_large = unwrapped_params >= min_num_params requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1 return is_large and requires_grad_uniform # Configure a custom `min_num_params` my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5)) torch.cuda.set_device(local_rank) model = FSDP( model, auto_wrap_policy=custom_auto_wrap_policy, mixed_precision=None, device_id=torch.cuda.current_device(), ) else: model = model.to(device=kwargs.get("device", "cuda")) kwargs["device"] = next(model.parameters()).device # optim logging.info("Build optim") optim = kwargs.get("optim", "adam") assert optim in optim_classes optim_class = optim_classes.get(optim) optim = optim_class(model.parameters(), **kwargs.get("optim_conf")) # scheduler logging.info("Build scheduler") scheduler = kwargs.get("scheduler", "warmuplr") assert scheduler in scheduler_classes scheduler_class = scheduler_classes.get(scheduler) scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) # dataset logging.info("Build dataloader") dataloader_class = tables.dataloader_classes.get( kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle") ) dataloader = dataloader_class(**kwargs) # dataloader_tr, dataloader_val = dataloader_class(**kwargs) trainer = Trainer( local_rank=local_rank, use_ddp=use_ddp, use_fsdp=use_fsdp, device=kwargs["device"], output_dir=kwargs.get("output_dir", "./exp"), **kwargs.get("train_conf"), ) scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler trainer.resume_checkpoint( model=model, optim=optim, scheduler=scheduler, scaler=scaler, ) tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard") os.makedirs(tensorboard_dir, exist_ok=True) try: from tensorboardX import SummaryWriter writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None except: writer = None dataloader_tr, dataloader_val = None, None for epoch in range(trainer.start_epoch, trainer.max_epoch): time1 = time.perf_counter() for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num): dataloader_tr, dataloader_val = dataloader.build_iter( epoch, data_split_i=data_split_i, start_step=trainer.start_step ) trainer.train_epoch( model=model, optim=optim, scheduler=scheduler, scaler=scaler, dataloader_train=dataloader_tr, dataloader_val=dataloader_val, epoch=epoch, writer=writer, data_split_i=data_split_i, data_split_num=dataloader.data_split_num, start_step=trainer.start_step, ) trainer.start_step = 0 torch.cuda.empty_cache() trainer.start_data_split_i = 0 trainer.validate_epoch( model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer ) scheduler.step() trainer.step_in_epoch = 0 trainer.save_checkpoint( epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler ) time2 = time.perf_counter() time_escaped = (time2 - time1) / 3600.0 logging.info( f"rank: {local_rank}, " f"time_escaped_epoch: {time_escaped:.3f} hours, " f"estimated to finish {trainer.max_epoch} " f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n" ) trainer.train_acc_avg = 0.0 trainer.train_loss_avg = 0.0 if trainer.rank == 0: average_checkpoints(trainer.output_dir, trainer.avg_nbest_model) trainer.close() if __name__ == "__main__": main_hydra()