First, let's install
!pip install blindai
The first step is to get the model in ONNX format. Let's pull the ResNet18 model from PyTorch Hub and export in in ONNX.
import torch model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True) # We need to provide an example of input with the right shape to export dummy_inputs = torch.zeros(1,3,224,224) torch.onnx.export(model, dummy_inputs, "resnet18.onnx")
Now we can upload the model to BlindAI Cloud. To upload of the model, make sure you have an API key.
You can get one on the Mithril Cloud.
You might get an error if the name you want to use is already taken, as models are uniquely identified by their
model_id. We will implement namespace soon to avoid that. Meanwhile, you will have to choose a unique ID.
import blindai import uuid api_key = "YOUR_API_KEY" # Enter your API key here model_id = "resnet18-" + str(uuid.uuid4()) # Upload the ONNX file to the remote enclave with blindai.Connection(api_key=api_key) as client: response = client.upload_model("resnet18.onnx", model_id=model_id)
Now that the model is uploaded, we just need to feed it with data. Let's first grab the image we want the model to analyze.
from torchvision import transforms import urllib from PIL import Image # Download an example image from the pytorch website url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg") try: urllib.URLopener().retrieve(url, filename) except: urllib.request.urlretrieve(url, filename) # sample execution (requires torchvision) input_image = Image.open(filename) input_image
We will preprocess the image to make it usable by the model.
preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) input_tensor = preprocess(input_image) input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
Now we just need to connect to the model inside the secure enclave and query it.
import blindai import torch with blindai.Connection(api_key=api_key) as client: # Send data to the ResNet18 model prediction = client.predict(model_id, input_batch)
We can get the name of the prediction now:
import requests response = requests.get("https://git.io/JJkYN") labels = response.text.split("\n") output = prediction.output.as_torch() probabilities = torch.nn.functional.softmax(output) labels[probabilities.argmax().item()], probabilities.max().item()