I have been encountering situations where I would like to deploy a trained pytorch model but feeling reluctant to have the full model details revealed. This might be the case where end users will try to copy and reproduce the model (even though the training data is the key) or the code or model has not been published or the model source code needs to be protected.

My typical workflow for preliminary model deployemnt has been setting up the models and load the training checkpoints and then do model inference:

# saving the model
torch.save(model.state_dict(), 'model.pt')

# loading the model
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load('model.pt'))
model.eval()

# inference 
with torch.no_grad(): model()

Even the above is dockerized, in the container, there is still raw model code. The end user can just ssh and fetch the code.

Another way to avoid revealing the model code in the deployment is using a cloud api, which I will explore later in my current project.

For now, what I am after is like a binary file, that can just load the saved model and do inference on user’s data for some prediction. I found this interesting tool: TorchScript. Basically, it can save the model not as state_dict but as JIT compiled function, similar to how Julia compiles a function into binary.

Simple Example

A very basic example. Say one has a model to deploy without source code.

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x):
        x = torch.relu(self.linear(x))
        return x

model = Model()
torch.manual_seed(99)
x = torch.randn(2, 4)
with torch.no_grad(): print(model(x))

Output:

tensor([[0.5713, 0.4536, 0.0000, 0.0000],
        [0.7102, 0.2850, 0.1824, 0.1124]])

Instead of saving the model dictionary, one can do

traced_cell = torch.jit.trace(model, (x))
torch.jit.save(traced_cell, 'model.jit')

The torch.jit.trace takes 2 arguments: model and inputs which can be tensor or tuple of tensors. If the model’s forward pass take multiple tensors, they can be torch.jit.trace(model, (x1, x2, x3)).

Now at inference time, one does not need to define Model in the code, just load model.jit then deploy.

import torch
torch.manual_seed(99)
x = torch.randn(2, 4)
loaded_model = torch.jit.load('model.jit')
with torch.no_grad(): print(loaded_model(x))

Output:

tensor([[0.5713, 0.4536, 0.0000, 0.0000],
        [0.7102, 0.2850, 0.1824, 0.1124]])

The outputs are the same.

GNN Example

The model I am trying to deploy is a GNN model with multiple input tensors of complex information. I will use a trivial GNN model trained on Cora dataset.

# load the Cora data
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

dataset = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())

print()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]

Output:

Dataset: Cora():
======================
Number of graphs: 1
Number of features: 1433
Number of classes: 7

The GNN model is just a three-hop graph convolution.

from torch_geometric.nn import GraphConv
import torch.nn.functional as F

class GraphNet(torch.nn.Module): 
    
    def __init__(self, hidden_channels): 
        super(GraphNet, self).__init__()
        torch.manual_seed(89)
        
        self.conv1 = GraphConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.conv3 = GraphConv(hidden_channels, dataset.num_classes)
        
    def forward(self, x, edge_index): 
        x = self.conv1(x, edge_index).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv3(x, edge_index)
        return x
    
model = GraphNet(hidden_channels=32)
print(model)

Output:

GraphNet(
  (conv1): GraphConv(1433, 32)
  (conv2): GraphConv(32, 32)
  (conv3): GraphConv(32, 7)
)

Let’s train it briefly:

from tqdm.notebook import tqdm

model = GraphNet(hidden_channels=32)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()

def train_gnn():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index) 
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  
    loss.backward()  
    optimizer.step()  
    return loss

def test_gnn():
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1) 
    test_correct = pred[data.test_mask] == data.y[data.test_mask]  
    test_acc = int(test_correct.sum()) / int(data.test_mask.sum())  
    return test_acc


for epoch in tqdm(range(1, 101)):
    loss = train_gnn()
    test_acc = test_gnn()
    if epoch % 50 ==0: print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Acc: {test_acc:.3f}')

Output:

Epoch: 050, Loss: 0.8477, Test Acc: 0.697
Epoch: 100, Loss: 0.5670, Test Acc: 0.731

Assume the model is fully trained and is ready to be saved.

with torch.no_grad():
    print(model)
    traced_cell = torch.jit.trace(model, (data.x, data.edge_index)) # the model inputs are node feature x and edge indices
torch.jit.save(traced_cell, 'gnn_model.jit')

And in the inference time, in another machine or compute environment, just load the gnn_model.jit without explicitly writing the GraphNet model:

import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

model = torch.jit.load('gnn_model.jit')

data = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())[0]

with torch.no_grad(): out = model(data.x, data.edge_index)
print(out.shape)

Output:

torch.Size([2708, 7])

This is the model prediction for 2708 nodes with 7 classes.

We still can get the number of parameters is the same way and look into model weights:

for param in model.parameters():
    print(param)
    break

print()
sum(param.numel() for param in model.parameters())

Output:

tensor([[ 0.0042,  0.0584,  0.0186,  ...,  0.0327,  0.0167,  0.0612],
        [ 0.0389,  0.0680,  0.0121,  ...,  0.0318,  0.0492,  0.0837],
        [ 0.0114,  0.0003, -0.0106,  ..., -0.0071,  0.0634, -0.0186],
        ...,
        [-0.0026,  0.0363,  0.0133,  ..., -0.0569,  0.0296,  0.0702],
        [ 0.0274, -0.0403, -0.0286,  ..., -0.0239, -0.0148,  0.0049],
        [ 0.0077,  0.0201, -0.0033,  ...,  0.0302, -0.0283,  0.0827]],
       requires_grad=True)

94279

Save compiled torch function

torch.jit.trace extends beyond just pytorch models. It can be applied to pytorch functions as well. It can save a compiled function, which can be loaded without defining the function explicitly, like a binary.

def some_function(x): return torch.zeros_like(x)

x = torch.randn(5, 5)
some_function(x)

Output:

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

And use the same trick.

traced_cell = torch.jit.trace(some_function, (x))
torch.jit.save(traced_cell, 'some_function.jit')

In another python script or in another machine, just do:

f = torch.jit.load("some_function.pth")
f

Output:

RecursiveScriptModule(original_name=PlaceholderModule)

And

f(torch.randn(10, 10))

Output:

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

That’s pretty convenient. The inference pipeline can consist of a number of these compiled models as blackboxes without revealing the code of the model or functions.

Another advantage. There are a number of models with different hyperparameters at deployment. I had to fetch the model config via wandb and initiate the model accordingly before loading the state dictionary. But with saved jit-compiled model, just load them and forward pass (given that they all take the same inputs).