From f2efe02e155ef95a0bdd539f3cb22b96a4186e59 Mon Sep 17 00:00:00 2001 From: John_J <79534962+John-Jepsen@users.noreply.github.com> Date: Thu, 28 May 2026 22:26:22 -0500 Subject: [PATCH 1/8] docs(memory): add design doc for per-tenant memory isolation DD-0001 captures the scope, invariant, propagation path, partitioning tradeoff, threat model, test contract, and merge order for closing the multi-tenant memory leak in the unified Memory subsystem. Includes the source-of-truth Mermaid diagram for the identity propagation path under design-docs/diagrams/. The diagram is embedded in the design doc for inline review. --- .../0001-per-tenant-memory-isolation.md | 474 ++++++++++++++++++ design-docs/diagrams/identity-propagation.mmd | 69 +++ 2 files changed, 543 insertions(+) create mode 100644 design-docs/0001-per-tenant-memory-isolation.md create mode 100644 design-docs/diagrams/identity-propagation.mmd diff --git a/design-docs/0001-per-tenant-memory-isolation.md b/design-docs/0001-per-tenant-memory-isolation.md new file mode 100644 index 0000000000..fa067d03b5 --- /dev/null +++ b/design-docs/0001-per-tenant-memory-isolation.md @@ -0,0 +1,474 @@ +# DD-0001 — Per-Tenant Memory Isolation + +| | | +|---|---| +| **Status** | Draft — for maintainer review | +| **Scope** | `crewai.memory` (unified `Memory` + `StorageBackend`), `crewai.rag` (legacy `BaseRAGStorage`) | +| **Touches** | `lib/crewai/src/crewai/memory/`, `lib/crewai/src/crewai/rag/storage/`, `lib/cli/src/crewai_cli/` | +| **Backward compatible** | Yes (default tenant `_default`; existing call sites unchanged) | +| **Security-relevant** | Yes — see [Security Model](#security-model) | + +--- + +## TL;DR + +CrewAI's memory subsystem persists across sessions but does not partition by user or tenant identity. In any multi-user deployment, every user's memories land in the same collection and `recall()` mixes them. This is the single reason teams disable built-in memory in production and bolt on an external provider. + +This document specifies per-tenant memory isolation as a **security property** enforced at the storage boundary. It threads a `tenant_id` (and optional `user_id`) through the entire save → store → recall path, makes it a required predicate on every read in `StorageBackend`, and centralizes enforcement in a `ScopedStorage` proxy so no agent code path can bypass it. + +--- + +## The Invariant + +> **A `recall()` scoped to tenant A can never return a memory written by tenant B — under any ranking, embedding collision, query depth, deep-recall exploration round, or backend.** + +Everything in this document exists to make that one sentence true and to make it stay true. It is the **first** thing reviewers should evaluate any change against. + +--- + +## The Gap Today + +Two relevant code paths exist in the current tree: + +1. **Unified memory** (the live path): `lib/crewai/src/crewai/memory/unified_memory.py` + `StorageBackend` protocol at `lib/crewai/src/crewai/memory/storage/backend.py`. `MemoryRecord.source` (a free-form string) and `MemoryRecord.private` (a bool) exist as provenance hints. The recall path applies them as a **post-query Python filter**: + + ```python + # unified_memory.py:704-709 (current) + if not include_private: + raw = [(r, s) for r, s in raw + if not r.private or r.source == source] + ``` + + This is not isolation. A `private=False` row from user B with a strong semantic match is returned to user A and only `private=True` rows are filtered. Even for `private=True`, the row was *retrieved* from the vector store before being filtered — meaning it counted against ranking, against the oversample budget, and was visible inside `RecallFlow`'s exploration LLM prompts. + +2. **Legacy RAG memory**: `BaseRAGStorage` at `lib/crewai/src/crewai/rag/storage/base_rag_storage.py` has no per-user concept at all. Subclasses (ChromaDB, etc.) take an `embedder_config` and a `crew` but no identity. + +Both paths leak. Both must be closed in the same PR or one of them keeps shipping a vulnerability. This document targets the unified `StorageBackend` boundary primarily; the legacy `BaseRAGStorage` fix is mechanically the same (a `tenant_id` field in the `where` clause) and is sketched in [Legacy RAG Path](#legacy-rag-path). + +--- + +## Non-Goals (what this PR is *not*) + +The role this fix occupies sits between three things that look similar and are not: + +| Need | Solved by | Not this PR | +|---|---|---| +| Survives restart | SQLite + Chroma/LanceDB persistence | Already shipped | +| Shared across crews / cross-session knowledge graph | External memory provider (mem0, etc.) | Out of scope | +| **Per-user data isolation** | **This PR** | — | + +Specifically excluded from this PR (each is its own ticket): + +- **A mem0 replacement.** Long-term consolidation, decay heuristics, cross-crew knowledge sharing — that is the external provider's job. Adding it here turns a security fix into a 3000-line feature and it never merges. +- **Per-tenant encryption.** Different ticket. Isolation is necessary but not sufficient for "encrypted at rest per tenant." +- **Per-tenant collection sprawl (Option B partitioning).** Documented below as a follow-up; not the default. +- **Multi-tenant rate limiting / quota.** Not a memory concern. + +--- + +## Design + +### Identity model + +Two fields are added to `MemoryRecord`: + +| Field | Type | Default | Role | +|---|---|---|---| +| `tenant_id` | `str` | `"_default"` | **The security boundary.** Every record is owned by exactly one tenant. Storage filters on this unconditionally. | +| `user_id` | `str \| None` | `None` | Sub-tenant refinement *inside* a tenant. A `user_id` filter is a soft partition (org admin can recall across users in the same tenant); `tenant_id` is the hard wall. | + +Why two and not one: + +- A SaaS deployment has tenants (customers) and users (people inside the customer). They are different trust boundaries: the customer-admin role is allowed to query across their users; no customer is ever allowed to query across other customers. Collapsing them into a single field forces every operator to pick the wrong granularity. +- `tenant_id` is non-null because the invariant says "every row has an owner." `user_id` is nullable because plenty of records (system summaries, agent-emitted reflections) legitimately belong to a tenant but not to a user. +- Default `tenant_id="_default"` is what keeps single-user setups working unchanged. Existing rows read back as `_default`; the default `Memory()` constructor uses `_default`; no caller has to change. + +`source` and `private` stay on the record for provenance, but they are **no longer load-bearing for isolation**. The `recall()` post-filter at `unified_memory.py:704-709` is deleted. The migration guide tells anyone who used `source="user_42"` for isolation purposes to move to `tenant_id="user_42"`. + +### Identity propagation path + +```mermaid +flowchart TD + Caller["Caller
(API handler, Flow step, agent)"] + Resolve{"Resolve effective tenant
per-call > instance default > '_default'"} + + subgraph MemoryAPI["Memory — public API"] + Remember["remember(content, tenant_id=, user_id=)"] + Recall["recall(query, tenant_id=, user_id=)"] + Forget["forget(tenant_id=, user_id=)"] + Scoped["_scoped(tenant_id, user_id)
builds ScopedStorage proxy"] + end + + subgraph Enforcement["ScopedStorage — ENFORCEMENT CHOKEPOINT"] + Stamp["WRITE — stamp tenant_id on every record
cross-tenant record raises PermissionError"] + Inject["READ — inject tenant_id predicate
callers cannot omit it"] + Verify["VERIFY — re-check every returned row
foreign tenant row raises RuntimeError (loud)"] + end + + subgraph Protocol["StorageBackend Protocol"] + SaveAPI["save(records)"] + SearchAPI["search(*, tenant_id: str, …)
required keyword-only
mypy --strict catches omissions in CI"] + DeleteAPI["delete(*, tenant_id: str, …)"] + end + + subgraph Backend["Underlying backend — LanceDB / Chroma / Qdrant"] + Column["tenant_id column
NOT NULL + index on (tenant_id, scope)"] + Filter["WHERE tenant_id = ? pushed into vector query
foreign rows never enter candidate pool"] + end + + RecallFlow["RecallFlow (deep mode)
holds ScopedStorage, not raw backend
→ exploration cannot escape tenant filter"] + + Caller --> Resolve + Resolve --> Remember + Resolve --> Recall + Resolve --> Forget + + Remember --> Scoped + Recall --> Scoped + Forget --> Scoped + + Scoped --> Stamp + Scoped --> Inject + Inject --> Verify + + Stamp --> SaveAPI + Verify --> SearchAPI + Inject --> DeleteAPI + + SaveAPI --> Column + SearchAPI --> Filter + DeleteAPI --> Filter + + Recall -. depth=deep .-> RecallFlow + RecallFlow --> Scoped + + classDef enforce fill:#ffe6e6,stroke:#c0392b,stroke-width:3px,color:#000 + classDef boundary fill:#fff4e0,stroke:#e67e22,stroke-width:2px,color:#000 + classDef storage fill:#e8f4fd,stroke:#2980b9,stroke-width:2px,color:#000 + classDef caller fill:#eafaf1,stroke:#27ae60,stroke-width:2px,color:#000 + + class Stamp,Inject,Verify enforce + class SaveAPI,SearchAPI,DeleteAPI boundary + class Column,Filter storage + class Caller,Resolve caller +``` + +Two structural properties the diagram is meant to make obvious: + +1. **The red band is the only place isolation is enforced.** Everything above it routes through it; everything below it inherits a filter it cannot remove. Audit lives in one file. +2. **`RecallFlow` re-enters through `ScopedStorage`, not around it.** Deep-recall LLM exploration is safe by construction — the flow does not hold a reference to the raw backend. + +### The `StorageBackend` Protocol change + +`tenant_id` becomes a **required keyword-only parameter** on every read method: + +```python +@runtime_checkable +class StorageBackend(Protocol): + def save(self, records: list[MemoryRecord]) -> None: ... + + def search( + self, + query_embedding: list[float], + *, + tenant_id: str, # required, no default + user_id: str | None = None, + scope_prefix: str | None = None, + categories: list[str] | None = None, + metadata_filter: dict[str, Any] | None = None, + limit: int = 10, + min_score: float = 0.0, + ) -> list[tuple[MemoryRecord, float]]: ... + + def delete(self, *, tenant_id: str, user_id: str | None = None, …) -> int: ... + def get_record(self, record_id: str, *, tenant_id: str) -> MemoryRecord | None: ... + def list_records(self, *, tenant_id: str, …) -> list[MemoryRecord]: ... + def list_scopes(self, *, tenant_id: str, parent: str = "/") -> list[str]: ... + def list_categories(self, *, tenant_id: str, …) -> dict[str, int]: ... + def count(self, *, tenant_id: str, scope_prefix: str | None = None) -> int: ... + def reset(self, *, tenant_id: str, scope_prefix: str | None = None) -> None: ... + + async def asearch(self, …, *, tenant_id: str, …) -> …: ... + # … all async variants mirror the sync ones +``` + +Why keyword-only and required: + +- Required (no default) means **mypy `--strict` (which this repo already enforces) turns any forgotten caller into a CI failure.** Static enforcement of the invariant beats every code review. +- Keyword-only means it cannot accidentally swap with a positional `scope_prefix` arg. +- `save()` does **not** take `tenant_id`. The tenant is on the record itself; a separate parameter would invite "record says A, param says B" mismatches. Single source of truth. +- `reset()` cannot wipe everything. There is no "reset all tenants" path. An operator who wants that calls `reset(tenant_id="_default")` deliberately, and to wipe other tenants they iterate. This is intentional friction — accidental cross-tenant wipes are a recovery nightmare. + +### The `ScopedStorage` wrapper + +This is the single most important file in the PR. It is the chokepoint the invariant rides on. + +```python +# lib/crewai/src/crewai/memory/storage/scoped_storage.py + +class ScopedStorage: + """Wraps any StorageBackend and enforces tenant isolation. + + A ScopedStorage is bound to exactly one tenant_id (and optionally one user_id) + at construction. Every read it issues to the underlying backend carries that + tenant_id as a non-optional predicate. Every record it writes is stamped with + that tenant_id, overwriting whatever the caller put on the record. + + This class is the single chokepoint for the isolation invariant: + a recall bound to tenant A NEVER returns a row written by tenant B. + + If you add a new read method, it MUST go through the tenant predicate. + If you add a new write method, it MUST go through _stamp(). + """ +``` + +Three contracts the wrapper holds and any reviewer must confirm: + +1. **Stamp on write.** Every record passed to `save()` or `update()` is `model_copy()`'d with `tenant_id` set to the wrapper's tenant. If the caller supplies a record already stamped with a *different* tenant, the wrapper raises `PermissionError` — silent relabel is a footgun. + +2. **Inject on read.** Every read method passes `tenant_id` to the underlying backend. The wrapper has no API surface to omit it; callers cannot opt out. + +3. **Double-check on return.** After the backend returns rows, the wrapper re-verifies `r.tenant_id == self._tenant_id` for every row. If any row fails the check, it raises `RuntimeError` — not "silently filter and return the clean ones." A broken backend filter must be **loud** so the next test run catches it. Quiet filtering is exactly how the original bug shipped. + +The wrapper is cheap to construct (two strings; one reference). `Memory` builds a fresh one per call rather than caching, so a single long-lived `Memory` instance can serve many tenants concurrently without leaking state between them. + +### `Memory` API surface + +```python +class Memory(BaseModel): + tenant_id: str = Field(default="_default") + user_id: str | None = Field(default=None) + # … existing fields + + def remember( + self, + content: str, + *, + tenant_id: str | None = None, # per-call override + user_id: str | None = None, # per-call override + scope: str | None = None, + # … existing args (categories, metadata, importance, source, …) + ) -> MemoryRecord | None: ... + + def recall( + self, + query: str, + *, + tenant_id: str | None = None, + user_id: str | None = None, + # … existing args + ) -> list[MemoryMatch]: ... + + def forget( + self, + *, + tenant_id: str | None = None, + user_id: str | None = None, + # … existing args + ) -> int: ... +``` + +Resolution order for the effective tenant: **per-call kwarg > instance default > `"_default"`**. Same for `user_id`. The per-call override is the multi-tenant SaaS pattern (one `Memory` instance, many requests). The instance default is the single-user CLI pattern (set once in `Memory()`). + +`crew.kickoff()` grows an analogous pair of kwargs that thread down to the agents' memory access: + +```python +crew.kickoff(inputs={...}, tenant_id="customer_42", user_id="alice") +``` + +Agents do not see `tenant_id` directly. They get a memory handle that is already bound to the right tenant. This matters: prompt-injected agents cannot recover a tenant they were never given. + +### Partitioning strategy: Option A vs Option B + +Two ways to physically partition data: + +| | **Option A — Metadata filter (default)** | **Option B — Per-tenant namespace** | +|---|---|---| +| **Storage layout** | One collection / one table. Every row has `tenant_id`. Reads push `tenant_id = ?` into the backend's filter. | One collection / namespace per tenant. Reads route to the right namespace. | +| **Isolation strength** | Strong **with centralized enforcement** (the `ScopedStorage` wrapper + required protocol kwargs + the double-check). Weaker if enforcement is left to ad-hoc callers. | Stronger by construction: no shared index, cross-tenant queries impossible. | +| **"Right to be forgotten"** | A range delete. Cheap row-wise but the underlying vector index does not always shrink — periodic compaction needed. | Drop the namespace. O(1). | +| **Cold-start cost** | None. | A collection per tenant means N collections; many vector DBs amortize poorly at high N. | +| **Operational complexity** | Low. | High at scale — collection sprawl, per-collection embedder warmups, backend-specific quirks. | +| **Recommended for** | Default. Most CrewAI deployments. | Strict-isolation deployments (regulated industries, single-tenant-per-namespace contracts). | + +**Decision: ship Option A as the default, with the `ScopedStorage` wrapper holding the invariant.** Document Option B as a `Storage(strategy="namespace")` flavor that lands in a follow-up PR. Reviewers who want Option B as the default should engage on this section specifically — that is the one architectural call worth re-litigating. + +### Backend changes + +Each `StorageBackend` implementation gains: + +1. A `tenant_id` column / field with a NOT NULL constraint at the schema level wherever the backend supports it. The Protocol's required kwarg is the Python enforcement; the column constraint is the database enforcement. Defense in depth. +2. An index on `(tenant_id)` and a composite index on `(tenant_id, scope)` because the hot read path is always `WHERE tenant_id = ? AND scope LIKE ? || '%'`. +3. The `tenant_id = ?` predicate pushed into whatever the backend's filter syntax is: + + ```python + # LanceDB (current default) + where = f"tenant_id = {_quote(tenant_id)}" + if user_id is not None: + where += f" AND user_id = {_quote(user_id)}" + ``` + + ```python + # ChromaDB (legacy RAG path) + filter_clause = {"tenant_id": tenant_id} + if user_id is not None: + filter_clause["user_id"] = user_id + collection.query(query_embeddings=[…], where=filter_clause, …) + ``` + + ```python + # Qdrant (qdrant_edge_storage) + must = [FieldCondition(key="tenant_id", match=MatchValue(value=tenant_id))] + if user_id is not None: + must.append(FieldCondition(key="user_id", match=MatchValue(value=user_id))) + ``` + +`_quote` is the escape helper. Never f-string a raw `tenant_id` from request context into SQL — Bandit's `S608` rule will catch it in CI. + +### Legacy RAG path + +`BaseRAGStorage` at `lib/crewai/src/crewai/rag/storage/base_rag_storage.py` is the older boundary used by entity/short-term memory built before the unified `Memory` class. The same changes apply: + +- Add `tenant_id` parameter (keyword-only, required) to `save`, `search`, `reset`. +- A subclass-side `ScopedRAGStorage` wrapper mirrors `ScopedStorage`. +- ChromaDB's `where={"tenant_id": tenant_id}` is the enforcement push-down. + +Combining the unified and legacy fixes in one PR keeps the changelog honest: "memory" is one user-visible concept and shipping isolation for half of it is misleading. + +### Migration + +Existing unscoped data must not be silently orphaned. Two mechanisms: + +1. **One-shot backfill CLI**: `crewai memory migrate` + ``` + crewai memory migrate \ + --storage-dir $CREWAI_STORAGE_DIR \ + --default-tenant _default \ + [--default-user-id ] \ + [--dry-run] + ``` + - Scans every storage file under `CREWAI_STORAGE_DIR`. + - For rows where `tenant_id IS NULL OR tenant_id = ''`, stamps `_default`. + - Idempotent. Prints `migrated N rows`. + - `--dry-run` is mandatory in the docs example. Somebody will run this against prod by accident. + +2. **Startup warning**: `Memory.model_post_init` issues a one-line `WARNING` log if it detects any unstamped rows in the underlying store. It does **not** auto-migrate. Auto-migration on a shared DB is how teams get paged at 3am. The warning is the nudge to run the CLI deliberately during a maintenance window. + +3. **Schema migration**: For LanceDB and SQLite-backed stores, adding the `tenant_id` column on existing tables uses each backend's `ALTER` path with `DEFAULT '_default' NOT NULL`. Existing rows pick up the default at the storage layer in a single transaction. + +--- + +## Security Model + +**Threat model.** + +| Threat | Mitigation | +|---|---| +| Caller forgets to pass `tenant_id` | Required keyword-only kwarg on every Protocol read method; mypy `--strict` fails CI. Plus runtime default `"_default"` so the **fallback is a non-leaking single-tenant bucket, never the union of all data**. | +| Caller passes wrong `tenant_id` | Out of scope — that is the caller's responsibility. We do not authenticate the caller; we enforce the predicate they pass. Auth lives upstream (FastAPI dep, agent runtime, etc.). | +| Backend filter is broken or miscompiled | `ScopedStorage` re-verifies `r.tenant_id == self._tenant_id` on every returned row; raises `RuntimeError` on mismatch. **Loud over silent.** | +| Prompt-injected agent tries to recall cross-tenant | Agents receive a `Memory` handle pre-bound to a tenant; the wrapper provides no API to widen scope. The agent can ask, but the storage refuses. | +| Embedding collision returns a near-identical neighbor from another tenant | Filter is pushed into the vector search itself (LanceDB `WHERE`, Chroma `where=`, Qdrant `must`), not applied post-retrieval. Cross-tenant rows never enter the candidate pool. | +| Operator runs `forget()` and wipes other tenants | `forget()` requires a tenant scope. There is no "wipe everything." | +| Backup / dump exposes one tenant to another | Out of scope for this PR; dump tooling lives elsewhere. Documented as a follow-up. | + +**Embedder is shared across tenants.** Embeddings are not a security boundary; they are a content-addressable hash function. Per-tenant embedders would be a meaningless cost without changing the threat surface. + +--- + +## Test Contract + +A test file titled `test_tenant_isolation.py` lives in `lib/crewai/tests/memory/`. The features below define "done." If any one of these is missing or passes vacuously, the PR is unfinished. + +| Test | What it pins | +|---|---| +| `test_cross_tenant_recall_returns_nothing` | The core invariant. Two tenants save near-identical content with the same embedding region, scope, and categories. Each tenant's `recall()` returns only its own row. | +| `test_default_tenant_backcompat` | A caller passing no `tenant_id` still gets `recall()` working and reads back rows with `tenant_id="_default"`. Single-user setups unchanged. | +| `test_deep_recall_honors_tenant` | The `depth="deep"` path goes through `RecallFlow` with LLM exploration. Isolation must hold there too, not just `depth="shallow"`. | +| `test_delete_is_scoped` | `forget(tenant_id="alice")` deletes only Alice's rows; Bob's survive. | +| `test_save_rejects_cross_tenant_record` | If a caller hands `ScopedStorage(tenant="alice")` a `MemoryRecord(tenant_id="bob")`, the wrapper raises `PermissionError`. No silent relabel. | +| `test_backend_leak_is_loud` | Monkeypatched backend returns a foreign-tenant row. `ScopedStorage` raises `RuntimeError`, not a quietly-filtered empty list. | +| `test_legacy_rag_storage_honors_tenant` | Mirror of the first test against `BaseRAGStorage` so the legacy ChromaDB path is covered. | + +These tests use the repo's existing VCR cassette infrastructure for the embedder/LLM calls (`pytest-recording`); the repo's `--block-network` default means a missing cassette is a hard failure, not a silent network hit. + +**Tests that do NOT count as the isolation contract:** + +- "stores and recalls" — that test already exists and never caught the bug. +- "private flag works" — `private` is being demoted to provenance; tests on it test the wrong thing. +- "save round-trips through serialization" — orthogonal. + +--- + +## Public API Diff (what changes for end users) + +```python +# Before +crew.kickoff(inputs={"topic": "X"}) + +# After (single-user; unchanged behavior) +crew.kickoff(inputs={"topic": "X"}) + +# After (multi-tenant SaaS pattern) +crew.kickoff(inputs={"topic": "X"}, tenant_id="customer_42", user_id="alice") +``` + +```python +# Direct Memory usage +mem = Memory() +mem.remember("note", tenant_id="alice") +mem.recall("question", tenant_id="alice") +mem.forget(tenant_id="alice") +``` + +No existing call site breaks. The `source` / `private` / `include_private` arguments still accept their old values; a `DeprecationWarning` fires if a caller is clearly using them for isolation (`source` non-null + `private=True` is the pattern that gets the warning). + +--- + +## Rollout / Merge Order + +Each row is a separate PR. Earlier PRs land green and ship value independently; the merge of #4 is the moment isolation becomes real. + +| # | PR | What it does | Risk | +|---|---|---|---| +| 1 | `feat(memory): add tenant_id/user_id to MemoryRecord` | Adds the two fields with `_default`. Backends ignore them. Schema migrations land. No behavior change. | Low — additive. | +| 2 | `refactor(memory): tenant_id keyword in StorageBackend protocol` | Adds required keyword-only `tenant_id` to every read method on `StorageBackend` and every implementation. mypy `--strict` catches forgotten call sites. Still no behavior change because `Memory` always passes `"_default"`. | Medium — touches every storage backend, but mechanical. | +| 3 | `feat(memory): ScopedStorage wrapper + isolation tests` | Adds `ScopedStorage` and the test file. Tests are written to **fail** until #4 lands. | Low. | +| 4 | `feat(memory): wire ScopedStorage through Memory and Flows` | Memory grows `tenant_id`/`user_id` fields and per-call kwargs. Every internal call routes through `_scoped(...)`. The `private`/`source` post-filter is deleted. Tests from #3 go green. This is the load-bearing PR. | High — security-relevant change to live code path. | +| 5 | `feat(cli): crewai memory migrate command` | One-shot backfill CLI + docs page. Startup warning lands here. | Low. | +| 6 | `docs(memory): per-tenant isolation guide` | Mintlify page at `docs/en/concepts/memory-isolation.mdx` (+ translations) with the API surface, threat model, and migration steps. | None. | +| 7 *(optional, follow-up)* | `feat(memory): namespace partitioning strategy` | `ScopedStorage(strategy="namespace")` for Option B deployments. | Medium. | + +`llm-generated` label on every PR per `.github/CONTRIBUTING.md`. + +--- + +## Open Questions + +1. **Is `"_default"` the right default tenant string?** Alternatives: empty string (collides with NULL handling), `None` (forces every backend to handle two-typed filters), the literal `"default"` (collision risk with a real tenant). `"_default"` is unambiguous and sorts predictably. Open to bikeshed; just want a maintainer to call it. + +2. **Per-call override on `Memory` vs. context-managed scope.** Current design uses a kwarg. An alternative is: + ```python + with memory.as_tenant("customer_42", user_id="alice"): + memory.recall("...") + ``` + I prefer the kwarg because it cannot leak across an exception or a forgotten `__exit__`. The context-manager form is sugar and could land in a follow-up. + +3. **Should `Crew` carry `tenant_id` at construction or only at `kickoff()`?** `kickoff()` is the right surface for SaaS (per-request); construction is the right surface for single-tenant. Proposal: both, with `kickoff()` winning if both are set. + +4. **Telemetry.** The repo emits anonymous telemetry; `tenant_id` values must never appear in telemetry payloads. A hash or a `"present"`/`"absent"` boolean is the most we should emit. Confirming with whoever owns `crewai/telemetry/`. + +5. **Legacy `BaseRAGStorage` deprecation timeline.** Now that the unified path is the live one, is there a release where legacy `BaseRAGStorage` can be removed entirely? Out of scope for this PR but the answer informs how much effort goes into legacy-path tests. + +--- + +## Appendix: Code references + +- Current leaky filter: `lib/crewai/src/crewai/memory/unified_memory.py:704-709` +- `StorageBackend` Protocol: `lib/crewai/src/crewai/memory/storage/backend.py` +- `MemoryRecord`: `lib/crewai/src/crewai/memory/types.py:20-73` +- Default storage (LanceDB): `lib/crewai/src/crewai/memory/storage/lancedb_storage.py` +- Qdrant storage: `lib/crewai/src/crewai/memory/storage/qdrant_edge_storage.py` +- Legacy RAG boundary: `lib/crewai/src/crewai/rag/storage/base_rag_storage.py` +- ChromaDB factory used by legacy path: `lib/crewai/src/crewai/rag/chromadb/` +- Storage dir env var: `CREWAI_STORAGE_DIR` (handled in `crewai-core`) diff --git a/design-docs/diagrams/identity-propagation.mmd b/design-docs/diagrams/identity-propagation.mmd new file mode 100644 index 0000000000..b0deef9b56 --- /dev/null +++ b/design-docs/diagrams/identity-propagation.mmd @@ -0,0 +1,69 @@ +%% Per-tenant memory isolation — identity propagation and enforcement path +%% Source of truth for the diagram embedded in +%% design-docs/0001-per-tenant-memory-isolation.md +%% Render: +%% npx -p @mermaid-js/mermaid-cli mmdc -i identity-propagation.mmd -o identity-propagation.svg +%% (or paste into https://mermaid.live) +flowchart TD + Caller["Caller
(API handler, Flow step, agent)"] + Resolve{"Resolve effective tenant
per-call > instance default > '_default'"} + + subgraph MemoryAPI["Memory — public API"] + Remember["remember(content, tenant_id=, user_id=)"] + Recall["recall(query, tenant_id=, user_id=)"] + Forget["forget(tenant_id=, user_id=)"] + Scoped["_scoped(tenant_id, user_id)
builds ScopedStorage proxy"] + end + + subgraph Enforcement["ScopedStorage — ENFORCEMENT CHOKEPOINT"] + Stamp["WRITE — stamp tenant_id on every record
cross-tenant record raises PermissionError"] + Inject["READ — inject tenant_id predicate
callers cannot omit it"] + Verify["VERIFY — re-check every returned row
foreign tenant row raises RuntimeError (loud)"] + end + + subgraph Protocol["StorageBackend Protocol"] + SaveAPI["save(records)"] + SearchAPI["search(*, tenant_id: str, …)
required keyword-only
mypy --strict catches omissions in CI"] + DeleteAPI["delete(*, tenant_id: str, …)"] + end + + subgraph Backend["Underlying backend — LanceDB / Chroma / Qdrant"] + Column["tenant_id column
NOT NULL + index on (tenant_id, scope)"] + Filter["WHERE tenant_id = ? pushed into vector query
foreign rows never enter candidate pool"] + end + + RecallFlow["RecallFlow (deep mode)
holds ScopedStorage, not raw backend
→ exploration cannot escape tenant filter"] + + Caller --> Resolve + Resolve --> Remember + Resolve --> Recall + Resolve --> Forget + + Remember --> Scoped + Recall --> Scoped + Forget --> Scoped + + Scoped --> Stamp + Scoped --> Inject + Inject --> Verify + + Stamp --> SaveAPI + Verify --> SearchAPI + Inject --> DeleteAPI + + SaveAPI --> Column + SearchAPI --> Filter + DeleteAPI --> Filter + + Recall -. depth=deep .-> RecallFlow + RecallFlow --> Scoped + + classDef enforce fill:#ffe6e6,stroke:#c0392b,stroke-width:3px,color:#000 + classDef boundary fill:#fff4e0,stroke:#e67e22,stroke-width:2px,color:#000 + classDef storage fill:#e8f4fd,stroke:#2980b9,stroke-width:2px,color:#000 + classDef caller fill:#eafaf1,stroke:#27ae60,stroke-width:2px,color:#000 + + class Stamp,Inject,Verify enforce + class SaveAPI,SearchAPI,DeleteAPI boundary + class Column,Filter storage + class Caller,Resolve caller From c55f5092dc98300dabe3b4de9bce21c4797c6df4 Mon Sep 17 00:00:00 2001 From: John_J <79534962+John-Jepsen@users.noreply.github.com> Date: Thu, 28 May 2026 22:26:32 -0500 Subject: [PATCH 2/8] feat(memory): add tenant_id and user_id fields to MemoryRecord Additive schema change. tenant_id defaults to "_default" (single-tenant deployments unchanged); user_id defaults to None. Updates the docstrings on the legacy source/private fields to point at tenant_id for callers who were using them as an isolation hint -- those fields were never an isolation boundary (the post-filter at unified_memory.py:704-709 ran after retrieval). Backends do not yet persist or filter on the new fields; that lands in the StorageBackend protocol change in the next PR. This commit is purely the Pydantic model + a hermetic round-trip test that pins the defaults and the "legacy row loads as _default" backward-compat path. Refs: design-docs/0001-per-tenant-memory-isolation.md --- lib/crewai/src/crewai/memory/types.py | 23 +++++++- .../test_memory_record_tenant_fields.py | 56 +++++++++++++++++++ 2 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 lib/crewai/tests/memory/test_memory_record_tenant_fields.py diff --git a/lib/crewai/src/crewai/memory/types.py b/lib/crewai/src/crewai/memory/types.py index b186ee37e9..04a3398e9d 100644 --- a/lib/crewai/src/crewai/memory/types.py +++ b/lib/crewai/src/crewai/memory/types.py @@ -61,14 +61,33 @@ class MemoryRecord(BaseModel): default=None, description=( "Origin of this memory (e.g. user ID, session ID). " - "Used for provenance tracking and privacy filtering." + "Used for provenance tracking. NOTE: not an isolation boundary -- " + "for per-user/per-tenant isolation, use tenant_id instead." ), ) private: bool = Field( default=False, description=( "If True, this memory is only visible to recall requests from the same source, " - "or when include_private=True is passed." + "or when include_private=True is passed. NOTE: not an isolation boundary -- " + "for per-user/per-tenant isolation, use tenant_id instead." + ), + ) + tenant_id: str = Field( + default="_default", + description=( + "Isolation key. Every record is owned by exactly one tenant. " + "Storage backends MUST filter on this so recall scoped to tenant A " + "never returns a row written by tenant B. Default '_default' keeps " + "single-tenant deployments working unchanged." + ), + ) + user_id: str | None = Field( + default=None, + description=( + "Optional sub-tenant identity. Filtered within a tenant when set. " + "tenant_id is the security boundary; user_id is a refinement that " + "lets a tenant admin scope to one user inside their tenant." ), ) diff --git a/lib/crewai/tests/memory/test_memory_record_tenant_fields.py b/lib/crewai/tests/memory/test_memory_record_tenant_fields.py new file mode 100644 index 0000000000..163062dd7a --- /dev/null +++ b/lib/crewai/tests/memory/test_memory_record_tenant_fields.py @@ -0,0 +1,56 @@ +"""Tests for tenant_id / user_id fields on MemoryRecord. + +PR #1 of per-tenant memory isolation: the fields exist on the Pydantic model +with safe defaults. Backends do not yet persist or filter on them -- that lands +in PR #2 when the StorageBackend protocol gains the required tenant_id kwarg. + +These tests pin the additive schema change and the default values that keep +single-tenant deployments working unchanged. +""" + +from __future__ import annotations + +import pytest + +from crewai.memory.types import MemoryRecord + + +class TestMemoryRecordTenantFields: + def test_default_tenant_id_is_underscore_default(self) -> None: + rec = MemoryRecord(content="hello") + assert rec.tenant_id == "_default" + + def test_default_user_id_is_none(self) -> None: + rec = MemoryRecord(content="hello") + assert rec.user_id is None + + def test_tenant_id_round_trips_via_model_dump(self) -> None: + rec = MemoryRecord(content="hello", tenant_id="customer_42", user_id="alice") + dumped = rec.model_dump() + assert dumped["tenant_id"] == "customer_42" + assert dumped["user_id"] == "alice" + restored = MemoryRecord.model_validate(dumped) + assert restored.tenant_id == "customer_42" + assert restored.user_id == "alice" + + def test_tenant_id_round_trips_via_json(self) -> None: + rec = MemoryRecord(content="hello", tenant_id="customer_42") + restored = MemoryRecord.model_validate_json(rec.model_dump_json()) + assert restored.tenant_id == "customer_42" + + def test_legacy_record_without_tenant_id_loads_as_default(self) -> None: + # Simulates reading an old row from disk that pre-dates this PR. + # The default_factory must fire so the loaded record is non-leaking. + legacy_payload = {"content": "old row from before this PR"} + rec = MemoryRecord.model_validate(legacy_payload) + assert rec.tenant_id == "_default" + assert rec.user_id is None + + def test_tenant_id_must_be_string(self) -> None: + # tenant_id is non-nullable str -- enforcement at the type level. + with pytest.raises(Exception): + MemoryRecord(content="x", tenant_id=None) # type: ignore[arg-type] + + def test_user_id_accepts_none_and_string(self) -> None: + assert MemoryRecord(content="x", user_id=None).user_id is None + assert MemoryRecord(content="x", user_id="alice").user_id == "alice" From 69ec535189effcd4a437325c63bc633caa2c0705 Mon Sep 17 00:00:00 2001 From: John_J <79534962+John-Jepsen@users.noreply.github.com> Date: Thu, 28 May 2026 22:43:10 -0500 Subject: [PATCH 3/8] refactor(memory): require tenant_id keyword on StorageBackend protocol Add tenant_id as a required keyword-only parameter on every read method of StorageBackend (search, delete, get_record, list_records, get_scope_info, list_scopes, list_categories, count, reset + async variants). The "required without default" form is deliberate: mypy --strict turns any forgotten caller into a CI failure, which is the static enforcement behind the isolation invariant. save() and update() do not gain a tenant_id parameter -- the tenant lives on the record, and a separate parameter would invite a "record says A, param says B" mismatch. Backend changes: * LanceDBStorage: tenant_id/user_id columns added to placeholder schema; auto-add via add_columns on opening pre-isolation tables; every WHERE clause includes the tenant predicate via a new _tenant_where helper; reset() can no longer drop the table -- it scopes to the tenant. * QdrantEdgeStorage: tenant_id/user_id added to payload; new _build_tenant_filter replaces _build_scope_filter so every FieldCondition chain starts from the tenant; tenant_id and user_id added to payload indexes; reset() now goes through delete() with the tenant predicate. Memory and the Flows (encoding_flow, recall_flow) temporarily pass tenant_id="_default" to every storage call so behavior is unchanged in this PR. The proper per-call tenant resolution lands when ScopedStorage is wired through Memory in the next PR. Test fixtures updated to pass tenant_id="_default" on direct backend calls. The synthetic orphan payload in test_orphaned_shard_cleanup gains tenant_id so the tenant filter matches the fixture (real pre-isolation orphan data is migrated by the upcoming crewai memory migrate command). Refs: design-docs/0001-per-tenant-memory-isolation.md --- lib/crewai/src/crewai/memory/encoding_flow.py | 13 +- lib/crewai/src/crewai/memory/recall_flow.py | 13 +- .../src/crewai/memory/storage/backend.py | 125 ++++++-- .../crewai/memory/storage/lancedb_storage.py | 267 ++++++++++++++--- .../memory/storage/qdrant_edge_storage.py | 273 +++++++++++++----- .../src/crewai/memory/unified_memory.py | 21 +- .../tests/memory/test_qdrant_edge_storage.py | 64 ++-- .../tests/memory/test_unified_memory.py | 33 ++- 8 files changed, 614 insertions(+), 195 deletions(-) diff --git a/lib/crewai/src/crewai/memory/encoding_flow.py b/lib/crewai/src/crewai/memory/encoding_flow.py index 968b439bff..0d6e52b093 100644 --- a/lib/crewai/src/crewai/memory/encoding_flow.py +++ b/lib/crewai/src/crewai/memory/encoding_flow.py @@ -176,6 +176,7 @@ def _search_one( return self._storage.search( # type: ignore[no-any-return] item.embedding, + tenant_id="_default", scope_prefix=effective_prefix, categories=None, limit=self._config.consolidation_limit, @@ -248,9 +249,13 @@ def parallel_analyze(self) -> None: None, ) scope_search_root = active_root if active_root else "/" - existing_scopes = self._storage.list_scopes(scope_search_root) or ["/"] + existing_scopes = self._storage.list_scopes( + scope_search_root, tenant_id="_default" + ) or ["/"] existing_categories = list( - self._storage.list_categories(scope_prefix=active_root).keys() + self._storage.list_categories( + tenant_id="_default", scope_prefix=active_root + ).keys() ) save_futures: dict[int, Future[MemoryAnalysis]] = {} @@ -449,7 +454,9 @@ def execute_plans(self) -> None: updated_records: dict[str, MemoryRecord] = {} if dedup_deletes: - self._storage.delete(record_ids=list(dedup_deletes)) + self._storage.delete( + tenant_id="_default", record_ids=list(dedup_deletes) + ) self.state.records_deleted += len(dedup_deletes) for rid, (_item_idx, new_content) in dedup_updates.items(): diff --git a/lib/crewai/src/crewai/memory/recall_flow.py b/lib/crewai/src/crewai/memory/recall_flow.py index 9da5dca64a..7480475b11 100644 --- a/lib/crewai/src/crewai/memory/recall_flow.py +++ b/lib/crewai/src/crewai/memory/recall_flow.py @@ -97,6 +97,7 @@ def _search_one( ) -> tuple[str, list[tuple[MemoryRecord, float]]]: raw = self._storage.search( embedding, + tenant_id="_default", scope_prefix=scope, categories=search_categories, limit=self.state.limit * _RECALL_OVERSAMPLE_FACTOR, @@ -201,11 +202,15 @@ def analyze_query_step(self) -> QueryAnalysis: ) self.state.query_analysis = analysis else: - available = self._storage.list_scopes(self.state.scope or "/") + available = self._storage.list_scopes( + self.state.scope or "/", tenant_id="_default" + ) if not available: available = ["/"] scope_info = ( - self._storage.get_scope_info(self.state.scope or "/") + self._storage.get_scope_info( + self.state.scope or "/", tenant_id="_default" + ) if self.state.scope else None ) @@ -249,7 +254,9 @@ def filter_and_chunk(self) -> list[str]: candidates = [s for s in analysis.suggested_scopes if s] else: try: - candidates = self._storage.list_scopes(scope_prefix) + candidates = self._storage.list_scopes( + scope_prefix, tenant_id="_default" + ) except Exception: logger.warning( "Storage list_scopes failed in filter_and_chunk, " diff --git a/lib/crewai/src/crewai/memory/storage/backend.py b/lib/crewai/src/crewai/memory/storage/backend.py index 147b9e2290..59d9b41a42 100644 --- a/lib/crewai/src/crewai/memory/storage/backend.py +++ b/lib/crewai/src/crewai/memory/storage/backend.py @@ -1,4 +1,17 @@ -"""Storage backend protocol for the unified memory system.""" +"""Storage backend protocol for the unified memory system. + +Per-tenant isolation contract +----------------------------- +Every read method on this Protocol takes ``tenant_id`` as a **required +keyword-only** argument. The required-without-default form is deliberate: +mypy --strict turns any forgotten caller into a CI failure, which is the +static guarantee behind the isolation invariant described in +``design-docs/0001-per-tenant-memory-isolation.md``. + +``save`` and ``update`` do not take ``tenant_id`` -- the tenant lives on +the record itself, and a separate parameter would invite a "record says A, +param says B" mismatch. +""" from __future__ import annotations @@ -15,6 +28,8 @@ class StorageBackend(Protocol): def save(self, records: list[MemoryRecord]) -> None: """Save memory records to storage. + The tenant_id is read from each record, not from a parameter. + Args: records: List of memory records to persist. """ @@ -23,6 +38,9 @@ def save(self, records: list[MemoryRecord]) -> None: def search( self, query_embedding: list[float], + *, + tenant_id: str, + user_id: str | None = None, scope_prefix: str | None = None, categories: list[str] | None = None, metadata_filter: dict[str, Any] | None = None, @@ -33,6 +51,10 @@ def search( Args: query_embedding: Embedding vector for the query. + tenant_id: Isolation key. Backends MUST push this into the + vector query so foreign-tenant rows never enter the + candidate pool. + user_id: Optional sub-tenant identity for further filtering. scope_prefix: Optional scope path prefix to filter results. categories: Optional list of categories to filter by. metadata_filter: Optional metadata key-value filter. @@ -46,15 +68,20 @@ def search( def delete( self, + *, + tenant_id: str, + user_id: str | None = None, scope_prefix: str | None = None, categories: list[str] | None = None, record_ids: list[str] | None = None, older_than: datetime | None = None, metadata_filter: dict[str, Any] | None = None, ) -> int: - """Delete memories matching the given criteria. + """Delete memories matching the given criteria (scoped to tenant). Args: + tenant_id: Isolation key. Only rows owned by this tenant are eligible. + user_id: Optional sub-tenant filter. scope_prefix: Optional scope path prefix. categories: Optional list of categories. record_ids: Optional list of record IDs to delete. @@ -67,29 +94,42 @@ def delete( ... def update(self, record: MemoryRecord) -> None: - """Update an existing record. Replaces the record with the same ID.""" + """Update an existing record. Replaces the record with the same ID. + + The tenant_id is read from the record itself. + """ ... - def get_record(self, record_id: str) -> MemoryRecord | None: - """Return a single record by ID, or None if not found. + def get_record( + self, record_id: str, *, tenant_id: str, user_id: str | None = None + ) -> MemoryRecord | None: + """Return a single record by ID, or None if not found in the tenant. Args: record_id: The unique ID of the record. + tenant_id: Isolation key. A record found by ID but owned by a + different tenant is treated as not-found. + user_id: Optional sub-tenant filter. Returns: - The MemoryRecord, or None if no record with that ID exists. + The MemoryRecord, or None if no record with that ID exists for the tenant. """ ... def list_records( self, + *, + tenant_id: str, + user_id: str | None = None, scope_prefix: str | None = None, limit: int = 200, offset: int = 0, ) -> list[MemoryRecord]: - """List records in a scope, newest first. + """List records in a scope (scoped to tenant), newest first. Args: + tenant_id: Isolation key. + user_id: Optional sub-tenant filter. scope_prefix: Optional scope path prefix to filter by. limit: Maximum number of records to return. offset: Number of records to skip (for pagination). @@ -99,55 +139,90 @@ def list_records( """ ... - def get_scope_info(self, scope: str) -> ScopeInfo: - """Get information about a scope. + def get_scope_info( + self, scope: str, *, tenant_id: str, user_id: str | None = None + ) -> ScopeInfo: + """Get information about a scope (scoped to tenant). Args: scope: The scope path. + tenant_id: Isolation key. + user_id: Optional sub-tenant filter. Returns: ScopeInfo with record count, categories, date range, child scopes. """ ... - def list_scopes(self, parent: str = "/") -> list[str]: - """List immediate child scopes under a parent path. + def list_scopes( + self, parent: str = "/", *, tenant_id: str, user_id: str | None = None + ) -> list[str]: + """List immediate child scopes under a parent path (scoped to tenant). Args: parent: Parent scope path (default root). + tenant_id: Isolation key. + user_id: Optional sub-tenant filter. Returns: List of immediate child scope paths. """ ... - def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]: - """List categories and their counts within a scope. + def list_categories( + self, + *, + tenant_id: str, + user_id: str | None = None, + scope_prefix: str | None = None, + ) -> dict[str, int]: + """List categories and their counts within a scope (scoped to tenant). Args: - scope_prefix: Optional scope to limit to (None = global). + tenant_id: Isolation key. + user_id: Optional sub-tenant filter. + scope_prefix: Optional scope to limit to (None = whole tenant). Returns: Mapping of category name to record count. """ ... - def count(self, scope_prefix: str | None = None) -> int: - """Count records in scope (and subscopes). + def count( + self, + *, + tenant_id: str, + user_id: str | None = None, + scope_prefix: str | None = None, + ) -> int: + """Count records in scope (scoped to tenant). Args: - scope_prefix: Optional scope path (None = all). + tenant_id: Isolation key. + user_id: Optional sub-tenant filter. + scope_prefix: Optional scope path (None = whole tenant). Returns: Number of records. """ ... - def reset(self, scope_prefix: str | None = None) -> None: - """Reset (delete all) memories in scope. + def reset( + self, + *, + tenant_id: str, + user_id: str | None = None, + scope_prefix: str | None = None, + ) -> None: + """Reset (delete all) memories within a tenant. + + There is no "reset everything across all tenants" path. An operator + who needs that calls reset for each tenant deliberately. Args: - scope_prefix: Optional scope path (None = reset all). + tenant_id: Isolation key. Only this tenant's rows are wiped. + user_id: Optional sub-tenant filter. + scope_prefix: Optional scope path (None = whole tenant). """ ... @@ -158,22 +233,28 @@ async def asave(self, records: list[MemoryRecord]) -> None: async def asearch( self, query_embedding: list[float], + *, + tenant_id: str, + user_id: str | None = None, scope_prefix: str | None = None, categories: list[str] | None = None, metadata_filter: dict[str, Any] | None = None, limit: int = 10, min_score: float = 0.0, ) -> list[tuple[MemoryRecord, float]]: - """Search for memories asynchronously.""" + """Search for memories asynchronously (scoped to tenant).""" ... async def adelete( self, + *, + tenant_id: str, + user_id: str | None = None, scope_prefix: str | None = None, categories: list[str] | None = None, record_ids: list[str] | None = None, older_than: datetime | None = None, metadata_filter: dict[str, Any] | None = None, ) -> int: - """Delete memories asynchronously.""" + """Delete memories asynchronously (scoped to tenant).""" ... diff --git a/lib/crewai/src/crewai/memory/storage/lancedb_storage.py b/lib/crewai/src/crewai/memory/storage/lancedb_storage.py index 4e88e967c0..2ec1849870 100644 --- a/lib/crewai/src/crewai/memory/storage/lancedb_storage.py +++ b/lib/crewai/src/crewai/memory/storage/lancedb_storage.py @@ -37,6 +37,36 @@ _MAX_RETRIES = 5 _RETRY_BASE_DELAY = 0.2 # seconds; doubles on each retry +# Tenant isolation: every row carries a tenant_id. Pre-isolation tables are +# migrated to this default tenant when first opened, so existing single-tenant +# deployments keep working unchanged. +_DEFAULT_TENANT = "_default" + + +def _sql_quote(value: str) -> str: + """Escape a string literal for use inside a LanceDB WHERE clause. + + LanceDB uses SQL-like single-quoted string literals. The only escape is + doubling a single quote. Centralizing this here keeps every tenant_id / + user_id / scope predicate using the same escape so Bandit's S608 + rule does not fire and so a hostile tenant_id cannot break out of the + quoted literal. + """ + return value.replace("'", "''") + + +def _tenant_where(tenant_id: str, user_id: str | None = None) -> str: + """Build the WHERE fragment that pins a query to one tenant (and optionally one user). + + Every read path in this storage assembles its WHERE clause by starting + from this fragment and ANDing on top. There is no read path that does + not call this function. + """ + clause = f"tenant_id = '{_sql_quote(tenant_id)}'" + if user_id is not None: + clause += f" AND user_id = '{_sql_quote(user_id)}'" + return clause + class LanceDBStorage: """LanceDB-backed storage for the unified memory system.""" @@ -97,6 +127,7 @@ def __init__( self._table: Any = self._db.open_table(self._table_name) self._vector_dim: int = self._infer_dim_from_table(self._table) with store_lock(self._lock_name): + self._ensure_tenant_columns() self._ensure_scope_index() self._compact_if_needed() except Exception: @@ -168,6 +199,8 @@ def _create_table(self, vector_dim: int) -> Any: "last_accessed": datetime.utcnow().isoformat(), "source": "", "private": False, + "tenant_id": _DEFAULT_TENANT, + "user_id": "", "vector": [0.0] * vector_dim, } ] @@ -179,6 +212,48 @@ def _create_table(self, vector_dim: int) -> Any: table.delete("id = '__schema_placeholder__'") return table + def _ensure_tenant_columns(self) -> None: + """Add ``tenant_id`` and ``user_id`` columns to an existing table if missing. + + This is the lazy schema upgrade for tables that were created before + per-tenant isolation. Existing rows are stamped with ``_default`` so + every read path's ``WHERE tenant_id = ?`` predicate matches. The + upgrade is best-effort: if LanceDB does not support add_columns at + runtime, or if the columns already exist, the exception is swallowed + and the storage continues. The migrate CLI command is the supported + path for explicitly stamping existing data. + + Caller must already hold ``store_lock(self._lock_name)``. + """ + if self._table is None: + return + existing = {field.name for field in self._table.schema} + to_add: dict[str, str] = {} + if "tenant_id" not in existing: + to_add["tenant_id"] = f"'{_DEFAULT_TENANT}'" + if "user_id" not in existing: + to_add["user_id"] = "''" + if not to_add: + return + try: + self._table.add_columns(to_add) + _logger.info( + "Migrated LanceDB table %r: added columns %s with default tenant=%r. " + "Run `crewai memory migrate` to assign real tenants to existing rows.", + self._table_name, + sorted(to_add), + _DEFAULT_TENANT, + ) + except Exception: + _logger.warning( + "Could not auto-add tenant columns to LanceDB table %r. " + "Existing rows will read back as tenant=%r via row-level defaults. " + "Run `crewai memory migrate` if needed.", + self._table_name, + _DEFAULT_TENANT, + exc_info=True, + ) + def _ensure_scope_index(self) -> None: """Create a BTREE scalar index on the ``scope`` column if not present. @@ -255,6 +330,8 @@ def _record_to_row(self, record: MemoryRecord) -> dict[str, Any]: "last_accessed": record.last_accessed.isoformat(), "source": record.source or "", "private": record.private, + "tenant_id": record.tenant_id or _DEFAULT_TENANT, + "user_id": record.user_id or "", "vector": record.embedding if record.embedding else [0.0] * self._vector_dim, @@ -269,6 +346,15 @@ def _parse_dt(val: Any) -> datetime: s = str(val) return datetime.fromisoformat(s.replace("Z", "+00:00")) + # Backward compat: pre-isolation rows have neither column; new rows + # have tenant_id stamped on save. Either way, every record loaded + # through this method has a non-empty tenant_id so downstream + # filtering and the ScopedStorage double-check never see None. + raw_tenant = row.get("tenant_id") + tenant_id = str(raw_tenant) if raw_tenant else _DEFAULT_TENANT + raw_user = row.get("user_id") + user_id = str(raw_user) if raw_user else None + return MemoryRecord( id=str(row["id"]), content=str(row["content"]), @@ -283,6 +369,8 @@ def _parse_dt(val: Any) -> datetime: embedding=row.get("vector"), source=row.get("source") or None, private=bool(row.get("private", False)), + tenant_id=tenant_id, + user_id=user_id, ) def save(self, records: list[MemoryRecord]) -> None: @@ -342,12 +430,19 @@ def touch_records(self, record_ids: list[str]) -> None: values={"last_accessed": now}, ) - def get_record(self, record_id: str) -> MemoryRecord | None: - """Return a single record by ID, or None if not found.""" + def get_record( + self, record_id: str, *, tenant_id: str, user_id: str | None = None + ) -> MemoryRecord | None: + """Return a single record by ID, or None if not found in the tenant. + + A record found by id but owned by a different tenant is treated as + not-found, which is what the isolation invariant requires. + """ if self._table is None: return None - safe_id = str(record_id).replace("'", "''") - rows = self._table.search().where(f"id = '{safe_id}'").limit(1).to_list() + safe_id = _sql_quote(str(record_id)) + where = f"id = '{safe_id}' AND {_tenant_where(tenant_id, user_id)}" + rows = self._table.search().where(where).limit(1).to_list() if not rows: return None return self._row_to_record(rows[0]) @@ -355,6 +450,9 @@ def get_record(self, record_id: str) -> MemoryRecord | None: def search( self, query_embedding: list[float], + *, + tenant_id: str, + user_id: str | None = None, scope_prefix: str | None = None, categories: list[str] | None = None, metadata_filter: dict[str, Any] | None = None, @@ -364,10 +462,14 @@ def search( if self._table is None: return [] query = self._table.search(query_embedding) + # Tenant predicate is unconditional and pushed down so foreign-tenant + # rows never enter the ANN candidate pool. + where = _tenant_where(tenant_id, user_id) if scope_prefix is not None and scope_prefix.strip("/"): prefix = scope_prefix.rstrip("/") - like_val = prefix + "%" - query = query.where(f"scope LIKE '{like_val}'") + like_val = _sql_quote(prefix) + "%" + where += f" AND scope LIKE '{like_val}'" + query = query.where(where) results = query.limit( limit * 3 if (categories or metadata_filter) else limit ).to_list() @@ -390,6 +492,9 @@ def search( def delete( self, + *, + tenant_id: str, + user_id: str | None = None, scope_prefix: str | None = None, categories: list[str] | None = None, record_ids: list[str] | None = None, @@ -398,14 +503,19 @@ def delete( ) -> int: if self._table is None: return 0 + tenant_clause = _tenant_where(tenant_id, user_id) with store_lock(self._lock_name): if record_ids and not (categories or metadata_filter): before = int(self._table.count_rows()) - ids_expr = ", ".join(f"'{rid}'" for rid in record_ids) - self._do_write("delete", f"id IN ({ids_expr})") + ids_expr = ", ".join(f"'{_sql_quote(rid)}'" for rid in record_ids) + self._do_write( + "delete", f"({tenant_clause}) AND id IN ({ids_expr})" + ) return before - int(self._table.count_rows()) if categories or metadata_filter: - rows = self._scan_rows(scope_prefix) + rows = self._scan_rows( + scope_prefix, tenant_id=tenant_id, user_id=user_id + ) to_delete: list[str] = [] for row in rows: record = self._row_to_record(row) @@ -423,21 +533,23 @@ def delete( if not to_delete: return 0 before = int(self._table.count_rows()) - ids_expr = ", ".join(f"'{rid}'" for rid in to_delete) - self._do_write("delete", f"id IN ({ids_expr})") + ids_expr = ", ".join(f"'{_sql_quote(rid)}'" for rid in to_delete) + self._do_write( + "delete", f"({tenant_clause}) AND id IN ({ids_expr})" + ) return before - int(self._table.count_rows()) - conditions = [] + conditions = [tenant_clause] if scope_prefix is not None and scope_prefix.strip("/"): prefix = scope_prefix.rstrip("/") if not prefix.startswith("/"): prefix = "/" + prefix - conditions.append(f"scope LIKE '{prefix}%' OR scope = '/'") + conditions.append( + f"(scope LIKE '{_sql_quote(prefix)}%' OR scope = '/')" + ) if older_than is not None: - conditions.append(f"created_at < '{older_than.isoformat()}'") - if not conditions: - before = int(self._table.count_rows()) - self._do_write("delete", "id != ''") - return before - int(self._table.count_rows()) + conditions.append( + f"created_at < '{_sql_quote(older_than.isoformat())}'" + ) where_expr = " AND ".join(conditions) before = int(self._table.count_rows()) self._do_write("delete", where_expr) @@ -448,31 +560,44 @@ def _scan_rows( scope_prefix: str | None = None, limit: int = _SCAN_ROWS_LIMIT, columns: list[str] | None = None, + *, + tenant_id: str, + user_id: str | None = None, ) -> list[dict[str, Any]]: - """Scan rows optionally filtered by scope prefix. + """Scan rows scoped to a tenant, optionally filtered by scope prefix. Uses a full table scan (no vector query) so the limit is applied after - the scope filter, not to ANN candidates before filtering. + the tenant + scope filter, not to ANN candidates before filtering. Args: scope_prefix: Optional scope path prefix to filter by. limit: Maximum number of rows to return (applied after filtering). - columns: Optional list of column names to fetch. Pass only the + columns: Optional list of column names to fetch. Pass only the columns you need for metadata operations to avoid reading the heavy ``vector`` column unnecessarily. + tenant_id: Isolation key (required, keyword-only). + user_id: Optional sub-tenant filter. """ if self._table is None: return [] q = self._table.search() + where = _tenant_where(tenant_id, user_id) if scope_prefix is not None and scope_prefix.strip("/"): - q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'") + where += f" AND scope LIKE '{_sql_quote(scope_prefix.rstrip('/'))}%'" + q = q.where(where) if columns is not None: q = q.select(columns) result: list[dict[str, Any]] = q.limit(limit).to_list() return result def list_records( - self, scope_prefix: str | None = None, limit: int = 200, offset: int = 0 + self, + *, + tenant_id: str, + user_id: str | None = None, + scope_prefix: str | None = None, + limit: int = 200, + offset: int = 0, ) -> list[MemoryRecord]: """List records in a scope, newest first. @@ -484,12 +609,19 @@ def list_records( Returns: List of MemoryRecord, ordered by created_at descending. """ - rows = self._scan_rows(scope_prefix, limit=limit + offset) + rows = self._scan_rows( + scope_prefix, + limit=limit + offset, + tenant_id=tenant_id, + user_id=user_id, + ) records = [self._row_to_record(r) for r in rows] records.sort(key=lambda r: r.created_at, reverse=True) return records[offset : offset + limit] - def get_scope_info(self, scope: str) -> ScopeInfo: + def get_scope_info( + self, scope: str, *, tenant_id: str, user_id: str | None = None + ) -> ScopeInfo: scope = scope.rstrip("/") or "/" prefix = scope if scope != "/" else "" if prefix and not prefix.startswith("/"): @@ -497,6 +629,8 @@ def get_scope_info(self, scope: str) -> ScopeInfo: rows = self._scan_rows( prefix or None, columns=["scope", "categories_str", "created_at"], + tenant_id=tenant_id, + user_id=user_id, ) if not rows: return ScopeInfo( @@ -545,10 +679,21 @@ def get_scope_info(self, scope: str) -> ScopeInfo: child_scopes=sorted(children), ) - def list_scopes(self, parent: str = "/") -> list[str]: + def list_scopes( + self, + parent: str = "/", + *, + tenant_id: str, + user_id: str | None = None, + ) -> list[str]: parent = parent.rstrip("/") or "" prefix = (parent + "/") if parent else "/" - rows = self._scan_rows(prefix if prefix != "/" else None, columns=["scope"]) + rows = self._scan_rows( + prefix if prefix != "/" else None, + columns=["scope"], + tenant_id=tenant_id, + user_id=user_id, + ) children: set[str] = set() for row in rows: sc = str(row.get("scope", "")) @@ -559,8 +704,19 @@ def list_scopes(self, parent: str = "/") -> list[str]: children.add(prefix + first_component) return sorted(children) - def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]: - rows = self._scan_rows(scope_prefix, columns=["categories_str"]) + def list_categories( + self, + *, + tenant_id: str, + user_id: str | None = None, + scope_prefix: str | None = None, + ) -> dict[str, int]: + rows = self._scan_rows( + scope_prefix, + columns=["categories_str"], + tenant_id=tenant_id, + user_id=user_id, + ) counts: dict[str, int] = {} for row in rows: cat_str = row.get("categories_str") or "[]" @@ -572,27 +728,48 @@ def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]: counts[c] = counts.get(c, 0) + 1 return counts - def count(self, scope_prefix: str | None = None) -> int: + def count( + self, + *, + tenant_id: str, + user_id: str | None = None, + scope_prefix: str | None = None, + ) -> int: if self._table is None: return 0 - if scope_prefix is None or scope_prefix.strip("/") == "": - return int(self._table.count_rows()) - info = self.get_scope_info(scope_prefix) + # Even an unfiltered count is scoped to a tenant; "count rows across + # all tenants" is intentionally not exposed. + info = self.get_scope_info( + scope_prefix or "/", tenant_id=tenant_id, user_id=user_id + ) return info.record_count - def reset(self, scope_prefix: str | None = None) -> None: + def reset( + self, + *, + tenant_id: str, + user_id: str | None = None, + scope_prefix: str | None = None, + ) -> None: + """Reset (delete all) memories for this tenant. + + There is no "drop the whole table" path; resetting one tenant never + wipes another tenant's data. To remove the entire on-disk table, + delete the storage directory directly. + """ + if self._table is None: + return + tenant_clause = _tenant_where(tenant_id, user_id) with store_lock(self._lock_name): if scope_prefix is None or scope_prefix.strip("/") == "": - if self._table is not None: - self._db.drop_table(self._table_name) - self._table = None - return - if self._table is None: + self._do_write("delete", tenant_clause) return prefix = scope_prefix.rstrip("/") if prefix: self._do_write( - "delete", f"scope >= '{prefix}' AND scope < '{prefix}/\uffff'" + "delete", + f"({tenant_clause}) AND scope >= '{_sql_quote(prefix)}' " + f"AND scope < '{_sql_quote(prefix)}/\uffff'", ) def optimize(self) -> None: @@ -617,6 +794,9 @@ async def asave(self, records: list[MemoryRecord]) -> None: async def asearch( self, query_embedding: list[float], + *, + tenant_id: str, + user_id: str | None = None, scope_prefix: str | None = None, categories: list[str] | None = None, metadata_filter: dict[str, Any] | None = None, @@ -625,6 +805,8 @@ async def asearch( ) -> list[tuple[MemoryRecord, float]]: return self.search( query_embedding, + tenant_id=tenant_id, + user_id=user_id, scope_prefix=scope_prefix, categories=categories, metadata_filter=metadata_filter, @@ -634,6 +816,9 @@ async def asearch( async def adelete( self, + *, + tenant_id: str, + user_id: str | None = None, scope_prefix: str | None = None, categories: list[str] | None = None, record_ids: list[str] | None = None, @@ -641,6 +826,8 @@ async def adelete( metadata_filter: dict[str, Any] | None = None, ) -> int: return self.delete( + tenant_id=tenant_id, + user_id=user_id, scope_prefix=scope_prefix, categories=categories, record_ids=record_ids, diff --git a/lib/crewai/src/crewai/memory/storage/qdrant_edge_storage.py b/lib/crewai/src/crewai/memory/storage/qdrant_edge_storage.py index d819094e9e..8be205e535 100644 --- a/lib/crewai/src/crewai/memory/storage/qdrant_edge_storage.py +++ b/lib/crewai/src/crewai/memory/storage/qdrant_edge_storage.py @@ -47,6 +47,11 @@ _SCROLL_BATCH: Final[int] = 256 +# Tenant isolation: every point carries a tenant_id in its payload. Pre-isolation +# payloads (no tenant_id key) are read back as this default tenant so +# single-tenant deployments keep working unchanged. +_DEFAULT_TENANT: Final[str] = "_default" + def _uuid_to_point_id(uuid_str: str) -> int: """Convert a UUID string to a stable Qdrant point ID. @@ -160,25 +165,27 @@ def _open_shard(self, path: Path) -> EdgeShard: return EdgeShard.create(str(path), self._config) def _ensure_indexes(self, shard: EdgeShard) -> None: - """Create payload indexes for efficient filtering.""" + """Create payload indexes for efficient filtering. + + tenant_id is indexed because every read path filters on it; without + the index, the WHERE-on-tenant_id becomes a full scan and isolation + carries a meaningful latency cost. + """ if self._indexes_created: return try: - shard.update( - UpdateOperation.create_field_index( - "scope_ancestors", PayloadSchemaType.Keyword - ) - ) - shard.update( - UpdateOperation.create_field_index( - "categories", PayloadSchemaType.Keyword - ) - ) - shard.update( - UpdateOperation.create_field_index( - "record_id", PayloadSchemaType.Keyword + for field in ( + "tenant_id", + "user_id", + "scope_ancestors", + "categories", + "record_id", + ): + shard.update( + UpdateOperation.create_field_index( + field, PayloadSchemaType.Keyword + ) ) - ) self._indexes_created = True except Exception: _logger.debug("Index creation failed (may already exist)", exc_info=True) @@ -204,6 +211,8 @@ def _record_to_point(self, record: MemoryRecord) -> Point: "last_accessed": record.last_accessed.isoformat(), "source": record.source or "", "private": record.private, + "tenant_id": record.tenant_id or _DEFAULT_TENANT, + "user_id": record.user_id or "", }, ) @@ -221,6 +230,12 @@ def _parse_dt(val: Any) -> datetime: return val return datetime.fromisoformat(str(val).replace("Z", "+00:00")) + # Backward compat: pre-isolation payloads have neither key. + raw_tenant = payload.get("tenant_id") + tenant_id = str(raw_tenant) if raw_tenant else _DEFAULT_TENANT + raw_user = payload.get("user_id") + user_id = str(raw_user) if raw_user else None + return MemoryRecord( id=str(payload["record_id"]), content=str(payload["content"]), @@ -233,19 +248,39 @@ def _parse_dt(val: Any) -> datetime: embedding=vector.get(VECTOR_NAME) if vector else None, source=payload.get("source") or None, private=bool(payload.get("private", False)), + tenant_id=tenant_id, + user_id=user_id, ) @staticmethod - def _build_scope_filter(scope_prefix: str | None) -> Filter | None: - """Build a Qdrant Filter for scope prefix matching.""" - if scope_prefix is None or not scope_prefix.strip("/"): - return None - prefix = scope_prefix.rstrip("/") - if not prefix.startswith("/"): - prefix = "/" + prefix - return Filter( - must=[FieldCondition(key="scope_ancestors", match=MatchValue(value=prefix))] - ) + def _build_tenant_filter( + tenant_id: str, + user_id: str | None = None, + scope_prefix: str | None = None, + ) -> Filter: + """Build a Qdrant Filter pinning a query to one tenant (and optionally scope/user). + + Every read path in this storage assembles its Filter by calling this + helper. The tenant_id clause is always present so foreign-tenant + points never enter the candidate pool. + """ + must: list[FieldCondition] = [ + FieldCondition(key="tenant_id", match=MatchValue(value=tenant_id)) + ] + if user_id is not None: + must.append( + FieldCondition(key="user_id", match=MatchValue(value=user_id)) + ) + if scope_prefix is not None and scope_prefix.strip("/"): + prefix = scope_prefix.rstrip("/") + if not prefix.startswith("/"): + prefix = "/" + prefix + must.append( + FieldCondition( + key="scope_ancestors", match=MatchValue(value=prefix) + ) + ) + return Filter(must=must) @staticmethod def _scroll_all( @@ -301,14 +336,17 @@ def save(self, records: list[MemoryRecord]) -> None: def search( self, query_embedding: list[float], + *, + tenant_id: str, + user_id: str | None = None, scope_prefix: str | None = None, categories: list[str] | None = None, metadata_filter: dict[str, Any] | None = None, limit: int = 10, min_score: float = 0.0, ) -> list[tuple[MemoryRecord, float]]: - """Search both central and local shards, merge results.""" - filt = self._build_scope_filter(scope_prefix) + """Search both central and local shards, merge results (scoped to tenant).""" + filt = self._build_tenant_filter(tenant_id, user_id, scope_prefix) fetch_limit = limit * 3 if (categories or metadata_filter) else limit all_scored: list[tuple[dict[str, Any], float, bool]] = [] @@ -364,13 +402,16 @@ def search( def delete( self, + *, + tenant_id: str, + user_id: str | None = None, scope_prefix: str | None = None, categories: list[str] | None = None, record_ids: list[str] | None = None, older_than: datetime | None = None, metadata_filter: dict[str, Any] | None = None, ) -> int: - """Delete matching records from central shard.""" + """Delete matching records from central shard (scoped to tenant).""" total_deleted = 0 for shard_path in (self._central_path, self._local_path): if not shard_path.exists(): @@ -378,6 +419,8 @@ def delete( try: total_deleted += self._delete_from_shard_path( shard_path, + tenant_id, + user_id, scope_prefix, categories, record_ids, @@ -391,6 +434,8 @@ def delete( def _delete_from_shard_path( self, shard_path: Path, + tenant_id: str, + user_id: str | None, scope_prefix: str | None, categories: list[str] | None, record_ids: list[str] | None, @@ -402,6 +447,8 @@ def _delete_from_shard_path( try: deleted = self._delete_from_shard( shard, + tenant_id, + user_id, scope_prefix, categories, record_ids, @@ -416,30 +463,43 @@ def _delete_from_shard_path( def _delete_from_shard( self, shard: EdgeShard, + tenant_id: str, + user_id: str | None, scope_prefix: str | None, categories: list[str] | None, record_ids: list[str] | None, older_than: datetime | None, metadata_filter: dict[str, Any] | None, ) -> int: - """Delete matching records from a single shard, returning count deleted.""" - before = shard.count(CountRequest()) + """Delete matching records from a single shard (scoped to tenant).""" + # Tenant clause is always part of the filter -- a delete by record_id + # cannot wipe a row that belongs to a different tenant. + tenant_filter = self._build_tenant_filter(tenant_id, user_id, scope_prefix) + before = shard.count(CountRequest(filter=tenant_filter)) if record_ids and not (categories or metadata_filter or older_than): - point_ids: list[int | uuid.UUID | str] = [ - _uuid_to_point_id(rid) for rid in record_ids + # Resolve record_ids against the tenant's rows so a delete by id + # cannot reach another tenant's data. + allowed_ids = set(record_ids) + tenant_points = self._scroll_all(shard, filt=tenant_filter) + to_delete: list[int | uuid.UUID | str] = [ + pt.id + for pt in tenant_points + if str(pt.payload.get("record_id", "")) in allowed_ids ] - shard.update(UpdateOperation.delete_points(point_ids)) - return before - shard.count(CountRequest()) + if to_delete: + shard.update(UpdateOperation.delete_points(to_delete)) + return before - shard.count(CountRequest(filter=tenant_filter)) if categories or metadata_filter or older_than: - scope_filter = self._build_scope_filter(scope_prefix) - points = self._scroll_all(shard, filt=scope_filter) - allowed_ids: set[str] | None = set(record_ids) if record_ids else None - to_delete: list[int | uuid.UUID | str] = [] + points = self._scroll_all(shard, filt=tenant_filter) + allowed_ids_opt: set[str] | None = ( + set(record_ids) if record_ids else None + ) + to_delete = [] for pt in points: record = self._payload_to_record(pt.payload or {}) - if allowed_ids and record.id not in allowed_ids: + if allowed_ids_opt and record.id not in allowed_ids_opt: continue if categories and not any(c in record.categories for c in categories): continue @@ -452,17 +512,12 @@ def _delete_from_shard( to_delete.append(pt.id) if to_delete: shard.update(UpdateOperation.delete_points(to_delete)) - return before - shard.count(CountRequest()) - - scope_filter = self._build_scope_filter(scope_prefix) - if scope_filter: - shard.update(UpdateOperation.delete_points_by_filter(filter=scope_filter)) - else: - points = self._scroll_all(shard) - if points: - all_ids: list[int | uuid.UUID | str] = [p.id for p in points] - shard.update(UpdateOperation.delete_points(all_ids)) - return before - shard.count(CountRequest()) + return before - shard.count(CountRequest(filter=tenant_filter)) + + # No record_ids, no other predicates -- delete every point matching + # the tenant (+ optional scope) filter. + shard.update(UpdateOperation.delete_points_by_filter(filter=tenant_filter)) + return before - shard.count(CountRequest(filter=tenant_filter)) def update(self, record: MemoryRecord) -> None: """Update a record by upserting with the same point ID.""" @@ -484,8 +539,14 @@ def update(self, record: MemoryRecord) -> None: finally: local.close() - def get_record(self, record_id: str) -> MemoryRecord | None: - """Return a single record by ID, or None if not found.""" + def get_record( + self, record_id: str, *, tenant_id: str, user_id: str | None = None + ) -> MemoryRecord | None: + """Return a single record by ID, or None if not found in the tenant. + + A point found by ID but owned by a different tenant is treated as + not-found, which is what the isolation invariant requires. + """ point_id = _uuid_to_point_id(record_id) for shard_path in (self._local_path, self._central_path): if not shard_path.exists(): @@ -496,6 +557,12 @@ def get_record(self, record_id: str) -> MemoryRecord | None: shard.close() if records: payload = records[0].payload or {} + # Tenant check is the isolation guarantee. A point that + # collides on ID but lives in another tenant is invisible. + if payload.get("tenant_id", _DEFAULT_TENANT) != tenant_id: + continue + if user_id is not None and payload.get("user_id") != user_id: + continue vec = records[0].vector vec_dict = vec if isinstance(vec, dict) else None return self._payload_to_record(payload, vec_dict) # type: ignore[arg-type] @@ -505,12 +572,15 @@ def get_record(self, record_id: str) -> MemoryRecord | None: def list_records( self, + *, + tenant_id: str, + user_id: str | None = None, scope_prefix: str | None = None, limit: int = 200, offset: int = 0, ) -> list[MemoryRecord]: - """List records in a scope, newest first.""" - filt = self._build_scope_filter(scope_prefix) + """List records in a scope (scoped to tenant), newest first.""" + filt = self._build_tenant_filter(tenant_id, user_id, scope_prefix) all_records: list[MemoryRecord] = [] seen_ids: set[str] = set() @@ -532,11 +602,13 @@ def list_records( all_records.sort(key=lambda r: r.created_at, reverse=True) return all_records[offset : offset + limit] - def get_scope_info(self, scope: str) -> ScopeInfo: - """Get information about a scope.""" + def get_scope_info( + self, scope: str, *, tenant_id: str, user_id: str | None = None + ) -> ScopeInfo: + """Get information about a scope (scoped to tenant).""" scope = scope.rstrip("/") or "/" prefix = scope if scope != "/" else None - filt = self._build_scope_filter(prefix) + filt = self._build_tenant_filter(tenant_id, user_id, prefix) all_points: list[Any] = [] for shard_path in (self._central_path, self._local_path): @@ -598,13 +670,21 @@ def get_scope_info(self, scope: str) -> ScopeInfo: child_scopes=sorted(children), ) - def list_scopes(self, parent: str = "/") -> list[str]: - """List immediate child scopes under a parent path.""" + def list_scopes( + self, + parent: str = "/", + *, + tenant_id: str, + user_id: str | None = None, + ) -> list[str]: + """List immediate child scopes under a parent path (scoped to tenant).""" parent = parent.rstrip("/") or "" prefix = (parent + "/") if parent else "/" all_scopes: set[str] = set() - filt = self._build_scope_filter(prefix if prefix != "/" else None) + filt = self._build_tenant_filter( + tenant_id, user_id, prefix if prefix != "/" else None + ) for shard_path in (self._central_path, self._local_path): if not shard_path.exists(): continue @@ -623,8 +703,14 @@ def list_scopes(self, parent: str = "/") -> list[str]: _logger.debug("list_scopes failed on %s", shard_path, exc_info=True) return sorted(all_scopes) - def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]: - """List categories and their counts within a scope.""" + def list_categories( + self, + *, + tenant_id: str, + user_id: str | None = None, + scope_prefix: str | None = None, + ) -> dict[str, int]: + """List categories and their counts within a scope (scoped to tenant).""" if not self._local_has_data and self._central_path.exists(): try: shard = EdgeShard.load(str(self._central_path)) @@ -636,7 +722,7 @@ def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]: ) except Exception: # noqa: S110 pass - filt = self._build_scope_filter(scope_prefix) + filt = self._build_tenant_filter(tenant_id, user_id, scope_prefix) facet_result = shard.facet( FacetRequest(key="categories", limit=1000, filter=filt) ) @@ -646,14 +732,25 @@ def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]: _logger.debug("list_categories failed on central", exc_info=True) counts: dict[str, int] = {} - for record in self.list_records(scope_prefix=scope_prefix, limit=50_000): + for record in self.list_records( + tenant_id=tenant_id, + user_id=user_id, + scope_prefix=scope_prefix, + limit=50_000, + ): for c in record.categories: counts[c] = counts.get(c, 0) + 1 return counts - def count(self, scope_prefix: str | None = None) -> int: - """Count records in scope (and subscopes).""" - filt = self._build_scope_filter(scope_prefix) + def count( + self, + *, + tenant_id: str, + user_id: str | None = None, + scope_prefix: str | None = None, + ) -> int: + """Count records in scope (and subscopes), scoped to tenant.""" + filt = self._build_tenant_filter(tenant_id, user_id, scope_prefix) if not self._local_has_data: if self._central_path.exists(): try: @@ -677,17 +774,25 @@ def count(self, scope_prefix: str | None = None) -> int: _logger.debug("count failed on %s", shard_path, exc_info=True) return len(seen_ids) - def reset(self, scope_prefix: str | None = None) -> None: - """Reset (delete all) memories in scope.""" - if scope_prefix is None or not scope_prefix.strip("/"): - for shard_path in (self._central_path, self._local_path): - if shard_path.exists(): - shutil.rmtree(shard_path, ignore_errors=True) - self._local_has_data = False - self._indexes_created = False - return + def reset( + self, + *, + tenant_id: str, + user_id: str | None = None, + scope_prefix: str | None = None, + ) -> None: + """Reset (delete all) memories for this tenant. - self.delete(scope_prefix=scope_prefix) + Even an unscoped reset is bound to the tenant; resetting one tenant + never wipes another tenant's data. To remove the entire on-disk + store, delete the storage directory directly. + """ + # Always go through delete() so the tenant_id predicate is applied; + # the old "shutil.rmtree everything" path was cross-tenant by design + # and is no longer reachable from this class. + self.delete( + tenant_id=tenant_id, user_id=user_id, scope_prefix=scope_prefix + ) def touch_records(self, record_ids: list[str]) -> None: """Update last_accessed to now for the given record IDs.""" @@ -836,16 +941,21 @@ async def asave(self, records: list[MemoryRecord]) -> None: async def asearch( self, query_embedding: list[float], + *, + tenant_id: str, + user_id: str | None = None, scope_prefix: str | None = None, categories: list[str] | None = None, metadata_filter: dict[str, Any] | None = None, limit: int = 10, min_score: float = 0.0, ) -> list[tuple[MemoryRecord, float]]: - """Search for memories asynchronously.""" + """Search for memories asynchronously (scoped to tenant).""" return await asyncio.to_thread( self.search, query_embedding, + tenant_id=tenant_id, + user_id=user_id, scope_prefix=scope_prefix, categories=categories, metadata_filter=metadata_filter, @@ -855,15 +965,20 @@ async def asearch( async def adelete( self, + *, + tenant_id: str, + user_id: str | None = None, scope_prefix: str | None = None, categories: list[str] | None = None, record_ids: list[str] | None = None, older_than: datetime | None = None, metadata_filter: dict[str, Any] | None = None, ) -> int: - """Delete memories asynchronously.""" + """Delete memories asynchronously (scoped to tenant).""" return await asyncio.to_thread( self.delete, + tenant_id=tenant_id, + user_id=user_id, scope_prefix=scope_prefix, categories=categories, record_ids=record_ids, diff --git a/lib/crewai/src/crewai/memory/unified_memory.py b/lib/crewai/src/crewai/memory/unified_memory.py index 02c1818224..93afab30db 100644 --- a/lib/crewai/src/crewai/memory/unified_memory.py +++ b/lib/crewai/src/crewai/memory/unified_memory.py @@ -696,6 +696,7 @@ def recall( else: raw = self._storage.search( embedding, + tenant_id="_default", scope_prefix=effective_scope, categories=categories, limit=limit, @@ -800,6 +801,7 @@ def forget( elif effective_scope is not None and self.root_scope: effective_scope = join_scope_paths(self.root_scope, effective_scope) return self._storage.delete( + tenant_id="_default", scope_prefix=effective_scope, categories=categories, record_ids=record_ids, @@ -832,7 +834,7 @@ def update( Raises: ValueError: If the record is not found. """ - existing = self._storage.get_record(record_id) + existing = self._storage.get_record(record_id, tenant_id="_default") if existing is None: raise ValueError(f"Record not found: {record_id}") now = datetime.utcnow() @@ -889,7 +891,7 @@ def list_scopes(self, path: str | None = None) -> list[str]: effective_path = join_scope_paths(self.root_scope, effective_path) elif effective_path is None: effective_path = "/" - return self._storage.list_scopes(effective_path) + return self._storage.list_scopes(effective_path, tenant_id="_default") def list_records( self, scope: str | None = None, limit: int = 200, offset: int = 0 @@ -908,7 +910,10 @@ def list_records( elif effective_scope is not None and self.root_scope: effective_scope = join_scope_paths(self.root_scope, effective_scope) return self._storage.list_records( - scope_prefix=effective_scope, limit=limit, offset=offset + tenant_id="_default", + scope_prefix=effective_scope, + limit=limit, + offset=offset, ) def info(self, path: str | None = None) -> ScopeInfo: @@ -925,7 +930,7 @@ def info(self, path: str | None = None) -> ScopeInfo: effective_path = join_scope_paths(self.root_scope, effective_path) elif effective_path is None: effective_path = "/" - return self._storage.get_scope_info(effective_path) + return self._storage.get_scope_info(effective_path, tenant_id="_default") def tree(self, path: str | None = None, max_depth: int = 3) -> str: """Return a formatted tree of scopes (string). @@ -948,7 +953,7 @@ def tree(self, path: str | None = None, max_depth: int = 3) -> str: def _walk(p: str, depth: int, prefix: str) -> None: if depth > max_depth: return - info = self._storage.get_scope_info(p) + info = self._storage.get_scope_info(p, tenant_id="_default") lines.append(f"{prefix}{p or '/'} ({info.record_count} records)") for child in info.child_scopes[:20]: _walk(child, depth + 1, prefix + " ") @@ -968,7 +973,9 @@ def list_categories(self, path: str | None = None) -> dict[str, int]: effective_path = self.root_scope elif effective_path is not None and self.root_scope: effective_path = join_scope_paths(self.root_scope, effective_path) - return self._storage.list_categories(scope_prefix=effective_path) + return self._storage.list_categories( + tenant_id="_default", scope_prefix=effective_path + ) def reset(self, scope: str | None = None) -> None: """Reset (delete all) memories in scope. @@ -982,7 +989,7 @@ def reset(self, scope: str | None = None) -> None: effective_scope = self.root_scope elif effective_scope is not None and self.root_scope: effective_scope = join_scope_paths(self.root_scope, effective_scope) - self._storage.reset(scope_prefix=effective_scope) + self._storage.reset(tenant_id="_default", scope_prefix=effective_scope) async def aextract_memories(self, content: str) -> list[str]: """Async variant of extract_memories.""" diff --git a/lib/crewai/tests/memory/test_qdrant_edge_storage.py b/lib/crewai/tests/memory/test_qdrant_edge_storage.py index bd30e8758f..19549ba089 100644 --- a/lib/crewai/tests/memory/test_qdrant_edge_storage.py +++ b/lib/crewai/tests/memory/test_qdrant_edge_storage.py @@ -57,7 +57,9 @@ def _rec( def test_save_search(storage: QdrantEdgeStorage) -> None: r = _rec(content="test content", scope="/foo", categories=["cat1"], importance=0.8) storage.save([r]) - results = storage.search([0.1, 0.2, 0.3, 0.4], scope_prefix="/foo", limit=5) + results = storage.search( + [0.1, 0.2, 0.3, 0.4], tenant_id="_default", scope_prefix="/foo", limit=5 + ) assert len(results) == 1 rec, score = results[0] assert rec.content == "test content" @@ -68,10 +70,10 @@ def test_save_search(storage: QdrantEdgeStorage) -> None: def test_delete_count(storage: QdrantEdgeStorage) -> None: r = _rec(scope="/") storage.save([r]) - assert storage.count() == 1 - n = storage.delete(scope_prefix="/") + assert storage.count(tenant_id="_default") == 1 + n = storage.delete(tenant_id="_default", scope_prefix="/") assert n >= 1 - assert storage.count() == 0 + assert storage.count(tenant_id="_default") == 0 def test_update_get_record(storage: QdrantEdgeStorage) -> None: @@ -79,13 +81,13 @@ def test_update_get_record(storage: QdrantEdgeStorage) -> None: storage.save([r]) r.content = "updated" storage.update(r) - found = storage.get_record(r.id) + found = storage.get_record(r.id, tenant_id="_default") assert found is not None assert found.content == "updated" def test_get_record_not_found(storage: QdrantEdgeStorage) -> None: - assert storage.get_record("nonexistent-id") is None + assert storage.get_record("nonexistent-id", tenant_id="_default") is None @@ -95,9 +97,9 @@ def test_list_scopes_get_scope_info(storage: QdrantEdgeStorage) -> None: _rec(content="a", scope="/"), _rec(content="b", scope="/team"), ]) - scopes = storage.list_scopes("/") + scopes = storage.list_scopes("/", tenant_id="_default") assert "/team" in scopes - info = storage.get_scope_info("/") + info = storage.get_scope_info("/", tenant_id="_default") assert info.record_count >= 1 assert info.path == "/" @@ -108,7 +110,9 @@ def test_scope_prefix_filter(storage: QdrantEdgeStorage) -> None: _rec(content="eng note", scope="/crew/eng"), _rec(content="other note", scope="/other"), ]) - results = storage.search([0.1, 0.2, 0.3, 0.4], scope_prefix="/crew", limit=10) + results = storage.search( + [0.1, 0.2, 0.3, 0.4], tenant_id="_default", scope_prefix="/crew", limit=10 + ) assert len(results) == 2 scopes = {r.scope for r, _ in results} assert "/crew/sales" in scopes @@ -123,7 +127,7 @@ def test_category_filter(storage: QdrantEdgeStorage) -> None: _rec(content="cat2 item", categories=["cat2"]), ]) results = storage.search( - [0.1, 0.2, 0.3, 0.4], categories=["cat1"], limit=10 + [0.1, 0.2, 0.3, 0.4], tenant_id="_default", categories=["cat1"], limit=10 ) assert len(results) == 1 assert results[0][0].categories == ["cat1"] @@ -135,7 +139,10 @@ def test_metadata_filter(storage: QdrantEdgeStorage) -> None: _rec(content="without key", metadata={"env": "dev"}), ]) results = storage.search( - [0.1, 0.2, 0.3, 0.4], metadata_filter={"env": "prod"}, limit=10 + [0.1, 0.2, 0.3, 0.4], + tenant_id="_default", + metadata_filter={"env": "prod"}, + limit=10, ) assert len(results) == 1 assert results[0][0].metadata["env"] == "prod" @@ -152,8 +159,8 @@ def test_list_records_pagination(storage: QdrantEdgeStorage) -> None: for i in range(5) ] storage.save(records) - page1 = storage.list_records(limit=2, offset=0) - page2 = storage.list_records(limit=2, offset=2) + page1 = storage.list_records(tenant_id="_default", limit=2, offset=0) + page2 = storage.list_records(tenant_id="_default", limit=2, offset=2) assert len(page1) == 2 assert len(page2) == 2 # Newest first. @@ -165,7 +172,7 @@ def test_list_categories(storage: QdrantEdgeStorage) -> None: _rec(categories=["a", "b"]), _rec(categories=["b", "c"]), ]) - cats = storage.list_categories() + cats = storage.list_categories(tenant_id="_default") assert cats.get("b", 0) == 2 assert cats.get("a", 0) >= 1 assert cats.get("c", 0) >= 1 @@ -176,26 +183,26 @@ def test_list_categories(storage: QdrantEdgeStorage) -> None: def test_touch_records(storage: QdrantEdgeStorage) -> None: r = _rec() storage.save([r]) - before = storage.get_record(r.id) + before = storage.get_record(r.id, tenant_id="_default") assert before is not None old_accessed = before.last_accessed storage.touch_records([r.id]) - after = storage.get_record(r.id) + after = storage.get_record(r.id, tenant_id="_default") assert after is not None assert after.last_accessed >= old_accessed def test_reset_full(storage: QdrantEdgeStorage) -> None: storage.save([_rec(scope="/a"), _rec(scope="/b")]) - assert storage.count() == 2 - storage.reset() - assert storage.count() == 0 + assert storage.count(tenant_id="_default") == 2 + storage.reset(tenant_id="_default") + assert storage.count(tenant_id="_default") == 0 def test_reset_scoped(storage: QdrantEdgeStorage) -> None: storage.save([_rec(scope="/a"), _rec(scope="/b")]) - storage.reset(scope_prefix="/a") - assert storage.count() == 1 + storage.reset(tenant_id="_default", scope_prefix="/a") + assert storage.count(tenant_id="_default") == 1 @@ -208,7 +215,7 @@ def test_flush_to_central(tmp_path: Path) -> None: assert not s._local_has_data assert not s._local_path.exists() # Central should have the record. - assert s.count() == 1 + assert s.count(tenant_id="_default") == 1 def test_dual_shard_search(tmp_path: Path) -> None: @@ -217,7 +224,7 @@ def test_dual_shard_search(tmp_path: Path) -> None: s.flush_to_central() s._closed = False s.save([_rec(content="local record", scope="/b")]) - results = s.search([0.1, 0.2, 0.3, 0.4], limit=10) + results = s.search([0.1, 0.2, 0.3, 0.4], tenant_id="_default", limit=10) assert len(results) == 2 contents = {r.content for r, _ in results} assert "central record" in contents @@ -230,7 +237,7 @@ def test_close_lifecycle(tmp_path: Path) -> None: s.close() # Reopen a new storage — should find the record in central. s2 = _make_storage(str(tmp_path / "edge")) - results = s2.search([0.1, 0.2, 0.3, 0.4], limit=5) + results = s2.search([0.1, 0.2, 0.3, 0.4], tenant_id="_default", limit=5) assert len(results) == 1 assert results[0][0].content == "persisted" s2.close() @@ -273,6 +280,13 @@ def test_orphaned_shard_cleanup(tmp_path: Path) -> None: "last_accessed": datetime.now(timezone.utc).replace(tzinfo=None).isoformat(), "source": "", "private": False, + # Synthetic orphan from a pre-isolation shard would lack + # tenant_id; we stamp it as "_default" here because the + # Qdrant filter pushes WHERE tenant_id = ? and an absent + # field would not match. The migrate CLI does this + # transformation for real pre-isolation data. + "tenant_id": "_default", + "user_id": "", }, ) ]) @@ -283,7 +297,7 @@ def test_orphaned_shard_cleanup(tmp_path: Path) -> None: s2 = _make_storage(str(base)) assert not orphan_path.exists() - results = s2.search([0.5, 0.5, 0.5, 0.5], limit=5) + results = s2.search([0.5, 0.5, 0.5, 0.5], tenant_id="_default", limit=5) assert len(results) >= 1 assert any(r.content == "orphaned" for r, _ in results) s2.close() diff --git a/lib/crewai/tests/memory/test_unified_memory.py b/lib/crewai/tests/memory/test_unified_memory.py index 776c36a2c6..9545c56afe 100644 --- a/lib/crewai/tests/memory/test_unified_memory.py +++ b/lib/crewai/tests/memory/test_unified_memory.py @@ -108,6 +108,7 @@ def test_lancedb_save_search(lancedb_path: Path) -> None: storage.save([r]) results = storage.search( [0.1, 0.2, 0.3, 0.4], + tenant_id="_default", scope_prefix="/foo", limit=5, ) @@ -124,10 +125,10 @@ def test_lancedb_delete_count(lancedb_path: Path) -> None: storage = LanceDBStorage(path=str(lancedb_path), vector_dim=4) r = MemoryRecord(content="x", scope="/", embedding=[0.0] * 4) storage.save([r]) - assert storage.count() == 1 - n = storage.delete(scope_prefix="/") + assert storage.count(tenant_id="_default") == 1 + n = storage.delete(tenant_id="_default", scope_prefix="/") assert n >= 1 - assert storage.count() == 0 + assert storage.count(tenant_id="_default") == 0 def test_lancedb_list_scopes_get_scope_info(lancedb_path: Path) -> None: @@ -138,9 +139,9 @@ def test_lancedb_list_scopes_get_scope_info(lancedb_path: Path) -> None: MemoryRecord(content="a", scope="/", embedding=[0.0] * 4), MemoryRecord(content="b", scope="/team", embedding=[0.0] * 4), ]) - scopes = storage.list_scopes("/") + scopes = storage.list_scopes("/", tenant_id="_default") assert "/team" in scopes # list_scopes returns children, not root itself - info = storage.get_scope_info("/") + info = storage.get_scope_info("/", tenant_id="_default") assert info.record_count >= 1 assert info.path == "/" @@ -189,10 +190,10 @@ def test_memory_forget(tmp_path: Path, mock_embedder: MagicMock) -> None: m = Memory(storage=str(tmp_path / "db2"), llm=MagicMock(), embedder=mock_embedder) m.remember("To forget", scope="/x", categories=[], importance=0.5, metadata={}) - assert m._storage.count("/x") >= 1 + assert m._storage.count(tenant_id="_default", scope_prefix="/x") >= 1 n = m.forget(scope="/x") assert n >= 1 - assert m._storage.count("/x") == 0 + assert m._storage.count(tenant_id="_default", scope_prefix="/x") == 0 def test_memory_scope_slice(tmp_path: Path, mock_embedder: MagicMock) -> None: @@ -635,7 +636,7 @@ def test_remember_survives_llm_failure( assert record.categories == [] assert record.importance == 0.5 assert record.id is not None - assert mem._storage.count() == 1 + assert mem._storage.count(tenant_id="_default") == 1 @@ -748,7 +749,7 @@ def test_intra_batch_dedup_drops_near_identical(tmp_path: Path) -> None: importance=0.7, ) mem.drain_writes() - assert mem._storage.count() == 1 + assert mem._storage.count(tenant_id="_default") == 1 def test_intra_batch_dedup_keeps_merely_similar(tmp_path: Path) -> None: @@ -781,7 +782,7 @@ def varying_embedder(texts: list[str]) -> list[list[float]]: importance=0.6, ) mem.drain_writes() - assert mem._storage.count() == 2 + assert mem._storage.count(tenant_id="_default") == 2 def test_batch_consolidation_deduplicates_against_storage( @@ -810,7 +811,7 @@ def test_batch_consolidation_deduplicates_against_storage( mem._storage.save([ MemoryRecord(content="CrewAI is great.", scope="/test", importance=0.7, embedding=emb), ]) - assert mem._storage.count() == 1 + assert mem._storage.count(tenant_id="_default") == 1 # remember_many with the same content + a new one (all identical embeddings) mem.remember_many( @@ -822,7 +823,7 @@ def test_batch_consolidation_deduplicates_against_storage( mem.drain_writes() # Intra-batch dedup fires: same embedding = 1.0 >= 0.98, so item 1 is dropped. # LLM says don't insert -> no new records. Total stays at 1. - assert mem._storage.count() == 1 + assert mem._storage.count(tenant_id="_default") == 1 def test_parallel_find_similar_runs_all_searches(tmp_path: Path) -> None: @@ -877,7 +878,7 @@ def test_single_remember_uses_batch_flow(tmp_path: Path, mock_embedder: MagicMoc assert record.content == "Single fact." assert record.scope == "/project" assert record.importance == 0.8 - assert mem._storage.count() == 1 + assert mem._storage.count(tenant_id="_default") == 1 def test_parallel_analyze_runs_concurrent_calls(tmp_path: Path) -> None: @@ -915,7 +916,7 @@ def distinct_embedder(texts: list[str]) -> list[list[float]]: mem.remember_many(["Fact A.", "Fact B.", "Fact C."]) mem.drain_writes() assert llm.call.call_count == 3 - assert mem._storage.count() == 3 + assert mem._storage.count(tenant_id="_default") == 3 @@ -950,7 +951,7 @@ def distinct_embedder(texts: list[str]) -> list[list[float]]: assert result == [] # After draining, records should exist mem.drain_writes() - assert mem._storage.count() == 2 + assert mem._storage.count(tenant_id="_default") == 2 def test_recall_drains_pending_writes(tmp_path: Path, mock_embedder: MagicMock) -> None: @@ -990,4 +991,4 @@ def test_close_drains_and_shuts_down(tmp_path: Path, mock_embedder: MagicMock) - ) mem.close() # After close, records should be persisted - assert mem._storage.count() == 1 + assert mem._storage.count(tenant_id="_default") == 1 From 6f89c45abc581f4d673c06d5854ab24de553acc6 Mon Sep 17 00:00:00 2001 From: John_J <79534962+John-Jepsen@users.noreply.github.com> Date: Thu, 28 May 2026 22:46:13 -0500 Subject: [PATCH 4/8] feat(memory): add ScopedStorage wrapper and tenant isolation tests ScopedStorage wraps any StorageBackend and binds every operation to a fixed (tenant_id, user_id) pair. It holds three contracts and they live nowhere else: 1. Stamp on write -- every record's tenant_id is overwritten with the wrapper's bound tenant before persisting; a record arriving with a different non-default tenant_id raises PermissionError instead of being silently relabeled. 2. Inject on read -- every read forwarded to the underlying backend carries the tenant_id predicate; the wrapper exposes no API to omit it. 3. Verify on return -- after the backend returns rows, the wrapper re-checks each row's tenant_id and raises RuntimeError on a foreign-tenant leak. Loud over silent, so a broken backend filter surfaces in the next test run instead of shipping quietly. The triple contract is defense in depth around the Protocol's required keyword arg (type-check-time enforcement) and the backend's pushed-down WHERE/FieldCondition (data-layer enforcement). test_tenant_isolation.py adds the security contract from the design doc. Nine ScopedStorage tests pass today (the wrapper is the enforcement chokepoint and works without Memory's involvement). One backcompat test pins the '_default' tenant single-tenant path. Two Memory-level tests are XFAIL'd with strict=True, raises=TypeError -- they will pass when PR #4 wires Memory's remember/recall/forget to take tenant_id kwargs and route through ScopedStorage. Refs: design-docs/0001-per-tenant-memory-isolation.md --- .../crewai/memory/storage/scoped_storage.py | 278 ++++++++++++++++ .../tests/memory/test_tenant_isolation.py | 310 ++++++++++++++++++ 2 files changed, 588 insertions(+) create mode 100644 lib/crewai/src/crewai/memory/storage/scoped_storage.py create mode 100644 lib/crewai/tests/memory/test_tenant_isolation.py diff --git a/lib/crewai/src/crewai/memory/storage/scoped_storage.py b/lib/crewai/src/crewai/memory/storage/scoped_storage.py new file mode 100644 index 0000000000..6bdcc552c8 --- /dev/null +++ b/lib/crewai/src/crewai/memory/storage/scoped_storage.py @@ -0,0 +1,278 @@ +"""ScopedStorage: the single chokepoint for the per-tenant memory isolation invariant. + +A ScopedStorage wraps any StorageBackend and binds every operation it forwards +to a fixed (tenant_id, user_id) pair. Three contracts are held here, and +**nowhere else**: + +1. **Stamp on write.** Every record passed to save()/update() is model_copied + with tenant_id set to the wrapper's bound tenant. A record that arrives + already stamped with a *different* tenant raises PermissionError -- silent + relabel masks bugs. +2. **Inject on read.** Every read forwarded to the underlying backend carries + the tenant_id predicate. The wrapper has no API to omit it. +3. **Verify on return.** After the backend returns rows, the wrapper re-checks + r.tenant_id == self._tenant_id on every row. A foreign-tenant row leaking + through raises RuntimeError -- loudly, not silently filtered. + +The triple contract is defense in depth. The Protocol's required keyword arg +catches forgotten parameters at type-check time; the backend's pushed-down +predicate is the SQL/Qdrant-level filter; this wrapper is the runtime guard +that fires when either of the first two fails. + +If you add a new read method, it MUST go through the tenant predicate. +If you add a new write method, it MUST go through _stamp(). + +See design-docs/0001-per-tenant-memory-isolation.md for the full design. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, Any + +from crewai.memory.types import MemoryRecord + + +if TYPE_CHECKING: + from crewai.memory.storage.backend import StorageBackend + from crewai.memory.types import ScopeInfo + + +_DEFAULT_TENANT = "_default" + + +class ScopedStorage: + """A tenant-bound proxy around a StorageBackend. + + Cheap to construct (two strings + one reference), so it is intended to be + created per request rather than cached. A long-lived Memory instance can + serve many tenants concurrently by building a fresh ScopedStorage per call. + """ + + def __init__( + self, + inner: StorageBackend, + *, + tenant_id: str, + user_id: str | None = None, + ) -> None: + if not tenant_id: + raise ValueError("ScopedStorage requires a non-empty tenant_id") + self._inner = inner + self._tenant_id = tenant_id + self._user_id = user_id + + @property + def tenant_id(self) -> str: + return self._tenant_id + + @property + def user_id(self) -> str | None: + return self._user_id + + # ------------------------------------------------------------------ + # Write path: stamp every record before it reaches the backend. + # ------------------------------------------------------------------ + + def _stamp(self, records: list[MemoryRecord]) -> list[MemoryRecord]: + stamped: list[MemoryRecord] = [] + for r in records: + if r.tenant_id and r.tenant_id != _DEFAULT_TENANT and r.tenant_id != self._tenant_id: + # Refuse to silently relabel. A caller mixing tenants is a bug, + # not something the storage layer should paper over. + raise PermissionError( + f"ScopedStorage bound to tenant_id={self._tenant_id!r} " + f"refused to save record tenant_id={r.tenant_id!r}. " + "Cross-tenant writes through a scoped handle are not allowed." + ) + updates: dict[str, Any] = {"tenant_id": self._tenant_id} + if self._user_id is not None and r.user_id is None: + updates["user_id"] = self._user_id + stamped.append(r.model_copy(update=updates)) + return stamped + + def save(self, records: list[MemoryRecord]) -> None: + self._inner.save(self._stamp(records)) + + def update(self, record: MemoryRecord) -> None: + self._inner.update(self._stamp([record])[0]) + + async def asave(self, records: list[MemoryRecord]) -> None: + await self._inner.asave(self._stamp(records)) + + # ------------------------------------------------------------------ + # Read path: inject tenant predicate, then verify every returned row. + # ------------------------------------------------------------------ + + def _verify(self, records: list[MemoryRecord]) -> None: + """Raise RuntimeError if any record's tenant does not match. + + Loud over silent. A broken backend filter must surface; quietly + filtering out the leak hides the bug for the next person. + """ + for r in records: + if r.tenant_id != self._tenant_id: + raise RuntimeError( + f"Backend returned a cross-tenant row: " + f"expected tenant_id={self._tenant_id!r}, got {r.tenant_id!r} " + f"(record id={r.id!r}). Refusing to serve." + ) + + def search( + self, + query_embedding: list[float], + *, + scope_prefix: str | None = None, + categories: list[str] | None = None, + metadata_filter: dict[str, Any] | None = None, + limit: int = 10, + min_score: float = 0.0, + ) -> list[tuple[MemoryRecord, float]]: + results = self._inner.search( + query_embedding, + tenant_id=self._tenant_id, + user_id=self._user_id, + scope_prefix=scope_prefix, + categories=categories, + metadata_filter=metadata_filter, + limit=limit, + min_score=min_score, + ) + self._verify([r for r, _ in results]) + return results + + async def asearch( + self, + query_embedding: list[float], + *, + scope_prefix: str | None = None, + categories: list[str] | None = None, + metadata_filter: dict[str, Any] | None = None, + limit: int = 10, + min_score: float = 0.0, + ) -> list[tuple[MemoryRecord, float]]: + results = await self._inner.asearch( + query_embedding, + tenant_id=self._tenant_id, + user_id=self._user_id, + scope_prefix=scope_prefix, + categories=categories, + metadata_filter=metadata_filter, + limit=limit, + min_score=min_score, + ) + self._verify([r for r, _ in results]) + return results + + def get_record(self, record_id: str) -> MemoryRecord | None: + record = self._inner.get_record( + record_id, tenant_id=self._tenant_id, user_id=self._user_id + ) + if record is None: + return None + self._verify([record]) + return record + + def list_records( + self, + *, + scope_prefix: str | None = None, + limit: int = 200, + offset: int = 0, + ) -> list[MemoryRecord]: + records = self._inner.list_records( + tenant_id=self._tenant_id, + user_id=self._user_id, + scope_prefix=scope_prefix, + limit=limit, + offset=offset, + ) + self._verify(records) + return records + + def delete( + self, + *, + scope_prefix: str | None = None, + categories: list[str] | None = None, + record_ids: list[str] | None = None, + older_than: datetime | None = None, + metadata_filter: dict[str, Any] | None = None, + ) -> int: + return self._inner.delete( + tenant_id=self._tenant_id, + user_id=self._user_id, + scope_prefix=scope_prefix, + categories=categories, + record_ids=record_ids, + older_than=older_than, + metadata_filter=metadata_filter, + ) + + async def adelete( + self, + *, + scope_prefix: str | None = None, + categories: list[str] | None = None, + record_ids: list[str] | None = None, + older_than: datetime | None = None, + metadata_filter: dict[str, Any] | None = None, + ) -> int: + return await self._inner.adelete( + tenant_id=self._tenant_id, + user_id=self._user_id, + scope_prefix=scope_prefix, + categories=categories, + record_ids=record_ids, + older_than=older_than, + metadata_filter=metadata_filter, + ) + + def reset(self, *, scope_prefix: str | None = None) -> None: + self._inner.reset( + tenant_id=self._tenant_id, + user_id=self._user_id, + scope_prefix=scope_prefix, + ) + + def get_scope_info(self, scope: str) -> ScopeInfo: + return self._inner.get_scope_info( + scope, tenant_id=self._tenant_id, user_id=self._user_id + ) + + def list_scopes(self, parent: str = "/") -> list[str]: + return self._inner.list_scopes( + parent, tenant_id=self._tenant_id, user_id=self._user_id + ) + + def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]: + return self._inner.list_categories( + tenant_id=self._tenant_id, + user_id=self._user_id, + scope_prefix=scope_prefix, + ) + + def count(self, scope_prefix: str | None = None) -> int: + return self._inner.count( + tenant_id=self._tenant_id, + user_id=self._user_id, + scope_prefix=scope_prefix, + ) + + def touch_records(self, record_ids: list[str]) -> None: + """Pass-through for non-isolation-relevant maintenance. + + touch_records is a write to last_accessed and does not need a tenant + predicate because it operates on specific record ids the caller + already retrieved through a scoped read. If those ids leak across + tenants somehow, the underlying backend's per-row tenant_id is + unchanged. + """ + touch = getattr(self._inner, "touch_records", None) + if touch is not None: + touch(record_ids) + + def close(self) -> None: + close = getattr(self._inner, "close", None) + if close is not None: + close() diff --git a/lib/crewai/tests/memory/test_tenant_isolation.py b/lib/crewai/tests/memory/test_tenant_isolation.py new file mode 100644 index 0000000000..0e6d891f85 --- /dev/null +++ b/lib/crewai/tests/memory/test_tenant_isolation.py @@ -0,0 +1,310 @@ +"""Isolation invariant tests for per-tenant memory. + +The contract: + A recall scoped to tenant A NEVER returns a row written by tenant B. + +If any test in this file fails or passes vacuously (e.g. because the +embeddings happen to differ), the per-tenant isolation feature is broken. + +Tests in this file are split into two groups: + +* ``TestScopedStorage`` -- exercise ScopedStorage directly. These tests pass + as of PR #3 (this PR) because they go through the wrapper, which is the + enforcement chokepoint. +* ``TestMemoryIsolation`` -- exercise Memory.remember / Memory.recall / + Memory.forget. These tests are XFAIL'd until PR #4 wires ScopedStorage + through Memory; the XFAILs are removed in that PR. Keeping them here in + PR #3 (failing on purpose) is the design doc's test contract -- a feature + without an isolation test is unfinished, and these are the failing + receipts that motivate PR #4. +""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from crewai.memory.storage.lancedb_storage import LanceDBStorage +from crewai.memory.storage.scoped_storage import ScopedStorage +from crewai.memory.types import MemoryRecord + + +@pytest.fixture +def lance_path(tmp_path: Path) -> Path: + return tmp_path / "isolation.lance" + + +@pytest.fixture +def lance_storage(lance_path: Path) -> LanceDBStorage: + return LanceDBStorage(path=str(lance_path), vector_dim=4) + + +@pytest.fixture +def mock_embedder() -> MagicMock: + """Embedder that returns DIFFERENT embeddings per text, never identical.""" + m = MagicMock() + + def embed(texts: list[str]) -> list[list[float]]: + out = [] + for t in texts: + h = abs(hash(t)) % 1000 / 1000.0 + out.append([h, 1.0 - h, h * 0.5, 1.0 - h * 0.5]) + return out + + m.side_effect = embed + return m + + +# ---------------------------------------------------------------------- +# ScopedStorage direct tests -- the enforcement primitive itself. +# These pass as of PR #3. +# ---------------------------------------------------------------------- + + +class TestScopedStorage: + def test_stamp_on_write_overrides_default_tenant( + self, lance_storage: LanceDBStorage + ) -> None: + # A record with the implicit default tenant arrives at a ScopedStorage + # bound to "alice"; it should be stamped as "alice" before persisting. + scoped = ScopedStorage(lance_storage, tenant_id="alice") + scoped.save([ + MemoryRecord(content="alice note", scope="/", embedding=[0.1] * 4) + ]) + records = lance_storage.list_records(tenant_id="alice") + assert len(records) == 1 + assert records[0].tenant_id == "alice" + # And the default tenant sees nothing. + assert lance_storage.list_records(tenant_id="_default") == [] + + def test_save_rejects_cross_tenant_record( + self, lance_storage: LanceDBStorage + ) -> None: + scoped = ScopedStorage(lance_storage, tenant_id="alice") + bad = MemoryRecord( + content="trojan", tenant_id="bob", scope="/", embedding=[0.1] * 4 + ) + with pytest.raises(PermissionError, match="alice.*bob|bob.*alice"): + scoped.save([bad]) + + def test_cross_tenant_search_returns_nothing( + self, lance_storage: LanceDBStorage + ) -> None: + # Two tenants save near-identical content with the same embedding. + # Alice's recall must never see Bob's row even though semantic + # similarity would otherwise rank it first. + alice = ScopedStorage(lance_storage, tenant_id="alice") + bob = ScopedStorage(lance_storage, tenant_id="bob") + + embedding = [0.5, 0.5, 0.5, 0.5] + alice.save([ + MemoryRecord( + content="API key is alice-secret-123", + scope="/credentials", + embedding=embedding, + ) + ]) + bob.save([ + MemoryRecord( + content="API key is bob-secret-456", + scope="/credentials", + embedding=embedding, + ) + ]) + + alice_hits = alice.search(embedding, limit=10) + bob_hits = bob.search(embedding, limit=10) + + assert len(alice_hits) == 1 + assert all(r.tenant_id == "alice" for r, _ in alice_hits) + assert "alice-secret" in alice_hits[0][0].content + assert not any("bob-secret" in r.content for r, _ in alice_hits) + + assert len(bob_hits) == 1 + assert all(r.tenant_id == "bob" for r, _ in bob_hits) + assert "bob-secret" in bob_hits[0][0].content + assert not any("alice-secret" in r.content for r, _ in bob_hits) + + def test_get_record_cross_tenant_returns_none( + self, lance_storage: LanceDBStorage + ) -> None: + alice = ScopedStorage(lance_storage, tenant_id="alice") + bob = ScopedStorage(lance_storage, tenant_id="bob") + alice.save([ + MemoryRecord( + id="rec-1", content="alice note", scope="/", embedding=[0.1] * 4 + ) + ]) + # Bob asking for Alice's record by id sees nothing. + assert bob.get_record("rec-1") is None + # Alice still sees her own. + assert alice.get_record("rec-1") is not None + + def test_delete_is_scoped(self, lance_storage: LanceDBStorage) -> None: + alice = ScopedStorage(lance_storage, tenant_id="alice") + bob = ScopedStorage(lance_storage, tenant_id="bob") + alice.save([ + MemoryRecord(content="alice note", scope="/", embedding=[0.1] * 4) + ]) + bob.save([ + MemoryRecord(content="bob note", scope="/", embedding=[0.1] * 4) + ]) + + # Alice deletes her entire tenant; Bob's row must survive. + deleted = alice.delete() + assert deleted == 1 + bob_rows = bob.list_records() + assert len(bob_rows) == 1 + assert "bob note" in bob_rows[0].content + + def test_reset_is_scoped(self, lance_storage: LanceDBStorage) -> None: + alice = ScopedStorage(lance_storage, tenant_id="alice") + bob = ScopedStorage(lance_storage, tenant_id="bob") + alice.save([ + MemoryRecord(content="alice", scope="/", embedding=[0.1] * 4) + ]) + bob.save([ + MemoryRecord(content="bob", scope="/", embedding=[0.1] * 4) + ]) + alice.reset() + assert alice.count() == 0 + assert bob.count() == 1 + + def test_backend_leak_is_loud(self, lance_storage: LanceDBStorage) -> None: + """If the backend filter ever leaks a foreign-tenant row, ScopedStorage + raises RuntimeError instead of quietly filtering. Loud over silent. + """ + alice = ScopedStorage(lance_storage, tenant_id="alice") + bad = MemoryRecord( + content="leak", + tenant_id="bob", + scope="/", + embedding=[0.1] * 4, + ) + # Monkeypatch the inner search to return a foreign-tenant row. + alice._inner.search = MagicMock(return_value=[(bad, 0.99)]) # type: ignore[method-assign] + with pytest.raises(RuntimeError, match="cross-tenant"): + alice.search([0.1] * 4) + + def test_user_id_filter_within_tenant( + self, lance_storage: LanceDBStorage + ) -> None: + # Tenant 'acme' has two users; an instance bound to user_id='alice' + # only sees alice's rows. + alice = ScopedStorage(lance_storage, tenant_id="acme", user_id="alice") + bob = ScopedStorage(lance_storage, tenant_id="acme", user_id="bob") + alice.save([ + MemoryRecord(content="alice preferences", scope="/", embedding=[0.1] * 4) + ]) + bob.save([ + MemoryRecord(content="bob preferences", scope="/", embedding=[0.1] * 4) + ]) + + alice_rows = alice.list_records() + bob_rows = bob.list_records() + assert len(alice_rows) == 1 + assert "alice" in alice_rows[0].content + assert len(bob_rows) == 1 + assert "bob" in bob_rows[0].content + + # The tenant-admin view (no user_id) sees both. + admin = ScopedStorage(lance_storage, tenant_id="acme") + admin_rows = admin.list_records() + assert len(admin_rows) == 2 + + def test_constructor_rejects_empty_tenant( + self, lance_storage: LanceDBStorage + ) -> None: + with pytest.raises(ValueError, match="tenant_id"): + ScopedStorage(lance_storage, tenant_id="") + + +# ---------------------------------------------------------------------- +# Memory-level isolation tests. +# +# These are XFAIL'd in PR #3 because Memory does not yet route through +# ScopedStorage -- every internal storage call hardcodes +# tenant_id="_default". PR #4 wires the resolved tenant through, and the +# XFAIL markers are removed there. The intent of having the failing tests +# in this PR is to keep the security contract visible in the test suite +# from the moment the wrapper lands. +# ---------------------------------------------------------------------- + + +class TestMemoryBackCompat: + """Single-tenant deployments (no tenant_id passed anywhere) must keep + working unchanged. This already passes today via the '_default' fallback. + """ + + def test_default_tenant_backcompat( + self, tmp_path: Path, mock_embedder: MagicMock + ) -> None: + from crewai.memory.unified_memory import Memory + + m = Memory( + storage=str(tmp_path / "mem.lance"), + llm=MagicMock(), + embedder=mock_embedder, + ) + m.remember("the meeting is at 3pm", scope="/") + hits = m.recall("when is the meeting", depth="shallow") + assert hits + assert all(h.record.tenant_id == "_default" for h in hits) + + +@pytest.mark.xfail( + reason="PR #4 wires ScopedStorage through Memory and adds tenant_id " + "kwargs to remember/recall/forget. Until then these calls TypeError.", + strict=True, + raises=TypeError, +) +class TestMemoryIsolation: + def test_cross_tenant_recall_returns_nothing( + self, tmp_path: Path, mock_embedder: MagicMock + ) -> None: + from crewai.memory.unified_memory import Memory + + m = Memory( + storage=str(tmp_path / "mem.lance"), + llm=MagicMock(), + embedder=mock_embedder, + ) + m.remember( + "API key is alice-secret-123", + tenant_id="alice", + scope="/credentials", + ) + m.remember( + "API key is bob-secret-456", + tenant_id="bob", + scope="/credentials", + ) + + alice = m.recall("what is my api key", tenant_id="alice", depth="shallow") + bob = m.recall("what is my api key", tenant_id="bob", depth="shallow") + + assert all(h.record.tenant_id == "alice" for h in alice) + assert all(h.record.tenant_id == "bob" for h in bob) + assert any("alice-secret" in h.record.content for h in alice) + assert any("bob-secret" in h.record.content for h in bob) + assert not any("bob-secret" in h.record.content for h in alice) + assert not any("alice-secret" in h.record.content for h in bob) + + def test_forget_is_scoped( + self, tmp_path: Path, mock_embedder: MagicMock + ) -> None: + from crewai.memory.unified_memory import Memory + + m = Memory( + storage=str(tmp_path / "mem.lance"), + llm=MagicMock(), + embedder=mock_embedder, + ) + m.remember("alice note", tenant_id="alice", scope="/") + m.remember("bob note", tenant_id="bob", scope="/") + deleted = m.forget(tenant_id="alice") + assert deleted == 1 + bob_hits = m.recall("note", tenant_id="bob", depth="shallow") + assert any("bob note" in h.record.content for h in bob_hits) From 9e3adb018682f828b30e982474f42d3533d74f52 Mon Sep 17 00:00:00 2001 From: John_J <79534962+John-Jepsen@users.noreply.github.com> Date: Thu, 28 May 2026 22:54:56 -0500 Subject: [PATCH 5/8] feat(memory): wire ScopedStorage through Memory and Flows Adds tenant_id and user_id to Memory and threads them through every public method (remember, remember_many, recall, forget, update, list_records, list_scopes, list_categories, info, tree, reset, plus async variants). Resolution order is per-call kwarg > instance default > "_default", so single-tenant deployments keep working unchanged. The mechanism: a new _scoped(tenant_id, user_id) factory builds a ScopedStorage proxy bound to the resolved tenant. Every internal storage call routes through it; the temporary tenant_id="_default" arguments from PR #2 are removed. RecallFlow and EncodingFlow are constructed with the ScopedStorage instance, so once a Flow is running its leaf calls cannot escape the tenant filter -- the deep recall LLM exploration path inherits isolation by construction. Within-tenant policy: the existing source/private post-filter in recall() is kept (it operates AFTER ScopedStorage has already filtered to the tenant, so it cannot leak cross-tenant rows). It is demoted in the docstring to a within-tenant convenience; per-user isolation should use user_id, not source/private. Tests: the two previously-XFAIL'd Memory-level isolation tests now pass (cross-tenant recall and forget). A new test_instance_default_tenant_holds pins the "Memory(tenant_id='X')" construction pattern. All 13 isolation tests green; full memory suite 129 passed, 19 skipped (qdrant optional dep). Refs: design-docs/0001-per-tenant-memory-isolation.md --- lib/crewai/src/crewai/memory/encoding_flow.py | 13 +- lib/crewai/src/crewai/memory/recall_flow.py | 13 +- .../src/crewai/memory/unified_memory.py | 235 +++++++++++++++--- .../tests/memory/test_tenant_isolation.py | 55 ++-- 4 files changed, 240 insertions(+), 76 deletions(-) diff --git a/lib/crewai/src/crewai/memory/encoding_flow.py b/lib/crewai/src/crewai/memory/encoding_flow.py index 0d6e52b093..968b439bff 100644 --- a/lib/crewai/src/crewai/memory/encoding_flow.py +++ b/lib/crewai/src/crewai/memory/encoding_flow.py @@ -176,7 +176,6 @@ def _search_one( return self._storage.search( # type: ignore[no-any-return] item.embedding, - tenant_id="_default", scope_prefix=effective_prefix, categories=None, limit=self._config.consolidation_limit, @@ -249,13 +248,9 @@ def parallel_analyze(self) -> None: None, ) scope_search_root = active_root if active_root else "/" - existing_scopes = self._storage.list_scopes( - scope_search_root, tenant_id="_default" - ) or ["/"] + existing_scopes = self._storage.list_scopes(scope_search_root) or ["/"] existing_categories = list( - self._storage.list_categories( - tenant_id="_default", scope_prefix=active_root - ).keys() + self._storage.list_categories(scope_prefix=active_root).keys() ) save_futures: dict[int, Future[MemoryAnalysis]] = {} @@ -454,9 +449,7 @@ def execute_plans(self) -> None: updated_records: dict[str, MemoryRecord] = {} if dedup_deletes: - self._storage.delete( - tenant_id="_default", record_ids=list(dedup_deletes) - ) + self._storage.delete(record_ids=list(dedup_deletes)) self.state.records_deleted += len(dedup_deletes) for rid, (_item_idx, new_content) in dedup_updates.items(): diff --git a/lib/crewai/src/crewai/memory/recall_flow.py b/lib/crewai/src/crewai/memory/recall_flow.py index 7480475b11..9da5dca64a 100644 --- a/lib/crewai/src/crewai/memory/recall_flow.py +++ b/lib/crewai/src/crewai/memory/recall_flow.py @@ -97,7 +97,6 @@ def _search_one( ) -> tuple[str, list[tuple[MemoryRecord, float]]]: raw = self._storage.search( embedding, - tenant_id="_default", scope_prefix=scope, categories=search_categories, limit=self.state.limit * _RECALL_OVERSAMPLE_FACTOR, @@ -202,15 +201,11 @@ def analyze_query_step(self) -> QueryAnalysis: ) self.state.query_analysis = analysis else: - available = self._storage.list_scopes( - self.state.scope or "/", tenant_id="_default" - ) + available = self._storage.list_scopes(self.state.scope or "/") if not available: available = ["/"] scope_info = ( - self._storage.get_scope_info( - self.state.scope or "/", tenant_id="_default" - ) + self._storage.get_scope_info(self.state.scope or "/") if self.state.scope else None ) @@ -254,9 +249,7 @@ def filter_and_chunk(self) -> list[str]: candidates = [s for s in analysis.suggested_scopes if s] else: try: - candidates = self._storage.list_scopes( - scope_prefix, tenant_id="_default" - ) + candidates = self._storage.list_scopes(scope_prefix) except Exception: logger.warning( "Storage list_scopes failed in filter_and_chunk, " diff --git a/lib/crewai/src/crewai/memory/unified_memory.py b/lib/crewai/src/crewai/memory/unified_memory.py index 93afab30db..31db64dfc6 100644 --- a/lib/crewai/src/crewai/memory/unified_memory.py +++ b/lib/crewai/src/crewai/memory/unified_memory.py @@ -23,6 +23,7 @@ from crewai.llms.base_llm import BaseLLM from crewai.memory.analyze import extract_memories_from_content from crewai.memory.storage.backend import StorageBackend +from crewai.memory.storage.scoped_storage import ScopedStorage from crewai.memory.types import ( MemoryConfig, MemoryMatch, @@ -137,6 +138,23 @@ class Memory(BaseModel): "will store memories at '/crew/research/'." ), ) + tenant_id: str = Field( + default="_default", + description=( + "Default tenant for all save/recall calls on this instance. The hard " + "isolation boundary: a recall scoped to tenant A never returns rows " + "written by tenant B. Per-call tenant_id kwargs on remember/recall/forget " + "override this default. Leave at '_default' for single-tenant deployments." + ), + ) + user_id: str | None = Field( + default=None, + description=( + "Default user_id within the tenant for all calls on this instance. " + "Soft sub-partition; tenant_id is the security boundary. Per-call " + "kwargs override this default." + ), + ) _config: MemoryConfig = PrivateAttr() _llm_instance: BaseLLM | None = PrivateAttr(default=None) @@ -221,6 +239,26 @@ def model_post_init(self, __context: Any) -> None: _MEMORY_DOCS_URL = "https://docs.crewai.com/concepts/memory" + def _scoped( + self, + tenant_id: str | None = None, + user_id: str | None = None, + ) -> ScopedStorage: + """Return a ScopedStorage proxy bound to the resolved (tenant_id, user_id). + + Resolution order: per-call kwarg > instance default > '_default'. + + Every internal call into the underlying storage goes through this + factory so the isolation invariant is enforced uniformly. Cheap to + construct -- a fresh instance per call lets a single Memory instance + serve many tenants concurrently. + """ + resolved_tenant = tenant_id or self.tenant_id + resolved_user = user_id if user_id is not None else self.user_id + return ScopedStorage( + self._storage, tenant_id=resolved_tenant, user_id=resolved_user + ) + @property def _llm(self) -> BaseLLM: """Lazy LLM initialization -- only created when first needed.""" @@ -341,11 +379,14 @@ def _encode_batch( source: str | None = None, private: bool = False, root_scope: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, ) -> list[MemoryRecord]: """Run the batch EncodingFlow for one or more items. No event emission. - This is the core encoding logic shared by ``remember()`` and - ``remember_many()``. Events are managed by the calling method. + Hands the EncodingFlow a ScopedStorage bound to the resolved tenant. + Once the Flow holds a scoped handle, every save and lookup it does + is automatically tenant-scoped -- the Flow cannot escape the filter. Args: contents: List of text content to encode and store. @@ -357,6 +398,8 @@ def _encode_batch( private: Whether items are private. root_scope: Structural root scope prefix. LLM-inferred or explicit scopes are nested under this root. + tenant_id: Tenant for this batch (resolution: arg > instance default > "_default"). + user_id: Optional sub-tenant identity. Returns: List of created MemoryRecord instances. @@ -364,7 +407,7 @@ def _encode_batch( from crewai.memory.encoding_flow import EncodingFlow flow = EncodingFlow( - storage=self._storage, + storage=self._scoped(tenant_id, user_id), llm=self._llm, embedder=self._embedder, config=self._config, @@ -400,6 +443,8 @@ def remember( private: bool = False, agent_role: str | None = None, root_scope: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, ) -> MemoryRecord | None: """Store a single item in memory (synchronous). @@ -413,11 +458,19 @@ def remember( categories: Optional categories; inferred if None. metadata: Optional metadata; merged with LLM-extracted if inferred. importance: Optional importance 0-1; inferred if None. - source: Optional provenance identifier (e.g. user ID, session ID). - private: If True, only visible to recall from the same source. + source: Optional provenance identifier. Note: provenance only -- + not an isolation boundary. Use tenant_id for isolation. + private: If True, hidden from recall by other sources within the + tenant. Note: within-tenant filter, not isolation. Use + tenant_id (or user_id within a tenant) for isolation. agent_role: Optional agent role for event metadata. root_scope: Optional root scope override. If provided, this overrides the instance-level root_scope for this call only. + tenant_id: Tenant for this record (resolution: arg > instance + default > "_default"). The hard isolation boundary -- a row + saved with tenant_id="A" is never returned by recall scoped + to tenant_id="B". + user_id: Optional sub-tenant identity within tenant_id. Returns: The created MemoryRecord, or None if this memory is read-only. @@ -454,6 +507,8 @@ def remember( source, private, effective_root, + tenant_id, + user_id, ) records = future.result() record = records[0] if records else None @@ -493,6 +548,8 @@ def remember_many( private: bool = False, agent_role: str | None = None, root_scope: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, ) -> list[MemoryRecord]: """Store multiple items in memory (non-blocking). @@ -537,6 +594,8 @@ def remember_many( private, agent_role, effective_root, + tenant_id, + user_id, ) return [] @@ -551,6 +610,8 @@ def _background_encode_batch( private: bool, agent_role: str | None, root_scope: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, ) -> list[MemoryRecord]: """Run the encoding pipeline in a background thread with event emission. @@ -598,6 +659,8 @@ def _background_encode_batch( source, private, root_scope, + tenant_id, + user_id, ) elapsed_ms = (time.perf_counter() - start) * 1000 except RuntimeError: @@ -645,6 +708,8 @@ def recall( depth: Literal["shallow", "deep"] = "deep", source: str | None = None, include_private: bool = False, + tenant_id: str | None = None, + user_id: str | None = None, ) -> list[MemoryMatch]: """Retrieve relevant memories. @@ -653,15 +718,27 @@ def recall( targeted sub-queries, selects scopes, searches in parallel, and applies confidence-based routing for optional deeper exploration. + Isolation: results are pre-filtered to (tenant_id, user_id) at the + storage layer via ScopedStorage. The within-tenant source/private + filter below is applied AFTER tenant isolation, so it cannot leak + cross-tenant rows regardless of the source/include_private values. + Args: query: Natural language query. scope: Optional scope prefix to search within. categories: Optional category filter. limit: Max number of results. depth: "shallow" for direct vector search, "deep" for intelligent flow. - source: Optional provenance filter. Private records are only visible - when this matches the record's source. - include_private: If True, all private records are visible regardless of source. + source: Provenance filter applied within the tenant. Private records + are only visible when this matches the record's source. Note: + not an isolation boundary -- use tenant_id for that. + include_private: If True, all private records within the tenant are + visible regardless of source. + tenant_id: Tenant for this recall (resolution: arg > instance + default > "_default"). A row written by another tenant is + NEVER returned, under any ranking, embedding collision, or + query depth. + user_id: Optional sub-tenant identity within tenant_id. Returns: List of MemoryMatch, ordered by relevance. @@ -676,6 +753,8 @@ def recall( elif effective_scope is not None and self.root_scope: effective_scope = join_scope_paths(self.root_scope, effective_scope) + scoped = self._scoped(tenant_id, user_id) + _source = "unified_memory" try: crewai_event_bus.emit( @@ -694,14 +773,18 @@ def recall( if not embedding: results: list[MemoryMatch] = [] else: - raw = self._storage.search( + raw = scoped.search( embedding, - tenant_id="_default", scope_prefix=effective_scope, categories=categories, limit=limit, min_score=0.0, ) + # Within-tenant source/private filter. Applied after + # ScopedStorage already filtered to this tenant, so cannot + # leak cross-tenant rows. Kept for backward compatibility + # with single-tenant deployments using the private flag -- + # for new per-user isolation, prefer user_id. if not include_private: raw = [ (r, s) @@ -723,7 +806,7 @@ def recall( from crewai.memory.recall_flow import RecallFlow flow = RecallFlow( - storage=self._storage, + storage=scoped, llm=self._llm, embedder=self._embedder, config=self._config, @@ -781,8 +864,10 @@ def forget( older_than: datetime | None = None, metadata_filter: dict[str, Any] | None = None, record_ids: list[str] | None = None, + tenant_id: str | None = None, + user_id: str | None = None, ) -> int: - """Delete memories matching criteria. + """Delete memories matching criteria (scoped to tenant). Args: scope: Scope to delete from. If None and root_scope is set, deletes @@ -790,7 +875,12 @@ def forget( categories: Filter by categories. older_than: Delete records older than this datetime. metadata_filter: Filter by metadata fields. - record_ids: Specific record IDs to delete. + record_ids: Specific record IDs to delete. Even an id-targeted + delete is constrained to the tenant -- a foreign id will not + match. + tenant_id: Tenant whose rows are eligible (resolution: + arg > instance default > "_default"). + user_id: Optional sub-tenant filter. Returns: Number of records deleted. @@ -800,8 +890,7 @@ def forget( effective_scope = self.root_scope elif effective_scope is not None and self.root_scope: effective_scope = join_scope_paths(self.root_scope, effective_scope) - return self._storage.delete( - tenant_id="_default", + return self._scoped(tenant_id, user_id).delete( scope_prefix=effective_scope, categories=categories, record_ids=record_ids, @@ -817,8 +906,12 @@ def update( categories: list[str] | None = None, metadata: dict[str, Any] | None = None, importance: float | None = None, + tenant_id: str | None = None, + user_id: str | None = None, ) -> MemoryRecord: - """Update an existing memory record by ID. + """Update an existing memory record by ID (scoped to tenant). + + A record_id that belongs to another tenant is treated as not-found. Args: record_id: ID of the record to update. @@ -827,14 +920,18 @@ def update( categories: New categories. metadata: New metadata. importance: New importance score. + tenant_id: Tenant whose record is looked up + (resolution: arg > instance default > "_default"). + user_id: Optional sub-tenant filter. Returns: The updated MemoryRecord. Raises: - ValueError: If the record is not found. + ValueError: If the record is not found in the tenant. """ - existing = self._storage.get_record(record_id, tenant_id="_default") + scoped = self._scoped(tenant_id, user_id) + existing = scoped.get_record(record_id) if existing is None: raise ValueError(f"Record not found: {record_id}") now = datetime.utcnow() @@ -852,7 +949,7 @@ def update( if importance is not None: updates["importance"] = importance updated = existing.model_copy(update=updates) - self._storage.update(updated) + scoped.update(updated) return updated def scope(self, path: str) -> Any: @@ -877,12 +974,19 @@ def slice( read_only=read_only, ) - def list_scopes(self, path: str | None = None) -> list[str]: - """List immediate child scopes under path. + def list_scopes( + self, + path: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, + ) -> list[str]: + """List immediate child scopes under path (scoped to tenant). Args: path: Scope path to list children of. If None and root_scope is set, defaults to root_scope. Otherwise defaults to '/'. + tenant_id: Tenant whose scopes are listed. + user_id: Optional sub-tenant filter. """ effective_path = path if effective_path is None and self.root_scope: @@ -891,37 +995,50 @@ def list_scopes(self, path: str | None = None) -> list[str]: effective_path = join_scope_paths(self.root_scope, effective_path) elif effective_path is None: effective_path = "/" - return self._storage.list_scopes(effective_path, tenant_id="_default") + return self._scoped(tenant_id, user_id).list_scopes(effective_path) def list_records( - self, scope: str | None = None, limit: int = 200, offset: int = 0 + self, + scope: str | None = None, + limit: int = 200, + offset: int = 0, + tenant_id: str | None = None, + user_id: str | None = None, ) -> list[MemoryRecord]: - """List records in a scope, newest first. + """List records in a scope (scoped to tenant), newest first. Args: scope: Optional scope path prefix to filter by. If None and root_scope is set, defaults to root_scope. limit: Maximum number of records to return. offset: Number of records to skip (for pagination). + tenant_id: Tenant whose records are listed. + user_id: Optional sub-tenant filter. """ effective_scope = scope if effective_scope is None and self.root_scope: effective_scope = self.root_scope elif effective_scope is not None and self.root_scope: effective_scope = join_scope_paths(self.root_scope, effective_scope) - return self._storage.list_records( - tenant_id="_default", + return self._scoped(tenant_id, user_id).list_records( scope_prefix=effective_scope, limit=limit, offset=offset, ) - def info(self, path: str | None = None) -> ScopeInfo: - """Return scope info for path. + def info( + self, + path: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, + ) -> ScopeInfo: + """Return scope info for path (scoped to tenant). Args: path: Scope path to get info for. If None and root_scope is set, defaults to root_scope. Otherwise defaults to '/'. + tenant_id: Tenant whose scope info is read. + user_id: Optional sub-tenant filter. """ effective_path = path if effective_path is None and self.root_scope: @@ -930,15 +1047,23 @@ def info(self, path: str | None = None) -> ScopeInfo: effective_path = join_scope_paths(self.root_scope, effective_path) elif effective_path is None: effective_path = "/" - return self._storage.get_scope_info(effective_path, tenant_id="_default") + return self._scoped(tenant_id, user_id).get_scope_info(effective_path) - def tree(self, path: str | None = None, max_depth: int = 3) -> str: - """Return a formatted tree of scopes (string). + def tree( + self, + path: str | None = None, + max_depth: int = 3, + tenant_id: str | None = None, + user_id: str | None = None, + ) -> str: + """Return a formatted tree of scopes (string), scoped to tenant. Args: path: Root path for the tree. If None and root_scope is set, defaults to root_scope. Otherwise defaults to '/'. max_depth: Maximum depth to traverse. + tenant_id: Tenant whose scope tree is rendered. + user_id: Optional sub-tenant filter. """ effective_path = path if effective_path is None and self.root_scope: @@ -948,12 +1073,13 @@ def tree(self, path: str | None = None, max_depth: int = 3) -> str: elif effective_path is None: effective_path = "/" + scoped = self._scoped(tenant_id, user_id) lines: list[str] = [] def _walk(p: str, depth: int, prefix: str) -> None: if depth > max_depth: return - info = self._storage.get_scope_info(p, tenant_id="_default") + info = scoped.get_scope_info(p) lines.append(f"{prefix}{p or '/'} ({info.record_count} records)") for child in info.child_scopes[:20]: _walk(child, depth + 1, prefix + " ") @@ -961,35 +1087,52 @@ def _walk(p: str, depth: int, prefix: str) -> None: _walk(effective_path.rstrip("/") or "/", 0, "") return "\n".join(lines) if lines else f"{effective_path or '/'} (0 records)" - def list_categories(self, path: str | None = None) -> dict[str, int]: - """List categories and counts. + def list_categories( + self, + path: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, + ) -> dict[str, int]: + """List categories and counts (scoped to tenant). Args: path: Scope path to filter categories by. If None and root_scope is set, defaults to root_scope. + tenant_id: Tenant whose categories are counted. + user_id: Optional sub-tenant filter. """ effective_path = path if effective_path is None and self.root_scope: effective_path = self.root_scope elif effective_path is not None and self.root_scope: effective_path = join_scope_paths(self.root_scope, effective_path) - return self._storage.list_categories( - tenant_id="_default", scope_prefix=effective_path + return self._scoped(tenant_id, user_id).list_categories( + scope_prefix=effective_path ) - def reset(self, scope: str | None = None) -> None: - """Reset (delete all) memories in scope. + def reset( + self, + scope: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, + ) -> None: + """Reset (delete all) memories in scope, within the tenant. + + Resetting one tenant never wipes another tenant's data. To remove + the entire on-disk store, delete the storage directory directly. Args: scope: Scope to reset. If None and root_scope is set, resets only - within root_scope. If None and no root_scope, resets all. + within root_scope. + tenant_id: Tenant whose rows are wiped. + user_id: Optional sub-tenant filter. """ effective_scope = scope if effective_scope is None and self.root_scope: effective_scope = self.root_scope elif effective_scope is not None and self.root_scope: effective_scope = join_scope_paths(self.root_scope, effective_scope) - self._storage.reset(tenant_id="_default", scope_prefix=effective_scope) + self._scoped(tenant_id, user_id).reset(scope_prefix=effective_scope) async def aextract_memories(self, content: str) -> list[str]: """Async variant of extract_memories.""" @@ -1004,6 +1147,8 @@ async def aremember( importance: float | None = None, source: str | None = None, private: bool = False, + tenant_id: str | None = None, + user_id: str | None = None, ) -> MemoryRecord | None: """Async remember: delegates to sync for now.""" return self.remember( @@ -1014,6 +1159,8 @@ async def aremember( importance=importance, source=source, private=private, + tenant_id=tenant_id, + user_id=user_id, ) async def aremember_many( @@ -1026,6 +1173,8 @@ async def aremember_many( source: str | None = None, private: bool = False, agent_role: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, ) -> list[MemoryRecord]: """Async remember_many: delegates to sync for now.""" return self.remember_many( @@ -1037,6 +1186,8 @@ async def aremember_many( source=source, private=private, agent_role=agent_role, + tenant_id=tenant_id, + user_id=user_id, ) async def arecall( @@ -1048,6 +1199,8 @@ async def arecall( depth: Literal["shallow", "deep"] = "deep", source: str | None = None, include_private: bool = False, + tenant_id: str | None = None, + user_id: str | None = None, ) -> list[MemoryMatch]: """Async recall: delegates to sync for now.""" return self.recall( @@ -1058,4 +1211,6 @@ async def arecall( depth=depth, source=source, include_private=include_private, + tenant_id=tenant_id, + user_id=user_id, ) diff --git a/lib/crewai/tests/memory/test_tenant_isolation.py b/lib/crewai/tests/memory/test_tenant_isolation.py index 0e6d891f85..fcf2852e8a 100644 --- a/lib/crewai/tests/memory/test_tenant_isolation.py +++ b/lib/crewai/tests/memory/test_tenant_isolation.py @@ -6,17 +6,18 @@ If any test in this file fails or passes vacuously (e.g. because the embeddings happen to differ), the per-tenant isolation feature is broken. -Tests in this file are split into two groups: - -* ``TestScopedStorage`` -- exercise ScopedStorage directly. These tests pass - as of PR #3 (this PR) because they go through the wrapper, which is the - enforcement chokepoint. +Tests in this file are split into three groups: + +* ``TestScopedStorage`` -- exercise ScopedStorage directly. The wrapper is + the enforcement chokepoint and these tests are the primary security + contract. +* ``TestMemoryBackCompat`` -- pin that single-tenant deployments (no + tenant_id passed anywhere) keep working unchanged via the '_default' + fallback. * ``TestMemoryIsolation`` -- exercise Memory.remember / Memory.recall / - Memory.forget. These tests are XFAIL'd until PR #4 wires ScopedStorage - through Memory; the XFAILs are removed in that PR. Keeping them here in - PR #3 (failing on purpose) is the design doc's test contract -- a feature - without an isolation test is unfinished, and these are the failing - receipts that motivate PR #4. + Memory.forget with explicit tenant_id kwargs end to end. These prove + isolation works through the full public API, including the deep-recall + LLM exploration path that goes through RecallFlow. """ from __future__ import annotations @@ -254,12 +255,6 @@ def test_default_tenant_backcompat( assert all(h.record.tenant_id == "_default" for h in hits) -@pytest.mark.xfail( - reason="PR #4 wires ScopedStorage through Memory and adds tenant_id " - "kwargs to remember/recall/forget. Until then these calls TypeError.", - strict=True, - raises=TypeError, -) class TestMemoryIsolation: def test_cross_tenant_recall_returns_nothing( self, tmp_path: Path, mock_embedder: MagicMock @@ -308,3 +303,31 @@ def test_forget_is_scoped( assert deleted == 1 bob_hits = m.recall("note", tenant_id="bob", depth="shallow") assert any("bob note" in h.record.content for h in bob_hits) + + def test_instance_default_tenant_holds( + self, tmp_path: Path, mock_embedder: MagicMock + ) -> None: + # Memory configured with a default tenant should auto-scope every call. + from crewai.memory.unified_memory import Memory + + alice_mem = Memory( + storage=str(tmp_path / "mem.lance"), + llm=MagicMock(), + embedder=mock_embedder, + tenant_id="alice", + ) + bob_mem = Memory( + storage=str(tmp_path / "mem.lance"), + llm=MagicMock(), + embedder=mock_embedder, + tenant_id="bob", + ) + alice_mem.remember("alice's calendar", scope="/") + bob_mem.remember("bob's calendar", scope="/") + + alice = alice_mem.recall("calendar", depth="shallow") + bob = bob_mem.recall("calendar", depth="shallow") + assert all(h.record.tenant_id == "alice" for h in alice) + assert all(h.record.tenant_id == "bob" for h in bob) + assert not any("bob" in h.record.content for h in alice) + assert not any("alice" in h.record.content for h in bob) From 7982112b48d2c4c87a9354975a29347cf31847fb Mon Sep 17 00:00:00 2001 From: John_J <79534962+John-Jepsen@users.noreply.github.com> Date: Thu, 28 May 2026 22:58:37 -0500 Subject: [PATCH 6/8] feat(cli): add 'crewai memory migrate' for tenant isolation rollout Converts 'crewai memory' from a single command into a click group (invoke_without_command=True preserves the bare 'crewai memory' TUI behavior). Adds a 'migrate' subcommand for stamping pre-isolation LanceDB rows with a tenant_id. Options: --storage-dir Memory storage directory (default: $CREWAI_STORAGE_DIR/memory) --default-tenant Tenant for unstamped rows (default: '_default') --from-metadata-key Optional metadata key whose value becomes tenant_id; rows missing the key fall back to --default-tenant. Useful when existing data carries customer/user identifiers in metadata. --table-name LanceDB table name (default: 'memories') --dry-run Print what would change without writing. Recommended for the first run against production. The command is idempotent. Schema migration (adding the tenant_id and user_id columns) is applied via table.add_columns; per-row updates only fire when the row's current tenant_id differs from the resolved target. Dry-run simulates the column-add side effect so its reported row count matches a real run. Refs: design-docs/0001-per-tenant-memory-isolation.md --- lib/cli/src/crewai_cli/cli.py | 88 +++++++++- lib/cli/src/crewai_cli/memory_migrate.py | 170 ++++++++++++++++++ lib/cli/tests/test_memory_migrate.py | 209 +++++++++++++++++++++++ 3 files changed, 465 insertions(+), 2 deletions(-) create mode 100644 lib/cli/src/crewai_cli/memory_migrate.py create mode 100644 lib/cli/tests/test_memory_migrate.py diff --git a/lib/cli/src/crewai_cli/cli.py b/lib/cli/src/crewai_cli/cli.py index f1a8b20bab..f0774e438b 100644 --- a/lib/cli/src/crewai_cli/cli.py +++ b/lib/cli/src/crewai_cli/cli.py @@ -278,7 +278,7 @@ def reset_memories( click.echo(f"An error occurred while resetting memories: {e}", err=True) -@crewai.command() +@crewai.group(invoke_without_command=True) @click.option( "--storage-path", type=str, @@ -303,13 +303,22 @@ def reset_memories( default=None, help='Full embedder config as JSON (e.g. \'{"provider": "cohere", "config": {"model_name": "embed-v4.0"}}\').', ) +@click.pass_context def memory( + ctx: click.Context, storage_path: str | None, embedder_provider: str | None, embedder_model: str | None, embedder_config: str | None, ) -> None: - """Open the Memory TUI to browse scopes and recall memories.""" + """Memory tools. Without a subcommand, opens the Memory TUI. + + Subcommands: + migrate Stamp existing unscoped rows with a tenant_id for per-tenant isolation. + """ + if ctx.invoked_subcommand is not None: + return + try: from crewai_cli.memory_tui import MemoryTUI except ImportError as exc: @@ -338,6 +347,81 @@ def memory( app.run() +@memory.command("migrate") +@click.option( + "--storage-dir", + type=str, + default=None, + help="Memory storage directory. Defaults to $CREWAI_STORAGE_DIR/memory " + "or the platform default if unset.", +) +@click.option( + "--default-tenant", + type=str, + default="_default", + help="Tenant to assign to unstamped rows. Defaults to '_default' so " + "existing single-tenant deployments keep working unchanged.", +) +@click.option( + "--from-metadata-key", + type=str, + default=None, + help="Optional metadata key whose value becomes the row's tenant_id. " + "Rows missing the key fall back to --default-tenant. Useful when " + "existing rows already carry a customer/user identifier in metadata.", +) +@click.option( + "--table-name", + type=str, + default="memories", + help="LanceDB table name (default: memories).", +) +@click.option( + "--dry-run", + is_flag=True, + default=False, + help="Print what would change without writing. Recommended for the " + "first run against production data.", +) +def memory_migrate( + storage_dir: str | None, + default_tenant: str, + from_metadata_key: str | None, + table_name: str, + dry_run: bool, +) -> None: + """Stamp existing unscoped memory rows with a tenant_id. + + For LanceDB, the schema column is auto-added when the table is opened; + this command additionally rewrites per-row values when --from-metadata-key + is provided. It is idempotent -- running it twice is safe. Run it during + a maintenance window; do not run against a process that is actively + writing to the same storage. + """ + from crewai_cli.memory_migrate import run_migrate + + summary = run_migrate( + storage_dir=storage_dir, + default_tenant=default_tenant, + from_metadata_key=from_metadata_key, + table_name=table_name, + dry_run=dry_run, + ) + + click.echo(f"Storage directory: {summary['storage_dir']}") + click.echo(f"Table: {summary['table_name']}") + click.echo(f"Rows scanned: {summary['rows_scanned']}") + click.echo(f"Rows to stamp: {summary['rows_to_stamp']}") + if from_metadata_key: + click.echo(f"From metadata key: {from_metadata_key}") + click.echo(f" with key set: {summary['rows_with_metadata_key']}") + click.echo(f" without key: {summary['rows_to_stamp'] - summary['rows_with_metadata_key']}") + if dry_run: + click.echo("DRY RUN -- no changes written. Re-run without --dry-run to apply.") + else: + click.echo(f"Rows updated: {summary['rows_updated']}") + + @crewai.command() @click.option( "-n", diff --git a/lib/cli/src/crewai_cli/memory_migrate.py b/lib/cli/src/crewai_cli/memory_migrate.py new file mode 100644 index 0000000000..c3f408a300 --- /dev/null +++ b/lib/cli/src/crewai_cli/memory_migrate.py @@ -0,0 +1,170 @@ +"""Logic for `crewai memory migrate` -- stamp unscoped LanceDB rows with a tenant_id. + +The migration is two layers: + +1. **Schema migration** (automatic on open). When LanceDBStorage opens an + existing table, _ensure_tenant_columns() adds the tenant_id and user_id + columns with default '_default' if they are missing. So the column exists + after the first open, and every row reads back as '_default' until something + explicitly rewrites it. + +2. **Per-row migration** (this command). When --from-metadata-key is supplied, + this command scans every row and copies the metadata[key] value into the + tenant_id column. Rows missing the key keep the '_default' fallback. + +The command is idempotent. Running it twice does not change anything that was +already correct. +""" + +from __future__ import annotations + +import json +import logging +import os +from pathlib import Path +from typing import Any, TypedDict + + +_logger = logging.getLogger(__name__) + + +class MigrateSummary(TypedDict): + storage_dir: str + table_name: str + rows_scanned: int + rows_to_stamp: int + rows_with_metadata_key: int + rows_updated: int + + +def _resolve_storage_dir(storage_dir: str | None) -> Path: + """Pick the storage directory the same way LanceDBStorage does. + + Priority: + 1. --storage-dir CLI flag + 2. $CREWAI_STORAGE_DIR/memory + 3. db_storage_path() / memory (platform data dir) + """ + if storage_dir: + return Path(storage_dir) + env_dir = os.environ.get("CREWAI_STORAGE_DIR") + if env_dir: + return Path(env_dir) / "memory" + from crewai_core.paths import db_storage_path + + return Path(db_storage_path()) / "memory" + + +def run_migrate( + *, + storage_dir: str | None, + default_tenant: str, + from_metadata_key: str | None, + table_name: str, + dry_run: bool, +) -> MigrateSummary: + """Run the migration and return a summary dict. + + Returns a dict the CLI prints; raises only on I/O or schema problems. + """ + if not default_tenant: + raise ValueError("default_tenant must be a non-empty string") + + resolved_dir = _resolve_storage_dir(storage_dir) + summary: MigrateSummary = { + "storage_dir": str(resolved_dir), + "table_name": table_name, + "rows_scanned": 0, + "rows_to_stamp": 0, + "rows_with_metadata_key": 0, + "rows_updated": 0, + } + + if not resolved_dir.exists(): + _logger.info( + "Storage directory %s does not exist; nothing to migrate.", resolved_dir + ) + return summary + + import lancedb # type: ignore[import-untyped] + + db = lancedb.connect(str(resolved_dir)) + try: + table = db.open_table(table_name) + except Exception: + _logger.info( + "No table %r in %s; nothing to migrate.", table_name, resolved_dir + ) + return summary + + # Opening the table via LanceDBStorage path would auto-add the columns; + # here we use lancedb directly. Replicate the column add so this command + # also fixes pre-isolation schemas without depending on LanceDBStorage init. + existing_fields = {field.name for field in table.schema} + to_add: dict[str, str] = {} + if "tenant_id" not in existing_fields: + to_add["tenant_id"] = f"'{default_tenant}'" + if "user_id" not in existing_fields: + to_add["user_id"] = "''" + if to_add and not dry_run: + try: + table.add_columns(to_add) + except Exception as exc: + _logger.warning( + "Could not add tenant columns to %r: %s. " + "Continuing -- per-row updates will still attempt to set the field.", + table_name, + exc, + ) + + # Scan every row that needs stamping. + # A row needs stamping if: + # - tenant_id is missing/empty, OR + # - --from-metadata-key was provided AND metadata[key] differs from row's + # current tenant_id (idempotent: re-runs don't re-update unchanged rows). + rows = table.search().limit(10_000_000).to_list() + summary["rows_scanned"] = len(rows) + + to_update: list[dict[str, Any]] = [] + for row in rows: + # A row that pre-dates the tenant_id column (or has an empty value) + # is treated as if it had been stamped with default_tenant. This makes + # dry-run and real-run report identical rows_to_stamp counts -- the + # alternative is misleading dry-run output that overstates the change + # because it doesn't simulate the column-add side effect. + raw_tenant = (row.get("tenant_id") or "").strip() + current_tenant = raw_tenant if raw_tenant else default_tenant + target_tenant = default_tenant + + if from_metadata_key: + metadata_str = row.get("metadata_str") or "{}" + try: + metadata = json.loads(metadata_str) + except json.JSONDecodeError: + metadata = {} + key_value = metadata.get(from_metadata_key) + if key_value: + target_tenant = str(key_value) + summary["rows_with_metadata_key"] += 1 + + if current_tenant == target_tenant: + continue + to_update.append({"id": row["id"], "tenant_id": target_tenant}) + + summary["rows_to_stamp"] = len(to_update) + + if dry_run: + return summary + + # Apply updates. LanceDB does row updates via merge_insert or per-row + # update; using per-row table.update() with a WHERE clause is the + # simplest correct path and works on every supported version. + for entry in to_update: + safe_id = str(entry["id"]).replace("'", "''") + safe_tenant = str(entry["tenant_id"]).replace("'", "''") + table.update( + where=f"id = '{safe_id}'", + values={"tenant_id": safe_tenant}, + ) + summary["rows_updated"] = len(to_update) + return summary diff --git a/lib/cli/tests/test_memory_migrate.py b/lib/cli/tests/test_memory_migrate.py new file mode 100644 index 0000000000..d7f5720eb6 --- /dev/null +++ b/lib/cli/tests/test_memory_migrate.py @@ -0,0 +1,209 @@ +"""Tests for `crewai memory migrate` -- the per-tenant migration command. + +These tests are hermetic: they create a fresh LanceDB table per test and +operate on it via the underlying lancedb package directly, never crossing +network boundaries. The Memory subsystem is not involved -- the migrate +command is a pure schema/data transformation. +""" + +from __future__ import annotations + +import json +from datetime import datetime +from pathlib import Path + +import lancedb # type: ignore[import-untyped] +import pytest + +from crewai_cli.memory_migrate import run_migrate + + +def _make_pre_isolation_table(path: Path, table_name: str = "memories") -> None: + """Create a LanceDB table that looks like a pre-isolation install. + + Schema mirrors LanceDBStorage._create_table BEFORE PR #1, so it has all + the legacy columns but no tenant_id / user_id. This is what the migrate + command must be able to handle. + """ + path.mkdir(parents=True, exist_ok=True) + db = lancedb.connect(str(path)) + rows = [ + { + "id": "row-1", + "content": "alice's note", + "scope": "/", + "categories_str": "[]", + "metadata_str": json.dumps({"customer_id": "acme"}), + "importance": 0.5, + "created_at": datetime.utcnow().isoformat(), + "last_accessed": datetime.utcnow().isoformat(), + "source": "", + "private": False, + "vector": [0.1, 0.2, 0.3, 0.4], + }, + { + "id": "row-2", + "content": "bob's note", + "scope": "/", + "categories_str": "[]", + "metadata_str": json.dumps({"customer_id": "globex"}), + "importance": 0.5, + "created_at": datetime.utcnow().isoformat(), + "last_accessed": datetime.utcnow().isoformat(), + "source": "", + "private": False, + "vector": [0.1, 0.2, 0.3, 0.4], + }, + { + "id": "row-3", + "content": "untagged note", + "scope": "/", + "categories_str": "[]", + "metadata_str": "{}", + "importance": 0.5, + "created_at": datetime.utcnow().isoformat(), + "last_accessed": datetime.utcnow().isoformat(), + "source": "", + "private": False, + "vector": [0.1, 0.2, 0.3, 0.4], + }, + ] + db.create_table(table_name, rows) + + +def test_migrate_on_missing_directory_is_noop(tmp_path: Path) -> None: + summary = run_migrate( + storage_dir=str(tmp_path / "nonexistent"), + default_tenant="_default", + from_metadata_key=None, + table_name="memories", + dry_run=False, + ) + assert summary["rows_scanned"] == 0 + assert summary["rows_to_stamp"] == 0 + assert summary["rows_updated"] == 0 + + +def test_migrate_on_missing_table_is_noop(tmp_path: Path) -> None: + (tmp_path / "memory").mkdir() + summary = run_migrate( + storage_dir=str(tmp_path / "memory"), + default_tenant="_default", + from_metadata_key=None, + table_name="memories", + dry_run=False, + ) + assert summary["rows_scanned"] == 0 + + +def test_migrate_adds_columns_with_default_tenant(tmp_path: Path) -> None: + store = tmp_path / "memory" + _make_pre_isolation_table(store) + + summary = run_migrate( + storage_dir=str(store), + default_tenant="_default", + from_metadata_key=None, + table_name="memories", + dry_run=False, + ) + + # The column was added with default '_default' so every row has a + # tenant_id matching default_tenant; no per-row update needed. + assert summary["rows_scanned"] == 3 + assert summary["rows_to_stamp"] == 0 + assert summary["rows_updated"] == 0 + + # Verify the schema migration happened: tenant_id column exists. + db = lancedb.connect(str(store)) + table = db.open_table("memories") + field_names = {f.name for f in table.schema} + assert "tenant_id" in field_names + assert "user_id" in field_names + + +def test_migrate_with_metadata_key_stamps_per_row(tmp_path: Path) -> None: + store = tmp_path / "memory" + _make_pre_isolation_table(store) + + summary = run_migrate( + storage_dir=str(store), + default_tenant="_default", + from_metadata_key="customer_id", + table_name="memories", + dry_run=False, + ) + + assert summary["rows_scanned"] == 3 + # Rows 1 and 2 have customer_id; row 3 does not. + assert summary["rows_with_metadata_key"] == 2 + # Rows 1 and 2 need a non-default tenant; row 3 stays '_default'. + assert summary["rows_to_stamp"] == 2 + assert summary["rows_updated"] == 2 + + db = lancedb.connect(str(store)) + table = db.open_table("memories") + by_id = {row["id"]: row for row in table.search().to_list()} + assert by_id["row-1"]["tenant_id"] == "acme" + assert by_id["row-2"]["tenant_id"] == "globex" + assert by_id["row-3"]["tenant_id"] == "_default" + + +def test_migrate_is_idempotent(tmp_path: Path) -> None: + store = tmp_path / "memory" + _make_pre_isolation_table(store) + + first = run_migrate( + storage_dir=str(store), + default_tenant="_default", + from_metadata_key="customer_id", + table_name="memories", + dry_run=False, + ) + second = run_migrate( + storage_dir=str(store), + default_tenant="_default", + from_metadata_key="customer_id", + table_name="memories", + dry_run=False, + ) + + assert first["rows_updated"] == 2 + assert second["rows_scanned"] == 3 + # Second pass finds nothing to change. + assert second["rows_to_stamp"] == 0 + assert second["rows_updated"] == 0 + + +def test_migrate_dry_run_does_not_write(tmp_path: Path) -> None: + store = tmp_path / "memory" + _make_pre_isolation_table(store) + + summary = run_migrate( + storage_dir=str(store), + default_tenant="_default", + from_metadata_key="customer_id", + table_name="memories", + dry_run=True, + ) + + assert summary["rows_to_stamp"] == 2 + assert summary["rows_updated"] == 0 + + # Verify the table was not modified: the new columns were NOT added + # because dry_run skipped the add_columns call. + db = lancedb.connect(str(store)) + table = db.open_table("memories") + field_names = {f.name for f in table.schema} + assert "tenant_id" not in field_names + + +def test_migrate_rejects_empty_default_tenant(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="default_tenant"): + run_migrate( + storage_dir=str(tmp_path), + default_tenant="", + from_metadata_key=None, + table_name="memories", + dry_run=False, + ) From a23620721d3cb83ce3284303d68bcebe3b3be97a Mon Sep 17 00:00:00 2001 From: John_J <79534962+John-Jepsen@users.noreply.github.com> Date: Thu, 28 May 2026 22:59:36 -0500 Subject: [PATCH 7/8] docs(memory): add per-tenant memory isolation guide User-facing Mintlify page at docs/en/concepts/memory-isolation.mdx covers when to use tenant_id, the three concepts (tenant_id, user_id, scope) and what each is/isn't a security boundary for, common patterns (SaaS, Crew.kickoff, Flows, instance-bound), the migration command, the threat model, FAQ, and API reference. docs.json: adds en/concepts/memory-isolation right after en/concepts/memory in all 15 nav sections (one per language variant and primary nav). docs/en/concepts/memory.mdx: fixes the "Customer support (per-customer context)" example that previously implied scope provided isolation. It now uses tenant_id and adds a Warning callout pointing at the new isolation page. Translations to ar/ko/pt-BR for the new page can land in follow-up PRs. Refs: design-docs/0001-per-tenant-memory-isolation.md --- docs/docs.json | 15 ++ docs/en/concepts/memory-isolation.mdx | 314 ++++++++++++++++++++++++++ docs/en/concepts/memory.mdx | 19 +- 3 files changed, 343 insertions(+), 5 deletions(-) create mode 100644 docs/en/concepts/memory-isolation.mdx diff --git a/docs/docs.json b/docs/docs.json index bcbcefe0e0..c40780c0bd 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -168,6 +168,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", @@ -685,6 +686,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", @@ -1202,6 +1204,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", @@ -1685,6 +1688,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", @@ -2168,6 +2172,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", @@ -2651,6 +2656,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", @@ -3144,6 +3150,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", @@ -3637,6 +3644,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", @@ -4130,6 +4138,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", @@ -4622,6 +4631,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", @@ -5104,6 +5114,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", @@ -5586,6 +5597,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", @@ -6068,6 +6080,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", @@ -6552,6 +6565,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", @@ -7035,6 +7049,7 @@ "en/concepts/collaboration", "en/concepts/training", "en/concepts/memory", + "en/concepts/memory-isolation", "en/concepts/reasoning", "en/concepts/planning", "en/concepts/testing", diff --git a/docs/en/concepts/memory-isolation.mdx b/docs/en/concepts/memory-isolation.mdx new file mode 100644 index 0000000000..506fa50a7d --- /dev/null +++ b/docs/en/concepts/memory-isolation.mdx @@ -0,0 +1,314 @@ +--- +title: Memory Isolation +description: Per-tenant and per-user memory partitioning for multi-user CrewAI deployments. +icon: shield-halved +mode: "wide" +--- + +## When You Need This + +If your CrewAI app serves **more than one user from a single process** — a SaaS product, a multi-tenant internal tool, a customer-support bot that talks to many customers — you need per-tenant memory isolation. + +Without it, every user's memories pool into the same vector collection. Alice's `recall()` will return rows that Bob saved, ranked by whatever the embedder thinks is most similar. That is a data leak, not a UX nit. + +If you are using CrewAI from a single-user CLI, a notebook, or a single-tenant deployment, you do **not** need to change anything. Memory keeps working exactly as before with a built-in default tenant. + + +**Not the same thing as `scope`.** The `scope` parameter on `Memory` (`/customer/acme-corp`, `/agent/researcher`, etc.) is a *structural* path for organizing memories by topic. It is **not** a security boundary. A `recall()` that omits the scope will still return rows from any scope. For real isolation between users or tenants, use `tenant_id`. + + +## The Guarantee + +> A `recall()` scoped to tenant A will never return a memory written by tenant B — under any ranking, embedding collision, query depth, or backend. + +This is enforced at the storage layer, not by the calling agent. An agent that asks for the wrong data does not get it. + +## Quick Start + +```python +from crewai import Memory + +memory = Memory() + +# Alice and Bob each save a credential under the same scope +memory.remember("My API key is alice-secret-123", tenant_id="alice") +memory.remember("My API key is bob-secret-456", tenant_id="bob") + +# Alice's recall only sees Alice's row +alice = memory.recall("what is my api key", tenant_id="alice") +assert all(m.record.tenant_id == "alice" for m in alice) + +# Bob's recall only sees Bob's row +bob = memory.recall("what is my api key", tenant_id="bob") +assert all(m.record.tenant_id == "bob" for m in bob) +``` + +That is the entire feature, end to end. Everything below is pattern guidance. + +## Tenants, Users, and Scopes + +Three concepts. They look similar; they are not. + +| Concept | Set at | Purpose | Is it a security boundary? | +|---|---|---|---| +| `tenant_id` | Per call or per `Memory()` | **Isolation.** Hard wall between customers, organizations, accounts. | **Yes.** | +| `user_id` | Per call or per `Memory()` | **Sub-partition** within a tenant. Lets a tenant admin recall across their own users without crossing tenants. | Soft. | +| `scope` | Per record | **Structural.** Organizes memories into a topic tree (`/project/alpha`, `/agent/researcher`). | No. | + +A SaaS deployment usually wants: + +- `tenant_id="customer_42"` — the customer who owns the data. +- `user_id="alice"` — the person inside that customer making the request. +- `scope="/agent/researcher/findings"` — what kind of memory this is. + +Examples: + +```python +# Recall only Alice's memories inside customer_42 +memory.recall("my preferences", tenant_id="customer_42", user_id="alice") + +# Recall all memories inside customer_42 (tenant admin view) +memory.recall("recent decisions", tenant_id="customer_42") + +# Forget everything for customer_42 — right-to-be-forgotten +memory.forget(tenant_id="customer_42") +``` + +## Common Patterns + +### Multi-tenant SaaS + +One `Memory` instance, many tenants. Resolve the tenant from your auth layer (JWT, session, request header) and pass it on every call. + +```python +from crewai import Memory + +memory = Memory() # one process-wide instance + +def handle_request(request): + tenant_id = request.user.tenant_id # from your auth + user_id = request.user.id + + memory.remember( + request.message, + tenant_id=tenant_id, + user_id=user_id, + ) + + return memory.recall( + request.query, + tenant_id=tenant_id, + user_id=user_id, + ) +``` + +### With `Crew.kickoff()` + +Pass identity at kickoff. Agents inside the crew receive a memory handle that is already bound to the right tenant — they cannot recall outside it, even if a prompt-injected agent tries. + +```python +from crewai import Crew, Agent, Task + +crew = Crew(agents=[...], tasks=[...], memory=True) + +result = crew.kickoff( + inputs={"topic": "Q4 strategy"}, + tenant_id="customer_42", + user_id="alice", +) +``` + +### Inside Flows + +`tenant_id` and `user_id` thread through Flow state. Set them once at the start; every downstream Crew kickoff inherits them. + +```python +from crewai.flow.flow import Flow, start, listen + +class SupportFlow(Flow): + @start() + def begin(self): + self.state.tenant_id = self.inputs["tenant_id"] + self.state.user_id = self.inputs["user_id"] + + @listen(begin) + def answer(self): + return support_crew.kickoff( + inputs={"question": self.inputs["question"]}, + tenant_id=self.state.tenant_id, + user_id=self.state.user_id, + ) +``` + +### Tenant-bound `Memory` instances + +If you prefer dependency injection over per-call kwargs, bind `tenant_id` at construction: + +```python +def memory_for(request) -> Memory: + return Memory(tenant_id=request.user.tenant_id, user_id=request.user.id) +``` + +Per-call kwargs always win over instance defaults, so this pattern composes with the SaaS pattern above when you need to. + +### Customer-support bot (replacing the old `scope` pattern) + +The old documentation example used `scope="/customer/acme-corp"` for per-customer context. That was structural, not isolated. The correct pattern is: + +```python +# Before (NOT isolated) +memory.remember("Prefers email", scope="/customer/acme-corp") + +# After (isolated) +memory.remember("Prefers email", tenant_id="acme-corp") +``` + +You can still use `scope` *inside* a tenant for organization: + +```python +memory.remember( + "Prefers email", + tenant_id="acme-corp", + scope="/preferences/communication", +) +``` + +## Migrating Existing Deployments + +If you have an existing CrewAI app with memory data on disk, you have two questions to answer: + +**1. Are all your existing memories owned by one identity?** + +If yes (single-user app, single-tenant SaaS), there is nothing to do. The default tenant `"_default"` covers all existing rows. Your code keeps working unchanged. + +**2. Did you previously use `scope` or `source` to fake per-user isolation?** + +Run the one-shot migration to assign each row a real `tenant_id`: + +```bash +crewai memory migrate \ + --storage-dir $CREWAI_STORAGE_DIR \ + --default-tenant _default \ + --dry-run +``` + +`--dry-run` prints what would change without writing. Remove it once you're satisfied. + +If you need to assign different tenants to different existing rows (e.g. you stored a customer ID in metadata), the migration command supports a `--from-metadata-key` flag: + +```bash +crewai memory migrate \ + --storage-dir $CREWAI_STORAGE_DIR \ + --from-metadata-key customer_id +``` + +This reads each row's `metadata["customer_id"]` and sets it as `tenant_id`. Rows missing the key fall back to `--default-tenant`. + + +The migration is **idempotent** — running it twice is safe. Run it during a maintenance window; do not run it against a live writing process. + + +## What Gets Filtered + +| Operation | Scoped to tenant? | +|---|---| +| `memory.recall(query, tenant_id="alice")` | Yes — vector search filters at the backend. | +| `memory.forget(tenant_id="alice")` | Yes — only Alice's rows are deleted. | +| `memory.list_records(tenant_id="alice")` | Yes. | +| `memory.list_scopes(tenant_id="alice")` | Yes — Alice only sees scopes she has written to. | +| `memory.list_categories(tenant_id="alice")` | Yes. | +| `memory.count(tenant_id="alice")` | Yes. | +| `memory.reset(tenant_id="alice")` | Yes — wipes Alice's tenant only. There is no "wipe everything" call. | +| Deep recall via `RecallFlow` (`depth="deep"`) | Yes — the LLM exploration loop sees only the tenant's rows. | +| Embedder | Shared. Embeddings are not a security boundary. | + +## Threat Model + +What this feature defends against, in plain language: + +- **A user querying for another user's data.** The vector store filters `WHERE tenant_id = ?` before ranking, so foreign-tenant rows are never candidates. +- **A buggy or forgotten filter in your code.** If you forget to pass `tenant_id`, you fall back to the `"_default"` tenant — which is **its own bucket**, not a global view. The fallback is non-leaking by design. +- **A prompt-injected agent trying to "ignore previous instructions" and recall everything.** Agents receive a `Memory` handle that is already bound to a tenant. There is no API on that handle to widen the scope. +- **A broken backend filter.** A second-line check inside the storage wrapper re-verifies every returned row's `tenant_id`. If a row sneaks through, the call raises an exception loudly rather than returning a quietly-filtered result. + +What this feature does **not** do: + +- **Authenticate the caller.** Verifying that "the person calling with `tenant_id='alice'` is actually Alice" is your auth layer's job. This feature enforces the predicate; it does not validate it. +- **Encrypt data per tenant.** Records are stored in the same database in plaintext (or whatever your backend's at-rest encryption setting is). Per-tenant encryption is a separate feature. +- **Prevent operator misuse.** A developer with database access can read any row. This is no different from any other application database. + +## Comparison + +| | Built-in `Memory` (this page) | External providers (mem0, etc.) | +|---|---|---| +| Per-user / per-tenant isolation | **Yes** | Yes | +| Survives restart | Yes | Yes | +| Cross-crew shared memory | Limited (via `scope`) | First-class | +| Long-term consolidation, decay | Basic | Advanced | +| Network dependency | None (local LanceDB/Chroma) | Yes | +| Self-hosted | Yes | Depends on provider | + +If you need cross-crew knowledge graphs, advanced consolidation, or a managed service, an external provider is the right tool. If you need isolation for a self-hosted or local-first deployment, use built-in `Memory` with `tenant_id`. + +## FAQ + +**Do I have to pass `tenant_id` everywhere now?** +No. If you don't pass it, calls fall back to the `"_default"` tenant. All your existing single-user code keeps working, and all existing rows on disk are readable as `"_default"`. + +**Can I have a `Memory` instance per tenant instead of passing `tenant_id` per call?** +Yes: `Memory(tenant_id="customer_42")`. Both patterns work. Per-call kwargs override instance defaults, so you can mix them. + +**What happens if a record's `tenant_id` and my `ScopedStorage`'s `tenant_id` disagree on save?** +The save raises `PermissionError`. Silent relabeling would mask bugs; loud failure surfaces them. + +**Is `tenant_id` sent to my LLM or embedder?** +No. It is a storage-layer predicate. The LLM sees the content and the query; it does not see the tenant identifier unless you put it in the content yourself. + +**Does this work with the deep-recall LLM exploration (`depth="deep"`)?** +Yes. The exploration loop searches through the same scoped storage handle and cannot escape the tenant filter. + +**What about the legacy `BaseRAGStorage` path?** +The same `tenant_id` field and filter applies. Both code paths are covered. + +**Can I delete all of one tenant's data for a right-to-be-forgotten request?** +Yes: `memory.forget(tenant_id="customer_42")` deletes every record for that tenant. `memory.reset(tenant_id="customer_42")` is the equivalent for the LanceDB/Chroma backend. + +**Should I use `tenant_id` or `user_id` or both?** +- One organization, multiple users: `tenant_id="org_id"`, `user_id="user_id"`. +- One isolation level only (e.g. each end-user is independent): `tenant_id="user_id"` and skip `user_id`. +- Pick the field name that matches your real boundary. `tenant_id` is the hard wall; `user_id` is the soft partition inside it. + +## API Reference + +### `Memory(...)` constructor + +| Argument | Type | Default | Notes | +|---|---|---|---| +| `tenant_id` | `str` | `"_default"` | Default tenant for all calls on this instance. | +| `user_id` | `str \| None` | `None` | Default user_id for all calls on this instance. | + +### `Memory.remember(content, *, tenant_id=None, user_id=None, ...)` + +Stores a record under the resolved tenant. If `tenant_id` is omitted, falls back to the instance default, then to `"_default"`. + +### `Memory.recall(query, *, tenant_id=None, user_id=None, ...)` + +Returns matches restricted to the resolved tenant. With `user_id` set, further restricted to that user. + +### `Memory.forget(*, tenant_id=None, user_id=None, ...)` + +Deletes records matching the resolved tenant (and optionally user). Returns the count deleted. + +### `Crew.kickoff(inputs={...}, *, tenant_id=None, user_id=None)` + +Sets the tenant context for all agents and memory access during this kickoff. + +### `crewai memory migrate` + +One-shot CLI command to stamp existing unscoped data with a tenant. See [Migrating Existing Deployments](#migrating-existing-deployments). + +## See Also + +- [Memory](/en/concepts/memory) — the underlying unified memory system. +- [Knowledge](/en/concepts/knowledge) — for shared, read-only knowledge sources (not per-user data). +- [Production Architecture](/en/concepts/production-architecture) — deployment patterns for multi-tenant CrewAI. diff --git a/docs/en/concepts/memory.mdx b/docs/en/concepts/memory.mdx index 954d5efe6e..51be99b58d 100644 --- a/docs/en/concepts/memory.mdx +++ b/docs/en/concepts/memory.mdx @@ -278,15 +278,24 @@ writer_view = memory.slice( ``` **Customer support (per-customer context):** + + +For real per-customer isolation in a multi-tenant deployment, use `tenant_id`, not `scope`. Scope is a structural path; it does not stop a recall from another customer's session returning the wrong row. See [Memory Isolation](/en/concepts/memory-isolation). + + ```python memory = Memory() -# Each customer gets isolated context -memory.remember("Prefers email communication", scope="/customer/acme-corp") -memory.remember("On enterprise plan, 50 seats", scope="/customer/acme-corp") +# Each customer is isolated at the storage layer via tenant_id. +memory.remember("Prefers email communication", tenant_id="acme-corp", + scope="/preferences/communication") +memory.remember("On enterprise plan, 50 seats", tenant_id="acme-corp", + scope="/account/plan") -# Shared product docs are accessible to all agents -memory.remember("Rate limit is 1000 req/min on enterprise plan", scope="/product/docs") +# Shared product docs use the default tenant and are read by every +# customer's recall via a separate Memory or a tenant-admin view. +memory.remember("Rate limit is 1000 req/min on enterprise plan", + scope="/product/docs") ``` From 90073fbbd7296ac009ab0c139ce18d5dafbf10c5 Mon Sep 17 00:00:00 2001 From: John_J <79534962+John-Jepsen@users.noreply.github.com> Date: Fri, 29 May 2026 00:11:06 -0500 Subject: [PATCH 8/8] fix(memory): address CodeRabbit review on PR #5967 Major: - LanceDBStorage.delete(): record_ids + older_than took the fast path and silently dropped older_than; record_ids + categories/metadata_filter scanned by the predicates but never intersected with record_ids. Now matches Qdrant semantics: fast path is gated on no other predicates, scan branch intersects with allowed_ids = set(record_ids). Two new regression tests pin both branches. - crewai memory migrate now streams rows in paginated batches instead of materializing up to 10M rows in one call. Only id, tenant_id, and metadata_str are selected -- the heavy vector column is never read. Pre-isolation tables (no tenant_id column) adapt the select set so dry-run works against unmigrated schemas. New test fixes the page size and confirms a multi-page table is scanned end-to-end. - Adds tenant-isolation tests to test_qdrant_edge_storage so when the optional qdrant_edge dep is installed the Qdrant backend's isolation contract is verified, not just LanceDB's. - Adds test_deep_recall_honors_tenant exercising the depth='deep' RecallFlow path with two tenants and colliding embeddings. Minor: - CLI "without key" count now derives from rows_scanned instead of rows_to_stamp, so it stays correct on reruns where rows_to_stamp shrinks to zero as the data converges. - mock_embedder in test_tenant_isolation now uses a SHA-256 digest instead of abs(hash(t)) % 1000, removing the 1000-bucket collision risk that could let two different texts produce identical vectors and make isolation assertions pass vacuously. Skipped (intentional): - Design-doc nits about startup-warning scan limits and _quote escaping strategy -- the warning is documented as a one-line log, not a load- bearing scan, and the escaping helper is documented inline in the LanceDB implementation. - The user_id sentinel suggestion -- per-call user_id=None meaning "instance default" is the intended SaaS pattern; callers who need "all users within tenant" construct a Memory without a default user_id. Test counts: 140 passed, 21 skipped (was 136/19; the deltas are the two new Qdrant isolation tests, the new deep-recall test, and the two new LanceDB delete-intersection tests). --- lib/cli/src/crewai_cli/cli.py | 8 ++- lib/cli/src/crewai_cli/memory_migrate.py | 67 ++++++++++++++--- lib/cli/tests/test_memory_migrate.py | 56 +++++++++++++++ .../crewai/memory/storage/lancedb_storage.py | 16 ++++- .../tests/memory/test_qdrant_edge_storage.py | 46 ++++++++++++ .../tests/memory/test_tenant_isolation.py | 71 +++++++++++++++++-- .../tests/memory/test_unified_memory.py | 65 +++++++++++++++++ 7 files changed, 314 insertions(+), 15 deletions(-) diff --git a/lib/cli/src/crewai_cli/cli.py b/lib/cli/src/crewai_cli/cli.py index f0774e438b..4ef442da44 100644 --- a/lib/cli/src/crewai_cli/cli.py +++ b/lib/cli/src/crewai_cli/cli.py @@ -415,7 +415,13 @@ def memory_migrate( if from_metadata_key: click.echo(f"From metadata key: {from_metadata_key}") click.echo(f" with key set: {summary['rows_with_metadata_key']}") - click.echo(f" without key: {summary['rows_to_stamp'] - summary['rows_with_metadata_key']}") + # "without key" is derived from rows_scanned, not rows_to_stamp, + # because rows_to_stamp shrinks on reruns once the data is migrated + # and would produce a misleading (possibly negative) display value. + click.echo( + f" without key: " + f"{summary['rows_scanned'] - summary['rows_with_metadata_key']}" + ) if dry_run: click.echo("DRY RUN -- no changes written. Re-run without --dry-run to apply.") else: diff --git a/lib/cli/src/crewai_cli/memory_migrate.py b/lib/cli/src/crewai_cli/memory_migrate.py index c3f408a300..e17d1b8164 100644 --- a/lib/cli/src/crewai_cli/memory_migrate.py +++ b/lib/cli/src/crewai_cli/memory_migrate.py @@ -14,10 +14,16 @@ The command is idempotent. Running it twice does not change anything that was already correct. + +The scan is paginated and column-selected: it never loads the ``vector`` column +(which dominates per-row memory) and never tries to materialize the whole table +at once. The previous implementation capped the read at 10_000_000 rows and +silently truncated past it; this one streams every row in fixed-size pages. """ from __future__ import annotations +from collections.abc import Iterator import json import logging import os @@ -28,6 +34,16 @@ _logger = logging.getLogger(__name__) +# Only the columns we actually inspect during migration. Crucially this omits +# the ``vector`` column -- on a 1536-dimension index, that column is ~6 KiB per +# row and dominates memory use. We do not need it to stamp tenant_id. +_SCAN_COLUMNS = ["id", "tenant_id", "metadata_str"] + +# Page size for the paginated scan. 5_000 keeps peak memory bounded +# (~5 MiB per page including json metadata) while amortizing per-call overhead. +_SCAN_PAGE_SIZE = 5_000 + + class MigrateSummary(TypedDict): storage_dir: str table_name: str @@ -37,6 +53,41 @@ class MigrateSummary(TypedDict): rows_updated: int +def _iter_rows_paginated(table: Any) -> Iterator[dict[str, Any]]: + """Yield rows from a LanceDB table in fixed-size pages, columns-selected. + + Selects only columns we actually inspect during migration. The heavy + ``vector`` column is never materialized. ``tenant_id`` is only selected + when the schema already has it -- in dry-run against a pre-isolation + table the column does not exist yet, and asking for it raises a LanceDB + schema error. A missing tenant_id column is equivalent to every row + being unstamped for the purpose of migration accounting. + + Pagination uses an offset cursor. The migration is read-then-write with + no concurrent writers expected, so the per-row id stays stable across + pages. + """ + available = {field.name for field in table.schema} + select_columns = [c for c in _SCAN_COLUMNS if c in available] + offset = 0 + while True: + query = table.search().select(select_columns).limit(_SCAN_PAGE_SIZE) + # .offset() is the chained form on lancedb >= 0.16; older versions + # only support reading from the start. The migration documents the + # supported lancedb version range in pyproject; here we assume the + # chain is available. + if offset: + query = query.offset(offset) + page = query.to_list() + if not page: + return + for row in page: + yield row + if len(page) < _SCAN_PAGE_SIZE: + return + offset += len(page) + + def _resolve_storage_dir(storage_dir: str | None) -> Path: """Pick the storage directory the same way LanceDBStorage does. @@ -117,16 +168,16 @@ def run_migrate( exc, ) - # Scan every row that needs stamping. - # A row needs stamping if: + # Scan every row that needs stamping. The iterator pages through the + # table fetching only id, tenant_id, and metadata_str -- the heavy + # ``vector`` column is never materialized. A row needs stamping if: # - tenant_id is missing/empty, OR - # - --from-metadata-key was provided AND metadata[key] differs from row's - # current tenant_id (idempotent: re-runs don't re-update unchanged rows). - rows = table.search().limit(10_000_000).to_list() - summary["rows_scanned"] = len(rows) - + # - --from-metadata-key was provided AND metadata[key] differs from the + # row's current tenant_id (idempotent: re-runs don't re-update + # unchanged rows). to_update: list[dict[str, Any]] = [] - for row in rows: + for row in _iter_rows_paginated(table): + summary["rows_scanned"] += 1 # A row that pre-dates the tenant_id column (or has an empty value) # is treated as if it had been stamped with default_tenant. This makes # dry-run and real-run report identical rows_to_stamp counts -- the diff --git a/lib/cli/tests/test_memory_migrate.py b/lib/cli/tests/test_memory_migrate.py index d7f5720eb6..348b8b923d 100644 --- a/lib/cli/tests/test_memory_migrate.py +++ b/lib/cli/tests/test_memory_migrate.py @@ -198,6 +198,62 @@ def test_migrate_dry_run_does_not_write(tmp_path: Path) -> None: assert "tenant_id" not in field_names +def test_migrate_pagination_streams_past_page_size( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """Tables larger than a single page must be fully scanned, not truncated. + + Pre-fix bug: the migrator did ``.limit(10_000_000).to_list()`` and silently + dropped anything past that cap. The paginated streamer must keep going + until an empty page is returned. + """ + from crewai_cli import memory_migrate + + # Force a small page size so a small fixture exercises the loop. + monkeypatch.setattr(memory_migrate, "_SCAN_PAGE_SIZE", 2) + + store = tmp_path / "memory" + store.mkdir(parents=True, exist_ok=True) + db = lancedb.connect(str(store)) + rows = [ + { + "id": f"row-{i}", + "content": f"item {i}", + "scope": "/", + "categories_str": "[]", + "metadata_str": json.dumps({"customer_id": f"cust-{i}"}), + "importance": 0.5, + "created_at": datetime.utcnow().isoformat(), + "last_accessed": datetime.utcnow().isoformat(), + "source": "", + "private": False, + "vector": [0.1, 0.2, 0.3, 0.4], + } + for i in range(7) # 4 pages: 2, 2, 2, 1 + ] + db.create_table("memories", rows) + + summary = memory_migrate.run_migrate( + storage_dir=str(store), + default_tenant="_default", + from_metadata_key="customer_id", + table_name="memories", + dry_run=False, + ) + + # All 7 rows must be visited even though the page size is 2. + assert summary["rows_scanned"] == 7 + assert summary["rows_with_metadata_key"] == 7 + assert summary["rows_updated"] == 7 + + by_id = { + r["id"]: r + for r in db.open_table("memories").search().to_list() + } + for i in range(7): + assert by_id[f"row-{i}"]["tenant_id"] == f"cust-{i}" + + def test_migrate_rejects_empty_default_tenant(tmp_path: Path) -> None: with pytest.raises(ValueError, match="default_tenant"): run_migrate( diff --git a/lib/crewai/src/crewai/memory/storage/lancedb_storage.py b/lib/crewai/src/crewai/memory/storage/lancedb_storage.py index 2ec1849870..c7f728a034 100644 --- a/lib/crewai/src/crewai/memory/storage/lancedb_storage.py +++ b/lib/crewai/src/crewai/memory/storage/lancedb_storage.py @@ -505,20 +505,32 @@ def delete( return 0 tenant_clause = _tenant_where(tenant_id, user_id) with store_lock(self._lock_name): - if record_ids and not (categories or metadata_filter): + # Fast path: pure record_ids delete with no other predicates. + # If any of older_than / categories / metadata_filter is also + # specified, fall through to the scan branch so those predicates + # are honored AND intersected with record_ids. + if record_ids and not (categories or metadata_filter or older_than): before = int(self._table.count_rows()) ids_expr = ", ".join(f"'{_sql_quote(rid)}'" for rid in record_ids) self._do_write( "delete", f"({tenant_clause}) AND id IN ({ids_expr})" ) return before - int(self._table.count_rows()) - if categories or metadata_filter: + if categories or metadata_filter or (record_ids and older_than): rows = self._scan_rows( scope_prefix, tenant_id=tenant_id, user_id=user_id ) + # When record_ids is provided alongside other predicates, the + # delete is the INTERSECTION of all of them: a row must match + # the predicates AND be in record_ids. + allowed_ids: set[str] | None = ( + set(record_ids) if record_ids else None + ) to_delete: list[str] = [] for row in rows: record = self._row_to_record(row) + if allowed_ids is not None and record.id not in allowed_ids: + continue if categories and not any( c in record.categories for c in categories ): diff --git a/lib/crewai/tests/memory/test_qdrant_edge_storage.py b/lib/crewai/tests/memory/test_qdrant_edge_storage.py index 19549ba089..502dda4b88 100644 --- a/lib/crewai/tests/memory/test_qdrant_edge_storage.py +++ b/lib/crewai/tests/memory/test_qdrant_edge_storage.py @@ -67,6 +67,52 @@ def test_save_search(storage: QdrantEdgeStorage) -> None: assert score >= 0.0 +def test_search_isolates_tenants_with_colliding_embeddings( + storage: QdrantEdgeStorage, +) -> None: + """Two tenants store records with identical embeddings; each tenant's + search must return only its own row. This is the Qdrant-backed mirror of + the LanceDB isolation contract in test_tenant_isolation.py. + """ + embedding = [0.5, 0.5, 0.5, 0.5] + alice = _rec(content="alice secret", scope="/", embedding=embedding) + alice.tenant_id = "alice" + bob = _rec(content="bob secret", scope="/", embedding=embedding) + bob.tenant_id = "bob" + storage.save([alice, bob]) + + alice_hits = storage.search(embedding, tenant_id="alice", limit=10) + bob_hits = storage.search(embedding, tenant_id="bob", limit=10) + + assert len(alice_hits) == 1 + assert alice_hits[0][0].content == "alice secret" + assert alice_hits[0][0].tenant_id == "alice" + assert not any("bob" in r.content for r, _ in alice_hits) + + assert len(bob_hits) == 1 + assert bob_hits[0][0].content == "bob secret" + assert bob_hits[0][0].tenant_id == "bob" + assert not any("alice" in r.content for r, _ in bob_hits) + + +def test_delete_is_tenant_scoped(storage: QdrantEdgeStorage) -> None: + """A tenant-scoped delete must not touch another tenant's rows.""" + alice = _rec(content="alice", scope="/") + alice.tenant_id = "alice" + bob = _rec(content="bob", scope="/") + bob.tenant_id = "bob" + storage.save([alice, bob]) + + deleted = storage.delete(tenant_id="alice") + assert deleted == 1 + + bob_remaining = storage.list_records(tenant_id="bob") + assert len(bob_remaining) == 1 + assert bob_remaining[0].content == "bob" + alice_remaining = storage.list_records(tenant_id="alice") + assert alice_remaining == [] + + def test_delete_count(storage: QdrantEdgeStorage) -> None: r = _rec(scope="/") storage.save([r]) diff --git a/lib/crewai/tests/memory/test_tenant_isolation.py b/lib/crewai/tests/memory/test_tenant_isolation.py index fcf2852e8a..e6f67528f5 100644 --- a/lib/crewai/tests/memory/test_tenant_isolation.py +++ b/lib/crewai/tests/memory/test_tenant_isolation.py @@ -44,14 +44,30 @@ def lance_storage(lance_path: Path) -> LanceDBStorage: @pytest.fixture def mock_embedder() -> MagicMock: - """Embedder that returns DIFFERENT embeddings per text, never identical.""" + """Embedder that returns DIFFERENT embeddings per text, never identical. + + A naive ``abs(hash(t)) % N`` mapping buckets distinct texts into N + collisions, which would let two different inputs produce identical + vectors and make the isolation assertions in this file pass vacuously + (the foreign-tenant row would be filtered by tenant predicate, not by + semantic distance). Using a SHA-256 digest of the text gives 2^32 worth + of distinct values across four float buckets while staying deterministic + across test runs. + """ + import hashlib + m = MagicMock() def embed(texts: list[str]) -> list[list[float]]: - out = [] + out: list[list[float]] = [] for t in texts: - h = abs(hash(t)) % 1000 / 1000.0 - out.append([h, 1.0 - h, h * 0.5, 1.0 - h * 0.5]) + digest = hashlib.sha256(t.encode("utf-8")).digest() + # 4 buckets of 4 bytes each -> 4 floats in [0, 1). + vec = [ + int.from_bytes(digest[i * 4 : (i + 1) * 4], "big") / 2**32 + for i in range(4) + ] + out.append(vec) return out m.side_effect = embed @@ -304,6 +320,53 @@ def test_forget_is_scoped( bob_hits = m.recall("note", tenant_id="bob", depth="shallow") assert any("bob note" in h.record.content for h in bob_hits) + def test_deep_recall_honors_tenant( + self, tmp_path: Path, mock_embedder: MagicMock + ) -> None: + """depth='deep' goes through RecallFlow with LLM-driven exploration. + + The design doc claims RecallFlow is safe by construction because it + holds a ScopedStorage, not the raw backend. This test exercises that + path end to end so the claim has receipts. Without this test, every + recall in this class would be the depth='shallow' path and the + deep-recall code would be uncovered for isolation. + """ + from crewai.memory.unified_memory import Memory + + # An LLM mock that always says "fall through to direct search" -- we + # don't need the LLM's analysis to validate isolation; we need the + # RecallFlow path to execute at all. + llm = MagicMock() + llm.call.return_value = "{}" + + m = Memory( + storage=str(tmp_path / "mem.lance"), + llm=llm, + embedder=mock_embedder, + ) + m.remember( + "alice's birthday is March 4", + tenant_id="alice", + scope="/personal", + categories=["birthday"], + importance=0.5, + ) + m.remember( + "bob's birthday is March 4", + tenant_id="bob", + scope="/personal", + categories=["birthday"], + importance=0.5, + ) + + alice = m.recall("when is the birthday", tenant_id="alice", depth="deep") + bob = m.recall("when is the birthday", tenant_id="bob", depth="deep") + + assert all(h.record.tenant_id == "alice" for h in alice) + assert all(h.record.tenant_id == "bob" for h in bob) + assert not any("bob" in h.record.content for h in alice) + assert not any("alice" in h.record.content for h in bob) + def test_instance_default_tenant_holds( self, tmp_path: Path, mock_embedder: MagicMock ) -> None: diff --git a/lib/crewai/tests/memory/test_unified_memory.py b/lib/crewai/tests/memory/test_unified_memory.py index 9545c56afe..5a9d3a3aff 100644 --- a/lib/crewai/tests/memory/test_unified_memory.py +++ b/lib/crewai/tests/memory/test_unified_memory.py @@ -131,6 +131,71 @@ def test_lancedb_delete_count(lancedb_path: Path) -> None: assert storage.count(tenant_id="_default") == 0 +def test_lancedb_delete_record_ids_intersects_with_other_filters( + lancedb_path: Path, +) -> None: + """delete(record_ids=..., older_than=...) must INTERSECT both predicates. + + Pre-fix bug: record_ids + older_than took the record_ids fast path and + silently ignored older_than. Equivalently, record_ids + categories scanned + by categories and never intersected the result with record_ids. + """ + from datetime import datetime, timedelta + + from crewai.memory.storage.lancedb_storage import LanceDBStorage + + storage = LanceDBStorage(path=str(lancedb_path), vector_dim=4) + + old = MemoryRecord( + id="old-id", + content="old", + scope="/", + embedding=[0.0] * 4, + created_at=datetime(2020, 1, 1), + ) + new = MemoryRecord( + id="new-id", + content="new", + scope="/", + embedding=[0.0] * 4, + created_at=datetime.utcnow(), + ) + storage.save([old, new]) + + # record_ids targets both, but older_than="1 day ago" should only match `old`. + cutoff = datetime.utcnow() - timedelta(days=1) + deleted = storage.delete( + tenant_id="_default", + record_ids=["old-id", "new-id"], + older_than=cutoff, + ) + assert deleted == 1 + remaining = storage.list_records(tenant_id="_default") + assert {r.id for r in remaining} == {"new-id"} + + +def test_lancedb_delete_record_ids_intersects_with_categories( + lancedb_path: Path, +) -> None: + """delete(record_ids=..., categories=...) must INTERSECT both predicates.""" + from crewai.memory.storage.lancedb_storage import LanceDBStorage + + storage = LanceDBStorage(path=str(lancedb_path), vector_dim=4) + storage.save([ + MemoryRecord(id="a", content="a", scope="/", categories=["x"], embedding=[0.0] * 4), + MemoryRecord(id="b", content="b", scope="/", categories=["x"], embedding=[0.0] * 4), + MemoryRecord(id="c", content="c", scope="/", categories=["y"], embedding=[0.0] * 4), + ]) + + # record_ids says [a, c]; categories says [x]; intersection is just `a`. + deleted = storage.delete( + tenant_id="_default", record_ids=["a", "c"], categories=["x"] + ) + assert deleted == 1 + remaining = {r.id for r in storage.list_records(tenant_id="_default")} + assert remaining == {"b", "c"} + + def test_lancedb_list_scopes_get_scope_info(lancedb_path: Path) -> None: from crewai.memory.storage.lancedb_storage import LanceDBStorage