I have been encountering situations where I would like to deploy a trained pytorch model but feeling reluctant to have the full model details revealed. This might be the case where end users will try to copy and reproduce the model (even though the training data is the key) or the code or model has not been published or the model source code needs to be protected.
My typical workflow for preliminary model deployemnt has been setting up the models and load the training checkpoints and then do model inference:
Even the above is dockerized, in the container, there is still raw model code. The end user can just ssh and fetch the code.
Another way to avoid revealing the model code in the deployment is using a cloud api, which I will explore later in my current project.
For now, what I am after is like a binary file, that can just load the saved model and do inference on user’s data for some prediction. I found this interesting tool: TorchScript. Basically, it can save the model not as state_dict but as JIT compiled function, similar to how Julia compiles a function into binary.
Simple Example
A very basic example. Say one has a model to deploy without source code.
Output:
Instead of saving the model dictionary, one can do
The torch.jit.trace takes 2 arguments: model and inputs which can be tensor or tuple of tensors. If the model’s forward pass take multiple tensors, they can be torch.jit.trace(model, (x1, x2, x3)).
Now at inference time, one does not need to define Model in the code, just load model.jit then deploy.
Output:
The outputs are the same.
GNN Example
The model I am trying to deploy is a GNN model with multiple input tensors of complex information. I will use a trivial GNN model trained on Cora dataset.
Output:
The GNN model is just a three-hop graph convolution.
Output:
Let’s train it briefly:
Output:
Assume the model is fully trained and is ready to be saved.
And in the inference time, in another machine or compute environment, just load the gnn_model.jit without explicitly writing the GraphNet model:
Output:
This is the model prediction for 2708 nodes with 7 classes.
We still can get the number of parameters is the same way and look into model weights:
Output:
Save compiled torch function
torch.jit.trace extends beyond just pytorch models. It can be applied to pytorch functions as well. It can save a compiled function, which can be loaded without defining the function explicitly, like a binary.
Output:
And use the same trick.
In another python script or in another machine, just do:
Output:
And
Output:
That’s pretty convenient. The inference pipeline can consist of a number of these compiled models as blackboxes without revealing the code of the model or functions.
Another advantage. There are a number of models with different hyperparameters at deployment. I had to fetch the model config via wandb and initiate the model accordingly before loading the state dictionary. But with saved jit-compiled model, just load them and forward pass (given that they all take the same inputs).