File size: 314 Bytes
b0a8a80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers import AutoTokenizer, GPT2LMHeadModel

'''

This is a script to convert the Jax model and the tokenizer to Pytorch model

'''

model = GPT2LMHeadModel.from_pretrained(".", from_flax=True)

model.save_pretrained(".")

tokenizer = AutoTokenizer.from_pretrained(".")

tokenizer.save_pretrained(".")