In this post, I would like to share how to convert a trained Pytorch model to a Tensorflow model.
ONNX
What is ONNX?
ONNX(Open Neural Network Exchange) is an open ecosystem that empowers AI developers to choose the right tools as their project evolves. Briefly speaking, it enables interoperability between different frameworks and streamlining the path from research to production helps increase the speed of innovation in the AI community. To achieves this, it defines an extensible computation graph model, as well as built in operators and standard data types.
We can think the Deep Learning as calculation over data flow graphs. The graphs are divided into two types: Dynamic
and Static
graphs. Different deep learning framework uses different kind of graphs. For instance, frameworks like Tensorflow, Caffe2, CNTK, Theano prefer to use static graph while others such as Pytorch, Chainer use dynamic graphs.
Both of them have Pros and Cons. As for static graph, once the graph is defined it can be used multiple times as fast as possible cause we are not going to create anything new. Also, the static computation graph can be used to schedule computation across a pool of computational devices so computational cost could be shared. So, once defined we can use the optimization compiler to optimize the graph so that large graph can be run efficiently on either CPUs or GPUs. However, it is not flexible. And, because many logic errors will wait to be uncovered until execution, static graph has difficulty in debugging. As for dynamic graph, it is more flaxible, you can define, change and execute the network as you go. There is no such special sessions
like we do in static graphs. It is more pythonic and easy to debug. However it is not that fast compared to static graph. So, it will be really great if we could develop the model using dynamic graph and deploy it using static graph. And, here he comes – ONNX.
In the rest of this blog, I will use an example to illustrate how to convert a pytorch model to a tensorflow model. The whole .ipynb notebook can be found in our repository
Install ONNX
It is easy,
$ pip install tensorflow onnx onnx-tf
Import pytorch model
Our model looks like this, it is proposed by Alex L. Cheng C, etc. ‘Real-time deep hair matting on mobile devices’ .
For the pytorch implementation of this model, you can refer to our repository
load torch model and export it to ONNX model
First, define a function to load the model from pertained file.
# import our Hair network
from models import MobileHairNet
# import config file
import config
def build_model(path, device):
model = MobileHairNet()
checkpoint = torch.load(path, map_location=device)
model.load_state_dict(checkpoint['model'])
# use appropriate device
model = model.to(device)
return model
Then we can load the model. TORCH_MODEL_PATH
is our pretrained model’s path
hair-dye/src/checkpoints/default/train_170.pt
Note that to export the model to ONNX model, we need a dummy input, so we just use an random input(batch_size, channel_size, height_size, weight_size). Our model has input size of (1, 3, 224, 224).
After we run the code, the notebook will print some information about the network. It looks like this:
torch_model = build_model(TORCH_MODEL_PATH, device)
# export model to ONNX
# ONNX need a dummy input
dummy_input = Variable(torch.randn(1, 3, config.IMG_SIZE, config.IMG_SIZE))
# pred = torch_model(dummy_input)
# print("output size: ", pred.size())
input_names = ["input"]
output_names = ["output"]
torch.onnx.export(torch_model, dummy_input, ONNX_MODEL_PATH, verbose=True, input_names=input_names, output_names=output_names)
… %215 : Float(1, 32, 224, 224) = onnx::Convdilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[2, 2], scope: MobileHairNet/_EncodeBlockBlueGreen[encode_block1]/Sequential[layer]/_BlueBlock[0]/Conv2d[layer] %216 : Float(1, 32, 224, 224) = onnx::Convdilations=[1, 1], group=32, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1], scope: MobileHairNet/_EncodeBlockBlueGreen[encode_block1]/Sequential[layer]/_GreenBlock[1]/Sequential[layer]/Conv2d[0] %217 : Float(1, 32, 224, 224) = onnx::BatchNormalizationepsilon=1e-05, is_test=1, momentum=1, scope: MobileHairNet/_EncodeBlockBlueGreen[encode_block1]/Sequential[layer]/_GreenBlock[1]/Sequential[layer]/BatchNorm2d[1] %218 : Float(1, 32, 224, 224) = onnx::Clipmax=6, min=0, scope: MobileHairNet/_EncodeBlockBlueGreen[encode_block1]/Sequential[layer]/_GreenBlock[1]/Sequential[layer]/ReLU6[2] …
we can check if the network is converted right.
import the ONNX model to tensorflow
# load the onnx file
onnx_model = onnx.load(ONNX_MODEL_PATH)
# Check the model
onnx.checker.check_model(onnx_model)
# print('The model before conversion:\n{}'.format(onnx_or_model))
# # A full list of supported adapters can be found here:
# # https://github.com/onnx/onnx/blob/master/onnx/version_converter.py#L21
# # Apply the version conversion on the original model
# onnx_model = version_converter.convert_version(onnx_or_model, 7)
# print('The model after conversion:\n{}'.format(onnx_model))
# import onnx to TF model
tf_rep = prepare(onnx_model)
Here I get an error,
NotImplementedError: Upsample version 1 is not implemented.
I search the Issues on ONNX repo on GitHub. Here is someone’s answer
Upsample was an experimental op in ONNX before opset=7. Since ONNX decided to totally remove experimental ops, we have not taken the implementation for the experimental upsample (opset=1) into the v1.3.0 release. Your onnx file seems for opset=6 therefore the error. If you could possibly regenerate the onnx file with opset 7 or newer, it should work with onnx-tf v1.3.0. Sorry about the confusion.
So, I search the version_converter function in ONNX. Unfortunately it does not work. Then I was considering downgrade the version of ONNX. I search the doc on github repo. I get the version information
ONNX version | File format version | Operator set version ai.onnx | Operator set version ai.onnx.ml |
---|---|---|---|
1.0 | 3 | 1 | 1 |
1.1 | 3 | 5 | 1 |
1.1.2 | 3 | 6 | 1 |
1.2 | 3 | 7 | 1 |
1.3 | 3 | 8 | 1 |
1.4.1 | 4 | 9 | 1 |
1.5.0 | 5 | 10 | 1 |
Also I search the operator set support of upsampling
Upsample (deprecated)
Upsample the input tensor. Each dimension value of the output tensor is: output_dimension = floor(input_dimension * scale).
Version
This version of the operator has been deprecated since version 10 of the default ONNX operator set.
Other versions of this operator: Upsample-7, Upsample-9
So, it seems I need version 7 operator set. So I need to downgrade the version to 1.2.
pip install tensorflow==1.13.1 onnx==1.2.1 onnx-tf==1.1.2
After some attempt, this version works fine to me.
Save to file
tf_rep.export_graph(TF_MODEL_PATH)
TF_MODEL_PATH is the new tensorflow graph model path
So far, I have finished the work. The rest part is optional
Run the ONNX model to see if it works(optional)
ImgTransformer
will load the image and do preprocessing.
class ImgTransformer():
def __init__(self, img_size, color_aug=False):
self.img_size = img_size
self.color_aug = color_aug
def transform(self, image):
transformer = transforms.Compose([
transforms.CenterCrop(min(image.size[0], image.size[1])),
transforms.Resize(self.img_size)
] + ([
transforms.ColorJitter(brightness=1, contrast=1, saturation=1, hue=.5),
transforms.RandomGrayscale(p=0.1)
] if self.color_aug else []) + [
transforms.ToTensor(),
transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])
transform_image = transformer(image)
return transform_image
def load(self, path):
image = Image.open(path).convert('RGB')
return self.transform(image)
img_path = os.path.join(DIR_PATH, config.TEST_CORPUS, "images", "Frame\ \(150\).jpg")
transformer = ImgTransformer(config.IMG_SIZE, color_aug=False)
img = transformer.load(img_path)
display(img)
tensor([[[ 0.4902, 0.5137, 0.5216, …, -0.6471, -0.6471, -0.6471], [ 0.4980, 0.5137, 0.5216, …, -0.6471, -0.6471, -0.6471], [ 0.5059, 0.5216, 0.5373, …, -0.6471, -0.6471, -0.6392], …, [ 0.0118, -0.0039, 0.0118, …, 0.1529, 0.1373, 0.1137], [-0.0039, 0.0118, 0.0196, …, 0.1216, 0.0980, 0.0667], [-0.0118, 0.0118, 0.0196, …, 0.0980, 0.0667, 0.0353]],
[[ 0.3569, 0.3804, 0.3882, …, -0.7882, -0.7882, -0.7882], [ 0.3647, 0.3804, 0.3882, …, -0.7882, -0.7882, -0.7882], [ 0.3647, 0.3804, 0.3961, …, -0.7882, -0.7882, -0.7804], …, [-0.1529, -0.1686, -0.1529, …, -0.0196, -0.0353, -0.0510], [-0.1451, -0.1294, -0.1137, …, -0.0667, -0.0824, -0.0902], [-0.1451, -0.1216, -0.1059, …, -0.1059, -0.1137, -0.1216]],
[[-0.2706, -0.2471, -0.2471, …, -0.7882, -0.7882, -0.7882], [-0.2627, -0.2471, -0.2471, …, -0.7882, -0.7882, -0.7882], [-0.2627, -0.2549, -0.2471, …, -0.7882, -0.7882, -0.7882], …, [-0.4902, -0.5059, -0.4980, …, -0.3255, -0.3569, -0.4118], [-0.5059, -0.4902, -0.4824, …, -0.3725, -0.4039, -0.4275], [-0.5059, -0.4824, -0.4824, …, -0.4039, -0.4353, -0.4510]]])
To display the Image, mask, prediction in one row
def create_multi_figure(rows, dye=False):
fig = plt.figure()
# 3 tensors, the middle one is mask
if len(rows[0]) == 3:
names = ["Image", "Mask", "Prediction"]
else:
names = ["Image", "Prediction"]
for i, data in enumerate(rows):
img = data[0]
prediction = data[-1]
# reverse normalization of the oriignal image
img = (img + 1) / 2
data[0] = img
if dye:
transform_hue = transforms.Compose([
transforms.ToPILImage(),
transforms.ColorJitter(
hue=0.5, saturation=0.8, contrast=0.2, brightness=0.3),
transforms.ToTensor()
])
dyed = transform_hue(img)
dyed = prediction * dyed + (1 - prediction) * img
data.append(dyed)
names.append('Dye')
for j, d in enumerate(data):
d = d.squeeze()
im = d.data.numpy()
if im.shape[0] != 3:
im = np.expand_dims(im, axis=0)
im = np.concatenate((im, im, im), axis=0)
im = im.transpose(1, 2, 0)
f = fig.add_subplot(len(rows), len(data), i * len(data)+ j + 1)
f.imshow(im)
if i == 0:
f.set_title(names[j])
f.set_xticks([])
f.set_yticks([])
return fig
def evaluateOne(img, model, absolute=True):
img = img.to(device).unsqueeze(0)
pred = model.run(img)[0]
pred = torch.from_numpy(pred)
if absolute:
pred[pred > .5] = 1.
pred[pred <= .5] = 0.
else:
pred[pred < .4] = 0
rows = [[img[0], pred[0]]]
create_multi_figure(rows, dye=True)
plt.show()
evaluateOne(img, tf_rep)
The results are shown here: