32 lines
876 B
Python
32 lines
876 B
Python
from funasr.register import tables
|
|
from funasr.models.whisper_lid.eres2net.ResNet import (
|
|
ERes2Net,
|
|
BasicBlockERes2Net,
|
|
BasicBlockERes2Net_diff_AFF,
|
|
)
|
|
|
|
|
|
@tables.register("lid_predictor_classes", "LidPredictor")
|
|
class LidPredictor(ERes2Net):
|
|
def __init__(
|
|
self,
|
|
block=BasicBlockERes2Net,
|
|
block_fuse=BasicBlockERes2Net_diff_AFF,
|
|
num_blocks=[3, 4, 6, 3],
|
|
m_channels=32,
|
|
feat_dim=80,
|
|
embedding_size=192,
|
|
pooling_func="TSTP",
|
|
two_emb_layer=False,
|
|
):
|
|
super(LidPredictor, self).__init__(
|
|
block=block,
|
|
block_fuse=block_fuse,
|
|
num_blocks=num_blocks,
|
|
m_channels=m_channels,
|
|
feat_dim=feat_dim,
|
|
embedding_size=embedding_size,
|
|
pooling_func=pooling_func,
|
|
two_emb_layer=two_emb_layer,
|
|
)
|