import logging import re from typing import Dict, Any from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from peft import PeftConfig, PeftModel import torch.cuda LOGGER = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) device = "cuda" if torch.cuda.is_available() else "cpu" MAX_INPUT_TOKEN_LENGTH = 16000 class EndpointHandler(): def __init__(self, path=""): config = PeftConfig.from_pretrained(path) model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, trust_remote_code=True, device_map='auto') self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ bos_token + '<>\\n' + message['content'] + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token }}{% endif %}{% endfor %}" # Load the Lora model self.model = PeftModel.from_pretrained(model, path) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: # Get inputs # Extract required parameters from data message = data.get("message") chat_history = data.get("chat_history", []) system_prompt = data.get("system_prompt", "") # Extract optional parameters for the generate function logic instruction = data.get("instruction") conclusions = data.get("conclusions") context = data.get("context") # Optional parameters with default values max_new_tokens = data.get("max_new_tokens", 1024) temperature = data.get("temperature", 0.6) top_p = data.get("top_p", 0.9) top_k = data.get("top_k", 50) repetition_penalty = data.get("repetition_penalty", 1.2) if message is None or system_prompt is None: raise ValueError("Missing required parameters.") # Call the generate function output = generate( tokenizer=self.tokenizer, model=self.model, message=message, chat_history=chat_history, system_prompt=system_prompt, instruction=instruction, conclusions=conclusions, context=context, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty ) # Postprocess prediction = output LOGGER.info(f"Generated text: {prediction}") return prediction def generate( tokenizer, model, message: str, chat_history: list[tuple[str, str]], system_prompt: str = "", instruction: str = None, conclusions: list[tuple[str, str]] = None, context: list[str] = None, max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, end_sequences: list[str] = ["[INST]", "[/INST]", "\n"] ) -> dict: LOGGER.info(f"instruction: {instruction}") LOGGER.info(f"conclusions: {conclusions}") LOGGER.info(f"context: {context}") # Check if the system_prompt is provided, else construct it from instruction, conclusions, and context if not system_prompt and instruction is not None and conclusions is not None and context is not None: system_prompt = "Instruction: {}\nConclusions:\n".format(instruction) for idx, (conclusion, conclusion_key) in enumerate(conclusions): system_prompt += "{}: {}\n".format(conclusion, conclusion_key) system_prompt += "\nContext:\n" for idx, ctx in enumerate(context): system_prompt += "{}: [{}]\n".format(ctx, idx + 1) # Construct conversation history conversation = [{"role": "system", "content": system_prompt}] for user, assistant in chat_history: if user is not None: conversation.extend([{"role": "user", "content": user}]) conversation.extend([{"role": "assistant", "content": assistant}]) conversation.append({"role": "user", "content": message}) # Tokenize and process the conversation input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] input_ids = input_ids.to(model.device) # Create a TextIteratorStreamer instance streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=False ) # Generate the response using TextIteratorStreamer generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty ) model.generate(**generate_kwargs) outputs = [] generated_text = "" conclusion_found = None context_numbers = [] for text in streamer: outputs.append(text) generated_text = "".join(outputs) for end_sequence in end_sequences: if end_sequence in generated_text: generated_text = generated_text.replace(end_sequence, "") return parse(generated_text, conclusions, end_sequences) return parse(generated_text, conclusions, end_sequences) def parse(generated_text: str, conclusions: list[tuple[str, str]], end_sequences: list[str]) -> dict: # Initialize variables conclusion_found = None context_numbers = [] # Remove end sequences and clean the text for end_sequence in end_sequences: generated_text = generated_text.replace(end_sequence, "") generated_text = generated_text.strip() # Check for conclusion keys in the generated text if conclusions: for conclusion_key, _ in conclusions: if conclusion_key in generated_text: conclusion_found = conclusion_key generated_text = generated_text.replace(conclusion_key, "") # Extract context numbers from the generated text context_pattern = r"\[\d+\]" context_matches = re.findall(context_pattern, generated_text) context_numbers = [int(match.strip("[]")) for match in context_matches] return { "generated_text": generated_text, "conclusion": conclusion_found, "context": context_numbers }