Skip to content

AddCoords

AddCoords

Bases: Module

AddCoords is a PyTorch module that adds additional channels to the input tensor containing the relative (normalized to [-1, 1]) coordinates of each input element along the specified number of dimensions (rank). Essentially, it adds spatial context information to the tensor.

Typically, these inputs are feature maps coming from some CNN, where the spatial organization of the input matters (such as an image or speech signal).

This additional spatial context allows subsequent layers (such as convolutions) to learn position-dependent features. For example, in tasks where the absolute position of features matters (such as denoising and segmentation tasks), it helps the model to know where (in terms of relative position) the features are.

Parameters:

Name Type Description Default
rank int

The dimensionality of the input tensor. That is to say, this tells us how many dimensions the input tensor's spatial context has. It's assumed to be 1, 2, or 3 corresponding to some 1D, 2D, or 3D data (like an image).

required
with_r bool

Boolean indicating whether to add an extra radial distance channel or not. If True, an extra channel is appended, which measures the Euclidean (L2) distance from the center of the image. This might be useful when the proximity to the center of the image is important to the task.

False
Source code in models/tts/delightful_tts/conv_blocks/add_coords.py
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class AddCoords(Module):
    r"""AddCoords is a PyTorch module that adds additional channels to the input tensor containing the relative
    (normalized to `[-1, 1]`) coordinates of each input element along the specified number of dimensions (`rank`).
    Essentially, it adds spatial context information to the tensor.

    Typically, these inputs are feature maps coming from some CNN, where the spatial organization of the input
    matters (such as an image or speech signal).

    This additional spatial context allows subsequent layers (such as convolutions) to learn position-dependent
    features. For example, in tasks where the absolute position of features matters (such as denoising and
    segmentation tasks), it helps the model to know where (in terms of relative position) the features are.

    Args:
        rank (int): The dimensionality of the input tensor. That is to say, this tells us how many dimensions the
                    input tensor's spatial context has. It's assumed to be 1, 2, or 3 corresponding to some 1D, 2D,
                    or 3D data (like an image).

        with_r (bool): Boolean indicating whether to add an extra radial distance channel or not. If True, an extra
                       channel is appended, which measures the Euclidean (L2) distance from the center of the image.
                       This might be useful when the proximity to the center of the image is important to the task.
    """

    def __init__(self, rank: int, with_r: bool = False):
        super().__init__()
        self.rank = rank
        self.with_r = with_r

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""Forward pass of the AddCoords module. Depending on the rank of the tensor, it adds one or more new channels
        with relative coordinate values. If `with_r` is True, an extra radial channel is included.

        For example, for an image (`rank=2`), two channels would be added which contain the normalized x and y
        coordinates respectively of each pixel.

        Calling the forward method updates the original tensor `x` with the added channels.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            out (torch.Tensor): The input tensor with added coordinate and possibly radial channels.
        """
        if self.rank == 1:
            batch_size_shape, _, dim_x = x.shape
            xx_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
            xx_channel = xx_range[None, None, :]

            xx_channel = xx_channel.float() / (dim_x - 1)
            xx_channel = xx_channel * 2 - 1
            xx_channel = xx_channel.repeat(batch_size_shape, 1, 1)

            out = torch.cat([x, xx_channel], dim=1)

            if self.with_r:
                rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2))
                out = torch.cat([out, rr], dim=1)

        elif self.rank == 2:
            batch_size_shape, _, dim_y, dim_x = x.shape
            xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32, device=x.device)
            yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32, device=x.device)

            xx_range = torch.arange(dim_y, dtype=torch.int32, device=x.device)
            yy_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
            xx_range = xx_range[None, None, :, None]
            yy_range = yy_range[None, None, :, None]

            xx_channel = torch.matmul(xx_range, xx_ones)
            yy_channel = torch.matmul(yy_range, yy_ones)

            # transpose y
            yy_channel = yy_channel.permute(0, 1, 3, 2)

            xx_channel = xx_channel.float() / (dim_y - 1)
            yy_channel = yy_channel.float() / (dim_x - 1)

            xx_channel = xx_channel * 2 - 1
            yy_channel = yy_channel * 2 - 1

            xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1)
            yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1)

            out = torch.cat([x, xx_channel, yy_channel], dim=1)

            if self.with_r:
                rr = torch.sqrt(
                    torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2),
                )
                out = torch.cat([out, rr], dim=1)

        elif self.rank == 3:
            batch_size_shape, _, dim_z, dim_y, dim_x = x.shape
            xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32, device=x.device)
            yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32, device=x.device)
            zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32, device=x.device)

            xy_range = torch.arange(dim_y, dtype=torch.int32, device=x.device)
            xy_range = xy_range[None, None, None, :, None]

            yz_range = torch.arange(dim_z, dtype=torch.int32, device=x.device)
            yz_range = yz_range[None, None, None, :, None]

            zx_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
            zx_range = zx_range[None, None, None, :, None]

            xy_channel = torch.matmul(xy_range, xx_ones)
            xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2)

            yz_channel = torch.matmul(yz_range, yy_ones)
            yz_channel = yz_channel.permute(0, 1, 3, 4, 2)
            yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4)

            zx_channel = torch.matmul(zx_range, zz_ones)
            zx_channel = zx_channel.permute(0, 1, 4, 2, 3)
            zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3)

            out = torch.cat([x, xx_channel, yy_channel, zz_channel], dim=1)

            if self.with_r:
                rr = torch.sqrt(
                    torch.pow(xx_channel - 0.5, 2)
                    + torch.pow(yy_channel - 0.5, 2)
                    + torch.pow(zz_channel - 0.5, 2),
                )
                out = torch.cat([out, rr], dim=1)
        else:
            raise NotImplementedError

        return out

forward(x)

Forward pass of the AddCoords module. Depending on the rank of the tensor, it adds one or more new channels with relative coordinate values. If with_r is True, an extra radial channel is included.

For example, for an image (rank=2), two channels would be added which contain the normalized x and y coordinates respectively of each pixel.

Calling the forward method updates the original tensor x with the added channels.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required

Returns:

Name Type Description
out Tensor

The input tensor with added coordinate and possibly radial channels.

Source code in models/tts/delightful_tts/conv_blocks/add_coords.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def forward(self, x: torch.Tensor) -> torch.Tensor:
    r"""Forward pass of the AddCoords module. Depending on the rank of the tensor, it adds one or more new channels
    with relative coordinate values. If `with_r` is True, an extra radial channel is included.

    For example, for an image (`rank=2`), two channels would be added which contain the normalized x and y
    coordinates respectively of each pixel.

    Calling the forward method updates the original tensor `x` with the added channels.

    Args:
        x (torch.Tensor): The input tensor.

    Returns:
        out (torch.Tensor): The input tensor with added coordinate and possibly radial channels.
    """
    if self.rank == 1:
        batch_size_shape, _, dim_x = x.shape
        xx_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
        xx_channel = xx_range[None, None, :]

        xx_channel = xx_channel.float() / (dim_x - 1)
        xx_channel = xx_channel * 2 - 1
        xx_channel = xx_channel.repeat(batch_size_shape, 1, 1)

        out = torch.cat([x, xx_channel], dim=1)

        if self.with_r:
            rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2))
            out = torch.cat([out, rr], dim=1)

    elif self.rank == 2:
        batch_size_shape, _, dim_y, dim_x = x.shape
        xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32, device=x.device)
        yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32, device=x.device)

        xx_range = torch.arange(dim_y, dtype=torch.int32, device=x.device)
        yy_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
        xx_range = xx_range[None, None, :, None]
        yy_range = yy_range[None, None, :, None]

        xx_channel = torch.matmul(xx_range, xx_ones)
        yy_channel = torch.matmul(yy_range, yy_ones)

        # transpose y
        yy_channel = yy_channel.permute(0, 1, 3, 2)

        xx_channel = xx_channel.float() / (dim_y - 1)
        yy_channel = yy_channel.float() / (dim_x - 1)

        xx_channel = xx_channel * 2 - 1
        yy_channel = yy_channel * 2 - 1

        xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1)
        yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1)

        out = torch.cat([x, xx_channel, yy_channel], dim=1)

        if self.with_r:
            rr = torch.sqrt(
                torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2),
            )
            out = torch.cat([out, rr], dim=1)

    elif self.rank == 3:
        batch_size_shape, _, dim_z, dim_y, dim_x = x.shape
        xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32, device=x.device)
        yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32, device=x.device)
        zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32, device=x.device)

        xy_range = torch.arange(dim_y, dtype=torch.int32, device=x.device)
        xy_range = xy_range[None, None, None, :, None]

        yz_range = torch.arange(dim_z, dtype=torch.int32, device=x.device)
        yz_range = yz_range[None, None, None, :, None]

        zx_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
        zx_range = zx_range[None, None, None, :, None]

        xy_channel = torch.matmul(xy_range, xx_ones)
        xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2)

        yz_channel = torch.matmul(yz_range, yy_ones)
        yz_channel = yz_channel.permute(0, 1, 3, 4, 2)
        yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4)

        zx_channel = torch.matmul(zx_range, zz_ones)
        zx_channel = zx_channel.permute(0, 1, 4, 2, 3)
        zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3)

        out = torch.cat([x, xx_channel, yy_channel, zz_channel], dim=1)

        if self.with_r:
            rr = torch.sqrt(
                torch.pow(xx_channel - 0.5, 2)
                + torch.pow(yy_channel - 0.5, 2)
                + torch.pow(zz_channel - 0.5, 2),
            )
            out = torch.cat([out, rr], dim=1)
    else:
        raise NotImplementedError

    return out