-
使用 MLX 探索分布式推理和训练
使用 MLX 将机器学习工作负载扩展到多台 Mac。了解如何解决互连效率、大模型推理、请求批处理和分布式训练方面的难题。探索如何只用几台 Mac 来替代昂贵的云基础设施,从而满足 AI 工作负载的高要求。
章节
- 0:00 - Introduction
- 2:09 - Distributed communication
- 4:32 - Setting up your cluster
- 10:33 - Distributed inference and fine-tuning
- 13:35 - Model parallelism strategies
- 15:53 - Distributed fine-tuning
- 18:34 - CLI, Python, Swift, and C++ APIs
- 20:45 - Next steps
资源
- MLX Swift LM on GitHub
- MLX Swift Examples
- MLX Examples
- MLX Swift
- MLX LM - Python API
- MLX Explore - Python API
- MLX Framework
- MLX
相关视频
WWDC26
WWDC25
-
搜索此视频…
你好 我是Tatiana MLX团队的研究科学家 本地LLM正处于 一个令人瞩目的时代 模型规模不断增大 并获得了惊人的新能力 变得更加智能 能够解决更难的问题 随着性能提升 我们使用它们 处理更多任务:更长上下文、更难的任务 以及更复杂的工作流 最终 单台机器的内存、算力 或带宽将成为瓶颈 在我们的WWDC 26视频 "Run local agentic AI on the Mac using MLX"中 展示了如何在本地运行AI智能体 但当你拥有多台设备时 可以将本地AI发挥到极致 运行更大的LLM或通过 分布式推理和训练加速它们 通过分布式 推理和训练来加速 今天 我们将深入探讨 如何用MLX跨多台Mac进行扩展 充分利用你桌上的硬件 我们将从命令行界面开始 让模型在你的机器上运行 在你的机器上运行 然后转向Python API 进行实验 最后用Swift将这些工作流 直接嵌入你的App 开始吧! 首先 我们来了解完整的 硬件和软件技术栈 以实现Apple Silicon上 的分布式工作负载 然后我们将所有内容整合起来 把四台M3 Ultra组成一个集群 我们将逐步演示:选择 合适的拓扑来连接机器 启用快速通信 并启动分布式任务 集群准备好后 我们就进入令人兴奋的环节 快速本地分布式LLM 推理和微调 我们将用MLX运行 并与单台Mac进行对比 并了解MLX如何在 集群中分发模型 大多数示例使用 命令行界面 最后我们将展示 分布式通信 如何通过Python、Swift 和C++ API向你开放 让我们先来了解Apple Silicon 的分布式通信 要实现快速数据收发 机器需要通过物理链路连接 即互连 在此之上 我们还需要 一个传输协议 一种将字节推送的机制 从一台机器的内存传输到另一台 从macOS 26.2开始 远程直接内存访问协议 即RDMA 支持通过Thunderbolt 5使用 RDMA将数据直接从一台机器 的内存传输到另一台 避免了大部分CPU 和操作系统的开销 基于Thunderbolt的RDMA 为我们提供了高带宽 低延迟通信 正是分布式工作负载所需的 但仅靠它 只能实现 两台机器之间的原始数据传输 因此分布式程序需要 更高层次的抽象 一个通信后端 提供用于发送数据的 通信原语 在各台机器之间传输 或在整个集群中 进行协调 这两种操作是分布式 训练和推理的基础模块 这就是JACCL发挥作用的地方
JACCL是一个开源的 集合通信库 由Apple构建 它利用基于Thunderbolt的RDMA 为你提供集合 通信原语 用于在机器之间发送数据 并在集群中合并结果 无需自行管理任何 底层传输 它不仅限于机器学习 Apple Silicon上的任何分布式工作负载 都可以基于它构建 技术栈的最后一块 是机器学习框架 使用通信后端 进行分布式推理和训练 这就是MLX MLX是一个开源 机器学习库 由Apple为Apple Silicon构建 它利用JACCL实现 低延迟分布式通信 并提供工具用于跨集群 编排分布式任务 如果你是MLX新手 请查看我们的视频 "Getting Started with MLX on Apple Silicon" 来自WWDC25
现在我们了解了完整的技术栈 让我们将所有内容整合起来 构建一个集群 一组共同协作 完成同一任务的机器 我们将使用4台M3 Ultra 要搭建集群 需要用 Thunderbolt 5线缆连接机器 有多种连接方式 拓扑结构直接 影响通信时间 首先 我们来了解 决定通信时间的因素 接下来 我们了解如何 实际连接这些机器 JACCL支持哪些拓扑 以及它们之间的权衡 之后 我们将展示如何在 机器上启用RDMA以实现快速通信 最后 我们将用MLX在 集群上启动分布式任务
通信时间由两个部分组成 延迟和传输时间 延迟是每次通信操作 需要支付的固定成本 与发送的数据量 无关
传输时间是通过链路 移动数据的成本 随消息大小增长 并取决于链路的带宽
对于小消息 数据移动成本很小 因此延迟占主导
对于大消息 权衡正好相反 根据通信是受延迟限制 还是受带宽限制 我们可能会偏好不同的拓扑
JACCL支持其中两种 网格和环形 在全网格中 每台机器 与其他所有机器直接连接 因此任何集群通信 具有最低的可能延迟 在环形拓扑中 每个节点 只与其两个邻居连接 非相邻节点之间的通信 必须经过中间机器 这会增加延迟 但环形拓扑每台机器 所需的线缆和端口更少 更易于扩展到更多节点 由于每个节点只有两个连接 可以使用额外的Thunderbolt端口 每个邻居使用两到三根线缆 (取决于Mac型号) 从而增加每条链路的带宽 并减少传输时间 当机器连接成网格后 我们可以灵活地为 每次通信选择路由 通过网格拓扑 或环形拓扑
JACCL的优点在于 它会自动选择最佳拓扑 根据消息大小 和通信操作 延迟敏感时选择网格 带宽敏感时选择环形 为获得这种灵活性 让我们将所有M3 Ultra连接成网格
当我们将所有M3 Ultra连接好后 需要在所有机器上 启用RDMA 打开机器上的设置 搜索"RDMA"
点击"Enable RDMA over Thunderbolt"
启用RDMA 然后重启
太好了! Mac已通过 Thunderbolt 5线缆连接 并且RDMA已启用 现在我们需要一种 启动分布式程序的方法
一种方式是通过局域网 例如通过Wi-Fi或以太网 从任何可以SSH访问 集群的机器上 比如我的MacBook 我们连接到每台Mac 启动程序 从那时起 所有机器通过 Thunderbolt链路直接通信 MLX提供了一个启动助手 可以为你完成所有这些操作!
你在MacBook上运行mlx.launch 它负责编排集群 你提供想要运行的可执行文件 以及描述集群的 JSON主机文件 它通过SSH连接到每个节点 使用提供的主机文件中的主机名 并在每台机器上 启动可执行文件 让我们看看描述集群的 主机文件应该是什么样子 它是一个JSON数组 每个节点一条记录 "ssh"是mlx.launch 用于连接机器的主机名 "ips"是机器在 局域网上的IP地址 由JACCL用于节点间 的初始协调 "rdma"是RDMA 设备名称的列表 对应每个Thunderbolt对等连接
你可以手动编写 但MLX也提供了 一个助手脚本`mlx.distributed_config` 可以自动生成 你提供主机名列表 和输出路径 你还可以在配置中 嵌入环境变量 它们将在启动时 自动设置在每个节点上 这里我们设置MLX_METAL_FAST_SYNCH=1 这能启用更快的 GPU到CPU同步 这对分布式任务至关重要 因为计算在GPU上运行 在GPU上运行 而通信在CPU上运行 你还可以传递--auto-setup标志 自动配置 Thunderbolt网络 Communication的--backend参数 定义是网格还是环形 对于网格 --backend设置为 jaccl 如本例所示 对于环形 我们将 其改为jaccl-ring 让我们运行此命令 为集群生成主机文件
首先 它检查所有主机 是否可以通过SSH访问 然后探测每台机器 的Thunderbolt端口 以发现哪些机器 与哪些机器物理连接 从而构建拓扑图 由于我们传递了--auto-setup 它会在所有机器上禁用Thunderbolt Bridge 在所有机器上 并为每条Thunderbolt 链路配置RDMA 最后 它写入一个JSON主机文件 包含mlx.launch所需的一切 注意 不带--auto-setup标志时 脚本会打印配置命令 方便你审查后自行运行
现在集群已准备就绪 让我们进入令人兴奋的部分 分布式语言模型推理 和微调 最简单的入门方式是通过 命令行界面 和MLX LM MLX LM是一个开源Python包 基于MLX构建 提供命令行工具 以及用于在本地运行 语言模型的Python API 在Apple Silicon上运行 请查看我们的视频 "Explore large language models on Apple Silicon with MLX" 来自WWDC25 以在单台设备上入门
正如我们去年展示的 在单台Mac上与模型对话 可以通过命令行界面 使用mlx_lm.chat实现 我们在终端中运行它 指定要使用的模型 例如Qwen 3.6 以及响应的 最大Token数 在后台 MLX LM会在 单台机器上加载并运行模型
要通过命令行界面在 集群上与同一模型对话 我们用mlx.launch包装命令 在MacBook上 我们在终端 中运行mlx.launch 使用--hostfile指向 我们的集群配置 在双横线后 我们传递 完全相同的mlx_lm.chat命令 但使用每个节点上 可执行文件的远程路径 命令几乎完全相同 MLX LM会对模型分片 并为你协调分布式推理 请记住 所有必要的库 如MLX必须安装在每台Mac上 可执行文件必须在 所有机器上都可访问 通过命令行界面一行命令 我们就让模型运行起来了 分布在整个集群上! 让我们并排比较一下 用Qwen 3.6对话 一个拥有270亿参数的模型 在单台M3 Ultra和4台上分别运行 我已在两侧启动了 mlx_lm.chat 左侧 模型加载在 单台M3 Ultra上 右侧 它分片在 四台机器上 让我们用"Implement a transformer model in MLX." 同时提示两边
速度提升相当惊人! 集群生成Token的速度 接近单台机器的三倍 相比单台机器 对于Qwen 3.6模型 正如我们所见 在多台Mac上 运行模型 可以显著提升推理速度 具体加速效果取决于 模型大小和架构 但时间改善并非 使用分布式的唯一原因 有时模型对于单台机器 来说太大了 例如Kimi 2.6拥有 1万亿个总参数 即使进行8-bit量化 仅权重本身就需要 约1TB的内存 这无法放入单台M3 Ultra 但可以分布在四台上 那么我们如何实际将权重 和计算拆分到各机器上?
MLX和MLX LM支持两种方式 流水线并行和张量并行
流水线并行 按深度拆分模型 在这种情况下 每台机器持有一组层 数据按顺序 流经各台机器 它不会加速推理 因为每个Token仍需依次通过 各组层 一个接一个 但优点是通信简单 机器只在层组边界处交换激活值 在层组的边界处 张量并行 按宽度拆分模型 在这种情况下 每台机器持有每层的一部分 因此所有机器同时 处理同一Token 由于每层计算并行化 推理速度得以提升 但代价是通信 频率大幅增加 在每层和每个Token时 都会发生通信 这使得低延迟变得重要 这就是为什么网格拓扑 对这种情况至关重要 每台机器都能通过 单跳访问其他任何机器
张量并行是MLX LM中 的默认分片策略 要用流水线并行 对模型进行分片 只需在命令中 附加--pipeline标志 注意 并非所有模型 都支持流水线并行 现在 让我们在集群上 与万亿参数的Kimi 2.6对话 在我们的集群上
为此我们像之前一样 从MacBook使用mlx.launch 指向主机文件 我没有传递--pipeline标志 所以我们使用张量并行 我们需要等待片刻 mlx.launch正在连接每台机器 MLX LM加载并分片模型 然后启动对话
太好了 模型已加载! 让我们向模型提问 "Implement machine learning architecture for GPT in Python with MLX"
就这样 仅凭一条命令 一个庞大的万亿参数模型 正在你的Mac上本地运行 回答你的问题
使用MLX和MLX LM 不仅可以运行语言模型推理 还可以在你的硬件上 对模型进行微调 快速、高效、完全私密 数据从不离开你的机器 让我们从单台Mac开始 然后扩展到集群 在单台机器上进行 微调或训练时 我们将训练数据分成批次 即多个样本的集合 对于每个批次 Mac计算梯度 并更新模型权重 我们对训练数据集 重复此过程一次或多次 直到模型达到 期望的质量 处理训练数据的速度越快 微调完成得越早 那么如何使用多台机器 来加速这一过程? 思路很简单 在每台Mac上复制模型 每台机器接收不同的数据批次 并在本地计算梯度 然后我们对梯度取平均 使模型更新使用来自 所有批次的信息 这称为数据并行训练 因为模型被复制 而数据在各机器上 并行处理 这就是加速的来源 因此 N台机器可以将 数据处理速度提升至多N倍 听起来很棒! 让我们看看如何在 MLX LM中使用数据并行 和之前一样 与单设备的 唯一区别 是用mlx.launch启动任务 从你的MacBook 指定远程机器上 mlx_lm.lora的路径 数据分片由MLX LM处理 命令几乎相同 我们将--batch-size 乘以设备数量 这样每台机器仍然处理 与之前每步 相同数量的样本 让我们对拥有90亿参数的 Qwen 3.5进行微调 在单台机器和集群上分别运行 并比较模型每秒 处理的Token数量 我们在左侧的单台设备上 启动微调 右侧在集群上启动 使用mlx.launch和主机文件 指定远程机器上 mlx_lm.lora的路径 首先加载数据和模型 然后训练开始 单台M3 Ultra每秒处理 约180个Token 而在集群上每秒处理 约600个Token 微调速度提升 超过3倍 现在 使用MLX 你可以将 设备变成本地训练集群 进行高效微调 无需迁移到云端 到目前为止 我们使用命令行界面 进行分布式推理 和MLX LM内的微调 然而 MLX提供了 细粒度的控制 用于分片和分布式操作 通过灵活的Python、Swift和C++ API 这允许你在Python和C++中 对模型进行实验 或用Swift将模型 嵌入你的App 让我们看看示例 要使用Python API 和MLX LM运行分布式推理 首先初始化 用于通信的分布式组 然后定义我们想要的 并行类型 例如张量并行 最后使用sharded_load函数 对模型进行分片 之后 我们像在单台设备上 一样使用该模型 MLX LM在底层处理 所有分布式通信
要对模型及其分片 有更多控制 可以使用MLX本身的 底层原语 例如 定义一个简单的 Linear层后 可以使用shard_linear函数 对其进行张量并行分片 你甚至可以控制基本的 分布式操作 如all reduce 在Python、Swift或C++中 通过JACCL初始化分布式组后 我们对张量在所有Mac上 执行集合分布式求和 使用对应的MLX原语 正如我们在会话开始时 所指出的 JACCL本身也是独立可用的 你可以将其用于任何应用 需要分布式通信的应用 甚至是非机器学习应用 JACCL可以不依赖MLX独立构建 它提供C++ API 包含通信原语 初始化JACCL组后 我们再次对所有Mac的 张量执行集合分布式求和 但这次直接通过JACCL 而非MLX 现在你已了解高层和低层API 用于使用MLX和JACCL 进行分布式推理和训练 你已准备好用MLX 构建高级分布式工作流
在本次会话中 我们了解了完整的技术栈 使分布式训练 和推理成为可能 在Apple Silicon上 从基于Thunderbolt的RDMA 一直到MLX和MLX LM 我们展示了从单台设备 扩展到多台设备是多么简单 以及它带来的好处 更快的推理 运行万亿 参数模型的能力 以及更快的微调 只需对单设备代码 做极少改动 支持命令行界面 Python、Swift和C++ API 有了分布式集群 现在你可以运行 完全由MLX驱动的本地AI智能体 快速、私密 运行在你自己的硬件上 了解更多 请查看我们的WWDC 2026视频 "Run local agentic AI on the Mac using MLX" 要深入了解 高级分布式功能 包括自定义并行策略 和训练循环 请查看我们的文档 你还可以使用MLX LM 通过内置服务器分布式提供模型服务 我们迫不及待地想看到你用 MLX在Apple Silicon上构建的作品!
-
-
8:31 - Hostfile format for a 4-node MLX cluster
[ { "ssh": "m3-ultra-0", "ips": ["192.168.1.10"], "rdma": [null, "rdma_en5", "rdma_en4", "rdma_en3"] }, { "ssh": "m3-ultra-1", "ips": ["192.168.1.11"], "rdma": ["rdma_en5", null, "rdma_en4", "rdma_en3"] }, { "ssh": "m3-ultra-2", "ips": ["192.168.1.12"], "rdma": ["rdma_en5", "rdma_en4", null, "rdma_en3"] }, { "ssh": "m3-ultra-3", "ips": ["192.168.1.13"], "rdma": ["rdma_en5", "rdma_en4", "rdma_en3", null] } ] -
8:56 - Generate the cluster hostfile with mlx.distributed_config
mlx.distributed_config \ --hosts m3-ultra-0,m3-ultra-1,m3-ultra-2,m3-ultra-3 \ --output "m3-ultra-jaccl.json" \ --env MLX_METAL_FAST_SYNCH=1 \ --auto-setup \ --backend jaccl -
11:04 - Run distributed LLM inference with mlx_lm.chat
# Single-device LLM inference mlx_lm.chat --model "Qwen/Qwen3.6-27B" --max-tokens 2048 # Distributed LLM inference across the cluster mlx.launch --hostfile "m3-ultra-jaccl.json" -- \ /remote/path/to/mlx_lm.chat --model "Qwen/Qwen3.6-27B" --max-tokens 2048 -
15:03 - Run distributed inference with pipeline parallelism
# Tensor parallelism (default) mlx.launch --hostfile "m3-ultra-jaccl.json" -- \ /remote/path/to/mlx_lm.chat --model "moonshotai/Kimi-K2.6" \ --max-tokens 2048 # Pipeline parallelism — append --pipeline flag mlx.launch --hostfile "m3-ultra-jaccl.json" -- \ /remote/path/to/mlx_lm.chat --model "moonshotai/Kimi-K2.6" \ --max-tokens 2048 \ --pipeline -
17:18 - Run distributed fine-tuning with mlx_lm.lora
# Single-device fine-tuning mlx_lm.lora --model "Qwen/Qwen3.5-9B" \ --data "mlx-community/wikisql" \ --train --batch-size 4 # Distributed fine-tuning (scale --batch-size by number of devices) mlx.launch --hostfile "hostfile.json" -- \ /remote/path/to/mlx_lm.lora --model "Qwen/Qwen3.5-9B" \ --data "mlx-community/wikisql" \ --train --batch-size 16 -
19:01 - Distributed inference with the MLX LM Python API
import mlx.core as mx from mlx_lm import stream_generate from mlx_lm.utils import sharded_load # Initialise distributed backend group = mx.distributed.init(strict=True, backend="jaccl") # Define parallelism tensor_group, pipeline_group = group, None # Shard the model model, tokenizer = sharded_load("moonshotai/Kimi-K2.6", pipeline_group, tensor_group) for response in stream_generate(model, tokenizer, prompt, max_tokens=1024): if group.rank() == 0: print(response.text, end="", flush=True) -
19:31 - Shard a layer with the MLX Python API
import mlx.core as mx import mlx.nn as nn # Initialise distributed backend group = mx.distributed.init(strict=True, backend="jaccl") # Define layer and shard it column-wise layer = nn.Linear(1024, 1024) sharded_layer = nn.layers.distributed.shard_linear( layer, strategy="all-to-sharded", group=group ) data = mx.random.normal((1, 1, 1024)) output = sharded_layer(data) mx.eval(output) -
19:47 - All-reduce across devices in Python, Swift, and C++
# Python import mlx.core as mx world = mx.distributed.init(strict=True, backend="jaccl") data = mx.full((4,), float(world.rank()), dtype=mx.float32) result = mx.distributed.all_sum(data, group=world) mx.eval(result) # Swift let group = try DistributedGroup(strict: .ring) let data = rank == 0 ? MLXArray(converting: [1.0, 2.0, 3.0]) : MLXArray(converting: [5.0, 6.0, 7.0]) let result = try group.allSum(data) // C++ namespace mx = mlx::core; auto world = mx::distributed::init(/* strict */ true, "jaccl"); mx::array data = mx::full({4}, static_cast<float>(world.rank()), mx::float32); mx::array result = mx::distributed::all_sum(data, world); mx::eval(result); -
20:06 - Standalone distributed sum with the JACCL C++ API
#include <jaccl/jaccl.h> #include <iostream> int main() { // Initialize JACCL group auto group = jaccl::init(); std::cout << "Rank " << group->rank() << " of " << group->size() << std::endl; // Perform all-reduce sum float data[10] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f}; float output[10]; group->all_sum(data, output, sizeof(data), jaccl::Float32); std::cout << "Result: " << output[0] << std::endl; return 0; }
-
-
- 0:00 - Introduction
Overview of why distributed AI becomes necessary as models grow larger, and a preview of what the session covers: CLI tools, Python API, and Swift for embedding distributed workflows in your apps.
- 2:09 - Distributed communication
A walkthrough of the full hardware and software stack enabling distributed workloads on Apple silicon: RDMA over Thunderbolt 5 for low-latency data movement, JACCL (open-source collective communication library), and MLX as the ML framework that ties them together.
- 4:32 - Setting up your cluster
How to physically connect four M3 Ultras into a cluster — understanding latency vs. bandwidth trade-offs, choosing between mesh and ring topologies, enabling RDMA in System Settings, and using mlx.distributed_config and mlx.launch to configure and orchestrate the cluster.
- 10:33 - Distributed inference and fine-tuning
How to run distributed LLM inference with MLX LM using a single CLI command — wrapping mlx_lm.chat with mlx.launch to shard a 27B-parameter Qwen model across four M3 Ultras, achieving nearly 3x the token generation rate of a single machine.
- 13:35 - Model parallelism strategies
How MLX LM splits large models across machines using tensor parallelism (splitting by width for faster inference) and pipeline parallelism (splitting by depth for simpler communication) — including a demo running the 1-trillion-parameter Kimi 2.6 model across four Macs.
- 15:53 - Distributed fine-tuning
How data-parallel training accelerates fine-tuning by replicating the model across machines, processing different data batches in parallel, and averaging gradients — demonstrated fine-tuning Qwen 3.5 (9B) at over 3x throughput on the cluster versus a single M3 Ultra.
- 18:34 - CLI, Python, Swift, and C++ APIs
How to use MLX's fine-grained Python, Swift, and C++ APIs for distributed inference — initializing a distributed group, sharding models with tensor parallelism, using low-level all_reduce primitives, and leveraging JACCL standalone for non-ML distributed workloads.
- 20:45 - Next steps
Summary of the full distributed stack — from RDMA over Thunderbolt to MLX and MLX LM — and next steps including the companion session on local agentic AI, documentation on custom parallelism strategies, and the built-in MLX LM distributed server.