159 lines
4.9 KiB
Python
159 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
|
# -*- encoding: utf-8 -*-
|
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
|
# MIT License (https://opensource.org/licenses/MIT)
|
|
|
|
import torch
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
from funasr.register import tables
|
|
from funasr.models.rwkv_bat.rwkv import RWKV
|
|
from funasr.models.transformer.layer_norm import LayerNorm
|
|
from funasr.models.transformer.utils.nets_utils import make_source_mask
|
|
from funasr.models.rwkv_bat.rwkv_subsampling import RWKVConvInput
|
|
|
|
|
|
@tables.register("encoder_classes", "RWKVEncoder")
|
|
class RWKVEncoder(torch.nn.Module):
|
|
"""RWKV encoder module.
|
|
|
|
Based on https://arxiv.org/pdf/2305.13048.pdf.
|
|
|
|
Args:
|
|
vocab_size: Vocabulary size.
|
|
output_size: Input/Output size.
|
|
context_size: Context size for WKV computation.
|
|
linear_size: FeedForward hidden size.
|
|
attention_size: SelfAttention hidden size.
|
|
normalization_type: Normalization layer type.
|
|
normalization_args: Normalization layer arguments.
|
|
num_blocks: Number of RWKV blocks.
|
|
embed_dropout_rate: Dropout rate for embedding layer.
|
|
att_dropout_rate: Dropout rate for the attention module.
|
|
ffn_dropout_rate: Dropout rate for the feed-forward module.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
output_size: int = 512,
|
|
context_size: int = 1024,
|
|
linear_size: Optional[int] = None,
|
|
attention_size: Optional[int] = None,
|
|
num_blocks: int = 4,
|
|
att_dropout_rate: float = 0.0,
|
|
ffn_dropout_rate: float = 0.0,
|
|
dropout_rate: float = 0.0,
|
|
subsampling_factor: int = 4,
|
|
time_reduction_factor: int = 1,
|
|
kernel: int = 3,
|
|
**kwargs,
|
|
) -> None:
|
|
"""Construct a RWKVEncoder object."""
|
|
super().__init__()
|
|
|
|
self.embed = RWKVConvInput(
|
|
input_size,
|
|
[output_size // 4, output_size // 2, output_size],
|
|
subsampling_factor,
|
|
conv_kernel_size=kernel,
|
|
output_size=output_size,
|
|
)
|
|
|
|
self.subsampling_factor = subsampling_factor
|
|
|
|
linear_size = output_size * 4 if linear_size is None else linear_size
|
|
attention_size = output_size if attention_size is None else attention_size
|
|
|
|
self.rwkv_blocks = torch.nn.ModuleList(
|
|
[
|
|
RWKV(
|
|
output_size,
|
|
linear_size,
|
|
attention_size,
|
|
context_size,
|
|
block_id,
|
|
num_blocks,
|
|
att_dropout_rate=att_dropout_rate,
|
|
ffn_dropout_rate=ffn_dropout_rate,
|
|
dropout_rate=dropout_rate,
|
|
)
|
|
for block_id in range(num_blocks)
|
|
]
|
|
)
|
|
|
|
self.embed_norm = LayerNorm(output_size)
|
|
self.final_norm = LayerNorm(output_size)
|
|
|
|
self._output_size = output_size
|
|
self.context_size = context_size
|
|
|
|
self.num_blocks = num_blocks
|
|
self.time_reduction_factor = time_reduction_factor
|
|
|
|
def output_size(self) -> int:
|
|
return self._output_size
|
|
|
|
def forward(self, x: torch.Tensor, x_len) -> torch.Tensor:
|
|
"""Encode source label sequences.
|
|
|
|
Args:
|
|
x: Encoder input sequences. (B, L)
|
|
|
|
Returns:
|
|
out: Encoder output sequences. (B, U, D)
|
|
|
|
"""
|
|
_, length, _ = x.size()
|
|
|
|
assert (
|
|
length <= self.context_size * self.subsampling_factor
|
|
), "Context size is too short for current length: %d versus %d" % (
|
|
length,
|
|
self.context_size * self.subsampling_factor,
|
|
)
|
|
mask = make_source_mask(x_len).to(x.device)
|
|
x, mask = self.embed(x, mask, None)
|
|
x = self.embed_norm(x)
|
|
olens = mask.eq(0).sum(1)
|
|
|
|
if self.training:
|
|
for block in self.rwkv_blocks:
|
|
x, _ = block(x)
|
|
else:
|
|
x = self.rwkv_infer(x)
|
|
|
|
x = self.final_norm(x)
|
|
|
|
if self.time_reduction_factor > 1:
|
|
x = x[:, :: self.time_reduction_factor, :]
|
|
olens = torch.floor_divide(olens - 1, self.time_reduction_factor) + 1
|
|
|
|
return x, olens, None
|
|
|
|
def rwkv_infer(self, xs_pad):
|
|
|
|
batch_size = xs_pad.shape[0]
|
|
|
|
hidden_sizes = [self._output_size for i in range(5)]
|
|
|
|
state = [
|
|
torch.zeros(
|
|
(batch_size, 1, hidden_sizes[i], self.num_blocks),
|
|
dtype=torch.float32,
|
|
device=xs_pad.device,
|
|
)
|
|
for i in range(5)
|
|
]
|
|
|
|
state[4] -= 1e-30
|
|
|
|
xs_out = []
|
|
for t in range(xs_pad.shape[1]):
|
|
x_t = xs_pad[:, t, :]
|
|
for idx, block in enumerate(self.rwkv_blocks):
|
|
x_t, state = block(x_t, state=state)
|
|
xs_out.append(x_t)
|
|
xs_out = torch.cat(xs_out, dim=1)
|
|
return xs_out
|