#!/usr/bin/env python3 from transformers import BertModel import torch import time model = BertModel.from_pretrained("bert-base-uncased") model.to("cuda") input_ids = torch.ones((16, 256), dtype=torch.long) input_ids = input_ids.to("cuda") model.requires_grad_(False) start_time = time.time() for _ in range(5): with torch.no_grad(): logits = model(input_ids) print(time.time() - start_time)