|
""" |
|
Mistral Model Wrapper for easy integration |
|
""" |
|
import os |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
from typing import Optional |
|
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") |
|
class MistralModel: |
|
"""Wrapper for Mistral model with caching and optimization""" |
|
|
|
_instance = None |
|
_model = None |
|
_tokenizer = None |
|
|
|
def __new__(cls): |
|
if cls._instance is None: |
|
cls._instance = super().__new__(cls) |
|
return cls._instance |
|
|
|
def __init__(self): |
|
if MistralModel._model is None: |
|
self._initialize_model() |
|
|
|
def _initialize_model(self): |
|
"""Initialize Mistral model with optimizations""" |
|
print("Loading Mistral model...") |
|
|
|
model_id = "mistralai/Mistral-7B-Instruct-v0.2" |
|
|
|
|
|
MistralModel._tokenizer = AutoTokenizer.from_pretrained(model_id, token=HUGGINGFACE_TOKEN,use_fast=False) |
|
|
|
|
|
MistralModel._model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
token=HUGGINGFACE_TOKEN, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
load_in_8bit=True |
|
) |
|
|
|
print("Mistral model loaded successfully!") |
|
|
|
def generate( |
|
self, |
|
prompt: str, |
|
max_length: int = 512, |
|
temperature: float = 0.7, |
|
top_p: float = 0.95 |
|
) -> str: |
|
"""Generate response from Mistral""" |
|
|
|
|
|
formatted_prompt = f"<s>[INST] {prompt} [/INST]" |
|
|
|
|
|
inputs = MistralModel._tokenizer( |
|
formatted_prompt, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=2048 |
|
) |
|
|
|
|
|
device = next(MistralModel._model.parameters()).device |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = MistralModel._model.generate( |
|
**inputs, |
|
max_new_tokens=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True, |
|
pad_token_id=MistralModel._tokenizer.eos_token_id |
|
) |
|
|
|
|
|
response = MistralModel._tokenizer.decode( |
|
outputs[0][inputs['input_ids'].shape[1]:], |
|
skip_special_tokens=True |
|
) |
|
|
|
return response.strip() |
|
|
|
def generate_embedding(self, text: str) -> torch.Tensor: |
|
"""Generate embeddings for text""" |
|
inputs = MistralModel._tokenizer( |
|
text, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=512 |
|
) |
|
|
|
device = next(MistralModel._model.parameters()).device |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
outputs = MistralModel._model(**inputs, output_hidden_states=True) |
|
|
|
embeddings = outputs.hidden_states[-1].mean(dim=1) |
|
|
|
return embeddings |