Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions docs/training/torch.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,80 @@ dataloader = torch.utils.data.DataLoader(
for batch in dataloader:
print(batch.schema)
```

## Using multiple DataLoader workers

PyTorch's `DataLoader` can fan out reads across worker processes by setting `num_workers > 0`. LanceDB tables and `Permutation` objects are picklable, so each worker reopens its own connection after the worker process starts.

Because LanceDB is multi-threaded internally, use the `spawn` start method (not `fork`) when running with multiple workers. See [the performance guide](/performance) for more on safe multiprocessing patterns.

```py Python icon=Python
from lancedb.permutation import Permutation

permutation = Permutation.identity(table)
dataloader = torch.utils.data.DataLoader(
permutation,
batch_size=1024,
shuffle=True,
num_workers=4,
multiprocessing_context="spawn",
persistent_workers=True,
)
```

### Remote tables in DataLoader workers

Tables opened from a remote LanceDB Enterprise connection (`db://...`) also work with multi-worker DataLoaders. The connection details needed to reopen the table — `db_url`, `api_key`, `region`, `host_override`, and the serializable parts of `client_config` — travel with the pickled table and are used to rebuild the connection in each worker.

```py Python icon=Python
import lancedb
from lancedb.permutation import Permutation

db = lancedb.connect(
"db://my-database",
api_key="sk-...",
region="us-east-1",
)
table = db.open_table("my_table")

permutation = Permutation.identity(table).select_columns(["id", "image"])
dataloader = torch.utils.data.DataLoader(
permutation,
batch_size=512,
num_workers=4,
multiprocessing_context="spawn",
)
```

<Note>
This embeds the API key in the pickle sent to each worker. If you'd rather load credentials inside the worker — for example, from an environment variable or a secret manager — use the connection factory escape hatch described below. A factory is also required when your `client_config` uses a non-serializable `header_provider`.
</Note>

### Providing a custom connection factory

`Permutation.with_connection_factory` lets you control how each worker reopens the base table. The factory takes the base table name and returns a LanceDB table. It must be picklable, which in practice means a top-level function, a `functools.partial` of one, or an instance of a picklable class with `__call__` — lambdas and closures over local variables will not work.

```py Python icon=Python
import os
import lancedb
from lancedb.permutation import Permutation

def open_table(name: str):
db = lancedb.connect(
"db://my-database",
api_key=os.environ["LANCEDB_API_KEY"],
region="us-east-1",
)
return db.open_table(name)

permutation = (
Permutation.identity(table)
.with_connection_factory(open_table)
)
dataloader = torch.utils.data.DataLoader(
permutation,
batch_size=512,
num_workers=4,
multiprocessing_context="spawn",
)
```
Loading