codys12's picture
Upload handler.py
428f41a
import logging
import re
from typing import Dict, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from peft import PeftConfig, PeftModel
import torch.cuda
LOGGER = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_INPUT_TOKEN_LENGTH = 16000
class EndpointHandler():
def __init__(self, path=""):
config = PeftConfig.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, trust_remote_code=True, device_map='auto')
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ bos_token + '<<SYS>>\\n' + message['content'] + '\\n<</SYS>>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token }}{% endif %}{% endfor %}"
# Load the Lora model
self.model = PeftModel.from_pretrained(model, path)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# Get inputs
# Extract required parameters from data
message = data.get("message")
chat_history = data.get("chat_history", [])
system_prompt = data.get("system_prompt", "")
# Extract optional parameters for the generate function logic
instruction = data.get("instruction")
conclusions = data.get("conclusions")
context = data.get("context")
# Optional parameters with default values
max_new_tokens = data.get("max_new_tokens", 1024)
temperature = data.get("temperature", 0.6)
top_p = data.get("top_p", 0.9)
top_k = data.get("top_k", 50)
repetition_penalty = data.get("repetition_penalty", 1.2)
if message is None or system_prompt is None:
raise ValueError("Missing required parameters.")
# Call the generate function
output = generate(
tokenizer=self.tokenizer,
model=self.model,
message=message,
chat_history=chat_history,
system_prompt=system_prompt,
instruction=instruction,
conclusions=conclusions,
context=context,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty
)
# Postprocess
prediction = output
LOGGER.info(f"Generated text: {prediction}")
return prediction
def generate(
tokenizer,
model,
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str = "",
instruction: str = None,
conclusions: list[tuple[str, str]] = None,
context: list[str] = None,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
end_sequences: list[str] = ["[INST]", "[/INST]", "\n"]
) -> dict:
LOGGER.info(f"instruction: {instruction}")
LOGGER.info(f"conclusions: {conclusions}")
LOGGER.info(f"context: {context}")
# Check if the system_prompt is provided, else construct it from instruction, conclusions, and context
if not system_prompt and instruction is not None and conclusions is not None and context is not None:
system_prompt = "Instruction: {}\nConclusions:\n".format(instruction)
for idx, (conclusion, conclusion_key) in enumerate(conclusions):
system_prompt += "{}: {}\n".format(conclusion, conclusion_key)
system_prompt += "\nContext:\n"
for idx, ctx in enumerate(context):
system_prompt += "{}: [{}]\n".format(ctx, idx + 1)
# Construct conversation history
conversation = [{"role": "system", "content": system_prompt}]
for user, assistant in chat_history:
if user is not None:
conversation.extend([{"role": "user", "content": user}])
conversation.extend([{"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
# Tokenize and process the conversation
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
input_ids = input_ids.to(model.device)
# Create a TextIteratorStreamer instance
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False
)
# Generate the response using TextIteratorStreamer
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty
)
model.generate(**generate_kwargs)
outputs = []
generated_text = ""
conclusion_found = None
context_numbers = []
for text in streamer:
outputs.append(text)
generated_text = "".join(outputs)
for end_sequence in end_sequences:
if end_sequence in generated_text:
generated_text = generated_text.replace(end_sequence, "")
return parse(generated_text, conclusions, end_sequences)
return parse(generated_text, conclusions, end_sequences)
def parse(generated_text: str, conclusions: list[tuple[str, str]], end_sequences: list[str]) -> dict:
# Initialize variables
conclusion_found = None
context_numbers = []
# Remove end sequences and clean the text
for end_sequence in end_sequences:
generated_text = generated_text.replace(end_sequence, "")
generated_text = generated_text.strip()
# Check for conclusion keys in the generated text
if conclusions:
for conclusion_key, _ in conclusions:
if conclusion_key in generated_text:
conclusion_found = conclusion_key
generated_text = generated_text.replace(conclusion_key, "")
# Extract context numbers from the generated text
context_pattern = r"\[\d+\]"
context_matches = re.findall(context_pattern, generated_text)
context_numbers = [int(match.strip("[]")) for match in context_matches]
return {
"generated_text": generated_text,
"conclusion": conclusion_found,
"context": context_numbers
}