• 我的订阅
  • 科技

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B

类别:科技 发布时间:2024-10-09 09:52:00 来源:机器之心Pro

随着 AI 模型的参数量越来越大,对算力的需求也水涨船高。

比如最近,Llama-3.1 登上了最强开源大模型的宝座,但超大杯 405B 版本的内存就高达 900 多 GB,这对算力构成了更加苛刻的挑战。

如何降低算力的使用成本和使用门槛,已经成为许多公司寻求突破的关键。Felafax 就是其中的一家创业公司,致力于简化 AI 训练集群的搭建流程。

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B

Nikhil Sonti 和 Nikhin Sonti 创立了 Felafax,他们的口号是在构建开源 AI 平台,为下一代 AI 硬件服务,将机器学习的训练成本降低 30%。

与英伟达相比,AMD 的 GPU,尤其是 MI300X 系列,提供了更高的性价比,按每美元计算,其性能表现更为出色。

最近,Felafax 的联合创始人 Nikhil Sonti 发布了一篇博客,详细分享了如何通过 8 张 AMD MI300X GPU 和 JAX 微调 LLaMA 3.1 405B 模型的方法,所有代码现已开源。

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B

Github 链接:https://github.com/felafax/felafax

机器之心对博客内容进行了不改变原意的编译、整理,以下是博客内容:

JAX 尤其适合非英伟达硬件

JAX 是一个强大的机器学习库,结合了类似 NumPy 的 API、自动微分功能以及 Google 的 XLA 编译器。它在模型并行化方面提供了优秀的 API,因此非常适合像 LLaMA 3.1 405B 这样的超大模型训练。

在使用 AMD 硬件时,JAX 有几个明显的优势:

多硬件并行支持:JAX 采用 XLA(加速线性代数)编译器,将计算编译为硬件无关的中间表示(HLO),这意味着同样的 JAX 代码无需修改便可高效运行在不同硬件后端,包括 AMD GPU。 独立于底层硬件:XLA 编译器的优化策略是通用的,不针对某个特定的硬件平台。这使得任何支持 XLA 的硬件设备(如 CPU、GPU、TPU)都能受益于这些优化,获得更好的性能表现。 极高的适应性:从 NVIDIA 转移到 AMD(或其他硬件)时,JAX 只需做极少的代码改动。而相较之下,PyTorch 与英伟达的 CUDA 生态系统紧密耦合,迁移过程相对复杂。

因此,JAX 成为了我们在非英伟达硬件上的最佳选择。

拉取 Docker 镜像:

docker pull rocm/jax:latest

启动 Docker 容器:

# Pull the Docker Image:

docker pull rocm/jax:latest

# Start the Docker Container:

docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video \

--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest

# Verify the Installation:

python3 -c 'import jax; print(jax.devices())'

验证安装

python3 -c 'import jax; print (jax.devices ())'

训练使用了一个配备了 8 张 AMD MI300x GPU 的 AMD 节点。每张 MI300x 拥有 192GB 的 HBM3 内存,性能表现与最新的英伟达 H100 GPU 相比非常出色。

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B

与英伟达 H100 的比较,来源:TensorWave

训练 LLaMA 405B:性能与可扩展性

使用 JAX,可以成功地在 AMD GPU 上训练 LLaMA 405B 模型。我们使用 LoRA 微调,将所有模型权重和 LoRA 参数都设为 bfloat16,LoRA rank 设为 8,LoRA alpha 设为 16:

模型大小:LLaMA 模型的权重占用了约 800GB 的显存。 LoRA 权重 + 优化器状态:大约占用了 400GB 的显存。 显存总使用量:占总显存的 77%,约 1200GB。 限制:由于 405B 模型的规模过大,batch 大小和序列长度的空间有限,使用的 batch size 为 16,序列长度为 64。 JIT 编译:由于空间限制,无法运行 JIT 编译版本;它可能需要比急切模式稍多的空间。 训练速度:使用 JAX 急切模式,约为 35 tokens / 秒。 内存效率:稳定在约 70% 左右。 扩展性:在 8 张 GPU 上,使用 JAX 的扩展性接近线性。

由于硬件和显存的限制,我们无法运行 JIT 编译版本的 405B 模型,整个训练过程是在 JAX 的急切模式下执行的,因此还有很大的进步空间。

下图中显示了在一次微调训练步骤中,8 张 GPU 的显存利用率和 rocm-smi 输出:

GPU 利用率:

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B

训练设置

将 LLaMA 3.1 从 PyTorch 移植到 JAX

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B

此前,Nikhil Sonti 分享过如何将 LLaMA 3.1 从 PyTorch 移植到 JAX。他指出,目前 90% 的大型语言模型(LLM)都运行在 NVIDIA GPU 上,但实际上还有一些同样强大且性价比更高的替代方案。例如,在 Google TPU 上训练和部署 Llama 3.1 的成本比 NVIDIA GPU 低约 30%。

然而,支持非 NVIDIA 硬件的开发工具较为匮乏。Sonti 最初尝试使用 PyTorch XLA 在 TPU 上训练 Llama 3.1,但过程并不顺利。XLA 与 PyTorch 的集成不够完善,缺少一些关键的库(如 bitsandbytes 无法正常运行),同时还遇到了一些难以解决的 HuggingFace 错误。

为此,他决定调整策略,将 Llama 3.1 从 PyTorch 移植到 JAX,成功解决了这些问题。Sonti 还录制了详细的教程视频,并开源了所有代码:

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B

方法演示:https://dub.sh/felafax-demo 代码仓库:https://github.com/felafax/felafax

加载模型,并把模型参数分片

处理像 LLaMA 405B 这样的超大模型,需要在多个设备之间高效地进行参数分片。以下是如何通过 JAX 实现这一点的。

在 JAX 中进行参数分片

为了将巨大的 LLaMA 405B 模型高效地分布到 8 张 AMD GPU 上,需要使用 JAX 的设备网格(device mesh)功能。

部署代码:https://github.com/felafax/felafax/blob/e2a96a0e207e1dc70effde099fe33a9e42a7d5cb/llama3_jax/trainer_engine/jax_utils.py#L69

JAX 的设备网格可以帮助我们把可用的设备组织成一个网格,让我们可以指定如何把模型的参数和计算分配到不同的 GPU 上。

在本文的设置中,需要创建一个形状为(1, 8, 1)的网格,并将轴分别命名为数据并行(dp)、全分片数据并行(fsdp)和模型并行(mp)。然后,为模型的每个张量定义特定的分片规则,指定这些维度如何沿着这些网格轴进行分片。

DEVICES = jax.devices ()

DEVICE_COUNT = len (DEVICES)

DEVICE_MESH = mesh_utils.create_device_mesh ((1, 8, 1))

MESH = Mesh (devices=DEVICE_MESH, axis_names=("dp", "fsdp", "mp"))

可视化分片

可以使用以下代码来可视化分片结果,从而方便地验证分片规则是否按预期应用。

jax.debug.visualize_array_sharding

分片规则

模型不同组件的分片规则如下所示:

参数如何分片:

参数要在 8 个 GPU 之间分配。例如,LM head(lm_head/kernel)张量有两个轴,按照 PS ("fsdp", "mp") 进行分片。在本例中是 8 和 1,因此可以看到该张量在第一个轴上沿着 8 个 GPU 被拆分。

Non-Replicated 参数:

没有任何分片规范的参数会在所有设备上进行复制。例如,层归一化(attention_norm/kernel 和 ffn_norm/kernel)没有设置分片规范,是 PS (None)。

应用分片函数

在加载模型时,使用以下分片函数逐步对模型权重进行分片:

def make_shard_and_gather_fns (partition_specs):

def make_shard_fn (partition_spec):

out_sharding = NamedSharding (mesh, partition_spec)

def shard_fn (tensor):

return jax.device_put (tensor, out_sharding).block_until_ready ()

return shard_fn

shard_fns = jax.tree_util.tree_map (make_shard_fn, partition_specs)

return shard_fns

# Create shard functions based on partitioning rules

shard_fns = make_shard_and_gather_fns (partitioning_rules)

这使得我们能够将每个参数放置在指定的设备上,并按照设定的分片进行处理。

分片训练 Batch

最初,训练 Batch 是正常创建的,但在输入模型之前,需要按照下面的代码在 GPU 上进行分片:

train_batch = jax.device_put ( train_batch,

NamedSharding (self.mesh, PS ("dp", "fsdp")))

在这里,我们指定训练 Batch 应该在 "dp" 和 "fsdp" 轴上进行分片,在本例中分别对应于被分成 1 和 8 份,如果把结果可视化出来,如下所示:

分片前:

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B

在调用 jax.device_put 之后:

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B

加入 LoRA

LoRA 通过将权重更新分解为低秩矩阵,减少了可训练参数的数量,这对于微调大型模型特别有效。以下是在 AMD GPU 上微调 Llama 3.1-405 的 LoRA 的要点:

将 LoRA 参数(lora_a 和 lora_b)与主模型参数分开。 使用 jax.lax.stop_gradient (kernel) 来防止对主模型权重的更新。 使用 lax.dot_general 进行快速、精确控制的矩阵运算。 LoRA 输出在添加到主输出之前会被缩放为 (self.lora_alpha/self.lora_rank)。

LoRADense 层

在此设定一个自定义的 LoRADense 层,该层集成了 LoRA 参数:

class LoRADense (nn.Module):

features: int

lora_rank: int = 8

lora_alpha: float = 16.0

@nn.compact

def __call__(self, inputs: Any) -> Any:

# Original kernel parameter (frozen)

kernel = self.param ('kernel', ...)

y = lax.dot_general (inputs, jax.lax.stop_gradient (kernel), ...)

# LoRA parameters (trainable)

lora_a = self.variable ('lora_params', 'lora_a', ..., ...)

lora_b = self.variable ('lora_params', 'lora_b', ..., ...)

# Compute LoRA output

lora_output = lax.dot_general (inputs, lora_a.value, ...)

lora_output = lax.dot_general (lora_output, lora_b.value, ...)

# Combine original output with LoRA modifications

y += (self.lora_alpha/self.lora_rank) * lora_output

return y.astype (self.dtype)

分片 LoRA 参数

为了高效地在设备之间分配 LoRA 参数,我们也通过 JAX 设定了分片规则,这确保了 LoRA 参数与主模型参数的分片一致,优化了内存使用和计算效率。

LoRA A matrices (lora_a)

LoRA A 矩阵(lora_a)

分片规则:PS ("fsdp", "mp") 可视化结果:如下图所示,lora_a 参数被分片为 (8, 1),这意味着第一个轴在 8 个设备上进行分片("fsdp" 轴),而第二个轴未进行分片。

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B

LoRA B 矩阵(lora_b)

分片规则:PS ("mp", "fsdp") 可视化结果:如下图所示,lora_b 参数被分片为 (1, 8),这意味着第二个轴在 8 个设备上进行分片(fsdp 轴),而第一个轴未进行分片。

微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B

这种分片策略优化了参数的分配,减少了通信开销,并在训练过程中增强了并行性。它确保每个设备仅持有一部分 LoRA 参数,使得大模型如 LLaMA 405B 的高效扩展成为可能。

仅更新 LoRA 参数

为了优化训练,在微调 LLaMA 405B 模型,只计算 LoRA 参数的梯度,保持主模型参数不变。这个方法减少了内存使用,并加速了训练,因为只更新较少的参数。可以移步 GitHub 仓库,查看实现细节。

在训练过程中,每一步都涉及将一批输入数据通过模型进行处理。由于只有 LoRA 参数是可训练的,因此模型的预测和计算的损失仅依赖于这些参数,然后对 LoRA 参数进行反向传播。只更新这些参数简化了训练过程,使得在多个 GPU 上高效微调像 LLaMA 405B 这样的大型模型成为可能。

更多研究细节,请参考原博客。

以上内容为资讯信息快照,由td.fyun.cc爬虫进行采集并收录,本站未对信息做任何修改,信息内容不代表本站立场。

快照生成时间:2024-10-09 11:45:09

本站信息快照查询为非营利公共服务,如有侵权请联系我们进行删除。

信息原文地址:

马斯克承诺开源版大模型 来了!Grok-1:3140亿参数迄今最大,权重架构全开放
...之心开源社区有福了。说到做到,马斯克承诺的开源版大模型 Grok 终于来了!今天凌晨,马斯克旗下大模型公司 xAI 宣布正式开源 3140 亿参数的混合专家(MoE)模型‘Grok-1’
2024-03-18 11:51:00
前谷歌科学家Yi Tay「LLM演义」系列博客第一弹:BERT为何匿迹江湖?
【新智元导读】前谷歌科学家Yi Tay重磅推出「LLM时代的模型架构」系列博客,首篇博文的话题关于:基于encoder-only架构的BERT是如何被基于encoder-decoder架构的T5所取代的
2024-07-22 09:39:00
Scaling Law百度最早提出!OpenAI/Claude受它启发,致谢中有Ilya
...统团队。他们探讨了深度学习中训练集大小、计算规模和模型精度之间的关系,并且通过大规模实证研究揭示了深度学习泛化误差和模型大小的缩放规律,还在图像和音频上进行了测试。只不过他们
2024-11-28 09:57:00
Meta首次公布AI芯片细节 功耗低于英伟达
...)计划的一部分,主要用于提升广告投放和其他内容推荐模型的效率。据Meta介绍,首个MTIA芯片将专注于AI推理。Meta软件工程师Joel Coburn表示
2023-05-19 14:00:00
Meta大模型LLaMA 3即将登场,参数量或达1400亿
在推出开源大模型LLaMA2近一年之后,Meta的新一代大模型LLaMA3即将面世。在4月9日伦敦举行的一次活动中,Meta确认计划在下个月内首次发布LLaMA3
2024-04-10 22:40:00
全球最强开源大模型Llama 3发布:使用15T数据预训练,最大模型参数将超4000亿
就在刚刚,Meta 发布了其最先进开源大型语言模型的下一代产品——Llama 3。据介绍,Llama 3 在 24K GPU 集群上训练
2024-04-20 11:03:00
美欧亚三洲开发者联手,全球首个组团训练的大模型,全流程开源
...,Prime Intellect 宣布通过去中心化方式训练完成了一个 10B 模型。30 号,他们开源了一切,包括基础模型
2024-12-03 13:34:00
谷歌「诚意之作」,开源9B、27B版Gemma2,主打高效、经济!
...著的能力跟进,可见其技术发展与创新的潜力。除了Gemini模型外,Gemma这一系列轻量级的SOTA开放模型似乎与我们距离更近
2024-06-29 09:37:00
大模型是否有推理能力?DeepMind数月前的论文让AI社区吵起来了
最近一段时间,随着 OpenAI o1 模型的推出,关于大型语言模型是否拥有推理能力的讨论又多了起来。比如苹果在前段时间的一篇论文中指出,只要给模型一些干扰,最聪明的模型也会犯最
2024-10-23 12:05:00
更多关于科技的资讯:
“产业炬光灯”聚焦厦企笃正新能源 紧跟市场谋创新
“产业炬光灯”聚焦笃正新能源。厦门网讯(厦门日报记者 林露虹)把阳光“存”起来,变成随时可用的电能。厦门企业笃正新能源在离网光伏储能领域持续深耕
2025-11-04 08:07:00
需求释放结构升级,消费市场涌动“焕新”潮
“还有咖啡节”在玄武湖公园打造了时尚潮流集市,吸引许多市民前来消费打卡,在明媚秋光中度过惬意周末。 通讯员 常成 南京日报/紫金山新闻记者 孙中元 摄今日关注数字4
2025-11-04 07:41:00
机器人“派上用场”,课堂知识接轨产业需求
105支高校战队在宁角逐“埃斯顿杯”——机器人“派上用场”,课堂知识接轨产业需求机器人正在进行自主抓取挑战︐参赛选手紧张调试
2025-11-04 07:41:00
木里木外x Designwire设计腕儿 《心居未来·豪宅艺术与生活方式》趋势论坛圆满举行
智能高定探索美好关系,2025年10月22日,木里木外·故宫·源邸迎来了一场思想与美学的盛宴。以“心居未来·豪宅艺术与生活方式”为主题的趋势论坛论坛在此隆重启幕
2025-11-03 14:31:00
以 AI 科技赋能中超 ,铸强品牌区域影响力
鲁网11月3日讯2025年2月22日,联想集团与中国足球职业联赛联合会正式达成战略合作,成为中超联赛官方合作伙伴,以全栈 AI 技术为核心为联赛注入科技活力
2025-11-03 14:37:00
中新经纬11月3日电 据“北京市场监管”微信号3日消息,京津冀三地市场监督管理部门联合发布“双十一”消费提示和平台合规指导
2025-11-03 15:00:00
第26届金牌服务季战略升级,金牌家居赋能装企决胜“品质交付”
10月29日上午,金牌服务季——“品质交付”装企交付保障誓师大会在金牌家居总部研发大楼隆重举行,金牌家居总裁潘孝贞、轮值执行副总裁陈少华等多位领导出席
2025-11-03 15:01:00
近日,致力于智慧、安全、健康和可持续建筑解决方案的全球性企业江森自控宣布推出EasyIO Neo Series楼宇自动化系统的新版本
2025-11-03 15:01:00
远大生命科学与图灵量子共建AI赋能平台|量子智能重塑GUTX益次方研发范式
近日,远大生命科学有限公司(以下简称“远大生命科学”)与图灵量子正式签署战略合作协议,双方将共同建设“益生菌数据库及量子AI赋能系统”
2025-11-03 15:01:00
2025世界物联网博览会|思特奇全栈物联能力筑基万物智联
10月31日,2025世界物联网博览会(2025WIOT)已在无锡市太湖国际博览中心盛大启幕。作为国内领先的数字化转型与智能化服务提供商
2025-11-03 15:09:00
2025年优质少儿编程教育机构:妙小程少儿编程-赛事出口+精准匹配!
在少儿编程教育领域,家长选择机构时通常会围绕课程体系的科学性、教学模式的适配性、竞赛支持的实用性及师资团队的稳定性展开考量
2025-11-03 15:10:00
2025年51Talk口碑好:收费透明+约课便捷+系统稳定!
在为孩子选择在线英语课程时,很多家长都会关心课程的收费价格是否合理、能否长期坚持。作为2011年成立、2016年在美国上市的首家中国在线教育企业
2025-11-03 15:12:00
张朝阳谈物理课开讲四周年:人生任何阶段都可以学习
11月2日14:00,搜狐创始人、董事局主席兼首席执行官、物理学博士张朝阳携《张朝阳的物理课》四周年特别直播如约而至。以“以矢量微积分的精神表达微分几何”为主题
2025-11-03 15:31:00
无人车配送场景再“上新”,顺丰同城携手收钱吧推出无人车校园餐配
近日,国内规模最大的第三方即配平台顺丰同城与数字化门店综合服务商收钱吧达成合作,联合在四川汽车职业技术学院推出无人车校园智能餐配服务
2025-11-03 15:31:00
鲁网11月3日讯近日,山东省地矿局第七地质大队与之江实验室科学数据枢纽研究中心,聚焦AI技术与地矿行业的深度融合,通过视频会议形式召开专题交流会
2025-11-03 15:47:00