Recently, I have been working with equivariant or invariant models. The models leverage inductive bias in the data. For example, convolutional layers is translational equivariant. The messaging aggregation function in the graph neural networks is permutational invariant. Physical data is typically in 3D and the models should respect euclidean symmetries, E(3). If the model does not respect E(3) symmetry, it is usually required to augment the data by random rotation and translation, which is an infinite amount of augmented data and the model might be hard to optimize.
If \(f: V \to V\) is an equivariant function on \(G\) and \(g\) is a group operation in \(G\). We have a data \(x\) on some vector space \(V\). We have
\[f(g \circ x) = g \circ f(x)\]
I am going to explore a general framework e3nn, which has taken care of the equivariant operations using irreproducible representations, irreps. There are 3 cases/tasks I want to explore in this post:
Tetris shape prediction: Given 3D coordinates of the voxels of a tetris block, predict the shape. The coordinates might be randomly rotated or translated and the model needs to generalize.
Trajectory prediction: Given a 12-body system and coordinates at certain time, predict 12 coordinates at the next time step. The system and physical motions should be independent of the reference frame, therefore, if we rotate or translate the coordinates, the motions should be equivariant.
Electron density prediction: The electron density is usually represented by the linear combinations of spherical harmonics, \(Y^l_m\). The goal is to predict the coefficients, which fits nicely to the irreps in e3nn.
1. Tetris
Given 3D coordinates of the voxels of a tetris block, predict the shape. The coordinates might be randomly rotated or translated and the model needs to generalize.
1.1 Packages
We will use some basic packages like torch, geometric libraries torch_geometric and augmented torch libraries torch_cluster and torch_scatter. The main one I am going to use is the e3nn package.
1.2 Tetris toy dataset
Each tetris block contain 4 voxels of size 1x1x1. There are 8 different shapes.
Note that the chiral shape has parity, i.e. the label of the chiral shape is a pseudo-scalar, which flip signs upon mirror reflection.
Everything should look familiar, maybe except for o3.rand_matrix(n). It generates \(n\) 3x3 random rotational matrices, which were applied to the coordinates.
The data follows the torch_geometric convention where batch represents which graph the positions belong to and the ptr is the pointer to the start of each graph.
1.3 GNN Baseline
Let’s build a simple GNN baseline model using message passing with attention: GATConv.
Again, this is a very typical GNN and forward pass. However, there are some foreign stuff sneaking in the code and I’ll explain. Since the input or the data only contains a set of 4 coordinates, (x, y, z). The node and edge contain absolutely no features at all except for the coordinates. I am going to use, again, spherical harmonics to featurize the edges and nodes.
The line self.irreps_sh = o3.Irreps.spherical_harmonics(3) is the basis of the spherical harmonics, containing the following irreps:
1x0e+1x1o+1x2e+1x3o
Ok and let me expand these weird codes. There are 4 irreps in the above expression. Each irrep is of the form of mxlp, where
m is the multiplicity, i.e. the number of features
x just stands for times
l is the rotation order, or the first quantum number of electron orbitals or the order of spherical harmonics.
l = 0 is transformed as scalar (one number), or the l = 0 orbital, which is a sphere.
l = 1 is transformed as vector (three numbers), or the l = 1 orbital, which aligns with the x, y, z axes, corresponding to 3 quantum numbers m = -1, 0, 1.
l = 2 is transformed as ?…? (five numbers), or the l = 2 orbital, which are more complicated, corresponding to 5 quantum numbers m = -2, -1, 0, 1, 2.
p describes the parity: even e or odd o, describing how these tensors transform under mirror reflection.
So 16x1o is a set of 16 vectors with odd parity.
Another line is the radius_graph(x=data.pos, r=10.1, batch=data.batch) which constructs the edge_index just using the coordinates and batch. I used r = 10.1 or something r > 3.1 to construct the fully connected graph of 4 nodes. These edges are directional, because they are translational vectors. So in the above line, we will have 4 * 3 = 12 edges for each graph. Therefore,
torch.Size([2, 96])
Since the input is only the 3D coordinates, the straightforward featurization is using the distance vector, \(\mathbf{r}_i - \mathbf{r}_j\) where \(\mathbf{r}\) is the position vector. We further use the coefficients of spherical harmonics as edge features, i.e. we project the distance vector on l=0, 1, 2, 3 spherical harmonics and get the coefficients, with a total number of 1 + 3 + 5 + 7 = 16 coefficients. Note that \(<Y^l_m, Y^{l'}_{m'}> = \delta_{ll'}\delta_{mm'}\), meaning that they are orthogonal bases.
The node features are the aggregated of the edge features of the inward edges.
With the node and edge features using spherical harmonics, we can just create a graph convolution and a readout for tetris shape prediction.
Let’s build a general training function that can be used for all the models built for this task.
Now we can try to train our GNN baseline model:
Training on CPU took about 90 seconds for 10000 epochs. We tried no shuffling and shuffling and the final loss were \(41.9\) and \(40.5\) respectively. They achieved final accuracy of \(75\%\), which is not optimal given the easiness of the task and small number of data and large GNN models with GATConv.
1.4 E3NN
In previous GNN model, I used a small part of e3nn for featurization but the model is still not aware of the E(3) symmetry: it just did something like torch.nn.Linear and ignored the relationship/interaction of each element under E(3). We were also toruched on the irreps, which I feel it might be better to go through them again.
Similar to the hidden dimension in the hidden layers of neural networks, we can also define the hidden irreps in e3nn.
The neural network layer (or the Linear layer) in e3nn is called the FullyConnectedTensorProduct with trainable weights. At initialization, it takes 2 input irreps and 1 output 1rreps and in the forward pass, it takes 2 tensors with corresponding input irreps. This layer dictates the interaction between different irreps.
We can check if the above tensor product is the same as exhausting all possible paths of the irreps.
And we can visualize these interaction paths:
Now we can define our simple equivariant model using e3nn’s FullyConnectedTensorProduct.
We use the train function again.
Clearly, the model was trained better due to the E(3) equivariance. No shuffling and shuffling model achieved a final loss of \(7.3\) and \(2.0\) respectively with \(100\%\) accuracy, which is a lot better than the GNN baseline.
1.5 Fancy E3NN
I have seen some people use the existent model of e3nn instead of writing the FullyConnectedTensorProduct. We will use this fancier model from e3nn: gate_points_2101.Network.
We start off by assigning some hyperparameter:
The output of the model is a 32 x 7 tensor, i.e. for each coordinate (regardless of the batch), the model outputs the 7-tensor. We need to aggregate these 4 7-tensors like before according to batch. So we created another wrapper model:
Again, we use our train function.
And.. take a look at those too-good-to-be-true losses and accuracies.
1.6 Comparisons
Let’s compare the 3 models trained: GNN, E3NN Basic, E3NN Fancy. We plot the training losses and accuracies using the logs.
1.7 Check Equivariance
Finally, we need to check the model equivariance, using our very first equation
\[f(g \circ x) = g \circ f(x)\]
where \(g\) here is a 3D rotation or translation and \(f\) is our trained model. The model is acturally invariant, i.e. the output does not depend on the transformation of group operation.
\[f(g \circ x) = f(x)\]
And we can test the equivariance using e3nn’s assert_equivariant function.
And this concludes the part 1 of initial messing around of e3nn.
References
Mario Geiger, Tess Smidt, e3nn: Euclidean Neural Networks, https://arxiv.org/abs/2207.09453