58 lines
1.5 KiB
Python
58 lines
1.5 KiB
Python
import copy
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
|
|
|
|
|
def sense_voice_encode_forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
ilens: torch.Tensor = None,
|
|
**kwargs,
|
|
):
|
|
use_padmask = self.use_padmask
|
|
x = F.gelu(self.conv1(x))
|
|
x = F.gelu(self.conv2(x))
|
|
x = x.permute(0, 2, 1)
|
|
|
|
n_frames = x.size(1)
|
|
max_pos = self.positional_embedding.size(0)
|
|
max_pos = n_frames if n_frames < max_pos else max_pos
|
|
x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
|
|
|
|
if ilens is not None:
|
|
if self.downsample_rate == 4:
|
|
olens = (
|
|
1
|
|
+ (ilens - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0])
|
|
// self.conv1.stride[0]
|
|
)
|
|
else:
|
|
olens = ilens
|
|
olens = (
|
|
1
|
|
+ (olens - self.conv2.kernel_size[0] + 2 * self.conv2.padding[0])
|
|
// self.conv2.stride[0]
|
|
)
|
|
olens = torch.clamp(olens, max=max_pos)
|
|
else:
|
|
olens = None
|
|
|
|
if use_padmask and olens is not None:
|
|
padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
|
|
else:
|
|
padding_mask = None
|
|
|
|
for layer, block in enumerate(self.blocks):
|
|
x = block(x, mask=padding_mask, is_pad_mask=True)
|
|
|
|
x = self.ln_post(x)
|
|
|
|
if ilens is None:
|
|
return x
|
|
else:
|
|
return x, olens
|