343 lines
12 KiB
Python
343 lines
12 KiB
Python
# ------------------------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
|
# ------------------------------------------------------------------------------------------
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import math
|
|
from typing import Optional, List
|
|
|
|
|
|
class LoRALayer:
|
|
def __init__(
|
|
self,
|
|
r: int,
|
|
lora_alpha: int,
|
|
lora_dropout: float,
|
|
merge_weights: bool,
|
|
):
|
|
self.r = r
|
|
self.lora_alpha = lora_alpha
|
|
# Optional dropout
|
|
if lora_dropout > 0.0:
|
|
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
|
else:
|
|
self.lora_dropout = lambda x: x
|
|
# Mark the weight as unmerged
|
|
self.merged = False
|
|
self.merge_weights = merge_weights
|
|
|
|
|
|
class Embedding(nn.Embedding, LoRALayer):
|
|
# LoRA implemented in a dense layer
|
|
def __init__(
|
|
self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
r: int = 0,
|
|
lora_alpha: int = 1,
|
|
merge_weights: bool = True,
|
|
**kwargs
|
|
):
|
|
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
|
|
LoRALayer.__init__(
|
|
self, r=r, lora_alpha=lora_alpha, lora_dropout=0, merge_weights=merge_weights
|
|
)
|
|
# Actual trainable parameters
|
|
if r > 0:
|
|
self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
|
|
self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
|
|
self.scaling = self.lora_alpha / self.r
|
|
# Freezing the pre-trained weight matrix
|
|
self.weight.requires_grad = False
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
nn.Embedding.reset_parameters(self)
|
|
if hasattr(self, "lora_A"):
|
|
# initialize A the same way as the default for nn.Linear and B to zero
|
|
nn.init.zeros_(self.lora_A)
|
|
nn.init.normal_(self.lora_B)
|
|
|
|
def train(self, mode: bool = True):
|
|
nn.Embedding.train(self, mode)
|
|
if self.merge_weights and self.merged:
|
|
# Make sure that the weights are not merged
|
|
if self.r > 0:
|
|
self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling
|
|
self.merged = False
|
|
|
|
def eval(self):
|
|
nn.Linear.eval(self)
|
|
if self.merge_weights and not self.merged:
|
|
# Merge the weights and mark it
|
|
if self.r > 0:
|
|
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
|
self.merged = True
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if self.r > 0 and not self.merged:
|
|
result = nn.Embedding.forward(self, x)
|
|
if self.r > 0:
|
|
after_A = F.embedding(
|
|
x,
|
|
self.lora_A.T,
|
|
self.padding_idx,
|
|
self.max_norm,
|
|
self.norm_type,
|
|
self.scale_grad_by_freq,
|
|
self.sparse,
|
|
)
|
|
result += (after_A @ self.lora_B.T) * self.scaling
|
|
return result
|
|
else:
|
|
return nn.Embedding.forward(self, x)
|
|
|
|
|
|
class Linear(nn.Linear, LoRALayer):
|
|
# LoRA implemented in a dense layer
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
r: int = 0,
|
|
lora_alpha: int = 1,
|
|
lora_dropout: float = 0.0,
|
|
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
|
merge_weights: bool = True,
|
|
**kwargs
|
|
):
|
|
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
|
LoRALayer.__init__(
|
|
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
|
|
)
|
|
|
|
self.fan_in_fan_out = fan_in_fan_out
|
|
# Actual trainable parameters
|
|
if r > 0:
|
|
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
|
|
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
|
|
self.scaling = self.lora_alpha / self.r
|
|
# Freezing the pre-trained weight matrix
|
|
self.weight.requires_grad = False
|
|
self.reset_parameters()
|
|
if fan_in_fan_out:
|
|
self.weight.data = self.weight.data.T
|
|
|
|
def reset_parameters(self):
|
|
nn.Linear.reset_parameters(self)
|
|
if hasattr(self, "lora_A"):
|
|
# initialize A the same way as the default for nn.Linear and B to zero
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
nn.init.zeros_(self.lora_B)
|
|
|
|
def train(self, mode: bool = True):
|
|
def T(w):
|
|
return w.T if self.fan_in_fan_out else w
|
|
|
|
nn.Linear.train(self, mode)
|
|
if self.merge_weights and self.merged:
|
|
# Make sure that the weights are not merged
|
|
if self.r > 0:
|
|
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
|
self.merged = False
|
|
|
|
def eval(self):
|
|
def T(w):
|
|
return w.T if self.fan_in_fan_out else w
|
|
|
|
nn.Linear.eval(self)
|
|
if self.merge_weights and not self.merged:
|
|
# Merge the weights and mark it
|
|
if self.r > 0:
|
|
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
|
self.merged = True
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
def T(w):
|
|
return w.T if self.fan_in_fan_out else w
|
|
|
|
if self.r > 0 and not self.merged:
|
|
result = F.linear(x, T(self.weight), bias=self.bias)
|
|
if self.r > 0:
|
|
result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
|
|
return result
|
|
else:
|
|
return F.linear(x, T(self.weight), bias=self.bias)
|
|
|
|
|
|
class MergedLinear(nn.Linear, LoRALayer):
|
|
# LoRA implemented in a dense layer
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
out_features: int,
|
|
r: int = 0,
|
|
lora_alpha: int = 1,
|
|
lora_dropout: float = 0.0,
|
|
enable_lora: List[bool] = [False],
|
|
fan_in_fan_out: bool = False,
|
|
merge_weights: bool = True,
|
|
**kwargs
|
|
):
|
|
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
|
LoRALayer.__init__(
|
|
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
|
|
)
|
|
assert (
|
|
out_features % len(enable_lora) == 0
|
|
), "The length of enable_lora must divide out_features"
|
|
self.enable_lora = enable_lora
|
|
self.fan_in_fan_out = fan_in_fan_out
|
|
# Actual trainable parameters
|
|
if r > 0 and any(enable_lora):
|
|
self.lora_A = nn.Parameter(self.weight.new_zeros((r * sum(enable_lora), in_features)))
|
|
self.lora_B = nn.Parameter(
|
|
self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
|
|
) # weights for Conv1D with groups=sum(enable_lora)
|
|
self.scaling = self.lora_alpha / self.r
|
|
# Freezing the pre-trained weight matrix
|
|
self.weight.requires_grad = False
|
|
# Compute the indices
|
|
self.lora_ind = self.weight.new_zeros((out_features,), dtype=torch.bool).view(
|
|
len(enable_lora), -1
|
|
)
|
|
self.lora_ind[enable_lora, :] = True
|
|
self.lora_ind = self.lora_ind.view(-1)
|
|
self.reset_parameters()
|
|
if fan_in_fan_out:
|
|
self.weight.data = self.weight.data.T
|
|
|
|
def reset_parameters(self):
|
|
nn.Linear.reset_parameters(self)
|
|
if hasattr(self, "lora_A"):
|
|
# initialize A the same way as the default for nn.Linear and B to zero
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
nn.init.zeros_(self.lora_B)
|
|
|
|
def zero_pad(self, x):
|
|
result = x.new_zeros((*x.shape[:-1], self.out_features))
|
|
result = result.view(-1, self.out_features)
|
|
result[:, self.lora_ind] = x.reshape(
|
|
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
|
|
)
|
|
return result.view((*x.shape[:-1], self.out_features))
|
|
|
|
def train(self, mode: bool = True):
|
|
def T(w):
|
|
return w.T if self.fan_in_fan_out else w
|
|
|
|
nn.Linear.train(self, mode)
|
|
if self.merge_weights and self.merged:
|
|
# Make sure that the weights are not merged
|
|
if self.r > 0 and any(self.enable_lora):
|
|
delta_w = F.conv1d(
|
|
self.lora_A.data.unsqueeze(0),
|
|
self.lora_B.data.unsqueeze(-1),
|
|
groups=sum(self.enable_lora),
|
|
).squeeze(0)
|
|
self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
|
|
self.merged = False
|
|
|
|
def eval(self):
|
|
def T(w):
|
|
return w.T if self.fan_in_fan_out else w
|
|
|
|
nn.Linear.eval(self)
|
|
if self.merge_weights and not self.merged:
|
|
# Merge the weights and mark it
|
|
if self.r > 0 and any(self.enable_lora):
|
|
delta_w = F.conv1d(
|
|
self.lora_A.data.unsqueeze(0),
|
|
self.lora_B.data.unsqueeze(-1),
|
|
groups=sum(self.enable_lora),
|
|
).squeeze(0)
|
|
self.weight.data += self.zero_pad(T(delta_w * self.scaling))
|
|
self.merged = True
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
def T(w):
|
|
return w.T if self.fan_in_fan_out else w
|
|
|
|
if self.merged:
|
|
return F.linear(x, T(self.weight), bias=self.bias)
|
|
else:
|
|
result = F.linear(x, T(self.weight), bias=self.bias)
|
|
if self.r > 0:
|
|
after_A = F.linear(self.lora_dropout(x), self.lora_A)
|
|
after_B = F.conv1d(
|
|
after_A.transpose(-2, -1),
|
|
self.lora_B.unsqueeze(-1),
|
|
groups=sum(self.enable_lora),
|
|
).transpose(-2, -1)
|
|
result += self.zero_pad(after_B) * self.scaling
|
|
return result
|
|
|
|
|
|
class Conv2d(nn.Conv2d, LoRALayer):
|
|
# LoRA implemented in a dense layer
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: int,
|
|
r: int = 0,
|
|
lora_alpha: int = 1,
|
|
lora_dropout: float = 0.0,
|
|
merge_weights: bool = True,
|
|
**kwargs
|
|
):
|
|
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
|
|
LoRALayer.__init__(
|
|
self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights
|
|
)
|
|
assert type(kernel_size) is int
|
|
# Actual trainable parameters
|
|
if r > 0:
|
|
self.lora_A = nn.Parameter(
|
|
self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
|
|
)
|
|
self.lora_B = nn.Parameter(
|
|
self.weight.new_zeros((out_channels * kernel_size, r * kernel_size))
|
|
)
|
|
self.scaling = self.lora_alpha / self.r
|
|
# Freezing the pre-trained weight matrix
|
|
self.weight.requires_grad = False
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
nn.Conv2d.reset_parameters(self)
|
|
if hasattr(self, "lora_A"):
|
|
# initialize A the same way as the default for nn.Linear and B to zero
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
|
nn.init.zeros_(self.lora_B)
|
|
|
|
def train(self, mode: bool = True):
|
|
nn.Conv2d.train(self, mode)
|
|
if self.merge_weights and self.merged:
|
|
# Make sure that the weights are not merged
|
|
self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
|
|
self.merged = False
|
|
|
|
def eval(self):
|
|
nn.Conv2d.eval(self)
|
|
if self.merge_weights and not self.merged:
|
|
# Merge the weights and mark it
|
|
self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
|
|
self.merged = True
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if self.r > 0 and not self.merged:
|
|
return F.conv2d(
|
|
x,
|
|
self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
|
|
self.bias,
|
|
self.stride,
|
|
self.padding,
|
|
self.dilation,
|
|
self.groups,
|
|
)
|
|
return nn.Conv2d.forward(self, x)
|