Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ asr:
device: cuda # Device for inference: 'cuda' or 'cpu'
device_id: 0 # GPU device ID
compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32'
use_amp: true # Enable Automatic Mixed Precision
use_amp: false # Enable Automatic Mixed Precision


# ==========================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ asr:
device: cuda # Device for inference: 'cuda' or 'cpu'
device_id: 0 # GPU device ID
compute_dtype: bfloat16 # Compute precision: 'bfloat16' for Ampere+, 'float16' for older GPUs, or 'float32'
use_amp: true # Enable Automatic Mixed Precision
use_amp: false # Enable Automatic Mixed Precision
decoding:
strategy: "greedy_batch"
preserve_alignments: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_initial_cache_state(self, batch_size: int) -> tuple[Tensor, Tensor, Tens
Returns:
(tuple[Tensor, Tensor, Tensor]) the initial cache state of the encoder.
"""
return self.asr_model.encoder.get_initial_cache_state(batch_size=batch_size)
return self.asr_model.encoder.get_initial_cache_state(batch_size=batch_size, dtype=self.cast_dtype)

def get_drop_extra_pre_encoded(self) -> int:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def __post_init__(self) -> None:

self.drop_extra_pre_encoded = self.get_drop_extra_pre_encoded()

self.cast_dtype = torch.float32 if self.use_amp else self.compute_dtype
self.asr_model.to(self.cast_dtype)

def get_blank_id(self) -> int:
"""
Returns id of the blank token.
Expand Down Expand Up @@ -180,6 +183,8 @@ def stream_step(
if processed_signal_length.device != self.device:
processed_signal_length = processed_signal_length.to(self.device)

processed_signal = processed_signal.to(self.cast_dtype)

if context is None:
# create a dummy context
context = CacheAwareContext()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __post_init__(self) -> None:

self.drop_extra_pre_encoded = self.get_drop_extra_pre_encoded()

self.cast_dtype = torch.float32 if self.use_amp else self.compute_dtype
self.asr_model.to(self.cast_dtype)

def get_blank_id(self) -> int:
"""
Returns id of the blank token.
Expand Down Expand Up @@ -170,6 +173,8 @@ def stream_step(
if processed_signal_length.device != self.device:
processed_signal_length = processed_signal_length.to(self.device)

processed_signal = processed_signal.to(self.cast_dtype)

if context is None:
# create a dummy context
context = CacheAwareContext()
Expand Down
Loading