Example of BlindAI deployment with Wav2vec2¶
This example shows how you can run a Wav2Vec2 model to perform Speech-To-Text with confidentiality guarantees.
By using BlindAI, people can send data for the AI to analyze their conversations without having to fear privacy leaks.
Wav2Vec2 is a state-of-the art Transformers model for speech. You can learn more about it on FAIR blog's post.
Install the dependencies this example needs.
!pip install -q transformers[onnx] torch
We will need
librosa to load the "hello world" audio file. You might need to downgrade
numpy to 1.21 to make it work. The following commands should do the trick to install
!pip install -q --upgrade numpy==1.21 !pip install -q librosa
In addition, you might need to install
ffmpeg to have a backend to process the wav file.
!sudo apt-get install -y ffmpeg
Install the latest version of BlindAI.
!pip install blindai
Here we will use a large Wav2Vec2 model. First step is to get the model and tokenizers.
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import torch # load model and processor processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
We can download an hello world audio file to be used as example. Let's download it.
We can hear it here:
import IPython.display as ipd ipd.Audio('hello_world.wav')
We can then see the Wav2vec2 model in action on the hello world file.
import librosa audio, rate = librosa.load("hello_world.wav", sr = 16000) # Tokenize sampled audio to input into model input_values = processor(audio, sampling_rate=rate, return_tensors="pt", padding="longest").input_values # Retrieve logits logits = model(input_values).logits # Take argmax and decode predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids) print(transcription)
In order to facilitate the deployment, we will add the post processing directly to the full model. This way the client will not have to do the post processing.
import torch.nn as nn # Let's embed the post-processing phase with argmax inside our model class ArgmaxLayer(nn.Module): def __init__(self): super(ArgmaxLayer, self).__init__() def forward(self, outputs): return torch.argmax(outputs.logits, dim = -1)
final_layer = ArgmaxLayer() # Finally we concatenate everything full_model = nn.Sequential(model, final_layer)
We can check the results are the same.
predicted_ids = full_model(input_values) transcription = processor.batch_decode(predicted_ids) transcription
Now we can export the model in ONNX format, so that we can feed later the ONNX to our BlindAI server.
torch.onnx.export( full_model, input_values, 'wav2vec2_hello_world.onnx', export_params=True, opset_version = 11)
Now we can upload the model to BlindAI Cloud. To upload of the model, make sure you have an API key.
You can get one on the Mithril Cloud.
You might get an error if the name you want to use is already taken, as models are uniquely identified by their
model_id. We will implement namespace soon to avoid that. Meanwhile, you will have to choose a unique ID. We provide an example below to upload your model with a unique name:
import blindai import uuid api_key = "YOUR_API_KEY" # Enter your API key here model_id = "wav2vec2-" + str(uuid.uuid4()) # Upload the ONNX file to the remote enclave with blindai.connect(api_key=api_key) as client: response = client.upload_model("wav2vec2_hello_world.onnx", model_id=model_id)
Now it's time to check it's working live!
We will just prepare some input for the model inside the secure enclave of BlindAI to process it.
First we prepare our input data, the hello world audio file.
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC import torch import librosa # load model and processor processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") audio, rate = librosa.load("hello_world.wav", sr = 16000) # Tokenize sampled audio to input into model input_values = processor(audio, sampling_rate=rate, return_tensors="pt", padding="longest").input_values
Now we can send the audio data to be processed confidentially!
with blindai.connect() as client: response = client.predict(model_id, input_values)
We can reconstruct the output now:
# Decode the output processor.batch_decode(response.output.as_torch().unsqueeze(0))