Skip to content

Commit

Permalink
[feat][store]Add calculate hamming distance
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuRuoyu01 authored and ketor committed Jan 20, 2025
1 parent 463fcee commit a5b1c14
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 21 deletions.
2 changes: 1 addition & 1 deletion contrib/faiss
46 changes: 37 additions & 9 deletions src/engine/storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "engine/storage.h"

#include <climits>
#include <cstdint>
#include <limits>
#include <string>
Expand Down Expand Up @@ -839,23 +840,50 @@ butil::Status Storage::VectorCalcDistance(const ::dingodb::pb::index::VectorCalc
}

int64_t dimension = 0;
auto value_type = op_left_vectors[0].value_type();

auto lambda_op_vector_check_function = [&dimension](const auto& op_vector, const std::string& name) {
auto lambda_op_vector_check_function = [&dimension,&value_type](const auto& op_vector, const std::string& name) {
if (!op_vector.empty()) {
size_t i = 0;
for (const auto& vector : op_vector) {
int64_t current_dimension = static_cast<int64_t>(vector.float_values().size());
if (0 == dimension) {
dimension = current_dimension;
if(vector.value_type() != value_type) {
std::string s = fmt::format("{} index : {} value_type : {} unequal value_type : {}", name, i,
::dingodb::pb::common::ValueType_Name(value_type),
::dingodb::pb::common::ValueType_Name(vector.value_type()));
LOG(ERROR) << s;
return butil::Status(pb::error::EILLEGAL_PARAMTETERS, s);
}

if (dimension != current_dimension) {
std::string s = fmt::format("{} index : {} dimension : {} unequal current_dimension : {}", name, i,
dimension, current_dimension);
if (vector.value_type() == ::dingodb::pb::common::ValueType::FLOAT) {
int64_t current_dimension = static_cast<int64_t>(vector.float_values().size());
if (0 == dimension) {
dimension = current_dimension;
}

if (dimension != current_dimension) {
std::string s = fmt::format("{} float index : {} dimension : {} unequal current_dimension : {}", name, i,
dimension, current_dimension);
LOG(ERROR) << s;
return butil::Status(pb::error::EILLEGAL_PARAMTETERS, s);
}
i++;
} else if (vector.value_type() == ::dingodb::pb::common::ValueType::UINT8) {
int64_t current_dimension = static_cast<int64_t>(vector.binary_values().size());
if (0 == dimension) {
dimension = current_dimension;
}

if (dimension != current_dimension) {
std::string s = fmt::format("{} binary index : {} dimension : {} unequal current_dimension : {}", name, i,
dimension * CHAR_BIT, current_dimension * CHAR_BIT);
LOG(ERROR) << s;
return butil::Status(pb::error::EILLEGAL_PARAMTETERS, s);
}
i++;
} else {
std::string s = fmt::format("{} index : {} value_type : VALUE_TYPE_NONE", name, i);
LOG(ERROR) << s;
return butil::Status(pb::error::EILLEGAL_PARAMTETERS, s);
}
i++;
}
}

Expand Down
91 changes: 82 additions & 9 deletions src/vector/vector_index_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ butil::Status VectorIndexUtils::CalcDistanceByFaiss(
return CalcCosineDistanceByFaiss(op_left_vectors, op_right_vectors, is_return_normlize, distances,
result_op_left_vectors, result_op_right_vectors);
}
case pb::common::METRIC_TYPE_HAMMING: {
return CalcHammingDistanceByFaiss(op_left_vectors, op_right_vectors, is_return_normlize, distances,
result_op_left_vectors, result_op_right_vectors);
}
case pb::common::METRIC_TYPE_NONE:
case pb::common::MetricType_INT_MIN_SENTINEL_DO_NOT_USE_:
case pb::common::MetricType_INT_MAX_SENTINEL_DO_NOT_USE_: {
Expand Down Expand Up @@ -213,6 +217,17 @@ butil::Status VectorIndexUtils::CalcCosineDistanceByFaiss(
result_op_right_vectors, DoCalcCosineDistanceByFaiss);
}

butil::Status VectorIndexUtils::CalcHammingDistanceByFaiss(
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_left_vectors,
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_right_vectors, bool is_return_normlize,
std::vector<std::vector<float>>& distances, // NOLINT
std::vector<::dingodb::pb::common::Vector>& result_op_left_vectors, // NOLINT
std::vector<::dingodb::pb::common::Vector>& result_op_right_vectors) // NOLINT
{ // NOLINT
return CalcDistanceCore(op_left_vectors, op_right_vectors, is_return_normlize, distances, result_op_left_vectors,
result_op_right_vectors, DoCalcHammingDistanceByFaiss);
}

butil::Status VectorIndexUtils::CalcL2DistanceByHnswlib(
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_left_vectors,
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_right_vectors, bool is_return_normlize,
Expand Down Expand Up @@ -307,6 +322,33 @@ butil::Status VectorIndexUtils::DoCalcCosineDistanceByFaiss(
return butil::Status();
}

butil::Status VectorIndexUtils::DoCalcHammingDistanceByFaiss(
const ::dingodb::pb::common::Vector& op_left_vectors, const ::dingodb::pb::common::Vector& op_right_vectors,
bool is_return_normlize,
float& distance, // NOLINT
dingodb::pb::common::Vector& result_op_left_vectors, // NOLINT
dingodb::pb::common::Vector& result_op_right_vectors) // NOLINT
{ // NOLINT
faiss::VectorDistance<faiss::MetricType::METRIC_HAMMING> vector_distance;
vector_distance.d = op_left_vectors.binary_values().size();

std::vector<uint8_t> left_vectors = std::vector<uint8_t>(op_left_vectors.binary_values().size());
for (int j = 0; j < op_left_vectors.binary_values().size(); j++) {
left_vectors[j] = static_cast<uint8_t>(op_left_vectors.binary_values()[j][0]);
}
std::vector<uint8_t> right_vectors = std::vector<uint8_t>(op_right_vectors.binary_values().size());
for (int j = 0; j < op_right_vectors.binary_values().size(); j++) {
right_vectors[j] = static_cast<uint8_t>(op_right_vectors.binary_values()[j][0]);
}

distance = vector_distance(left_vectors.data(), right_vectors.data());

ResultOpBinaryVectorAssignmentWrapper(op_left_vectors, op_right_vectors, is_return_normlize, result_op_left_vectors,
result_op_right_vectors);

return butil::Status();
}

butil::Status VectorIndexUtils::DoCalcL2DistanceByHnswlib(
const ::dingodb::pb::common::Vector& op_left_vectors, const ::dingodb::pb::common::Vector& op_right_vectors,
bool is_return_normlize,
Expand Down Expand Up @@ -386,6 +428,13 @@ void VectorIndexUtils::ResultOpVectorAssignment(dingodb::pb::common::Vector& res
result_op_vectors.set_value_type(::dingodb::pb::common::ValueType::FLOAT);
}

void VectorIndexUtils::ResultOpBinaryVectorAssignment(dingodb::pb::common::Vector& result_op_vectors,
const ::dingodb::pb::common::Vector& op_vectors) {
result_op_vectors = op_vectors;
result_op_vectors.set_dimension(result_op_vectors.binary_values().size() * CHAR_BIT);
result_op_vectors.set_value_type(::dingodb::pb::common::ValueType::UINT8);
}

void VectorIndexUtils::ResultOpVectorAssignmentWrapper(const ::dingodb::pb::common::Vector& op_left_vectors,
const ::dingodb::pb::common::Vector& op_right_vectors,
bool is_return_normlize,
Expand All @@ -403,6 +452,23 @@ void VectorIndexUtils::ResultOpVectorAssignmentWrapper(const ::dingodb::pb::comm
}
}

void VectorIndexUtils::ResultOpBinaryVectorAssignmentWrapper(
const ::dingodb::pb::common::Vector& op_left_vectors, const ::dingodb::pb::common::Vector& op_right_vectors,
bool is_return_normlize,
dingodb::pb::common::Vector& result_op_left_vectors, // NOLINT
dingodb::pb::common::Vector& result_op_right_vectors) // NOLINT
{ // NOLINT
if (is_return_normlize) {
if (result_op_left_vectors.binary_values().empty()) {
ResultOpBinaryVectorAssignment(result_op_left_vectors, op_left_vectors);
}

if (result_op_right_vectors.binary_values().empty()) {
ResultOpBinaryVectorAssignment(result_op_right_vectors, op_right_vectors);
}
}
}

void VectorIndexUtils::NormalizeVectorForFaiss(float* x, int32_t d) {
static const float kFloatAccuracy = 0.00001;

Expand Down Expand Up @@ -446,6 +512,10 @@ butil::Status VectorIndexUtils::CheckVectorDimension(const std::vector<pb::commo
DINGO_LOG(ERROR) << s;
return butil::Status(pb::error::Errno::EVECTOR_INVALID, s);
}
if (vector_with_id.vector().dimension() != dimension) {
std::string s = fmt::format("vector dimension not match, {} {}", vector_with_id.vector().dimension(), dimension);
return butil::Status(pb::error::Errno::EVECTOR_INVALID, s);
}
}

return butil::Status::OK();
Expand Down Expand Up @@ -486,9 +556,9 @@ template <typename T>
std::unique_ptr<T[]> VectorIndexUtils::ExtractVectorValue(const std::vector<pb::common::VectorWithId>& vector_with_ids,
faiss::idx_t dimension, bool normalize) {
std::unique_ptr<T[]> vectors = nullptr;
if (std::is_same<T, float>::value) {
if constexpr (std::is_same<T, float>::value) {
vectors = std::make_unique<T[]>(vector_with_ids.size() * dimension);
} else if (std::is_same<T, uint8_t>::value) {
} else if constexpr (std::is_same<T, uint8_t>::value) {
vectors = std::make_unique<T[]>(vector_with_ids.size() * dimension / CHAR_BIT);
} else {
std::string s = fmt::format("invalid value typename type");
Expand All @@ -497,8 +567,8 @@ std::unique_ptr<T[]> VectorIndexUtils::ExtractVectorValue(const std::vector<pb::
}

for (size_t i = 0; i < vector_with_ids.size(); ++i) {
if (vector_with_ids[i].vector().value_type() == pb::common::ValueType::FLOAT) {
if (!std::is_same<T, float>::value) {
if constexpr (std::is_same<T, float>::value) {
if (vector_with_ids[i].vector().value_type() != pb::common::ValueType::FLOAT) {
std::string s = fmt::format("template not match vectors value_type : {}",
pb::common::ValueType_Name(vector_with_ids[i].vector().value_type()));
DINGO_LOG(ERROR) << s;
Expand All @@ -509,15 +579,17 @@ std::unique_ptr<T[]> VectorIndexUtils::ExtractVectorValue(const std::vector<pb::
if (normalize) {
VectorIndexUtils::NormalizeVectorForFaiss(reinterpret_cast<float*>(vectors.get()) + i * dimension, dimension);
}
} else if (vector_with_ids[i].vector().value_type() == pb::common::ValueType::UINT8) {
if (!std::is_same<T, uint8_t>::value) {
} else if constexpr (std::is_same<T, uint8_t>::value) {
if (vector_with_ids[i].vector().value_type() != pb::common::ValueType::UINT8) {
std::string s = fmt::format("template not match vectors value_type : {}",
pb::common::ValueType_Name(vector_with_ids[i].vector().value_type()));
DINGO_LOG(ERROR) << s;
return nullptr;
}
const auto& vector_value = vector_with_ids[i].vector().binary_values();
memcpy(vectors.get() + i * dimension / CHAR_BIT, vector_value.data(), dimension / CHAR_BIT);
for (int j = 0; j < vector_value.size(); j++) {
vectors.get()[i * dimension / CHAR_BIT + j] = static_cast<uint8_t>(vector_value[j][0]);
}
} else {
std::string s =
fmt::format("invalid value type : {}", pb::common::ValueType_Name(vector_with_ids[i].vector().value_type()));
Expand Down Expand Up @@ -855,8 +927,9 @@ butil::Status VectorIndexUtils::ValidateVectorIndexParameter(
!(ivf_flat_parameter.metric_type() == pb::common::METRIC_TYPE_INNER_PRODUCT) &&
!(ivf_flat_parameter.metric_type() == pb::common::METRIC_TYPE_L2)) {
DINGO_LOG(ERROR) << "ivf_flat_parameter.metric_type is illegal " << ivf_flat_parameter.metric_type();
return butil::Status(pb::error::Errno::EILLEGAL_PARAMTETERS,
"ivf_flat_parameter.metric_type is illegal " + std::to_string(ivf_flat_parameter.metric_type()));
return butil::Status(
pb::error::Errno::EILLEGAL_PARAMTETERS,
"ivf_flat_parameter.metric_type is illegal " + std::to_string(ivf_flat_parameter.metric_type()));
}

// check ivf_flat_parameter.ncentroids
Expand Down
25 changes: 23 additions & 2 deletions src/vector/vector_index_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ class VectorIndexUtils {
std::vector<::dingodb::pb::common::Vector>& result_op_left_vectors,
std::vector<::dingodb::pb::common::Vector>& result_op_right_vectors);

static butil::Status CalcHammingDistanceByFaiss(
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_left_vectors,
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_right_vectors,
bool is_return_normlize, std::vector<std::vector<float>>& distances,
std::vector<::dingodb::pb::common::Vector>& result_op_left_vectors,
std::vector<::dingodb::pb::common::Vector>& result_op_right_vectors);

static butil::Status CalcL2DistanceByHnswlib(
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_left_vectors,
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_right_vectors,
Expand Down Expand Up @@ -132,6 +139,12 @@ class VectorIndexUtils {
dingodb::pb::common::Vector& result_op_left_vectors,
dingodb::pb::common::Vector& result_op_right_vectors);

static butil::Status DoCalcHammingDistanceByFaiss(const ::dingodb::pb::common::Vector& op_left_vectors,
const ::dingodb::pb::common::Vector& op_right_vectors,
bool is_return_normlize, float& distance,
dingodb::pb::common::Vector& result_op_left_vectors,
dingodb::pb::common::Vector& result_op_right_vectors);

static butil::Status DoCalcL2DistanceByHnswlib(const ::dingodb::pb::common::Vector& op_left_vectors,
const ::dingodb::pb::common::Vector& op_right_vectors,
bool is_return_normlize, float& distance,
Expand All @@ -152,18 +165,26 @@ class VectorIndexUtils {

static void ResultOpVectorAssignment(dingodb::pb::common::Vector& result_op_vectors,
const ::dingodb::pb::common::Vector& op_vectors);
static void ResultOpBinaryVectorAssignment(dingodb::pb::common::Vector& result_op_vectors,
const ::dingodb::pb::common::Vector& op_vectors);

static void ResultOpVectorAssignmentWrapper(const ::dingodb::pb::common::Vector& op_left_vectors,
const ::dingodb::pb::common::Vector& op_right_vectors,
bool is_return_normlize,
dingodb::pb::common::Vector& result_op_left_vectors,
dingodb::pb::common::Vector& result_op_right_vectors);

static void ResultOpBinaryVectorAssignmentWrapper(const ::dingodb::pb::common::Vector& op_left_vectors,
const ::dingodb::pb::common::Vector& op_right_vectors,
bool is_return_normlize,
dingodb::pb::common::Vector& result_op_left_vectors,
dingodb::pb::common::Vector& result_op_right_vectors);

static void NormalizeVectorForFaiss(float* x, int32_t d);
static void NormalizeVectorForHnsw(const float* data, uint32_t dimension, float* norm_array);

static butil::Status CheckVectorDimension(
const std::vector<pb::common::VectorWithId>& vector_with_ids, int dimension);
static butil::Status CheckVectorDimension(const std::vector<pb::common::VectorWithId>& vector_with_ids,
int dimension);

static std::unique_ptr<faiss::idx_t[]> CastVectorId(const std::vector<int64_t>& delete_ids);

Expand Down
76 changes: 76 additions & 0 deletions test/unit_test/vector/test_vector_index_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
#include <gtest/gtest.h>

#include <array>
#include <climits>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <random>
#include <string>
#include <vector>

#include "butil/status.h"
Expand Down Expand Up @@ -4222,6 +4224,80 @@ TEST_F(VectorIndexUtilsTest, DoCalcCosineDistanceByFaiss) {
}
}

TEST_F(VectorIndexUtilsTest, DoCalcHammingDistanceByFaiss) {
// ok
{
constexpr uint32_t kDimension = 16;
std::array<uint8_t, kDimension / CHAR_BIT> data_left{};

std::mt19937 rng;
std::uniform_real_distribution<> distrib(0, 255);
for (auto& elem : data_left) {
elem = distrib(rng);
}

LOG(INFO) << "left_data : \t";
for (const auto elem : data_left) {
LOG(INFO) << std::setw(3) << static_cast<int32_t>(elem) << " ";
}

std::array<uint8_t, kDimension / CHAR_BIT> data_right{};
for (auto& elem : data_right) {
elem = distrib(rng);
}

LOG(INFO) << "right_data : \t";
for (const auto elem : data_right) {
LOG(INFO) << std::setw(3) << static_cast<int32_t>(elem) << " ";
}

::dingodb::pb::common::Vector result_op_left_vectors;
::dingodb::pb::common::Vector result_op_right_vectors;
::dingodb::pb::common::Vector op_left_vectors;
::dingodb::pb::common::Vector op_right_vectors;
bool is_return_normlize = true;
float distance = 0.0f;

op_left_vectors.set_value_type(::dingodb::pb::common::ValueType::UINT8);
op_right_vectors.set_value_type(::dingodb::pb::common::ValueType::UINT8);

for (const auto elem : data_left) {
std::string str = std::string(1, static_cast<char>(elem));
op_left_vectors.add_binary_values(str);
}

for (const auto elem : data_right) {
std::string str = std::string(1, static_cast<char>(elem));
op_right_vectors.add_binary_values(str);
}

butil::Status ok =
VectorIndexUtils::DoCalcHammingDistanceByFaiss(op_left_vectors, op_right_vectors, is_return_normlize, distance,
result_op_left_vectors, result_op_right_vectors);

EXPECT_EQ(ok.error_code(), pb::error::Errno::OK);
LOG(INFO) << "DoCalcHammingDistanceByFaiss:distance:" << distance;

EXPECT_EQ(result_op_left_vectors.value_type(), ::dingodb::pb::common::ValueType::UINT8);
LOG(INFO) << "DoCalcHammingDistanceByFaiss:left";
LOG(INFO) << "DoCalcHammingDistanceByFaiss:value_type : " << result_op_left_vectors.value_type();
LOG(INFO) << "DoCalcHammingDistanceByFaiss:dimension : " << result_op_left_vectors.dimension();
LOG(INFO) << "DoCalcHammingDistanceByFaiss:data : \t\t";
for (const auto& elem : result_op_left_vectors.binary_values()) {
LOG(INFO) << static_cast<int32_t>(elem[0]) << " ";
}

EXPECT_EQ(result_op_right_vectors.value_type(), ::dingodb::pb::common::ValueType::UINT8);
LOG(INFO) << "DoCalcHammingDistanceByFaiss:right";
LOG(INFO) << "DoCalcHammingDistanceByFaiss:value_type : " << result_op_right_vectors.value_type();
LOG(INFO) << "DoCalcHammingDistanceByFaiss:dimension : " << result_op_right_vectors.dimension();
LOG(INFO) << "DoCalcHammingDistanceByFaiss:data : \t\t";
for (const auto& elem : result_op_right_vectors.binary_values()) {
LOG(INFO) << static_cast<int32_t>(elem[0]) << " ";
}
}
}

TEST_F(VectorIndexUtilsTest, DoCalcL2DistanceByHnswlib) {
// ok
{
Expand Down

0 comments on commit a5b1c14

Please sign in to comment.