True-bfloat16 inference for cache aware pipelines#15763
Conversation
Signed-off-by: naymaraq <dkaramyan@nvidia.com>
|
If we are looking for RTF improvements, maybe we should move the caching step after the qkv proj, this is a computation that is being repeated every step on cached inputs, the cost will be double the memory needed for cache but that is already affordable since the cache is capped at a certain size |
artbataev
left a comment
There was a problem hiding this comment.
Thanks a lot! Great improvement!
Good point. We need to understand what kind of RTFx gain it will provide at the cost of doubling the memory usage. |
|
/claude review |
|
Test coverage: There are no unit tests covering the cache-aware inference wrappers ( |
|
/ok to test 4d15f6e |
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do ?
Fix true
bfloat16(use_amp=false) inference for cache-aware streaming ASR. Observed significant RTFx improvment and 2x cache compression.Collection: [ASR]
Changelog
Problem: With
use_amp: false, the cache-aware wrappers cast the model weights tobfloat16, but disabled autocast, while the input mel features and the encoder caches stayedfloat32. The encoder then receivedfloat32inputs/caches againstbfloat16weights with no autocast to reconcile them, raising a dtype-mismatch error.Cast input features: In both
CacheAwareRNNTInferenceWrapper.stream_stepandCacheAwareCTCInferenceWrapper.stream_step, castprocessed_signaltoself.cast_dtypeafter the device move.Align cache dtype:
CacheAwareASRInferenceWrapper.get_initial_cache_statenow passesdtype=self.cast_dtype, so the context manager's persistent cache storage matches the encoder's output caches.Config defaults: Set use_amp: false in
cache_aware_rnnt.yamlandcache_aware_ctc.yamlso the example configs run true bf16 by default.Results
Setup:
Key Findings:
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information