19 lines
521 B
Python
19 lines
521 B
Python
|
import torch
|
||
|
|
||
|
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||
|
from funasr.modules.nets_utils import make_pad_mask
|
||
|
|
||
|
|
||
|
class SimpleAvg(AbsEncoder):
|
||
|
def __init__(self, feat_dim):
|
||
|
super(SimpleAvg, self).__init__()
|
||
|
self.feat_dim = feat_dim
|
||
|
|
||
|
def forward(self, x, ilens):
|
||
|
mask = ~make_pad_mask(ilens, maxlen=x.shape[1]).to(x.device)
|
||
|
avg_x = (x * mask[:, :, None]).sum(1) / mask.sum(-1)[:, None]
|
||
|
return avg_x
|
||
|
|
||
|
def output_size(self) -> int:
|
||
|
return self.feat_dim
|