318 lines
9.8 KiB
Python
318 lines
9.8 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn, einsum
|
|
from einops import rearrange
|
|
|
|
|
|
def identity(t, *args, **kwargs):
|
|
return t
|
|
|
|
|
|
def append_dims(x, num_dims):
|
|
if num_dims <= 0:
|
|
return x
|
|
return x.view(*x.shape, *((1,) * num_dims))
|
|
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
|
|
def default(val, d):
|
|
return val if exists(val) else d
|
|
|
|
|
|
def padding_to_multiple_of(n, mult):
|
|
remainder = n % mult
|
|
if remainder == 0:
|
|
return 0
|
|
return mult - remainder
|
|
|
|
|
|
class Transpose(nn.Module):
|
|
"""Wrapper class of torch.transpose() for Sequential module."""
|
|
|
|
def __init__(self, shape: tuple):
|
|
super(Transpose, self).__init__()
|
|
self.shape = shape
|
|
|
|
def forward(self, x):
|
|
return x.transpose(*self.shape)
|
|
|
|
|
|
class DepthwiseConv1d(nn.Module):
|
|
"""
|
|
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
|
|
this operation is termed in literature as depthwise convolution.
|
|
Args:
|
|
in_channels (int): Number of channels in the input
|
|
out_channels (int): Number of channels produced by the convolution
|
|
kernel_size (int or tuple): Size of the convolving kernel
|
|
stride (int, optional): Stride of the convolution. Default: 1
|
|
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
|
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
|
Inputs: inputs
|
|
- **inputs** (batch, in_channels, time): Tensor containing input vector
|
|
Returns: outputs
|
|
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: int,
|
|
stride: int = 1,
|
|
padding: int = 0,
|
|
bias: bool = False,
|
|
) -> None:
|
|
super(DepthwiseConv1d, self).__init__()
|
|
assert (
|
|
out_channels % in_channels == 0
|
|
), "out_channels should be constant multiple of in_channels"
|
|
self.conv = nn.Conv1d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
groups=in_channels,
|
|
stride=stride,
|
|
padding=padding,
|
|
bias=bias,
|
|
)
|
|
|
|
def forward(self, inputs):
|
|
return self.conv(inputs)
|
|
|
|
|
|
class ConvModule(nn.Module):
|
|
"""
|
|
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
|
|
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
|
|
to aid training deep models.
|
|
Args:
|
|
in_channels (int): Number of channels in the input
|
|
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
|
|
dropout_p (float, optional): probability of dropout
|
|
Inputs: inputs
|
|
inputs (batch, time, dim): Tensor contains input sequences
|
|
Outputs: outputs
|
|
outputs (batch, time, dim): Tensor produces by conformer convolution module.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
kernel_size: int = 17,
|
|
expansion_factor: int = 2,
|
|
dropout_p: float = 0.1,
|
|
) -> None:
|
|
super(ConvModule, self).__init__()
|
|
assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
|
|
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
|
|
|
|
self.sequential = nn.Sequential(
|
|
Transpose(shape=(1, 2)),
|
|
DepthwiseConv1d(
|
|
in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2
|
|
),
|
|
)
|
|
|
|
def forward(self, inputs):
|
|
return inputs + self.sequential(inputs).transpose(1, 2)
|
|
|
|
|
|
class OffsetScale(nn.Module):
|
|
def __init__(self, dim, heads=1):
|
|
super().__init__()
|
|
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
|
self.beta = nn.Parameter(torch.zeros(heads, dim))
|
|
nn.init.normal_(self.gamma, std=0.02)
|
|
|
|
def forward(self, x):
|
|
out = einsum("... d, h d -> ... h d", x, self.gamma) + self.beta
|
|
return out.unbind(dim=-2)
|
|
|
|
|
|
class FFConvM(nn.Module):
|
|
def __init__(self, dim_in, dim_out, norm_klass=nn.LayerNorm, dropout=0.1):
|
|
super().__init__()
|
|
self.mdl = nn.Sequential(
|
|
norm_klass(dim_in),
|
|
nn.Linear(dim_in, dim_out),
|
|
nn.SiLU(),
|
|
ConvModule(dim_out),
|
|
nn.Dropout(dropout),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
):
|
|
output = self.mdl(x)
|
|
return output
|
|
|
|
|
|
class FLASH_ShareA_FFConvM(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim,
|
|
group_size=256,
|
|
query_key_dim=128,
|
|
expansion_factor=1.0,
|
|
causal=False,
|
|
dropout=0.1,
|
|
rotary_pos_emb=None,
|
|
norm_klass=nn.LayerNorm,
|
|
shift_tokens=True
|
|
):
|
|
super().__init__()
|
|
hidden_dim = int(dim * expansion_factor)
|
|
self.group_size = group_size
|
|
self.causal = causal
|
|
self.shift_tokens = shift_tokens
|
|
|
|
# positional embeddings
|
|
self.rotary_pos_emb = rotary_pos_emb
|
|
# norm
|
|
self.dropout = nn.Dropout(dropout)
|
|
# projections
|
|
|
|
self.to_hidden = FFConvM(
|
|
dim_in=dim,
|
|
dim_out=hidden_dim,
|
|
norm_klass=norm_klass,
|
|
dropout=dropout,
|
|
)
|
|
self.to_qk = FFConvM(
|
|
dim_in=dim,
|
|
dim_out=query_key_dim,
|
|
norm_klass=norm_klass,
|
|
dropout=dropout,
|
|
)
|
|
|
|
self.qk_offset_scale = OffsetScale(query_key_dim, heads=4)
|
|
|
|
self.to_out = FFConvM(
|
|
dim_in=dim * 2,
|
|
dim_out=dim,
|
|
norm_klass=norm_klass,
|
|
dropout=dropout,
|
|
)
|
|
|
|
self.gateActivate = nn.Sigmoid()
|
|
|
|
def forward(self, x, *, mask=None):
|
|
"""
|
|
b - batch
|
|
n - sequence length (within groups)
|
|
g - group dimension
|
|
d - feature dimension (keys)
|
|
e - feature dimension (values)
|
|
i - sequence dimension (source)
|
|
j - sequence dimension (target)
|
|
"""
|
|
|
|
normed_x = x
|
|
|
|
# do token shift - a great, costless trick from an independent AI researcher in Shenzhen
|
|
residual = x
|
|
|
|
if self.shift_tokens:
|
|
x_shift, x_pass = normed_x.chunk(2, dim=-1)
|
|
x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.0)
|
|
normed_x = torch.cat((x_shift, x_pass), dim=-1)
|
|
|
|
# initial projections
|
|
|
|
v, u = self.to_hidden(normed_x).chunk(2, dim=-1)
|
|
qk = self.to_qk(normed_x)
|
|
|
|
# offset and scale
|
|
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
|
|
att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
|
|
out = (att_u * v) * self.gateActivate(att_v * u)
|
|
x = x + self.to_out(out)
|
|
return x
|
|
|
|
def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask=None):
|
|
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
|
|
|
|
if exists(mask):
|
|
lin_mask = rearrange(mask, "... -> ... 1")
|
|
lin_k = lin_k.masked_fill(~lin_mask, 0.0)
|
|
|
|
# rotate queries and keys
|
|
|
|
if exists(self.rotary_pos_emb):
|
|
quad_q, lin_q, quad_k, lin_k = map(
|
|
self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k)
|
|
)
|
|
|
|
# padding for groups
|
|
|
|
padding = padding_to_multiple_of(n, g)
|
|
|
|
if padding > 0:
|
|
quad_q, quad_k, lin_q, lin_k, v, u = map(
|
|
lambda t: F.pad(t, (0, 0, 0, padding), value=0.0),
|
|
(quad_q, quad_k, lin_q, lin_k, v, u),
|
|
)
|
|
|
|
mask = default(mask, torch.ones((b, n), device=device, dtype=torch.bool))
|
|
mask = F.pad(mask, (0, padding), value=False)
|
|
|
|
# group along sequence
|
|
|
|
quad_q, quad_k, lin_q, lin_k, v, u = map(
|
|
lambda t: rearrange(t, "b (g n) d -> b g n d", n=self.group_size),
|
|
(quad_q, quad_k, lin_q, lin_k, v, u),
|
|
)
|
|
|
|
if exists(mask):
|
|
mask = rearrange(mask, "b (g j) -> b g 1 j", j=g)
|
|
|
|
# calculate quadratic attention output
|
|
|
|
sim = einsum("... i d, ... j d -> ... i j", quad_q, quad_k) / g
|
|
|
|
attn = F.relu(sim) ** 2
|
|
attn = self.dropout(attn)
|
|
|
|
if exists(mask):
|
|
attn = attn.masked_fill(~mask, 0.0)
|
|
|
|
if self.causal:
|
|
causal_mask = torch.ones((g, g), dtype=torch.bool, device=device).triu(1)
|
|
attn = attn.masked_fill(causal_mask, 0.0)
|
|
|
|
quad_out_v = einsum("... i j, ... j d -> ... i d", attn, v)
|
|
quad_out_u = einsum("... i j, ... j d -> ... i d", attn, u)
|
|
|
|
# calculate linear attention output
|
|
|
|
if self.causal:
|
|
lin_kv = einsum("b g n d, b g n e -> b g d e", lin_k, v) / g
|
|
# exclusive cumulative sum along group dimension
|
|
lin_kv = lin_kv.cumsum(dim=1)
|
|
lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.0)
|
|
lin_out_v = einsum("b g d e, b g n d -> b g n e", lin_kv, lin_q)
|
|
|
|
lin_ku = einsum("b g n d, b g n e -> b g d e", lin_k, u) / g
|
|
# exclusive cumulative sum along group dimension
|
|
lin_ku = lin_ku.cumsum(dim=1)
|
|
lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value=0.0)
|
|
lin_out_u = einsum("b g d e, b g n d -> b g n e", lin_ku, lin_q)
|
|
else:
|
|
lin_kv = einsum("b g n d, b g n e -> b d e", lin_k, v) / n
|
|
lin_out_v = einsum("b g n d, b d e -> b g n e", lin_q, lin_kv)
|
|
|
|
lin_ku = einsum("b g n d, b g n e -> b d e", lin_k, u) / n
|
|
lin_out_u = einsum("b g n d, b d e -> b g n e", lin_q, lin_ku)
|
|
|
|
# fold back groups into full sequence, and excise out padding
|
|
return map(
|
|
lambda t: rearrange(t, "b g n d -> b (g n) d")[:, :n],
|
|
(quad_out_v + lin_out_v, quad_out_u + lin_out_u),
|
|
)
|