File size: 3,021 Bytes
f214f36
dff74c4
f214f36
 
 
dff74c4
ee19553
f214f36
 
 
 
 
 
dff74c4
f214f36
dff74c4
f214f36
 
dff74c4
 
adc8fc7
f214f36
 
 
 
dff74c4
 
1ca09f6
444ccdb
dff74c4
 
444ccdb
 
dff74c4
1ca09f6
 
f214f36
2fae289
f214f36
444ccdb
f214f36
 
 
 
 
 
 
dff74c4
 
 
 
 
f214f36
 
dff74c4
f214f36
 
 
 
 
 
 
 
dff74c4
f214f36
dff74c4
 
 
f214f36
dff74c4
 
f214f36
 
 
 
 
dff74c4
 
 
 
f214f36
 
dff74c4
 
f214f36
 
dff74c4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from .Model import Model
import os
import signal
from functools import lru_cache
import spaces

def handle_timeout(sig, frame):
    raise TimeoutError('took too long')
signal.signal(signal.SIGALRM, handle_timeout)

class Llama(Model):
    def __init__(self, config, device="cuda:0"):
        super().__init__(config)
        self.device = device
        self.max_output_tokens = int(config["params"]["max_output_tokens"])
        api_pos = int(config["api_key_info"]["api_key_use"])
        self.hf_token = config["api_key_info"]["api_keys"][api_pos] or os.getenv("HF_TOKEN")
        self.tokenizer = AutoTokenizer.from_pretrained(self.name, use_auth_token=self.hf_token)
        self.model = None  # Delayed init
        self.terminators = [
            self.tokenizer.eos_token_id,
            self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

    def _load_model_if_needed(self):
        if self.model is None:
            model = AutoModelForCausalLM.from_pretrained(
                self.name,
                torch_dtype=torch.bfloat16,
                token=self.hf_token,
                device_map="auto",  # or omit entirely to default to CPU
            )
            self.model = model
        return self.model

    @spaces.GPU
    def query(self, msg, max_tokens=128000):
        model = self._load_model_if_needed().to("cuda")
        messages = self.messages
        messages[1]["content"] = msg

        input_ids = self.tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt",
            truncation=True
        ).to(model.device)

        attention_mask = torch.ones(input_ids.shape, device=model.device)

        try:
            signal.alarm(60)
            output_tokens = model.generate(
                input_ids,
                max_length=max_tokens,
                attention_mask=attention_mask,
                eos_token_id=self.terminators,
                top_k=50,
                do_sample=False
            )
            signal.alarm(0)
        except TimeoutError:
            print("time out")
            return "time out"

        return self.tokenizer.decode(output_tokens[0][input_ids.shape[-1]:], skip_special_tokens=True)

    def get_prompt_length(self, msg):
        model = self._load_model_if_needed()
        messages = self.messages
        messages[1]["content"] = msg
        input_ids = self.tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(model.device)
        return len(input_ids[0])

    def cut_context(self, msg, max_length):
        tokens = self.tokenizer.encode(msg, add_special_tokens=True)
        truncated_tokens = tokens[:max_length]
        truncated_text = self.tokenizer.decode(truncated_tokens, skip_special_tokens=True)
        return truncated_text