Flex Checkpoint工作记录
1. Flex Checkpoint关键组件
1.1 reshard_sharded_state_dict
def reshard_sharded_state_dict(
src_sharded_state_dict: ShardedStateDict,
dst_sharded_state_dict: ShardedStateDict,
process_group: Group,
coordinator_rank: int | None = 0,
offload: bool | None = False,
aoa_config: dist[str, list[str]] | None = None,
) -> None:
local_src_state_dict_shard_info = {
key: (
value.global_offset,
value.local_shape,
str(value.local_tensor.dtype).split(".")[-1],
value.global_shape,
value.is_flattened,
)
for key, value in src_sharded_state_dict.items()
}
global_src_state_dict_shard_info = []
dist.all_gather_object(
global_src_state_dict_shard_info,
local_src_state_dict_shard_info,
group=process_group,
)
src_state_dict_shard_info = {}
for rank_shard_info in global_src_state_dict_shard_info:
for key, tensor_shard_info in rank_shard_info.items():
if key not in src_state_dict_shard_info:
src_state_dict_shard_info[key] = []
src_state_dict_shard_info[key].append(tensor_shard_info)
# check validity
check_src_state_dict_validity(src_state_dict_shard_info)
local_dst_state_dict_shard_info = {
key: (
value.global_offset,
value.local_shape,
str(value.local_tensor.dtype).split(".")[-1],
value.global_shape,
value.is_flattened,
)
for key, value in dst_sharded_state_dict.items()
}
global_dst_state_dict_shard_info = []
dist.all_gather_object(
global_dst_state_dict_shard_info,
local_dst_state_dict_shard_info,
group=process_group,
)
dst_state_dict_shard_info = {}
for rank_shard_info in global_dst_state_dict_shard_info:
for key, tensor_shard_info in rank_shard_info.items():
if key not in dst_state_dict_shard_info:
dst_state_dict_shard_info[key] = []
dst_state_dict_shard_info[key].append(tensor_shard_info)
# check validity
check_dst_state_dict_validity(dst_state_dict_shard_info)
check_src_dst_state_dict_validity(
src_state_dict_shard_info, dst_state_dict_shard_info
)
# build metadata
state_dict_metadata = {
tensor_name: [
LocalTensorMetadata(
global_offset=shard_info[0],
local_shape=shard_info[1],
dtype=shard_info[2],
)
for shard_info in shard_infos
]
for tensor_name, shard_infos in src_state_dict_shard_info.items()
}
virtual_file_path = f"vfile_{dist.get_rank()}"
local_storage_metadata = {
LocalTensorIndex(
tensor_key=value.key,
global_offset=value.global_offset,
): virtual_file_path
for key, value in src_sharded_state_dict.items()
}
global_storage_metadata: list[dict[LocalTensorIndex, str]] = []
dist.all_gather_object(
global_storage_metadata,
local_storage_metadata,
group=process_group,
)
# Merge storage metadata
storage_metadata: dict[LocalTensorIndex, str] = {}
for rank_storage_metadata in global_storage_metadata:
storage_metadata.update(rank_storage_metadata)
# Prepare metadata for loading
metadata = Metadata(
state_dict_metadata=state_dict_metadata,
storage_metadata=storage_metadata,
flat_mapping=None,
)
# Extract local tensors
src_state_dict = {
key: value.local_tensor for key, value in src_sharded_state_dict.items()
}
dst_state_dict = dst_sharded_state_dict
# reshard using _load_state_dict
_load_state_dict(
target_state_dict=dst_state_dict,
source_state_dict={virtual_file_path: src_state_dict},
metadata_list=[metadata],
coordinator_rank=coordinator_rank,
process_group=process_group,
offload=offload,
)
这个函数实际是为了构建reshard过程中需要的metadata,实际的reshard操作,在load_state_dict里面。state_dict_metadata 和 storage_metadata 最终都包含了所有 rank 的分片信息,是全局的完整信息。
这里使用virtual_file_path是因为此时实际的数据已经可以取到,即每个rank上local_tensor的实际值,无需再从文件中读取,这么做是为了整个格式上的对齐。
1.1.1 全局信息的构建过程
state_dict_metadata 的构建,state_dict_metadata用来保存Tensor的全局元数据信息
# 步骤1:每个 rank 收集自己的分片信息
local_src_state_dict_shard_info = {
key: (
value.global_offset,
value.local_shape,
str(value.local_tensor.dtype).split(".")[-1],
value.global_shape,
value.is_flattened,
)
for key, value in src_sharded_state_dict.items()
}
# 步骤2:全局收集所有 rank 的信息
global_src_state_dict_shard_info = []
dist.all_gather_object(
global_src_state_dict_shard_info,
local_src_state_dict_shard_info,
group=process_group,
)
# 结果:每个 rank 都有所有 rank 的信息
global_src_state_dict_shard_info = [
# rank 0 的信息
{"linear.weight": ((0, 0), (256, 512), "float32", (1024, 512), False)},
# rank 1 的信息
{"linear.weight": ((256, 0), (256, 512), "float32", (1024, 512), False)},
# rank 2 的信息
{"linear.weight": ((512, 0), (256, 512), "float32", (1024, 512), False)},
# rank 3 的信息
{"linear.weight": ((768, 0), (256, 512), "float32", (1024, 512), False)},
]
# 步骤3:重组为按张量分组的全局信息
src_state_dict_shard_info = {
"linear.weight": [
((0, 0), (256, 512), "float32", (1024, 512), False), # rank 0
((256, 0), (256, 512), "float32", (1024, 512), False), # rank 1
((512, 0), (256, 512), "float32", (1024, 512), False), # rank 2
((768, 0), (256, 512), "float32", (1024, 512), False), # rank 3
]
}
# 步骤4:构建全局的 state_dict_metadata
state_dict_metadata = {
"linear.weight": [
LocalTensorMetadata(global_offset=(0, 0), local_shape=(256, 512), dtype="float32"), # rank 0
LocalTensorMetadata(global_offset=(256, 0), local_shape=(256, 512), dtype="float32"), # rank 1
LocalTensorMetadata(global_offset=(512, 0), local_shape=(256, 512), dtype="float32"), # rank 2
LocalTensorMetadata(global_offset=(768, 0), local_shape=(256, 512), dtype="float32"), # rank 3
]
}
storage_metadata 的构建,storage_metadata 用来保存Tensor实际数据保存的位置信息
# 步骤1:每个 rank 构建自己的存储映射
virtual_file_path = f"vfile_{dist.get_rank()}"
local_storage_metadata = {
LocalTensorIndex(
tensor_key=value.key,
global_offset=value.global_offset,
): virtual_file_path
for key, value in src_sharded_state_dict.items()
}
# rank 0 的本地映射
local_storage_metadata = {
LocalTensorIndex("linear.weight", (0, 0)): "vfile_0",
}
# 步骤2:全局收集所有 rank 的存储映射
global_storage_metadata: list[dict[LocalTensorIndex, str]] = []
dist.all_gather_object(
global_storage_metadata,
local_storage_metadata,
group=process_group,
)
# 结果:每个 rank 都有所有 rank 的存储映射
global_storage_metadata = [
# rank 0 的映射
{LocalTensorIndex("linear.weight", (0, 0)): "vfile_0"},
# rank 1 的映射
{LocalTensorIndex("linear.weight", (256, 0)): "vfile_1"},
# rank 2 的映射
{LocalTensorIndex("linear.weight", (512, 0)): "vfile_2"},
# rank 3 的映射
{LocalTensorIndex("linear.weight", (768, 0)): "vfile_3"},
]
# 步骤3:合并为全局的 storage_metadata
storage_metadata: dict[LocalTensorIndex, str] = {}
for rank_storage_metadata in global_storage_metadata:
storage_metadata.update(rank_storage_metadata)
# 最终的全局 storage_metadata
storage_metadata = {
LocalTensorIndex("linear.weight", (0, 0)): "vfile_0", # rank 0
LocalTensorIndex("linear.weight", (256, 0)): "vfile_1", # rank 1
LocalTensorIndex("linear.weight", (512, 0)): "vfile_2", # rank 2
LocalTensorIndex("linear.weight", (768, 0)): "vfile_3", # rank 3
}
1.1.2 为什么需要全局信息?
重分片需要完整的分片信息
# 重分片过程:
# 源:4 个分片 -> 目标:2 个分片
# 需要知道所有源分片的信息才能正确重分片
source_shards = [
((0, 0), (256, 512)), # rank 0
((256, 0), (256, 512)), # rank 1
((512, 0), (256, 512)), # rank 2
((768, 0), (256, 512)), # rank 3
]
# 目标分片需要从多个源分片组合数据
target_shard_0 = combine(source_shards[0], source_shards[1]) # 需要 rank 0 和 rank 1 的数据
target_shard_1 = combine(source_shards[2], source_shards[3]) # 需要 rank 2 和 rank 3 的数据
数据访问需要全局映射
# _load_state_dict 需要知道:
# 1. 每个分片在哪里(storage_metadata)
# 2. 每个分片的形状和位置(state_dict_metadata)
def load_shard(tensor_name, global_offset):
# 根据全局信息找到对应的分片
index = LocalTensorIndex(tensor_name, global_offset)
file_path = storage_metadata[index] # "vfile_0"
# 从对应的数据源获取数据
if file_path in source_state_dict:
return source_state_dict[file_path][tensor_name]
验证需要全局视图
# 验证分片完整性需要全局信息
def validate_completeness():
# 检查是否所有分片都存在
expected_shards = [
(0, 0), (256, 0), (512, 0), (768, 0)
]
for offset in expected_shards:
index = LocalTensorIndex("linear.weight", offset)
if index not in storage_metadata:
raise ValueError(f"Missing shard at {offset}")
1.2 utils相关工具组件总结
1. 索引转换工具
ravel_index(indices, shape)
def ravel_index(indices, shape):
idx = 0
for i, dim in zip(indices, shape):
idx = idx * dim + i
return idx
作用:将多维索引转换为线性索引(行优先顺序)
详细解释:
# 例子:shape = (2, 3, 4)
# 多维索引 (1, 2, 3) 转换为线性索引
# 计算过程:
# i=0: idx = 0 * 2 + 1 = 1
# i=1: idx = 1 * 3 + 2 = 5
# i=2: idx = 5 * 4 + 3 = 23
# 结果:线性索引 = 23
# 验证:在2×3×4的张量中,位置(1,2,3)的线性索引确实是23
应用场景:
- 将多维张量的位置转换为内存中的线性地址
- 在分片计算中定位元素在全局张量中的位置
unravel_index(idx, shape)
def unravel_index(idx, shape):
indices = []
for dim in reversed(shape):
indices.append(idx % dim)
idx //= dim
return tuple(reversed(indices))
作用:将线性索引转换为多维索引
详细解释:
# 例子:shape = (2, 3, 4), idx = 23
# 线性索引 23 转换为多维索引
# 计算过程(从右到左):
# dim=4: indices.append(23 % 4 = 3), idx = 23 // 4 = 5
# dim=3: indices.append(5 % 3 = 2), idx = 5 // 3 = 1
# dim=2: indices.append(1 % 2 = 1), idx = 1 // 2 = 0
# 结果:多维索引 = (1, 2, 3)
应用场景:
- 从内存地址恢复多维张量的位置
- 在分片重建时确定元素在全局张量中的坐标
2. 切片计算工具
minimal_nd_slice(shape, flat_start, flat_end)
def minimal_nd_slice(shape, flat_start, flat_end):
start_idx = unravel_index(flat_start, shape)
end_idx = unravel_index(flat_end - 1, shape)
min_slices = []
for axis in range(len(shape)):
if axis == 0:
s = start_idx[axis]
e = end_idx[axis] + 1
else:
if start_idx[axis - 1] == end_idx[axis - 1]:
s = min(start_idx[axis], end_idx[axis])
e = max(start_idx[axis], end_idx[axis]) + 1
else:
s = 0
e = shape[axis]
min_slices.append((s, e))
return min_slices, start_idx, end_idx
作用:计算包含给定扁平化范围的最小N维切片
详细解释:
# 例子:shape = (4, 3), flat_start = 5, flat_end = 8
# 扁平化范围 [5, 8) 转换为最小切片
# 计算过程:
# start_idx = unravel_index(5, (4, 3)) = (1, 2)
# end_idx = unravel_index(7, (4, 3)) = (2, 1)
# 对于axis=0:
# s = 1, e = 2 + 1 = 3
# 对于axis=1:
# start_idx[0] = 1, end_idx[0] = 2, 不相等
# 所以 s = 0, e = 3
# 结果:min_slices = [(1, 3), (0, 3)]
# 这表示需要切片 [1:3, 0:3]
应用场景:
- 将扁平化的索引范围转换为最优的多维切片
- 减少数据传输量,提高效率
flat_range_in_min_slice(shape, min_slices, flat_start, flat_end)
def flat_range_in_min_slice(shape, min_slices, flat_start, flat_end):
min_starts = tuple(s[0] for s in min_slices)
min_flat_start = ravel_index(min_starts, shape)
return flat_start - min_flat_start, flat_end - min_flat_start
作用:计算在最小切片中的相对扁平化范围
详细解释:
# 例子:shape = (4, 3), min_slices = [(1, 3), (0, 3)]
# flat_start = 5, flat_end = 8
# 计算过程:
# min_starts = (1, 0)
# min_flat_start = ravel_index((1, 0), (4, 3)) = 3
# 相对范围 = (5 - 3, 8 - 3) = (2, 5)
# 这表示在最小切片内的相对位置
应用场景:
- 计算在切片内的相对偏移
- 用于精确的数据提取和复制
3. 状态字典检查工具
is_sharded_state_dict(o)
def is_sharded_state_dict(o):
if not isinstance(o, dict):
return False
values = list(o.values())
has_sharded_weight = any(isinstance(v, ShardedWeight) for v in values)
if has_sharded_weight:
if not all(isinstance(v, ShardedWeight) for v in values):
raise TypeError(
"All values must be ShardedWeight if any value is ShardedWeight."
)
return True
else:
return False
作用:检查字典是否为分片状态字典
详细解释:
# 检查规则:
# 1. 必须是字典类型
# 2. 如果任何值是ShardedWeight,则所有值都必须是ShardedWeight
# 3. 不允许混合类型
# 例子:
valid_dict = {
"weight": ShardedWeight(...),
"bias": ShardedWeight(...)
} # 返回 True
invalid_dict = {
"weight": ShardedWeight(...),
"bias": paddle.Tensor(...)
} # 抛出TypeError
应用场景:
- 验证检查点格式的正确性
- 确保状态字典的一致性
4. 重叠区域计算工具
get_overlap_region(desc_offset, desc_shape, shard_offset, shard_shape)
def get_overlap_region(desc_offset, desc_shape, shard_offset, shard_shape):
ndim = len(desc_offset)
overlap_offset = []
overlap_shape = []
desc_starts = []
shard_starts = []
for i in range(ndim):
desc_lo = desc_offset[i]
desc_hi = desc_offset[i] + desc_shape[i]
shard_lo = shard_offset[i]
shard_hi = shard_offset[i] + shard_shape[i]
# overlap
lo = max(desc_lo, shard_lo)
hi = min(desc_hi, shard_hi)
if lo >= hi:
return False, None, None, None, None
overlap_offset.append(lo)
overlap_shape.append(hi - lo)
desc_starts.append(lo - desc_lo)
shard_starts.append(lo - shard_lo)
return True, overlap_offset, overlap_shape, desc_starts, shard_starts
作用:计算两个分片之间的重叠区域
详细解释:
# 例子:2D张量
# desc: offset=(0,0), shape=(4,4)
# shard: offset=(2,2), shape=(4,4)
# 计算过程:
# 维度0:
# desc_lo=0, desc_hi=4, shard_lo=2, shard_hi=6
# lo = max(0,2) = 2, hi = min(4,6) = 4
# overlap_offset[0] = 2, overlap_shape[0] = 2
# desc_starts[0] = 2-0 = 2, shard_starts[0] = 2-2 = 0
# 维度1:
# desc_lo=0, desc_hi=4, shard_lo=2, shard_hi=6
# lo = max(0,2) = 2, hi = min(4,6) = 4
# overlap_offset[1] = 2, overlap_shape[1] = 2
# desc_starts[1] = 2-0 = 2, shard_starts[1] = 2-2 = 0
# 结果:
# 重叠区域:offset=(2,2), shape=(2,2)
# 在desc中的起始:(2,2)
# 在shard中的起始:(0,0)
应用场景:
- 计算不同分片策略间的数据重叠
- 为数据复制提供精确的范围信息
5. 分片数据复制工具
assign_sharded_slice(src_desc, src_shard, dst_desc, dst_shard)
def assign_sharded_slice(src_desc, src_shard, dst_desc, dst_shard):
# 1. 计算源分片的重叠区域
src_has, _, overlap_shape, src_desc_starts, src_shard_starts = (
get_overlap_region(
src_desc.global_offset,
src_desc.local_shape,
src_shard.global_offset,
src_shard.local_shape,
)
)
# 2. 计算目标分片的重叠区域
dst_has, _, overlap_shape2, dst_desc_starts, dst_shard_starts = (
get_overlap_region(
dst_desc.global_offset,
dst_desc.local_shape,
dst_shard.global_offset,
dst_shard.local_shape,
)
)
# 3. 验证重叠区域一致性
assert src_has or dst_has, "no overlap!"
assert overlap_shape == overlap_shape2, "overlap shape mismatch!"
# 4. 执行数据复制
axes = list(range(len(overlap_shape)))
src_tensor_slice = paddle.slice(
src_shard.local_tensor,
axes=axes,
starts=src_shard_starts,
ends=[s + o for s, o in zip(src_shard_starts, overlap_shape)],
)
dst_tensor_slice = paddle.slice(
dst_shard.local_tensor,
axes=axes,
starts=dst_shard_starts,
ends=[s + o for s, o in zip(dst_shard_starts, overlap_shape)],
)
paddle.assign(src_tensor_slice, dst_tensor_slice)
作用:在不同分片间复制重叠数据
详细解释:
# 完整流程:
# 1. 计算源分片与描述符的重叠区域
# 2. 计算目标分片与描述符的重叠区域
# 3. 验证两个重叠区域的一致性
# 4. 从源分片提取重叠部分
# 5. 复制到目标分片
# 例子:从tp2转换到tp4
# src_desc: 描述tp2时的分片布局
# src_shard: tp2时的实际数据
# dst_desc: 描述tp4时的分片布局
# dst_shard: tp4时的目标数据
# 函数会:
# 1. 找到tp2和tp4分片的重叠部分
# 2. 将tp2的重叠数据复制到tp4的对应位置
应用场景:
- 分片策略转换时的数据重分布
- 检查点加载时的数据恢复
6. 信息合并工具
merge_shard_info_list(list_of_dicts)
def merge_shard_info_list(list_of_dicts):
merged = defaultdict(list)
for info in list_of_dicts:
for k, v in info.items():
merged[k].extend(v)
return dict(merged)
作用:合并多个分片信息字典
详细解释:
# 例子:
list_of_dicts = [
{"param1": [info1, info2]},
{"param1": [info3], "param2": [info4]},
{"param2": [info5, info6]}
]
# 合并结果:
merged = {
"param1": [info1, info2, info3],
"param2": [info4, info5, info6]
}
应用场景:
- 收集所有rank的分片信息
- 构建全局的分片视图
7. 描述符构建工具
build_shard_desc(val)
def build_shard_desc(val):
return ShardedWeightDesc(
key=val.key,
local_shape=tuple(val.local_shape),
global_shape=tuple(val.global_shape),
global_offset=tuple(val.global_offset),
)
作用:从ShardedWeight构建ShardedWeightDesc
详细解释:
# 转换过程:
# 输入:ShardedWeight对象(包含实际数据)
# 输出:ShardedWeightDesc对象(仅包含元数据)
# 例子:
sharded_weight = ShardedWeight(
key="linear.weight",
local_tensor=paddle.Tensor(...), # 实际数据
local_shape=(1024, 512),
global_shape=(1024, 2048),
global_offset=(0, 0)
)
# 转换为:
shard_desc = ShardedWeightDesc(
key="linear.weight",
local_shape=(1024, 512),
global_shape=(1024, 2048),
global_offset=(0, 0)
)
应用场景:
- 提取分片权重的元数据信息
- 用于分片信息的传输和存储
1.3 sharded_tensor的关键组件
class ShardedTensor:
"""
Represents a local shard of a distributed tensor parameter.
Args:
key (str): The name of the parameter.
local_tensor (Tensor): The local shard of the parameter.
local_shape (Tuple[int, ...]): The shape of the local shard.
global_shape (Tuple[int, ...]): The global logical shape of the parameter.
global_offset (Tuple[int, ...]): The offset of the local shard in the global parameter.
is_flattened (bool, optional): Whether the parameter has been flattened (used in sharding_v2 scenarios). Default is False.
flattened_range (slice, optional): If the parameter is flattened, this indicates the index range of the actual local shard within the local_tensor.
"""
def __init__(
self,
key: str,
local_tensor: Tensor,
local_shape: tuple[int, ...],
global_shape: tuple[int, ...],
global_offset: tuple[int, ...],
is_flattened: bool = False,
flattened_range: slice | None = None,
) -> None:
self.key = key
self.local_tensor = local_tensor
self.local_shape = local_shape
self.global_shape = global_shape
self.global_offset = global_offset
self.is_flattened = is_flattened
self.flattened_range = flattened_range
def __str__(self) -> str:
"""Returns a formatted string representation of the sharded tensor."""
return (
f"ShardedTensor(\n"
f" key={self.key},\n"
f" local_tensor={type(self.local_tensor).__name__}(shape={self.local_tensor.shape}),\n"
f" local_shape={self.local_shape},\n"
f" global_shape={self.global_shape},\n"
f" global_offset={self.global_offset},\n"
f" flattened_range={self.flattened_range}\n"
f")"
)
def shard_weight(
key: str,
weight: Tensor,
axis: int,
group: Group,
) -> ShardedTensor:
"""Creates a ShardedTensor by splitting the input tensor along a specified axis.
Args:
key: Unique identifier for the tensor.
weight: The input tensor to be sharded.
axis: The axis along which to shard the tensor.
group: The process group used for distributed communication.
Returns:
A ShardedTensor representing the local portion of the global tensor.
"""
if axis < 0 or axis >= len(weight.shape):
raise ValueError(
f"Shard axis {axis} is invalid for tensor with shape {weight.shape}"
)
# Get hybrid communication group and rank information
hcg = fleet.get_hybrid_communicate_group()
current_rank = group.rank
world_size = group.nranks
# Calculate shapes and offsets
local_shape = weight.shape
global_shape = deepcopy(local_shape)
global_shape[axis] = local_shape[axis] * world_size
global_shape = tuple(global_shape)
local_shape = tuple(local_shape)
global_offset = [0] * len(global_shape)
if world_size > 1:
global_offset[axis] = current_rank * local_shape[axis]
global_offset = tuple(global_offset)
return ShardedTensor(
key=key,
local_tensor=weight,
local_shape=local_shape,
global_shape=global_shape,
global_offset=global_offset,
)
def build_sharded_state_dict(
state_dict: dict[str, Tensor],
shard_rules: dict[str, int] | None = None,
prefix: str = "",
) -> dict[str, ShardedTensor]:
"""Converts a regular state dict to a sharded state dict based on sharding rules.
Args:
state_dict: The original state dictionary containing tensors
shard_rules: Dictionary mapping tensor names to their sharding axes.
If None, treated as empty dict (no tensor parallelism).
prefix: Optional prefix to prepend to all tensor keys
Returns:
Dictionary with the same keys as input but values converted to ShardedTensor
or regular Tensor based on sharding rules.
Note:
Tensors not in shard_rules will be wrapped as non-sharded ShardedTensors.
"""
shard_rules = shard_rules or {}
sharded_state_dict = {}
for key, tensor in state_dict.items():
full_key = f"{prefix}{key}" if prefix else key
if key in shard_rules:
# Apply tensor parallelism sharding
sharded_state_dict[full_key] = (
make_tp_sharded_tensor_for_checkpoint(
key=full_key,
tensor=tensor,
tensor_parallel_axis=shard_rules[key],
)
)
else:
# Create regular sharded tensor (non-tensor-parallel)
sharded_state_dict[full_key] = make_replicated_sharded_tensor(
key=full_key,
tensor=tensor,
)
return sharded_state_dict
主要是ShardedTensor类和build_sharded_state_dict、shard_weight两个接口,ShardedTensor主要是作为后续shard_state_dict中的基础单元,即{key:ShardedTensor},原来版本是普通的Tensor,而现在的ShardedTensor携带了Tensor切分的信息,主要是local_shape、global_shape、global_offset则可以据此对local_tensor进行全局tensor的重建,再对齐进行reshard。build_sharded_state_dict是在普通的state_dict的基础上,对于需要做分布式处理的(即shard)tensor进行切分标记,将tensor转化为ShardedTensor,make_tp_sharded_tensor_for_checkpoint其实就是做mp参数并行,里面调用的就是shard_weight接口,返回一个ShardedTensor;对于不需要切分的,也要用make_replicated_sharded_tensor处理,将其转化为统一的ShardedTensor类,这部分处理无需调用shard_weight,直接返回ShardedTensor,local_shape=global_shape,因为每个rank上保存的这部分数据都一样。而shard_weight,传入进来的tensor,对应切分的那个维度的数据,每个rank都不一样(对于shard组来说),因此将每个rank上该tensor的对应维度的shape加起来,即可得到global_shape,从而构造出具有分布式信息的Tensor。
1.4 load_state_dict的关键组件
1.4.1 get_rank_to_files(与原来的一致)
ef get_rank_to_files(
metadata_list,
local_data_files,
state_dict,
process_group,
use_dist,
mw_name_compatibility=True,
):
"""
Get the mapping of rank to its accessible files.
"""
# The necessary files to be read
tensor_key_list = []
necessary_files = []
mw_name_compatibility_mapping = {}
for metadata in metadata_list:
for local_tensor_index, file_name in metadata.storage_metadata.items():
assert (
local_tensor_index not in tensor_key_list
), f"Duplicate tensor_key:{local_tensor_index} found. Check whether the metadata."
tensor_key_list.append(local_tensor_index.tensor_key)
if local_tensor_index.tensor_key in state_dict:
necessary_files.append(file_name)
all_necessary_files = []
if use_dist:
paddle.distributed.all_gather_object(
all_necessary_files, necessary_files, process_group
)
else:
all_necessary_files.append(necessary_files)
global_necessary_files = [
file for files in all_necessary_files for file in files
]
global_necessary_files_set = set(global_necessary_files)
if len(global_necessary_files_set) <= 0:
logger.warning(
"No necessary data files found in the checkpoint directory. Please check the metadata."
)
missing_keys = set(state_dict.keys())
return {}, missing_keys, mw_name_compatibility_mapping
# allgather all accessible files
global_data_files = []
if use_dist:
paddle.distributed.all_gather_object(
global_data_files, local_data_files, process_group
)
else:
global_data_files.append(local_data_files)
tmp = []
for files in global_data_files:
tmp += files
global_data_files_set = set(tmp)
logger.debug(
f"necessary_data_files_set:{global_necessary_files_set}, global_data_files_set:{global_data_files_set}"
)
# check necessary files in global_data_files
assert (
global_data_files_set & global_necessary_files_set
== global_necessary_files_set
), f"The checkpoint files are not complete. Please check the checkpoint directory. global_data_files_set:{global_data_files_set}, necessary_data_files_set:{global_necessary_files_set}"
missing_keys = set(state_dict.keys()) - set(tensor_key_list)
if len(missing_keys) > 0:
if mw_name_compatibility:
mw_name_compatibility_mapping = _modify_mw_name_for_compatibility(
state_dict, missing_keys, tensor_key_list
)
if len(missing_keys) > 0:
logger.warning(
f"Missing keys:{missing_keys}, check whether the checkpoint is complete."
)
else:
logger.warning(
f"Missing keys:{missing_keys}, check whether the checkpoint is complete."
)
rank_to_files = {}
for rank, need_files in enumerate(all_necessary_files):
seen = set()
unique_need_files = [
f for f in need_files if not (f in seen or seen.add(f))
]
rank_to_files[rank] = unique_need_files
logger.debug(f"mapping rank_to_files:{rank_to_files}")
return rank_to_files, missing_keys, mw_name_compatibility_mapping
根据保存的storage_metadata,遍历当前rank上的state_dict,根据local_tensor_index.tensor_key是否在state_dict中,来确定是否需要当前local_tensor_index对应的文件,如果需要就添加到necessary_files中,all_necessary_files保存的是所有rank的necessary_files,如下:
all_necessary_files = [
["0_0.distcp", "1_0.distcp"], # rank 0 需要的文件
["2_0.distcp", "3_0.distcp"], # rank 1 需要的文件
["4_0.distcp", "5_0.distcp"], # rank 2 需要的文件
["6_0.distcp", "7_0.distcp"], # rank 3 需要的文件
]
即key就是rank id,value就是该rank需要的文件列表,seen是用来去重的。
1.5 paddlenlp适配
结论:因为只有 LlamaLMHead 的分片保存规则在本文件里需要“特殊约定”,其它层要么已经在各自实现里内建了 sharded_state_dict,要么可以用默认递归收集;而 LM Head 需要显式告诉检查点系统“按哪一维切”。
为什么只有 LlamaLMHead 和 PipelinePretrainedModel 需要适配?
1.LlamaLMHead
LM Head 的权重轴不固定:
LlamaLMHead支持transpose_y和词表并行(vocab parallel)。这会改变权重逻辑形状与“被切分的维度”:当
transpose_y=True且tie_word_embeddings时,weight形状是[vocab_size, hidden_size],切分轴应为axis=0。否则通常是
[hidden_size, vocab_size],切分轴应为axis=1。代码中专门计算了
axis = 0 if self.transpose_y else 1,然后:# L2000-L2006
state_dict = self.state_dict(structured_name_prefix="")
return build_sharded_state_dict(state_dict, {"weight": axis}, structured_name_prefix)这一步确保统一检查点能正确记录“词表维度”的切分方式,便于跨并行策略重构权重。
其它模块已有分片实现或可用默认机制:
- 注意力/MLP里用的
ColumnParallelLinear、RowParallelLinear(以及对应的 Sequence Parallel 版本)在它们各自的实现里已经处理了分片参数保存;模型其他权重(如LlamaRMSNorm.weight)不涉及并行切分轴的歧义,默认递归即可。 - 词嵌入
VocabParallelEmbedding也在并行库里有自己的分布式属性与导出路径。
- 注意力/MLP里用的
LM Head 还涉及权重共享与并行输出:
tie_word_embeddings时和Embedding共享权重,且is_distributed/split_axis被设置用于张量并行。- 因此 LM Head 成为“需要显式声明切分轴”的最特殊一层,避免统一检查点在重构/重分片(如从 TP2 切换到 TP4)时出错。
2.PipelinePretrainedModel
结论:因为只有 LlamaLMHead 的分片保存规则在本文件里需要“特殊约定”,其它层要么已经在各自实现里内建了 sharded_state_dict,要么可以用默认递归收集;而 LM Head 需要显式告诉检查点系统“按哪一维切”。
LM Head 的权重轴不固定:
LlamaLMHead支持transpose_y和词表并行(vocab parallel)。这会改变权重逻辑形状与“被切分的维度”:当
transpose_y=True且tie_word_embeddings时,weight形状是[vocab_size, hidden_size],切分轴应为axis=0。否则通常是
[hidden_size, vocab_size],切分轴应为axis=1。代码中专门计算了
axis = 0 if self.transpose_y else 1,然后:# L2000-L2006
state_dict = self.state_dict(structured_name_prefix="")
return build_sharded_state_dict(state_dict, {"weight": axis}, structured_name_prefix)这一步确保统一检查点能正确记录“词表维度”的切分方式,便于跨并行策略重构权重。
其它模块已有分片实现或可用默认机制:
- 注意力/MLP里用的
ColumnParallelLinear、RowParallelLinear(以及对应的 Sequence Parallel 版本)在它们各自的实现里已经处理了分片参数保存;模型其他权重(如LlamaRMSNorm.weight)不涉及并行切分轴的歧义,默认递归即可。 - 词嵌入
VocabParallelEmbedding也在并行库里有自己的分布式属性与导出路径。
- 注意力/MLP里用的
LM Head 还涉及权重共享与并行输出:
tie_word_embeddings时和Embedding共享权重,且is_distributed/split_axis被设置用于张量并行。- 因此 LM Head 成为“需要显式声明切分轴”的最特殊一层,避免统一检查点在重构/重分片(如从 TP2 切换到 TP4)时出错。
2.对相关的分布式API添加shard_state_dict处理
2.1 VocabParallelEmbedding
2.1.1 接收的输入
文本输入
用户输入: "Hello world, how are you?"
分词(Tokenization)
分词结果: ["Hello", "world", ",", "how", "are", "you", "?"]
词汇表映射(Vocabulary Mapping)
词汇表: {"<PAD>": 0, "<UNK>": 1, "<BOS>": 2, "<EOS>": 3,
"Hello": 4, "world": 5, ",": 6, "how": 7, "are": 8, "you": 9, "?": 10, ...}
映射结果: [4, 5, 6, 7, 8, 9, 10]
输入到模型为词汇ID序列
模型接收的输入: x = [4, 5, 6, 7, 8, 9, 10] (词汇ID序列)
因此,VocabParallelEmbedding接收到的输入x是[batch_size,seqlenth],即多组词汇ID序列。
2.1.2 处理输入
假设vocab_size=50000,embedding_dim=1024,即有50000个词,映射成向量用1024个特征表示,每个词对应一个1024长度的特征向量:
每个词汇ID对应矩阵中的一行:
word_id=0 -> W[0, :] = [0.1, 0.2, 0.3, ..., 0.1024]
word_id=1 -> W[1, :] = [0.5, 0.1, 0.8, ..., 0.2048]
word_id=2 -> W[2, :] = [0.3, 0.7, 0.2, ..., 0.3072]
...
word_id=499999 -> W[499999, :] = [0.9, 0.4, 0.6, ..., 0.1024]
输入为:
# 输入: x = [batch_size, seq_len] (词汇ID)
# 例如: x = [[100, 250000, 500000, 750000],
# [150, 250100, 500100, 750100]]
Vocab分割后:
GPU0: W[0:250000, :] (250000行,1024列)
GPU1: W[250000:500000, :] (250000行,1024列)
GPU2: W[500000:750000, :] (250000行,1024列)
GPU3: W[750000:1000000, :] (250000行,1024列)
并行化后的查找过程:
对于输入词汇ID,每个GPU的处理:
GPU0 (负责词汇0-249999):
- 输入ID=100: 查找 W[100, :] = [0.1, 0.2, ..., 0.1024]
- 输入ID=150: 查找 W[150, :] = [0.3, 0.4, ..., 0.1024]
- 输入ID=250000: 不在范围内,返回零向量或特殊处理
- 输入ID=500000: 不在范围内,返回零向量或特殊处理
GPU1 (负责词汇250000-499999):
- 输入ID=100: 不在范围内,返回零向量
- 输入ID=250000: 查找 W[250000, :] = [0.5, 0.6, ..., 0.1024]
- 输入ID=250100: 查找 W[250100, :] = [0.7, 0.8, ..., 0.1024]
- 输入ID=500000: 不在范围内,返回零向量
GPU2 (负责词汇500000-749999):
- 输入ID=500000: 查找 W[500000, :] = [0.9, 0.1, ..., 0.1024]
- 输入ID=500100: 查找 W[500100, :] = [0.2, 0.3, ..., 0.1024]
GPU3 (负责词汇750000-999999):
- 输入ID=750000: 查找 W[750000, :] = [0.4, 0.5, ..., 0.1024]
- 输入ID=750100: 查找 W[750100, :] = [0.6, 0.7, ..., 0.1024]
最终将每张卡的结果做allreduce合并,则得到最终结果,输出为:[batch_size, seq_len, embedding_dim]。
一开始该层权重是随机初始化的,即,每个词虽然都用向量表示,但此时是无意义的,经过训练后,相近的词embedding的数据会逐渐相似,从而在推理时,正确找到每个词的embedding。
2.2 ColumnParallelLayer与RowParallelLayer同时使用的关系
2.2.1 ColumnParallelLayer

2.2.2 RowParallelLayer

可以看到,RowParallelLayer在计算的过程中,需要把输入拆分成两列分别在两张卡上做计算,最终两张卡都得到Parital状态的数据,而如果上一层是ColumnParallel则其计算的结果刚好分配到两个设备上(即结果被按列切分),而此结果正是RowParallelLayer需要的输入,那么就无需做通信,直接继续计算最后再做allreduce即可。
2.2.3 ColumnParallelLayer与RowParallelLayer的w和bias的切分

注意,在做y=x*W^T+b的计算时,首先乘积得到的数据是[batchsize,output_size],每一行表示一个数据,而bias是分别和每一行相加,因此bias是一个一维的向量,因此,当W按列切分时,bias需要按行切分,从而保持正确的计算关系。
当添加了bias的时候,做RowParallelLayer和ColumnParallelLayer情况如下:
RowParallelLayer:

RowParallelLayer只切w,不切bias
ColumnParallelLayer:

ColumnParallelLayer切w的axis=1,切bias的axis=0
2.3 DygraphShardingOptimizerV2
核心目标
sharded_state_dict 是为了解决不同并行策略间状态转换的问题,以及V2情境下,optimizer被展开铺平的问题:
- 例如从 tp2 切换到 tp4:需要重新划分参数
- 保持数据完整性:确保参数和优化器状态正确转换
- 支持断点续训:在不同并行配置间无缝切换
实现方法总结
1. 分片信息收集阶段
# 第一步:收集当前分片策略的信息
for comm_group, buffers in comm_group_buffers.items():
for buffer in buffers:
for param_name, grad_view in buffer._sharding_param_grad_view.items():
# 记录每个参数在当前rank的分片范围
param_slice_info[param_name] = (
grad_view._param_begin, # 分片起始位置
grad_view._param_end, # 分片结束位置
)
# 记录参数的完整形状信息
param_shape_info[param_name] = (
grad_view._param.shape, # 原始形状
grad_view._param.numel().item(), # 元素总数
grad_view._index, # 分片索引
grad_view._padded_size, # 填充大小
)
目的:记录当前分片策略下每个参数如何被分配到各个rank。
2. 全局信息同步阶段
# 第二步:收集所有rank的分片信息
for comm_group, buffers in comm_group_buffers.items():
# 从当前rank收集信息
param_slice_info["sharding_rank"] = comm_group.rank
# 通过all_gather收集所有rank的信息
gathered_info = []
paddle.distributed.all_gather_object(
gathered_info, param_slice_info, group=comm_group
)
all_rank_slice_info.extend(gathered_info)
目的:让每个rank都知道完整的分片分布情况,为后续重建做准备。
3. 部分分片张量识别阶段
# 第三步:识别哪些张量是部分分片的
for param_key, tensor in optim_state_dict.items():
base_name, _ = _generate_base_static_name(param_key)
if int(tensor.numel()) > 1: # 非标量张量
begin, end = merged_slice_info[base_name]
shape_info = merged_shape_info[base_name]
# 判断是否为部分分片:分片大小 < 原始大小
if shape_info and end > begin and end - begin < shape_info[1]:
partial_tensor_names.append(base_name)
目的:区分完全分片和部分分片的张量,它们需要不同的处理策略。
4. 偏移映射计算阶段
# 第四步:计算每个rank在完整张量中的偏移位置
for tensor_name in partial_tensor_names:
offset_mapping[tensor_name] = [0] * world_size
# 记录每个rank的分片大小
for info in all_rank_slice_info:
if tensor_name in info:
begin, end = info[tensor_name]
if end > begin:
offset_mapping[tensor_name][info["sharding_rank"]] = end - begin
# 转换为累积偏移
running_total = 0
for rank in range(world_size):
current_size = offset_mapping[tensor_name][rank]
offset_mapping[tensor_name][rank] = running_total
running_total += current_size
目的:为每个rank计算其在完整张量中的起始位置,用于重建完整张量。
5. 状态字典构建阶段
# 第五步:构建分片状态字典
for param_key, tensor in optim_state_dict.items():
base_name, optim_state_type = _generate_base_static_name(param_key)
struct_name = static_to_struct[base_name]
sharded_param = model_sharded_state_dict[struct_name]
unified_name = f"{struct_name}.{optim_state_type}"
# 处理三种不同类型的张量
if int(tensor.numel()) == 1:
# 标量参数:直接保存
sharded_weight = ShardedWeight(...)
elif base_name in partial_tensor_names:
# 部分分片张量:记录在完整张量中的位置
flattened_offset = offset_mapping[base_name][sharding_rank]
sharded_weight = ShardedWeight(
flattened_range=slice(flattened_offset, flattened_offset + int(tensor.numel()))
)
else:
# 完全分片张量:当前rank拥有完整分片
sharded_weight = ShardedWeight(
flattened_range=slice(0, int(tensor.numel()))
)
目的:为每个优化器状态创建包含完整分片信息的 ShardedWeight 对象。
关键设计思想
1. 分层信息记录
# 记录三个层次的信息:
# 1. 参数级:param_slice_info - 分片范围
# 2. 形状级:param_shape_info - 完整形状
# 3. 全局级:offset_mapping - 全局偏移
2. 分类处理策略
# 三种处理策略:
# 1. 标量参数:直接保存,无需分片信息
# 2. 部分分片张量:记录在完整张量中的位置
# 3. 完全分片张量:当前rank拥有完整分片
3. 全局视角构建
# 每个rank都收集全局信息:
# 1. 所有rank的分片范围
# 2. 完整的参数形状
# 3. 全局偏移映射
2.4 SP(序列并行)
与ColumnParallel、RowParallel类似,只是维度发生在seq_len,且伴随tp(mp)使用。
2.5 关于shared_state_dict方法中structured_name_prefix为空的问题

实际上在这里会递归调用sub_layer的shared_state_dict方法,从而将当前层的name传递到sub_layer作为前缀。
3.测试Ernie中的一些问题
1.self.args.offload_optim
_offlad_optimizer导致保存的转换后的optimizer.pdopt中的动量都没保存成功
模型转换时,如dp2->dp4,offload_optimizer处理后,此时state_dice()中只有master_params和shceduler的数据,动量都被卸载到cpu上了,导致保存失败。

2.sharding4转纯dp2时,文件名不对应,无法加载

shading对应保存的文件名是model_state_shardxx,但是纯dp加载的model文件名是model_state.paparams,因此保存shading4的ckpt,而此时换成纯dp2训练时,无法正确加载ckpt文件。
3.checkpoint文件路径问题

- checkpoint文件路径修改后,paddleformers得同步更新导入得load_state_dict和save_state_dict
4.MoElayer找不到config属性

原因:

这块是因为还没定义就用了那个config的一些参数,我直接给注释掉了。
5.纯sharding出错,原因是在梯度累加时累加的数据类型有问题


注释掉的为原来的代码,然而测试了最新的paddle发现,add_已经适配了fp16和float32两个不同精度的数相加的场景,估计是当时合入的pr造成的bug,已经被修复。
6.t2(ep2)->pp4,报源ckpt加载后的数据,缺少某个参数的优化器状态
lm_head与embedding共享一份weight,因此优化器内部的优化器状态也只有一份
这会导致,在加载ckpt的时候,报错:

主要原因

在加载ckpt时,需要初始化model和opt,而使用flex_ckpt框架时,对应的init_opt中是根据model里面的每个key来创建对应的opt状态,,此时embed_layer和lm_head_layer共用同一个参数,因此优化器状态只有一份,所以导致在ckpt中找不到初始化时创建的embed_tokens,导致报错。
问题追溯:
打印出的model及其对应的value:


可以看到都指向同一个tensor,而具体实现在tie_weight:



可以看到在这里面将lm_head 直接赋值为embedding对应的tensor
为什么共用一个参数,他们也共用一份优化器状态?




以上是创建optimizer涉及到得流程,可以看到,optimizer中包含的参数,是根据params来去重的,即直接根据Tensor去重,而不是key,因此共享tensor的参数,只会有一份保留在optimizer的参数列表中,并且是第一次出现的参数。



在创建累加器时,此时只有{key:embedding_0.w_0,shared_tensor}保留下来了,所以只有embedding的优化器状态创建了,就不会再创建lm_head的了,打印出来如下:
为什么报错提示找不到embedding的优化器状态,而不是lm_head的优化器状态?

在这里,因为两者的v.local_tensor.name一致,前者被覆盖了。
问题总结:
tp2(ep2)->pp4问题总结: 遇到的问题: 在pp4 load tp2(ep2)保存的ckpt时,加载AOAEngine,调用shape_propagation函数时,未被AOA规则改写的参数会做补全映射,而此时会判断补全的这个key是否在源策略(tp2ep2)中出现过,若没出现过则会报错,而此处就报错:找不到 ernie.embed_tokens.weight.moment1_0(其实所有的embed_tokens.weight相关的优化器状态都找不到)。 原因总结: 在初始化opt的函数中即init_optmizer();会根据当前加载的model参数初始化优化器状态,每一个参数都会为其创建优化器状态,而在ernie4.5非pp的组网中,会使用tie_weight函数使得lm_head与embed相关的两个参数共享同一份tensor,而在训练tp2(ep2)创建优化器状态时,相同param.name的param,只会创建一份优化器状态,并且以第一次出现的key来创建优化器状态参数对应的名称,因此确实lm_head与embed仅仅只有一份权重才对,因此此处是需要优化init_opt部分的逻辑。 然而针对上述逻辑,最终应该是能找到embed相关的优化器状态,而找不到lm相关的优化器状态才对,经过查证,问题在于,AdamW的sharded_state_dict在创建static_to_struct_mapping映射时,未对共用同一个tensor的参数做判断,导致对于共享同一个weight的layer来说,后面layer的参数名(即key)会把前面layer的参数名给覆盖,lm_head在后面,因此覆盖了embed,导致我们在优化器看到的是只有lm_head的优化器状态。因此这里需要优化的是,dygraph_sharding_optimizer和AdamW内的sharded_state_dict函数的逻辑。 但针对ernie的pp组网,查证后发现,并未支持tie_weights操作,lm_head和embed分别独立一份weight;而非pp组网,默认一定调用tie_weights操作,因此在当前情况下,无法做tp2(ep2)->pp4的转换。
7.bias开false时,会遇到报错

主要原因是,这里直接对bias做scale,然而当bias为None时,是无法做scale的,导致出错。

做如下修改即可:

8.测tp2(ep2)->tp4,有一个参数的md5未对齐
最终总结:FLAGS_shard_bypass_dygraph_optimizer 标志位只能控制优化器本身的参数更新,但无法阻止回调函数中的直接参数操作更新。在 on_optimizer_end 阶段,OrthogonalCallback 会计算正交损失并直接更新 ernie.layers.1.mlp.gate.weight 参数,这种直接参数修改绕过了优化器控制机制,导致该参数在 save/load 转换过程中被意外更新,从而造成 MD5 校验失败。证明了我们的FlexCheckpoint框架逻辑没问题。

9. 测tp2(ep2)->pp4,有多个参数的md5未对齐

发现是因为moe模式下,开了的原因,导致moe模式下的模型,会多出一个moe_statics.e_score_correction_bias的参数。
10.tp2(ep2)转vpp4,暂时有问题,num_hidden_layer配8层,9层都不对

主要是:multi_token_pred_depth参数开启时,会在hidden_layer中多加一层MTP层,导致以下断言不支持,当前ernie4.5的vpp不支持加MTP这一层:
assert sum(weights) % actual_num_parts == 0
但是关掉multi_token_pred_depth参数后,又报p2p通信错误,看起来是ernie4.5跑vpp自身的bug:

11.tp2(ep2)转tp2(ep2)+sd2 和 dp2转tp2(ep2)+sd2时,会出现,某些优化器状态消失得现象

也是开了opitimizer_offload的原因
12.DP2转DP4

在ernie下跑会报错,因为opt没被封装,它没有inner_opt,而llama下跑不会报错,因为llama下封装了一层:
paddle.distributed.fleet.utils.mix_precision_utils.MixPrecisionOptimizer
因此需要加个判断:
inner_opt = getattr(optimizer, "_inner_opt", None)
if DygraphShardingOptimizer is not None and isinstance(inner_opt, DygraphShardingOptimizer):
local_params = optimizer._rank2params[optimizer._sharding_rank]
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), local_params)
return
elif DygraphShardingOptimizerV2 is not None and isinstance(inner_opt, DygraphShardingOptimizerV2):
注意:ernie4.5训练时,train函数中调用的self._wrap_model是/home/ERNIE/examples/pre-training/ernie/src/trainers/pretraining_trainer.py内的方法。
13.DP2转Sharding4_V1的时候(开dp_group和sharding_group)

总共64个专家,2卡sharding的时候,只有32个专家有优化器状态,4卡sharding的时候只有16个专家有优化器状态;原因是每个rank上的experts组在训练一次后,所有专家的参数被同步了,未具体定位,但训练5步,4张卡上的16个专家参数的md5完全相同;但初始化时,4张卡上的专家参数是不同的。

每张卡都做了同样的操作,deepcopy fn,而对每个rank来说,这个fn都是同一个layer。
而DP的时候,每个rank

14.DP2转Sharding4_V2的时候
报错如下:

15.DP2转Sharding2_EP2
报错如下,主要是在sharding2_ep2转回dp2时报错,但是接续loss 1E-5对齐:

其实是开了这个的原因
16.DP2、ShardingV1、V2转TP2(EP2)_PP2
存在下面的报错:

主要原因是SequenceParallelLayer没有适配sharded_state_dict
![]()
4.测试LLAMA中的一些问题
1.路径需要更换

2.sharding__stage_1_overlap不支持

3.纯tp2时报错:

AdamW需要适配sharded_state_dict
4.纯DP下,fleet显示没有正确初始化

报错的原因是,在纯dp的情况下比较特殊,要开unified_checkpoint这参数,才能用fleet.init初始化,不然用的动半的初始化,这样的话self._hcg没有初始化,就不能调用get_hybrid_communicate_group,加上就好了。

5.纯dp会hang住
现象:

原因:





在调用check_unique_id函数时,会调用all_gather获取所有rank的unique_id,而因为纯dp下,should_save被设置为只在0卡保存权重,因此1卡是不会进入save_state_dict中的,而在调用all_gather时,如果process_group为None,则会调用global_group,纯dp2下,即ranks:2 rank_id:0,1;而此时0卡向1卡all_gather请求,1卡却没有做all_gather,0卡就一直等待,最终导致hang住。
解决方案:
添加一个条件,在纯DP时,此时use_hybrid_parallel为false(这是每个rank共同的特征),因此,添加个判断条件,让1卡也进入即可。
6.flash_attention无法正确传入,要手动修改

7.测tp时,fused_qkv, fused_ffn打开后loss接续不符合预期
在测tp策略转换的过程中,发现fused_qkv, fused_ffn打开后loss差距不符合预期;经验证,当前存在fused_qkv与old_fused_qkv两套逻辑,llama当前默认使用的是old_fused_qkv此时无需配置aoa与tp自洽,而ernie使用的是fused_qkv,需要配置aoa。
llama下的aoa配置:
--aoa_config '{
"aoa_statements": [
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight, fused_ffn",
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment1_0 -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment1_0, fused_ffn",
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment2_0 -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment2_0, fused_ffn",
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.w_0 -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.w_0, fused_ffn"
]
}' \
ernie下的aoa配置:
aoa_config: {
"aoa_statements": [
"ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight -> ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight, fused_qkv, num_heads=20, num_key_value_groups=5",
"ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight.moment1_0 -> ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight.moment1_0, fused_qkv, num_heads=20, num_key_value_groups=5",
"ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight.moment2_0 -> ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight.moment2_0, fused_qkv, num_heads=20, num_key_value_groups=5",
"ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight.w_0 -> ernie.layers.\$LAYER_ID.self_attn.qkv_proj.weight.w_0, fused_qkv, num_heads=20, num_key_value_groups=5",
"ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight -> ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight, fused_qkv, num_heads=20, num_key_value_groups=5",
"ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight.moment1_0 -> ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight.moment1_0, fused_qkv, num_heads=20, num_key_value_groups=5",
"ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight.moment2_0 -> ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight.moment2_0, fused_qkv, num_heads=20, num_key_value_groups=5",
"ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight.w_0 -> ernie.mtp_block.\$LAYER_ID.self_attn.qkv_proj.weight.w_0, fused_qkv, num_heads=20, num_key_value_groups=5",
"ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight -> ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight, fused_ffn",
"ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight.moment1_0 -> ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight.moment1_0, fused_ffn",
"ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight.moment2_0 -> ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight.moment2_0, fused_ffn",
"ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight.w_0 -> ernie.layers.\$LAYER_ID.mlp.up_gate_proj.weight.w_0, fused_ffn",
"ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight -> ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight, fused_ffn",
"ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight.moment1_0 -> ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight.moment1_0, fused_ffn",
"ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight.moment2_0 -> ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight.moment2_0, fused_ffn",
"ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight.w_0 -> ernie.mtp_block.\$LAYER_ID.mlp.up_gate_proj.weight.w_0, fused_ffn",
"ernie.layers.1.mlp.shared_experts.up_gate_proj.weight -> ernie.layers.1.mlp.shared_experts.up_gate_proj.weight, fused_ffn",
"ernie.layers.1.mlp.shared_experts.up_gate_proj.weight.moment1_0 -> ernie.layers.1.mlp.shared_experts.up_gate_proj.weight.moment1_0, fused_ffn",
"ernie.layers.1.mlp.shared_experts.up_gate_proj.weight.moment2_0 -> ernie.layers.1.mlp.shared_experts.up_gate_proj.weight.moment2_0, fused_ffn",
"ernie.layers.1.mlp.shared_experts.up_gate_proj.weight.w_0 -> ernie.layers.1.mlp.shared_experts.up_gate_proj.weight.w_0, fused_ffn",
]
}
fused_qkv(llama)实现逻辑图:

tp2->tp4,num_heads=k_v_nums:

tp2->tp4,num_heads>k_v_nums:

old_fused_qkv(ernie)实现逻辑图:

tp2->tp4,num_heads=k_v_nums: 此时逻辑同上,也是均分最后一维。
tp2->tp4,num_heads>k_v_nums:

5.unified_checkpoint与flex_check_point的区别
以tp2为例,flex_check_point保存的权重,是按照参数的部分分片保存的,并没有在最后做allgather:


可以看到,保存下来的embed_tokens参数,仍然是按照vocab_size的大小,切成两份的形式,是一种shard的状态,注意,这里看起来像是batch被切分了,其实是因为,参数一般都以[vocab_size,batch_size]的形式排列,为了后续方便计算。
而unified_ckpt保存的权重,最后会做allgather,即所有rank上的参数都是完整参数,而保存的时候,是将所有参数划分成tp_degree份保存到多个文件中:

6.代码改进意见
1.关于Layer的shared_state_dcit
因为ColumnParallelLinear、VocabParallelEmbedding、RowParallelLinear本质也是继承nn.Layer,因此可以统一用同一个shared_state_dict。添加一个如下的切分方法即可:
def _get_shard_rules(self):
"""子类可重写此方法来提供分片规则"""
return None

7.一些代码记录
1.ShardingGradView
class ShardingGradView:
def __init__(
self,
param,
param_buffer,
grad_buffer,
index,
padded_size,
sharding_degree,
rank,
use_main_grad=False,
release_grad=False,
):
self._param = param
self._param_buffer = param_buffer
self._grad_buffer = grad_buffer
self._index = index
self._padded_size = padded_size
self._sharding_degree = sharding_degree
self._rank = rank
self._use_main_grad = use_main_grad
self._release_grad = release_grad
shard_size = param_buffer._numel() // sharding_degree
rank_begin = max(rank, 0) * shard_size
rank_end = rank_begin + shard_size
param_begin = max(self._index, rank_begin)
param_end = min(self._index + self._padded_size, rank_end)
self._param_begin = param_begin
self._param_end = param_end
self._rank_begin = rank_begin
self._slice_grad = None
if not self._release_grad:
self._link_grad_to_buffer()
# share param buffer
self._share_param_buffer()
def _get_padding(self):
if self._param_begin < self._param_end and self._slice_grad is not None:
padding_start = self._index + self._param._numel()
padding_end = self._index + self._padded_size
padding_start = max(self._param_begin, padding_start)
padding_end = min(self._param_end, padding_end)
if padding_start >= padding_end:
return None
padding = padding_end - padding_start
grad_numel = self._slice_grad._numel()
assert grad_numel >= padding, f"{grad_numel} vs {padding}"
padding_grad = self._slice_grad._slice(
grad_numel - padding, grad_numel
)
return padding_grad
else:
return None
def _slice_grad_from_buffer(self):
assert self._grad_buffer is not None
if self._param_begin < self._param_end:
self._slice_grad = self._grad_buffer._slice(
self._param_begin, self._param_end
)
tmp_grad = self._grad_buffer._slice(
self._index, self._index + self._param._numel()
)
return tmp_grad
def _link_grad_to_buffer(self):
tmp_grad = self._slice_grad_from_buffer()
tmp_grad.get_tensor()._set_dims(self._param.shape)
if not self._use_main_grad:
self._param._copy_gradient_from(tmp_grad)
else:
self._param.main_grad = tmp_grad
这里是根据padded_param减去param的大小,从而得到参数的padding大小,又由于grad和param大小是一致的,因此,切出grad尾部的padding大小的这一部分,就是padding_grad。
可以看到,self._slice_grad在初始化的时候,用的self._param_begin, self._param_end这一段,而这是经过了padded的大小,因此self._slice_grad.numel()等于padded后的param的大小。
而tmp_grad是实际的未padded的大小。也就是参数实际的大小,并且self._grad_buffer和self._param_buffer都是用paddle.zeros创建的全0矩阵,大小为当前param组所有param经过padded之后的大小的和。
因此在_link_grad_to_buffer中,实际是将tmp_grad赋值给param的grad属性,如果有main_grad就给main_grad,这里是浅拷贝,所以就和buffer共享内存了。
def _share_param_buffer(self):
param_shape = self._param.shape
stop_gradient = self._param.stop_gradient
self._param.stop_gradient = True
self._param.flatten_()
paddle.assign(
self._param,
self._param_buffer._slice(
self._index, self._index + self._param._numel()
),
)
self._param.get_tensor()._set_dims(param_shape)
self._param.stop_gradient = stop_gradient
self._param_buffer._slice(
self._index, self._index + self._param._numel()
)._share_buffer_to(self._param)
def fill_slice_param(self, slice_param):
slice_begin = self._param_begin
slice_end = self._param_end
if slice_param._is_initialized():
assert self._param_buffer._is_shared_buffer_with(slice_param)
assert len(slice_param.shape) == 1
assert slice_param.shape[0] == (slice_end - slice_begin)
slice_begin = self._param_begin
slice_end = self._param_end
slice_buffer = self._param_buffer._slice(slice_begin, slice_end)
slice_buffer._share_buffer_to(slice_param)
slice_param.get_tensor()._set_dims([slice_end - slice_begin])
def assign_slice_grad(self, slice_param):
assert self._param_buffer._is_shared_buffer_with(self._param)
slice_grad = self._slice_grad
if slice_grad is None:
return
self.fill_slice_param(slice_param)
if hasattr(self._param, "main_grad"):
if not hasattr(slice_param, "main_grad"):
slice_param.main_grad = slice_grad
else:
assert slice_param.main_grad is slice_grad
elif slice_grad is not None:
if slice_param.grad is None:
slice_param._copy_gradient_from(slice_grad)
else:
assert slice_param.grad._is_shared_buffer_with(slice_grad)
二者的区别是,前者是初始化的时候,将param_buffer的内存与param共享,注意是未padded的,而后者则是padded后的param,并且在当超过当前rank分配的buffer大小时会被截断,而前者不会。
为什么有这两个?
前者是给初始化param使用的,将param和buffer的内存共享,而后者是针对slice_param = EagerParamBase(shape=[1], dtype=param.dtype) 使用得,初始时,不知道大小,通过fill_slice_param来设置实际大小,即,因为每个优化器负责更新部分参数,大多数参数是完整的,但当self._index + self._padded_size超过rank_end时,此时param_end会被设置为rank_end,即将参数截断,一部分在当前rank,另一部分在下一个rank上。这里用EagerParamBase实际是想要实现深拷贝原始的param,原始也是此类型,这样可以使得能够将所有属性完全相同的复制过来。
所以model的param其实是没有被切分的,而slice_param在不跨rank的时候,是完整param的区域+padding区域,而如果出现跨rank的情况,则会被截断,前者用于forward完整计算不切,后者用于仅更新当前rank上的optimizer负责更新的参数,因此做跨rank切分。对应的self.param.main_grad和slice_param.main_grad也是上述区别。

可以看到sharding_param_grad_view中保存了多个ShardingGradView实体,而每个ShardingGradView实体的信息,都能体现当前param的是如何被sharding到每个rank上的优化器上的,即其在param_buffer中的位置信息,因此我们利用这个信息,去获取每个param在当前rank上的切片信息(注意这里是opitimizer要更新的参数的切片信息,而不是对model的param做了切分,实际是复制了一份,可见前文介绍),即flattened_range。
2.opitmizer中的param在每个rank上的划分

可以看到,只要是不属于当前rank的切片,要么①param_begin>param_end;②param_end-index<0。则slice的star,end都取param_begin-index,做空切片,从而控制param_slice_info只保留自己rank上的param的有效切片信息。
3.FusedCommBuffer的逻辑
注意单个fuedCommBuffer并不对应全部参数,optimizer的参数列表会被划分到多个buffer中

4.FusedCommBuffer在DygraphSharding_Optimizer中的应用(v1,v2都用到了,但是只有V2对optimizer的参数做了摊平切分,而V1只是对完整参数做了划分。)
DygraphSharding与DygraphShardingV2的区别是,前者是参数级划分,将opitimizer的参数划分到不同rank上(且为了均衡负载,是依次分配参数,及分配下一个参数是,是看哪个rank上当前分配的参数大小最小,则分配,参数是乱序的);而后者是参数内划分,会将参数列表根据color分成多个buffer,然后每个buffer内的所有参数flatten后,sharding到多个rank上,所以会存在有的参数一部分在该rank,另一部分在下一个rank的现象。DygraphShardingV2不支持fused_param
注意,会根据group_size创建多个FusedCommBuffer,同时这里的tensor_fusion,指的是,比如把多个参数放到同一片存储区域,多个梯度放到同一片存储区域。

8.合参UC测试
1.uc下optimizer的格式是
因此需要做一个格式转换,把斜杠改成.
2.uc下跑dp,只会保存optimizer和master,具体在如下位置:

3.uc下跑SD2EP2时,md5未对齐

经过验证,主要原因是,fc下,会把expert,转换成key不同的情况,也就是说,比如有64个expert,分到两个rank上的时候,编号都是0-31,而fc下,我们会把2卡的参数名修改,做数字偏移,比如0变成rank_id*(per_device_expert_nums),从而区别不同rank上的专家,而uc不会,导致对比时,uc上的expert被覆盖,比较出错,修改脚本验证后,对齐。
4.涉及TP的都会报错
追溯原因:
问题1,uc把不同rank的专家当同一个参数合并


在tp下,如果moe_group是tp,则做恒等映射,是不切分专家的。只是均分到整个moe_group中,如下,fc就是ErnieMoeMLP:

但是注意:


这里moe_group已经被parse_moe_group解析成了group格式,即如图,所以这里moe_in_mp始终为false,所以默认所有参数都按tp合并。所以在这里,用一个moe_group_name来提前接收moe_group字符串。
tp合并参数的逻辑如下
:

然后 action(ret) 执行,但 experts 的 action 是 partial(fn,is_column)列切,因此按列合并。
因此,两个rank上的expert会被错误合并成一个大的tensor。
问题2,未给专家参数设置mp_moe的标志,导致专家被allgather,而实际应该是直接获取本rank的,非本rank的专家参数设置为None
而,当moe_group直接设置为True的时候,action(ret) 执行,但 experts 的 action 是 lambda x: x,直接返回收集到的张量列表,所以这时候返回的tensor就是一个列表,包含rank0的expert tensor和rank1的expert tensor。仍然会导致报错。

注意这里会得到一个tensor列表的主要原因是,如图处丢掉了expert的p.mp_moe的属性,导致expert无法被识别出有mp_moe属性,导致保存时,仍然保存的不是本地的expert,而是一个expert列表,即本来应该走绿色的这条分支,而现在走了红色的这条分支。

问题3,ernie_moe的_get_tensor_parallel_mappings中,未设置mtp_block层的映射,导致在save_ckpt合参时,该参数未被按tp切分维度合并

因此,需要加入如下映射,标记着其处于切分状态:

问题4,load_state_dict和_handle_aoa未考虑到多卡转单机的情况
最后,load_state_dict和_handle_aoa要适配一下多卡转单机的情况,例如加载tp4的ckpt,到单卡时,也需要用到_handle_aoa。

9.AOAEngine学习记录
1.AOAShardInfoContext

这个主要用于记录上下文信息,保留一些信息,给后续操作可调用。
source_state_shard_info和destination_state_shard_info分别表示需要load下来的ckpt对应策略的参数分片信息,和当前正在执行的策略的参数分片信息,格式为_ShardInfo = dict[str, list[ShardedWeightDesc]],即包含,同一个key,再不同rank上的参数分片状态,如果是类似dp这样的,同一个key只会在单个distcp文件中保存,因此只有一个参数分片状态。
get_all_dst_state_keys与get_all_src_state_keys则是辅助函数获取其中所有的key,get_num_hidden_layers通过aoa_config中是否配置了$LAYER_ID,来正则匹配dst中所有key中的layer_id,例如下:
"ernie.layers.$LAYER_ID.self_attn.qkv_proj.weight -> ernie.layers.$LAYER_ID.self_attn.qkv_proj.weight, fused_qkv_old, num_heads=20, num_key_value_groups=5"
会以$LAYER_ID为分隔符,分成两份,然后中间以\d匹配,从而匹配到layer_id,遍历所有key,得到的做大ID+1,则为num_hidden_layer的层数。
get_src_state_shard_num和get_dst_state_shard_num这两个主要是查看当前key对应参数的分片数,即tp数。
为什么要把optmizer的key也都转换成model的key来算呢,原因是,当做sharding的时候,opt的参数分片数=tp_nums*sharding_nums,直接求就有问题了。

2.Lexer

核心目标:为 AoA 表达式做词法分析(Lexing),并在词法分析前先应用已注册的宏展开,最终生成供解析器使用的 token 序列。
首先传入的参数expressions是aoa_conifg["aoa_statements"],这是一个字符串列表,形状如下:
--aoa_config '{
"aoa_statements": [
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight, fused_ffn",
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment1_0 -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment1_0, fused_ffn",
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment2_0 -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.moment2_0, fused_ffn",
"llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.w_0 -> llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight.w_0, fused_ffn"
]
}' \
每一个expression会使用apply_macros,即对每个expression,遍历使用所有的已经注册好的macro。
在进入macro之前,会使用tokenize方法将expression解析成多个token,按照token_specification中的正则项进行匹配,name作为key,匹配到的实际内容作为value,比如上述的aoa_config的第一条,首先会根据identifier获取到第一个token:llama.layers.$LAYER_ID.mlp.gate_up_fused_proj.weight,遇到空格会skip,然后根据rarrow匹配到->,紧接着再根据identifier获取到下一个token,知道最终结束,而每个text都会判断一下后面有没有\n,没有就补充,从而得到NEWLINE,标志着一条text匹配结束。
被所有macro处理后,会得到一个results列表,列表里面也都是expression样子的表达式,最终Lexer会把result_expression再次调用tokenized解析成token返回,给到parser里面做处理。