Skip to content

RimoChan/RO-Merge

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RO-Merge: Reward-Optimized Model Merging

事情是这样的,大家平时会想要自己训练扩散模型吗?

如果想要提升模型的整体性能,传统上,要么做全参数微调、要么做强化学习,这些方法通常都要训很久,而且很贵。家里如果只有1张显卡的话就只能训1训LoRA了!

不过好在开源社区里,有很多由同1个Base Model训来的模型。这些模型各自分别有1些优点。

那么,比起从头训练,不如直接在这些的开源里拼出最好的权重组合,将它们结合在1起,我们就可以获得超越所有模型的模型!

原理

这是1个零阶优化算法,是把权重当作黑盒,将模型推理结果加上1个评分函数送进贝叶斯优化来跑的。

具体来说是这样:

  1. 假设有n个待融合模型、每个模型有m个层,使用贝叶斯优化器,初始化1组参数a[1, 1] ~ a[n, m] 。还需要约束同1层内各模型的权重和为 1:sum(a[i, k] for i in range(n)) = 1,也就是这些参数是(n-1)*m个浮点数。

  2. 根据上1步的参数组,建立1个新模型b,其中的每个层的权重为 b[k] = sum(a[i, k]*Model[i, k])

  3. 评测模型性能。首先从预定义的列表里随机抽取Danbooru标签组合成c个Prompt,然后使用模型b生成c张图像,将生成的图像送入Danbooru标签模型来反推标签。输入标签中能被预测出的个数除以总标签数就是是模型b的准确度了。

  4. 贝叶斯优化器自己根据分数调整下1轮的权重分布,回到步骤2继续循环,直到达到设定的循环次数。最终保留得分排名前top k的融合模型。

此外,在训练SDXL时也加入了1些小妙招:

  • 评估时还有1个惩罚机制,因为在实验中观察到,过大的负权重(例如A模型为1.8倍,B模型为-0.8倍,它们加起来也等于1)容易导致生成图像在视觉上出现崩坏或噪点。因此最终的Reward里加了1项,让它减去负权重的平均值乘以惩罚因子。

  • SDXL的层数对于贝叶斯优化器来说还是太多了,因此实际上并不是直接用torch的层,而是做了分组,Unet的参数我按IN/MID/OUT blocks手工分了10组,加上剩下的部分1共11组。这样1来参数空间会小很多,贝叶斯会比较高兴。

  • 融合前的模型都可以放在RAM里,这样可以VRAM压到和推理基本1致,24G的4090就可以跑,不过比较吃RAM,建议用48G的,不过现在RAM太贵,不够的话就别买了,直接把Swap开大一点吧。虽然应该4060ti也可以跑,不过4090要跑1天,4060ti可能就要跑1周了。

使用方法

首先你需要有torch,这个代码应该不挑版本,反正你有装哪个就用哪个。

先安装1下依赖——

pip install -r requirements.txt

然后用Python运行就可以了——

python ro_merge.py --output_dir="./savedata" --models="['model_a.safetensors','model_b.safetensors','model_c.safetensors']"

还有一些可选参数,比如:

  • --n_iter=100: 设置贝叶斯优化的轮数。

  • --eval_n_iter=100: 设置每1轮评估的样本数。

其他的参数也都不复杂,有兴趣仔细调1调的话可以进ro_merge.py自己看1下就懂了。

此外,output_dir里有1个log文件夹,可以用tensorboard --logdir=log打开来看。

模型下载

SD1.5版: https://civitai.com/models/249129

XL版: https://civitai.com/models/358055

这个模型是2024年发布的,那个时候效果还算是挺好的。不过后来Illustrious出来之后,因为Illustrious的效果太好,打不过它是正常的,毕竟便宜没好货!

1些问题

  1. 负权重惩罚看起来很弱,为什么模型不会出现Reward Hacking?

我的理解是这样,因为这个做法不是传统的RL,传统RL有几B的参数可以修改,这里只有几十个参数,本身就是一个非常强的约束了。模型没法做太多的Hacking,加了负权重惩罚也就够了。

  1. 如果我不想用2次元模型,想跑真实系模型怎么办?

什么,竟敢背叛2次元,直接砍头!

  1. 支持断点续训吗?

其实不支持。这个1次训1天也就差不多了,万一停电了那就重训吧,或者直接从log里挑1个最好的模型凑合用也行。

结束

好,就这样,大家88,我要回去和1girl亲热了!

About

【RO-Merge】基于奖励优化的模型融合算法!

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages