Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions bindings/cpp/include/svs/runtime/vamana_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ struct SVS_RUNTIME_API VamanaIndex {
IDFilter* filter = nullptr
) const noexcept = 0;

// Compute distance between stored vector `id` and `query` (dim floats).
virtual Status
get_distance(double* distance, size_t id, const float* query) const noexcept = 0;

// Reconstruct `n` vectors by ID into `output` buffer (n * dim floats).
virtual Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept = 0;

// Utility function to check storage kind support
static Status check_storage_kind(StorageKind storage_kind) noexcept;

Expand Down
9 changes: 9 additions & 0 deletions bindings/cpp/src/dynamic_vamana_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ struct DynamicVamanaIndexManagerBase : public DynamicVamanaIndex {
Status save(std::ostream& out) const noexcept override {
return runtime_error_wrapper([&] { impl_->save(out); });
}

Status
get_distance(double* distance, size_t id, const float* query) const noexcept override {
return runtime_error_wrapper([&] { *distance = impl_->get_distance(id, query); });
}

Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept override {
return runtime_error_wrapper([&] { impl_->reconstruct_at(n, ids, output); });
}
};
} // namespace

Expand Down
17 changes: 17 additions & 0 deletions bindings/cpp/src/dynamic_vamana_index_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,23 @@ class DynamicVamanaIndexImpl {
return remove(ids_to_delete);
}

double get_distance(size_t id, const float* query) const {
if (!impl_) {
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}
auto query_span = std::span<const float>(query, dim_);
return impl_->get_distance(id, query_span);
}

void reconstruct_at(size_t n, const size_t* ids, float* output) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of raw pointers here, I would recommend to follow approaches for existing methods' signatures defined in this class where arguments are typed.
See: void DynamicVamanaIndexImpl::add(data::ConstSimpleDataView<float> data, std::span<const size_t> labels)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, changed

if (!impl_) {
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}
svs::data::SimpleDataView<float> dst{output, n, dim_};
std::span<const uint64_t> id_span{reinterpret_cast<const uint64_t*>(ids), n};
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if on x86-64 CPU size_t is equal to uint64_t, there is still the risk of reinterpret_cast here.
Would suggest type-safe conversion:

Suggested change
std::span<const uint64_t> id_span{reinterpret_cast<const uint64_t*>(ids), n};
std::vector<const uint64_t> id_vec(ids, ids + n);

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx, changed

impl_->reconstruct_at(dst, id_span);
}

void reset() {
impl_.reset();
ntotal_soft_deleted = 0;
Expand Down
9 changes: 9 additions & 0 deletions bindings/cpp/src/vamana_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ struct VamanaIndexManagerBase : public VamanaIndex {
Status save(std::ostream& out) const noexcept override {
return runtime_error_wrapper([&] { impl_->save(out); });
}

Status
get_distance(double* distance, size_t id, const float* query) const noexcept override {
return runtime_error_wrapper([&] { *distance = impl_->get_distance(id, query); });
}

Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept override {
return runtime_error_wrapper([&] { impl_->reconstruct_at(n, ids, output); });
}
};
} // namespace

Expand Down
17 changes: 17 additions & 0 deletions bindings/cpp/src/vamana_index_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,23 @@ class VamanaIndexImpl {
}
}

double get_distance(size_t id, const float* query) const {
if (!impl_) {
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}
auto query_span = std::span<const float>(query, dim_);
return get_impl()->get_distance(id, query_span);
}

void reconstruct_at(size_t n, const size_t* ids, float* output) {
if (!impl_) {
throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"};
}
svs::data::SimpleDataView<float> dst{output, n, dim_};
std::span<const uint64_t> id_span{reinterpret_cast<const uint64_t*>(ids), n};
get_impl()->reconstruct_at(dst, id_span);
}

void reset() { impl_.reset(); }

void save(std::ostream& out) const {
Expand Down
138 changes: 138 additions & 0 deletions bindings/cpp/tests/runtime_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,3 +881,141 @@ CATCH_TEST_CASE("RangeSearchFunctionalStatic", "[runtime][static_vamana]") {

svs::runtime::v0::VamanaIndex::destroy(index);
}

CATCH_TEST_CASE("GetDistanceDynamic", "[runtime]") {
const auto& test_data = get_test_data();
svs::runtime::v0::DynamicVamanaIndex* index = nullptr;
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
auto status = svs::runtime::v0::DynamicVamanaIndex::build(
&index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
CATCH_REQUIRE(status.ok());

std::vector<size_t> labels(test_n);
std::iota(labels.begin(), labels.end(), 0);
status = index->add(test_n, labels.data(), test_data.data());
CATCH_REQUIRE(status.ok());

// Self-distance should be approximately 0
double dist = -1.0;
const float* vec0 = test_data.data();
status = index->get_distance(&dist, 0, vec0);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(dist < 1e-6);

// Distance to a different vector should be positive
const float* vec1 = test_data.data() + test_d;
status = index->get_distance(&dist, 0, vec1);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(dist > 0.0);

svs::runtime::v0::DynamicVamanaIndex::destroy(index);
}

CATCH_TEST_CASE("GetDistanceStatic", "[runtime][static_vamana]") {
const auto& test_data = get_test_data();
svs::runtime::v0::VamanaIndex* index = nullptr;
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
auto status = svs::runtime::v0::VamanaIndex::build(
&index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
CATCH_REQUIRE(status.ok());

status = index->add(test_n, test_data.data());
CATCH_REQUIRE(status.ok());

// Self-distance should be approximately 0
double dist = -1.0;
const float* vec0 = test_data.data();
status = index->get_distance(&dist, 0, vec0);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(dist < 1e-6);

// Distance to a different vector should be positive
const float* vec1 = test_data.data() + test_d;
status = index->get_distance(&dist, 0, vec1);
CATCH_REQUIRE(status.ok());
CATCH_REQUIRE(dist > 0.0);

svs::runtime::v0::VamanaIndex::destroy(index);
}

CATCH_TEST_CASE("ReconstructAtDynamic", "[runtime]") {
const auto& test_data = get_test_data();
svs::runtime::v0::DynamicVamanaIndex* index = nullptr;
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
auto status = svs::runtime::v0::DynamicVamanaIndex::build(
&index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
CATCH_REQUIRE(status.ok());

std::vector<size_t> labels(test_n);
std::iota(labels.begin(), labels.end(), 0);
status = index->add(test_n, labels.data(), test_data.data());
CATCH_REQUIRE(status.ok());

// Reconstruct first 5 vectors
constexpr size_t nrecon = 5;
std::vector<size_t> ids(nrecon);
std::iota(ids.begin(), ids.end(), 0);
std::vector<float> output(nrecon * test_d, 0.0f);

status = index->reconstruct_at(nrecon, ids.data(), output.data());
CATCH_REQUIRE(status.ok());

// For FP32 storage, reconstructed vectors should match originals exactly
for (size_t i = 0; i < nrecon; ++i) {
for (size_t j = 0; j < test_d; ++j) {
CATCH_REQUIRE(output[i * test_d + j] == test_data[i * test_d + j]);
}
}

svs::runtime::v0::DynamicVamanaIndex::destroy(index);
}

CATCH_TEST_CASE("ReconstructAtStatic", "[runtime][static_vamana]") {
const auto& test_data = get_test_data();
svs::runtime::v0::VamanaIndex* index = nullptr;
svs::runtime::v0::VamanaIndex::BuildParams build_params{64};
auto status = svs::runtime::v0::VamanaIndex::build(
&index,
test_d,
svs::runtime::v0::MetricType::L2,
svs::runtime::v0::StorageKind::FP32,
build_params
);
CATCH_REQUIRE(status.ok());

status = index->add(test_n, test_data.data());
CATCH_REQUIRE(status.ok());

// Reconstruct first 5 vectors
constexpr size_t nrecon = 5;
std::vector<size_t> ids(nrecon);
std::iota(ids.begin(), ids.end(), 0);
std::vector<float> output(nrecon * test_d, 0.0f);

status = index->reconstruct_at(nrecon, ids.data(), output.data());
CATCH_REQUIRE(status.ok());

// For FP32 storage, reconstructed vectors should match originals exactly
for (size_t i = 0; i < nrecon; ++i) {
for (size_t j = 0; j < test_d; ++j) {
CATCH_REQUIRE(output[i * test_d + j] == test_data[i * test_d + j]);
}
}

svs::runtime::v0::VamanaIndex::destroy(index);
}
Loading