Spaces:
Runtime error
Runtime error
| try: | |
| from extensions.telegram_bot.source.generators.abstract_generator import AbstractGenerator | |
| except ImportError: | |
| from source.generators.abstract_generator import AbstractGenerator | |
| import os, glob, sys | |
| import torch | |
| from typing import List | |
| sys.path.append(os.path.join(os.path.split(__file__)[0], "exllama")) | |
| from source.generators.exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig | |
| from source.generators.exllama.tokenizer import ExLlamaTokenizer | |
| from source.generators.exllama.generator import ExLlamaGenerator | |
| class Generator(AbstractGenerator): | |
| # Place where path to LLM file stored | |
| model_change_allowed = False # if model changing allowed without stopping. | |
| preset_change_allowed = True # if preset_file changing allowed. | |
| def __init__(self, model_path: str, n_ctx=4096, seed=0, n_gpu_layers=0): | |
| self.n_ctx = n_ctx | |
| self.seed = seed | |
| self.n_gpu_layers = n_gpu_layers | |
| self.model_directory = model_path | |
| # Locate files we need within that directory | |
| self.tokenizer_path = os.path.join(self.model_directory, "tokenizer.model") | |
| self.model_config_path = os.path.join(self.model_directory, "config.json") | |
| self.st_pattern = os.path.join(self.model_directory, "model.safetensors") | |
| self.model_path = glob.glob(self.st_pattern) | |
| # Create config, model, tokenizer and generator | |
| self.ex_config = ExLlamaConfig(self.model_config_path) # create config from config.json | |
| self.ex_config.llm_path = self.model_path # supply path to model weights file | |
| self.ex_config.max_seq_len = n_ctx | |
| self.ex_config.max_input_len = n_ctx | |
| self.ex_config.max_attention_size = n_ctx**2 | |
| self.model = ExLlama(self.ex_config) # create ExLlama instance and load the weights | |
| self.tokenizer = ExLlamaTokenizer(self.tokenizer_path) # create tokenizer from tokenizer model file | |
| self.cache = ExLlamaCache(self.model, max_seq_len=n_ctx) # create cache for inference | |
| self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache) # create generator | |
| def generate_answer( | |
| self, prompt, generation_params, eos_token, stopping_strings, default_answer: str, turn_template="", **kwargs | |
| ): | |
| # Preparing, add stopping_strings | |
| answer = default_answer | |
| try: | |
| # Configure generator | |
| self.generator.disallow_tokens([self.tokenizer.eos_token_id]) | |
| self.generator.settings.token_repetition_penalty_max = generation_params["repetition_penalty"] | |
| self.generator.settings.temperature = generation_params["temperature"] | |
| self.generator.settings.top_p = generation_params["top_p"] | |
| self.generator.settings.top_k = generation_params["top_k"] | |
| self.generator.settings.typical = generation_params["typical_p"] | |
| # random seed set | |
| random_data = os.urandom(4) | |
| random_seed = int.from_bytes(random_data, byteorder="big") | |
| torch.manual_seed(random_seed) | |
| torch.cuda.manual_seed(random_seed) | |
| # Produce a simple generation | |
| answer = self.generate_custom( | |
| prompt, stopping_strings=stopping_strings, max_new_tokens=generation_params["max_new_tokens"] | |
| ) | |
| answer = answer[len(prompt) :] | |
| except Exception as exception: | |
| print("generator_wrapper get answer error ", str(exception) + str(exception.args)) | |
| return answer | |
| def generate_custom(self, prompt, stopping_strings: List, max_new_tokens=128): | |
| self.generator.end_beam_search() | |
| ids, mask = self.tokenizer.encode(prompt, return_mask=True, max_seq_len=self.model.config.max_seq_len) | |
| self.generator.gen_begin(ids, mask=mask) | |
| max_new_tokens = min(max_new_tokens, self.model.config.max_seq_len - ids.shape[1]) | |
| eos = torch.zeros((ids.shape[0],), dtype=torch.bool) | |
| for i in range(max_new_tokens): | |
| token = self.generator.gen_single_token(mask=mask) | |
| for j in range(token.shape[0]): | |
| if token[j, 0].item() == self.tokenizer.eos_token_id: | |
| eos[j] = True | |
| text = self.tokenizer.decode( | |
| self.generator.sequence[0] if self.generator.sequence.shape[0] == 1 else self.generator.sequence | |
| ) | |
| # check stopping string | |
| for stopping in stopping_strings: | |
| if text.endswith(stopping): | |
| text = text[: -len(stopping)] | |
| return text | |
| if eos.all(): | |
| break | |
| text = self.tokenizer.decode( | |
| self.generator.sequence[0] if self.generator.sequence.shape[0] == 1 else self.generator.sequence | |
| ) | |
| return text | |
| def tokens_count(self, text: str): | |
| encoded = self.tokenizer.encode(text, max_seq_len=20480) | |
| return len(encoded[0]) | |
| def get_model_list(self): | |
| bins = [] | |
| for i in os.listdir("../../models"): | |
| if i.endswith(".bin"): | |
| bins.append(i) | |
| return bins | |
| def load_model(self, model_file: str): | |
| return None | |