import torch
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor
from espnet2.enh.decoder.abs_decoder import AbsDecoder
from espnet2.layers.stft import Stft
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")
[docs]class STFTDecoder(AbsDecoder):
"""STFT decoder for speech enhancement and separation"""
def __init__(
self,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window="hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
):
super().__init__()
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window=window,
center=center,
normalized=normalized,
onesided=onesided,
)
[docs] def forward(self, input: ComplexTensor, ilens: torch.Tensor):
"""Forward.
Args:
input (ComplexTensor): spectrum [Batch, T, (C,) F]
ilens (torch.Tensor): input lengths [Batch]
"""
if not isinstance(input, ComplexTensor) and (
is_torch_1_9_plus and not torch.is_complex(input)
):
raise TypeError("Only support complex tensors for stft decoder")
bs = input.size(0)
if input.dim() == 4:
multi_channel = True
# input: (Batch, T, C, F) -> (Batch * C, T, F)
input = input.transpose(1, 2).reshape(-1, input.size(1), input.size(3))
else:
multi_channel = False
wav, wav_lens = self.stft.inverse(input, ilens)
if multi_channel:
# wav: (Batch * C, Nsamples) -> (Batch, Nsamples, C)
wav = wav.reshape(bs, -1, wav.size(1)).transpose(1, 2)
return wav, wav_lens