Skip to content

Attention

Attention

Bases: Module

Attention class that creates an attention mechanism with optional context.

Source code in models/tts/styledtts2/diffusion/attention.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
class Attention(nn.Module):
    r"""Attention class that creates an attention mechanism with optional context."""

    def __init__(
        self,
        features: int,
        *,
        head_features: int,
        num_heads: int,
        out_features: Optional[int] = None,
        context_features: Optional[int] = None,
        use_rel_pos: bool,
        rel_pos_num_buckets: Optional[int] = None,
        rel_pos_max_distance: Optional[int] = None,
    ):
        r"""Initialize the Attention with features, head features, number of heads, and relative position parameters.

        Args:
            features (int): The number of input features.
            head_features (int): The number of features in each head.
            num_heads (int): The number of heads.
            out_features (Optional[int]): The number of output features. If None, it will be set to the number of input features.
            context_features (Optional[int]): The number of context features. If None, it will be set to the number of input features.
            use_rel_pos (bool): Whether to use relative position bias.
            rel_pos_num_buckets (Optional[int]): The number of buckets for relative position bias. Required if use_rel_pos is True.
            rel_pos_max_distance (Optional[int]): The maximum distance for relative position bias. Required if use_rel_pos is True.
        """
        super().__init__()
        self.context_features = context_features
        mid_features = head_features * num_heads
        context_features = default(context_features, features)

        self.norm = nn.LayerNorm(features)
        self.norm_context = nn.LayerNorm(context_features)
        self.to_q = nn.Linear(
            in_features=features, out_features=mid_features, bias=False,
        )
        self.to_kv = nn.Linear(
            in_features=context_features, out_features=mid_features * 2, bias=False,
        )

        self.attention = AttentionBase(
            features,
            out_features=out_features,
            num_heads=num_heads,
            head_features=head_features,
            use_rel_pos=use_rel_pos,
            rel_pos_num_buckets=rel_pos_num_buckets,
            rel_pos_max_distance=rel_pos_max_distance,
        )

    def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
        r"""Forward pass of the Attention.

        Args:
            x (Tensor): The input tensor.
            context (Optional[Tensor]): The context tensor. If None, the input tensor will be used as the context.

        Returns:
            Tensor: The output tensor.
        """
        assert_message = "You must provide a context when using context_features"
        assert not self.context_features or exists(context), assert_message

        # Use context if provided
        context = default(context, x)
        # Normalize then compute q from input and k,v from context
        x, context = self.norm(x), self.norm_context(context)
        q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))

        # Compute and return attention
        return self.attention(q, k, v)

__init__(features, *, head_features, num_heads, out_features=None, context_features=None, use_rel_pos, rel_pos_num_buckets=None, rel_pos_max_distance=None)

Initialize the Attention with features, head features, number of heads, and relative position parameters.

Parameters:

Name Type Description Default
features int

The number of input features.

required
head_features int

The number of features in each head.

required
num_heads int

The number of heads.

required
out_features Optional[int]

The number of output features. If None, it will be set to the number of input features.

None
context_features Optional[int]

The number of context features. If None, it will be set to the number of input features.

None
use_rel_pos bool

Whether to use relative position bias.

required
rel_pos_num_buckets Optional[int]

The number of buckets for relative position bias. Required if use_rel_pos is True.

None
rel_pos_max_distance Optional[int]

The maximum distance for relative position bias. Required if use_rel_pos is True.

None
Source code in models/tts/styledtts2/diffusion/attention.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def __init__(
    self,
    features: int,
    *,
    head_features: int,
    num_heads: int,
    out_features: Optional[int] = None,
    context_features: Optional[int] = None,
    use_rel_pos: bool,
    rel_pos_num_buckets: Optional[int] = None,
    rel_pos_max_distance: Optional[int] = None,
):
    r"""Initialize the Attention with features, head features, number of heads, and relative position parameters.

    Args:
        features (int): The number of input features.
        head_features (int): The number of features in each head.
        num_heads (int): The number of heads.
        out_features (Optional[int]): The number of output features. If None, it will be set to the number of input features.
        context_features (Optional[int]): The number of context features. If None, it will be set to the number of input features.
        use_rel_pos (bool): Whether to use relative position bias.
        rel_pos_num_buckets (Optional[int]): The number of buckets for relative position bias. Required if use_rel_pos is True.
        rel_pos_max_distance (Optional[int]): The maximum distance for relative position bias. Required if use_rel_pos is True.
    """
    super().__init__()
    self.context_features = context_features
    mid_features = head_features * num_heads
    context_features = default(context_features, features)

    self.norm = nn.LayerNorm(features)
    self.norm_context = nn.LayerNorm(context_features)
    self.to_q = nn.Linear(
        in_features=features, out_features=mid_features, bias=False,
    )
    self.to_kv = nn.Linear(
        in_features=context_features, out_features=mid_features * 2, bias=False,
    )

    self.attention = AttentionBase(
        features,
        out_features=out_features,
        num_heads=num_heads,
        head_features=head_features,
        use_rel_pos=use_rel_pos,
        rel_pos_num_buckets=rel_pos_num_buckets,
        rel_pos_max_distance=rel_pos_max_distance,
    )

forward(x, *, context=None)

Forward pass of the Attention.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required
context Optional[Tensor]

The context tensor. If None, the input tensor will be used as the context.

None

Returns:

Name Type Description
Tensor Tensor

The output tensor.

Source code in models/tts/styledtts2/diffusion/attention.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
    r"""Forward pass of the Attention.

    Args:
        x (Tensor): The input tensor.
        context (Optional[Tensor]): The context tensor. If None, the input tensor will be used as the context.

    Returns:
        Tensor: The output tensor.
    """
    assert_message = "You must provide a context when using context_features"
    assert not self.context_features or exists(context), assert_message

    # Use context if provided
    context = default(context, x)
    # Normalize then compute q from input and k,v from context
    x, context = self.norm(x), self.norm_context(context)
    q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))

    # Compute and return attention
    return self.attention(q, k, v)

AttentionBase

Bases: Module

AttentionBase class that creates a base attention mechanism.

Source code in models/tts/styledtts2/diffusion/attention.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
class AttentionBase(nn.Module):
    r"""AttentionBase class that creates a base attention mechanism."""

    def __init__(
        self,
        features: int,
        *,
        head_features: int,
        num_heads: int,
        use_rel_pos: bool,
        out_features: Optional[int] = None,
        rel_pos_num_buckets: Optional[int] = None,
        rel_pos_max_distance: Optional[int] = None,
    ):
        r"""Initialize the AttentionBase with features, head features, number of heads, and relative position parameters.

        Args:
            features (int): The number of input features.
            head_features (int): The number of features in each head.
            num_heads (int): The number of heads.
            use_rel_pos (bool): Whether to use relative position bias.
            out_features (Optional[int]): The number of output features. If None, it will be set to the number of input features.
            rel_pos_num_buckets (Optional[int]): The number of buckets for relative position bias. Required if use_rel_pos is True.
            rel_pos_max_distance (Optional[int]): The maximum distance for relative position bias. Required if use_rel_pos is True.
        """
        super().__init__()
        self.scale = head_features ** -0.5
        self.num_heads = num_heads
        self.use_rel_pos = use_rel_pos
        mid_features = head_features * num_heads

        if use_rel_pos:
            if not exists(rel_pos_num_buckets):
                raise ValueError("rel_pos_num_buckets must be provided.")
            if not exists(rel_pos_max_distance):
                raise ValueError("rel_pos_max_distance must be provided.")

            self.rel_pos = RelativePositionBias(
                num_buckets=rel_pos_num_buckets,
                max_distance=rel_pos_max_distance,
                num_heads=num_heads,
            )
        if out_features is None:
            out_features = features

        self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        r"""Forward pass of the AttentionBase.

        Args:
            q (Tensor): The query tensor.
            k (Tensor): The key tensor.
            v (Tensor): The value tensor.

        Returns:
            Tensor: The output tensor.
        """
        # Split heads
        q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
        # Compute similarity matrix
        sim = einsum("... n d, ... m d -> ... n m", q, k)
        sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
        sim = sim * self.scale
        # Get attention matrix with softmax
        attn = sim.softmax(dim=-1)
        # Compute values
        out = einsum("... n m, ... m d -> ... n d", attn, v)
        out = rearrange(out, "b h n d -> b n (h d)")
        return self.to_out(out)

__init__(features, *, head_features, num_heads, use_rel_pos, out_features=None, rel_pos_num_buckets=None, rel_pos_max_distance=None)

Initialize the AttentionBase with features, head features, number of heads, and relative position parameters.

Parameters:

Name Type Description Default
features int

The number of input features.

required
head_features int

The number of features in each head.

required
num_heads int

The number of heads.

required
use_rel_pos bool

Whether to use relative position bias.

required
out_features Optional[int]

The number of output features. If None, it will be set to the number of input features.

None
rel_pos_num_buckets Optional[int]

The number of buckets for relative position bias. Required if use_rel_pos is True.

None
rel_pos_max_distance Optional[int]

The maximum distance for relative position bias. Required if use_rel_pos is True.

None
Source code in models/tts/styledtts2/diffusion/attention.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def __init__(
    self,
    features: int,
    *,
    head_features: int,
    num_heads: int,
    use_rel_pos: bool,
    out_features: Optional[int] = None,
    rel_pos_num_buckets: Optional[int] = None,
    rel_pos_max_distance: Optional[int] = None,
):
    r"""Initialize the AttentionBase with features, head features, number of heads, and relative position parameters.

    Args:
        features (int): The number of input features.
        head_features (int): The number of features in each head.
        num_heads (int): The number of heads.
        use_rel_pos (bool): Whether to use relative position bias.
        out_features (Optional[int]): The number of output features. If None, it will be set to the number of input features.
        rel_pos_num_buckets (Optional[int]): The number of buckets for relative position bias. Required if use_rel_pos is True.
        rel_pos_max_distance (Optional[int]): The maximum distance for relative position bias. Required if use_rel_pos is True.
    """
    super().__init__()
    self.scale = head_features ** -0.5
    self.num_heads = num_heads
    self.use_rel_pos = use_rel_pos
    mid_features = head_features * num_heads

    if use_rel_pos:
        if not exists(rel_pos_num_buckets):
            raise ValueError("rel_pos_num_buckets must be provided.")
        if not exists(rel_pos_max_distance):
            raise ValueError("rel_pos_max_distance must be provided.")

        self.rel_pos = RelativePositionBias(
            num_buckets=rel_pos_num_buckets,
            max_distance=rel_pos_max_distance,
            num_heads=num_heads,
        )
    if out_features is None:
        out_features = features

    self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)

forward(q, k, v)

Forward pass of the AttentionBase.

Parameters:

Name Type Description Default
q Tensor

The query tensor.

required
k Tensor

The key tensor.

required
v Tensor

The value tensor.

required

Returns:

Name Type Description
Tensor Tensor

The output tensor.

Source code in models/tts/styledtts2/diffusion/attention.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
    r"""Forward pass of the AttentionBase.

    Args:
        q (Tensor): The query tensor.
        k (Tensor): The key tensor.
        v (Tensor): The value tensor.

    Returns:
        Tensor: The output tensor.
    """
    # Split heads
    q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
    # Compute similarity matrix
    sim = einsum("... n d, ... m d -> ... n m", q, k)
    sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
    sim = sim * self.scale
    # Get attention matrix with softmax
    attn = sim.softmax(dim=-1)
    # Compute values
    out = einsum("... n m, ... m d -> ... n d", attn, v)
    out = rearrange(out, "b h n d -> b n (h d)")
    return self.to_out(out)

RelativePositionBias

Bases: Module

RelativePositionBias class that creates a relative position bias for attention mechanisms.

Source code in models/tts/styledtts2/diffusion/attention.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class RelativePositionBias(nn.Module):
    r"""RelativePositionBias class that creates a relative position bias for attention mechanisms."""

    def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
        r"""Initialize the RelativePositionBias with a number of buckets, maximum distance, and number of heads.

        Args:
            num_buckets (int): The number of buckets for the relative position bias.
            max_distance (int): The maximum distance for the relative position bias.
            num_heads (int): The number of heads for the relative position bias.
        """
        super().__init__()
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.num_heads = num_heads
        self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)

    @staticmethod
    def _relative_position_bucket(
        relative_position: Tensor, num_buckets: int, max_distance: int,
    ) -> Tensor:
        r"""Compute the relative position bucket.

        Args:
            relative_position (Tensor): The relative position tensor.
            num_buckets (int): The number of buckets.
            max_distance (int): The maximum distance.

        Returns:
            Tensor: The relative position bucket tensor.
        """
        num_buckets //= 2
        ret = (relative_position >= 0).to(torch.long) * num_buckets
        n = torch.abs(relative_position)

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = (
            max_exact
            + (
                torch.log(n.float() / max_exact)
                / log(max_distance / max_exact)
                * (num_buckets - max_exact)
            ).long()
        )
        val_if_large = torch.min(
            val_if_large, torch.full_like(val_if_large, num_buckets - 1),
        )

        ret += torch.where(is_small, n, val_if_large)
        return ret

    def forward(self, num_queries: int, num_keys: int) -> Tensor:
        r"""Forward pass of the RelativePositionBias.

        Args:
            num_queries (int): The number of queries.
            num_keys (int): The number of keys.

        Returns:
            Tensor: The output tensor.
        """
        i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
        q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
        k_pos = torch.arange(j, dtype=torch.long, device=device)
        rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")

        relative_position_bucket = self._relative_position_bucket(
            rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance,
        )

        bias = self.relative_attention_bias(relative_position_bucket)
        bias = rearrange(bias, "m n h -> 1 h m n")
        return bias

__init__(num_buckets, max_distance, num_heads)

Initialize the RelativePositionBias with a number of buckets, maximum distance, and number of heads.

Parameters:

Name Type Description Default
num_buckets int

The number of buckets for the relative position bias.

required
max_distance int

The maximum distance for the relative position bias.

required
num_heads int

The number of heads for the relative position bias.

required
Source code in models/tts/styledtts2/diffusion/attention.py
16
17
18
19
20
21
22
23
24
25
26
27
28
def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
    r"""Initialize the RelativePositionBias with a number of buckets, maximum distance, and number of heads.

    Args:
        num_buckets (int): The number of buckets for the relative position bias.
        max_distance (int): The maximum distance for the relative position bias.
        num_heads (int): The number of heads for the relative position bias.
    """
    super().__init__()
    self.num_buckets = num_buckets
    self.max_distance = max_distance
    self.num_heads = num_heads
    self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)

forward(num_queries, num_keys)

Forward pass of the RelativePositionBias.

Parameters:

Name Type Description Default
num_queries int

The number of queries.

required
num_keys int

The number of keys.

required

Returns:

Name Type Description
Tensor Tensor

The output tensor.

Source code in models/tts/styledtts2/diffusion/attention.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def forward(self, num_queries: int, num_keys: int) -> Tensor:
    r"""Forward pass of the RelativePositionBias.

    Args:
        num_queries (int): The number of queries.
        num_keys (int): The number of keys.

    Returns:
        Tensor: The output tensor.
    """
    i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
    q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
    k_pos = torch.arange(j, dtype=torch.long, device=device)
    rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")

    relative_position_bucket = self._relative_position_bucket(
        rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance,
    )

    bias = self.relative_attention_bias(relative_position_bucket)
    bias = rearrange(bias, "m n h -> 1 h m n")
    return bias

FeedForward(features, multiplier)

Creates a feed-forward neural network with GELU activation in the middle layer.

Parameters:

Name Type Description Default
features int

The number of input and output features.

required
multiplier int

The factor to multiply the number of features to get the number of features in the middle layer.

required

Returns:

Type Description
Module

nn.Module: A feed-forward neural network module.

Source code in models/tts/styledtts2/diffusion/attention.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def FeedForward(features: int, multiplier: int) -> nn.Module:
    r"""Creates a feed-forward neural network with GELU activation in the middle layer.

    Args:
        features (int): The number of input and output features.
        multiplier (int): The factor to multiply the number of features to get the number of features in the middle layer.

    Returns:
        nn.Module: A feed-forward neural network module.
    """
    mid_features = features * multiplier
    return nn.Sequential(
        nn.Linear(in_features=features, out_features=mid_features),
        nn.GELU(),
        nn.Linear(in_features=mid_features, out_features=features),
    )