File size: 1,055 Bytes
45c96b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from transformers import AutoModel, AutoConfig
from modeling_chatglm import ChatGLMForConditionalGeneration
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
import sentencepiece as spm
import torch
def dump_model():
config = AutoConfig.from_pretrained('./', trust_remote_code=True)
model = ChatGLMForConditionalGeneration(config)
state_dict = model.state_dict()
torch.save(state_dict, 'pytorch_model.bin')
def dump_small_tokenizer():
sp = spm.SentencePieceProcessor()
sp.Load('./origin_tokenizer.model')
print(sp.piece_size())
print(sp.pad_id())
print(sp.unk_id())
print(sp.eos_id())
print(sp.bos_id())
# 从sp中随机选择500个piece保存为新的model
new_sp = sp_pb2_model.ModelProto()
new_sp.ParseFromString(sp.serialized_model_proto())
while len(new_sp.pieces) > 500:
new_sp.pieces.pop()
with open('tokenizer.model', 'wb') as f:
f.write(new_sp.SerializeToString())
if __name__ == "__main__":
# dump_small_tokenizer()
dump_model()
|