File size: 3,296 Bytes
20d720d aeee3e3 20d720d aeee3e3 20d720d 005cc1a 20d720d aeee3e3 20d720d aeee3e3 20d720d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
"""
Mistral Model Wrapper for easy integration
"""
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import Optional
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
class MistralModel:
"""Wrapper for Mistral model with caching and optimization"""
_instance = None
_model = None
_tokenizer = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if MistralModel._model is None:
self._initialize_model()
def _initialize_model(self):
"""Initialize Mistral model with optimizations"""
print("Loading Mistral model...")
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
# Load tokenizer
MistralModel._tokenizer = AutoTokenizer.from_pretrained(model_id, token=HUGGINGFACE_TOKEN,use_fast=False)
# Load model with optimizations
MistralModel._model = AutoModelForCausalLM.from_pretrained(
model_id,
token=HUGGINGFACE_TOKEN,
torch_dtype=torch.float16,
device_map="auto",
load_in_8bit=True # Use 8-bit quantization for memory efficiency
)
print("Mistral model loaded successfully!")
def generate(
self,
prompt: str,
max_length: int = 512,
temperature: float = 0.7,
top_p: float = 0.95
) -> str:
"""Generate response from Mistral"""
# Format prompt for Mistral instruction format
formatted_prompt = f"<s>[INST] {prompt} [/INST]"
# Tokenize
inputs = MistralModel._tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=2048
)
# Move to device
device = next(MistralModel._model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate
with torch.no_grad():
outputs = MistralModel._model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=MistralModel._tokenizer.eos_token_id
)
# Decode
response = MistralModel._tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
)
return response.strip()
def generate_embedding(self, text: str) -> torch.Tensor:
"""Generate embeddings for text"""
inputs = MistralModel._tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512
)
device = next(MistralModel._model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = MistralModel._model(**inputs, output_hidden_states=True)
# Use last hidden state as embedding
embeddings = outputs.hidden_states[-1].mean(dim=1)
return embeddings |