C4 Group CNN
This was my learning process for the simple group convolution from scratch, which I did back in 2021 along with the Geometric Deep Learning (GDL) course. This is part of the tutorial 2, which can be found here. I was trying to remind myself of some of the key concepts by cleaning my previous codes and summarizing here for myself.
The key of the group equivariant neural network is to identify the spaces of input, hidden and output because the network will perform like a function mapping from one space to another. I use \(X\), \(Y\) as input and output spaces and \(H_i\) as the i-th hidden space. The network consists a number of layers, transforming the input to hidden to output:
\[X \to H_1 \to H_2 \to ... \to H_i \to ... \to Y\]Each of the layers is represented as \(\to\) in the above and depneding on which spaces the layers map from and to, we need to consider the design of the layers. This will be the focus of this post.
0. Dependencies
Let’s just load some dependencies for later use.
# packages
import os, torch, einops
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
TOL = 1e-5
# if torch.cuda.is_available(): DEVICE = 'cuda'
# else: DEVICE = 'mps' if torch.backends.mps.is_available() else 'cpu'
DEVICE = 'cpu' # I find the 'mps' device is buggy or I have not spent time to write better code for the mac gpu
print(DEVICE)
1. Group
A group \((G, \circ )\) is a set \(G\) with a binary operation: \(\circ : G\times G \to G\), which satisfies the following (Ignored the \(\circ\)).
- Associativity: \(a (bc) = (ab) c\)
- Identity: \(e \in G\), \(g e = e g = g\)
- Inverse: \(\forall g \in G\), \(gg^{-1} = g^{-1}g = e\)
Consider the group \(G = C_4\), the cyclic group with \(\pi / 2\) planar rotation. \(\|C_4\| = 4\). Let \(X\) be the set of some \(n\times n\) gray images. An image \(x \in X\) is a function \(x: p \to x[p] \in R\) which maps each pixel \(p = (h, w)\) to a real number.
An element \(g \in G = C_4\), transforms an image \(x \in X\) into the image \(gx \in X\) through rotation. The rotated image \(gx\) is \([gx](p) = x(g^{-1}p)\) where \(g^{-1}p\) is the pixel in the unrotated image. This action is the following:
# rotate an image, the x is ... x H x W in dimension
def rotate(x, r): return x.rot90(r, dims=(-2, -1))
2. Equivariant Convolution Layers
An equivariant layer (or function) \(\psi: X \to Y\) from an input G-space \(X\) to an output G-space \(Y\). The input space is the pixel space and we choose the same output space: \(Y=X\). So both input and output are the gray-scale image (they might be of different sizes / dimensions.)
The equivariant layer will be \(3 \times 3\) filter. We will verify that a random filter is not a rotational equivariant layer.
# random filter
torch.manual_seed(37)
x = torch.randn(1, 1, 33, 33) ** 3
gx = rotate(x, 1)
filter = torch.randn(1, 1, 3, 3)
psi_x = torch.conv2d(x, filter, bias=None, padding=1)
psi_gx = torch.conv2d(gx, filter, bias=None, padding=1)
g_psi_x = rotate(psi_x, 1)
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(g_psi_x[0, 0].numpy())
axes[0].set_title('$g.\psi(x)$')
axes[1].imshow(psi_gx[0, 0].numpy())
axes[1].set_title('$\psi(g.x)$')
plt.show()
print('Equivariant' if torch.allclose(psi_gx, g_psi_x, atol=TOL, rtol=TOL) else 'Not equivariant!!')
Not equivariant!!

Clearly, if the filter \(\psi\) has no constraint on the weights,
\[\psi(gx) \neq g\psi(x)\]There must be some \(C_4\) symmetry boiled in the filter. Let’s first try the isotropic filter where there are 2 trainable weights: one in the middle, the the other in the ring.
# Isotropic filter
# The filter looks like:
# a, a, a
# a, b, a
# a, a, a
torch.manual_seed(37)
filter = torch.empty(1, 1, 3, 3)
filter[0, 0, 1, 1] = np.random.randn() # middle pixel
mask = torch.ones(3, 3, dtype=torch.bool)
mask[1, 1] = 0
filter[0, 0, mask] = np.random.randn() # ring pixel
# we recycle the previous codes:
psi_x = torch.conv2d(x, filter, bias=None, padding=1)
psi_gx = torch.conv2d(gx, filter, bias=None, padding=1)
g_psi_x = rotate(psi_x, 1)
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(g_psi_x[0, 0].numpy())
axes[0].set_title('$g.\psi(x)$')
axes[1].imshow(psi_gx[0, 0].numpy())
axes[1].set_title('$\psi(g.x)$')
plt.show()
print('Equivariant' if torch.allclose(psi_gx, g_psi_x, atol=TOL, rtol=TOL) else 'Not equivariant!!')
Equivariant

Let’s use this filter to build a equivariant convolution layer:
class IsotropicEConv2d(torch.nn.Module):
def __init__(self, in_channels, out_channels, bias=True):
super(IsotropicEConv2d, self).__init__()
self.kernel_size = 3
self.stride = self.dilation = self.padding = 1
self.in_channels = in_channels
self.out_channels = out_channels
# for each filter contains 2 parameters
# There are total of in_channels x out_channels filters
self.weight = self.weight = torch.nn.Parameter(torch.empty(self.out_channels, self.in_channels, 2).normal_(mean=0, std=1/np.sqrt(in_channels * out_channels)), requires_grad=True)
self.bias = torch.nn.Parameter(torch.zeros(out_channels), requires_grad=True) if bias else None
def build_filter(self):
mask = torch.ones(3, 3, dtype=torch.bool)
filter = torch.zeros(self.out_channels, self.in_channels, 3, 3)
filter[:, :, mask] = self.weight[:, :, 1].unsqueeze(2).repeat(1, 1, 9)
filter[:, :, 1, 1] = self.weight[:, :, 0]
return filter
def forward(self, x):
_filter = self.build_filter()
return torch.conv2d(x, _filter, stride=self.stride, padding=self.padding, dilation=self.dilation, bias=self.bias)
Since this layer maps from input space \(X\) to input space \(X\), so we define the following equivariance checker.
def check_equivariance_XX(layer, x, tol=TOL):
layer.eval()
gx = rotate(x, 1)
with torch.no_grad():
psi_x = layer(x)
psi_gx = layer(gx)
g_psi_x = rotate(psi_x, 1)
assert not torch.allclose(psi_x, torch.zeros_like(psi_x), atol=tol, rtol=tol)
print('Equivariant' if torch.allclose(psi_gx, g_psi_x, atol=tol, rtol=tol) else 'Not equivariant!!')
return psi_gx, g_psi_x
Let’s check if the IsotropicEConv2d
is acturally equivariant:
torch.manual_seed(37)
in_channels = 5
out_channels = 10
batchsize = 16
S = 33
x = torch.randn(batchsize, in_channels, S, S) ** 2
layer = IsotropicEConv2d(in_channels=x.shape[1], out_channels=16, bias=True)
psi_gx, g_psi_x = check_equivariance_XX(layer, x, tol=TOL)
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(g_psi_x[0, 0].numpy())
axes[0].set_title('$g.\psi(x)$')
axes[1].imshow(psi_gx[0, 0].numpy())
axes[1].set_title('$\psi(g.x)$')
plt.show()
Equivariant

Unfortunately, isotropic filters are not very expressive. Instead, we would like to use more general, unconstrained filters. To do so, we need to rely on group convolution.
Let \(X\) be the space of grayscale images, \(\psi \in X\) be a filter and \(x \in x\) be an input image. The group convolution is
\[[\psi x](t, r) = \sum_{p} \psi((t, r)^{-1}p)x(p) = \sum_{p} \psi(r^{-1}(p - t))x(p)\]The output of the convolution is not a grayscale image in \(X\). It is now a function over the rotational group. The use of the filter \(\psi\) in the group convolution maps the input space \(X\) into a new larger space $Y$, where \(Y\) is the space of all function \(y: p_4 \to \mathbf{R}\).
This is the lifting convolution since it maps the space \(X\) to the more complex space \(Y\). Note that a function \(y \in Y\) can be implemented as a 4-channel image, where the ith channel is defined as \(y_i(t) = y(y, r=i) \in \mathbf{R}, i \in \{0, 1, 2, 3\}\)
In the end of the day, we want to have a network
\[X \to H_1 \to H_2 \to ... \to Y\]where \(X\) is the grayscale image or the 3-channel image space and \(H\)’s are the group (hidden) space and \(Y\) can be a pooled invariant output or equivariant output space.
def rotate_p4(y, r):
# y is (..., 4, h, w)
# r = 0, 1, 2, 3
assert len(y.shape) >= 3
assert y.shape[-3] == 4
assert r in [0, 1, 2, 3]
ry = y.clone()
# then we just rotate each of the 4 channels
for i in range(4): ry[:, :, (i + r) % 4, :, :] = rotate(y[:, :, i, :, :], r)
return ry
See the rotation p4 group by \(r = 1\):
# Test the rotation by r = 1
torch.manual_seed(37)
y = torch.randn(1, 1, 4, 33, 33) ** 3
y[0, 0, :, 16, 16] = torch.ones(4)*10
ry = rotate_p4(y, 1)
fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, squeeze=True, figsize=(16, 4))
for i in range(4):
axes[i].imshow(y[0, 0, i].numpy())
fig.suptitle('Original y')
plt.show()
fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, squeeze=True, figsize=(16, 4))
for i in range(4):
axes[i].imshow(ry[0, 0, i].numpy())
fig.suptitle('Rotated y')
plt.show()


Next, we will build a lifting convoltuion. The input is a grayscale image \(x \in X\) and the output is a function \(y \in Y\). This can berealized by exploiting the usual convolution using 4 rotated copies of a SINGLE
learnable filter. The image is colvolved with each copy independently by stacking 4 copies into a unique filter with 4 output channels.
Finally, a convolutional layer usually includes a bias term. In a normal convolutional network, it is common to share the same bias over all pixels, i.e. the same bias is summed to the features at each pixel. Similarly, when we use a lifting convolution, we share the bias over all pixels but also over all rotations, i.e. over the output channels.
# Lift conv of C4 Group
class LiftingConv2d(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int = 0, bias: bool = True):
super(LiftingConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.stride = 1
self.dilation = 1
# learnable weights for `out_channels x in_channels` different learnable filters, each of shape `kernel_size x kernel_size`
# later populate the larger C4 filters of shape `out_channels x 4 x in_channels x kernel_size x kernel_size` by rotating 4 times
self.weight = torch.nn.Parameter(torch.empty(self.out_channels,
self.in_channels,
self.kernel_size,
self.kernel_size).normal_(mean=0,
std=1/np.sqrt(self.in_channels * self.out_channels)),
requires_grad=True)
self.bias = None
if bias:
self.bias = torch.nn.Parameter(torch.randn(self.out_channels), requires_grad=True)
def _build_filter(self) -> torch.Tensor:
# using the tensors of learnable parameters, build
# - the `out_channels x 4 x in_channels x kernel_size x kernel_size` filter
# - the `out_channels x 4` bias
_filter = torch.zeros(self.out_channels, 4, self.in_channels, self.kernel_size, self.kernel_size)
for i in range(4): _filter[:, i, :, :, :] = rotate(self.weight, i)
if self.bias is not None:
_bias = einops.repeat(self.bias, 'c -> c g', g=4)
else:
_bias = torch.zeros(self.out_channels, 4)
return _filter, _bias
def forward(self, x: torch.Tensor) -> torch.Tensor:
_filter, _bias = self._build_filter()
assert _filter.shape == (self.out_channels, 4, self.in_channels, self.kernel_size, self.kernel_size), f'lifting _filter has shape {_filter.shape}'
assert _bias.shape == (self.out_channels, 4), f'lifting _bias has shape {_bias.shape}'
_filter = einops.rearrange(_filter, 'o c i w h -> (o c) i w h')
_bias = einops.rearrange(_bias, 'c g -> (c g)')
out = torch.conv2d(x, _filter, stride=self.stride, padding=self.padding, dilation=self.dilation, bias=_bias)
return einops.rearrange(out, 'b (o c) w h -> b o c w h', c=4)
Since this layer maps from input space \(X\) to group space \(Y\), so we define the following equivariance checker.
def check_equivariance_XY(layer, x, tol=TOL):
layer.eval()
gx = rotate(x, 1) # rotate in X
with torch.no_grad():
psi_x = layer(x)
psi_gx = layer(gx)
g_psi_x = rotate_p4(psi_x, 1) # rotate in Y
assert not torch.allclose(psi_x, torch.zeros_like(psi_x), atol=tol, rtol=tol)
print('Equivariant' if torch.allclose(psi_gx, g_psi_x, atol=tol, rtol=tol) else 'Not equivariant!!')
return psi_gx, g_psi_x
Let’s check if the LiftingConv2d
is equivariant.
# Let's check if the layer is really equivariant
in_channels = 5
out_channels = 10
kernel_size = 3
batchsize = 6
S = 33
layer = LiftingConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size = kernel_size, padding=1, bias=True)
# layer.to(DEVICE)
layer.eval()
x = torch.randn(batchsize, in_channels, S, S)
psi_gx, g_psi_x = check_equivariance_XY(layer, x)
Equivariant
The lifting convolution is only the first piece of building the \(C_4\) equivariant neural network. Remember we are after
\[X \to H_1 \to H_2 \to ... \to Y\]and lifting convolution is the first “\(\to\)”, we still need to build equivariant layers from \(H_i \to H_j\).
As compared to the usual CNN: \(X \to X \to X ... \to Y\).
We will construct the convolution on the OUTPUT
from the lifting convolution.
So we simply use additional 4-rotated filters for the convolution of the input, i.e. 4 rotated filters for 4 channels from lift convolution. The output of this group convolution is also 4 channels and we can stack this group convolution to build a deep G-CNN.
# GroupConv2d
class GroupConv2d(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: int = 0, bias: bool = True):
super(GroupConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.stride = 1
self.dilation = 1
self.weight = torch.nn.Parameter(torch.empty(self.out_channels,
self.in_channels,
4,
self.kernel_size,
self.kernel_size).normal_(mean=0,
std=1/np.sqrt(self.in_channels * self.out_channels)),
requires_grad=True)
self.bias = None
if bias:
self.bias = torch.nn.Parameter(torch.randn(self.out_channels), requires_grad=True)
def _build_filter(self) -> torch.Tensor:
# using the tensors of learnable parameters, build
# - the `out_channels x 4 x in_channels x 4 x kernel_size x kernel_size` filter
# - the `out_channels x 4` bias
_filter = torch.zeros(self.out_channels, 4, self.in_channels, 4, self.kernel_size, self.kernel_size)
for i in range(4): _filter[:, i, :, :, :] = rotate_p4(self.weight, i)
if self.bias is not None:
_bias = einops.repeat(self.bias, 'c -> c g', g=4)
else:
_bias = torch.zeros(self.out_channels, 4)
return _filter, _bias
def forward(self, x: torch.Tensor) -> torch.Tensor:
_filter, _bias = self._build_filter()
assert _filter.shape == (self.out_channels, 4, self.in_channels, 4, self.kernel_size, self.kernel_size), f'groupconv _filter has shape {_filter.shape}'
assert _bias.shape == (self.out_channels, 4), f'groupconv _bias has shape {_bias.shape}'
_filter = einops.rearrange(_filter, 'o c i s w h -> (o c) (i s) w h')
_bias = einops.rearrange(_bias, 'c g -> (c g)')
# x is `batch_size x in_channels x 4 x W x H`
x = einops.rearrange(x, 'b i c w h -> b (i c) w h')
out = torch.conv2d(x, _filter, stride=self.stride, padding=self.padding, dilation=self.dilation, bias=_bias)
return einops.rearrange(out, 'b (o c) w h -> b o c w h', c=4)
Now, GroupConv2d
maps from group space \(Y\) to the same group space \(Y\). We will define another equivariance checker.
def check_equivariance_YY(layer, x, tol=TOL):
layer.eval()
gx = rotate_p4(x, 1) # rotate in Y
with torch.no_grad():
psi_x = layer(x)
psi_gx = layer(gx)
g_psi_x = rotate_p4(psi_x, 1) # rotate in Y
assert not torch.allclose(psi_x, torch.zeros_like(psi_x), atol=tol, rtol=tol)
print('Equivariant' if torch.allclose(psi_gx, g_psi_x, atol=tol, rtol=tol) else 'Not equivariant!!')
return psi_gx, g_psi_x
And ckeck the equivariance.
# Let's check if the layer is really equivariant
torch.manual_seed(42)
in_channels = 5
out_channels = 10
kernel_size = 3
batchsize = 4
S = 33
layer = GroupConv2d(in_channels=in_channels, out_channels=out_channels, kernel_size = kernel_size, padding=1, bias=True)
layer.eval()
x = torch.randn(batchsize, in_channels, 4, S, S)**2
psi_gx, g_psi_x = check_equivariance_YY(layer, x)
Equivariant
3. Implement A Deep Rotation Equivariant CNN
Fianlly, you can combine the layers you have implemented earlier to build a rotation equivariant CNN. You model will take in input batches of $33 \times 33$ images with a single input channel.
The network performs a first lifting layer with $8$ output channels and is followed by $4$ group convolution with, respectively, $16$, $32$, $64$ and $128$ output channels. All convolutions have kernel size $3$, padding $1$ and stride $1$ and should use the bias. All convolutions are followed by torch.nn.MaxPool3d
and torch.nn.ReLU
. Note that we use MaxPool3d
rather than MaxPool2d
since our feature tensors have $5$ dimensions (there is an additional dimension of size $4$). In all pooling layers, we will use a kernel of size $(1, 3, 3)$, a stride of $(1, 2, 2)$ and a padding of $(0, 1, 1)$. This ensures pooling is done only on the spatial dimensions, while the rotational dimension is preserved. The last pooling layer, however, will also pool over the rotational dimension so it will use a kernel of size $(4, 3, 3)$, stride $(1, 1, 1)$ and padding $(0, 0, 0)$.
Finally, the features extracted from the convolutional network are used in a linear layer to classify the input in $10$ classes.
# Equivariant Networks
# Lifting Layers: X -> Y over p4
# GroupConv Layers: Y -> Y over p4
# So the idea is to have a lifting layer followed by groupconv layers and non linearities
# to make it invariant, make sure to use c4_pooling at the end
class C4CNN(torch.nn.Module):
def __init__(self, n_classes=10):
super(C4CNN, self).__init__()
channels = [8, 16, 32, 64, 128]
self.n_classes = n_classes
self.liftingconv = LiftingConv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1, bias=True)
layers = []
for i in range(4):
layers.append(GroupConv2d(in_channels=channels[i], out_channels=channels[i + 1], kernel_size=3, padding=1, bias=True))
layers.append(torch.nn.ReLU())
if i != 3: layers.append(torch.nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)))
self.groupconv_encoder = torch.nn.Sequential(*layers)
self.flatten = torch.nn.Flatten(start_dim=1)
self.head = torch.nn.Sequential(torch.nn.Linear(3200, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 32),
torch.nn.ReLU(),
torch.nn.Linear(32, n_classes))
def _c4_pool(self, x: torch.Tensor) -> torch.Tensor:
assert len(x.shape) == 5 # batch, channel, group, height, weight
assert x.shape[2] == 4
x_pre_pool = torch.concat([rotate(x, i) for i in range(4)], dim=2)
x_pool = x_pre_pool.mean(2)
return x_pool
def forward(self, x: torch.Tensor):
lifted_x = self.liftingconv(x)
gp_encoded_x = self.groupconv_encoder(lifted_x) # batch x channel x group x width x height
# Before pooling, the network is equivariant and after pooling, the model is invariant
# however, the hidden features are still equivariant
pooled_x = self._c4_pool(gp_encoded_x) # batch x channel x 1 x width x height
flattened_x = self.flatten(pooled_x) # batch x 3200
return self.head(flattened_x)
So C4CNN
is invariant to \(C_4\) rotation, meaning that it can recognize an image even though it’s rotated in \(C_4\). the C4CNN
, though invariant, it contains a lot to hidden features that are equivariant. In other words:
Rotated image -> rotated features -> rotated features -> … -> invariant output
Allowing the equivariant hidden features make the model more powerful and data efficient because the model is already symmetry-restricted.
Let’s check if the C4CNN
is invariant.
net = C4CNN()
torch.manual_seed(37)
x = torch.randn(5, 1, 33, 33)
y = net(x)
assert y.shape == (5, 10) # batch x n_classes
# Let's check if the model is invariant!
gx = rotate(x, 1)
gy = net(gx)
assert torch.allclose(y, gy, atol=TOL, rtol=TOL)
4. Compare CNN to GCNN on rotated MNIST dataset
After buidling the C4CNN
as \(C_4\)-invariant model with \(C_4\)-equivariant features, we want to compare it with typical CNN on the rotated MNIST dataset.
# dataset
# https://zenodo.org/record/3670627/files/mnist_rotation_new.zip?download=1
class MnistRotDataset(Dataset):
def __init__(self, mode, transform=None):
assert mode in ['train', 'test']
if mode == 'train': file = './mnist_rotation_new/mnist_all_rotation_normalized_float_train_valid.amat'
else: file = './mnist_rotation_new/mnist_all_rotation_normalized_float_test.amat'
self.transform = transform
data = np.loadtxt(file, delimiter=' ')
self.labels = data[:, -1].astype(np.int64)
self.num_samples = len(self.labels)
self.images = data[:, :-1].reshape(-1, 28, 28).astype(np.float32)
# images in MNIST are only 28x28
# we pad them to have shape 33 x 33
self.images = np.pad(self.images, pad_width=((0,0), (2, 3), (2, 3)), mode='edge')
assert self.images.shape == (self.labels.shape[0], 33, 33)
def __getitem__(self, index):
image, label = self.images[index], self.labels[index]
image = Image.fromarray(image)
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.labels)
train_set = MnistRotDataset('train', ToTensor())
test_set = MnistRotDataset('test', ToTensor())
We define functions to train and test models, which is typical pytorch forward pass with/without gradients.
def train_model(model: torch.nn.Module, nepoch: int=30):
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-5)
model.train()
for epoch in tqdm(range(nepoch + 1)):
for i, (x, t) in enumerate(train_loader):
# x, t = x.to(DEVICE), t.to(DEVICE)
y = model(x)
loss = loss_function(y, t)
loss.backward()
optimizer.step()
optimizer.zero_grad()
return model
def test_model(model: torch.nn.Module):
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64)
total = 0
correct = 0
with torch.no_grad():
model.eval()
for i, (x, t) in tqdm(enumerate(test_loader)):
x = x.to(DEVICE)
t = t.to(DEVICE)
y = model(x)
_, prediction = torch.max(y.data, 1)
total += t.shape[0]
correct += (prediction == t).sum().item()
accuracy = correct/total*100.
return accuracy
Next, we define a CNN model similar to C4CNN
by recycling the code.
# Build a normal CNN
class CNN(torch.nn.Module):
def __init__(self, n_classes=10):
super(CNN, self).__init__()
channels = [8, 16, 32, 64, 128]
self.first_conv = torch.nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3)
layers = []
for i in range(4):
layers.append(torch.nn.Conv2d(in_channels=channels[i], out_channels=channels[i + 1], kernel_size=3))
layers.append(torch.nn.ReLU())
if i != 3: layers.append(torch.nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)))
self.convs = torch.nn.Sequential(*layers)
self.flatten = torch.nn.Flatten(start_dim=1)
self.head = torch.nn.Sequential(torch.nn.Linear(128, 32),
torch.nn.ReLU(),
torch.nn.Linear(32, n_classes))
def forward(self, x):
x = self.first_conv(x)
# x = torch.nn.functional.layer_norm(x, x.shape[-3:])
x = torch.nn.functional.relu(x)
x = self.convs(x)
x = self.flatten(x)
# apply average pooling over remaining spatial dimensions
# x = torch.nn.functional.adaptive_avg_pool2d(x, 1).squeeze()
x = self.head(x)
return x
Let’s finally get the models trained and report the accuracies.
# training and keep the stats
print('Training C4CNN')
c4cnn = train_model(C4CNN(), nepoch=50)
print('Training CNN')
cnn = train_model(CNN(), nepoch=50)
acc_c4cnn = test_model(c4cnn)
acc_cnn = test_model(cnn)
print(f'C4CNN Test Accuracy: {acc_c4cnn :.3f}')
print(f'CNN Test Accuracy: {acc_cnn :.3f}')
C4CNN Test Accuracy: 91.744
CNN Test Accuracy: 82.134
5. Final Note
The performance of C4CNN
is significantly higher. I also did 25 repeated runs and C4CNN
and CNN
averaged 92%
and 81%
in accuracy. However, C4CNN
took about 5x
more time to train. Considering the (maybe) 4x
data one needs to augment, this might not be beneficial in this case and also natural equivariance shows up in the trained filters.
However, equivariance makes a lot difference in the continuous group where one cannot just rotate the filters to achieve equivariance.