140 lines
5.0 KiB
Python
140 lines
5.0 KiB
Python
"""Fastformer attention definition.
|
|
|
|
Reference:
|
|
Wu et al., "Fastformer: Additive Attention Can Be All You Need"
|
|
https://arxiv.org/abs/2108.09084
|
|
https://github.com/wuch15/Fastformer
|
|
|
|
"""
|
|
|
|
import numpy
|
|
import torch
|
|
|
|
|
|
class FastSelfAttention(torch.nn.Module):
|
|
"""Fast self-attention used in Fastformer."""
|
|
|
|
def __init__(
|
|
self,
|
|
size,
|
|
attention_heads,
|
|
dropout_rate,
|
|
):
|
|
super().__init__()
|
|
if size % attention_heads != 0:
|
|
raise ValueError(
|
|
f"Hidden size ({size}) is not an integer multiple "
|
|
f"of attention heads ({attention_heads})"
|
|
)
|
|
self.attention_head_size = size // attention_heads
|
|
self.num_attention_heads = attention_heads
|
|
|
|
self.query = torch.nn.Linear(size, size)
|
|
self.query_att = torch.nn.Linear(size, attention_heads)
|
|
self.key = torch.nn.Linear(size, size)
|
|
self.key_att = torch.nn.Linear(size, attention_heads)
|
|
self.transform = torch.nn.Linear(size, size)
|
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
|
|
|
def espnet_initialization_fn(self):
|
|
self.apply(self.init_weights)
|
|
|
|
def init_weights(self, module):
|
|
if isinstance(module, torch.nn.Linear):
|
|
module.weight.data.normal_(mean=0.0, std=0.02)
|
|
if isinstance(module, torch.nn.Linear) and module.bias is not None:
|
|
module.bias.data.zero_()
|
|
|
|
def transpose_for_scores(self, x):
|
|
"""Reshape and transpose to compute scores.
|
|
|
|
Args:
|
|
x: (batch, time, size = n_heads * attn_dim)
|
|
|
|
Returns:
|
|
(batch, n_heads, time, attn_dim)
|
|
"""
|
|
|
|
new_x_shape = x.shape[:-1] + (
|
|
self.num_attention_heads,
|
|
self.attention_head_size,
|
|
)
|
|
return x.reshape(*new_x_shape).transpose(1, 2)
|
|
|
|
def forward(self, xs_pad, mask):
|
|
"""Forward method.
|
|
|
|
Args:
|
|
xs_pad: (batch, time, size = n_heads * attn_dim)
|
|
mask: (batch, 1, time), nonpadding is 1, padding is 0
|
|
|
|
Returns:
|
|
torch.Tensor: (batch, time, size)
|
|
"""
|
|
|
|
batch_size, seq_len, _ = xs_pad.shape
|
|
mixed_query_layer = self.query(xs_pad) # (batch, time, size)
|
|
mixed_key_layer = self.key(xs_pad) # (batch, time, size)
|
|
|
|
if mask is not None:
|
|
mask = mask.eq(0) # padding is 1, nonpadding is 0
|
|
|
|
# (batch, n_heads, time)
|
|
query_for_score = (
|
|
self.query_att(mixed_query_layer).transpose(1, 2) / self.attention_head_size**0.5
|
|
)
|
|
if mask is not None:
|
|
min_value = float(
|
|
numpy.finfo(torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype).min
|
|
)
|
|
query_for_score = query_for_score.masked_fill(mask, min_value)
|
|
query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0)
|
|
else:
|
|
query_weight = torch.softmax(query_for_score, dim=-1)
|
|
|
|
query_weight = query_weight.unsqueeze(2) # (batch, n_heads, 1, time)
|
|
query_layer = self.transpose_for_scores(
|
|
mixed_query_layer
|
|
) # (batch, n_heads, time, attn_dim)
|
|
|
|
pooled_query = (
|
|
torch.matmul(query_weight, query_layer)
|
|
.transpose(1, 2)
|
|
.reshape(-1, 1, self.num_attention_heads * self.attention_head_size)
|
|
) # (batch, 1, size = n_heads * attn_dim)
|
|
pooled_query = self.dropout(pooled_query)
|
|
pooled_query_repeat = pooled_query.repeat(1, seq_len, 1) # (batch, time, size)
|
|
|
|
mixed_query_key_layer = mixed_key_layer * pooled_query_repeat # (batch, time, size)
|
|
|
|
# (batch, n_heads, time)
|
|
query_key_score = (
|
|
self.key_att(mixed_query_key_layer) / self.attention_head_size**0.5
|
|
).transpose(1, 2)
|
|
if mask is not None:
|
|
min_value = float(
|
|
numpy.finfo(torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype).min
|
|
)
|
|
query_key_score = query_key_score.masked_fill(mask, min_value)
|
|
query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(mask, 0.0)
|
|
else:
|
|
query_key_weight = torch.softmax(query_key_score, dim=-1)
|
|
|
|
query_key_weight = query_key_weight.unsqueeze(2) # (batch, n_heads, 1, time)
|
|
key_layer = self.transpose_for_scores(
|
|
mixed_query_key_layer
|
|
) # (batch, n_heads, time, attn_dim)
|
|
pooled_key = torch.matmul(query_key_weight, key_layer) # (batch, n_heads, 1, attn_dim)
|
|
pooled_key = self.dropout(pooled_key)
|
|
|
|
# NOTE: value = query, due to param sharing
|
|
weighted_value = (pooled_key * query_layer).transpose(
|
|
1, 2
|
|
) # (batch, time, n_heads, attn_dim)
|
|
weighted_value = weighted_value.reshape(
|
|
weighted_value.shape[:-2] + (self.num_attention_heads * self.attention_head_size,)
|
|
) # (batch, time, size)
|
|
weighted_value = self.dropout(self.transform(weighted_value)) + mixed_query_layer
|
|
|
|
return weighted_value
|