Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,021 Bytes
f214f36 dff74c4 f214f36 dff74c4 ee19553 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 adc8fc7 f214f36 dff74c4 1ca09f6 444ccdb dff74c4 444ccdb dff74c4 1ca09f6 f214f36 2fae289 f214f36 444ccdb f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 f214f36 dff74c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from .Model import Model
import os
import signal
from functools import lru_cache
import spaces
def handle_timeout(sig, frame):
raise TimeoutError('took too long')
signal.signal(signal.SIGALRM, handle_timeout)
class Llama(Model):
def __init__(self, config, device="cuda:0"):
super().__init__(config)
self.device = device
self.max_output_tokens = int(config["params"]["max_output_tokens"])
api_pos = int(config["api_key_info"]["api_key_use"])
self.hf_token = config["api_key_info"]["api_keys"][api_pos] or os.getenv("HF_TOKEN")
self.tokenizer = AutoTokenizer.from_pretrained(self.name, use_auth_token=self.hf_token)
self.model = None # Delayed init
self.terminators = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
def _load_model_if_needed(self):
if self.model is None:
model = AutoModelForCausalLM.from_pretrained(
self.name,
torch_dtype=torch.bfloat16,
token=self.hf_token,
device_map="auto", # or omit entirely to default to CPU
)
self.model = model
return self.model
@spaces.GPU
def query(self, msg, max_tokens=128000):
model = self._load_model_if_needed().to("cuda")
messages = self.messages
messages[1]["content"] = msg
input_ids = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
truncation=True
).to(model.device)
attention_mask = torch.ones(input_ids.shape, device=model.device)
try:
signal.alarm(60)
output_tokens = model.generate(
input_ids,
max_length=max_tokens,
attention_mask=attention_mask,
eos_token_id=self.terminators,
top_k=50,
do_sample=False
)
signal.alarm(0)
except TimeoutError:
print("time out")
return "time out"
return self.tokenizer.decode(output_tokens[0][input_ids.shape[-1]:], skip_special_tokens=True)
def get_prompt_length(self, msg):
model = self._load_model_if_needed()
messages = self.messages
messages[1]["content"] = msg
input_ids = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
padding=True,
truncation=True
).to(model.device)
return len(input_ids[0])
def cut_context(self, msg, max_length):
tokens = self.tokenizer.encode(msg, add_special_tokens=True)
truncated_tokens = tokens[:max_length]
truncated_text = self.tokenizer.decode(truncated_tokens, skip_special_tokens=True)
return truncated_text
|