Skip to content

Running into OOM errors on a workflow that used to work in prior versions. #1011

@ayushdg

Description

@ayushdg

Running a script that read, hashes, shuffles, drops non duplicate elements and writes out to parquet similar to the logic here: https://github.com/NVIDIA-NeMo/Curator/blob/main/nemo_curator/stages/deduplication/fuzzy/lsh/lsh.py
The script runs fine on 25.10 but consistently runs into OOM errors in 26.04.

Non Curator Reproducer is a minor patch applied to the existing bulk_ray_shuffle scripts in rapidsmpf:

--- a/bulk_ray_shuffle.py
+++ b/bulk_ray_shuffle.py
@@ -254,6 +254,62 @@
         for partition_id, partition in self.extract():
             self.write_table(partition, self.output_path, partition_id, column_names)
 
+    def lsh_read_and_insert(
+        self,
+        partitions: list[list[str]],
+        band_range: tuple[int, int],
+        minhashes_per_band: int,
+        id_field: str,
+        minhash_field: str,
+    ) -> list[str]:
+        """Read minhash parquets per partition, hash bands, melt, and insert.
+
+        `partitions` is a list of file groups (typically ~2 GiB each, produced by
+        Curator's FilePartitioningStage). Each partition is read as one cudf.read_parquet
+        call so reads stay aligned with the user-chosen blocksize.
+
+        band_range is half-open [start, end). All bands in the range are processed in a
+        single shuffle pass -- there is no iteration over bands.
+        """
+        import cudf
+        from rapidsmpf.utils.cudf import cudf_to_pylibcudf_table
+
+        bucket_field = "_bucket_id"
+        column_names = [id_field, bucket_field]
+        start, end = band_range
+        if start < 0 or start >= end:
+            raise ValueError(f"Invalid band range: {band_range}")
+        for partition_files in partitions:
+            df = cudf.read_parquet(partition_files, columns=[id_field, minhash_field])
+            if len(df) == 0:
+                continue
+            id_df = df[[id_field]]
+            for k in range(start, end):
+                idx = list(range(k * minhashes_per_band, (k + 1) * minhashes_per_band))
+                rep = cudf.Series([idx]).repeat(len(id_df))
+                id_df[f"_bucket_{k}"] = f"b{k}_" + df[minhash_field].list.take(rep).hash_values(method="md5")
+            value_vars = [f"_bucket_{k}" for k in range(start, end)]
+            band_df = id_df.melt(id_vars=[id_field], value_name=bucket_field, value_vars=value_vars)[column_names]
+            self.insert_chunk(cudf_to_pylibcudf_table(band_df), column_names)
+            del df, id_df, band_df
+        self.insert_finished()
+        return column_names
+
+    def lsh_extract_and_write(self, id_field: str) -> None:
+        """Extract shuffled bucket rows, drop singletons, group ids per bucket, write parquet."""
+        bucket_field = "_bucket_id"
+        column_names = [id_field, bucket_field]
+        for partition_id, partition in self.extract():
+            df = pylibcudf_to_cudf_dataframe(partition, column_names=column_names)
+            if len(df) == 0:
+                continue
+            df = df[df[bucket_field].duplicated(keep=False)]
+            if len(df) == 0:
+                continue
+            grouped = df.groupby(bucket_field)[id_field].agg(list).list.sort_values().reset_index()
+            grouped.to_parquet(f"{self.output_path}/part.{partition_id}.parquet", index=False)
+            del df, grouped
+
 
 def bulk_ray_shuffle(
     paths: list[str],
@@ -328,6 +384,67 @@
     ray.get([actor.cleanup.remote() for actor in actors])
 
 
+def lsh_bulk_ray_shuffle(
+    partitions: list[list[str]],
+    output_path: str,
+    num_bands: int,
+    minhashes_per_band: int,
+    id_field: str,
+    minhash_field: str,
+    num_workers: int = 8,
+    num_output_files: int | None = None,
+    rmm_pool_size: int = 1024 * 1024 * 1024,
+    spill_device: int | None = None,
+    *,
+    enable_statistics: bool = False,
+) -> None:
+    """LSH-mode driver: read minhashes per partition, hash bands, shuffle on bucket id, group, write.
+
+    `partitions` is a list of file groups produced by Curator's FilePartitioningStage
+    (each group is ~2 GiB by default). Partitions are distributed round-robin across
+    `num_workers` actors. All `num_bands` bands are shuffled in a single pass.
+    """
+    num_partitions = len(partitions)
+    total_num_partitions = num_output_files or num_partitions
+
+    actors = setup_ray_ucxx_cluster(
+        BulkRayShufflerActor,
+        num_workers=num_workers,
+        total_nparts=total_num_partitions,
+        shuffle_on=["_bucket_id"],
+        batchsize=1,
+        output_path=output_path,
+        enable_statistics=enable_statistics,
+        rmm_pool_size=rmm_pool_size,
+        spill_device=spill_device,
+    )
+    # Round-robin partition assignment so size variance is spread across actors.
+    actor_partitions: list[list[list[str]]] = [[] for _ in range(num_workers)]
+    for i, part in enumerate(partitions):
+        actor_partitions[i % num_workers].append(part)
+    print(
+        f"Distributing {num_partitions} partitions across {num_workers} actors "
+        f"(min/max per actor: {min(len(p) for p in actor_partitions)} / {max(len(p) for p in actor_partitions)})"
+    )
+
+    start_time = time.time()
+    insert_tasks = [
+        actor.lsh_read_and_insert.remote(
+            actor_partitions[i],
+            band_range=(0, num_bands),
+            minhashes_per_band=minhashes_per_band,
+            id_field=id_field,
+            minhash_field=minhash_field,
+        )
+        for i, actor in enumerate(actors)
+    ]
+    ray.get(insert_tasks)
+    ray.get([actor.lsh_extract_and_write.remote(id_field) for actor in actors])
+    end_time = time.time()
+    print(f"LSH shuffle time: {end_time - start_time} seconds")
+    ray.get([actor.cleanup.remote() for actor in actors])
+
+
 def dir_path(path: str) -> Path:
     """
     Validate that the given path is a directory and return a Path object.
@@ -367,12 +484,17 @@
     else:
         ray.init(num_gpus=args.num_workers, dashboard_host="0.0.0.0")
 
-    bulk_ray_shuffle(
-        paths=sorted(map(str, args.input.glob("**/*"))),
-        shuffle_on=args.on.split(","),
+    import json
+    with open(args.partitions_json) as fh:
+        partitions = json.load(fh)
+    lsh_bulk_ray_shuffle(
+        partitions=partitions,
         output_path=args.output,
+        num_bands=args.num_bands,
+        minhashes_per_band=args.minhashes_per_band,
+        id_field=args.id_field,
+        minhash_field=args.minhash_field,
         num_workers=args.num_workers,
-        batchsize=args.batchsize,
         num_output_files=args.n_output_files,
         enable_statistics=args.statistics,
         rmm_pool_size=args.rmm_pool_size,
@@ -392,10 +514,10 @@
         help="Number of workers to use.",
     )
     parser.add_argument(
-        "input",
-        type=dir_path,
-        metavar="INPUT_DIR_PATH",
-        help="Input directory path.",
+        "partitions_json",
+        type=str,
+        metavar="PARTITIONS_JSON",
+        help="Path to JSON file produced by make_partitions.py (list[list[str]]).",
     )
     parser.add_argument(
         "output",
@@ -404,10 +526,28 @@
         help="Output directory path.",
     )
     parser.add_argument(
-        "on",
-        metavar="COLUMN_LIST",
+        "--num-bands",
+        type=int,
+        default=20,
+        help="Number of LSH bands to process in a single shuffle pass (bands [0, num_bands)).",
+    )
+    parser.add_argument(
+        "--minhashes-per-band",
+        type=int,
+        default=13,
+        help="Number of minhashes per LSH band.",
+    )
+    parser.add_argument(
+        "--id-field",
+        type=str,
+        default="_curator_dedup_id",
+        help="Document id column.",
+    )
+    parser.add_argument(
+        "--minhash-field",
         type=str,
-        help="Comma-separated list of column names to shuffle on.",
+        default="_minhash_signature",
+        help="Minhash list column.",
     )
     parser.add_argument(
         "--n-output-files",
@@ -416,12 +556,6 @@
         help="Number of output files. Default preserves input file count.",
     )
     parser.add_argument(
-        "--batchsize",
-        type=int,
-        default=1,
-        help="Number of files to read on each MPI rank at once.",
-    )
-    parser.add_argument(
         "--rmm-pool-size",
         type=parse_bytes,
         default=format_bytes(int(rmm.mr.available_device_memory()[1] * 0.8)),

And to reproduce:

  python bulk_ray_shuffle_lsh.py \
      /path/to/partitions_2GiB.json \
      /path/to/output_dir \
      --num-workers 8 \
      --num-bands 20 --minhashes-per-band 13 \
      --rmm-pool-size 72GiB --spill-device 60GiB \
      --statistics

Requires 8x 80GB GPUs. Success on rmpf 25.10 and fails on 26.04

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions