Multi-Head Attention
MultiHeadAttention
Bases: Module
A class that implements a Multi-head Attention mechanism. Multi-head attention allows the model to focus on different positions, capturing various aspects of the input.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query_dim |
int
|
The dimensionality of the query. |
required |
key_dim |
int
|
The dimensionality of the key. |
required |
num_units |
int
|
The total number of dimensions of the output. |
required |
num_heads |
int
|
The number of parallel attention layers (multi-heads). |
required |
query, and key
- query: Tensor of shape [N, T_q, query_dim]
- key: Tensor of shape [N, T_k, key_dim]
Outputs
- An output tensor of shape [N, T_q, num_units]
Source code in models/tts/delightful_tts/attention/multi_head_attention.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
|
forward(query, key)
Performs the forward pass over input tensors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query |
Tensor
|
The input tensor containing query vectors. It is expected to have the dimensions [N, T_q, query_dim] where N is the batch size, T_q is the sequence length of queries, and query_dim is the dimensionality of a single query vector. |
required |
key |
Tensor
|
The input tensor containing key vectors. It is expected to have the dimensions [N, T_k, key_dim] where N is the batch size, T_k is the sequence length of keys, and key_dim is the dimensionality of a single key vector. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: The output tensor of shape [N, T_q, num_units] which represents the results of the multi-head attention mechanism applied on the provided queries and keys. |
Source code in models/tts/delightful_tts/attention/multi_head_attention.py
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
|