在大数据场景下,高效地进行近似最近邻搜索(Approximate Nearest Neighbors, ANN)是许多应用的关键,如推荐系统、图像检索等。传统的单机版 HNSWlib 在处理大规模数据时速度较慢,因此我们尝试采用分布式解决方案 HNSWlib-PySpark 进行召回实验。
背景
HNSW(Hierarchical Navigable Small World)是一种高效的 ANN 算法,通过构建层次化的图结构来加速搜索过程。HNSWlib 是其实现库,但在单机环境下处理大规模数据时性能受限。HNSWlib-PySpark 将 HNSW 算法与 PySpark 集成,利用分布式计算的优势,可以更高效地处理海量数据。
HNSWlib-PySpark 测试
安装
首先,确保安装了 HNSWlib-PySpark:
pip install pyspark-hnsw --upgrade
在 PySpark 调度时,添加以下配置:
conf.spark.jars.packages 'com.github.jelmerk:hnswlib-spark_2.3_2.11:1.1.0'
测试代码
下面是完整的测试代码:
import os import argparse import random import logging from pyspark.sql import SparkSession from pyspark.sql.functions import udf, column, encode from pyspark.sql.types import * from datetime import datetime, timedelta import requests as req from sparknlp.base import * from sparknlp.annotator import * from pyspark.ml import Pipeline from pyspark_hnsw.knn import HnswSimilarity from pyspark_hnsw.evaluation import KnnSimilarityEvaluator from pyspark_hnsw.knn import * from pyspark_hnsw.linalg import Normalizer from pyspark_hnsw.conversion import VectorConverter from pyspark.ml.linalg import Vectors hadoop = os.path.join(os.environ['HADOOP_COMMON_HOME'], 'bin/hadoop') def init_spark(): spark = SparkSession.builder \ .config("spark.sql.caseSensitive", "false") \ .config("spark.shuffle.spill", "true") \ .config("spark.shuffle.spill.compress", "true") \ .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \ .config("metastore.catalog.default", "hive") \ .config("spark.sql.hive.convertMetastoreOrc", "true") \ .config("spark.kryoserializer.buffer.max", "1024m") \ .config("spark.kryoserializer.buffer", "64m") \ .config("spark.driver.maxResultSize","4g") \ .config("spark.sql.broadcastTimeout", "36000") \ .enableHiveSupport() \ .getOrCreate() return spark def system_command(command): code = os.system(command) if code != 0: logging.error(f"Command: ({command}) excute failed.") else: logging.info(f"Command: ({command}) excute succeed.") if __name__ == "__main__": spark = init_spark() # 查询数据 df = spark.sql( f"""select user_id_zm, user_embedding from algo.dssm_user_embedding where pt='2025-05-18' """ ) # 转换数据格式 df_user_id = df.rdd.map(lambda row: row.user_id_zm) df_embedding = df.rdd.map(lambda row: Vectors.dense(row.user_embedding)) new_df = df_user_id.zip(df_embedding).toDF(schema=['user_id_zm', 'user_embedding']) # 查看数据 new_df.show() # 数据预处理 converter = VectorConverter(inputCol='user_embedding', outputCol='features') normalizer = Normalizer(inputCol='features', outputCol='normalized_features') # HNSW相似度计算 hnsw = HnswSimilarity(identifierCol='user_id_zm', queryIdentifierCol='user_id_zm', featuresCol='normalized_features', distanceFunction='inner-product', m=48, ef=15, k=10, efConstruction=200, numPartitions=2, excludeSelf=True, similarityThreshold=0.4, predictionCol='approximate') # 暴力计算相似度 brute_force = BruteForceSimilarity(identifierCol='user_id_zm', queryIdentifierCol='user_id_zm', featuresCol='normalized_features', distanceFunction='inner-product', k=10, numPartitions=2, excludeSelf=True, similarityThreshold=0.4, predictionCol='exact') # 构建 Pipeline pipeline = Pipeline(stages=[converter, normalizer, hnsw, brute_force]) model = pipeline.fit(new_df) # 对部分数据进行查询 query_items = new_df.sample(0.01) output = model.transform(query_items) # 评估结果 evaluator = KnnSimilarityEvaluator(approximateNeighborsCol='approximate', exactNeighborsCol='exact') accuracy = evaluator.evaluate(output) print("accuracy: ", accuracy) # 停止 SparkSession spark.stop() del spark
实验结果
在测试中,使用 HNSWlib-PySpark 进行召回实验,与暴力计算相比,召回率在 0.8 ~ 0.9 之间,这个结果在大规模数据场景下还算可以接受。HNSWlib-PySpark 的优势在于其分布式架构,能够有效处理海量数据,提高召回效率。
参考
本文由博客一文多发平台 OpenWrite 发布!