163 lines
5.4 KiB
Python
163 lines
5.4 KiB
Python
|
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||
|
|
||
|
import copy
|
||
|
from typing import Any, List, Tuple
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
import whisper
|
||
|
|
||
|
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||
|
from funasr.register import tables
|
||
|
|
||
|
|
||
|
@tables.register("decoder_classes", "OpenAIWhisperDecoderWarp")
|
||
|
class OpenAIWhisperDecoderWarp(nn.Module):
|
||
|
"""Transformer-based Speech-to-Text Decoder from OpenAI's Whisper Model:
|
||
|
|
||
|
URL: https://github.com/openai/whisper
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
dropout_rate: float = 0.0,
|
||
|
whisper_model: str = "small",
|
||
|
download_dir: str = None,
|
||
|
use_padmask: bool = False,
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
assert whisper_model in whisper.available_models()
|
||
|
_model = whisper.load_model(whisper_model, download_root=download_dir, device="cpu")
|
||
|
self.decoders = copy.deepcopy(_model.decoder)
|
||
|
attention_dim = self.decoders.token_embedding.embedding_dim
|
||
|
|
||
|
# note that originally Whisper doesn't use dropouts
|
||
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||
|
|
||
|
self.decoders.train()
|
||
|
del _model
|
||
|
self.use_padmask = use_padmask
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
hs_pad: torch.Tensor,
|
||
|
hlens: torch.Tensor,
|
||
|
ys_in_pad: torch.Tensor,
|
||
|
ys_in_lens: torch.Tensor,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""Forward decoder.
|
||
|
|
||
|
Args:
|
||
|
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||
|
hlens: (batch)
|
||
|
ys_in_pad:
|
||
|
input token ids, int64 (batch, maxlen_out)
|
||
|
if input_layer == "embed"
|
||
|
input tensor (batch, maxlen_out, #mels) in the other cases
|
||
|
ys_in_lens: (batch)
|
||
|
Returns:
|
||
|
(tuple): tuple containing:
|
||
|
|
||
|
x: decoded token score before softmax (batch, maxlen_out, token)
|
||
|
if use_output_layer is True,
|
||
|
olens: (batch, )
|
||
|
"""
|
||
|
tgt, memory = ys_in_pad, hs_pad
|
||
|
tgt = self.decoders.token_embedding(tgt) + self.decoders.positional_embedding[: tgt.size(1)]
|
||
|
tgt = self.dropout(tgt)
|
||
|
|
||
|
x = tgt.to(memory.dtype)
|
||
|
|
||
|
if self.use_padmask:
|
||
|
memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
|
||
|
else:
|
||
|
memory_mask = None
|
||
|
|
||
|
for layer, block in enumerate(self.decoders.blocks):
|
||
|
x = block(
|
||
|
x,
|
||
|
memory,
|
||
|
mask=self.decoders.mask,
|
||
|
memory_mask=memory_mask,
|
||
|
is_pad_mask=False,
|
||
|
is_pad_memory_mask=True,
|
||
|
)
|
||
|
|
||
|
if layer < len(self.decoders.blocks) - 1:
|
||
|
x = self.dropout(x)
|
||
|
|
||
|
x = self.decoders.ln(x)
|
||
|
x = (x @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
||
|
|
||
|
return x, ys_in_lens
|
||
|
|
||
|
def forward_one_step(
|
||
|
self,
|
||
|
tgt: torch.Tensor,
|
||
|
tgt_mask: torch.Tensor,
|
||
|
memory: torch.Tensor,
|
||
|
cache: List[torch.Tensor] = None,
|
||
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||
|
"""Forward one step.
|
||
|
|
||
|
Args:
|
||
|
tgt: input token ids, int64 (batch, maxlen_out)
|
||
|
tgt_mask: input token mask, (batch, maxlen_out)
|
||
|
dtype=torch.uint8 in PyTorch 1.2-
|
||
|
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
||
|
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
||
|
cache: cached output list of (batch, max_time_out-1, size)
|
||
|
Returns:
|
||
|
y, cache: NN output value and cache per `self.decoders`.
|
||
|
y.shape` is (batch, maxlen_out, token)
|
||
|
NOTE (Shih-Lun):
|
||
|
cache implementation is ignored for now
|
||
|
for simplicity & correctness
|
||
|
"""
|
||
|
x = self.decoders.token_embedding(tgt) + self.decoders.positional_embedding[: tgt.size(1)]
|
||
|
x = self.dropout(x)
|
||
|
x = x.to(memory.dtype)
|
||
|
|
||
|
for layer, block in enumerate(self.decoders.blocks):
|
||
|
x = block(x, memory, mask=self.decoders.mask)
|
||
|
if layer < len(self.decoders.blocks) - 1:
|
||
|
x = self.dropout(x)
|
||
|
|
||
|
x = self.decoders.ln(x)
|
||
|
y = x[:, -1]
|
||
|
y = (y @ torch.transpose(self.decoders.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
||
|
y = torch.log_softmax(y, dim=-1)
|
||
|
|
||
|
return y, None
|
||
|
|
||
|
def score(self, ys, state, x):
|
||
|
"""Score."""
|
||
|
logp, state = self.forward_one_step(
|
||
|
ys.unsqueeze(0), torch.empty(0), x.unsqueeze(0), cache=state # dummy mask
|
||
|
)
|
||
|
return logp.squeeze(0), state
|
||
|
|
||
|
def batch_score(
|
||
|
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
|
||
|
) -> Tuple[torch.Tensor, List[Any]]:
|
||
|
"""Score new token batch.
|
||
|
|
||
|
Args:
|
||
|
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
|
||
|
states (List[Any]): Scorer states for prefix tokens.
|
||
|
xs (torch.Tensor):
|
||
|
The encoder feature that generates ys (n_batch, xlen, n_feat).
|
||
|
|
||
|
Returns:
|
||
|
tuple[torch.Tensor, List[Any]]: Tuple of
|
||
|
batchfied scores for next token with shape of `(n_batch, n_vocab)`
|
||
|
and next state list for ys.
|
||
|
|
||
|
"""
|
||
|
# batch decoding, dummy mask is passed
|
||
|
logp, states = self.forward_one_step(ys, torch.empty(0), xs, cache=None)
|
||
|
|
||
|
return logp, None
|