Skip to content

True-bfloat16 inference for cache aware pipelines#15763

Open
naymaraq wants to merge 2 commits into
mainfrom
dkaramyan/cache-pipelines-bf16
Open

True-bfloat16 inference for cache aware pipelines#15763
naymaraq wants to merge 2 commits into
mainfrom
dkaramyan/cache-pipelines-bf16

Conversation

@naymaraq

@naymaraq naymaraq commented Jun 7, 2026

Copy link
Copy Markdown
Collaborator

Important

The Update branch button must only be pressed in very rare occassions.
An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.

What does this PR do ?

Fix true bfloat16 (use_amp=false) inference for cache-aware streaming ASR. Observed significant RTFx improvment and 2x cache compression.

Collection: [ASR]

Changelog

Problem: With use_amp: false, the cache-aware wrappers cast the model weights to bfloat16, but disabled autocast, while the input mel features and the encoder caches stayed float32. The encoder then received float32 inputs/caches against bfloat16 weights with no autocast to reconcile them, raising a dtype-mismatch error.

  • Cast input features: In both CacheAwareRNNTInferenceWrapper.stream_step and CacheAwareCTCInferenceWrapper.stream_step, cast processed_signal to self.cast_dtype after the device move.

  • Align cache dtype: CacheAwareASRInferenceWrapper.get_initial_cache_state now passes dtype=self.cast_dtype, so the context manager's persistent cache storage matches the encoder's output caches.

  • Config defaults: Set use_amp: false in cache_aware_rnnt.yaml and cache_aware_ctc.yaml so the example configs run true bf16 by default.

Results

Setup:

  • num_slots=256 and batch_size=64
  • Experiments done on NVIDIA RTX 5000 Ada GPU

Key Findings:

  • True-bf16 inference reduced cache memory 2x times versus fp32 (936 MB vs. 1872 MB) with essentially no WER degradation across all attention context sizes.
  • True-bf16 inference also significantly improved RTFx across all attention context sizes.
Method AMP Att Context Size Comp Ratio LS CLEAN LS OTHER TED VOX AVG. RTFX (LS-OTHER)
bf16 TRUE [70, 13] x2 5.28% 8.26% 12.03% 11.13% 9.18% 407
bf16 FALSE [70, 13] x2 5.29% 8.26% 12.01% 11.15% 9.18% 499
bf16 TRUE [70, 6] x2 5.38% 8.49% 12.02% 11.23% 9.28% 264
bf16 FALSE [70, 6] x2 5.37% 8.48% 12.00% 11.21% 9.26% 338
bf16 TRUE [70, 1] x2 5.57% 8.91% 12.30% 11.56% 9.58% 107
bf16 FALSE [70, 1] x2 5.56% 8.91% 12.30% 11.51% 9.57% 136
bf16 TRUE [70, 0] x2 6.06% 9.80% 12.61% 12.70% 10.29% 63
bf16 FALSE [70, 0] x2 6.08% 9.79% 12.66% 12.67% 10.30% 80

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

Signed-off-by: naymaraq <dkaramyan@nvidia.com>
@naymaraq naymaraq requested a review from artbataev June 7, 2026 19:59
@copy-pr-bot

copy-pr-bot Bot commented Jun 7, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions github-actions Bot added the ASR label Jun 7, 2026
@MahmoudAshraf97

Copy link
Copy Markdown
Contributor

If we are looking for RTF improvements, maybe we should move the caching step after the qkv proj, this is a computation that is being repeated every step on cached inputs, the cost will be double the memory needed for cache but that is already affordable since the cache is capped at a certain size

@artbataev artbataev left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! Great improvement!

@naymaraq

naymaraq commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator Author

If we are looking for RTF improvements, maybe we should move the caching step after the qkv proj, this is a computation that is being repeated every step on cached inputs, the cost will be double the memory needed for cache but that is already affordable since the cache is capped at a certain size

Good point. We need to understand what kind of RTFx gain it will provide at the cost of doubling the memory usage.

@naymaraq naymaraq added the CI label Jun 8, 2026
@naymaraq naymaraq enabled auto-merge (squash) June 8, 2026 10:14
@naymaraq

naymaraq commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator Author

/claude review

@claude

claude Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Test coverage: There are no unit tests covering the cache-aware inference wrappers (CacheAwareRNNTInferenceWrapper, CacheAwareCTCInferenceWrapper). A test that exercises stream_step with use_amp=False and a non-float32 compute_dtype would guard against dtype-mismatch regressions — which is exactly the bug this PR fixes.

@artbataev

Copy link
Copy Markdown
Collaborator

/ok to test 4d15f6e

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants