|
from transformers import AutoModel, AutoConfig |
|
from modeling_chatglm import ChatGLMForConditionalGeneration |
|
from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model |
|
import sentencepiece as spm |
|
import torch |
|
|
|
|
|
def dump_model(): |
|
config = AutoConfig.from_pretrained('./', trust_remote_code=True) |
|
model = ChatGLMForConditionalGeneration(config) |
|
state_dict = model.state_dict() |
|
torch.save(state_dict, 'pytorch_model.bin') |
|
|
|
|
|
def dump_small_tokenizer(): |
|
sp = spm.SentencePieceProcessor() |
|
sp.Load('./origin_tokenizer.model') |
|
print(sp.piece_size()) |
|
print(sp.pad_id()) |
|
print(sp.unk_id()) |
|
print(sp.eos_id()) |
|
print(sp.bos_id()) |
|
|
|
new_sp = sp_pb2_model.ModelProto() |
|
new_sp.ParseFromString(sp.serialized_model_proto()) |
|
while len(new_sp.pieces) > 500: |
|
new_sp.pieces.pop() |
|
with open('tokenizer.model', 'wb') as f: |
|
f.write(new_sp.SerializeToString()) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
dump_model() |
|
|
|
|
|
|