diff --git a/scripts/speech_recognition/convert_to_tarred_audio_dataset.py b/scripts/speech_recognition/convert_to_tarred_audio_dataset.py index 7269e4f17ba3..8a23f3a195b1 100644 --- a/scripts/speech_recognition/convert_to_tarred_audio_dataset.py +++ b/scripts/speech_recognition/convert_to_tarred_audio_dataset.py @@ -78,14 +78,17 @@ import argparse import copy import json +import math import os import random import tarfile +import tempfile from collections import defaultdict from dataclasses import dataclass, field from datetime import datetime from io import BytesIO from typing import Any, List, Optional, Union +from urllib.parse import urlparse import soundfile from joblib import Parallel, delayed @@ -100,6 +103,161 @@ DALI_INDEX_SCRIPT_AVAILABLE = False +def is_s3_path(path: Optional[str]) -> bool: + return path is not None and str(path).startswith("s3://") + + +def _is_not_found_error(exc: Exception) -> bool: + status_code = getattr(exc, "status_code", None) + if status_code == 404: + return True + response = getattr(exc, "response", None) + if response is not None and getattr(response, "status_code", None) == 404: + return True + return False + + +class AISS3HTTPClient: + def __init__(self, endpoint: str, token: str): + import requests + + self._base = endpoint.rstrip("/") + self._session = requests.Session() + self._session.headers["Authorization"] = f"Bearer {token}" + + def bucket(self, bucket_name: str, provider: str = "s3"): + if provider != "s3": + raise ValueError(f"AISS3HTTPClient only supports provider='s3', got {provider!r}") + return AISS3HTTPBucket(self._base, self._session, bucket_name) + + +class AISS3HTTPBucket: + def __init__(self, base_url: str, session, bucket_name: str): + self._base = base_url + self._session = session + self._bucket_name = bucket_name + + def object(self, key: str): + return AISS3HTTPObject(self._base, self._session, self._bucket_name, key) + + +class AISS3HTTPObject: + def __init__(self, base_url: str, session, bucket_name: str, key: str): + from urllib.parse import quote + + self._session = session + self._url = f"{base_url}/s3/{bucket_name}/{quote(key)}" + + def head(self): + response = self._session.head(self._url) + if response.status_code >= 400: + err = RuntimeError(f"STATUS:{response.status_code}, MESSAGE:{response.reason}, REQ_URL:{self._url}") + err.status_code = response.status_code + raise err + return response.headers + + def get_writer(self): + return AISS3HTTPObjectWriter(self._session, self._url) + + +class AISS3HTTPObjectWriter: + def __init__(self, session, url: str): + self._session = session + self._url = url + + def put_file(self, path: str): + with open(path, "rb") as f: + response = self._session.put(self._url, data=f) + if response.status_code >= 400: + err = RuntimeError(f"STATUS:{response.status_code}, MESSAGE:{response.reason}, REQ_URL:{self._url}") + err.status_code = response.status_code + raise err + return response + + +class OutputTarget: + def __init__(self, target_dir: str): + self.target_dir = target_dir + self.is_s3 = is_s3_path(target_dir) + self.s3_client = None + self.bucket = None + self.bucket_name = None + self.key_prefix = "" + self._tempdir = None + + if self.is_s3: + parsed = urlparse(target_dir) + self.bucket_name = parsed.netloc + self.key_prefix = parsed.path.lstrip("/").rstrip("/") + self._tempdir = tempfile.TemporaryDirectory(prefix="nemo_tarred_audio_") + self.local_dir = self._tempdir.name + endpoint = os.environ.get("AIS_ENDPOINT") + token = os.environ.get("AIS_AUTHN_TOKEN") + missing = [name for name, value in (("AIS_ENDPOINT", endpoint), ("AIS_AUTHN_TOKEN", token)) if not value] + if missing: + raise ValueError(f"S3 target_dir requires environment variables: {', '.join(missing)}") + try: + from aistore.sdk.client import Client + except ModuleNotFoundError as exc: + print( + f"AIStore SDK import failed because {exc.name!r} is missing; " + "falling back to AIS S3-compatible HTTP upload." + ) + self.s3_client = AISS3HTTPClient(endpoint, token) + else: + self.s3_client = Client(endpoint, token=token) + self.bucket = self.s3_client.bucket(self.bucket_name, provider="s3") + print(f"Uploading tarred dataset output to {self.target_dir}") + else: + self.local_dir = target_dir + + def cleanup(self): + if self._tempdir is not None: + self._tempdir.cleanup() + + def relative_path(self, local_path: str) -> str: + return os.path.relpath(local_path, self.local_dir).replace(os.sep, "/") + + def object_key(self, relative_path: str) -> str: + relative_path = relative_path.replace(os.sep, "/").lstrip("/") + if self.key_prefix: + return f"{self.key_prefix}/{relative_path}" + return relative_path + + def object_uri(self, relative_path: str) -> str: + return f"s3://{self.bucket_name}/{self.object_key(relative_path)}" + + def display_path(self, local_path: str) -> str: + if not self.is_s3: + return local_path + return self.object_uri(self.relative_path(local_path)) + + def exists(self, local_path: str) -> bool: + if not self.is_s3: + return os.path.exists(local_path) + + relative_path = self.relative_path(local_path) + key = self.object_key(relative_path) + try: + self.bucket.object(key).head() + return True + except Exception as exc: + if _is_not_found_error(exc): + return False + raise + + def upload_file(self, local_path: str, remove_after: bool = False) -> None: + if not self.is_s3: + return + + relative_path = self.relative_path(local_path) + key = self.object_key(relative_path) + print(f"Uploading {local_path} -> {self.object_uri(relative_path)}") + self.bucket.object(key).get_writer().put_file(local_path) + if remove_after: + os.remove(local_path) + + @dataclass class ASRTarredDatasetConfig: num_shards: int = -1 @@ -112,6 +270,7 @@ class ASRTarredDatasetConfig: shard_manifests: bool = True keep_files_together: bool = False force_codec: Optional[str] = None + force_sampling_rate: Optional[int] = None use_lhotse: bool = False use_bucketing: bool = False num_buckets: Optional[int] = None @@ -154,6 +313,7 @@ class ASRTarredDatasetBuilder: def __init__(self): self.config = None + self.output_target = None def configure(self, config: ASRTarredDatasetConfig): """ @@ -167,6 +327,20 @@ def configure(self, config: ASRTarredDatasetConfig): if self.config.num_shards < 0: raise ValueError("`num_shards` must be > 0. Please fill in the metadata information correctly.") + def _output_exists(self, local_path: str) -> bool: + if self.output_target is not None: + return self.output_target.exists(local_path) + return os.path.exists(local_path) + + def _upload_output_file(self, local_path: str, remove_after: bool = False) -> None: + if self.output_target is not None: + self.output_target.upload_file(local_path, remove_after=remove_after) + + def _display_output_path(self, local_path: str) -> str: + if self.output_target is not None: + return self.output_target.display_path(local_path) + return local_path + def create_new_dataset( self, manifest_path: str, @@ -213,6 +387,17 @@ def create_new_dataset( if not os.path.exists(target_dir): os.makedirs(target_dir) + if not config.shuffle: + self._create_new_dataset_streaming( + manifest_path=manifest_path, + target_dir=target_dir, + buckets_num=buckets_num, + dynamic_buckets_num=dynamic_buckets_num, + only_manifests=only_manifests, + dry_run=dry_run, + ) + return + # Read the existing manifest entries, total_duration, filtered_entries, filtered_duration = self._read_manifest(manifest_path, config) @@ -286,6 +471,7 @@ def create_new_dataset( for entry in manifest: json.dump(entry, m2, ensure_ascii=False) m2.write('\n') + self._upload_output_file(new_manifest_shard_path, remove_after=True) # Flatten the list of list of entries to a list of entries new_entries = [sample for manifest in new_entries_list for sample in manifest] @@ -316,9 +502,12 @@ def create_new_dataset( for k, v in bucketing_kwargs.items(): setattr(metadata.dataset_config, k, v) + self._upload_output_file(new_manifest_path, remove_after=True) + # Write metadata metadata_yaml = OmegaConf.structured(metadata) OmegaConf.save(metadata_yaml, new_metadata_path, resolve=True) + self._upload_output_file(new_metadata_path, remove_after=True) def estimate_dynamic_bucketing_duration_bins(self, manifest_path: str, num_buckets: int = 30) -> dict: from lhotse import CutSet @@ -492,6 +681,7 @@ def create_concatenated_dataset( for entry in manifest: json.dump(entry, m2, ensure_ascii=False) m2.write('\n') + self._upload_output_file(new_manifest_shard_path, remove_after=True) # Flatten the list of list of entries to a list of entries new_entries = [sample for manifest in new_entries_list for sample in manifest] @@ -517,6 +707,8 @@ def create_concatenated_dataset( json.dump(entry, m2, ensure_ascii=False) m2.write('\n') + self._upload_output_file(new_manifest_path, remove_after=True) + # Preserve historical metadata base_metadata = metadata @@ -541,6 +733,7 @@ def create_concatenated_dataset( # Write metadata metadata_yaml = OmegaConf.structured(metadata) OmegaConf.save(metadata_yaml, new_metadata_path, resolve=True) + self._upload_output_file(new_metadata_path, remove_after=True) def _read_manifest(self, manifest_path: Union[str, List[str]], config: ASRTarredDatasetConfig): """Read and filters data from the manifest""" @@ -549,10 +742,7 @@ def _read_manifest(self, manifest_path: Union[str, List[str]], config: ASRTarred filtered_entries = [] filtered_duration = 0.0 - if isinstance(manifest_path, str): - manifest_paths = manifest_path.split(",") - else: - manifest_paths = manifest_path + manifest_paths = self._get_manifest_paths(manifest_path) print(f"Found {len(manifest_paths)} manifest files to be processed") for manifest_file in manifest_paths: @@ -566,63 +756,227 @@ def _read_manifest(self, manifest_path: Union[str, List[str]], config: ASRTarred return entries, total_duration, filtered_entries, filtered_duration - def _read_single_manifest(self, manifest_path: str, config: ASRTarredDatasetConfig): - # Read the existing manifest - entries = [] - total_duration = 0.0 - filtered_entries = [] - filtered_duration = 0.0 - print(f"Reading manifest: {manifest_path}") + def _get_manifest_paths(self, manifest_path: Union[str, List[str]]): + if isinstance(manifest_path, str): + return manifest_path.split(",") + return manifest_path + + def _prepare_manifest_entry(self, entry: dict, manifest_path: str, config: ASRTarredDatasetConfig): + audio_key = "audio_filepath" if "audio_filepath" in entry else "audio_file" + if config.slice_with_offset and "offset" not in entry: + raise KeyError( + f"Manifest entry does not contain 'offset' field, but '--slice_with_offset' is enabled: {entry}" + ) + if audio_key not in entry: + raise KeyError(f"Manifest entry does not contain 'audio_filepath' or 'audio_file' key: {entry}") + audio_filepath = entry[audio_key] + if not os.path.isfile(audio_filepath) and not os.path.isabs(audio_filepath): + audio_filepath_abs = os.path.join(os.path.dirname(manifest_path), audio_filepath) + if not os.path.isfile(audio_filepath_abs): + raise FileNotFoundError(f"Could not find {audio_filepath} or {audio_filepath_abs}!") + entry[audio_key] = audio_filepath_abs + if audio_key != "audio_filepath": + entry["audio_filepath"] = entry[audio_key] + + is_valid = (config.max_duration is None or entry['duration'] < config.max_duration) and ( + config.min_duration is None or entry['duration'] >= config.min_duration + ) + return entry, is_valid + + def _iter_single_manifest(self, manifest_path: str, config: ASRTarredDatasetConfig, action: str = "Reading"): + print(f"{action} manifest: {manifest_path}") with open(manifest_path, 'r', encoding='utf-8') as m: for line in m: line = line.strip() if not line: continue entry = json.loads(line) - audio_key = "audio_filepath" if "audio_filepath" in entry else "audio_file" - if config.slice_with_offset and "offset" not in entry: - raise KeyError( - f"Manifest entry does not contain 'offset' field, but '--slice_with_offset' is enabled: {entry}" - ) - if audio_key not in entry: - raise KeyError(f"Manifest entry does not contain 'audio_filepath' or 'audio_file' key: {entry}") - audio_filepath = entry[audio_key] - if not os.path.isfile(audio_filepath) and not os.path.isabs(audio_filepath): - audio_filepath_abs = os.path.join(os.path.dirname(manifest_path), audio_filepath) - if not os.path.isfile(audio_filepath_abs): - raise FileNotFoundError(f"Could not find {audio_filepath} or {audio_filepath_abs}!") - entry[audio_key] = audio_filepath_abs - if (config.max_duration is None or entry['duration'] < config.max_duration) and ( - config.min_duration is None or entry['duration'] >= config.min_duration - ): - entries.append(entry) + yield self._prepare_manifest_entry(entry, manifest_path, config) + + def _read_single_manifest(self, manifest_path: str, config: ASRTarredDatasetConfig): + # Read the existing manifest + entries = [] + total_duration = 0.0 + filtered_entries = [] + filtered_duration = 0.0 + + for entry, is_valid in self._iter_single_manifest(manifest_path, config): + if is_valid: + entries.append(entry) + total_duration += entry["duration"] + else: + filtered_entries.append(entry) + filtered_duration += entry['duration'] + + return entries, total_duration, filtered_entries, filtered_duration + + def _count_manifest(self, manifest_path: Union[str, List[str]], config: ASRTarredDatasetConfig): + entries_count = 0 + total_duration = 0.0 + filtered_entries_count = 0 + filtered_duration = 0.0 + manifest_paths = self._get_manifest_paths(manifest_path) + + print(f"Found {len(manifest_paths)} manifest files to be processed") + for manifest_file in manifest_paths: + for entry, is_valid in self._iter_single_manifest(str(manifest_file), config, action="Counting"): + if is_valid: + entries_count += 1 total_duration += entry["duration"] else: - filtered_entries.append(entry) - filtered_duration += entry['duration'] + filtered_entries_count += 1 + filtered_duration += entry["duration"] - return entries, total_duration, filtered_entries, filtered_duration + return entries_count, total_duration, filtered_entries_count, filtered_duration + + def _iter_manifest_entries(self, manifest_path: Union[str, List[str]], config: ASRTarredDatasetConfig): + for manifest_file in self._get_manifest_paths(manifest_path): + for entry, is_valid in self._iter_single_manifest(str(manifest_file), config): + if is_valid: + yield entry + + def _write_manifest_entries(self, manifest_path: str, entries) -> None: + with open(manifest_path, 'w', encoding='utf-8') as m2: + for entry in entries: + json.dump(entry, m2, ensure_ascii=False) + m2.write('\n') + + def _write_shard_manifest(self, target_dir: str, entries) -> None: + if not self.config.shard_manifests: + return + if not entries: + return + + sharded_manifests_dir = target_dir + '/sharded_manifests' + if not os.path.exists(sharded_manifests_dir): + os.makedirs(sharded_manifests_dir) + + shard_id = entries[0]['shard_id'] + new_manifest_shard_path = os.path.join(sharded_manifests_dir, f'manifest_{shard_id}.json') + self._write_manifest_entries(new_manifest_shard_path, entries) + self._upload_output_file(new_manifest_shard_path, remove_after=True) + + def _create_new_dataset_streaming( + self, + manifest_path: str, + target_dir: str, + buckets_num: int = 1, + dynamic_buckets_num: int = 30, + only_manifests: bool = False, + dry_run: bool = False, + ): + config = self.config + entries_count, total_duration, filtered_entries_count, filtered_duration = self._count_manifest( + manifest_path, config + ) + + entries_per_shard = entries_count // config.num_shards + remainder_entries = entries_count % config.num_shards + print( + f"\n Min duration: {config.min_duration} s" + f"\n Max duration: {config.max_duration} s" + f"\n Entries after filtration: {entries_count} / {entries_count + filtered_entries_count}" + f"\n Duration after filtration: {total_duration:.2f} / {total_duration + filtered_duration:.2f} s" + f"\n Shards: {config.num_shards}" + f"\n Entries per shard: {entries_per_shard}" + f"\n Remainder entries: {remainder_entries}" + ) + if dry_run: + return + + if entries_count == 0: + print("No tarred dataset was created as there were 0 valid samples after filtering!") + return + + if entries_per_shard == 0: + print( + "No tarred dataset was created because the number of valid samples is smaller than " + "the requested number of shards." + ) + return + + manifest_folder, _ = os.path.split(self._get_manifest_paths(manifest_path)[0]) + new_manifest_path = os.path.join(target_dir, 'tarred_audio_manifest.json') + total_new_entries = 0 + + current_entries = [] + current_shard_id = 0 + start_idx = 0 + valid_entries = self._iter_manifest_entries(manifest_path, config) + + with open(new_manifest_path, 'w', encoding='utf-8') as m2: + for entry in valid_entries: + if current_shard_id >= config.num_shards: + break + + current_entries.append(entry) + if len(current_entries) < entries_per_shard: + continue + + end_idx = start_idx + entries_per_shard + print(f"Shard {current_shard_id} has entries {start_idx} ~ {end_idx}") + files = {ent["audio_filepath"] for ent in current_entries} + print(f"Shard {current_shard_id} contains {len(files)} files") + if current_shard_id == config.num_shards - 1: + print(f"Have {remainder_entries} entries left over that will be discarded.") + + new_entries = self._create_shard( + current_entries, target_dir, current_shard_id, manifest_folder, only_manifests + ) + self._write_shard_manifest(target_dir, new_entries) + for new_entry in new_entries: + json.dump(new_entry, m2, ensure_ascii=False) + m2.write('\n') + total_new_entries += len(new_entries) + + current_entries = [] + current_shard_id += 1 + start_idx = end_idx + + print("Total number of entries in manifest :", total_new_entries) + + # Write metadata (default metadata for new datasets) + new_metadata_path = os.path.join(target_dir, 'metadata.yaml') + metadata = ASRTarredDatasetMetadata() + + metadata.dataset_config = config + metadata.num_samples_per_shard = entries_per_shard + + if buckets_num <= 1: + bucketing_kwargs = self.estimate_dynamic_bucketing_duration_bins( + new_manifest_path, num_buckets=dynamic_buckets_num + ) + for k, v in bucketing_kwargs.items(): + setattr(metadata.dataset_config, k, v) + + self._upload_output_file(new_manifest_path, remove_after=True) + + metadata_yaml = OmegaConf.structured(metadata) + OmegaConf.save(metadata_yaml, new_metadata_path, resolve=True) + self._upload_output_file(new_metadata_path, remove_after=True) def _write_to_tar( self, tar, audio_filepath: str, squashed_filename: str, duration: float = None, offset: float = 0 ) -> None: codec = self.config.force_codec + force_sampling_rate = self.config.force_sampling_rate + source_sampling_rate = soundfile.info(audio_filepath).samplerate to_transcode = not (codec is None or audio_filepath.endswith(f".{codec}")) to_crop = not (duration is None and offset == 0) + to_resample = force_sampling_rate is not None and force_sampling_rate != source_sampling_rate - if not to_crop and not to_transcode: + if not to_crop and not to_transcode and not to_resample: # Add existing file without transcoding, trimming, or re-encoding. tar.add(audio_filepath, arcname=squashed_filename) return - # Standard processing: read, trim, and transcode the audio file - with soundfile.SoundFile(audio_filepath) as f: - sampling_rate = f.samplerate - # Trim audio based on offset and duration. - start_sample = int(offset * sampling_rate) - num_frames = int(duration * sampling_rate) if duration else -1 - audio, sampling_rate = soundfile.read(file_path, start=start_sample, frames=num_frames) + start_sample = int(offset * source_sampling_rate) + num_frames = int(duration * source_sampling_rate) if duration else -1 + audio, sampling_rate = soundfile.read(audio_filepath, start=start_sample, frames=num_frames) + if to_resample: + audio = self._resample_audio(audio, sampling_rate, force_sampling_rate) + sampling_rate = force_sampling_rate # Determine codec parameters. if codec is not None: @@ -644,9 +998,26 @@ def _write_to_tar( # Add the in-memory audio file to the tar archive. ti = tarfile.TarInfo(encoded_squashed_filename) encoded_audio.seek(0) - ti.size = len(encoded_audio.getvalue()) + ti.size = encoded_audio.getbuffer().nbytes tar.addfile(ti, encoded_audio) + def _tar_audio_filename(self, squashed_filename: str) -> str: + if self.config.force_codec is None: + return squashed_filename + base, _ = os.path.splitext(squashed_filename) + return f"{base}.{self.config.force_codec}" + + def _resample_audio(self, audio, source_sampling_rate: int, target_sampling_rate: int): + if source_sampling_rate == target_sampling_rate: + return audio + + from scipy.signal import resample_poly + + common = math.gcd(source_sampling_rate, target_sampling_rate) + up = target_sampling_rate // common + down = source_sampling_rate // common + return resample_poly(audio, up, down, axis=0) + def _create_shard(self, entries, target_dir, shard_id, manifest_folder: str = None, only_manifests: bool = False): """Creates a tarball containing the audio files from `entries`.""" if self.config.sort_in_shards: @@ -655,8 +1026,14 @@ def _create_shard(self, entries, target_dir, shard_id, manifest_folder: str = No new_entries = [] tar_filepath = os.path.join(target_dir, f'audio_{shard_id}.tar') - if not only_manifests: + tar_exists = self._output_exists(tar_filepath) + write_tar = not only_manifests and not tar_exists + if tar_exists and not only_manifests: + print(f"Skipping existing tar shard: {self._display_output_path(tar_filepath)}") + if write_tar: tar = tarfile.open(tar_filepath, mode='w', dereference=True) + else: + tar = None count = dict() for entry in tqdm(entries, desc="Creating shard.."): @@ -704,8 +1081,8 @@ def _create_shard(self, entries, target_dir, shard_id, manifest_folder: str = No ) entry_duration = "_".join(entry_duration) - to_write = base + "_" + entry_offset + "_" + entry_duration + ext - if not only_manifests: + to_write = self._tar_audio_filename(base + "_" + entry_offset + "_" + entry_duration + ext) + if write_tar: self._write_to_tar( tar, audio_filepath, to_write, duration=entry['duration'], offset=entry['offset'] ) @@ -715,12 +1092,12 @@ def _create_shard(self, entries, target_dir, shard_id, manifest_folder: str = No del entry['offset'] else: if squashed_filename not in count: - if not only_manifests: - self._write_to_tar(tar, audio_filepath, squashed_filename) - to_write = squashed_filename + to_write = self._tar_audio_filename(squashed_filename) + if write_tar: + self._write_to_tar(tar, audio_filepath, to_write) count[squashed_filename] = 1 else: - to_write = base + "-sub" + str(count[squashed_filename]) + ext + to_write = self._tar_audio_filename(base + "-sub" + str(count[squashed_filename]) + ext) count[squashed_filename] += 1 if only_manifests: @@ -734,8 +1111,9 @@ def _create_shard(self, entries, target_dir, shard_id, manifest_folder: str = No } new_entries.append(new_entry) - if not only_manifests: + if write_tar: tar.close() + self._upload_output_file(tar_filepath, remove_after=True) return new_entries @classmethod @@ -790,12 +1168,16 @@ def create_tar_datasets( write_metadata: bool = False, no_shard_manifests: bool = False, force_codec: str = None, + force_sampling_rate: int = None, workers: int = 1, slice_with_offset: bool = False, only_manifests: bool = False, dry_run: bool = False, ): builder = ASRTarredDatasetBuilder() + output_target = OutputTarget(target_dir) + target_dir = output_target.local_dir + builder.output_target = output_target shard_manifests = False if no_shard_manifests else True @@ -811,14 +1193,17 @@ def create_tar_datasets( shard_manifests=shard_manifests, keep_files_together=keep_files_together, force_codec=force_codec, + force_sampling_rate=force_sampling_rate, slice_with_offset=slice_with_offset, ) metadata.dataset_config = dataset_cfg output_path = os.path.join(target_dir, 'default_metadata.yaml') OmegaConf.save(metadata, output_path, resolve=True) + output_target.upload_file(output_path, remove_after=True) print(f"Default metadata written to {output_path}") - exit(0) + output_target.cleanup() + return if concat_manifest_paths is None or len(concat_manifest_paths) == 0: # Create a tarred dataset from scratch @@ -832,6 +1217,7 @@ def create_tar_datasets( shard_manifests=shard_manifests, keep_files_together=keep_files_together, force_codec=force_codec, + force_sampling_rate=force_sampling_rate, slice_with_offset=slice_with_offset, ) builder.configure(config) @@ -868,6 +1254,7 @@ def create_tar_datasets( metadata.dataset_config.shuffle_seed = shuffle_seed metadata.dataset_config.sort_in_shards = sort_in_shards metadata.dataset_config.shard_manifests = shard_manifests + metadata.dataset_config.force_sampling_rate = force_sampling_rate builder.configure(metadata.dataset_config) @@ -883,11 +1270,22 @@ def create_tar_datasets( dry_run=dry_run, ) - if not dry_run and (DALI_INDEX_SCRIPT_AVAILABLE and dali_index.INDEX_CREATOR_AVAILABLE): + if not dry_run and output_target.is_s3: + print("Skipping DALI Tarfile Index construction for S3 target_dir.") + elif not dry_run and (DALI_INDEX_SCRIPT_AVAILABLE and dali_index.INDEX_CREATOR_AVAILABLE): print("Constructing DALI Tarfile Index - ", target_dir) index_config = dali_index.DALITarredIndexConfig(tar_dir=target_dir, workers=workers) dali_index.main(index_config) + output_target.cleanup() + + +def positive_int(value: str) -> int: + value_int = int(value) + if value_int <= 0: + raise argparse.ArgumentTypeError("value must be a positive integer") + return value_int + if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -1006,6 +1404,15 @@ def create_tar_datasets( "Supports libnsndfile formats (example values: 'opus', 'flac')." ), ) + parser.add_argument( + "--force_sampling_rate", + type=positive_int, + default=None, + help=( + "If specified, resample audio to this sampling rate before writing it into the tar file. " + "Example: --force_sampling_rate=16000." + ), + ) parser.add_argument( "--only_manifests", action='store_true',