Class 1 Byte-Pair Encoding (BPE) Tokenizer

写在前面

好久不见。今天开始我的博客会更新cs336这门课程的学习笔记。这门课算是目前为止对大模型介绍最为详细,最贴近前沿的一门课,对代码手操要求也是比较高的。那么这个系列的笔记就是我个人学习中的一些记录和思考,目的是让⑨看了也能轻松学会。

❄️「バカ!バカ!こんな簡単なこと、あたいでも覚えられるのに!」

Unicode 标准

在python中,我们可以很容易地使用ord(str)来获取一个字符的unicode编码,用chr(code)来获取unicode编码对应的字符,就像这样:

1
2
3
4
>>> ord('尻')
23611
>>> chr(23611)
'尻'

让我们思考以下几个问题:

  1. chr(0)会返回哪个字符?
点击展开 / 折叠内容 >>> chr(0)
'\x00'
  1. 这个字符的字符表示(__repr__())和他的打印表示(__str())有何不同?
点击展开 / 折叠内容 >>> repr(a)
"'\\x00'"
>>> str(a)
'\x00'
>>> a
'\x00'
repr是面向开发者的表示方式,命令行默认使用repr来表示
由于a包含不可见字符(\x00),repr会用**转义序列**显示它
当输入单个a时,解释器实际调用的就是repr(a),但是外层引号不再重复显示
str(a)返回字符本身,但由于字符不可打印,解释器用\x00的形式显示
实际上,\x00表示的是unicode中的空字符NUL
  1. 当这个字符在文本段中会发生什么?
点击展开 / 折叠内容 我们可以用一个例子来做测试:
>>> chr(0)
'\x00'
>>> print(chr(0))
>>> "this is a test" + chr(0) + "string"
'this is a test\x00string'
>>> print("this is a test" + chr(0) + "string")
this is a teststring
正如预料的,print()实际上效果也等同str(),所以会将可见的字符打出
而不可见的字符不会显示,也不会占用宽度,但它依然存在

Unicode 编码

Unicode标准为我们提供了从Unicode码到字符的映射关系,但是这个关系无法适用于模型的构建,因为它的体量过于巨大,而且含有很多的生僻字符。所以,我们需要进行Unicode的编码,以便训练出一个简单的tokenizer(分词器)。我们一般将Unicode编码为含有几个字节的序列,而Unicode自带了多个编码序列标准:UTF-8,UTF-16,UTF-32。由于现在使用率最高的是UTF-8,并且UTF-8表示占用空间比较小,因此我们用UTF-8来进行编码会更加方便。

可以用下面这个例子来说明UTF-8与Unicode的关系:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
>>> test_string = "你好!hello!"
>>> utf8_encoded = test_string.encode("utf-8")
>>> print(utf8_encoded)
b'\xe4\xbd\xa0\xe5\xa5\xbd\xef\xbc\x81hello!'
>>> print(type(utf8_encoded))
<class 'bytes'>
>>> list(utf8_encoded)
[228, 189, 160, 229, 165, 189, 239, 188, 129, 104, 101, 108, 108, 111, 33]
>>> print(len(test_string))
9
>>> print(len(utf8_encoded))
15
>>> print(utf8_encoded.decode("utf-8"))
你好!hello!

在python中,可以用encode()decode()来相互转换Unicode和UTF-8。要访问Python字节对象的底层字节值,我们可以对其进行迭代(例如,调用list())。

你可能会有疑问:为什么我们需要对一整条字符串进行统一的encode,再进行序列化,最后才能获得可以重新decode的字符串呢?一个具体的例子如下:

1
2
3
4
5
>>> def decode_utf8_bytes_to_str_wrong(bytestring: bytes): 
>>> return "".join([bytes([b]).decode("utf-8") for b in bytestring])
>>> print(decode_utf8_bytes_to_str_wrong('你好'.encode('utf-8')))

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe4 in position 0: unexpected end of data

这是因为,除了ASCII字符在UTF-8中占用一个字节,类似于中文字符的其他字符,都或多或少需要占用超过一个字节的位置(中文需要4个)。所以,上面这个函数对每个encode后得到的序列字节单独进行decode,就会发生无法还原的问题。

这也是Unicode字符集的一大弊端,由于每个语言的字符集位置、大小都是分配好的,导致当文字过多时,就有可能会导致多个文字实际上的编码是同一个,以便节省空间。这种情况在中文和日文中最为明显,因为这两种语言都存在很多相近字形的字,但是笔画并不完全一致。这也是很多游戏一旦缺失中文字体,就会自动显示该文字的日文字体的原因。


子词(Subword)分词

虽然字节级(byte-level)分词可以缓解词级(word-level)分词器所面临的不在词表(out-of-vocabulary, OOV)问题,但将文本划分为字节会导致输入序列极度冗长
这会减慢模型训练速度,例如:
在词级语言模型中,一个包含 10 个单词的句子可能只对应 10 个词元(token);
而在字符级模型中,根据单词的长度,这个句子可能会对应 50 个或更多的词元,这种更长的输入序列会导致每一步的模型计算量增加。
此外,在字节序列上进行语言建模也更加困难,因为更长的输入序列会在数据中引入更长程的依赖关系(long-term dependencies)

为此,提出子词分词(Subword tokenization)这一介于词级分词和字节级分词之间的折中方案。
请注意,字节级分词器的词表大小为 256(因为字节值范围是 0 到 255)。
而子词分词器通过使用更大的词表,来换取对输入字节序列的更高压缩率
例如,如果字节序列 b'the' 在原始文本训练数据中频繁出现,那么将它作为一个独立的词汇条目加入词表,可以将原本的 3 个字节合并为单一词元。

那么,我们该如何选择这些要加入词表的子词单元呢?
Sennrich 等人(2016)提出使用 字节对编码(Byte Pair Encoding, BPE; Gage, 1994),这是一种压缩算法,
它通过迭代地将最常出现的字节对替换(“合并”)为一个新的、尚未使用的索引来实现。
请注意,这种算法会将子词条目加入词表,以最大化输入序列的压缩程度——
如果一个单词在输入文本中出现得足够频繁,它最终就会被表示为一个单独的子词单元。

使用 BPE 构建词汇表的子词分词器通常被称为 BPE 分词器(BPE tokenizer)
在本次作业中,我们将实现一种字节级 BPE 分词器,其中的词汇项可以是单个字节,也可以是若干字节的合并序列,
从而在词汇外处理能力可管理的输入序列长度之间取得良好的平衡。
构建 BPE 分词器词表的过程被称为训练BPE 分词器

训练BPE分词器

训练一个BPE分词器主要分为以下几个步骤:

初始化词表

分词器词表是一个从字节化token到常数ID的一对一映射表。因为我们要训练的是一个字节级分词器,因此我们初始化的词表就是一个简单的全ASCII字节集合,众所周知ASCII字符有256个(我们这里用的是扩展ascii字符表),我们可以初始化词表大小为256

imagepng

预分词

一旦我们已经拥有一个词汇表,从原则上讲,就可以统计文本中各个字节相邻出现的频率,并从最频繁的一对字节开始进行合并。然而,这种做法在计算上代价非常高,因为每当我们进行一次合并,就必须对整个语料库重新遍历一遍。

此外,直接在整个语料库上对字节进行合并还可能导致出现仅在标点符号上不同的词被视为完全不同的标记(例如 dog!dog.)。这些词会被分配到完全不同的 token ID,即使它们在语义上极为相似(因为差别仅在于标点)。

为避免这种情况,我们会对语料库进行预分词(pre-tokenization)。可以将其理解为一种粗粒度的分词过程,用于帮助我们统计字符对共同出现的频率。例如,单词 'text' 可能在语料库中出现了 10 次。在这种情况下,当我们统计字符 't''e' 相邻出现的次数时,可以直接知道 'text''t''e' 是相邻的,于是可以将它们的计数增加 10,而无需再次遍历整个语料库。

由于我们正在训练的是一个字节级别的 BPE 模型(byte-level BPE model),每个预分词单元(pre-token)都会被表示为一串 UTF-8 字节序列。

为了方便,我们使用GPT-2使用的分词方式来进行预分词,也就是正则匹配分词法:

1
2
3
4
5
6
7
8
import regex as re
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
RES = re.finditer(PAT, "some text that i'll pre-tokenize")
# 不使用re.findall,因为finditer返回迭代器,适合节省内存,并且包含位置信息
print([res.group() for res in RES])

$ python tests/sample.py
['some', ' text', ' that', ' i', "'ll", ' pre', '-', 'tokenize']

计算BPE合并

现在我们已将我们输入的文本转换为预分词,每个预分词都表示一串UTF-8的字节序列,接下来我们就可以计算BPE合并了,也就是训练BPE分词器。

从整体上看,BPE(Byte-Pair Encoding)算法的核心思想是:迭代地统计文本中所有相邻字节对的出现次数,并找到出现频率最高的一对字节(记作 “A”、“B”)。
然后,将所有该最频繁字节对 (“A”, “B”) 的出现位置合并为一个新的 token,即替换为新的符号 “AB”。

这个新的合并 token 会被加入到我们的词汇表中;因此,BPE 训练完成后的最终词汇表大小等于初始词汇表大小(在我们的例子中为 256)+ 训练过程中执行的合并操作次数

在 BPE 训练的过程中,为了提高计算效率,算法不会考虑跨越预分词边界的字节对

在统计并选择要合并的字节对时,如果多个字节对的出现频率相同,则通过字典序确定优先级,即选择字典序较大的那一对

例如下面的这个例子,(“A”, “B”), (“A”, “C”), (“B”, “ZZ”) 和 (“BA”, “A”) 具有相同的最高出现频率,那么这种情况下我们选择合并(“BA”,“A”),因为它的字典序最大。

特殊token的处理

一些特殊的字符是用来编码元数据的,如<|endoftext|>用于表示文本到达结尾。当我们编码文本时,不应将这些特殊字符分词,而应该视为一个单独的token。这些特殊的token必须加入词表中,他们会获得对应的固定token ID。

不过,在本次的项目中我们不需要这个过程,因为原始的BPE计算已经包含了特殊字符。换言之,所有的特殊字符都已经加入模型的词表中,不需要额外的添加。

来个例子

接下来,我们用一个实际的例子来演示以上的所有步骤,这个例子来源于 Algorithm 1 of Sennrich et al. [2016] 。

假设我们的语料库内容如下:

Text
1
2
3
low low low low low
lower lower widest widest widest
newest newest newest newest newest newest

首先初始化词表。词表包含上文中的特殊token<|endoftext|>,还有256字节值。

然后进行预分词。为了方便,我们就认为只靠词与词之间的空白字符来分词就行,然后我们对首次分词的个数进行统计,得到以下结果:{low: 5, lower: 2, widest: 3, newest: 6}

在python中要建立键值对关系,可以用dict[tuple[bytes], int]类型来实现,比如记录上述结果为{(l,o,w): 5}

python中单个或多个字符都使用bytes作为类型,因为python没有byte类型也没有char类型

接下来进行合并。首先两两组合单个字符进行统计,得到以下统计结果:{lo: 7, ow: 7, we: 8, er: 2, wi: 3, id: 3, de: 3, es: 9, st: 9, ne: 6, ew: 6 }

组合中出现最高的频率的是esst,按照字典序合并st,于是进行第一次合并,结果如下:{(l,o,w) : 5, (l,o,w,e,r) : 2, (w,i,d,e,st) : 3, (n,e,w,e,st) : 6}

在第二个轮次,继续两两组合,统计结果如下:{lo: 7, ow: 7, we: 8, er: 2, wi: 3, id: 3, de: 3, (e,st): 9, ne: 6, ew: 6 }

组合中出现频率最高的是(e,st),因此进行第二次合并得到:{(l,o,w) : 5, (l,o,w,e,r) : 2, (w,i,d,est) : 3, (n,e,w,est) : 6}

继续这个过程直到无法合并下去,最后我们会得到下面的几个最终合并token:{st, est, ow, low, west ,ne},将它们加入词表中,最后的词表就是[<|endoftext|>,[256个字符],st, est, ow, low, west ,ne]

这个就是我们需要的词表了,在这个词表下,单词newest就会被分词为[ne, west]


实验:BPE分词器的实现

资料和部分代码来源这里,这个博客还有很多关于分词器的优化策略介绍和代码实现,在此感谢

在了解BPE分词器的基本实现流程后,我们来动手实现一个最基本的版本。

为了方便,我们首先创建一个BPE_Trainer类,并定义train()方法,参数和返回值按照adpaters.py里面的run_train_bpe()填写,之后只需要通过这个类就能开始训练一个分词器

1
2
3
4
class BPE_Trainer:
def train(self, input_path: str | os.PathLike, vocab_size: int, special_tokens: list[str]) :
# do something..
return vocabulary, merges

接下来我们开始按照流程进行各个模块的编写

首先是初始化词表,这个可以通过python中的dict类型来实现,我们初始化的词表是dict[int, bytes]类型的。需要注意的是对special_tokens需要编码为utf-8类型的序列。

1
2
3
4
5
6
7
8
9
10
11
12
# ASCII码表大小
BYTES_NUM = 256
class BPE_Trainer:
def train(self, input_path: str | os.PathLike, vocab_size: int, special_tokens: list[str]) :
# 初始化词表
vocabulary = {i: bytes([i]) for i in range(BYTES_NUM)}
# 加入special_tokens
for i, sp_token in enumerate(special_tokens):
vocabulary[BYTES_NUM + i] = sp_token.encode("utf-8")
# 计算词表大小
size = BYTES_NUM + len(special_tokens)
return vocabulary,merges

接下来是预分词。我们需要对输入的文本进行正则匹配,获取单词列表。但是,输入文本的体量可能非常大,一次性读入的话很有可能会把内存撑爆。可以采用流式读取的方法逐步处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
@staticmethod
def _get_document_chuck_streaming(input_path: str | os.PathLike, chunk_size: int, special_token: str = "<|endoftext|>") -> \
Generator[str, Any, None]:
* *left = ""
token_len = len(special_token)
with open(input_path, "r", encoding="utf-8") as f:
while True:
block = f.read(chunk_size)
if not block:
break
# 之前剩下的加到新加载的block头位置
block = left + block
left = ""
# 为了不把特殊token给切分,这里找到距离尾部最近的一个特殊符号
last_special_index = block.rfind(special_token)
if last_special_index == -1:
# 没有找到,直接整块留到下次读取
left = block
else:
# 返回特殊token前(包含特殊token)的内容,剩下的下次返回
yield block[:last_special_index+token_len]
left = block[last_special_index+token_len:]
# 最后将剩下的left返回
if left != "":
yield left

yield提供了一个generator,用于在每次申请时动态生成,而不是一次性全部返回结果,这样就能实现以chunk为单位的流式读取

1
2
3
4
5
6
7
8
9
10
11
12
13
14
@staticmethod
def _pretokenize_and_count(input_path: str | os.PathLike,special_tokens: list[str]) -> defaultdict:
pattern = re.compile(r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
special_pattern = "|".join(re.escape(token) for token in special_tokens)
word_counts = defaultdict(int)

for chunk in BPE_Trainer._get_document_chuck_streaming(input_path, chunk_size=CHUNK_SIZE):
# 首先进行特殊token的分词
blocks = re.split(special_pattern,chunk)
# 然后进行普通正则分词
for block in blocks:
for match in re.finditer(pattern, block):
word_counts[match.group(0)] += 1
return word_counts

pattern就是上文提供的分词正则式,而special_pattern是以特殊token为分界线进行分词。

然后,用word_counts实现了各个token的计数,这样预分词就实现好了。

train()函数中加入对预分词函数的调用,同时为了提高训练质量,要把每个分词转换为utf-8编码,正如之前对词表做的那样。

1
2
3
4
5
6
7
8
9
10
11
# 获取单词频率
word_counts = self._pretokenize_and_count(input_path, special_tokens)
# 初始化词表
vocabulary = {i: bytes([i]) for i in range(BYTES_NUM)}
for i, sp_token in enumerate(special_tokens):
vocabulary[BYTES_NUM + i] = sp_token.encode("utf8")
size = BYTES_NUM + len(special_tokens)
# 将词转为utf-8编码
word_encodings = {}
for word in word_counts:
word_encodings[word] = list(word.encode("utf-8"))

接下来进行合并。

首先我们需要将token打碎为单个编码,两两组合并进行统计,用以下的函数来实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
@staticmethod
def _count_pairs(vocabulary, word_counts, word_encodings, pair_strings) :
# 初始化一个defaultdict[Any, int]
pair_counts = defaultdict(int)
for word, count in word_counts.items():
encoding = word_encodings[word]
# 依次将编码两两组合
for i in range(0, len(encoding) - 1):
pair = encoding[i], encoding[i + 1]
pair_counts[pair] += count
if pair not in pair_strings:
pair_strings[pair] = (vocabulary[pair[0]],vocabulary[pair[1]])

return pair_counts

可以看到_count_pairs()这个函数在最后让pair_strings加入之前未加入的pair,这个pair_strings的作用是什么呢?不妨看看train()函数接下来的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 初始化merge表为空
merges = []
pair_strings = {}

while size < vocab_size:
# 调用_count_pairs(),获得两两组合的频数表
pair_counts = BPE_Trainer._count_pairs(vocabulary, word_counts, word_encodings, pair_strings)
# 选择频数最大的合并对,将其合并
merge_pair, max_count = max(pair_counts.items(), key=lambda x:(x[1], pair_strings[x[0]]))
merge_bytes = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
# 将新合并得到的token加入到词表中
vocabulary[size] = merge_bytes
new_id = size
size = size + 1

可以看到调用了_count_pairs()后获取到了两两组合的频数表,接下来的这句代码用到了lambda语句,等价于下面的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# merge_pair, max_count = max(pair_counts.items(), key=lambda x:(x[1], pair_strings[x[0]]))
# 1. 获取字典的所有 (key, value) 对
items = pair_counts.items()
# 2. 定义排序/比较函数(相当于 key=lambda x: (x[1], pair_strings[x[0]]) )
def sort_key(x):
# pair_counts的值
count = x[1]
# 根据key去pair_strings找对应字符串
pair_string = pair_strings[x[0]]
return (count, pair_string)
# 3. 通过max()找到“最大”的元素(按照上面定义的规则来比较,优先比较count,若相同比较pair_string)
best_item = max(items, key=sort_key)
# 4. 拆包得到两个变量
merge_pair, max_count = best_item

这句代码实际上实现了从 pair_counts 这个字典中,找到一个“最佳键值对”(key, value),其选择标准是:

  1. 先比较 value(x[1])的大小;

  2. 如果 value 相同,则比较 pair_strings[x[0]] 的字典序

所以,pair_strings的作用就是保存当次循环中两两合并得到的全部字符串,便于后续根据字典序找到最大的合并对。

在做完合并后,还需要更新编码,也就是将word_encodings的值进行更新。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
while size < vocab_size:
pair_counts = BPE_Trainer._count_pairs(vocabulary, word_counts, word_encodings, pair_strings)
merge_pair, max_count = max(pair_counts.items(), key=lambda x:(x[1], pair_strings[x[0]]))
merge_bytes = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
vocabulary[size] = merge_bytes
new_id = size
size = size + 1

# 还需要更新编码
for word, word_tokens in word_encodings.items():
i = 0
new_tokens = []
has_new_id = false
while i < len(word_tokens):
if i < len(word_tokens) - 1 and (word_tokens[i], word_tokens[i + 1]) == merge_pair:
new_tokens.append(new_id)
i += 2
has_new_id = True
else:
new_tokens.append(word_tokens[i])
i += 1

if has_new_id:
word_encodings[word] = new_tokens
merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))

这部分的代码作用就是对每个键原来的值进行重新的筛选,如果前后两个token都是合并项的组成部分,那么就替换为合并项,否则还是按原来的内容放入,通过这样实现了每个单词的组成token的更新。

相当于现在的word_encodings变成了有合并项(如est)的列表:{(l,o,w) : 5, (l,o,w,e,r) : 2, (w,i,d,est) : 3 },这主要是为了下个大循环进一步合并而准备的

最后将合并项导出到merges中,再return vocabulary, merges就实现了这个train()函数的全部功能。

在完成这个类的编写后,我们在adapters.py中的run_train_bpe函数声明这个类,便可以实现完整的功能了:

1
2
3
4
5
6
7
8
def run_train_bpe(
input_path: str | os.PathLike,
vocab_size: int,
special_tokens: list[str],
**kwargs,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
* *trainer = bpe.BPE_Trainer()
return trainer.train(input_path, vocab_size, special_tokens)

使用python -m pytest tests进行测试,可以发现test_train_bpetest_train_bpe_special_tokens的测试都能顺利通过,说明我们的基本功能都是没有问题的。但是,test_train_bpe_speed的测试没有通过,因为分词器的运行时间约为3s,远远超出了1.5s的最慢时间限制,这说明这个分词器的分词算法还有待完善。


BPE分词器的优化

可以通过在不同环节加入time.perf_counter()来测量每个步骤所用时长。经过测算,分词器主要的时间损耗集中在_count_pairs()计算合并对频数阶段和最后的更新编码阶段,主要是因为这两个过程都是对列表中的全部项进行了更新,可以用下面的例子来说明这样的全量更新造成的性能损失。

假设现在的word_counts是这样的:

{low: 5, lower: 2, widest: 3, newest: 6, es: 2, st: 2}

第一个循环

计算得到合并对频数pair_counts

{lo: 7, ow: 7, we: 8, er: 2, wi: 3, id: 3, de: 3, es: 11, st: 11, ne: 6, ew: 6}

然后选择其中频数最大的合并对,根据字典序选择st,将其加入到词表中

最后需要更新编码word_encodings,事实上是迭代所有单词,并为包含st的单词更新编码:

1
2
3
word_encodings['widest'] = ['w','i','d','e','st']
word_encodings['newest'] = ['n','e','s','e','st']
word_encodings['st'] = ['st']

第二个循环

计算得到合并对频数pair_counts

{lo: 7, ow: 7, we: 8, er: 2, wi: 3, id: 3, de: 3, es: 2, st: 2, ne: 6, ew: 6, est: 9}

然后选择其中频数最大的合并对,根据字典序选择('e', 'st'),将其合并为est加入到词表中

最后需要更新编码word_encodings,事实上是迭代所有单词,并为包含est的单词更新编码:

1
2
word_encodings['widest'] = ['w','i','d','est']
word_encodings['newest'] = ['n','e','s','est']

通过比较可以发现,实际上前后两次循环涉及合并对频数pair_counts的变动非常小,只有esst的频数因为est的合并发生了变化。并且,它们两个减少的数目(-9)和est增加的数目(+9)是完全对应的,这说明了两件事:

  1. 前后循环中,两个token减少的数量就是他们合并后的新token增加的数量;

  2. 前后循环中,只有三个token(2个被合并为1个新的)的数量发生了变化;并且,只有包含这两个token的单词的重新编码发生了改变

这意味着整个列表中大部分的迭代判断逻辑都是没有任何作用的,程序只会在那上面白白浪费时间,如果能够避免这个方面的大量无用计算和判断,我们就能节省下大量的时间。

我们得到了具体的优化思路:通过某种手段找到每个合并token对应的单词,并单独进行频数计算和重新编码,而不需要迭代全部单词。很容易想到倒排索引就能够满足我们的需要,而且编程上也容易实现。

倒排索引常用于寻找一个值(value)对应的全部键(key),可以通过另外定义一个dict类型的变量来实现。用上文中的第一个循环为例子来解释倒排索引的作用:

假设我们已经得到了这次循环的最大频数合并对(s,t),接下来我们需要对包含这个合并对的所有单词更新编码。

首先,使用倒排索引找出包含这个合并对的所有单词:

1
2
'widest'
'newest'

这两个词中所有对的计数如下:

Text
1
2
{wi: 3, id: 3, de: 3, es: 3, st: 3}
{ne: 6, ew: 6, we: 6, es: 6, st: 6}

当前全部合并对的计数(即pair_counts)如下:

Text
1
{lo: 7, ow: 7, we: 8, er: 2, wi: 3, id: 3, de: 3, es: 11, st: 11, ne: 6, ew: 6}

减去两个词中合并对的数量,结果如下:

Text
1
{lo: 7, ow: 7, we: 2, er: 2, wi: 0, id: 0, de: 0, es: 2, st: 2, ne: 0, ew: 0}

接下来,这两个词的标记变为:

Text
1
2
word_encodings['widest'] = ['w','i','d','e','st']
word_encodings['newest'] = ['n','e','s','e','st']

然后,我们根据这些新编码计算新对的频数:

Text
1
2
{wi: 3, id: 3, de: 3, est: 3}
{ne: 6, ew: 6, we: 6, est: 6}

将这些添加回全部合并对的计数(即pair_counts)会得到:

Text
1
{lo: 7, ow: 7, we: 8, er: 2, wi: 3, id: 3, de: 3, es: 2, st: 2, ne: 6, ew: 6, est: 9}

以上的处理方式实际上非常简单,却让原本的计算规模明显降低了,接下来可以看看代码方面的修改。

train()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def train(self, input_path: str | os.PathLike, vocab_size: int, special_tokens: list[str]):
# 获取单词频率
word_counts = self._pretokenize_and_count(input_path, special_tokens)
# 初始化词表
vocabulary = {i: bytes([i]) for i in range(BYTES_NUM)}
for i, sp_token in enumerate(special_tokens):
vocabulary[BYTES_NUM + i] = sp_token.encode("utf8")
size = BYTES_NUM + len(special_tokens)
# 初始化merge表为空
merges = []
# 将词转为utf-8编码
word_encodings = {}
for word in word_counts:
word_encodings[word] = list(word.encode("utf-8"))

pair_strings = {}
# 初始化倒排索引
pair_to_words = defaultdict(set)
pair_counts = BPE_Trainer._count_pairs(vocabulary, word_counts, word_encodings, pair_strings, pair_to_words)
while size < vocab_size:
# 避免vocab_size过大导致pair_counts为空报错
if not pair_counts:
print(f"[Info] Training stopped early: reached max mergeable pairs at vocab size {size}.")
break
BPE_Trainer._merge_a_pair(pair_counts, pair_strings, vocabulary,
pair_to_words, word_counts, word_encodings,
merges, size)
size += 1

return vocabulary, merges

train()函数的改动集中在初始化倒排索引pair_to_words和后面的循环部分。由于算法已经更新为依靠增量更新而不是迭代更新的模式,因此我们不需要每次都进行_count_pairs(),此外,编写了一个新函数_merge_a_pair()用来代替先前版本的挑选最大合并项和更新编码,这使得代码更加简洁。

_count_pairs()

1
2
3
4
5
6
7
8
9
10
11
12
13
@staticmethod
def _count_pairs(vocabulary, word_counts, word_encodings, pair_strings, pair_to_words) :
pair_counts = defaultdict(int)
for word, count in word_counts.items():
encoding = word_encodings[word]
for i in range(0, len(encoding) - 1):
pair = encoding[i], encoding[i + 1]
pair_counts[pair] += count
if pair not in pair_strings:
pair_strings[pair] = (vocabulary[pair[0]],vocabulary[pair[1]])
# 向倒排索引加入对应键值对
pair_to_words[pair].add(word)
return pair_counts

对于这个计算合并对频数的函数改动并不多,只是加入了倒排索引的处理。注意到上面我们使用的是defaultdict(set)类型,因此不需要进行重复项的判断即可直接add(word),这样倒排索引就能记录下每个合并项对应的所有单词了。

_merge_a_pair()

1
2
3
4
5
6
7
8
9
10
11
12
13
@staticmethod
def _merge_a_pair(pair_counts, pair_strings, vocabulary,
pair_to_words, word_counts, word_encodings,
merges, size):
merge_pair, max_count = max(pair_counts.items(), key=lambda x: (x[1], pair_strings[x[0]]))
merge_bytes = vocabulary[merge_pair[0]] + vocabulary[merge_pair[1]]
vocabulary[size] = merge_bytes
new_id = size
affected_words = pair_to_words[merge_pair]
BPE_Trainer._updated_affected_word_count(merge_pair, affected_words, word_encodings,
word_counts, pair_counts,
pair_to_words, new_id, pair_strings, vocabulary)
merges.append((vocabulary[merge_pair[0]], vocabulary[merge_pair[1]]))

这个函数实际上就是将先前版本的挑选最大合并项和更新编码的功能集合到了一起。

挑选最大合并项的语句和之前没有差异,而由于我们提出的新算法,需要首先挑选出含有最大合并项的所有单词,因此使用affected_words来记录这些单词。接下来将更新词频、重新编码和更新pair_to_words的工作交给了updated_affected_word_count来进行。在完成上面的工作后,最后将合并项如先前代码一样返回。

_updated_affected_word_count()

这个函数篇幅比较长,分几段来解释功能:

1
2
3
4
5
@staticmethod
def _updated_affected_word_count(merge_pair, affected_words, word_encodings,
word_counts, pair_counts, pair_to_words,
new_id, pair_strings, vocabulary):
affected_words = affected_words.copy()

首先对传入这个函数的affected_words进行浅拷贝,原因是后续会对pair_counts在迭代中进行更新,而affected_words来自pair_counts[merge_pair]。由于str类型是不会改变值的,所以只需要浅层拷贝即可。

后续的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
for word in affected_words:
word_tokens = word_encodings[word]
wc = word_counts[word]
# 在所有合并项频数中减去两个词中合并项的数量
for i in range(len(word_tokens) - 1):
old_pair = (word_tokens[i], word_tokens[i + 1])
pair_counts[old_pair] -= wc
if pair_counts[old_pair] <= 0:
del pair_counts[old_pair]
pair_to_words.pop(old_pair)
else:
pair_to_words[old_pair].discard(word)
# 重新编码
new_tokens = []
i = 0
while i < len(word_tokens):
if i < len(word_tokens) - 1 and (word_tokens[i], word_tokens[i + 1]) == merge_pair:
new_tokens.append(new_id)
i += 2
else:
new_tokens.append(word_tokens[i])
i += 1
word_encodings[word] = new_tokens
# 在所有合并项频数中增加两个词中合并项的数量
for i in range(len(new_tokens) - 1):
new_pair = (new_tokens[i], new_tokens[i + 1])
pair_counts[new_pair] += wc
# 同时更新倒排索引
pair_to_words[new_pair].add(word)
if new_pair not in pair_strings:
pair_strings[new_pair] = (vocabulary[new_pair[0]], vocabulary[new_pair[1]])

可以看到,接下来的操作是对每一个包含最大合并项的单词进行的,算法的主要步骤正如上文文字描述的一样,分为三个主要过程。

1
2
3
4
5
6
7
8
9
# 在所有合并项频数中减去两个词中合并项的数量
for i in range(len(word_tokens) - 1):
old_pair = (word_tokens[i], word_tokens[i + 1])
pair_counts[old_pair] -= wc
if pair_counts[old_pair] <= 0:
del pair_counts[old_pair]
pair_to_words.pop(old_pair)
else:
pair_to_words[old_pair].discard(word)

第一步是在所有合并项频数中减去两个词中合并对的数量。这一步的实现很简单,但是不要忘了对pair_to_words的处理。如果一个合并项的计数减到0,那么就直接删去倒排索引中的这一项,同时清除pair_counts中的这一项;如果还存在,那么就只将倒排索引中这一合并项对应的单词清除以防止重复统计。

1
2
3
4
5
6
7
8
9
10
11
# 重新编码
new_tokens = []
i = 0
while i < len(word_tokens):
if i < len(word_tokens) - 1 and (word_tokens[i], word_tokens[i + 1]) == merge_pair:
new_tokens.append(new_id)
i += 2
else:
new_tokens.append(word_tokens[i])
i += 1
word_encodings[word] = new_tokens

第二步是重新编码,这一步和先前代码的作用是一样的,就是负责将单词对应的token更新。

1
2
3
4
5
6
7
8
# 在所有合并对频数中增加两个词中合并项的数量
for i in range(len(new_tokens) - 1):
new_pair = (new_tokens[i], new_tokens[i + 1])
pair_counts[new_pair] += wc
# 同时更新倒排索引
pair_to_words[new_pair].add(word)
if new_pair not in pair_strings:
pair_strings[new_pair] = (vocabulary[new_pair[0]], vocabulary[new_pair[1]])

第三步是在所有合并项频数中增加两个词中合并项的数量,同时还需要更新倒排索引。这一步的代码也是很清晰易懂的。

将这个新版本命名为BPE_Trainer_v2.py,并将adapter.py中引入这个包的语句改为

1
from cs336_basics import BPE_Trainer_v2 as bpe

使用python -m pytest tests进行测试,可以发现test_train_bpe_speed的测试顺利通过,分词器的运行时间约为0.27s,性能显著优化了。

imagepng

实际上,如果在文本测试集TinyStories上测试,前一个版本的运行时间是2187s,后一个版本的时间仅仅是757s,时间差异是3倍之多;如果不计算训练过程前半段的_pretokenize_and_count()过程,时间差异将是惊人的13倍!

总结

在这次的课程中,我们了解了BPE Tokenizer的基本工作流程,并编写了一个简单的BPE Tokenizer并对其进行了训练,通过了基本的功能测试。我们还成功对这个BPE Tokenizer进行了初步的优化,使用倒排索引让它的性能提升了10倍以上。

如果你还需要更快的性能提升,不妨看看这里,这个博客对BPE Tokenizer进行了层层优化,从倒排索引到并行搜索,再到使用c语言进行底层加速,并利用堆查找提高性能,最后使用Cython和PyPy加速Python代码,成功地将原来需要训练10+h的大文本数据集open_web优化到只需要200s完成,性能提高了100倍以上!这篇博客同时还将各个优化版本的代码都放在仓库中,非常适合额外的学习。