The EPI-EPMP multitasking model was proposed by this paper. The model is based on message-passing neural networks; it takes antibody and antigen structures as input and predicts the paratope on the antibody and epitope on the antigen. The model architecture is shown below, which is the Fig. 2 of the original paper.

The model uses 2 different graph neural networks to process the antibody and antigen structure separately to get node embeddings for each residues. Then these nodes were used to construct a bipartite graph, similar to the old recommender system. The antigen and antibody node embeddings are passed around with graph attentional network (GAT) to obtain a final embedding for predicting binding probability of both paratopes and epitopes.

The benefit of such model is that it does not assume or sample relative orientations between antibody and antigen, which can be quite computationally expensive. It aims to do message-passing between antigen and antibody residues and determines the epitope and paratope from the mutual structural context.

I will use pyotrch-geometric(PyG) to build the batched data and EPI-EPMP model.


1. Data: Graph Pairing

For the model, one data point is one (antibody, antigen) pair. We have to build a paired data with minibatching. The tricky part is that antibody and antigen might have different number of residues, so we need to first take care of them separately.

Let’s assume that one can transform a protein structure data (i.e. atomic coordinates, PDBs) into a featurized graph G with node feature x, edge index edge_index. The Graphein package is one of the cool packages for this purpose. The following AbAgPair data class builds the paired graph:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from torch_geometric.data import Data

class AbAgPair(Data): 
    
    def __init__(self, antigen, antibody): 
        super().__init__()
        
        # antigen: add prefix 'ag' to all fields
        ag_attrs = ['edge_index', 'x']
        if antigen is not None: 
            for ag_attr in ag_attrs: 
                setattr(self, f'ag_{ag_attr}', getattr(antigen, ag_attr))
        
        # antibody: add prefix 'ab' to all fields
        ab_attrs = ['edge_index', 'x']
        if antibody is not None: 
            for ab_attr in ab_attrs: 
                setattr(self, f'ab_{ab_attr}', getattr(antibody, ab_attr))
    
    # incremental operation: for minibatching
    # the index incremental follows # of nodes of antibody or antigen
    def __inc__(self, key, value, *args, **kwargs):
        if key == 'ag_edge_index': return self.ag_x.shape[0]
        elif key == 'ab_edge_index': return self.ab_x.shape[0]
        else: return super().__inc__(key, value, *args, **kwargs)

Now given a list of AbAgPair objects, we can build the data loader with antigen-antibody pairs.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader

datalist = []
    for pairname in tqdm(pairnames[:n], desc='Loading Ab-Ag Pairs'): 
        g_ag, g_ab = ..., ... # the graphs for pairing 
        paired_data = AbAgPair(g_ag, g_ab)
        datalist.append(paired_data)

    # train and test splitting
    train_data, test_data = train_test_split(datalist, test_size=0.2, random_state=random_state)
    
    # dataloaders: 
    # Need to keep track of the batch index for both antibody and antigen separately
    # The batched data will contain additional fields: 'ag_x_batch' and 'ab_x_batch'
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, follow_batch=['ag_x', 'ab_x'])
    test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, follow_batch=['ag_x', 'ab_x'])
    

2. Model: EPI-EPMP

Now the data and loader is constructed. The model is built based on how it was architected in the figure above.

2.1 Model Architecture

The model needs to process information differently for ab_ and ag_.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch

from torch_geometric.nn import GATConv, GCNConv, Linear
from torch.nn import Sequential, Tanh, BatchNorm1d

# EPI-EPMP Multitasking Model
class EpiEPMP(torch.nn.Module): 
    
    def __init__(self, node_attr_dim, edge_attr_dim, hidden_dim, h1_dim, h2_dim, share_weight=False, dropout=0.2, heads=4): 
        super(EpiEPMP, self).__init__()
        
        self.node_attr_dim = node_attr_dim
        self.edge_attr_dim = edge_attr_dim
        self.hidden_dim = hidden_dim
        self.h1_dim = h1_dim # dim after the first gnn
        self.h2_dim = h2_dim # dim after the bipartite message passing
        self.share_weight = share_weight # Should model weights be shared between Ag and Ab?
        self.dropout = dropout
        self.heads = heads
        
        # 2 layers antigen GCN (2-hop basically)
        self.ag_gnn1 = GCNConv(self.node_attr_dim, self.hidden_dim)
        self.ag_gnn2 = GCNConv(self.hidden_dim, self.h1_dim)
        
        if self.share_weight: 
            # if weights are shared, 2 layers of antibody GCN are identical with antigen GCN
            self.ab_gnn1 = self.ag_gnn1
            self.ab_gnn2 = self.ag_gnn2
        else: 
            # or the antibody uses new GCNs
            self.ab_gnn1 = GCNConv(self.node_attr_dim, self.hidden_dim)
            self.ab_gnn2 = GCNConv(self.hidden_dim, self.h1_dim)
        
        self.ag_bnorm1 = BatchNorm1d(self.hidden_dim)
        self.ag_bnorm2 = BatchNorm1d(self.h1_dim)
        self.ab_bnorm1 = BatchNorm1d(self.hidden_dim)
        self.ab_bnorm2 = BatchNorm1d(self.h1_dim)
        
        # The GATs for the bipartite graph
        self.bp_gnn1 = GATConv(self.h1_dim, self.hidden_dim, heads=self.heads, concat=True, dropout=self.dropout)
        self.bp_gnn2 = GATConv(self.hidden_dim * self.heads, self.h2_dim, heads=self.heads, concat=False, dropout=self.dropout)
        
        # antigen prediction head
        self.ag_classifier = Sequential(Linear(self.h1_dim + self.h2_dim, self.hidden_dim), 
                                        Tanh(),
                                        Linear(self.hidden_dim, 1))
        
        # Classifier can be shared or not
        # The same as the GCN's
        if self.share_weight: 
            self.ab_classifier = self.ag_classifier
        else: 
            self.ab_classifier = Sequential(Linear(self.h1_dim + self.h2_dim, self.hidden_dim), 
                                            Tanh(),
                                            Linear(self.hidden_dim, 1))
    
    
    
    def forward(self, ag_x, ag_edge_index, ag_x_batch, \
                      ab_x, ab_edge_index, ab_x_batch):
        
        # antigen gnn + batchnorm
        ag_h1 = self.ag_bnorm1(torch.tanh(self.ag_gnn1(ag_x, ag_edge_index)))
        ag_h1 = self.ag_bnorm2(torch.tanh(self.ag_gnn2(ag_h1, ag_edge_index)))
        
        # antibody gnn + batchnorm
        ab_h1 = self.ab_bnorm1(torch.tanh(self.ab_gnn1(ab_x, ab_edge_index)))
        ab_h1 = self.ab_bnorm2(torch.tanh(self.ab_gnn2(ab_h1, ab_edge_index)))
        
        # bipartite construction
        x, edge_index, ag_index, ab_index = bipartite(ag_h1, ag_x_batch, ab_h1, ab_x_batch)
        h2 = torch.tanh(self.bp_gnn1(x, edge_index))
        h2 = torch.tanh(self.bp_gnn2(h2, edge_index))
        ag_h2, ab_h2 = h2[ag_index], h2[ab_index]
        
        # concat skip connection and clasifier heads
        ag_out = self.ag_classifier(torch.cat([ag_h1, ag_h2], dim=1))
        ab_out = self.ab_classifier(torch.cat([ab_h1, ab_h2], dim=1))
        
        return ag_out, ag_h1, ag_h2, ab_out, ab_h1, ab_h2

2.2 The Bipartite Graph

The bipartite graph is constructed in such way that every antigen node has edges connecting to all antibody nodes and vise versa. The antigen/antibody nodes themselves don’t necessarily need to be fully connected in the bipartite because such locality was captured in the original protein graph already. In the bipartite, we just want the model to pass messages between antigen and antibody.

The bipartite is built with the minibatched paired data, i.e. batch_size bipartites are built for each minibatch, so ag_x_batch and ab_x_batch are used to keep track which bipartite is which.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from torch_geometric.utils import add_self_loops


# create bipartite using index and attributes
# with device consistent tensors
def bipartite(ag_x, ag_x_batch, ab_x, ab_x_batch): 
    
    assert ag_x.device == ag_x_batch.device == ab_x.device == ab_x_batch.device
    device = ag_x.device
    
    # node attr and edge index
    x, edge_index = torch.tensor([], dtype=torch.float, device=device), torch.tensor([], dtype=torch.long, device=device)

    # antigen and antibody index for later retreval of the matrix
    ag_index, ab_index = torch.tensor([], dtype=torch.long, device=device), torch.tensor([], dtype=torch.long, device=device)
    
    # getting the unique batch index for the antibody and antigens
    ag_batch_unique, ab_batch_unique = ag_x_batch.unique(), ab_x_batch.unique()
    assert ab_batch_unique.shape[0] == ag_batch_unique.shape[0]
    
    # creating the bipartite following the batch index
    for b in ag_batch_unique: 
        idx_ag, idx_ab = torch.eq(ag_x_batch, b), torch.eq(ab_x_batch, b)
        n_ag, n_ab = idx_ag.sum(), idx_ab.sum()
        
        # keeping the offset
        offset = x.shape[0]
        
        # antigen
        ag = torch.tensor(range(offset, offset + n_ag), dtype=torch.long, device=device)
        ag_index = torch.cat([ag_index, ag])
        # antibody
        ab = torch.tensor(range(offset + n_ag, offset + n_ag + n_ab), dtype=torch.long, device=device)
        ab_index = torch.cat([ab_index, ab])
        
        # construct the edge_index
        edge_index = torch.cat([edge_index, torch.cat([ag.repeat_interleave(n_ab).view(1, -1), ab.repeat(n_ag).view(1, -1)], dim=0)], dim=1)
        edge_index = add_self_loops(edge_index)[0]
        
        # construct the node feature
        x = torch.cat([x, ag_x[idx_ag], ab_x[idx_ab]], dim=0)
        
    return x, edge_index, ag_index, ab_index

3. In-Batch Sampling

Since paratope and epitope residues are generally less than 10 \% of the number of residues (nodes) in the protein (graph). One can scale the loss function to account for the label imbalance. I chose to do the balance sampling while leaving equal weighting of the paratope and epitope prediction losses. The in-batch sampling function is as follows.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import numpy as np

def sample_balanced_indices(y): 
    pos_ix = torch.nonzero(y).view(-1).cpu()
    n_pos = pos_ix.shape[0]
    neg_indices = torch.nonzero(y != 1).view(-1)
    neg_ix = torch.tensor(np.random.choice(neg_indices.cpu().numpy(), min(n_pos, neg_indices.shape[0]), replace=False))
    return torch.cat((pos_ix, neg_ix))

# For each batch index, sample balanced labels of the negative (non-epitope) nodes
def in_batch_sampling(y, batch): 

    ubatch = batch.unique()
    idx, sampled_batch = torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)
    offset = 0 # keeping track of the # of nodes as offset for different batch index
    for b in ubatch: 
        yb = y[batch == b]

        ix = sample_balanced_indices(yb) + offset

        idx = torch.cat([idx, ix])
        offset += yb.shape[0]
 
    return idx, sampled_batch 

4. Observations and Discussions

4.1 From the Paper

The authors demonstrated good validation case with the SARS-COV-2 spike protein and a corresponding antibody that targets it.

4.2 Performances

I experiemnted with self-generated structural features, i.e. ag_x and ab_x and achieved similar performances as reported in the paper.

For antibody paratopes: AUROC = 0.800 (0.966), AUPRC = 0.551 (0.752). The bracketed values were the reported values. The discrepancy is that the authors only use the CDRs regions to feed into the model where as I was using the whole antibody, including the framework regions, which makes the paratope task a bit challenging as the labels are much more imbalanced and noises or confusions might arise from the loop regions on top of the CDR loops. However, genrally, antibody CDRs can be determined purely from sequences, so there is so surprise that the paratopes are easier to predict (middle of the CDR loops of H3 and L3). The left figure below is the predicted paratopes among different anibody fabs.

For antigen epitopes: AUROC = 0.798 (0.710), AUPRC = 0.319 (0.277). The middle figure is the ground truth and right is the model prediction of one sample in the testing set with anitgen and antibody shown in blue and green with different probabilities overlayed on the surfaces using PyMol.

It is usually interesting to look under the hood what the model learned. Here are the PCA plots for antigen and antibody node embeddings (ag_h1 and ab_h1) colored by the node degree. There is no obvious correlation between embeddings and node degree because the model was trained in such way that it focused on structural context, i.e. node features while still taking local neighborhood into account.

4.3 Efficiency

Generally, the training took about 48 GPU-hrs for a few k of data and epochs. The inference took ~ 2s. The computationally expensive part was the featurization.

The minibatch might bags LARGE antigens with >500 residues, and might consume large memory or even cause RuntimeError: CUDA error: out of memory.

4.4 Some Thoughts

When the antibody and antigen complex is formed for cryoEM or X-ray cystallography, the conformations of both changed. The side chains might pertrude and all the structural context might differ from the original free antibody or antigen. Using the bound complex to predict the interfaces or epitopes might suffer information leakage that undermine the model’s generalizability to free antigens for antibody designs or applications.

This EPI-EPMP model is not a conditional model. It does message passing between antigen and antibodies but the authors did not show that if the model learned the conditional probability P(epiotpe | paratope) or independent probabilities P(epitope) P(paratope). I was convinced that the latter was the case, meaning that if one pairs any antibody to the same antigen, the epitope prediction does not vary, which might limit the real-world use case as people care more about where different antibodies bind to a certain target. A deeper look into the GAT attentional weight in the bipartite confirms that the attentional edges (or soft edges) are very sparse, i.e. the model might as well use an epitope node embedding ONLY rather than paying attention to ANY antibody nodes. I call this as attention collapse, where the dense attentional edges collaped to self-attention after training.

One solution to this could be forcing information flow into the epitope prediction.

Sometimes the epitopes span a large portion on the antigen surface; further narrowing down of the epitopes might require force-field (physical) modeling of the docking free energies but at least the model has ruled out quite a large set of potential patches.


Here is the repo of my implementation of the EPI-EPMP multitasking model.

EPI_EPMP_Pytorch


Reference

  1. Alice Del Vecchio, Andreea Deac, Pietro Liò, Petar Veličković, Neural message passing for joint paratope-epitope prediction, https://arxiv.org/abs/2106.00757