LoRA原理与实现--PyTorch自己搭建LoRA模型

admin1年前笔记121

一、前言

在AIGC领域频繁出现着一个特殊名词“LoRA”,听上去有点像人名,但是这是一种模型训练的方法。LoRA全称Low-Rank Adaptation of Large Language Models,中文叫做大语言模型的低阶适应。如今在stable diffusion中用地非常频繁。

由于大语言模型的参数量巨大,许多大公司都需要训练数月,由此提出了各种资源消耗较小的训练方法,LoRA就是其中一种。

本文将详细介绍LoRA的原理,并使用PyTorch实现小模型的LoRA训练。

二、模型训练

现在大多数模型训练都是采用梯度下降算法。梯度下降算法可以分为下面4个步骤:

  1. 正向传播计算损失值

  2. 反向传播计算梯度

  3. 利用梯度更新参数

  4. 重复1、2、3的步骤,直到获取较小的损失

以线性模型为例,模型参数为W,输入输出为x、y,损失函数以均方误差为例。那么各个步骤的计算如下,首先是正向传播,对于线性模型来说就是做一个矩阵乘法:

L=MSE(Wx,y)

在求出损失后,可以计算L对W的梯度,得到dW:

dW=LW

dW是一个矩阵,它会指向L上升最快的方向,但是我们的目的是让L下降,因此让W减去dW。为了调整更新的步伐,还会乘上一个学习率η,计算如下:

W=WηdW

最后一直重复即刻。上述三个步骤的伪代码如下:

image.png

在更新完成后,得到新的参数W'。此时我们使用模型预测时,计算如下:

pred=Wx

三、引入LoRA

我们可以来思考一下W和W'之间的关系。W通常指基础模型的参数,而W'是在基础模型的基础上,经过几次矩阵加减得到的。假设在训练的过程中更新了10次,每次的dW分别为dW1、dW2、....、dW10,那么完整的更新过程可以写为一次运算:

W=WηdW1ηdW2...ηdW10令:dW=i=110dWiW=WηdW

其中dW是一个形状与W'一致的矩阵。我们把-ηdW写成矩阵R,那么更新后的参数就是:

W=W+R

此时训练的过程就被简化为原矩阵加上另一个矩阵R。但是求解矩阵R并没有更简单,而且也没有节约资源,此时就引出LoRA了这一思想。

一个训练充分的矩阵,通常是满秩或者基本满足秩的,即矩阵中没有一列是多余的。在论文《Scaling Laws for Neural Language Model》中提出了数据集与参数大小之间的关系,满足该关系且训练良好,得到的模型是基本满秩的。在微调模型时,我们会选取一个底模,该底模就是基本满秩的。而更新矩阵R秩的情况是如何的呢?

我们假定R矩阵是一个低秩矩阵,低秩矩阵有许多重复的列,因此可以分解为两个更小的矩阵。假如W的形状为m×n,那么A的形状也是m×n,我们把矩阵R分解为AB(其中A形状为m×r,B形状为r×N),r通常会选取一个远小于m、n的值,如图所示:

image.png

将低秩矩阵分解为两个矩阵几点好处,首先是参数量明显减少。假设R矩阵的形状为100×100,那么R的参数量为10000。当我们选取秩为10时,此时矩阵A的形状为100×10,矩阵B的形状为10×100,此时参数量为2000,比R矩阵少了80%。

而且由于R是低秩矩阵,所以在训练充分的情况下,A和B矩阵可以达到R的效果。这里的矩阵AB就是我们常说的LoRA模型。

在引入LoRA后,我们的预测需要将x分别输入W和AB,此时预测的计算为:

pred=Wx+ABx

在预测时会比原始模型稍慢,但是在大模型中基本感觉不到差异。

四、实战

为了把握各个细节,这里不使用大模型作为lora的实战,而是选择使用vgg19这种小型网络来训练lora模型。导入需要用到的模块:

image.png

4.1 数据集准备

这里使用vgg19在imagenet上的预训练权重作为底模,因此需要准备分类数据集。为了方便,这里只准备了一个类别,且只准备了5张图片,图片在项目下的data/goldfish下:

image.png

在imagenet中包含了goldfish类别,但是这里选取的是插画版的goldfish,经过测试,预训练模型不能将上述图片正确分类。我们的目的就是训练LoRA,让模型正确分类。

我们创建一个LoraDataset:

image.png

4.2 创建LoRA模型

我们把LoRA封装成一个层,LoRA中只有两个需要训练的矩阵,LoRA的代码如下:

image.png

其中m是输入的大小,n是输出的大小,rank是秩的大小,我们可以设置一个较小的值。

在权重初始化时,我们把A用高斯噪声初始化,而B用0矩阵初始化,这样的目的是保证从底模开始训练。因为AB是0矩阵,所以初始状态下,LoRA不起作用。

4.3 设置超参数并训练

接下来就是训练了,这里和PyTorch常规训练代码基本一致,先看代码:

image.png

这里有两点需要注意,第一点是我们把vgg19的权重设置为不可训练,这和迁移学习很像,但其实是不一样的。

第二点则是正向传播时,我们使用了下面代码:

image.png

4.4 测试

下面来简单测试一下:

image.png

[object Object]

五、总结

LoRA是针对大模型的一种高效的训练方法,而本文则将LoRA使用在小型的分类网络中,旨在让读者更清晰认识LoRA的详细实现(同时也因为跑不动大模型)。限于数据量,对LoRA的精度效率等问题没有详细讨论,读者可以参考相关资料深入了解。


相关文章

25个Linux服务器安全小技巧

大家都认为Linux 默认是安全的,我大体是认可的(这是个有争议的话题)。Linux默认确实有内置的安全模型。你需要打开它并且对其进行定制,这样才能得到更安全的系统。Linux更难管理,不过相应也更灵...

详解linux目录结构

详解linux目录结构

[root@bogon /]# ls -l total 94 dr-xr-xr-x.   2 root r...

网络地址转换(NAT)的报文跟踪

网络地址转换(NAT)的报文跟踪

这是有关网络地址转换network address translation(NAT)的系列文章中的第一篇。这一部分将展示如何使用 iptables/nftables 报文跟踪功能来定位 NAT 相关的...

教你搭建你自己的Git服务器

教你搭建你自己的Git服务器

直到现在,我们主要讨论的还是以一个使用者的身份与 Git 进行交互。这篇文章中我将讨论 Git 的管理,并且设计一个灵活的 Git 框架。你可能会觉得这听起来是 “高阶 Git 技术” 或者 “只有狂...

gdb 调试利器

GDB是一个由GNU开源组织发布的、UNIX/LINUX操作系统下的、基于命令行的、功能强大的程序调试工具。 对于一名Linux下工作的c++程序员,gdb是必不可少的工具;启动gdb对C/C++程序...

Linux下防御ddos攻击

Linux下防御ddos攻击

SYN攻击是利用TCP/IP协议3次握手的原理,发送大量的建立连接的网络包,但不实际建立连接,最终导致被攻击服务器的网络队列被占满,无法被正常用户访问。 Linux内核提供了若干SYN相关的配置,加大...

发表评论    

◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。