diff --git a/bindings/cpp/include/svs/runtime/vamana_index.h b/bindings/cpp/include/svs/runtime/vamana_index.h index 2149c0d3..a419228a 100644 --- a/bindings/cpp/include/svs/runtime/vamana_index.h +++ b/bindings/cpp/include/svs/runtime/vamana_index.h @@ -84,6 +84,13 @@ struct SVS_RUNTIME_API VamanaIndex { IDFilter* filter = nullptr ) const noexcept = 0; + // Compute distance between stored vector `id` and `query` (dim floats). + virtual Status + get_distance(float* distance, size_t id, const float* query) const noexcept = 0; + + // Reconstruct `n` vectors by ID into `output` buffer (n * dim floats). + virtual Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept = 0; + // Utility function to check storage kind support static Status check_storage_kind(StorageKind storage_kind) noexcept; diff --git a/bindings/cpp/src/dynamic_vamana_index.cpp b/bindings/cpp/src/dynamic_vamana_index.cpp index 0c1a6a89..0fb1e7b6 100644 --- a/bindings/cpp/src/dynamic_vamana_index.cpp +++ b/bindings/cpp/src/dynamic_vamana_index.cpp @@ -118,6 +118,22 @@ struct DynamicVamanaIndexManagerBase : public DynamicVamanaIndex { Status save(std::ostream& out) const noexcept override { return runtime_error_wrapper([&] { impl_->save(out); }); } + + Status + get_distance(float* distance, size_t id, const float* query) const noexcept override { + return runtime_error_wrapper([&] { + std::span q{query, impl_->dimensions()}; + *distance = static_cast(impl_->get_distance(id, q)); + }); + } + + Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept override { + return runtime_error_wrapper([&] { + svs::data::SimpleDataView dst{output, n, impl_->dimensions()}; + std::span id_span{ids, n}; + impl_->reconstruct_at(dst, id_span); + }); + } }; } // namespace diff --git a/bindings/cpp/src/dynamic_vamana_index_impl.h b/bindings/cpp/src/dynamic_vamana_index_impl.h index 516a1653..ee282b27 100644 --- a/bindings/cpp/src/dynamic_vamana_index_impl.h +++ b/bindings/cpp/src/dynamic_vamana_index_impl.h @@ -344,6 +344,21 @@ class DynamicVamanaIndexImpl { return remove(ids_to_delete); } + double get_distance(size_t id, std::span query) const { + if (!impl_) { + throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"}; + } + return impl_->get_distance(id, query); + } + + void reconstruct_at(svs::data::SimpleDataView dst, std::span ids) { + if (!impl_) { + throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"}; + } + std::vector id_vec(ids.begin(), ids.end()); + impl_->reconstruct_at(dst, std::span{id_vec}); + } + void reset() { impl_.reset(); ntotal_soft_deleted = 0; diff --git a/bindings/cpp/src/vamana_index.cpp b/bindings/cpp/src/vamana_index.cpp index c015dd21..8e089ab5 100644 --- a/bindings/cpp/src/vamana_index.cpp +++ b/bindings/cpp/src/vamana_index.cpp @@ -88,6 +88,22 @@ struct VamanaIndexManagerBase : public VamanaIndex { Status save(std::ostream& out) const noexcept override { return runtime_error_wrapper([&] { impl_->save(out); }); } + + Status + get_distance(float* distance, size_t id, const float* query) const noexcept override { + return runtime_error_wrapper([&] { + std::span q{query, impl_->dimensions()}; + *distance = static_cast(impl_->get_distance(id, q)); + }); + } + + Status reconstruct_at(size_t n, const size_t* ids, float* output) noexcept override { + return runtime_error_wrapper([&] { + svs::data::SimpleDataView dst{output, n, impl_->dimensions()}; + std::span id_span{ids, n}; + impl_->reconstruct_at(dst, id_span); + }); + } }; } // namespace diff --git a/bindings/cpp/src/vamana_index_impl.h b/bindings/cpp/src/vamana_index_impl.h index 550257d4..8a321620 100644 --- a/bindings/cpp/src/vamana_index_impl.h +++ b/bindings/cpp/src/vamana_index_impl.h @@ -307,6 +307,21 @@ class VamanaIndexImpl { } } + double get_distance(size_t id, std::span query) const { + if (!impl_) { + throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"}; + } + return get_impl()->get_distance(id, query); + } + + void reconstruct_at(svs::data::SimpleDataView dst, std::span ids) { + if (!impl_) { + throw StatusException{ErrorCode::NOT_INITIALIZED, "Index not initialized"}; + } + std::vector id_vec(ids.begin(), ids.end()); + get_impl()->reconstruct_at(dst, std::span{id_vec}); + } + void reset() { impl_.reset(); } void save(std::ostream& out) const { get_impl()->save(out); } diff --git a/bindings/cpp/tests/runtime_test.cpp b/bindings/cpp/tests/runtime_test.cpp index abd14296..db5ff99b 100644 --- a/bindings/cpp/tests/runtime_test.cpp +++ b/bindings/cpp/tests/runtime_test.cpp @@ -997,3 +997,141 @@ CATCH_TEST_CASE("RangeSearchFunctionalStatic", "[runtime][static_vamana]") { svs::runtime::v0::VamanaIndex::destroy(index); } + +CATCH_TEST_CASE("GetDistanceDynamic", "[runtime]") { + const auto& test_data = get_test_data(); + svs::runtime::v0::DynamicVamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + auto status = svs::runtime::v0::DynamicVamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + status = index->add(test_n, labels.data(), test_data.data()); + CATCH_REQUIRE(status.ok()); + + // Self-distance should be approximately 0 + float dist = -1.0f; + const float* vec0 = test_data.data(); + status = index->get_distance(&dist, 0, vec0); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(dist < 1e-6); + + // Distance to a different vector should be positive + const float* vec1 = test_data.data() + test_d; + status = index->get_distance(&dist, 0, vec1); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(dist > 0.0); + + svs::runtime::v0::DynamicVamanaIndex::destroy(index); +} + +CATCH_TEST_CASE("GetDistanceStatic", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + svs::runtime::v0::VamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + auto status = svs::runtime::v0::VamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + + status = index->add(test_n, test_data.data()); + CATCH_REQUIRE(status.ok()); + + // Self-distance should be approximately 0 + float dist = -1.0f; + const float* vec0 = test_data.data(); + status = index->get_distance(&dist, 0, vec0); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(dist < 1e-6); + + // Distance to a different vector should be positive + const float* vec1 = test_data.data() + test_d; + status = index->get_distance(&dist, 0, vec1); + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(dist > 0.0); + + svs::runtime::v0::VamanaIndex::destroy(index); +} + +CATCH_TEST_CASE("ReconstructAtDynamic", "[runtime]") { + const auto& test_data = get_test_data(); + svs::runtime::v0::DynamicVamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + auto status = svs::runtime::v0::DynamicVamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + + std::vector labels(test_n); + std::iota(labels.begin(), labels.end(), 0); + status = index->add(test_n, labels.data(), test_data.data()); + CATCH_REQUIRE(status.ok()); + + // Reconstruct first 5 vectors + constexpr size_t nrecon = 5; + std::vector ids(nrecon); + std::iota(ids.begin(), ids.end(), 0); + std::vector output(nrecon * test_d, 0.0f); + + status = index->reconstruct_at(nrecon, ids.data(), output.data()); + CATCH_REQUIRE(status.ok()); + + // For FP32 storage, reconstructed vectors should match originals exactly + for (size_t i = 0; i < nrecon; ++i) { + for (size_t j = 0; j < test_d; ++j) { + CATCH_REQUIRE(output[i * test_d + j] == test_data[i * test_d + j]); + } + } + + svs::runtime::v0::DynamicVamanaIndex::destroy(index); +} + +CATCH_TEST_CASE("ReconstructAtStatic", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + svs::runtime::v0::VamanaIndex* index = nullptr; + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + auto status = svs::runtime::v0::VamanaIndex::build( + &index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + CATCH_REQUIRE(status.ok()); + + status = index->add(test_n, test_data.data()); + CATCH_REQUIRE(status.ok()); + + // Reconstruct first 5 vectors + constexpr size_t nrecon = 5; + std::vector ids(nrecon); + std::iota(ids.begin(), ids.end(), 0); + std::vector output(nrecon * test_d, 0.0f); + + status = index->reconstruct_at(nrecon, ids.data(), output.data()); + CATCH_REQUIRE(status.ok()); + + // For FP32 storage, reconstructed vectors should match originals exactly + for (size_t i = 0; i < nrecon; ++i) { + for (size_t j = 0; j < test_d; ++j) { + CATCH_REQUIRE(output[i * test_d + j] == test_data[i * test_d + j]); + } + } + + svs::runtime::v0::VamanaIndex::destroy(index); +}