#!/usr/bin/env python3 # -*- encoding: utf-8 -*- import os import sys import torch import torch.nn as nn import hydra import logging import time import argparse from io import BytesIO from contextlib import nullcontext import torch.distributed as dist from omegaconf import DictConfig, OmegaConf from torch.cuda.amp import autocast, GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.algorithms.join import Join from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from funasr.train_utils.average_nbest_models import average_checkpoints from funasr.register import tables from funasr.optimizers import optim_classes from funasr.train_utils.trainer_ds import Trainer from funasr.schedulers import scheduler_classes from funasr.train_utils.initialize import initialize from funasr.download.download_from_hub import download_model from funasr.models.lora.utils import mark_only_lora_as_trainable from funasr.train_utils.set_all_random_seed import set_all_random_seed from funasr.train_utils.load_pretrained_model import load_pretrained_model from funasr.utils.misc import prepare_model_dir from funasr.train_utils.model_summary import model_summary from funasr import AutoModel try: import deepspeed except: deepspeed = None @hydra.main(config_name=None, version_base=None) def main_hydra(kwargs: DictConfig): if kwargs.get("debug", False): import pdb pdb.set_trace() assert "model" in kwargs if "model_conf" not in kwargs: logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms"))) kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs) main(**kwargs) def main(**kwargs): # set random seed set_all_random_seed(kwargs.get("seed", 0)) torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) # open tf32 torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True) rank = int(os.environ.get("RANK", 0)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) world_size = int(os.environ.get("WORLD_SIZE", 1)) if local_rank == 0: tables.print() use_ddp = world_size > 1 use_fsdp = kwargs.get("use_fsdp", False) use_deepspeed = kwargs.get("use_deepspeed", False) if use_deepspeed: logging.info(f"use_deepspeed: {use_deepspeed}") deepspeed.init_distributed(dist_backend=kwargs.get("backend", "nccl")) elif use_ddp or use_fsdp: logging.info(f"use_ddp: {use_ddp}, use_fsdp: {use_fsdp}") dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method="env://") torch.cuda.set_device(local_rank) logging.info("Build model, frontend, tokenizer") device = kwargs.get("device", "cuda") kwargs["device"] = "cpu" model = AutoModel(**kwargs) # save config.yaml if rank == 0: prepare_model_dir(**kwargs) # parse kwargs kwargs = model.kwargs kwargs["device"] = device tokenizer = kwargs["tokenizer"] frontend = kwargs["frontend"] model = model.model del kwargs["model"] # freeze_param freeze_param = kwargs.get("freeze_param", None) if freeze_param is not None: if "," in freeze_param: freeze_param = eval(freeze_param) if not isinstance(freeze_param, (list, tuple)): freeze_param = (freeze_param,) logging.info("freeze_param is not None: %s", freeze_param) for t in freeze_param: for k, p in model.named_parameters(): if k.startswith(t + ".") or k == t: logging.info(f"Setting {k}.requires_grad = False") p.requires_grad = False if local_rank == 0: logging.info(f"{model_summary(model)}") trainer = Trainer( rank=rank, local_rank=local_rank, world_size=world_size, use_ddp=use_ddp, use_fsdp=use_fsdp, device=kwargs["device"], output_dir=kwargs.get("output_dir", "./exp"), **kwargs.get("train_conf"), ) model = trainer.warp_model(model) kwargs["device"] = next(model.parameters()).device trainer.device = kwargs["device"] # optim logging.info("Build optim") optim = kwargs.get("optim", "adam") assert optim in optim_classes optim_class = optim_classes.get(optim) optim = optim_class(model.parameters(), **kwargs.get("optim_conf")) # scheduler logging.info("Build scheduler") scheduler = kwargs.get("scheduler", "warmuplr") assert scheduler in scheduler_classes scheduler_class = scheduler_classes.get(scheduler) scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) if use_deepspeed: args = OmegaConf.create({"deepspeed_config": kwargs.get("deepspeed_config", "")}) model, optimizer, _, scheduler = deepspeed.initialize( args=args, model=model, optimizer=optim, lr_scheduler=scheduler, model_parameters=model.parameters(), ) # dataset logging.info("Build dataloader") dataloader_class = tables.dataloader_classes.get( kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle") ) dataloader = dataloader_class(**kwargs) # dataloader_tr, dataloader_val = dataloader_class(**kwargs) scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler trainer.resume_checkpoint( model=model, optim=optim, scheduler=scheduler, scaler=scaler, ) tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard") os.makedirs(tensorboard_dir, exist_ok=True) try: from tensorboardX import SummaryWriter writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None except: writer = None dataloader_tr, dataloader_val = None, None for epoch in range(trainer.start_epoch, trainer.max_epoch): time1 = time.perf_counter() for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num): dataloader_tr, dataloader_val = dataloader.build_iter( epoch, data_split_i=data_split_i, start_step=trainer.start_step ) trainer.train_epoch( model=model, optim=optim, scheduler=scheduler, scaler=scaler, dataloader_train=dataloader_tr, dataloader_val=dataloader_val, epoch=epoch, writer=writer, data_split_i=data_split_i, data_split_num=dataloader.data_split_num, start_step=trainer.start_step, ) trainer.start_step = 0 torch.cuda.empty_cache() trainer.start_data_split_i = 0 trainer.validate_epoch( model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer ) scheduler.step() trainer.step_in_epoch = 0 trainer.save_checkpoint( epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler ) time2 = time.perf_counter() time_escaped = (time2 - time1) / 3600.0 logging.info( f"rank: {local_rank}, " f"time_escaped_epoch: {time_escaped:.3f} hours, " f"estimated to finish {trainer.max_epoch} " f"epoch: {(trainer.max_epoch - epoch) * time_escaped:.3f} hours\n" ) trainer.train_acc_avg = 0.0 trainer.train_loss_avg = 0.0 if trainer.rank == 0: average_checkpoints(trainer.output_dir, trainer.avg_nbest_model) trainer.close() if __name__ == "__main__": main_hydra()