diff --git a/docs/src/format/index/scalar/fts.md b/docs/src/format/index/scalar/fts.md index adc7f94d65e..fc1e97f678c 100644 --- a/docs/src/format/index/scalar/fts.md +++ b/docs/src/format/index/scalar/fts.md @@ -43,6 +43,8 @@ An FTS index may contain multiple partitions. Each partition has its own set of | `_length` | UInt32 | false | Number of documents containing the token | | `_compressed_position` | List> | true | Optional compressed position lists for phrase queries | +The posting-list file schema metadata includes `posting_block_size`, the number of documents encoded per compressed posting block. Older indexes that do not have this metadata use the legacy block size `128`. + ### Metadata File Schema The metadata file contains JSON-serialized configuration and partition information: @@ -67,6 +69,7 @@ The metadata file contains JSON-serialized configuration and partition informati | `min_gram` | UInt32 | 2 | Minimum n-gram length (only for ngram tokenizer) | | `max_gram` | UInt32 | 15 | Maximum n-gram length (only for ngram tokenizer) | | `prefix_only` | Boolean | false | Generate only prefix n-grams (only for ngram tokenizer) | +| `block_size` | UInt32 | 256 | Documents per compressed posting block. Must be 128 or 256. Missing values from older indexes read as 128. | ## Tokenizers diff --git a/docs/src/quickstart/full-text-search.md b/docs/src/quickstart/full-text-search.md index f990b2bd589..777a8310ad2 100644 --- a/docs/src/quickstart/full-text-search.md +++ b/docs/src/quickstart/full-text-search.md @@ -98,6 +98,7 @@ ds.create_scalar_index( remove_stop_words=True, # Remove stop words (language-dependent) custom_stop_words=None, # Optional additional stop words (only used if remove_stop_words=True) ascii_folding=True, # Fold accents to ASCII when possible (e.g., "é" -> "e") + block_size=256, # Posting block size: 128 or 256 ) ``` diff --git a/java/src/main/java/org/lance/index/scalar/InvertedIndexParams.java b/java/src/main/java/org/lance/index/scalar/InvertedIndexParams.java index f509efd922b..97d75bfc20e 100755 --- a/java/src/main/java/org/lance/index/scalar/InvertedIndexParams.java +++ b/java/src/main/java/org/lance/index/scalar/InvertedIndexParams.java @@ -53,6 +53,7 @@ public static final class Builder { private Integer minNgramLength; private Integer maxNgramLength; private Boolean prefixOnly; + private Integer blockSize = 256; private Boolean skipMerge; /** @@ -225,6 +226,24 @@ public Builder prefixOnly(boolean prefixOnly) { return this; } + /** + * Configure the number of documents in each compressed posting block. + * + *

Supported values are {@code 128} and {@code 256}. New indexes default to {@code 256} when + * this is not set. + * + * @param blockSize posting block size + * @return this builder + * @throws IllegalArgumentException if {@code blockSize} is unsupported + */ + public Builder blockSize(int blockSize) { + if (blockSize != 128 && blockSize != 256) { + throw new IllegalArgumentException("blockSize must be one of 128 or 256"); + } + this.blockSize = blockSize; + return this; + } + /** * Configure whether to skip the partition merge stage after indexing. If true, skip the * partition merge stage after indexing. This can be useful for distributed indexing where merge @@ -282,6 +301,9 @@ public ScalarIndexParams build() { if (prefixOnly != null) { params.put("prefix_only", prefixOnly); } + if (blockSize != null) { + params.put("block_size", blockSize); + } if (skipMerge != null) { params.put("skip_merge", skipMerge); } diff --git a/java/src/test/java/org/lance/index/scalar/InvertedIndexParamsTest.java b/java/src/test/java/org/lance/index/scalar/InvertedIndexParamsTest.java index e5024a95c2a..20001f04e4e 100644 --- a/java/src/test/java/org/lance/index/scalar/InvertedIndexParamsTest.java +++ b/java/src/test/java/org/lance/index/scalar/InvertedIndexParamsTest.java @@ -13,19 +13,49 @@ */ package org.lance.index.scalar; +import org.lance.util.JsonUtils; + import org.junit.jupiter.api.Test; +import java.util.Map; + import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -public class InvertedIndexParamsTest { +class InvertedIndexParamsTest { @Test - public void testIcuSplitTokenizerVariant() { + void testIcuSplitTokenizerVariant() { ScalarIndexParams params = InvertedIndexParams.builder().baseTokenizer("icu/split").build(); assertEquals("inverted", params.getIndexType()); String jsonParams = params.getJsonParams().orElseThrow(AssertionError::new); assertTrue(jsonParams.contains("\"base_tokenizer\":\"icu/split\"")); } + + @Test + void defaultBlockSizeIsSerialized() { + ScalarIndexParams params = InvertedIndexParams.builder().build(); + + Map json = JsonUtils.fromJson(params.getJsonParams().orElseThrow()); + assertEquals(256, ((Number) json.get("block_size")).intValue()); + } + + @Test + void blockSizeIsSerialized() { + ScalarIndexParams params = InvertedIndexParams.builder().blockSize(128).build(); + + assertEquals("inverted", params.getIndexType()); + Map json = JsonUtils.fromJson(params.getJsonParams().orElseThrow()); + assertEquals(128, ((Number) json.get("block_size")).intValue()); + } + + @Test + void invalidBlockSizeIsRejected() { + assertThrows( + IllegalArgumentException.class, () -> InvertedIndexParams.builder().blockSize(129)); + assertThrows( + IllegalArgumentException.class, () -> InvertedIndexParams.builder().blockSize(512)); + } } diff --git a/protos/index_old.proto b/protos/index_old.proto index 601aa2681da..94429e8d964 100644 --- a/protos/index_old.proto +++ b/protos/index_old.proto @@ -39,4 +39,5 @@ message InvertedIndexDetails { uint32 min_ngram_length = 9; uint32 max_ngram_length = 10; bool prefix_only = 11; + optional uint32 block_size = 12; } diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 833d0f1321d..03ec1c52d04 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -3315,6 +3315,9 @@ def create_scalar_index( query. This will significantly increase the index size. It won't impact the performance of non-phrase queries even if it is set to True. + block_size: int, default 256 + This is for the ``INVERTED`` index. Number of documents per compressed + posting block. Must be one of ``128`` or ``256``. memory_limit: int, optional This is for the ``INVERTED`` index. Total build-time memory limit in MiB. If set, Lance divides this budget evenly across the workers. If unset, diff --git a/python/python/tests/compat/test_scalar_indices.py b/python/python/tests/compat/test_scalar_indices.py index 35022df3b12..67f07299e6e 100644 --- a/python/python/tests/compat/test_scalar_indices.py +++ b/python/python/tests/compat/test_scalar_indices.py @@ -9,6 +9,7 @@ and written by other versions. """ +import os import shutil from pathlib import Path @@ -320,7 +321,12 @@ def create(self): max_rows_per_file=100, data_storage_version=safe_data_storage_version(self.compat_version), ) - dataset.create_scalar_index("text", "INVERTED", with_position=True) + kwargs = {"with_position": True} + # Downgrade reads use older wheels, so current-created FTS indexes must + # stay on the legacy posting block layout. + if os.environ.get("LANCE_COMPAT_FTS_LEGACY_BLOCK_SIZE") == "1": + kwargs["block_size"] = 128 + dataset.create_scalar_index("text", "INVERTED", **kwargs) def check_read(self): """Verify FTS index can be queried.""" @@ -351,7 +357,10 @@ def skip_downgrade(self, version: str) -> bool: def current_env(self, method_name: str) -> dict[str, str]: if method_name == "create": - return {"LANCE_FTS_FORMAT_VERSION": "1"} + return { + "LANCE_COMPAT_FTS_LEGACY_BLOCK_SIZE": "1", + "LANCE_FTS_FORMAT_VERSION": "1", + } if method_name == "check_write": return {"LANCE_FTS_FORMAT_VERSION": "2"} return {} diff --git a/python/python/tests/test_scalar_index.py b/python/python/tests/test_scalar_index.py index 4450f6821cc..ec7c4985c3b 100644 --- a/python/python/tests/test_scalar_index.py +++ b/python/python/tests/test_scalar_index.py @@ -949,6 +949,26 @@ def test_create_scalar_index_fts_alias(dataset): assert any(idx.index_type == "Inverted" for idx in dataset.describe_indices()) +def test_create_scalar_index_fts_block_size(dataset): + dataset.create_scalar_index( + "doc", index_type="INVERTED", with_position=False, block_size=256 + ) + row = dataset.take(indices=[0], columns=["doc"]) + query = row.column(0)[0].as_py().split(" ")[0] + results = dataset.scanner(columns=["doc"], full_text_query=query).to_table() + assert results.num_rows > 0 + + with pytest.raises(ValueError, match="block_size"): + dataset.create_scalar_index( + "doc", index_type="INVERTED", name="doc_invalid_129", block_size=129 + ) + + with pytest.raises(ValueError, match="block_size"): + dataset.create_scalar_index( + "doc", index_type="INVERTED", name="doc_invalid_512", block_size=512 + ) + + def test_multi_index_create(tmp_path): dataset = lance.write_dataset( pa.table({"ints": range(1024)}), tmp_path, max_rows_per_file=100 diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 7c525b0aefe..c2c936911d5 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -2429,6 +2429,11 @@ impl Dataset { if let Some(prefix_only) = kwargs.get_item("prefix_only")? { params = params.ngram_prefix_only(prefix_only.extract()?); } + if let Some(block_size) = kwargs.get_item("block_size")? { + params = params + .block_size(block_size.extract()?) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + } if let Some(memory_limit) = kwargs.get_item("memory_limit")? { params = params.memory_limit_mb(memory_limit.extract()?); } diff --git a/rust/compression/bitpacking/src/bitpacker_internal/bitpacker8x.rs b/rust/compression/bitpacking/src/bitpacker_internal/bitpacker8x.rs index 188a1f4ce2a..b17edacabb2 100644 --- a/rust/compression/bitpacking/src/bitpacker_internal/bitpacker8x.rs +++ b/rust/compression/bitpacking/src/bitpacker_internal/bitpacker8x.rs @@ -420,12 +420,11 @@ enum InstructionSet { Scalar, } -/// Internal 8-wide bitpacker implementation. +/// 8-wide bitpacker implementation. /// -/// One block contains 256 integers. This stays private to avoid exposing a new -/// block-size choice through the public Lance bitpacking API. +/// One block contains 256 integers. #[derive(Clone, Copy)] -pub(crate) struct BitPacker8x(InstructionSet); +pub struct BitPacker8x(InstructionSet); impl BitPacker8x { #[cfg(target_arch = "x86_64")] diff --git a/rust/compression/bitpacking/src/bitpacker_internal/mod.rs b/rust/compression/bitpacking/src/bitpacker_internal/mod.rs index c287a29da0b..80803e50ec8 100644 --- a/rust/compression/bitpacking/src/bitpacker_internal/mod.rs +++ b/rust/compression/bitpacking/src/bitpacker_internal/mod.rs @@ -20,6 +20,7 @@ mod bitpacker4x; mod bitpacker8x; pub use bitpacker4x::BitPacker4x; +pub use bitpacker8x::BitPacker8x; pub(crate) trait Available { fn available() -> bool; diff --git a/rust/compression/bitpacking/src/lib.rs b/rust/compression/bitpacking/src/lib.rs index df2313bed5d..c4163a5324d 100644 --- a/rust/compression/bitpacking/src/lib.rs +++ b/rust/compression/bitpacking/src/lib.rs @@ -18,7 +18,7 @@ use core::mem::size_of; mod bitpacker_internal; -pub use bitpacker_internal::{BitPacker, BitPacker4x}; +pub use bitpacker_internal::{BitPacker, BitPacker4x, BitPacker8x}; pub const FL_ORDER: [usize; 8] = [0, 4, 2, 6, 1, 5, 3, 7]; diff --git a/rust/lance-index/protos-cache/cache.proto b/rust/lance-index/protos-cache/cache.proto index b24a27055d7..e92da05815f 100644 --- a/rust/lance-index/protos-cache/cache.proto +++ b/rust/lance-index/protos-cache/cache.proto @@ -28,6 +28,9 @@ message CompressedPostingHeader { PositionStorage position_storage = 4; // Only meaningful when position_storage == POSITION_STORAGE_SHARED. PositionStreamCodec position_stream_codec = 5; + // Number of documents in each compressed posting block. Older cache entries + // omit this field and decode as the legacy 128-doc block size. + uint32 block_size = 6; } // Header for a serialized `PlainPostingList` cache entry. Followed by an Arrow diff --git a/rust/lance-index/src/scalar/inverted.rs b/rust/lance-index/src/scalar/inverted.rs index d0bb0e40d3a..8300dc1043c 100644 --- a/rust/lance-index/src/scalar/inverted.rs +++ b/rust/lance-index/src/scalar/inverted.rs @@ -215,7 +215,7 @@ impl ScalarIndexPlugin for InvertedIndexPlugin { .into())) } - let params = serde_json::from_str::(params)?; + let params = InvertedIndexParams::from_training_json(params)?; Ok(Box::new(InvertedIndexTrainingRequest::new(params))) } @@ -308,6 +308,7 @@ impl ScalarIndexPlugin for InvertedIndexPlugin { #[cfg(test)] mod tests { use super::*; + use crate::scalar::{BuiltinIndexType, ScalarIndexParams}; #[test] fn test_plugin_version_tracks_max_supported_format() { @@ -317,4 +318,36 @@ mod tests { max_supported_fts_format_version().index_version() ); } + + #[test] + fn test_new_training_request_defaults_missing_block_size_to_256() { + let plugin = InvertedIndexPlugin; + let field = Field::new("text", DataType::Utf8, true); + + let cases = [ + ( + ScalarIndexParams::for_builtin(BuiltinIndexType::Inverted), + false, + ), + (ScalarIndexParams::new("inverted".to_string()), false), + ( + ScalarIndexParams::new("inverted".to_string()) + .with_params(&serde_json::json!({ "with_position": true })), + true, + ), + ]; + + for (params, expected_with_position) in cases { + let request = plugin + .new_training_request(params.params.as_deref().unwrap_or("{}"), &field) + .unwrap(); + let request = request + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(request.parameters.posting_block_size(), DEFAULT_BLOCK_SIZE); + assert_eq!(request.parameters.has_positions(), expected_with_position); + } + } } diff --git a/rust/lance-index/src/scalar/inverted/builder.rs b/rust/lance-index/src/scalar/inverted/builder.rs index fe72540b422..504f1ab1b8a 100644 --- a/rust/lance-index/src/scalar/inverted/builder.rs +++ b/rust/lance-index/src/scalar/inverted/builder.rs @@ -6,6 +6,7 @@ use super::{InvertedIndexParams, index::*}; use crate::scalar::inverted::document_tokenizer::DocType; use crate::scalar::inverted::json::JsonTextStream; use crate::scalar::inverted::tokenizer::document_tokenizer::LanceTokenizer; +use crate::scalar::inverted::tokenizer::{LEGACY_BLOCK_SIZE, validate_block_size}; #[cfg(test)] use crate::scalar::lance_format::LanceIndexStore; use crate::scalar::{IndexFile, IndexStore, OldIndexDataFilter}; @@ -41,9 +42,9 @@ use std::task::{Context, Poll}; use std::{fmt::Debug, sync::atomic::AtomicU64}; use tracing::instrument; -// the number of elements in each block -// each block contains 128 row ids and 128 frequencies -// WARNING: changing this value will break the compatibility with existing indexes +// The legacy bitpacking block size. Position streams still use this block size; +// FTS posting blocks choose their physical bitpacker from the configured +// InvertedIndexParams::block_size. pub const BLOCK_SIZE: usize = BitPacker4x::BLOCK_LEN; // The default number of workers to use for FTS builds. @@ -448,6 +449,7 @@ impl InvertedIndexBuilder { fragment_mask: self.fragment_mask, token_set_format: self.token_set_format, worker_memory_limit_bytes, + block_size: self.params.block_size, }; let next_id = self.next_partition_id(); let id_alloc = Arc::new(AtomicU64::new(next_id)); @@ -794,6 +796,7 @@ pub struct InnerBuilder { token_set_format: TokenSetFormat, format_version: InvertedListFormatVersion, posting_tail_codec: PostingTailCodec, + block_size: usize, pub(crate) tokens: TokenSet, pub(crate) posting_lists: Vec, pub(crate) docs: DocSet, @@ -816,12 +819,45 @@ impl InnerBuilder { token_set_format: TokenSetFormat, format_version: InvertedListFormatVersion, ) -> Self { + Self::new_with_format_version_and_block_size( + id, + with_position, + token_set_format, + format_version, + LEGACY_BLOCK_SIZE, + ) + } + + pub fn new_with_block_size( + id: u64, + with_position: bool, + token_set_format: TokenSetFormat, + block_size: usize, + ) -> Self { + Self::new_with_format_version_and_block_size( + id, + with_position, + token_set_format, + current_fts_format_version(), + block_size, + ) + } + + pub fn new_with_format_version_and_block_size( + id: u64, + with_position: bool, + token_set_format: TokenSetFormat, + format_version: InvertedListFormatVersion, + block_size: usize, + ) -> Self { + validate_block_size(block_size).expect("invalid posting list block size"); Self { id, with_position, token_set_format, format_version, posting_tail_codec: format_version.posting_tail_codec(), + block_size, tokens: TokenSet::default(), posting_lists: Vec::new(), docs: DocSet::default(), @@ -834,14 +870,35 @@ impl InnerBuilder { with_position: bool, token_set_format: TokenSetFormat, posting_tail_codec: PostingTailCodec, + ) -> Self { + Self::new_with_posting_tail_codec_and_block_size( + id, + with_position, + token_set_format, + posting_tail_codec, + LEGACY_BLOCK_SIZE, + ) + } + + pub fn new_with_posting_tail_codec_and_block_size( + id: u64, + with_position: bool, + token_set_format: TokenSetFormat, + posting_tail_codec: PostingTailCodec, + block_size: usize, ) -> Self { let format_version = if posting_tail_codec == PostingTailCodec::Fixed32 { InvertedListFormatVersion::V1 } else { InvertedListFormatVersion::V2 }; - let mut builder = - Self::new_with_format_version(id, with_position, token_set_format, format_version); + let mut builder = Self::new_with_format_version_and_block_size( + id, + with_position, + token_set_format, + format_version, + block_size, + ); builder.posting_tail_codec = posting_tail_codec; builder } @@ -922,6 +979,7 @@ impl InnerBuilder { token_set_format, format_version, posting_tail_codec, + block_size, tokens, posting_lists, docs, @@ -952,6 +1010,12 @@ impl InnerBuilder { self.posting_tail_codec, posting_tail_codec ))); } + if self.block_size != block_size { + return Err(Error::index(format!( + "cannot merge partitions with mismatched FTS block sizes: {} vs {}", + self.block_size, block_size + ))); + } let mut token_id_map = vec![u32::MAX; posting_lists.len()]; match tokens.tokens { @@ -977,7 +1041,11 @@ impl InnerBuilder { self.docs.append(*row_id, *num_tokens); } self.posting_lists.resize_with(self.tokens.len(), || { - PostingListBuilder::new_with_posting_tail_codec(with_position, self.posting_tail_codec) + PostingListBuilder::new_with_posting_tail_codec_and_block_size( + with_position, + self.posting_tail_codec, + self.block_size, + ) }); for (token_id, posting_list) in posting_lists.into_iter().enumerate() { @@ -1044,7 +1112,11 @@ impl InnerBuilder { let mut writer = store .new_index_file( path, - inverted_list_schema_for_version(self.with_position, self.format_version), + inverted_list_schema_for_version_with_block_size( + self.with_position, + self.format_version, + self.block_size, + ), ) .await?; let posting_lists = std::mem::take(&mut self.posting_lists); @@ -1058,7 +1130,11 @@ impl InnerBuilder { let with_position = self.with_position; let format_version = self.format_version; let group_config = self.group_config; - let schema = inverted_list_schema_for_version(self.with_position, self.format_version); + let schema = inverted_list_schema_for_version_with_block_size( + self.with_position, + self.format_version, + self.block_size, + ); let docs_for_batches = docs.clone(); let schema_for_batches = schema.clone(); let batch_rows = *LANCE_FTS_POSTING_BATCH_ROWS; @@ -1214,6 +1290,7 @@ struct IndexWorkerConfig { fragment_mask: Option, token_set_format: TokenSetFormat, worker_memory_limit_bytes: u64, + block_size: usize, } impl IndexWorker { @@ -1257,17 +1334,22 @@ impl IndexWorker { id_alloc: Arc, config: IndexWorkerConfig, ) -> Result { - let schema = inverted_list_schema_for_version(config.with_position, config.format_version); + let schema = inverted_list_schema_for_version_with_block_size( + config.with_position, + config.format_version, + config.block_size, + ); Ok(Self { tokenizer, dest_store, - builder: InnerBuilder::new_with_format_version( + builder: InnerBuilder::new_with_format_version_and_block_size( id_alloc.fetch_add(1, std::sync::atomic::Ordering::Relaxed) | config.fragment_mask.unwrap_or(0), config.with_position, config.token_set_format, config.format_version, + config.block_size, ), partitions: Vec::new(), files: Vec::new(), @@ -1326,9 +1408,10 @@ impl IndexWorker { * std::mem::size_of::()) as u64; builder.posting_lists.push( - PostingListBuilder::new_with_posting_tail_codec( + PostingListBuilder::new_with_posting_tail_codec_and_block_size( true, posting_tail_codec, + builder.block_size, ), ); let new_posting_lists_overhead_size = (builder.posting_lists.capacity() @@ -1374,9 +1457,10 @@ impl IndexWorker { self.builder .posting_lists .resize_with(self.builder.tokens.len(), || { - PostingListBuilder::new_with_posting_tail_codec( + PostingListBuilder::new_with_posting_tail_codec_and_block_size( false, self.builder.posting_tail_codec, + self.builder.block_size, ) }); let new_posting_lists_overhead_size = self.posting_lists_overhead_size(); @@ -1481,15 +1565,17 @@ impl IndexWorker { self.memory_size = self.temporary_memory_size(); let with_position = self.has_position(); let format_version = self.builder.format_version; + let block_size = self.builder.block_size; let builder = std::mem::replace( &mut self.builder, - InnerBuilder::new_with_format_version( + InnerBuilder::new_with_format_version_and_block_size( self.id_alloc .fetch_add(1, std::sync::atomic::Ordering::Relaxed) | self.fragment_mask.unwrap_or(0), with_position, self.token_set_format, format_version, + block_size, ), ); let written_partition_id = builder.id(); @@ -1609,17 +1695,31 @@ pub fn inverted_list_schema_for_version( with_position: bool, format_version: InvertedListFormatVersion, ) -> SchemaRef { + inverted_list_schema_for_version_with_block_size( + with_position, + format_version, + LEGACY_BLOCK_SIZE, + ) +} + +pub fn inverted_list_schema_for_version_with_block_size( + with_position: bool, + format_version: InvertedListFormatVersion, + block_size: usize, +) -> SchemaRef { + validate_block_size(block_size).expect("invalid posting list block size"); match format_version { - InvertedListFormatVersion::V1 => inverted_list_schema_v1(with_position), + InvertedListFormatVersion::V1 => inverted_list_schema_v1(with_position, block_size), InvertedListFormatVersion::V2 => inverted_list_schema_with_tail_codec_and_position_codec( with_position, PostingTailCodec::VarintDelta, Some(PositionStreamCodec::PackedDelta), + block_size, ), } } -fn inverted_list_schema_v1(with_position: bool) -> SchemaRef { +fn inverted_list_schema_v1(with_position: bool, block_size: usize) -> SchemaRef { let mut fields = vec![ arrow_schema::Field::new( POSTING_COL, @@ -1648,7 +1748,10 @@ fn inverted_list_schema_v1(with_position: bool) -> SchemaRef { false, )); } - Arc::new(arrow_schema::Schema::new(fields)) + Arc::new(arrow_schema::Schema::new_with_metadata( + fields, + HashMap::from([(POSTING_BLOCK_SIZE_KEY.to_owned(), block_size.to_string())]), + )) } pub fn inverted_list_schema_with_tail_codec( @@ -1659,6 +1762,7 @@ pub fn inverted_list_schema_with_tail_codec( with_position, posting_tail_codec, Some(PositionStreamCodec::PackedDelta), + LEGACY_BLOCK_SIZE, ) } @@ -1666,6 +1770,7 @@ fn inverted_list_schema_with_tail_codec_and_position_codec( with_position: bool, posting_tail_codec: PostingTailCodec, position_codec: Option, + block_size: usize, ) -> SchemaRef { let mut fields = vec![ // we compress the posting lists (including row ids and frequencies), @@ -1702,6 +1807,7 @@ fn inverted_list_schema_with_tail_codec_and_position_codec( POSTING_TAIL_CODEC_KEY.to_owned(), posting_tail_codec.as_str().to_owned(), )]); + metadata.insert(POSTING_BLOCK_SIZE_KEY.to_owned(), block_size.to_string()); if let Some(position_codec) = position_codec.filter(|_| with_position) { metadata.insert( POSITIONS_LAYOUT_KEY.to_owned(), @@ -3184,6 +3290,7 @@ mod tests { fragment_mask: None, token_set_format, worker_memory_limit_bytes: u64::MAX, + block_size: params.block_size, }, ) .await?; @@ -3207,6 +3314,7 @@ mod tests { fragment_mask: None, token_set_format, worker_memory_limit_bytes: u64::MAX, + block_size: params.block_size, }, ) .await?; @@ -3680,6 +3788,7 @@ mod tests { fragment_mask: None, token_set_format: TokenSetFormat::default(), worker_memory_limit_bytes: u64::MAX, + block_size: InvertedIndexParams::default().block_size, }, ) .await?; @@ -3711,6 +3820,7 @@ mod tests { fragment_mask: None, token_set_format: TokenSetFormat::default(), worker_memory_limit_bytes: u64::MAX, + block_size: InvertedIndexParams::default().block_size, }, ) .await?; @@ -3749,6 +3859,7 @@ mod tests { fragment_mask: None, token_set_format: TokenSetFormat::default(), worker_memory_limit_bytes: u64::MAX, + block_size: InvertedIndexParams::default().block_size, }, ) .await?; diff --git a/rust/lance-index/src/scalar/inverted/cache_codec.rs b/rust/lance-index/src/scalar/inverted/cache_codec.rs index a676455d5c9..fffb080bda1 100644 --- a/rust/lance-index/src/scalar/inverted/cache_codec.rs +++ b/rust/lance-index/src/scalar/inverted/cache_codec.rs @@ -43,6 +43,7 @@ use super::index::{ CompressedPositionStorage, CompressedPostingList, PlainPostingList, PositionStreamCodec, Positions, PostingList, PostingListGroup, PostingTailCodec, SharedPositionStream, }; +use super::tokenizer::{LEGACY_BLOCK_SIZE, validate_block_size}; // --------------------------------------------------------------------------- // Tags @@ -291,6 +292,7 @@ fn serialize_compressed( posting_tail_codec: posting_tail_codec_to_proto(posting.posting_tail_codec) as i32, position_storage: position_storage as i32, position_stream_codec: position_stream_codec as i32, + block_size: posting.block_size as u32, }; w.write_header(&header)?; @@ -322,12 +324,18 @@ fn deserialize_compressed(r: &mut CacheEntryReader<'_>) -> Result { assert_eq!(restored.max_score, posting.max_score); assert_eq!(restored.length, posting.length); assert_eq!(restored.posting_tail_codec, posting.posting_tail_codec); + assert_eq!(restored.block_size, posting.block_size); assert_eq!(restored.blocks, posting.blocks); assert!(restored.positions.is_none()); } @@ -555,6 +564,7 @@ mod tests { 1.25, 5, PostingTailCodec::Fixed32, + crate::scalar::inverted::LEGACY_BLOCK_SIZE, Some(CompressedPositionStorage::LegacyPerDoc(legacy_positions( &[&[0, 4, 8]], ))), @@ -589,6 +599,7 @@ mod tests { 7.0, 3, PostingTailCodec::VarintDelta, + 256, Some(CompressedPositionStorage::SharedStream(stream)), ); let entry = PostingList::Compressed(posting.clone()); @@ -617,6 +628,7 @@ mod tests { 7.0, 3, PostingTailCodec::VarintDelta, + 256, Some(CompressedPositionStorage::SharedStream( expected_stream.clone(), )), @@ -651,6 +663,7 @@ mod tests { 2.5, 7, PostingTailCodec::VarintDelta, + 256, None, )); @@ -760,6 +773,7 @@ mod tests { 7.0, 3, PostingTailCodec::VarintDelta, + 256, Some(CompressedPositionStorage::SharedStream(stream)), )) } @@ -810,6 +824,7 @@ mod tests { 7.0, 3, PostingTailCodec::VarintDelta, + 256, None, )) }; diff --git a/rust/lance-index/src/scalar/inverted/encoding.rs b/rust/lance-index/src/scalar/inverted/encoding.rs index d75fa742ee7..28ece2a0d32 100644 --- a/rust/lance-index/src/scalar/inverted/encoding.rs +++ b/rust/lance-index/src/scalar/inverted/encoding.rs @@ -5,10 +5,15 @@ use std::io::Write; use super::builder::BLOCK_SIZE; use super::index::{PositionStreamCodec, PostingTailCodec}; +#[cfg(test)] +use super::tokenizer::LEGACY_BLOCK_SIZE; +use super::tokenizer::validate_block_size; use arrow::array::LargeBinaryBuilder; -use lance_bitpacking::{BitPacker, BitPacker4x}; +use lance_bitpacking::{BitPacker, BitPacker4x, BitPacker8x}; use lance_core::{Error, Result}; +pub const MAX_POSTING_BLOCK_SIZE: usize = BitPacker8x::BLOCK_LEN; + // we compress the posting list to multiple blocks of fixed number of elements (BLOCK_SIZE), // returns a LargeBinaryArray, where each binary is a compressed block (128 row ids + 128 frequencies) // each block is: @@ -41,13 +46,33 @@ pub fn compress_posting_list<'a>( #[cfg(test)] pub fn compress_posting_list_with_tail_codec<'a>( + length: usize, + doc_ids: impl Iterator, + frequencies: impl Iterator, + block_max_scores: impl Iterator, + tail_codec: PostingTailCodec, +) -> Result { + compress_posting_list_with_tail_codec_and_block_size( + length, + doc_ids, + frequencies, + block_max_scores, + tail_codec, + LEGACY_BLOCK_SIZE, + ) +} + +#[cfg(test)] +pub fn compress_posting_list_with_tail_codec_and_block_size<'a>( length: usize, doc_ids: impl Iterator, frequencies: impl Iterator, mut block_max_scores: impl Iterator, tail_codec: PostingTailCodec, + block_size: usize, ) -> Result { - if length < BLOCK_SIZE { + let block_size = validate_block_size(block_size)?; + if length < block_size { // directly do remainder compression to avoid overhead of creating buffer let mut builder = LargeBinaryBuilder::with_capacity(1, length * 4 * 2 + 1); // write the max score of the block @@ -63,27 +88,23 @@ pub fn compress_posting_list_with_tail_codec<'a>( return Ok(builder.finish()); } - let mut builder = LargeBinaryBuilder::with_capacity(length.div_ceil(BLOCK_SIZE), length * 3); - let mut buffer = [0u8; BLOCK_SIZE * 4 + 5]; - let mut doc_id_buffer = Vec::with_capacity(BLOCK_SIZE); - let mut freq_buffer = Vec::with_capacity(BLOCK_SIZE); + let mut builder = LargeBinaryBuilder::with_capacity(length.div_ceil(block_size), length * 3); + let mut doc_id_buffer = Vec::with_capacity(block_size); + let mut freq_buffer = Vec::with_capacity(block_size); for (doc_id, freq) in std::iter::zip(doc_ids, frequencies) { doc_id_buffer.push(*doc_id); freq_buffer.push(*freq); - if doc_id_buffer.len() < BLOCK_SIZE { + if doc_id_buffer.len() < block_size { continue; } - assert_eq!(doc_id_buffer.len(), BLOCK_SIZE); + assert_eq!(doc_id_buffer.len(), block_size); // write the max score of the block let max_score = block_max_scores.next().unwrap(); let _ = builder.write(max_score.to_le_bytes().as_ref())?; - // delta encoding + bitpacking for doc ids - compress_sorted_block(&doc_id_buffer, &mut buffer, &mut builder)?; - // bitpacking for frequencies - compress_block(&freq_buffer, &mut buffer, &mut builder)?; + encode_posting_block_payload(&doc_id_buffer, &freq_buffer, &mut builder)?; builder.append_value(""); doc_id_buffer.clear(); freq_buffer.clear(); @@ -105,12 +126,32 @@ pub fn encode_full_posting_block_into( frequencies: &[u32], block: &mut Vec, ) -> Result<()> { - debug_assert_eq!(doc_ids.len(), BLOCK_SIZE); - debug_assert_eq!(frequencies.len(), BLOCK_SIZE); + validate_block_size(doc_ids.len())?; + debug_assert_eq!(doc_ids.len(), frequencies.len()); block.extend_from_slice(&0f32.to_le_bytes()); - let mut buffer = [0u8; BLOCK_SIZE * 4 + 5]; - compress_sorted_block(doc_ids, &mut buffer, block)?; - compress_block(frequencies, &mut buffer, block)?; + encode_posting_block_payload(doc_ids, frequencies, block)?; + Ok(()) +} + +fn encode_posting_block_payload( + doc_ids: &[u32], + frequencies: &[u32], + block: &mut impl Write, +) -> Result<()> { + debug_assert_eq!(doc_ids.len(), frequencies.len()); + validate_block_size(doc_ids.len())?; + let mut buffer = [0u8; MAX_POSTING_BLOCK_SIZE * 4 + 5]; + match doc_ids.len() { + BitPacker4x::BLOCK_LEN => { + compress_sorted_block_with::(doc_ids, &mut buffer, block)?; + compress_block_with::(frequencies, &mut buffer, block)?; + } + BitPacker8x::BLOCK_LEN => { + compress_sorted_block_with::(doc_ids, &mut buffer, block)?; + compress_block_with::(frequencies, &mut buffer, block)?; + } + _ => unreachable!("validated posting block size should be supported"), + } Ok(()) } @@ -128,7 +169,17 @@ pub fn encode_remainder_posting_block_into( #[inline] fn compress_sorted_block(data: &[u32], buffer: &mut [u8], builder: &mut impl Write) -> Result<()> { - let compressor = BitPacker4x::new(); + compress_sorted_block_with::(data, buffer, builder) +} + +#[inline] +fn compress_sorted_block_with( + data: &[u32], + buffer: &mut [u8], + builder: &mut impl Write, +) -> Result<()> { + debug_assert_eq!(data.len(), P::BLOCK_LEN); + let compressor = P::new(); let num_bits = compressor.num_bits_sorted(data[0], data); let num_bytes = compressor.compress_sorted(data[0], data, buffer, num_bits); let _ = builder.write(data[0].to_le_bytes().as_ref())?; @@ -139,7 +190,17 @@ fn compress_sorted_block(data: &[u32], buffer: &mut [u8], builder: &mut impl Wri #[inline] fn compress_block(data: &[u32], buffer: &mut [u8], builder: &mut impl Write) -> Result<()> { - let compressor = BitPacker4x::new(); + compress_block_with::(data, buffer, builder) +} + +#[inline] +fn compress_block_with( + data: &[u32], + buffer: &mut [u8], + builder: &mut impl Write, +) -> Result<()> { + debug_assert_eq!(data.len(), P::BLOCK_LEN); + let compressor = P::new(); let num_bits = compressor.num_bits(data); let num_bytes = compressor.compress(data, buffer, num_bits); let _ = builder.write(&[num_bits])?; @@ -650,17 +711,39 @@ pub fn decompress_posting_list_with_tail_codec( posting_list: &arrow::array::LargeBinaryArray, tail_codec: PostingTailCodec, ) -> Result<(Vec, Vec)> { + decompress_posting_list_with_tail_codec_and_block_size( + num_docs, + posting_list, + tail_codec, + LEGACY_BLOCK_SIZE, + ) +} + +#[cfg(test)] +pub fn decompress_posting_list_with_tail_codec_and_block_size( + num_docs: u32, + posting_list: &arrow::array::LargeBinaryArray, + tail_codec: PostingTailCodec, + block_size: usize, +) -> Result<(Vec, Vec)> { + let block_size = validate_block_size(block_size)?; let mut doc_ids: Vec = Vec::with_capacity(num_docs as usize); let mut frequencies: Vec = Vec::with_capacity(num_docs as usize); - let mut buffer = [0u32; BLOCK_SIZE]; - let bitpacking_blocks = num_docs as usize / BLOCK_SIZE; + let mut buffer = [0u32; MAX_POSTING_BLOCK_SIZE]; + let bitpacking_blocks = num_docs as usize / block_size; for compressed in posting_list.iter().take(bitpacking_blocks) { let compressed = compressed.unwrap(); - decompress_posting_block(compressed, &mut buffer, &mut doc_ids, &mut frequencies); + decompress_posting_block( + compressed, + &mut buffer, + &mut doc_ids, + &mut frequencies, + block_size, + ); } - let remainder = num_docs as usize % BLOCK_SIZE; + let remainder = num_docs as usize % block_size; if remainder > 0 { let compressed = posting_list.value(bitpacking_blocks); decompress_posting_remainder( @@ -701,14 +784,35 @@ pub fn read_num_positions(compressed: &arrow::array::LargeBinaryArray) -> u32 { pub fn decompress_posting_block( block: &[u8], - buffer: &mut [u32; BLOCK_SIZE], + buffer: &mut [u32], doc_ids: &mut Vec, frequencies: &mut Vec, + block_size: usize, ) { + debug_assert!(validate_block_size(block_size).is_ok()); + debug_assert!(buffer.len() >= block_size); // skip the first 4 bytes for the max block score - let block = &block[4..]; - let num_bytes = decompress_sorted_block(block, buffer, doc_ids); - decompress_block(&block[num_bytes..], buffer, frequencies); + let mut block = &block[4..]; + match block_size { + BitPacker4x::BLOCK_LEN => { + let num_bytes = decompress_sorted_block_with::(block, buffer, doc_ids); + block = &block[num_bytes..]; + let num_bytes = decompress_block_with::(block, buffer, frequencies); + block = &block[num_bytes..]; + } + BitPacker8x::BLOCK_LEN => { + let num_bytes = decompress_sorted_block_with::(block, buffer, doc_ids); + block = &block[num_bytes..]; + let num_bytes = decompress_block_with::(block, buffer, frequencies); + block = &block[num_bytes..]; + } + _ => unreachable!("validated posting block size should be supported"), + } + debug_assert!( + block.is_empty(), + "posting block has {} trailing bytes after decoding", + block.len() + ); } pub fn decompress_posting_remainder( @@ -755,9 +859,14 @@ pub fn decompress_posting_remainder( } } -pub fn decode_full_posting_block(block: &[u8], doc_ids: &mut Vec, frequencies: &mut Vec) { - let mut buffer = [0u32; BLOCK_SIZE]; - decompress_posting_block(block, &mut buffer, doc_ids, frequencies); +pub fn decode_full_posting_block( + block: &[u8], + doc_ids: &mut Vec, + frequencies: &mut Vec, + block_size: usize, +) { + let mut buffer = [0u32; MAX_POSTING_BLOCK_SIZE]; + decompress_posting_block(block, &mut buffer, doc_ids, frequencies, block_size); } pub fn decompress_sorted_block( @@ -765,7 +874,17 @@ pub fn decompress_sorted_block( buffer: &mut [u32; BLOCK_SIZE], res: &mut Vec, ) -> usize { - let compressor = BitPacker4x::new(); + decompress_sorted_block_with::(block, buffer, res) +} + +fn decompress_sorted_block_with( + block: &[u8], + buffer: &mut [u32], + res: &mut Vec, +) -> usize { + debug_assert!(buffer.len() >= P::BLOCK_LEN); + let buffer = &mut buffer[..P::BLOCK_LEN]; + let compressor = P::new(); let initial = u32::from_le_bytes(block[0..4].try_into().unwrap()); let num_bits = block[4]; let num_bytes = compressor.decompress_sorted(initial, &block[5..], buffer, num_bits); @@ -773,11 +892,18 @@ pub fn decompress_sorted_block( 5 + num_bytes } -fn decompress_block(block: &[u8], buffer: &mut [u32; BLOCK_SIZE], res: &mut Vec) { - let compressor = BitPacker4x::new(); +fn decompress_block_with( + block: &[u8], + buffer: &mut [u32], + res: &mut Vec, +) -> usize { + debug_assert!(buffer.len() >= P::BLOCK_LEN); + let buffer = &mut buffer[..P::BLOCK_LEN]; + let compressor = P::new(); let num_bits = block[0]; - compressor.decompress(&block[1..], buffer, num_bits); + let num_bytes = compressor.decompress(&block[1..], buffer, num_bits); res.extend_from_slice(&buffer[..]); + 1 + num_bytes } pub fn decompress_raw_remainder(compressed: &[u8], n: usize, dest: &mut Vec) { @@ -867,6 +993,78 @@ mod tests { Ok(()) } + #[test] + fn test_compress_posting_list_supported_block_sizes() -> Result<()> { + for block_size in [128, 256] { + let num_rows: usize = block_size * 2 + 7; + let doc_ids = (0..num_rows as u32).collect::>(); + let frequencies = (0..num_rows as u32) + .map(|value| value % 7 + 1) + .collect::>(); + let block_max_scores = + (0..num_rows.div_ceil(block_size)).map(|value| value as f32 + 1.0); + + let posting_list = compress_posting_list_with_tail_codec_and_block_size( + doc_ids.len(), + doc_ids.iter(), + frequencies.iter(), + block_max_scores, + PostingTailCodec::VarintDelta, + block_size, + )?; + assert_eq!(posting_list.len(), num_rows.div_ceil(block_size)); + + let (decoded_doc_ids, decoded_frequencies) = + decompress_posting_list_with_tail_codec_and_block_size( + num_rows as u32, + &posting_list, + PostingTailCodec::VarintDelta, + block_size, + )?; + assert_eq!(decoded_doc_ids, doc_ids); + assert_eq!(decoded_frequencies, frequencies); + } + Ok(()) + } + + #[test] + fn test_256_posting_block_uses_single_physical_bitpack_chunk() -> Result<()> { + let block_size = BitPacker8x::BLOCK_LEN; + let doc_ids = (0..block_size as u32).collect::>(); + let frequencies = (0..block_size as u32) + .map(|value| value % 13 + 1) + .collect::>(); + + let posting_list = compress_posting_list_with_tail_codec_and_block_size( + doc_ids.len(), + doc_ids.iter(), + frequencies.iter(), + std::iter::once(1.0), + PostingTailCodec::VarintDelta, + block_size, + )?; + assert_eq!(posting_list.len(), 1); + + let block = posting_list.value(0); + let doc_num_bits = block[8]; + let doc_bytes = BitPacker8x::compressed_block_size(doc_num_bits); + let freq_header_offset = 9 + doc_bytes; + let freq_num_bits = block[freq_header_offset]; + let freq_bytes = BitPacker8x::compressed_block_size(freq_num_bits); + assert_eq!(block.len(), freq_header_offset + 1 + freq_bytes); + + let (decoded_doc_ids, decoded_frequencies) = + decompress_posting_list_with_tail_codec_and_block_size( + doc_ids.len() as u32, + &posting_list, + PostingTailCodec::VarintDelta, + block_size, + )?; + assert_eq!(decoded_doc_ids, doc_ids); + assert_eq!(decoded_frequencies, frequencies); + Ok(()) + } + #[test] fn test_compress_posting_list_fixed32_tail_still_roundtrips() -> Result<()> { let doc_ids = vec![3_u32, 10_u32, 24_u32]; diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index fc0b02ec209..48695d624e2 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -55,11 +55,12 @@ use tracing::{info, instrument}; use super::encoding::{PositionBlockBuilder, decode_group_starts}; use super::iter::PostingListIterator; use super::lazy_docset::LazyDocSet; +use super::tokenizer::{LEGACY_BLOCK_SIZE, validate_block_size}; use super::{InvertedIndexBuilder, InvertedIndexParams, wand::*}; use super::{ builder::{ BLOCK_SIZE, PostingGroupAccumulator, PostingGroupConfig, ScoredDoc, doc_file_path, - inverted_list_schema_for_version, posting_file_path, token_file_path, + inverted_list_schema_for_version_with_block_size, posting_file_path, token_file_path, }, iter::PlainPostingListIterator, query::*, @@ -111,6 +112,7 @@ pub const TOKEN_SET_FORMAT_KEY: &str = "token_set_format"; pub const POSTING_TAIL_CODEC_KEY: &str = "posting_tail_codec"; pub const POSITIONS_LAYOUT_KEY: &str = "positions_layout"; pub const POSITIONS_CODEC_KEY: &str = "positions_codec"; +pub const POSTING_BLOCK_SIZE_KEY: &str = "posting_block_size"; /// Schema-metadata key holding the 1-indexed global-buffer id of the /// varint-delta-encoded posting-list cache-group boundaries (issue #7040). /// Absent on indexes written before grouping was introduced, which fall back @@ -365,6 +367,21 @@ pub(super) fn parse_posting_tail_codec( .unwrap_or(PostingTailCodec::Fixed32)) } +pub(super) fn parse_posting_block_size(metadata: &HashMap) -> Result { + metadata + .get(POSTING_BLOCK_SIZE_KEY) + .map(|value| { + let block_size = value.parse::().map_err(|err| { + Error::index(format!( + "invalid {POSTING_BLOCK_SIZE_KEY} metadata value {value:?}: {err}" + )) + })?; + validate_block_size(block_size) + }) + .transpose() + .map(|block_size| block_size.unwrap_or(LEGACY_BLOCK_SIZE)) +} + impl PositionStreamCodec { pub fn as_str(self) -> &'static str { match self { @@ -1556,6 +1573,13 @@ impl InvertedPartition { postings: Vec, docs: &DocSet, ) -> Result { + let block_size = postings + .iter() + .find_map(|posting| match posting { + PostingList::Compressed(posting) => Some(posting.block_size), + PostingList::Plain(_) => None, + }) + .unwrap_or(LEGACY_BLOCK_SIZE); let mut freqs_by_doc_id = BTreeMap::new(); for posting in postings { for (doc_id, freq, _) in posting.iter() { @@ -1580,7 +1604,7 @@ impl InvertedPartition { ))); } - let mut builder = PostingListBuilder::new(false); + let mut builder = PostingListBuilder::new_with_block_size(false, block_size); let mut doc_ids = Vec::with_capacity(freqs_by_doc_id.len()); let mut frequencies = Vec::with_capacity(freqs_by_doc_id.len()); for (doc_id, freq) in freqs_by_doc_id { @@ -1588,7 +1612,11 @@ impl InvertedPartition { doc_ids.push(doc_id); frequencies.push(freq); } - let block_max_scores = docs.calculate_block_max_scores(doc_ids.iter(), frequencies.iter()); + let block_max_scores = docs.calculate_block_max_scores_with_block_size( + doc_ids.iter(), + frequencies.iter(), + block_size, + ); let batch = builder.to_batch(block_max_scores)?; let max_score = batch[MAX_SCORE_COL].as_primitive::().value(0); let length = batch[LENGTH_COL].as_primitive::().value(0); @@ -1819,11 +1847,12 @@ impl InvertedPartition { } pub async fn into_builder(self) -> Result { - let mut builder = InnerBuilder::new_with_posting_tail_codec( + let mut builder = InnerBuilder::new_with_posting_tail_codec_and_block_size( self.id, self.inverted_list.has_positions(), self.token_set_format, self.inverted_list.posting_tail_codec(), + self.inverted_list.block_size(), ); builder.tokens = self.tokens.into_mutable(); // into_builder rewrites every doc, so materialize the full @@ -2233,6 +2262,7 @@ pub struct PostingListReader { has_position: bool, posting_tail_codec: PostingTailCodec, + block_size: usize, positions_layout: PositionsLayout, /// First row of each posting-list cache group, decoded at open from the @@ -2331,6 +2361,7 @@ impl PostingListReader { PositionsLayout::None }; let posting_tail_codec = parse_posting_tail_codec(&reader.schema().metadata)?; + let block_size = parse_posting_block_size(&reader.schema().metadata)?; let has_position = positions_layout != PositionsLayout::None; let metadata = if reader.schema().field(POSTING_COL).is_none() { let (offsets, max_scores) = Self::load_metadata(reader.schema())?; @@ -2351,6 +2382,7 @@ impl PostingListReader { metadata, has_position, posting_tail_codec, + block_size, positions_layout, group_starts, index_cache: WeakLanceCache::from(index_cache), @@ -2413,6 +2445,10 @@ impl PostingListReader { self.posting_tail_codec } + pub(crate) fn block_size(&self) -> usize { + self.block_size + } + fn is_legacy_layout(&self) -> bool { matches!(self.metadata, PostingMetadata::LegacyV1 { .. }) } @@ -2704,6 +2740,7 @@ impl PostingListReader { max_score: Option, length: Option, posting_tail_codec: PostingTailCodec, + block_size: usize, positions_layout: PositionsLayout, ) -> Result { let posting_list = PostingList::from_batch_with_tail_codec_and_positions_layout( @@ -2711,6 +2748,7 @@ impl PostingListReader { max_score, length, posting_tail_codec, + block_size, positions_layout, )?; Ok(posting_list) @@ -2727,6 +2765,7 @@ impl PostingListReader { max_score, length, self.posting_tail_codec, + self.block_size, self.positions_layout, ) } @@ -2763,6 +2802,7 @@ impl PostingListReader { ctx.max_scores.map(|scores| scores[global]), ctx.lengths.map(|lengths| lengths[global]), ctx.posting_tail_codec, + ctx.block_size, ctx.positions_layout, )?; posting_lists.push((global as u32, posting_list)); @@ -2899,6 +2939,7 @@ impl PostingListReader { max_scores: max_scores.map(Arc::new), lengths: lengths.map(Arc::new), posting_tail_codec: self.posting_tail_codec, + block_size: self.block_size, positions_layout: self.positions_layout, } } @@ -2932,12 +2973,14 @@ impl PostingListReader { let max_scores = state.max_scores.clone(); let lengths = state.lengths.clone(); let posting_tail_codec = state.posting_tail_codec; + let block_size = state.block_size; let positions_layout = state.positions_layout; let posting_lists = spawn_blocking(move || { let ctx = PrewarmBuildCtx { max_scores: max_scores.as_deref().map(|v| v.as_slice()), lengths: lengths.as_deref().map(|v| v.as_slice()), posting_tail_codec, + block_size, positions_layout, }; let chunk = PrewarmChunk { @@ -3188,6 +3231,7 @@ struct ChunkBuildState { max_scores: Option>>, lengths: Option>>, posting_tail_codec: PostingTailCodec, + block_size: usize, positions_layout: PositionsLayout, } @@ -3198,6 +3242,7 @@ struct PrewarmBuildCtx<'a> { max_scores: Option<&'a [f32]>, lengths: Option<&'a [u32]>, posting_tail_codec: PostingTailCodec, + block_size: usize, positions_layout: PositionsLayout, } @@ -3461,7 +3506,8 @@ impl PostingList { length: Option, ) -> Result { let posting_tail_codec = parse_posting_tail_codec(batch.schema_ref().metadata())?; - Self::from_batch_with_tail_codec(batch, max_score, length, posting_tail_codec) + let block_size = parse_posting_block_size(batch.schema_ref().metadata())?; + Self::from_batch_with_tail_codec(batch, max_score, length, posting_tail_codec, block_size) } pub fn from_batch_with_tail_codec( @@ -3469,6 +3515,7 @@ impl PostingList { max_score: Option, length: Option, posting_tail_codec: PostingTailCodec, + block_size: usize, ) -> Result { let positions_layout = if batch.column_by_name(COMPRESSED_POSITION_COL).is_some() { PositionsLayout::SharedStream(parse_shared_position_codec( @@ -3484,6 +3531,7 @@ impl PostingList { max_score, length, posting_tail_codec, + block_size, positions_layout, ) } @@ -3493,6 +3541,7 @@ impl PostingList { max_score: Option, length: Option, posting_tail_codec: PostingTailCodec, + block_size: usize, positions_layout: PositionsLayout, ) -> Result { match batch.column_by_name(POSTING_COL) { @@ -3507,6 +3556,7 @@ impl PostingList { max_score.unwrap(), length.unwrap(), posting_tail_codec, + block_size, shared_position_codec, ); Ok(Self::Compressed(posting)) @@ -3578,9 +3628,14 @@ impl PostingList { Self::Plain(_) => PostingTailCodec::Fixed32, Self::Compressed(posting) => posting.posting_tail_codec, }; - let mut builder = PostingListBuilder::new_with_posting_tail_codec( + let block_size = match &self { + Self::Plain(_) => LEGACY_BLOCK_SIZE, + Self::Compressed(posting) => posting.block_size, + }; + let mut builder = PostingListBuilder::new_with_posting_tail_codec_and_block_size( self.has_position(), posting_tail_codec, + block_size, ); match self { // legacy format @@ -3738,9 +3793,11 @@ pub struct CompressedPostingList { pub max_score: f32, pub length: u32, // each binary is a block of compressed data - // that contains `BLOCK_SIZE` doc ids and then `BLOCK_SIZE` frequencies + // that contains `block_size` doc ids and then `block_size` frequencies, + // packed by the physical bitpacker matching that block size. pub blocks: LargeBinaryArray, pub posting_tail_codec: PostingTailCodec, + pub block_size: usize, pub positions: Option, } @@ -3761,6 +3818,7 @@ impl CompressedPostingList { max_score: f32, length: u32, posting_tail_codec: PostingTailCodec, + block_size: usize, positions: Option, ) -> Self { Self { @@ -3768,6 +3826,7 @@ impl CompressedPostingList { length, blocks, posting_tail_codec, + block_size, positions, } } @@ -3777,6 +3836,7 @@ impl CompressedPostingList { max_score: f32, length: u32, posting_tail_codec: PostingTailCodec, + block_size: usize, shared_position_codec: Option, ) -> Self { debug_assert_eq!(batch.num_rows(), 1); @@ -3813,6 +3873,7 @@ impl CompressedPostingList { length, blocks, posting_tail_codec, + block_size, positions, } } @@ -3823,6 +3884,7 @@ impl CompressedPostingList { self.blocks.clone(), self.posting_tail_codec, self.positions.clone(), + self.block_size, ) } @@ -3833,7 +3895,7 @@ impl CompressedPostingList { pub fn block_least_doc_id(&self, block_idx: usize) -> u32 { let block = self.blocks.value(block_idx); - let remainder = self.length as usize % BLOCK_SIZE; + let remainder = self.length as usize % self.block_size; let is_remainder_block = remainder > 0 && block_idx + 1 == self.blocks.len(); if is_remainder_block { super::encoding::read_posting_tail_first_doc(block, self.posting_tail_codec) @@ -3964,6 +4026,7 @@ pub struct PostingListBuilder { open_doc_id: Option, open_doc_frequency: u32, open_doc_last_position: Option, + block_size: usize, memory_size_bytes: u32, len: u32, } @@ -3993,6 +4056,7 @@ enum BatchPositionsBuilder { struct PostingListParts<'a> { with_positions: bool, posting_tail_codec: PostingTailCodec, + block_size: usize, length: usize, encoded_blocks: EncodedBlocks, encoded_position_blocks: EncodedPositionBlocks, @@ -4155,9 +4219,10 @@ impl PostingListBuilder { } pub fn new(with_position: bool) -> Self { - Self::new_with_posting_tail_codec( + Self::new_with_posting_tail_codec_and_block_size( with_position, current_fts_format_version().posting_tail_codec(), + LEGACY_BLOCK_SIZE, ) } @@ -4165,6 +4230,27 @@ impl PostingListBuilder { with_position: bool, posting_tail_codec: PostingTailCodec, ) -> Self { + Self::new_with_posting_tail_codec_and_block_size( + with_position, + posting_tail_codec, + LEGACY_BLOCK_SIZE, + ) + } + + pub fn new_with_block_size(with_position: bool, block_size: usize) -> Self { + Self::new_with_posting_tail_codec_and_block_size( + with_position, + current_fts_format_version().posting_tail_codec(), + block_size, + ) + } + + pub fn new_with_posting_tail_codec_and_block_size( + with_position: bool, + posting_tail_codec: PostingTailCodec, + block_size: usize, + ) -> Self { + validate_block_size(block_size).expect("invalid posting list block size"); Self { with_positions: with_position, posting_tail_codec, @@ -4175,6 +4261,7 @@ impl PostingListBuilder { open_doc_id: None, open_doc_frequency: 0, open_doc_last_position: None, + block_size, len: 0, memory_size_bytes: 0, } @@ -4196,8 +4283,8 @@ impl PostingListBuilder { &self, mut visit: impl FnMut(u32, u32, Option>) -> std::result::Result<(), E>, ) -> std::result::Result<(), E> { - let mut doc_ids = Vec::with_capacity(BLOCK_SIZE); - let mut frequencies = Vec::with_capacity(BLOCK_SIZE); + let mut doc_ids = Vec::with_capacity(self.block_size); + let mut frequencies = Vec::with_capacity(self.block_size); let mut decoded_positions = Vec::new(); let mut position_block_index = 0usize; @@ -4205,7 +4292,12 @@ impl PostingListBuilder { for block in encoded_blocks.iter() { doc_ids.clear(); frequencies.clear(); - super::encoding::decode_full_posting_block(block, &mut doc_ids, &mut frequencies); + super::encoding::decode_full_posting_block( + block, + &mut doc_ids, + &mut frequencies, + self.block_size, + ); decoded_positions.clear(); if self.with_positions { let position_blocks = self @@ -4285,7 +4377,7 @@ impl PostingListBuilder { } self.len += 1; - if self.tail_entries.len() == BLOCK_SIZE { + if self.tail_entries.len() == self.block_size { self.flush_tail_block() .expect("posting list block compression should succeed"); } @@ -4344,7 +4436,7 @@ impl PostingListBuilder { self.open_doc_id = None; self.open_doc_frequency = 0; self.open_doc_last_position = None; - if self.tail_entries.len() == BLOCK_SIZE { + if self.tail_entries.len() == self.block_size { self.flush_tail_block()?; } Ok(()) @@ -4395,13 +4487,17 @@ impl PostingListBuilder { self.open_doc_id.is_none(), "cannot flush a posting block while a document is still open" ); - debug_assert_eq!(self.tail_entries.len(), BLOCK_SIZE); - let mut doc_ids = [0u32; BLOCK_SIZE]; - let mut frequencies = [0u32; BLOCK_SIZE]; - for (index, entry) in self.tail_entries.iter().enumerate() { - doc_ids[index] = entry.doc_id; - frequencies[index] = entry.frequency; - } + debug_assert_eq!(self.tail_entries.len(), self.block_size); + let doc_ids = self + .tail_entries + .iter() + .map(|entry| entry.doc_id) + .collect::>(); + let frequencies = self + .tail_entries + .iter() + .map(|entry| entry.frequency) + .collect::>(); let encoded_blocks_size_before = self .encoded_blocks .as_ref() @@ -4572,6 +4668,7 @@ impl PostingListBuilder { open_doc_id, open_doc_frequency, open_doc_last_position, + block_size, len, .. } = self; @@ -4581,6 +4678,7 @@ impl PostingListBuilder { let parts = PostingListParts { with_positions, posting_tail_codec, + block_size, length: len as usize, encoded_blocks: encoded_blocks .map(|encoded_blocks| *encoded_blocks) @@ -4619,6 +4717,7 @@ impl PostingListBuilder { with_positions, posting_tail_codec, length, + block_size, mut encoded_blocks, mut encoded_position_blocks, tail_entries, @@ -4627,14 +4726,19 @@ impl PostingListBuilder { let avgdl = docs.average_length(); let idf_scale = idf(length, docs.len()) * (K1 + 1.0); let mut max_score = f32::MIN; - let mut doc_ids = Vec::with_capacity(BLOCK_SIZE); - let mut frequencies = Vec::with_capacity(BLOCK_SIZE); + let mut doc_ids = Vec::with_capacity(block_size); + let mut frequencies = Vec::with_capacity(block_size); for index in 0..encoded_blocks.len() { let block = encoded_blocks.block(index); doc_ids.clear(); frequencies.clear(); - super::encoding::decode_full_posting_block(block, &mut doc_ids, &mut frequencies); + super::encoding::decode_full_posting_block( + block, + &mut doc_ids, + &mut frequencies, + block_size, + ); let block_score = compute_block_score( docs, avgdl, @@ -4733,7 +4837,11 @@ impl PostingListBuilder { } else { InvertedListFormatVersion::V2 }; - let schema = inverted_list_schema_for_version(self.has_positions(), format_version); + let schema = inverted_list_schema_for_version_with_block_size( + self.has_positions(), + format_version, + self.block_size, + ); let legacy_positions = if self.with_positions && !format_version.uses_shared_position_stream() { Some(self.build_legacy_positions()?) @@ -4750,6 +4858,7 @@ impl PostingListBuilder { open_doc_id, open_doc_frequency, open_doc_last_position, + block_size, len, .. } = self; @@ -4780,6 +4889,7 @@ impl PostingListBuilder { open_doc_id: None, open_doc_frequency: 0, open_doc_last_position: None, + block_size, memory_size_bytes: 0, len, }; @@ -4814,6 +4924,7 @@ impl PostingListBuilder { open_doc_id, open_doc_frequency, open_doc_last_position, + block_size, len, .. } = self; @@ -4823,6 +4934,7 @@ impl PostingListBuilder { let parts = PostingListParts { with_positions, posting_tail_codec, + block_size, length: len as usize, encoded_blocks: encoded_blocks .map(|encoded_blocks| *encoded_blocks) @@ -4845,6 +4957,7 @@ impl PostingListBuilder { open_doc_id: None, open_doc_frequency: 0, open_doc_last_position: None, + block_size, memory_size_bytes: 0, len, }; @@ -4857,8 +4970,11 @@ impl PostingListBuilder { pub fn remap(&mut self, removed: &[u32]) { let mut cursor = 0; - let mut new_builder = - Self::new_with_posting_tail_codec(self.has_positions(), self.posting_tail_codec); + let mut new_builder = Self::new_with_posting_tail_codec_and_block_size( + self.has_positions(), + self.posting_tail_codec, + self.block_size, + ); for (doc_id, freq, positions) in self.iter() { while cursor < removed.len() && removed[cursor] < doc_id { cursor += 1; @@ -5079,9 +5195,19 @@ impl DocSet { doc_ids: impl Iterator, freqs: impl Iterator, ) -> Vec { + self.calculate_block_max_scores_with_block_size(doc_ids, freqs, LEGACY_BLOCK_SIZE) + } + + pub fn calculate_block_max_scores_with_block_size<'a>( + &self, + doc_ids: impl Iterator, + freqs: impl Iterator, + block_size: usize, + ) -> Vec { + validate_block_size(block_size).expect("invalid posting list block size"); let avgdl = self.average_length(); let length = doc_ids.size_hint().0; - let num_blocks = length.div_ceil(BLOCK_SIZE); + let num_blocks = length.div_ceil(block_size); let mut block_max_scores = Vec::with_capacity(num_blocks); let idf_scale = idf(length, self.len()) * (K1 + 1.0); let mut max_score = f32::MIN; @@ -5092,13 +5218,13 @@ impl DocSet { if score > max_score { max_score = score; } - if (i + 1) % BLOCK_SIZE == 0 { + if (i + 1) % block_size == 0 { max_score *= idf_scale; block_max_scores.push(max_score); max_score = f32::MIN; } } - if !length.is_multiple_of(BLOCK_SIZE) { + if !length.is_multiple_of(block_size) { max_score *= idf_scale; block_max_scores.push(max_score); } @@ -5756,15 +5882,20 @@ mod tests { token: &str, row_id: u64, ) -> Result> { - let mut partition = InnerBuilder::new_with_format_version( + let block_size = params.posting_block_size(); + let mut partition = InnerBuilder::new_with_format_version_and_block_size( 0, false, token_set_format, InvertedListFormatVersion::V1, + block_size, ); partition.tokens.add(token.to_owned()); - let mut posting_list = - PostingListBuilder::new_with_posting_tail_codec(false, PostingTailCodec::Fixed32); + let mut posting_list = PostingListBuilder::new_with_posting_tail_codec_and_block_size( + false, + PostingTailCodec::Fixed32, + block_size, + ); posting_list.add(0, PositionRecorder::Count(1)); partition.posting_lists.push(posting_list); partition.docs.append(row_id, 1); @@ -5800,6 +5931,84 @@ mod tests { )) } + #[test] + fn test_posting_block_size_schema_metadata() { + assert_eq!(parse_posting_block_size(&HashMap::new()).unwrap(), 128); + + let metadata = HashMap::from([(POSTING_BLOCK_SIZE_KEY.to_owned(), "512".to_owned())]); + let err = parse_posting_block_size(&metadata).unwrap_err(); + assert!(err.to_string().contains("block_size")); + + let metadata = HashMap::from([(POSTING_BLOCK_SIZE_KEY.to_owned(), "129".to_owned())]); + let err = parse_posting_block_size(&metadata).unwrap_err(); + assert!(err.to_string().contains("block_size")); + } + + #[tokio::test] + async fn test_build_search_uses_configured_posting_block_size() { + let tmpdir = TempObjDir::default(); + let store = Arc::new(LanceIndexStore::new( + ObjectStore::local().into(), + tmpdir.clone(), + Arc::new(LanceCache::no_cache()), + )); + + let params = InvertedIndexParams::default().block_size(256).unwrap(); + let block_size = params.posting_block_size(); + let num_docs = block_size + 7; + + let mut builder = InnerBuilder::new_with_format_version_and_block_size( + 0, + false, + TokenSetFormat::default(), + InvertedListFormatVersion::V1, + block_size, + ); + builder.tokens.add("needle".to_owned()); + let mut posting_list = PostingListBuilder::new_with_posting_tail_codec_and_block_size( + false, + PostingTailCodec::Fixed32, + block_size, + ); + for doc_id in 0..num_docs { + posting_list.add(doc_id as u32, PositionRecorder::Count(1)); + builder.docs.append(1_000 + doc_id as u64, 1); + } + builder.posting_lists.push(posting_list); + builder.write(store.as_ref()).await.unwrap(); + write_test_metadata(&store, vec![0], params).await; + + let cache = Arc::new(LanceCache::with_capacity(4096)); + let index = InvertedIndex::load(store.clone(), None, cache.as_ref()) + .await + .unwrap(); + assert_eq!(index.partitions[0].inverted_list.block_size(), block_size); + + let posting = index.partitions[0] + .inverted_list + .posting_list(0, false, &NoOpMetricsCollector) + .await + .unwrap(); + let PostingList::Compressed(posting) = posting else { + panic!("expected compressed posting list"); + }; + assert_eq!(posting.block_size, block_size); + assert_eq!(posting.blocks.len(), num_docs.div_ceil(block_size)); + + let tokens = Arc::new(Tokens::new(vec!["needle".to_owned()], DocType::Text)); + let params = Arc::new(FtsSearchParams::new().with_limit(Some(10))); + let prefilter = Arc::new(NoFilter); + let metrics = Arc::new(NoOpMetricsCollector); + let (row_ids, scores) = index + .bm25_search(tokens, params, Operator::Or, prefilter, metrics, None) + .await + .unwrap(); + + assert_eq!(row_ids.len(), 10); + assert_eq!(scores.len(), 10); + assert!(row_ids.iter().all(|row_id| *row_id >= 1_000)); + } + #[tokio::test] async fn test_posting_builder_remap() { let posting_tail_codec = PostingTailCodec::Fixed32; @@ -7397,6 +7606,7 @@ mod tests { 1.0, SLICE_LEN as u32, PostingTailCodec::Fixed32, + LEGACY_BLOCK_SIZE, None, ); diff --git a/rust/lance-index/src/scalar/inverted/iter.rs b/rust/lance-index/src/scalar/inverted/iter.rs index dc07b15769c..51970c46c33 100644 --- a/rust/lance-index/src/scalar/inverted/iter.rs +++ b/rust/lance-index/src/scalar/inverted/iter.rs @@ -6,10 +6,9 @@ use arrow_array::{Array, LargeBinaryArray}; use super::{ CompressedPositionStorage, PostingList, PostingTailCodec, - builder::BLOCK_SIZE, encoding::{ - decode_position_stream_block, decompress_positions, decompress_posting_block, - decompress_posting_remainder, + MAX_POSTING_BLOCK_SIZE, decode_position_stream_block, decompress_positions, + decompress_posting_block, decompress_posting_remainder, }, }; @@ -28,6 +27,7 @@ impl<'a> PostingListIterator<'a> { posting.blocks.clone(), posting.posting_tail_codec, posting.positions.clone(), + posting.block_size, ))) } } @@ -82,13 +82,14 @@ pub struct CompressedPostingListIterator { next_block_idx: usize, posting_tail_codec: PostingTailCodec, positions: Option, + block_size: usize, idx: usize, doc_ids: Vec, frequencies: Vec, doc_idx_in_block: usize, decoded_positions: Vec, position_offsets: Vec, - buffer: [u32; BLOCK_SIZE], + buffer: [u32; MAX_POSTING_BLOCK_SIZE], } impl CompressedPostingListIterator { @@ -97,10 +98,11 @@ impl CompressedPostingListIterator { blocks: LargeBinaryArray, posting_tail_codec: PostingTailCodec, positions: Option, + block_size: usize, ) -> Self { debug_assert!(length > 0, "length: {}", length); debug_assert_eq!( - length.div_ceil(BLOCK_SIZE), + length.div_ceil(block_size), blocks.len(), "length: {}, num_blocks: {}", length, @@ -108,18 +110,19 @@ impl CompressedPostingListIterator { ); Self { - remainder: length % BLOCK_SIZE, + remainder: length % block_size, blocks, next_block_idx: 0, posting_tail_codec, positions, + block_size, idx: 0, doc_ids: Vec::new(), frequencies: Vec::new(), doc_idx_in_block: 0, decoded_positions: Vec::new(), position_offsets: Vec::new(), - buffer: [0; BLOCK_SIZE], + buffer: [0; MAX_POSTING_BLOCK_SIZE], } } } @@ -172,6 +175,7 @@ impl Iterator for CompressedPostingListIterator { &mut self.buffer, &mut self.doc_ids, &mut self.frequencies, + self.block_size, ); } self.doc_idx_in_block = 0; diff --git a/rust/lance-index/src/scalar/inverted/tokenizer.rs b/rust/lance-index/src/scalar/inverted/tokenizer.rs index 1785a23fedd..3cfce1ae279 100644 --- a/rust/lance-index/src/scalar/inverted/tokenizer.rs +++ b/rust/lance-index/src/scalar/inverted/tokenizer.rs @@ -2,7 +2,7 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use lance_core::{Error, Result}; -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use std::{env, path::PathBuf}; #[cfg(feature = "tokenizer-jieba")] @@ -29,6 +29,10 @@ use lance_tokenizer::{ WhitespaceTokenizer, }; +pub const LEGACY_BLOCK_SIZE: usize = 128; +pub const DEFAULT_BLOCK_SIZE: usize = 256; +pub const VALID_BLOCK_SIZES: [usize; 2] = [128, 256]; + /// Tokenizer configs #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct InvertedIndexParams { @@ -98,6 +102,17 @@ pub struct InvertedIndexParams { #[serde(default)] pub(crate) prefix_only: bool, + /// Number of documents in each compressed posting block. + /// + /// Missing serialized values come from indexes written before this + /// parameter existed and must read as 128 for backwards compatibility. New + /// indexes default to 256. + #[serde( + default = "legacy_block_size", + deserialize_with = "deserialize_block_size" + )] + pub(crate) block_size: usize, + /// Total memory limit in MiB for the build stage. /// /// This is split evenly across FTS workers at build time. By default Lance @@ -138,6 +153,7 @@ impl TryFrom<&InvertedIndexParams> for pbold::InvertedIndexDetails { min_ngram_length: params.min_ngram_length, max_ngram_length: params.max_ngram_length, prefix_only: params.prefix_only, + block_size: Some(params.block_size as u32), }) } } @@ -165,6 +181,10 @@ impl TryFrom<&pbold::InvertedIndexDetails> for InvertedIndexParams { min_ngram_length: details.min_ngram_length, max_ngram_length: details.max_ngram_length, prefix_only: details.prefix_only, + block_size: match details.block_size { + Some(block_size) => validate_block_size(block_size as usize)?, + None => LEGACY_BLOCK_SIZE, + }, memory_limit_mb: defaults.memory_limit_mb, num_workers: defaults.num_workers, }) @@ -183,6 +203,30 @@ fn default_max_ngram_length() -> u32 { 3 } +fn legacy_block_size() -> usize { + LEGACY_BLOCK_SIZE +} + +fn invalid_block_size_message(block_size: usize) -> String { + format!("FTS inverted index block_size must be one of 128 or 256, got {block_size}") +} + +pub fn validate_block_size(block_size: usize) -> Result { + if VALID_BLOCK_SIZES.contains(&block_size) { + Ok(block_size) + } else { + Err(Error::invalid_input(invalid_block_size_message(block_size))) + } +} + +fn deserialize_block_size<'de, D>(deserializer: D) -> std::result::Result +where + D: Deserializer<'de>, +{ + let block_size = usize::deserialize(deserializer)?; + validate_block_size(block_size).map_err(serde::de::Error::custom) +} + impl Default for InvertedIndexParams { fn default() -> Self { Self::new("simple".to_owned(), Language::English) @@ -220,6 +264,7 @@ impl InvertedIndexParams { min_ngram_length: default_min_ngram_length(), max_ngram_length: default_max_ngram_length(), prefix_only: false, + block_size: DEFAULT_BLOCK_SIZE, memory_limit_mb: None, num_workers: None, } @@ -310,6 +355,20 @@ impl InvertedIndexParams { self } + /// Set the compressed posting block size. + /// + /// Supported values are 128 and 256. Larger values reduce block-max metadata + /// and WAND skip granularity; smaller values preserve the legacy layout. + pub fn block_size(mut self, block_size: usize) -> Result { + self.block_size = validate_block_size(block_size)?; + Ok(self) + } + + /// Get the compressed posting block size. + pub fn posting_block_size(&self) -> usize { + self.block_size + } + pub fn memory_limit_mb(mut self, memory_limit_mb: u64) -> Self { self.memory_limit_mb = Some(memory_limit_mb); self @@ -345,6 +404,23 @@ impl InvertedIndexParams { Ok(value) } + /// Deserialize params for new index training, using current creation defaults + /// for omitted fields. + pub(crate) fn from_training_json(params: &str) -> Result { + let supplied = serde_json::from_str::(params)?; + let mut value = serde_json::to_value(Self::default())?; + + let supplied = supplied.as_object().ok_or_else(|| { + Error::invalid_input("FTS inverted index params must be a JSON object".to_string()) + })?; + let object = value + .as_object_mut() + .expect("inverted index params should serialize to a JSON object"); + object.extend(supplied.clone()); + + Ok(serde_json::from_value(value)?) + } + pub fn build(&self) -> Result> { let mut builder = self.build_base_tokenizer()?; if let Some(max_token_length) = self.max_token_length { @@ -450,8 +526,10 @@ pub fn language_model_home() -> Option { #[cfg(test)] mod tests { + use crate::pbold; + use super::InvertedIndexParams; - use lance_tokenizer::TokenStream; + use lance_tokenizer::{Language, TokenStream}; use rstest::rstest; #[test] @@ -502,6 +580,77 @@ mod tests { assert_eq!(json.get("num_workers"), Some(&serde_json::Value::from(3))); } + #[test] + fn test_block_size_new_default_serializes() { + let params = InvertedIndexParams::default(); + assert_eq!(params.block_size, 256); + let json = serde_json::to_value(¶ms).unwrap(); + assert_eq!(json.get("block_size"), Some(&serde_json::Value::from(256))); + } + + #[test] + fn test_block_size_missing_metadata_falls_back_to_128() { + let mut json = serde_json::to_value(InvertedIndexParams::default()).unwrap(); + json.as_object_mut().unwrap().remove("block_size"); + + let params: InvertedIndexParams = serde_json::from_value(json).unwrap(); + assert_eq!(params.block_size, 128); + } + + #[test] + fn test_block_size_details_conversion() { + let params = InvertedIndexParams::default().block_size(256).unwrap(); + let details = pbold::InvertedIndexDetails::try_from(¶ms).unwrap(); + assert_eq!(details.block_size, Some(256)); + + let old_details = pbold::InvertedIndexDetails { + base_tokenizer: Some("simple".to_string()), + language: serde_json::to_string(&Language::English).unwrap(), + with_position: false, + max_token_length: Some(40), + lower_case: true, + stem: true, + remove_stop_words: true, + ascii_folding: true, + min_ngram_length: 3, + max_ngram_length: 3, + prefix_only: false, + block_size: None, + }; + let params = InvertedIndexParams::try_from(&old_details).unwrap(); + assert_eq!(params.block_size, 128); + } + + #[test] + fn test_block_size_accepts_supported_values() { + for block_size in [128, 256] { + let params = InvertedIndexParams::default() + .block_size(block_size) + .unwrap(); + assert_eq!(params.block_size, block_size); + + let roundtrip: InvertedIndexParams = + serde_json::from_value(serde_json::to_value(¶ms).unwrap()).unwrap(); + assert_eq!(roundtrip.block_size, block_size); + } + } + + #[test] + fn test_block_size_rejects_invalid_values() { + let err = InvertedIndexParams::default().block_size(129).unwrap_err(); + assert!(err.to_string().contains("block_size")); + + let err = InvertedIndexParams::default().block_size(512).unwrap_err(); + assert!(err.to_string().contains("128 or 256")); + + let mut json = serde_json::to_value(InvertedIndexParams::default()).unwrap(); + json.as_object_mut() + .unwrap() + .insert("block_size".to_string(), serde_json::Value::from(1024)); + let err = serde_json::from_value::(json).unwrap_err(); + assert!(err.to_string().contains("128 or 256")); + } + #[test] fn test_build_icu_tokenizer() { let mut tokenizer = InvertedIndexParams::default() diff --git a/rust/lance-index/src/scalar/inverted/wand.rs b/rust/lance-index/src/scalar/inverted/wand.rs index 0fc8b95cddb..b94ac61aa0a 100644 --- a/rust/lance-index/src/scalar/inverted/wand.rs +++ b/rust/lance-index/src/scalar/inverted/wand.rs @@ -29,8 +29,8 @@ use super::{ CompressedPostingList, DocSet, PostingList, RawDocInfo, builder::ScoredDoc, encoding::{ - decode_position_stream_block, decompress_positions, decompress_posting_block, - decompress_posting_remainder, + MAX_POSTING_BLOCK_SIZE, decode_position_stream_block, decompress_positions, + decompress_posting_block, decompress_posting_remainder, }, query::FtsSearchParams, scorer::Scorer, @@ -66,7 +66,7 @@ struct CompressedState { block_idx: usize, doc_ids: Vec, freqs: Vec, - buffer: Box<[u32; BLOCK_SIZE]>, + buffer: Box<[u32; MAX_POSTING_BLOCK_SIZE]>, position_block_idx: Option, position_values: Vec, position_offsets: Vec, @@ -74,12 +74,12 @@ struct CompressedState { } impl CompressedState { - fn new() -> Self { + fn new(block_size: usize) -> Self { Self { block_idx: 0, - doc_ids: Vec::with_capacity(BLOCK_SIZE), - freqs: Vec::with_capacity(BLOCK_SIZE), - buffer: Box::new([0; BLOCK_SIZE]), + doc_ids: Vec::with_capacity(block_size), + freqs: Vec::with_capacity(block_size), + buffer: Box::new([0; MAX_POSTING_BLOCK_SIZE]), position_block_idx: None, position_values: Vec::new(), position_offsets: Vec::new(), @@ -95,11 +95,12 @@ impl CompressedState { num_blocks: usize, length: u32, tail_codec: super::PostingTailCodec, + block_size: usize, ) { self.doc_ids.clear(); self.freqs.clear(); - let remainder = length as usize % BLOCK_SIZE; + let remainder = length as usize % block_size; if block_idx + 1 == num_blocks && remainder != 0 { decompress_posting_remainder( block, @@ -109,7 +110,13 @@ impl CompressedState { &mut self.freqs, ); } else { - decompress_posting_block(block, &mut self.buffer, &mut self.doc_ids, &mut self.freqs); + decompress_posting_block( + block, + &mut self.buffer[..], + &mut self.doc_ids, + &mut self.freqs, + block_size, + ); } self.block_idx = block_idx; self.position_block_idx = None; @@ -279,6 +286,7 @@ impl PostingIterator { list.blocks.len(), list.length, list.posting_tail_codec, + list.block_size, ); } compressed as *mut CompressedState @@ -307,8 +315,12 @@ impl PostingIterator { Some(max_score) => max_score, None => idf(list.len(), num_doc) * (K1 + 1.0), }; - - let is_compressed = matches!(list, PostingList::Compressed(_)); + let compressed = match &list { + PostingList::Compressed(list) => { + Some(UnsafeCell::new(CompressedState::new(list.block_size))) + } + PostingList::Plain(_) => None, + }; Self { token, @@ -319,7 +331,7 @@ impl PostingIterator { index: 0, block_idx: 0, approximate_upper_bound, - compressed: is_compressed.then(|| UnsafeCell::new(CompressedState::new())), + compressed, } } @@ -361,8 +373,8 @@ impl PostingIterator { match self.list { PostingList::Compressed(ref list) => { - let block_idx = self.index / BLOCK_SIZE; - let block_offset = self.index % BLOCK_SIZE; + let block_idx = self.index / list.block_size; + let block_offset = self.index % list.block_size; let compressed = unsafe { &mut *self.ensure_compressed_block_ptr(list, block_idx) }; // Read from the decompressed block @@ -400,8 +412,8 @@ impl PostingIterator { )) } CompressedPositionStorage::SharedStream(stream) => { - let block_idx = self.index / BLOCK_SIZE; - let block_offset = self.index % BLOCK_SIZE; + let block_idx = self.index / list.block_size; + let block_offset = self.index % list.block_size; let compressed = unsafe { &mut *self.ensure_compressed_block_ptr(list, block_idx) }; if compressed.position_block_idx != Some(block_idx) { @@ -441,33 +453,33 @@ impl PostingIterator { PostingList::Compressed(ref list) => { debug_assert!(least_id <= u32::MAX as u64); let least_id = least_id as u32; - let mut block_idx = self.index / BLOCK_SIZE; + let mut block_idx = self.index / list.block_size; while block_idx + 1 < list.blocks.len() && list.block_least_doc_id(block_idx + 1) <= least_id { block_idx += 1; } - self.index = self.index.max(block_idx * BLOCK_SIZE); + self.index = self.index.max(block_idx * list.block_size); let length = list.length as usize; while self.index < length { - let block_idx = self.index / BLOCK_SIZE; - let block_offset = self.index % BLOCK_SIZE; + let block_idx = self.index / list.block_size; + let block_offset = self.index % list.block_size; let compressed = unsafe { &mut *self.ensure_compressed_block_ptr(list, block_idx) }; let in_block = &compressed.doc_ids[block_offset..]; let offset_in_block = in_block.partition_point(|&doc_id| doc_id < least_id); let new_offset = block_offset + offset_in_block; if new_offset < compressed.doc_ids.len() { - self.index = block_idx * BLOCK_SIZE + new_offset; + self.index = block_idx * list.block_size + new_offset; break; } if block_idx + 1 >= list.blocks.len() { self.index = length; break; } - self.index = (block_idx + 1) * BLOCK_SIZE; + self.index = (block_idx + 1) * list.block_size; } - self.block_idx = self.index / BLOCK_SIZE; + self.block_idx = self.index / list.block_size; } PostingList::Plain(ref list) => { self.index += list.row_ids[self.index..].partition_point(|&id| id < least_id); @@ -2146,6 +2158,7 @@ mod tests { max_score, doc_ids.len() as u32, crate::scalar::inverted::PostingTailCodec::VarintDelta, + crate::scalar::inverted::LEGACY_BLOCK_SIZE, None, )) } else { diff --git a/rust/lance/src/dataset/mem_wal/index/fts.rs b/rust/lance/src/dataset/mem_wal/index/fts.rs index 522c4c98231..7d0b70944ea 100644 --- a/rust/lance/src/dataset/mem_wal/index/fts.rs +++ b/rust/lance/src/dataset/mem_wal/index/fts.rs @@ -1699,6 +1699,7 @@ impl FtsMemIndex { let st = self.state.load_full(); let with_position = self.params.has_positions(); + let block_size = self.params.posting_block_size(); let total_rows_u64 = total_rows as u64; // Step 1: collect (original_pos, num_tokens) for every doc across all @@ -1716,10 +1717,11 @@ impl FtsMemIndex { } } if all_docs.is_empty() { - return Ok(InnerBuilder::new( + return Ok(InnerBuilder::new_with_block_size( partition_id, with_position, Default::default(), + block_size, )); } @@ -1805,7 +1807,10 @@ impl FtsMemIndex { docs_for_term.sort_by_key(|(doc_id, _, _)| *doc_id); let token_id = tokens.add(token) as usize; debug_assert_eq!(token_id, posting_lists.len()); - posting_lists.push(PostingListBuilder::new(with_position)); + posting_lists.push(PostingListBuilder::new_with_block_size( + with_position, + block_size, + )); let plb = &mut posting_lists[token_id]; for (doc_id, freq, pos) in docs_for_term { let recorder = if with_position { @@ -1817,7 +1822,12 @@ impl FtsMemIndex { } } - let mut builder = InnerBuilder::new(partition_id, with_position, Default::default()); + let mut builder = InnerBuilder::new_with_block_size( + partition_id, + with_position, + Default::default(), + block_size, + ); builder.set_tokens(tokens); builder.set_docs(docs); builder.set_posting_lists(posting_lists);