RoBERTa
Use cases
Transformer-based language model for text generation.
Description
RoBERTa builds on BERT’s language masking strategy and modifies key hyperparameters in BERT, including removing BERT’s next-sentence pretraining objective, and training with much larger mini-batches and learning rates. RoBERTa was also trained on an order of magnitude more data than BERT, for a longer amount of time. This allows RoBERTa representations to generalize even better to downstream tasks compared to BERT.
Model
| Model | Download | Download (with sample test data) | ONNX version | Opset version | Accuracy | 
|---|---|---|---|---|---|
| RoBERTa-BASE | 499 MB | 295 MB | 1.6 | 11 | 88.5 | 
| RoBERTa-SequenceClassification | 499 MB | 432 MB | 1.6 | 9 | MCC of 0.85 | 
Source
PyTorch RoBERTa => ONNX RoBERTa PyTorch RoBERTa + script changes => ONNX RoBERTa-SequenceClassification
Conversion
Here is the benchmark script that was used for exporting RoBERTa-BASE model.
Tutorial for conversion of RoBERTa-SequenceClassification model can be found in the conversion notebook.
Official tool from HuggingFace that can be used to convert transformers models to ONNX can be found here
Inference
We used ONNX Runtime to perform the inference.
Tutorial for running inference for RoBERTa-SequenceClassification model using onnxruntime can be found in the inference notebook.
Input
input_ids: Indices of input tokens in the vocabulary. It's a int64 tensor of dynamic shape (batch_size, sequence_length). Text tokenized by RobertaTokenizer.
For RoBERTa-BASE model: Input is a sequence of words as a string. Example: "Text to encode: Hello, World"
For RoBERTa-SequenceClassification model: Input is a sequence of words as a string including sentiment. Example: "This film is so good"
Preprocessing
For RoBERTa-BASE and RoBERTa-SequenceClassification model use tokenizer.encode() to encode the input text:
import torch
import numpy as np
from simpletransformers.model import TransformerModel
from transformers import RobertaForSequenceClassification, RobertaTokenizer
text = "This film is so good"
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode(text, add_special_tokens=True)).unsqueeze(0)  # Batch size 1
Output
For RoBERTa-BASE model:
Output of this model is a float32 tensors [batch_size,seq_len,768] and [batch_size,768]
For RoBERTa-SequenceClassification model:
Output of this model is a float32 tensor [batch_size, 2]
Postprocessing
For RoBERTa-BASE model:
last_hidden_states = ort_out[0]
For RoBERTa-SequenceClassification model: Print sentiment prediction
pred = np.argmax(ort_out)
if(pred == 0):
print("Prediction: negative")
elif(pred == 1):
print("Prediction: positive")
Dataset
RoBERTa-BASE model was trained on five datasets:
- BookCorpus, a dataset consisting of 11,038 unpublished books;
- English Wikipedia (excluding lists, tables and headers) ;
- CC-News, a dataset containing 63 millions English news articles crawled between September 2016 and February 2019.
- OpenWebText, an opensource recreation of the WebText dataset used to train GPT-2,
- Stories a dataset containing a subset of CommonCrawl data filtered to match the story-like style of Winograd schemas.
Pretrained RoBERTa-BASE model weights can be downloaded here.
RoBERTa-SequenceClassification model weights can be downloaded here.
Validation accuracy
GLUE (Wang et al., 2019) (dev set, single model, single-task finetuning)
| Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B | 
|---|---|---|---|---|---|---|---|---|
| roberta.base | 87.6 | 92.8 | 91.9 | 78.7 | 94.8 | 90.2 | 63.6 | 91.2 | 
Metric and benchmarking details are provided by fairseq.
Publication/Attribution
- RoBERTa: A Robustly Optimized BERT Pretraining Approach.Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov
References
- The RoBERTa-SequenceClassification model is converted directly from seldon-models/pytorch
- Accelerate your NLP pipelines using Hugging Face Transformers and ONNX Runtime
Contributors
License
Apache 2.0
