Source code for qbm4eo.decoder

"""
=============================================================================

This file is a part of qbm4eo.rst project.

https://github.com/FeralQubits/qbm4eo

=============================================================================

It has been modified as a part of the EuroHPC PL project funded at the Smart Growth
Operational Programme 2014-2020, Measure 4.2 under the grant agreement no.
POIR.04.02.00-00-D014/20-00.

=============================================================================
"""

from math import prod
from typing import Any, Dict, List, Optional, Sequence

import torch
from torch import nn


[docs] class ResBlockDeConvPart(nn.Module): """ A single part of the ResBlockDeConv class. """ def __init__( self, channels: int, *args: Dict[str, Any], negative_slope: float = 0.02, bias: bool = False, **kwargs: Dict[str, Any], ) -> None: """ A default constructor for the ResBlockDeConvPart class. :param channels: The number of channels in the input and output. :param negative_slope: The negative slope of the LeakyReLU activation function. :param bias: Whether to use a bias term in the convolutional layer. :param args: Additional arguments to pass to the super class. :param kwargs: Additional keyword arguments to pass to the super class. """ super().__init__(*args, **kwargs) self.subnet: nn.Sequential = nn.Sequential( nn.LeakyReLU(negative_slope), nn.ConvTranspose2d( channels, channels, kernel_size=3, stride=1, padding=1, bias=bias ), nn.BatchNorm2d(channels), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ The forward pass of the ResBlockDeConvPart class. :param x: The input tensor. :return: The output tensor. """ return self.subnet(x)
[docs] class ResBlockDeConv(nn.Module): """ A residual block for the decoder part of the LBAE model. """ def __init__( self, channels: int, *args: Dict[str, Any], in_channels: Optional[int] = None, negative_slope: float = 0.02, bias: bool = False, **kwargs: Dict[str, Any], ) -> None: """ A default constructor for the ResBlockDeConv class. :param channels: The number of channels in the input and output. :param in_channels: The number of channels in the input. If None, defaults to the number of channels in the output. Defaults to None. :param negative_slope: The negative slope of the LeakyReLU activation function. :param bias: Whether to use a bias term in the convolutional layer. :param args: Additional arguments to pass to the super class. :param kwargs: Additional keyword arguments to pass to the super class. """ super().__init__(*args, **kwargs) if in_channels is None: in_channels = channels self.initial_block: nn.Sequential = nn.Sequential( nn.ConvTranspose2d( in_channels, channels, kernel_size=4, stride=2, padding=1, output_padding=0, bias=bias, ), nn.BatchNorm2d(channels), ) self.middle_block: nn.Sequential = nn.Sequential( ResBlockDeConvPart(channels, negative_slope=negative_slope, bias=bias), ResBlockDeConvPart(channels, negative_slope=negative_slope, bias=bias), ) self.negative_slope: float = negative_slope
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ The forward pass of the ResBlockDeConv class. :param x: The input tensor. :return: The output tensor. """ x = self.initial_block(x) y: torch.Tensor = x x = self.middle_block(x) return nn.functional.leaky_relu(x + y, self.negative_slope)
[docs] class LBAEDecoder(nn.Module): """ The decoder part of the LBAE model. """ def __init__( self, input_size: Sequence[int], output_size: Sequence[int], latent_space_size: int, num_layers: int, *args: Dict[str, Any], negative_slope: float = 0.02, bias: bool = False, **kwargs: Dict[str, Any], ) -> None: """ A default constructor for the LBAEDecoder class. :param input_size: The size of the input tensor. This should be a tuple of the form (channels, height, width). The height and width should be divisible by 2**num_layers. :param output_size: The size of the output tensor. :param latent_space_size: The size of the latent space. This is the size of the input to the linear layer. :param num_layers: The number of layers in the decoder. This is the number of ResBlockDeConv blocks. :param negative_slope: The negative slope of the LeakyReLU activation function. :param bias: Whether to use a bias term in the convolutional layer. :param args: Additional arguments to pass to the super class. :param kwargs: Additional keyword arguments to pass to the super class. """ super().__init__(*args, **kwargs) self.input_size: Sequence[int] = input_size self.linear: nn.Linear = nn.Linear(latent_space_size, prod(input_size)) layers: List[nn.Module] = [] for i in range(num_layers): layers.append( ResBlockDeConv( channels=input_size[0] // (2 ** (i + 1)), in_channels=input_size[0] // (2**i), ) ) layers += [ # Again, swapped input/output number of channels nn.ConvTranspose2d( input_size[0] // 2**num_layers, input_size[0] // 2**num_layers, kernel_size=4, stride=2, bias=bias, padding=1, output_padding=0, ), nn.BatchNorm2d(input_size[0] // 2**num_layers), nn.LeakyReLU(negative_slope), nn.ConvTranspose2d( input_size[0] // 2**num_layers, output_size[0], kernel_size=3, stride=1, padding=1, bias=bias, ), ] self.net: nn.Sequential = nn.Sequential(*layers)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ The forward pass of the LBAEDecoder class. :param x: The input tensor. :return: The output tensor of the decoder. """ x = x.view(x.size(0), -1) x = self.linear(x) x = x.view(x.size(0), *self.input_size) x = self.net(x) x = torch.sigmoid(x) return x