Source code for espnet2.asr_transducer.decoder.abs_decoder

"""Abstract decoder definition for Transducer models."""

from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple

import torch


[docs]class AbsDecoder(torch.nn.Module, ABC): """Abstract decoder module."""
[docs] @abstractmethod def forward(self, labels: torch.Tensor) -> torch.Tensor: """Encode source label sequences. Args: labels: Label ID sequences. (B, L) Returns: dec_out: Decoder output sequences. (B, T, D_dec) """ raise NotImplementedError
[docs] @abstractmethod def score( self, label: torch.Tensor, label_sequence: List[int], dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]], ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]: """One-step forward hypothesis. Args: label: Previous label. (1, 1) label_sequence: Current label sequence. dec_state: Previous decoder hidden states. ((N, 1, D_dec), (N, 1, D_dec) or None) or None Returns: dec_out: Decoder output sequence. (1, D_dec) or (1, D_emb) dec_state: Decoder hidden states. ((N, 1, D_dec), (N, 1, D_dec) or None) or None """ raise NotImplementedError
[docs] @abstractmethod def batch_score( self, hyps: List[Any], ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]: """One-step forward hypotheses. Args: hyps: Hypotheses. Returns: dec_out: Decoder output sequences. (B, D_dec) or (B, D_emb) states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) or None """ raise NotImplementedError
[docs] @abstractmethod def set_device(self, device: torch.Tensor) -> None: """Set GPU device to use. Args: device: Device ID. """ raise NotImplementedError
[docs] @abstractmethod def init_state( self, batch_size: int ) -> Optional[Tuple[torch.Tensor, Optional[torch.tensor]]]: """Initialize decoder states. Args: batch_size: Batch size. Returns: : Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) or None """ raise NotImplementedError
[docs] @abstractmethod def select_state( self, states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None, idx: int = 0, ) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]: """Get specified ID state from batch of states, if provided. Args: states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) or None idx: State ID to extract. Returns: : Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None) or None """ raise NotImplementedError