FunASR/funasr/models/sense_voice/encoder.py

58 lines
1.5 KiB
Python
Raw Normal View History

2024-05-18 15:50:56 +08:00
import copy
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
def sense_voice_encode_forward(
self,
x: torch.Tensor,
ilens: torch.Tensor = None,
**kwargs,
):
use_padmask = self.use_padmask
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
n_frames = x.size(1)
max_pos = self.positional_embedding.size(0)
max_pos = n_frames if n_frames < max_pos else max_pos
x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
if ilens is not None:
if self.downsample_rate == 4:
olens = (
1
+ (ilens - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0])
// self.conv1.stride[0]
)
else:
olens = ilens
olens = (
1
+ (olens - self.conv2.kernel_size[0] + 2 * self.conv2.padding[0])
// self.conv2.stride[0]
)
olens = torch.clamp(olens, max=max_pos)
else:
olens = None
if use_padmask and olens is not None:
padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
else:
padding_mask = None
for layer, block in enumerate(self.blocks):
x = block(x, mask=padding_mask, is_pad_mask=True)
x = self.ln_post(x)
if ilens is None:
return x
else:
return x, olens