跳转至

分布式训练场景下ModelEMA的优化

本文写于2024年9月9号晚22点

一、前言

有一天白天喝茶饮料喝多了,怎么也睡不着。于是尝试想一想ModelEMA的分布式优化版本,由于不满足于这种系统实现上的优化,手推公式一顿近似化简想把ModelEMA的行为放到优化器中,结果第二天一早实现后Loss NaN。

就这样拖了一周,再到后面重新思考ModelEMA的分布式实现,刚好看到torch官方zero2的源码,所以将其参数平均分配到各个rank的算法移植过来,再加上将不连续内存合并的优化,效果真的很惊艳。

最开始的想法启发于zero2,即针对EMA运算,每个计算卡分别存储和计算模型整体参数的一小部分,在模型评估阶段再对所有参数进行all gather操作。

以8卡分布式数据并行为例,记原始yolov5 EMA更新操作单个training step 55ms,开启分布式EMA后达到7ms(接近55/8),开启内存合并后降为0.5ms,速度提升了100多倍!

由于小模型参数量不大,所以没有计算节省的参数量。

本文相关代码开源在

https://github.com/thb1314/distributed_modelema

二、实现原理

ModelEMA可以有效地缓解过拟合问题并提升泛化性,在自监督学习任务,比如MoCO、DINO、BYOL等,ModelEMA在该类训练任务中起重要作用。

ModelEMA的更新过程如下

\[ \zeta_{t+1}=\rho \zeta_t+(1-\rho) \theta_t \]

其中\(\zeta\)表示EMA模型参数(初始化为模型参数),\(\theta_t\)表示optimizer更新后的模型参数。

启发于ZERO2,针对上述运算,假设总共的计算单元有world_size个,我们可以将原版模型的参数分为world_size组,每组参数分别在各自的计算单元中执行EMA操作,最后在有需要的时候再将参数all gather到所有机器中。

那么什么时候是“有需要的时候”呢?即“需要采用EMA后的模型做评估的时候”

三、实现细节

若要实现分布式版本EMA,参数分配算法和参数同步直观重要。本文提出的实现,首先将state_dict中的参数转换为parameter中的参数,接着采用参数分配算法将parameter划分到各个rank中,然后在有必要的时候执行EMA同步操作,同时可以有选择性地采用Tensor合并算法对EMA过程进行优化。

3.1 参数接口转换

由于state_dict中参数都是detach后的,如下代码片段实现将model.state_dict()中的参数转换为(来不及解释了,看如下源码吧)

# 收集原模型 state_dict
self._ori_state_dict:Dict[str, nn.Parameter] = de_parallel(model).state_dict()
# replace to original parameter
# 原模型 state_dict 与 pamameter和buffer中的参数data_ptr相同
ori_param_dict = {param.data_ptr():param for param in de_parallel(model).parameters()}
ori_param_dict.update({buffer.data_ptr():buffer for buffer in de_parallel(model).buffers()})

# 统计不需要ema的参数
self._no_need_ema_dict = dict()
for name, param in self._ori_state_dict.items():
    if param.data_ptr() in ori_param_dict and param.dtype.is_floating_point:
        self._ori_state_dict[name] = ori_param_dict[param.data_ptr()]
        else:
            self._no_need_ema_dict[name] = param
            for rm_name in self._no_need_ema_dict:
                self._ori_state_dict.pop(rm_name)

3.2 参数分配算法

参考ZeroRedundancyOptimizer的实现,partition_parameters 方法会将参数进行分区,根据参数大小(而不是使用顺序)以排序贪婪(sorted-greedy)算法来对优化器状态进行分片,在每个rank中打包一些参数,这样每个参数都属于一个rank,不在ranks之间划分。分区是任意的,可能与参数注册或使用顺序不匹配。这是为了确保每个rank具有几乎相同大小的显存。

def partition_parameters(self) -> List[Dict[str, nn.Parameter]]:
    r"""
        Partitions parameters across distributed data parallel ranks.

        Returns:
            a list of ``param_groups`` (which is a list of dict) where each
            element of the list contains the param_groups for a rank. Element 0
            corresponds to rank 0, etc. We need all the ranks for the broadcast
            inside ``get_model_state_dict()``.
        """
    if len(self._partition_parameters_cache) == 0:
        self._partition_parameters_cache = [dict() for _ in range(self.world_size)]
        # 生成一个数组,用来记录每个rank的大小,一共有world size个rank
        sizes = [0] * self.world_size

        # 遍历参数组
        param_lists: List[List[Tuple[str, nn.Parameter]]] = [list() for _ in range(self.world_size)]
            for name, param in self._ori_state_dict.items():
                # add this param to rank with smallest size
                # 找到最小的那个rank
                rank = sizes.index(min(sizes))
                # 把参数放到最小rank之中
                param_lists[rank].append((name, param))
                # 增加rank的大小
                sizes[rank] += param.numel()

                # 遍历list
                for rank, param_tuple_list in enumerate(param_lists):
                    for name, param in param_tuple_list:
                        self._partition_parameters_cache[rank][name] = param

                        return self._partition_parameters_cache

这里就分区好了,最终返回一个param_groups 的列表(这是一个dict列表),列表的每个元素都包含一个rank的param_groups,比如元素0对应于rank 0,每个rank的group的参数有差不多大小。

3.3 同步EMA参数

需要注意的是get_model_state_dict需要每个rank都得执行,通过判断参数是在当前rank下还是其他rank下来获取源头的rank地址,之后执行dist.broadcast来广播tensor到其他rank。

    def get_model_state_dict(self, strict=True):
        ema_state_dict = OrderedDict()
        ori_state_dict = OrderedDict()
        handles = []

        for key in self._ori_state_dict:
            if key in self._no_need_ema_dict:
                if not strict:
                    continue
                # adopt its original reference
                ema_state_dict[key] = self._no_need_ema_dict[key]
                ori_state_dict[key] = self._no_need_ema_dict[key]
            elif key in self._ori_state_dict:
                # send parameters
                if key in self._cur_rank_param:
                    param_value = self._cur_rank_param[key]
                    ema_state_dict[key] = param_value
                    ori_state_dict[key] = self._ori_cur_rank_param[key].detach().clone()
                    if self.world_size > 1:
                        handles.append(dist.broadcast(tensor=param_value.data, src=self.rank, group=self.group, async_op=True))
                elif key in self._other_rank_param:
                    param_value = self._other_rank_param[key]
                    src_rank = self._other_param2rank[param_value]
                    ori_state_dict[key] = param_value.detach().clone()
                    param_value = param_value.detach().clone()
                    ema_state_dict[key] = param_value
                    if self.world_size > 1:
                        handles.append(dist.broadcast(tensor=param_value.data, src=src_rank, group=self.group, async_op=True))
                else:
                    raise RuntimeError(f"{key} not in parameter list")
            else:
                raise RuntimeError(f"{key} not in parameter list")

        _ = list(map(lambda x: x.wait(), handles))
        return ema_state_dict, ori_state_dict

这里需要注意的是,broadcast操作是异步的。

3.4 Tensor合并

如果设置了parameters_as_bucket_view,则调用建立若干buffer。同样设备上同样rank的张量合并一个buffer,这里需要注意的是个别处理的字节对齐问题,本文实现的是8字节对齐版本。

        if parameters_as_bucket_view and self._ori_cur_rank_param:
            device = next(iter(self._ori_cur_rank_param.values())).device
            dtype = next(iter(self._ori_cur_rank_param.values())).dtype
            buffer_size = 0
            # 8 bytes aligned
            grid_size = 8 // item_size_dict[dtype]

            # 统计参数排序信息
            for key, param in self._ori_cur_rank_param.items():
                offset_start = buffer_size
                buffer_size += (param.numel() + grid_size - 1) // grid_size * grid_size
                self._bucket_data_info_dict[key] = {
                    "offset_start": offset_start,
                    "offset_end": buffer_size,
                    "real_size": param.numel()
                }
               # 初始化 bucket 参数大小
            bucket = nn.Parameter(torch.empty(buffer_size, dtype=dtype, device=device), requires_grad=False)
            self._ori_cur_rank_bucket = bucket

            # 根据偏移 copy 原始数据
            for key, param in self._ori_cur_rank_param.items():
                data_info_dict = self._bucket_data_info_dict[key]
                offset = data_info_dict['offset_start']
                offset_next = offset + data_info_dict['real_size']
                bucket[offset:offset_next].copy_(param.data.flatten(), non_blocking=False)
                param.data = bucket[offset:offset_next].view_as(param.data)

        self._cur_rank_param:Dict[str, nn.Parameter] = dict()
        self._cur_rank_bucket:Optional[nn.Parameter] = None
        if self._ori_cur_rank_bucket is not None:
            self._cur_rank_bucket = self._ori_cur_rank_bucket.detach().clone()

        # 如果设置了bucket则将param.data指向buffer中的区域,从而param的更新会自动更新到buffer
        for name, param in self._ori_cur_rank_param.items():
            param = param.detach().clone()
            self._cur_rank_param[name] = param
            param.requires_grad_(False)
            if self._cur_rank_bucket is not None:
                data_info_dict = self._bucket_data_info_dict[name]
                offset = data_info_dict["offset_start"]
                offset_next = offset + data_info_dict["real_size"]
                param.data = self._cur_rank_bucket[offset:offset_next].view_as(param.data)

开启Tensor合并后,大大提高的程序的并行化程度,从而获得极致的加速效果。

总结

ModelEMA的分布式实现不算很难,举一反一,很朴素的想法。

尽管思路简单,但是效果真的很惊艳。

参考文献与链接


最后更新: September 17, 2024
创建日期: September 17, 2024

评论