RecIS 开发了一个低内存开销、高效查表、动态扩展、强大稳定的哈希表,规避了 torch.nn.Embedding 的问题。RecIS 内部叫 HashTable,也就是广义上的 Embedding。在本文中,会对 HashTable 的底层实现一探究竟,看看它如何达到的:
- 降低 Embedding 表的内存、通信开销
- 快速查表
- 当 id 增加时,需要动态扩表,且无明显消耗
- 整个系统需要稳定、简洁
表名的唯一性
通常一个搜广推模型拥有很多 Embedding 表,如用户特征(含年龄、性别和身高这些 id)位于 user 表,app 使用习惯(含购物时间、消费区间这些 id)位于 behave 表。这些表囊括了各式各样的 id。且表之间不能重复,也就是不需要两张 user 表。为了实现这个目的,RecIS 实现了 HashTableRegister 单例类,这部分代码位于 recis/nn/modules/hashtable.py。
1 | class SingletonMeta(type): |
正常创建类对象时,会调用元类的 __call__ 方法,使用 __new__ 创建实例,并调用类对象的 __init__ 方法初始化对象。但 HashTableRegister 声明了元类 SingletonMeta,并重新声明了 __call__ 方法,所以改写了类对象的创建方式。通过 cls 捕获类对象 HashTableRegister,使这个类只有一个实例,所以在任何地方声明的 HashTableRegister() 都是同一个实例,即:
1 | a = HashTableRegister() |
所以无论创建多少 HashTable,每个表名都只会被创建一次,否则会报错退出。
底层设计
在注册完 HashTable 后会创建一个 HashTable 的对象,这部分代码位于 csrc/Embedding/HashTable.cc 文件。HashTable 的创建和底层部分都由 C++ 实现,并把接口暴露给 torch。
1 | self._HashTable_impl = torch.ops.recis.make_HashTable(...) |
HashTable::Make 会返回一个指向 HashTable 的侵入式智能指针,这个指针管理着实际的 HashTable 对象。在 HashTable 的构造过程中,除了存储切片大小、长度、设备、数据类型、数据生成器等基本信息外,还会创建两个核心组件:slotgroup 和 idmap。
IDMap 设计
这部分代码位于 /csrc/embedding/cpu_id_map.h。idmap 负责 id 的插入,其核心数据结构是 ska::flat_hash_map,将原始特征 ID 映射到内部索引 index。使用内部索引 index 去访问对应的 Embedding。所以 idmap 可以理解为 std::map,key 是原始ID,value 是内部索引。
- 在插入 ids 时实际会调用 Lookup 查表函数去遍历 idmap,这部分的代码实现在
IndexLookupFunctor函数中。- 如果 ids 在 idmap 中,返回对应的 index
- 如果 ids 不在 idmap 中,标记为缺失的 ids。之后为缺失的 ids 生成 index,插入到 idmap 中。在插入时,这里的 index 是连续的,即 0, 1, 2, …, ids_num 连续分布
- 在删除 ids 时,像 std::map 一样 erase 掉不需要的 ids,并存储对应的 index,用于下次插入时的内存复用
id_allocator
当再次插入 ids 时,为了保持 index 的连续性需要知道上次的 index 到哪里了;删除 ids 时需要知道哪些 index 被空了出来可以复用。在这个过程中,index 的生成和删除由 id_allocator 管理。这部分代码位于 csrc/embedding/id_allocator.cc,使用 cur_size_ 表示分配了多少 index,使用 free_size_ 表示删除了多少 ids,使用 free_blocks_ 存储可以被复用的 index,说起来有些抽象,可以看个例子。
- 生成 index:由 generate_ids_op 函数实现。假如第一次插入 5 个 id,由于没有 index,所以 index 从 0 开始一直递增到 4。如下所示,左侧是 id,右侧是 index。
1 | ids_map = { |
- 删除 index:由 free_ids_op 函数实现。假如现在删除 655922 这个 id,它对应的 index 是 2,那么把 2 存储到 free_blocks_ 中
- 当再次插入一个取值为 328637 的 id 时,检测到 free_blocks_ 中的 2 可以被复用,此时的 ids_map 取值为:
1 | ids_map = { |
高效查表
假设 ids = [1180210, 721458, 655922, 1000000, 2000000],实际上 ids 并不连续。如果直接用 ID 作为数组索引:
1 | Embedding[1180210] # ❌ 需要分配 1180210 个元素,浪费大量内存! |
但是可以借助 ids_map 将 ids 映射为连续的 index,此时只需要分配 5 个 Embedding 的位置:
1 | Embedding[0] # ✅ 对应 ID 1180210 |
ids_map 的优势不言而喻:
- 内存效率:只分配实际需要的空间
- 访问效率:连续内存,缓存友好
- 索引复用:删除 ID 后可以复用索引
Slot 设计
slot 一般翻译为槽,如信号槽,数据槽等,在 RecIS 中的作用是数据槽,用于存储 Embedding、优化器状态等,这部分代码位于 csrc/Embedding/slot_group.cc。
slotgroup 以 vector 的形式管理多个类型的 slot,如 Embedding、优化器状态就是两个类型的 slot。在每个 slot 中,以 vector 的形式管理 torch::Tensor,而 torch::Tensor 负责存储实际的数据。
以 Embedding 为例,在查询 index 对应的 Embedding 时,会调用 HashTable::EmbeddingLookup 函数,最终调用到:
1 | void HashTable::IncrementBlocknum(int64_t ids_num) { |
BlockSize 初始化为 5,所以当插入的 ids_num 为 100 时,需要创建 20 个 slot,即执行 20 次:
1 | void Slot::IncrementBlock() { values_->push_back(generator_->Generate()); } |
生成器 generator_ 一开始就传递给了 HashTable,假设 Embedding shape 是 128,数据类型是 float32,那么一次 IncrementBlock 就会生成一个 [5, 128] 的 float32 的 Embedding,并按照指定的初始化方式进行初始化。初始化方式可以参考 csrc/Embedding/initializer.h。
1 | // 每个 Tensor 是一个 block,大小为 [block_size, Embedding_dim] |
在查表时,会调用 block_gather 函数,计算 index 数据哪个 block ,进而访问最终的 Embedding 数据。
1 | auto src_block_index = src_index / block_size_; // 哪个 block |
这种设计形式除了高效查表外,还可以减少内存碎片,并且支持增量扩展。总结一下,HashTable 的数据的存储形式为:

接入 torch 层
那么 C++ 实现的 HashTable,是如何参与 torch 的前向计算和反向传播呢?补充一些额外知识,torch.autograd.Function 可以自定义一个自动求导的算子,它允许用户同时定义:
- 正向计算(forward)
- 在反向传播(backward)时,这个算子的求导公式
- 如果 forward() 返回 N 个值,backward() 会接收 N 个梯度(按顺序对应)
在集成自定义的 HashTable 时(非 torch 代码的 C++、cuda 算子),就需要使用这个类完成对自定义算子的自动求导。举个例子:
1 | class HashTableLookupHelpFunction(torch.autograd.Function): |
在上述代码中,会自动调用 forward(ctx, x, table) 得到 y;在需要梯度时,记录运算图;之后反向传播时,会自动调用 backward(ctx, grad_output)。
前向计算
回到 recis/nn/modules/hashtable.py 这个文件。在训练模式下,使用可以自动求导的 HashTableLookupHelpFunction 在前向阶段完成查表,得到内部索引 index 和对应的 Embedding。之后,考虑到输入的 batch 中会有重复的 ids,所以使用 GradWorkerMeanFunction 完成对 index 的 gather 聚合:
1 | # 示例数据 |
反向传播
反向传播的顺序和前向计算完全相反,首先是 GradWorkerMeanFunction 的反向传播:
1 | reduce_grad, None, None = GradWorkerMeanFunction.backward(ctx, grad_from_upstream) |
根据当前 worker 收到的上游梯度 grad_from_upstream 除以 worker 数量,得到平均梯度。最后使用 index_add_ 函数累加相同索引的梯度,并传递给 HashTableLookupHelpFunction:
1 | # grad_output_index = None (因为 index 没有参与后续计算) |
在 HashTableLookupHelpFunction 的 backward 方法中,会调用 accept_grad 函数存储 index 以及对应的梯度。目前梯度只是被存了起来,像所有其他 torch 模型一样,在调用 optimizer.step 后,才会更新这些梯度。
至此,已经实现了一个类 Embedding 设计,并且可以在 torch 端调用并参与 AI 的前向计算与反向传播。在后续的内容中,将一睹 HashTable 上层的设计,看看特征处理、数据通信的实现,如何基于 HashTable 搭建出一个完整的稀疏模型。
未完待续
在 GradWorkerMeanFunction 的反向传播函数中求了梯度的均值,说明这个梯度可能已经过 all_reduce 或其他聚合处理。而这就涉及了多卡通信,暂时还没读到这部分源码,未来某天在写。由于内容过多,一些比较有意思的实现细节只是一带而过,这些有时间在慢慢看看。
- 侵入式智能指针
- ska::flat_hash_map 底层实现
- torch.autograd.Function 自动求导实现原理