FunASR/funasr/models/sense_voice/rwkv_v5.py

598 lines
21 KiB
Python

########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################
import os, math, gc, importlib
import torch
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
import torch.nn as nn
from torch.nn import functional as F
def __nop(ob):
return ob
MyModule = nn.Module
MyFunction = __nop
if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1":
MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
########################################################################################################
# CUDA Kernel
########################################################################################################
wkv5_cuda = None
def load_rwkv_kernel(
HEAD_SIZE: int = 64,
RWKV_CTXLEN: int = 512,
):
from torch.utils.cpp_extension import load
global wkv5_cuda
if wkv5_cuda is not None:
return
absolute_file_path = os.path.abspath(__file__)
cur_dir = os.path.dirname(absolute_file_path)
wkv5_cuda = load(
name="wkv5",
sources=[f"{cur_dir}/cuda/wkv5_op.cpp", f"{cur_dir}/cuda/wkv5_cuda.cu"],
verbose=True,
extra_cuda_cflags=[
"-res-usage",
"--use_fast_math",
"-O3",
"-Xptxas -O3",
"--extra-device-vectorization",
f"-D_N_={HEAD_SIZE}",
],
)
# dtype = torch.float
dtype = torch.bfloat16
class WKV_5(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, H, r, k, v, w, u):
with torch.no_grad():
# assert r.dtype == torch.bfloat16
# assert k.dtype == torch.bfloat16
# assert v.dtype == torch.bfloat16
# assert w.dtype == torch.bfloat16
# assert u.dtype == torch.bfloat16
# assert HEAD_SIZE == C // H
ctx.B = B
ctx.T = T
ctx.C = C
ctx.H = H
assert r.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
assert w.is_contiguous()
assert u.is_contiguous()
ew = (-torch.exp(w.float())).contiguous()
eew = (torch.exp(ew)).contiguous()
ctx.save_for_backward(r, k, v, eew, ew, u)
y = torch.empty(
(B, T, C),
device=r.device,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-1, 1)
wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y)
return y
@staticmethod
def backward(ctx, gy):
with torch.no_grad():
assert gy.dtype == torch.bfloat16
B = ctx.B
T = ctx.T
C = ctx.C
H = ctx.H
assert gy.is_contiguous()
r, k, v, eew, ew, u = ctx.saved_tensors
gr = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-1, 1)
gk = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-1, 1)
gv = torch.empty(
(B, T, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-1, 1)
gw = torch.empty(
(B, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-1, 1)
gu = torch.empty(
(B, C),
device=gy.device,
requires_grad=False,
dtype=torch.bfloat16,
memory_format=torch.contiguous_format,
) # .uniform_(-1, 1)
wkv5_cuda.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu)
gw = torch.sum(gw, 0).view(H, C // H)
gu = torch.sum(gu, 0).view(H, C // H)
return (None, None, None, None, gr, gk, gv, gw, gu)
def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u):
return WKV_5.apply(B, T, C, H, r, k, v, w, u)
class RWKV_Tmix_x052(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
load_rwkv_kernel(args.head_size_a, args.ctx_len)
self.layer_id = layer_id
self.head_size = args.head_size_a
# assert HEAD_SIZE == self.head_size # change HEAD_SIZE to match args.head_size_a
self.n_head = args.dim_att // self.head_size
assert args.dim_att % self.n_head == 0
self.head_size_divisor = args.head_size_divisor
with torch.no_grad():
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
# fancy time_mix
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
# fancy time_decay
decay_speed = torch.ones(args.dim_att)
for n in range(args.dim_att):
decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
self.time_decay = nn.Parameter(decay_speed.reshape(self.n_head, self.head_size))
# print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
tmp = torch.zeros(args.dim_att)
for n in range(args.dim_att):
zigzag = ((n + 1) % 3 - 1) * 0.1
tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
self.ln_x = nn.GroupNorm(self.n_head, args.dim_att)
@MyFunction
def jit_func(self, x):
B, T, C = x.size()
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
xg = x * self.time_mix_g + xx * (1 - self.time_mix_g)
r = self.receptance(xr)
k = self.key(xk)
v = self.value(xv)
g = F.silu(self.gate(xg))
return r, k, v, g
@MyFunction
def jit_func_2(self, x, g):
B, T, C = x.size()
x = x.view(B * T, C)
x = self.ln_x(x / self.head_size_divisor).view(B, T, C)
x = self.output(x * g)
return x
def forward(self, x, **kwargs):
B, T, C = x.size()
H = self.n_head
r, k, v, g = self.jit_func(x)
x = RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa)
return self.jit_func_2(x, g)
#
# class RWKV_Tmix_x060(MyModule):
# def __init__(self, args, layer_id):
# super().__init__()
# self.args = args
#
# load_rwkv_kernel(args.head_size_a, args.ctx_len)
#
# self.layer_id = layer_id
#
# self.head_size = args.head_size_a
# self.n_head = args.dim_att // self.head_size
# assert args.dim_att % self.n_head == 0
#
# with torch.no_grad():
# ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
# ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
# ddd = torch.ones(1, 1, args.n_embd)
# for i in range(args.n_embd):
# ddd[0, 0, i] = i / args.n_embd
#
# # fancy time_mix
# self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
# self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
# self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
# self.time_maa_v = nn.Parameter(
# 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
# )
# self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
# self.time_maa_g = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
#
# D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g
# self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5))
# self.time_maa_w2 = nn.Parameter(
# torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01)
# )
#
# # fancy time_decay
# decay_speed = torch.ones(args.dim_att)
# for n in range(args.dim_att):
# decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
# self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
#
# D_DECAY_LORA = 64
# self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA))
# self.time_decay_w2 = nn.Parameter(
# torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01)
# )
#
# tmp = torch.zeros(args.dim_att)
# for n in range(args.dim_att):
# zigzag = ((n + 1) % 3 - 1) * 0.1
# tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
#
# self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
#
# self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
# self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
# self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
#
# self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
# self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
# self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
# self.ln_x = nn.GroupNorm(
# self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2)
# )
#
# @MyFunction
# def jit_func(self, x):
# B, T, C = x.size()
#
# xx = self.time_shift(x) - x
#
# xxx = x + xx * self.time_maa_x
# xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
# xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
# mw, mk, mv, mr, mg = xxx.unbind(dim=0)
#
# xw = x + xx * (self.time_maa_w + mw)
# xk = x + xx * (self.time_maa_k + mk)
# xv = x + xx * (self.time_maa_v + mv)
# xr = x + xx * (self.time_maa_r + mr)
# xg = x + xx * (self.time_maa_g + mg)
#
# r = self.receptance(xr)
# k = self.key(xk)
# v = self.value(xv)
# g = F.silu(self.gate(xg))
#
# ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
# w = self.time_decay + ww
#
# return r, k, v, g, w
#
# @MyFunction
# def jit_func_2(self, x, g):
# B, T, C = x.size()
# x = x.view(B * T, C)
#
# x = self.ln_x(x).view(B, T, C)
# x = self.output(x * g)
# return x
#
# def forward(self, x):
# B, T, C = x.size()
# H = self.n_head
#
# r, k, v, g, w = self.jit_func(x)
# x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa)
#
# return self.jit_func_2(x, g)
class RWKV_CMix_x052(MyModule):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
@MyFunction
def forward(self, x):
xx = self.time_shift(x)
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
k = self.key(xk)
k = torch.relu(k) ** 2
kv = self.value(k)
return torch.sigmoid(self.receptance(xr)) * kv
# class RWKV_CMix_x060(MyModule):
# def __init__(self, args, layer_id):
# super().__init__()
# self.args = args
# self.layer_id = layer_id
# self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
#
# with torch.no_grad(): # fancy init of time_mix
# ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
# ddd = torch.ones(1, 1, args.n_embd)
# for i in range(args.n_embd):
# ddd[0, 0, i] = i / args.n_embd
# self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
# self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
#
# self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
# self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
# self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
#
# @MyFunction
# def forward(self, x):
# xx = self.time_shift(x) - x
# xk = x + xx * self.time_maa_k
# xr = x + xx * self.time_maa_r
#
# k = self.key(xk)
# k = torch.relu(k) ** 2
# kv = self.value(k)
# return torch.sigmoid(self.receptance(xr)) * kv
# class Block(nn.Module):
# def __init__(self, args, layer_id):
# super().__init__()
# self.args = args
# self.layer_id = layer_id
#
# self.ln1 = nn.LayerNorm(args.n_embd)
# self.ln2 = nn.LayerNorm(args.n_embd)
#
# if self.layer_id == 0:
# self.ln0 = nn.LayerNorm(args.n_embd)
#
# self.att = RWKV_Tmix_x060(args, layer_id)
#
# self.ffn = RWKV_CMix_x060(args, layer_id)
#
# if args.dropout > 0:
# self.drop0 = nn.Dropout(p=args.dropout)
# self.drop1 = nn.Dropout(p=args.dropout)
#
# def forward(self, x, x_emb=None):
# args = self.args
# B, T, C = x.size()
# if self.layer_id == 0:
# x = self.ln0(x)
#
# if self.args.dropout == 0:
# if self.layer_id == 0 and args.pre_ffn > 0:
# x = x + self.ffnPre(self.ln1(x))
# else:
# x = x + self.att(self.ln1(x))
# x = x + self.ffn(self.ln2(x))
# else:
# if self.layer_id == 0 and args.pre_ffn > 0:
# x = self.drop0(x + self.ffnPre(self.ln1(x)))
# else:
# x = self.drop0(x + self.att(self.ln1(x)))
# x = self.drop1(x + self.ffn(self.ln2(x)))
#
# return x
class RWKVLayer(nn.Module):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
if args.dim_ffn is None:
args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
self.ln0 = None
if self.layer_id == 0 and args.get("ln0", True):
self.ln0 = nn.LayerNorm(args.n_embd)
self.ln1 = None
if args.get("ln1", True):
self.ln1 = nn.LayerNorm(args.n_embd)
self.att = RWKV_Tmix_x052(args, layer_id)
self.ln2 = None
self.ffn = None
if args.get("use_rwkv_ffn", True):
self.ln2 = nn.LayerNorm(args.n_embd)
self.ffn = RWKV_CMix_x052(args, layer_id)
if args.dropout > 0:
self.drop0 = nn.Dropout(p=args.dropout)
self.drop1 = nn.Dropout(p=args.dropout)
# init
if args.get("init_rwkv", True):
print("init_rwkv")
nn.init.orthogonal_(self.att.receptance.weight, gain=1)
nn.init.orthogonal_(self.att.key.weight, gain=0.1)
nn.init.orthogonal_(self.att.value.weight, gain=1)
nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
nn.init.zeros_(self.att.output.weight)
nn.init.orthogonal_(self.ffn.key.weight, gain=1)
nn.init.zeros_(self.ffn.value.weight)
nn.init.zeros_(self.ffn.receptance.weight)
scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
nn.init.constant_(self.ln2.weight, scale)
if self.ln0 is not None:
nn.init.constant_(self.ln0.weight, scale)
if self.ln1 is not None:
nn.init.constant_(self.ln1.weight, scale)
def forward(self, x, x_emb=None, mask=None, **kwargs):
args = self.args
if args.get("datatype", "bf16") == "bf16":
x = x.bfloat16()
B, T, C = x.size()
if self.layer_id == 0 and self.ln0 is not None:
x = self.ln0(x)
if self.args.dropout == 0:
if self.ln1 is None:
x = x + self.att(x)
else:
x = x + self.att(self.ln1(x))
if self.ffn is not None:
x = x + self.ffn(self.ln2(x))
else:
if self.ln1 is None:
x = self.drop0(x + self.att(x))
else:
x = self.drop0(x + self.att(self.ln1(x)))
if self.ffn is not None:
x = self.drop1(x + self.ffn(self.ln2(x)))
if args.get("datatype", "bf16") == "bf16":
x = x.to(torch.float32)
return x
# class RWKV(nn.Module):
# def __init__(self, args):
# super().__init__()
# self.args = args
# if not hasattr(args, "dim_att"):
# args.dim_att = args.n_embd
# if not hasattr(args, "dim_ffn"):
# if "-f4" in os.environ["RWKV_MY_TESTING"]:
# args.dim_ffn = int((args.n_embd * 4) // 32 * 32)
# else:
# args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size
# if not hasattr(args, "tiny_att_layer"):
# args.tiny_att_layer = -1
# if not hasattr(args, "tiny_att_dim"):
# args.tiny_att_dim = -1
# assert args.n_embd % 32 == 0
# assert args.dim_att % 32 == 0
# assert args.dim_ffn % 32 == 0
#
# self.emb = nn.Embedding(args.vocab_size, args.n_embd)
#
# self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
#
# self.ln_out = nn.LayerNorm(args.n_embd)
# self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
#
# if args.dropout > 0:
# self.drop0 = nn.Dropout(p=args.dropout)
#
# def forward(self, idx):
# args = self.args
# B, T = idx.size()
# assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
#
# x = self.emb(idx)
# x_emb = x
#
# if args.dropout > 0:
# x = self.drop0(x)
# if args.tiny_att_dim > 0:
# for block in self.blocks:
# if args.grad_cp == 1:
# x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
# else:
# x = block(x, x_emb)
# else:
# for block in self.blocks:
# if args.grad_cp == 1:
# x = deepspeed.checkpointing.checkpoint(block, x)
# else:
# x = block(x)
#
# x = self.ln_out(x)
#
# if args.head_qk > 0:
# q = self.head_q(x)[:, :T, :]
# k = self.head_k(x)[:, :T, :]
# c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
# c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
#
# if "32" in os.environ["RWKV_FLOAT_MODE"]:
# c = c @ F.one_hot(idx, num_classes=args.vocab_size)
# elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
# c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
# elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
# c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
#
# x = self.head(x) + c
# else:
# x = self.head(x)
#
# return x