38 lines
1.3 KiB
Python
38 lines
1.3 KiB
Python
|
#!/usr/bin/env python3
|
||
|
# -*- encoding: utf-8 -*-
|
||
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||
|
# MIT License (https://opensource.org/licenses/MIT)
|
||
|
|
||
|
import time
|
||
|
import torch
|
||
|
import logging
|
||
|
from contextlib import contextmanager
|
||
|
from typing import Dict, Optional, Tuple
|
||
|
from distutils.version import LooseVersion
|
||
|
|
||
|
from funasr.register import tables
|
||
|
from funasr.utils import postprocess_utils
|
||
|
from funasr.utils.datadir_writer import DatadirWriter
|
||
|
from funasr.models.transducer.model import Transducer
|
||
|
from funasr.train_utils.device_funcs import force_gatherable
|
||
|
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
|
||
|
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
||
|
from funasr.models.transformer.scorers.length_bonus import LengthBonus
|
||
|
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
|
||
|
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||
|
from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer
|
||
|
|
||
|
|
||
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||
|
from torch.cuda.amp import autocast
|
||
|
else:
|
||
|
# Nothing to do if torch<1.6.0
|
||
|
@contextmanager
|
||
|
def autocast(enabled=True):
|
||
|
yield
|
||
|
|
||
|
|
||
|
@tables.register("model_classes", "BAT") # TODO: BAT training
|
||
|
class BAT(Transducer):
|
||
|
pass
|