0%

阿里 RecIS:分布式高性能哈希表

RecIS 开发了一个低内存开销、高效查表、动态扩展、强大稳定的哈希表,规避了 torch.nn.Embedding 的问题。RecIS 内部叫 HashTable,也就是广义上的 Embedding。在本文中,会对 HashTable 的底层实现一探究竟,看看它如何达到的:

  • 降低 Embedding 表的内存、通信开销
  • 快速查表
  • 当 id 增加时,需要动态扩表,且无明显消耗
  • 整个系统需要稳定、简洁

表名的唯一性

通常一个搜广推模型拥有很多 Embedding 表,如用户特征(含年龄、性别和身高这些 id)位于 user 表,app 使用习惯(含购物时间、消费区间这些 id)位于 behave 表。这些表囊括了各式各样的 id。且表之间不能重复,也就是不需要两张 user 表。为了实现这个目的,RecIS 实现了 HashTableRegister 单例类,这部分代码位于 recis/nn/modules/hashtable.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class SingletonMeta(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
return cls._instances[cls]

class HashTableRegister(metaclass=SingletonMeta):
def __init__(self) -> None:
self._HashTables = {}

def register(self, name: str, info: str):
if name in self._HashTables:
raise ValueError(
f"Duplicate HashTable shard name: {name}, before: {self._HashTables[name]}, now: {info}"
)
self._HashTables[name] = info

class HashTable(torch.nn.Module):
def __init__(self):
HashTableRegister().register(table_name)

正常创建类对象时,会调用元类的 __call__ 方法,使用 __new__ 创建实例,并调用类对象的 __init__ 方法初始化对象。但 HashTableRegister 声明了元类 SingletonMeta,并重新声明了 __call__ 方法,所以改写了类对象的创建方式。通过 cls 捕获类对象 HashTableRegister,使这个类只有一个实例,所以在任何地方声明的 HashTableRegister() 都是同一个实例,即:

1
2
3
a = HashTableRegister()
b = HashTableRegister()
print(a is b) # True

所以无论创建多少 HashTable,每个表名都只会被创建一次,否则会报错退出。

底层设计

在注册完 HashTable 后会创建一个 HashTable 的对象,这部分代码位于 csrc/Embedding/HashTable.cc 文件。HashTable 的创建和底层部分都由 C++ 实现,并把接口暴露给 torch。

1
2
3
self._HashTable_impl = torch.ops.recis.make_HashTable(...)
...
m.def("make_HashTable", HashTable::Make);

HashTable::Make 会返回一个指向 HashTable 的侵入式智能指针,这个指针管理着实际的 HashTable 对象。在 HashTable 的构造过程中,除了存储切片大小、长度、设备、数据类型、数据生成器等基本信息外,还会创建两个核心组件:slotgroup 和 idmap。

IDMap 设计

这部分代码位于 /csrc/embedding/cpu_id_map.h。idmap 负责 id 的插入,其核心数据结构是 ska::flat_hash_map,将原始特征 ID 映射到内部索引 index。使用内部索引 index 去访问对应的 Embedding。所以 idmap 可以理解为 std::map,key 是原始ID,value 是内部索引。

  • 在插入 ids 时实际会调用 Lookup 查表函数去遍历 idmap,这部分的代码实现在 IndexLookupFunctor 函数中。
    • 如果 ids 在 idmap 中,返回对应的 index
    • 如果 ids 不在 idmap 中,标记为缺失的 ids。之后为缺失的 ids 生成 index,插入到 idmap 中。在插入时,这里的 index 是连续的,即 0, 1, 2, …, ids_num 连续分布
  • 在删除 ids 时,像 std::map 一样 erase 掉不需要的 ids,并存储对应的 index,用于下次插入时的内存复用

id_allocator

当再次插入 ids 时,为了保持 index 的连续性需要知道上次的 index 到哪里了;删除 ids 时需要知道哪些 index 被空了出来可以复用。在这个过程中,index 的生成和删除由 id_allocator 管理。这部分代码位于 csrc/embedding/id_allocator.cc,使用 cur_size_ 表示分配了多少 index,使用 free_size_ 表示删除了多少 ids,使用 free_blocks_ 存储可以被复用的 index,说起来有些抽象,可以看个例子。

  • 生成 index:由 generate_ids_op 函数实现。假如第一次插入 5 个 id,由于没有 index,所以 index 从 0 开始一直递增到 4。如下所示,左侧是 id,右侧是 index。
1
2
3
4
5
6
7
ids_map = {
1180210: 0, # 原始 ID → 内部索引
721458: 1,
655922: 2,
1000000: 3,
2000000: 4
}
  • 删除 index:由 free_ids_op 函数实现。假如现在删除 655922 这个 id,它对应的 index 是 2,那么把 2 存储到 free_blocks_ 中
  • 当再次插入一个取值为 328637 的 id 时,检测到 free_blocks_ 中的 2 可以被复用,此时的 ids_map 取值为:
1
2
3
4
5
6
7
ids_map = {
1180210: 0, # 原始 ID → 内部索引
721458: 1,
1000000: 3,
2000000: 4
328637: 2,
}

高效查表

假设 ids = [1180210, 721458, 655922, 1000000, 2000000],实际上 ids 并不连续。如果直接用 ID 作为数组索引:

1
2
Embedding[1180210]  # ❌ 需要分配 1180210 个元素,浪费大量内存!
Embedding[721458] # ❌ 大部分位置都是空的

但是可以借助 ids_map 将 ids 映射为连续的 index,此时只需要分配 5 个 Embedding 的位置:

1
2
3
4
5
Embedding[0]  # ✅ 对应 ID 1180210
Embedding[1] # ✅ 对应 ID 721458
Embedding[2] # ✅ 对应 ID 655922
Embedding[3] # ✅ 对应 ID 1000000
Embedding[4] # ✅ 对应 ID 2000000

ids_map 的优势不言而喻:

  • 内存效率:只分配实际需要的空间
  • 访问效率:连续内存,缓存友好
  • 索引复用:删除 ID 后可以复用索引

Slot 设计

slot 一般翻译为槽,如信号槽,数据槽等,在 RecIS 中的作用是数据槽,用于存储 Embedding、优化器状态等,这部分代码位于 csrc/Embedding/slot_group.cc

slotgroup 以 vector 的形式管理多个类型的 slot,如 Embedding、优化器状态就是两个类型的 slot。在每个 slot 中,以 vector 的形式管理 torch::Tensor,而 torch::Tensor 负责存储实际的数据。

以 Embedding 为例,在查询 index 对应的 Embedding 时,会调用 HashTable::EmbeddingLookup 函数,最终调用到:

1
2
3
4
5
6
7
void HashTable::IncrementBlocknum(int64_t ids_num) {
size_t block_num =
(ids_num + slot_group_->BlockSize()) / slot_group_->BlockSize();
while (slot_group_->BlockNum() < block_num) {
slot_group_->IncrementBlock();
}
}

BlockSize 初始化为 5,所以当插入的 ids_num 为 100 时,需要创建 20 个 slot,即执行 20 次:

1
void Slot::IncrementBlock() { values_->push_back(generator_->Generate()); }

生成器 generator_ 一开始就传递给了 HashTable,假设 Embedding shape 是 128,数据类型是 float32,那么一次 IncrementBlock 就会生成一个 [5, 128] 的 float32 的 Embedding,并按照指定的初始化方式进行初始化。初始化方式可以参考 csrc/Embedding/initializer.h

1
2
3
4
5
6
// 每个 Tensor 是一个 block,大小为 [block_size, Embedding_dim]
// 默认 block_size = 5, Embedding_dim = 128
// block_0: shape [5, 128],存储 index 0-5 的 Embedding
// block_1: shape [5, 128],存储 index 5-10 的 Embedding
// block_2: shape [5, 128],存储 index 10-15 的 Embedding
// ...

在查表时,会调用 block_gather 函数,计算 index 数据哪个 block ,进而访问最终的 Embedding 数据。

1
2
3
4
auto src_block_index = src_index / block_size_;  // 哪个 block
auto src_row_index = src_index % block_size_; // block 内的行
...
Embedding = emb_blocks[src_block_index][src_row_index] // 成功访问

这种设计形式除了高效查表外,还可以减少内存碎片,并且支持增量扩展。总结一下,HashTable 的数据的存储形式为:

接入 torch 层

那么 C++ 实现的 HashTable,是如何参与 torch 的前向计算和反向传播呢?补充一些额外知识,torch.autograd.Function 可以自定义一个自动求导的算子,它允许用户同时定义:

  • 正向计算(forward)
  • 在反向传播(backward)时,这个算子的求导公式
  • 如果 forward() 返回 N 个值,backward() 会接收 N 个梯度(按顺序对应)

在集成自定义的 HashTable 时(非 torch 代码的 C++、cuda 算子),就需要使用这个类完成对自定义算子的自动求导。举个例子:

1
2
3
4
5
6
7
8
9
10
11
12
class HashTableLookupHelpFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, table):
...
return y

@staticmethod
def backward(ctx, grad_output):
...
return grad_x, grad_table

y = HashTableLookupHelpFunction.apply(x, table)

在上述代码中,会自动调用 forward(ctx, x, table) 得到 y;在需要梯度时,记录运算图;之后反向传播时,会自动调用 backward(ctx, grad_output)。

前向计算

回到 recis/nn/modules/hashtable.py 这个文件。在训练模式下,使用可以自动求导的 HashTableLookupHelpFunction 在前向阶段完成查表,得到内部索引 index 和对应的 Embedding。之后,考虑到输入的 batch 中会有重复的 ids,所以使用 GradWorkerMeanFunction 完成对 index 的 gather 聚合:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 示例数据
Embedding = [
[0.1, 0.2], # 唯一 ID 0 的 Embedding
[0.3, 0.4], # 唯一 ID 1 的 Embedding
[0.5, 0.6], # 唯一 ID 2 的 Embedding
] # Shape: [3, 2]

index = [0, 1, 2, 1, 0] # 5 个原始位置到唯一 ID 的映射

# gather 操作:根据 index 从 Embedding 中选择
output = torch.ops.recis.gather(index, Embedding)
# output = [
# [0.1, 0.2], # index[0] = 0 → Embedding[0]
# [0.3, 0.4], # index[1] = 1 → Embedding[1]
# [0.5, 0.6], # index[2] = 2 → Embedding[2]
# [0.3, 0.4], # index[3] = 1 → Embedding[1]
# [0.1, 0.2], # index[4] = 0 → Embedding[0]
# ] # Shape: [5, 2]

反向传播

反向传播的顺序和前向计算完全相反,首先是 GradWorkerMeanFunction 的反向传播:

1
reduce_grad, None, None = GradWorkerMeanFunction.backward(ctx, grad_from_upstream)

根据当前 worker 收到的上游梯度 grad_from_upstream 除以 worker 数量,得到平均梯度。最后使用 index_add_ 函数累加相同索引的梯度,并传递给 HashTableLookupHelpFunction

1
2
3
# grad_output_index = None  (因为 index 没有参与后续计算)
# grad_output_emb = reduce_grad (因为 Embedding 参与了后续计算)
HashTableLookupHelpFunction.backward(ctx, None, reduce_grad)

HashTableLookupHelpFunction 的 backward 方法中,会调用 accept_grad 函数存储 index 以及对应的梯度。目前梯度只是被存了起来,像所有其他 torch 模型一样,在调用 optimizer.step 后,才会更新这些梯度。

至此,已经实现了一个类 Embedding 设计,并且可以在 torch 端调用并参与 AI 的前向计算与反向传播。在后续的内容中,将一睹 HashTable 上层的设计,看看特征处理、数据通信的实现,如何基于 HashTable 搭建出一个完整的稀疏模型。

未完待续

GradWorkerMeanFunction 的反向传播函数中求了梯度的均值,说明这个梯度可能已经过 all_reduce 或其他聚合处理。而这就涉及了多卡通信,暂时还没读到这部分源码,未来某天在写。由于内容过多,一些比较有意思的实现细节只是一带而过,这些有时间在慢慢看看。

  • 侵入式智能指针
  • ska::flat_hash_map 底层实现
  • torch.autograd.Function 自动求导实现原理
感谢上学期间打赏我的朋友们。赛博乞讨:我,秦始皇,打钱。

欢迎订阅我的文章