"""Set of methods to build Transducer encoder architecture."""
from typing import Any, Dict, List, Optional, Union
from espnet2.asr_transducer.activation import get_activation
from espnet2.asr_transducer.encoder.blocks.branchformer import Branchformer
from espnet2.asr_transducer.encoder.blocks.conformer import Conformer
from espnet2.asr_transducer.encoder.blocks.conv1d import Conv1d
from espnet2.asr_transducer.encoder.blocks.conv_input import ConvInput
from espnet2.asr_transducer.encoder.modules.attention import ( # noqa: H301
RelPositionMultiHeadedAttention,
)
from espnet2.asr_transducer.encoder.modules.convolution import ( # noqa: H301
ConformerConvolution,
ConvolutionalSpatialGatingUnit,
)
from espnet2.asr_transducer.encoder.modules.multi_blocks import MultiBlocks
from espnet2.asr_transducer.encoder.modules.normalization import get_normalization
from espnet2.asr_transducer.encoder.modules.positional_encoding import ( # noqa: H301
RelPositionalEncoding,
)
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
PositionwiseFeedForward,
)
[docs]def build_main_parameters(
pos_wise_act_type: str = "swish",
conv_mod_act_type: str = "swish",
pos_enc_dropout_rate: float = 0.0,
pos_enc_max_len: int = 5000,
simplified_att_score: bool = False,
norm_type: str = "layer_norm",
conv_mod_norm_type: str = "layer_norm",
after_norm_eps: Optional[float] = None,
after_norm_partial: Optional[float] = None,
dynamic_chunk_training: bool = False,
short_chunk_threshold: float = 0.75,
short_chunk_size: int = 25,
left_chunk_size: int = 0,
**activation_parameters,
) -> Dict[str, Any]:
"""Build encoder main parameters.
Args:
pos_wise_act_type: Conformer position-wise feed-forward activation type.
conv_mod_act_type: Conformer convolution module activation type.
pos_enc_dropout_rate: Positional encoding dropout rate.
pos_enc_max_len: Positional encoding maximum length.
simplified_att_score: Whether to use simplified attention score computation.
norm_type: X-former normalization module type.
conv_mod_norm_type: Conformer convolution module normalization type.
after_norm_eps: Epsilon value for the final normalization.
after_norm_partial: Value for the final normalization with RMSNorm.
dynamic_chunk_training: Whether to use dynamic chunk training.
short_chunk_threshold: Threshold for dynamic chunk selection.
short_chunk_size: Minimum number of frames during dynamic chunk training.
left_chunk_size: Number of frames in left context.
**activations_parameters: Parameters of the activation functions.
(See espnet2/asr_transducer/activation.py)
Returns:
: Main encoder parameters
"""
main_params = {}
main_params["pos_wise_act"] = get_activation(
pos_wise_act_type, **activation_parameters
)
main_params["conv_mod_act"] = get_activation(
conv_mod_act_type, **activation_parameters
)
main_params["pos_enc_dropout_rate"] = pos_enc_dropout_rate
main_params["pos_enc_max_len"] = pos_enc_max_len
main_params["simplified_att_score"] = simplified_att_score
main_params["norm_type"] = norm_type
main_params["conv_mod_norm_type"] = conv_mod_norm_type
(
main_params["after_norm_class"],
main_params["after_norm_args"],
) = get_normalization(norm_type, eps=after_norm_eps, partial=after_norm_partial)
main_params["dynamic_chunk_training"] = dynamic_chunk_training
main_params["short_chunk_threshold"] = max(0, short_chunk_threshold)
main_params["short_chunk_size"] = max(0, short_chunk_size)
main_params["left_chunk_size"] = max(0, left_chunk_size)
return main_params
[docs]def build_positional_encoding(
block_size: int, configuration: Dict[str, Any]
) -> RelPositionalEncoding:
"""Build positional encoding block.
Args:
block_size: Input/output size.
configuration: Positional encoding configuration.
Returns:
: Positional encoding module.
"""
return RelPositionalEncoding(
block_size,
configuration.get("pos_enc_dropout_rate", 0.0),
max_len=configuration.get("pos_enc_max_len", 5000),
)
[docs]def build_conv1d_block(
configuration: List[Dict[str, Any]],
causal: bool,
) -> Conv1d:
"""Build Conv1d block.
Args:
configuration: Conv1d block configuration.
Returns:
: Conv1d block function.
"""
return lambda: Conv1d(
configuration["input_size"],
configuration["output_size"],
configuration["kernel_size"],
stride=configuration.get("stride", 1),
dilation=configuration.get("dilation", 1),
groups=configuration.get("groups", 1),
bias=configuration.get("bias", True),
relu=configuration.get("relu", True),
batch_norm=configuration.get("batch_norm", False),
causal=causal,
dropout_rate=configuration.get("dropout_rate", 0.0),
)
[docs]def build_body_blocks(
configuration: List[Dict[str, Any]],
main_params: Dict[str, Any],
output_size: int,
) -> MultiBlocks:
"""Build encoder body blocks.
Args:
configuration: Body blocks configuration.
main_params: Encoder main parameters.
output_size: Architecture output size.
Returns:
MultiBlocks function encapsulation all encoder blocks.
"""
fn_modules = []
extended_conf = []
for c in configuration:
if c.get("num_blocks") is not None:
extended_conf += c["num_blocks"] * [
{c_i: c[c_i] for c_i in c if c_i != "num_blocks"}
]
else:
extended_conf += [c]
for i, c in enumerate(extended_conf):
block_type = c["block_type"]
if block_type == "branchformer":
module = build_branchformer_block(c, main_params)
elif block_type == "conformer":
module = build_conformer_block(c, main_params)
elif block_type == "conv1d":
module = build_conv1d_block(c, main_params["dynamic_chunk_training"])
else:
raise NotImplementedError
fn_modules.append(module)
return MultiBlocks(
[fn() for fn in fn_modules],
output_size,
norm_class=main_params["after_norm_class"],
norm_args=main_params["after_norm_args"],
)