在这一期中,我们延续上一期 Bert 中文短句相似度计算 Docker CPU镜像,继续使用 huggingface transformersentence-transformer 类库,并将英语句子生成 bert embedding,然后引入 faiss 类库来建立索引,最后查询最接近的句子。

Docker 镜像获取方式

本期 docker 镜像获取方式为,关注 MyEncyclopedia 公众号后回复 docker-faiss-transformer 即可获取如下完整命令。

1
docker run -p 8888:8888 myencyclopedia/faiss-demo bash -c 'jupyter notebook --allow-root --port 8888 --NotebookApp.token= --ip 0.0.0.0'

然后打开浏览器,输入 http://localhost:8888/notebooks/faiss_demo.ipynb

faiss 简介

Faiss 的全称是Facebook AI Similarity Search,是由 Facebook 开发的适用于稠密向量匹配的开源库,作为向量化检索开山鼻祖,Faiss 提供了一套查询海量高维数据集的解决方案,它从两个方面改善了暴力搜索算法存在的问题:降低空间占用和加快检索速度。此外,Faiss 提供了若干种方法实现数据压缩,包括 PCA、Product-Quantization等。

Faiss 主要特性:

  • 支持相似度检索和聚类;
  • 支持多种索引方式;
  • 支持CPU和GPU计算;
  • 支持Python和C++调用;

Faiss 使用流程

使用 faiss 分成两部,第一步需要对原始向量建立索引文件,第二步再对索引文件进行向量 search 操作。

在第一次建立索引文件的时候,需要经过 trainadd 两个过程;后续如果有新的向量需要被添加到索引文件,只需要一个 add 操作来实现增量索引更新,但是如果增量的量级与原始索引差不多的话,整个向量空间就可能发生了一些变化,这个时候就需要重新建立整个索引文件,也就是再用全部的向量来走一遍 trainadd,至于具体是如何 trainadd的,就和特定的索引类型有关了。

1. IndexFlatL2和indexFlatIP

对于精确搜索,例如欧式距离 faiss.indexFlatL2 或 内积距离 faiss.indexFlatIP,没有 train 过程,add 完直接可以 search

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import faiss 


index = faiss.IndexFlatL2(d)


index.add(xb)
print(index.ntotal)



k = 4

D, I = index.search(xq, k)


print(I[:5])

 2.公司简介

IndexFlatL2 的结果虽然精确,但当数据集比较大的时候,暴力搜索的时间复杂度很高,因此我们一般会使用其他方式的索引来加速。比如 IndexIVFFlat,将数据集在 train 阶段分割为几部分,技术术语为 Voronoi Cells,每个数据向量只能落在一个cell中。Search 时只需要查询query向量落在cell中的数据了,降低了距离计算次数。这个过程本质就是高维 KNN 聚类算法。search 阶段使用倒排索引来。

IndexIVFFlat 需要一个训练的阶段,其与另外一个索引 quantizer 有关,通过 quantizer 来判断属于哪个cell。IndexIVFFlat 在搜索阶段,引入了nlist(cell的数量)与nprob(执行搜索的cell数)参数。增大nprobe可以得到与brute-force更为接近的结果,nprobe就是速度与精度的调节器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import faiss
nlist = 100
k = 4


quantizer = faiss.IndexFlatL2(d)


index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)


index.train(xb)
index.add(xb)
index.nprobe = 10
D, I = index.search(xq, k)
print(I[-5:])

3. IndexIVFPQ

IndexFlatL2 和 IndexIVFFlat都要存储所有的向量数据。对于超大规模数据集来说,可能会不大现实。因此IndexIVFPQ 索引可以用来压缩向量,具体的压缩算法就是 Product-Quantization,注意,由于高维向量被压缩,因此 search 时候返回也是近似的结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import faiss

nlist = 100

m = 8

k = 4
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)
index.train(xb)
index.add(xb)
index.nprobe = 10
D, I = index.search(xq, k)
print(I[-5:])

在本期中,我们仅使用基本的 IndexIVFFlat 和 IndexFlatIP 完成 bert embedding 的索引和搜索,后续会有篇幅来解读 Product-Quantization 的论文原理和代码实践。

ag_news 新闻数据集

ag_news 新闻数据集 3.0 包含了英语新闻标题,training 部分包含 120000条数据, test 部分包含 7600条数据。

ag_news 可以通过 huggingface datasets API 自动下载

1
2
3
4
5
6
7
8
9
10
def load_dataset(part='test') -> List[str]:
ds = datasets.load_dataset("ag_news")
list_str = [r['text'] for r in ds[part]]
return list_str

list_str = load_dataset(part='train')
print(f'{len(list_str)}')
for s in list_str[:3]:
print(s)
print('\n')

显示前三条新闻标题为

1
2
3
4
5
6
7
8
9
120000
Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.


Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market.


Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums.

 变压器

和上一期一样,我们利用sentence-transformer 生成句子级别的embedding。其原理基于 Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks (https://arxiv.org/abs/1908.10084)这篇论文。基本思想很直接,将句子中的每个词的 bert embedding ,输进入一个池化层(pooling),例如选择最简单的平均池化层,将所有token embedding 的均值作为输出,便得到跟输入句子长度无关的一个定长的 sentence embedding。



结果展示

数据集 train 部分由于包含的样本比较多,需要一段时间生成 bert embedding,大家可以使用 load_dataset(part='test') 来快速体验。下面我们演示一个查询 how to make money 的最接近结果。

1
2
3
4
index = load_index('news_train.index')
list_id = query(model, index, 'how to make money')
for id in list_id:
print(list_str[id])
1
2
3
4
5
6
7
8
9
Profit From That Traffic Ticket Got a traffic ticket? Can't beat 'em? Join 'em by investing in the company that processes those tickets.

Answers in the Margins By just looking at operating margins, investors can find profitable industry leaders.

Types of Investors: Which Are You? Learn a little about yourself, and it may improve your performance.

Target Can Aim High Target can maintain its discount image while offering pricier services and merchandise.

Finance moves Ford into the black US carmaker Ford Motor returns to profit, as the money it makes from lending to customers outweighs losses from selling vehicles.

核心代码

所有可运行代码和数据都已经包含在 docker 镜像中了,下面列出核心代码

建立索引

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def train_flat(index_name, id_list, embedding_list, num_clusters):
import numpy as np
import faiss

dim = 768
m = 16

embeddings = np.asarray(embedding_list)

quantiser = faiss.IndexFlatIP(dim)
index = faiss.IndexIVFFlat(quantiser, dim, num_clusters, faiss.METRIC_INNER_PRODUCT)
index.train(embeddings)

ids = np.arange(0, len(id_list))
ids = np.asarray(ids.astype('int64'))

index.add_with_ids(embeddings, ids)
print(index.is_trained)
print("Total Number of Embeddings in the index", index.ntotal)
faiss.write_index(index, index_name)

查询结果

1
2
3
4
5
6
7
def query(model, index, query_str: str) -> List[int]:
topk = 5
q_embed = model.encode([query_str])
D, I = index.search(q_embed, topk)
print(D)
print(I)
return I[0].tolist()