Source code for sinabs.backend.dynapcnn.crop2d

from typing import Union, List, Tuple
import numpy as np
from torch import nn

ArrayLike = Union[np.ndarray, List, Tuple]


[docs]class Crop2d(nn.Module): """ Crop input image by """ def __init__( self, cropping: ArrayLike = ((0, 0), (0, 0)), ): """ Crop input to the the rectangle dimensions :param cropping: ((top, bottom), (left, right)) """ super().__init__() self.top_crop, self.bottom_crop = cropping[0] self.left_crop, self.right_crop = cropping[1]
[docs] def forward(self, binary_input): # Crop the data array crop_out = binary_input[ :, :, self.top_crop: self.bottom_crop, self.left_crop: self.right_crop, ] self.out_shape = crop_out.shape[1:] self.spikes_number = crop_out.abs().sum() self.tw = len(crop_out) return crop_out
[docs] def get_output_shape(self, input_shape: Tuple) -> Tuple: """ Retuns the output dimensions :param input_shape: (channels, height, width) :return: (channels, height, width) """ channels, height, width = input_shape return ( channels, self.bottom_crop - self.top_crop, self.right_crop - self.left_crop, )
def __repr__(self): return f"Crop2d(({self.top_crop}, {self.bottom_crop}), ({self.left_crop}, {self.right_crop}))"