Example of BlindAI deployment with Wav2vec2¶
This example shows how you can run a Wav2Vec2 model to perform Speech-To-Text with confidentiality guarantees.
By using BlindAI, people can send data for the AI to analyze their conversations without having to fear privacy leaks.
Wav2Vec2 is a state-of-the art Transformers model for speech. You can learn more about it on FAIR blog's post.
Installing dependencies¶
Install the dependencies this example needs.
!pip install -q transformers[onnx] torch
We will need librosa
to load the "hello world" audio file. You might need to downgrade numpy
to 1.21 to make it work. The following commands should do the trick to install librosa
:
!pip install -q --upgrade numpy==1.21
!pip install -q librosa
In addition, you might need to install ffmpeg
to have a backend to process the wav file.
!sudo apt-get install -y ffmpeg
Install the latest version of BlindAI.
!pip install blindai
Preparing the model¶
Here we will use a large Wav2Vec2 model. First step is to get the model and tokenizers.
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
# load model and processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
We can download an hello world audio file to be used as example. Let's download it.
!wget https://github.com/mithril-security/blindai/raw/master/examples/wav2vec2/hello_world.wav
We can hear it here:
import IPython.display as ipd
ipd.Audio('hello_world.wav')
We can then see the Wav2vec2 model in action on the hello world file.
import librosa
audio, rate = librosa.load("hello_world.wav", sr = 16000)
# Tokenize sampled audio to input into model
input_values = processor(audio, sampling_rate=rate, return_tensors="pt", padding="longest").input_values
# Retrieve logits
logits = model(input_values).logits
# Take argmax and decode
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
print(transcription)
In order to facilitate the deployment, we will add the post processing directly to the full model. This way the client will not have to do the post processing.
import torch.nn as nn
# Let's embed the post-processing phase with argmax inside our model
class ArgmaxLayer(nn.Module):
def __init__(self):
super(ArgmaxLayer, self).__init__()
def forward(self, outputs):
return torch.argmax(outputs.logits, dim = -1)
final_layer = ArgmaxLayer()
# Finally we concatenate everything
full_model = nn.Sequential(model, final_layer)
We can check the results are the same.
predicted_ids = full_model(input_values)
transcription = processor.batch_decode(predicted_ids)
transcription
Now we can export the model in ONNX format, so that we can feed later the ONNX to our BlindAI server.
torch.onnx.export(
full_model,
input_values,
'wav2vec2_hello_world.onnx',
export_params=True,
opset_version = 11)
Deployment on BlindAI¶
Now we can upload the model to BlindAI Cloud. To upload of the model, make sure you have an API key.
You can get one on the Mithril Cloud.
You might get an error if the name you want to use is already taken, as models are uniquely identified by their model_id
. We will implement namespace soon to avoid that. Meanwhile, you will have to choose a unique ID. We provide an example below to upload your model with a unique name:
import blindai
import uuid
api_key = "YOUR_API_KEY" # Enter your API key here
model_id = "wav2vec2-" + str(uuid.uuid4())
# Upload the ONNX file to the remote enclave
with blindai.connect(api_key=api_key) as client:
response = client.upload_model("wav2vec2_hello_world.onnx", model_id=model_id)
Sending data for confidential prediction¶
Now it's time to check it's working live!
We will just prepare some input for the model inside the secure enclave of BlindAI to process it.
First we prepare our input data, the hello world audio file.
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import librosa
# load model and processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
audio, rate = librosa.load("hello_world.wav", sr = 16000)
# Tokenize sampled audio to input into model
input_values = processor(audio, sampling_rate=rate, return_tensors="pt", padding="longest").input_values
Now we can send the audio data to be processed confidentially!
with blindai.connect() as client:
response = client.predict(model_id, input_values)
We can reconstruct the output now:
# Decode the output
processor.batch_decode(response.output[0].as_torch().unsqueeze(0))