File size: 6,797 Bytes
9e3dd30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import torch
from transformers import AutoTokenizer, AutoConfig, DistilBertForQuestionAnswering # Correct import
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
import os
import logging
from typing import Optional, Dict, Any
class ONNXModelConverter:
def __init__(self, model_name: str, output_dir: str):
self.model_name = model_name
self.output_dir = output_dir
self.setup_logging()
os.makedirs(output_dir, exist_ok=True)
self.logger.info(f"Loading tokenizer {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.logger.info(f"Loading model config {model_name}...")
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
self.logger.info(f"Loading model {model_name}...")
try:
self.model = DistilBertForQuestionAnswering.from_pretrained(
model_name,
config=config,
trust_remote_code=True,
torch_dtype=torch.float32 # Keep this for consistency, though it might not be strictly necessary
)
except Exception as e: # Catch the exception if pytorch weights are not found
self.logger.info(f"Trying to load tensorflow weights")
try:
self.model = DistilBertForQuestionAnswering.from_pretrained(
model_name,
config=config,
trust_remote_code=True,
from_tf=True # Load from TensorFlow weights
)
except Exception as e:
self.logger.error(f"Failed to load the model: {e}")
raise # Re-raise the exception after logging
self.model.eval()
def setup_logging(self):
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
self.logger.addHandler(handler)
def prepare_dummy_inputs(self):
dummy_input = self.tokenizer(
"Hello, how are you?",
return_tensors="pt",
padding=True,
truncation=True,
max_length=128
)
dummy_input.pop('token_type_ids', None)
return {
'input_ids': dummy_input['input_ids'],
'attention_mask': dummy_input['attention_mask'],
}
def export_to_onnx(self):
output_path = os.path.join(self.output_dir, "model.onnx")
inputs = self.prepare_dummy_inputs()
dynamic_axes = {
'input_ids': {0: 'batch_size', 1: 'sequence_length'},
'attention_mask': {0: 'batch_size', 1: 'sequence_length'},
'start_logits': {0: 'batch_size', 1: 'sequence_length'},
'end_logits': {0: 'batch_size', 1: 'sequence_length'},
}
class ModelWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
return outputs.start_logits, outputs.end_logits
wrapped_model = ModelWrapper(self.model)
try:
torch.onnx.export(
wrapped_model,
(inputs['input_ids'], inputs['attention_mask']),
output_path,
export_params=True,
opset_version=14, # Or a suitable version
do_constant_folding=True,
input_names=['input_ids', 'attention_mask'],
output_names=['start_logits', 'end_logits'],
dynamic_axes=dynamic_axes,
verbose=False
)
self.logger.info(f"Model exported to {output_path}")
return output_path
except Exception as e:
self.logger.error(f"ONNX export failed: {str(e)}")
raise
def verify_model(self, model_path: str):
try:
onnx_model = onnx.load(model_path)
onnx.checker.check_model(onnx_model)
self.logger.info("ONNX model verification successful")
return True
except Exception as e:
self.logger.error(f"Model verification failed: {str(e)}")
return False
def quantize_model(self, model_path: str):
weight_types = {'int4':QuantType.QInt4, 'int8':QuantType.QInt8, 'uint4':QuantType.QUInt4, 'uint8':QuantType.QUInt8, 'uint16':QuantType.QUInt16, 'int16':QuantType.QInt16}
all_quantized_paths = []
for weight_type in weight_types.keys():
quantized_path = os.path.join(self.output_dir, "model_" + weight_type + ".onnx")
try:
quantize_dynamic(
model_path,
quantized_path,
weight_type=weight_types[weight_type]
)
self.logger.info(f"Model quantized ({weight_type}) and saved to {quantized_path}")
all_quantized_paths.append(quantized_path)
except Exception as e:
self.logger.error(f"Quantization ({weight_type}) failed: {str(e)}")
raise
return all_quantized_paths
def convert(self):
try:
onnx_path = self.export_to_onnx()
if self.verify_model(onnx_path):
quantized_paths = self.quantize_model(onnx_path)
tokenizer_path = os.path.join(self.output_dir, "tokenizer")
self.tokenizer.save_pretrained(tokenizer_path)
self.logger.info(f"Tokenizer saved to {tokenizer_path}")
return {
'onnx_model': onnx_path,
'quantized_models': quantized_paths,
'tokenizer': tokenizer_path
}
else:
raise Exception("Model verification failed")
except Exception as e:
self.logger.error(f"Conversion process failed: {str(e)}")
raise
if __name__ == "__main__":
MODEL_NAME = "Docty/question_and_answer" # Or any other suitable model
OUTPUT_DIR = "onnx"
try:
converter = ONNXModelConverter(MODEL_NAME, OUTPUT_DIR)
results = converter.convert()
print("\nConversion completed successfully!")
print(f"ONNX model path: {results['onnx_model']}")
print(f"Quantized model paths: {results['quantized_models']}")
print(f"Tokenizer path: {results['tokenizer']}")
except Exception as e:
print(f"Conversion failed: {str(e)}")
|