Source code for espnet.nets.pytorch_backend.transformer.longformer_attention

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# Copyright 2022 Roshan Sharma (Carnegie Mellon University)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

"""Longformer based Local Attention Definition."""

from longformer.longformer import LongformerConfig, LongformerSelfAttention
from torch import nn


[docs]class LongformerAttention(nn.Module): """Longformer based Local Attention Definition.""" def __init__(self, config: LongformerConfig, layer_id: int): """Compute Longformer based Self-Attention. Args: config : Longformer attention configuration layer_id: Integer representing the layer index """ super().__init__() self.attention_window = config.attention_window[layer_id] self.attention_layer = LongformerSelfAttention(config, layer_id=layer_id) self.attention = None
[docs] def forward(self, query, key, value, mask): """Compute Longformer Self-Attention with masking. Expects `len(hidden_states)` to be multiple of `attention_window`. Padding to `attention_window` happens in :meth:`encoder.forward` to avoid redoing the padding on each layer. Args: query (torch.Tensor): Query tensor (#batch, time1, size). key (torch.Tensor): Key tensor (#batch, time2, size). value (torch.Tensor): Value tensor (#batch, time2, size). pos_emb (torch.Tensor): Positional embedding tensor (#batch, 2*time1-1, size). mask (torch.Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2). Returns: torch.Tensor: Output tensor (#batch, time1, d_model). """ attention_mask = mask.int() attention_mask[mask == 0] = -1 attention_mask[mask == 1] = 0 output, self.attention = self.attention_layer( hidden_states=query, attention_mask=attention_mask.unsqueeze(1), head_mask=None, output_attentions=True, ) return output