diff --git a/docs/training/torch.mdx b/docs/training/torch.mdx index 9a9dbe5..e9c571f 100644 --- a/docs/training/torch.mdx +++ b/docs/training/torch.mdx @@ -96,3 +96,80 @@ dataloader = torch.utils.data.DataLoader( for batch in dataloader: print(batch.schema) ``` + +## Using multiple DataLoader workers + +PyTorch's `DataLoader` can fan out reads across worker processes by setting `num_workers > 0`. LanceDB tables and `Permutation` objects are picklable, so each worker reopens its own connection after the worker process starts. + +Because LanceDB is multi-threaded internally, use the `spawn` start method (not `fork`) when running with multiple workers. See [the performance guide](/performance) for more on safe multiprocessing patterns. + +```py Python icon=Python +from lancedb.permutation import Permutation + +permutation = Permutation.identity(table) +dataloader = torch.utils.data.DataLoader( + permutation, + batch_size=1024, + shuffle=True, + num_workers=4, + multiprocessing_context="spawn", + persistent_workers=True, +) +``` + +### Remote tables in DataLoader workers + +Tables opened from a remote LanceDB Enterprise connection (`db://...`) also work with multi-worker DataLoaders. The connection details needed to reopen the table — `db_url`, `api_key`, `region`, `host_override`, and the serializable parts of `client_config` — travel with the pickled table and are used to rebuild the connection in each worker. + +```py Python icon=Python +import lancedb +from lancedb.permutation import Permutation + +db = lancedb.connect( + "db://my-database", + api_key="sk-...", + region="us-east-1", +) +table = db.open_table("my_table") + +permutation = Permutation.identity(table).select_columns(["id", "image"]) +dataloader = torch.utils.data.DataLoader( + permutation, + batch_size=512, + num_workers=4, + multiprocessing_context="spawn", +) +``` + + +This embeds the API key in the pickle sent to each worker. If you'd rather load credentials inside the worker — for example, from an environment variable or a secret manager — use the connection factory escape hatch described below. A factory is also required when your `client_config` uses a non-serializable `header_provider`. + + +### Providing a custom connection factory + +`Permutation.with_connection_factory` lets you control how each worker reopens the base table. The factory takes the base table name and returns a LanceDB table. It must be picklable, which in practice means a top-level function, a `functools.partial` of one, or an instance of a picklable class with `__call__` — lambdas and closures over local variables will not work. + +```py Python icon=Python +import os +import lancedb +from lancedb.permutation import Permutation + +def open_table(name: str): + db = lancedb.connect( + "db://my-database", + api_key=os.environ["LANCEDB_API_KEY"], + region="us-east-1", + ) + return db.open_table(name) + +permutation = ( + Permutation.identity(table) + .with_connection_factory(open_table) +) +dataloader = torch.utils.data.DataLoader( + permutation, + batch_size=512, + num_workers=4, + multiprocessing_context="spawn", +) +```