|
import logging |
|
import re |
|
from typing import Dict, Any |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
from peft import PeftConfig, PeftModel |
|
import torch.cuda |
|
|
|
|
|
LOGGER = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
MAX_INPUT_TOKEN_LENGTH = 16000 |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
config = PeftConfig.from_pretrained(path) |
|
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, trust_remote_code=True, device_map='auto') |
|
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) |
|
self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ bos_token + '<<SYS>>\\n' + message['content'] + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token }}{% endif %}{% endfor %}" |
|
|
|
self.model = PeftModel.from_pretrained(model, path) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
message = data.get("message") |
|
chat_history = data.get("chat_history", []) |
|
system_prompt = data.get("system_prompt", "") |
|
|
|
|
|
instruction = data.get("instruction") |
|
conclusions = data.get("conclusions") |
|
context = data.get("context") |
|
|
|
|
|
max_new_tokens = data.get("max_new_tokens", 1024) |
|
temperature = data.get("temperature", 0.6) |
|
top_p = data.get("top_p", 0.9) |
|
top_k = data.get("top_k", 50) |
|
repetition_penalty = data.get("repetition_penalty", 1.2) |
|
|
|
if message is None or system_prompt is None: |
|
raise ValueError("Missing required parameters.") |
|
|
|
|
|
output = generate( |
|
tokenizer=self.tokenizer, |
|
model=self.model, |
|
message=message, |
|
chat_history=chat_history, |
|
system_prompt=system_prompt, |
|
instruction=instruction, |
|
conclusions=conclusions, |
|
context=context, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty |
|
) |
|
|
|
|
|
prediction = output |
|
LOGGER.info(f"Generated text: {prediction}") |
|
return prediction |
|
|
|
def generate( |
|
tokenizer, |
|
model, |
|
message: str, |
|
chat_history: list[tuple[str, str]], |
|
system_prompt: str = "", |
|
instruction: str = None, |
|
conclusions: list[tuple[str, str]] = None, |
|
context: list[str] = None, |
|
max_new_tokens: int = 1024, |
|
temperature: float = 0.6, |
|
top_p: float = 0.9, |
|
top_k: int = 50, |
|
repetition_penalty: float = 1.2, |
|
end_sequences: list[str] = ["[INST]", "[/INST]", "\n"] |
|
) -> dict: |
|
|
|
LOGGER.info(f"instruction: {instruction}") |
|
LOGGER.info(f"conclusions: {conclusions}") |
|
LOGGER.info(f"context: {context}") |
|
|
|
if not system_prompt and instruction is not None and conclusions is not None and context is not None: |
|
system_prompt = "Instruction: {}\nConclusions:\n".format(instruction) |
|
for idx, (conclusion, conclusion_key) in enumerate(conclusions): |
|
system_prompt += "{}: {}\n".format(conclusion, conclusion_key) |
|
system_prompt += "\nContext:\n" |
|
for idx, ctx in enumerate(context): |
|
system_prompt += "{}: [{}]\n".format(ctx, idx + 1) |
|
|
|
|
|
conversation = [{"role": "system", "content": system_prompt}] |
|
for user, assistant in chat_history: |
|
if user is not None: |
|
conversation.extend([{"role": "user", "content": user}]) |
|
conversation.extend([{"role": "assistant", "content": assistant}]) |
|
conversation.append({"role": "user", "content": message}) |
|
|
|
|
|
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") |
|
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: |
|
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] |
|
input_ids = input_ids.to(model.device) |
|
|
|
|
|
streamer = TextIteratorStreamer( |
|
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False |
|
) |
|
|
|
|
|
generate_kwargs = dict( |
|
{"input_ids": input_ids}, |
|
streamer=streamer, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
top_p=top_p, |
|
top_k=top_k, |
|
temperature=temperature, |
|
num_beams=1, |
|
repetition_penalty=repetition_penalty |
|
) |
|
model.generate(**generate_kwargs) |
|
|
|
outputs = [] |
|
generated_text = "" |
|
conclusion_found = None |
|
context_numbers = [] |
|
for text in streamer: |
|
outputs.append(text) |
|
generated_text = "".join(outputs) |
|
for end_sequence in end_sequences: |
|
if end_sequence in generated_text: |
|
generated_text = generated_text.replace(end_sequence, "") |
|
return parse(generated_text, conclusions, end_sequences) |
|
return parse(generated_text, conclusions, end_sequences) |
|
|
|
def parse(generated_text: str, conclusions: list[tuple[str, str]], end_sequences: list[str]) -> dict: |
|
|
|
conclusion_found = None |
|
context_numbers = [] |
|
|
|
|
|
for end_sequence in end_sequences: |
|
generated_text = generated_text.replace(end_sequence, "") |
|
generated_text = generated_text.strip() |
|
|
|
|
|
if conclusions: |
|
for conclusion_key, _ in conclusions: |
|
if conclusion_key in generated_text: |
|
conclusion_found = conclusion_key |
|
generated_text = generated_text.replace(conclusion_key, "") |
|
|
|
|
|
context_pattern = r"\[\d+\]" |
|
context_matches = re.findall(context_pattern, generated_text) |
|
context_numbers = [int(match.strip("[]")) for match in context_matches] |
|
|
|
return { |
|
"generated_text": generated_text, |
|
"conclusion": conclusion_found, |
|
"context": context_numbers |
|
} |