Spaces:
Paused
Paused
File size: 2,167 Bytes
9c9a39f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from typing import Optional, List, Mapping, Any
import warnings
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.models.mistral.modeling_mistral import MistralForCausalLM
from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast
from pydantic import Field
class CustomLLMMistral(LLM):
model: MistralForCausalLM = Field(...)
tokenizer: LlamaTokenizerFast = Field(...)
def __init__(self):
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, quantization_config=quantization_config, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
super().__init__(model=model, tokenizer=tokenizer)
self.model = model
self.tokenizer = tokenizer
@property
def _llm_type(self) -> str:
return "custom"
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None) -> str:
messages = [
{"role": "user", "content": prompt},
]
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to(self.model.device)
generated_ids = self.model.generate(model_inputs, max_new_tokens=512, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, top_k=4, temperature=0.7)
decoded = self.tokenizer.batch_decode(generated_ids)
output = decoded[0].split("[/INST]")[1].replace("</s>", "").strip()
if stop is not None:
for word in stop:
output = output.split(word)[0].strip()
while not output.endswith("```"):
output += "`"
return output
@property
def _identifying_params(self) -> Mapping[str, Any]:
return {"model": self.model} |