114 lines
3.4 KiB
Python
114 lines
3.4 KiB
Python
|
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||
|
|
||
|
import copy
|
||
|
from typing import Optional, Tuple, Union
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
import torch.nn.functional as F
|
||
|
import whisper
|
||
|
|
||
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||
|
from funasr.models.specaug.specaug import SpecAug
|
||
|
from funasr.register import tables
|
||
|
|
||
|
|
||
|
@tables.register("encoder_classes", "OpenAIWhisperEncoderWarp")
|
||
|
class OpenAIWhisperEncoderWarp(nn.Module):
|
||
|
"""Transformer-based Speech Encoder from OpenAI's Whisper Model:
|
||
|
|
||
|
URL: https://github.com/openai/whisper
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
dropout_rate: float = 0.0,
|
||
|
whisper_model: str = "small",
|
||
|
download_dir: str = None,
|
||
|
use_specaug: bool = False,
|
||
|
use_padmask: bool = False,
|
||
|
specaug_conf: Union[dict, None] = None,
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
# note that originally Whisper doesn't use dropouts
|
||
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||
|
|
||
|
assert whisper_model in whisper.available_models()
|
||
|
_model = whisper.load_model(whisper_model, download_root=download_dir, device="cpu")
|
||
|
self.encoders = copy.deepcopy(_model.encoder)
|
||
|
self.encoders.train()
|
||
|
|
||
|
del _model
|
||
|
|
||
|
if use_specaug:
|
||
|
self.specaug = SpecAug(**specaug_conf)
|
||
|
else:
|
||
|
self.specaug = None
|
||
|
self.use_padmask = use_padmask
|
||
|
|
||
|
def whisper_encode(
|
||
|
self,
|
||
|
input: torch.Tensor,
|
||
|
ilens: torch.Tensor = None,
|
||
|
) -> torch.Tensor:
|
||
|
x = F.gelu(self.encoders.conv1(input))
|
||
|
x = F.gelu(self.encoders.conv2(x))
|
||
|
x = x.permute(0, 2, 1)
|
||
|
|
||
|
n_frames = x.size(1)
|
||
|
max_pos = self.encoders.positional_embedding.size(0)
|
||
|
if n_frames <= max_pos:
|
||
|
x = (x + self.encoders.positional_embedding[: x.size(1), :]).to(x.dtype)
|
||
|
else:
|
||
|
# due to positional encoding, audios >30 sec won't be accepted
|
||
|
x = x[:, :max_pos, :] + self.encoders.positional_embedding
|
||
|
|
||
|
if ilens is not None:
|
||
|
olens = (
|
||
|
1
|
||
|
+ (ilens - self.encoders.conv2.kernel_size[0] + 2 * self.encoders.conv2.padding[0])
|
||
|
// self.encoders.conv2.stride[0]
|
||
|
)
|
||
|
olens = torch.clamp(olens, max=max_pos)
|
||
|
else:
|
||
|
olens = None
|
||
|
|
||
|
if self.use_padmask:
|
||
|
padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
|
||
|
else:
|
||
|
padding_mask = None
|
||
|
|
||
|
x = self.dropout(x)
|
||
|
|
||
|
for layer, block in enumerate(self.encoders.blocks):
|
||
|
x = block(x)
|
||
|
if layer < len(self.encoders.blocks) - 1:
|
||
|
x = self.dropout(x)
|
||
|
|
||
|
x = self.encoders.ln_post(x)
|
||
|
|
||
|
return x, olens
|
||
|
|
||
|
def output_size(self) -> int:
|
||
|
# dummy output size
|
||
|
return self.encoders.conv2.weight.shape[0]
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
xs_pad: torch.Tensor,
|
||
|
ilens: torch.Tensor,
|
||
|
prev_states: torch.Tensor = None,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||
|
feats, feats_lens = xs_pad, ilens
|
||
|
|
||
|
if self.specaug is not None and self.encoders.training:
|
||
|
feats = torch.transpose(feats, 1, 2)
|
||
|
feats, feats_lens = self.specaug(feats, feats_lens)
|
||
|
feats = torch.transpose(feats, 1, 2)
|
||
|
|
||
|
xs_pad, olens = self.whisper_encode(feats, feats_lens)
|
||
|
|
||
|
return xs_pad, olens, None
|