Source code for sinabs.backend.dynapcnn.flipdims
import torch
import torch.nn as nn
from typing import Tuple
[docs]class FlipDims(nn.Module):
def __init__(
self, flip_x: bool = False, flip_y: bool = False, swap_xy: bool = False
):
super().__init__()
self.flip_x = flip_x
self.flip_y = flip_y
self.swap_xy = swap_xy
[docs] def forward(self, data):
_, _, h, w = list(data.shape)
# Flip along x and y axis
if self.flip_y:
data = data.flip(2)
if self.flip_x:
data = data.flip(3)
if self.swap_xy:
data = torch.transpose(data, 2, 3)
return data
[docs] def get_output_shape(self, input_shape: Tuple) -> Tuple:
"""
Retuns the output dimensions
:param input_shape: (channels, height, width)
:return: (channels, height, width)
"""
return input_shape