#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import json import logging import uuid from typing import List, Dict, Union, Optional, Generator, Any import random import time from flask import Flask, request, Response, stream_with_context, jsonify, g, render_template_string from llama_cpp import Llama class JsonFormatter(logging.Formatter): def format(self, record): log_record = { "timestamp": self.formatTime(record, self.datefmt), "level": record.levelname, "name": record.name, "message": record.getMessage(), "pathname": record.pathname, "lineno": record.lineno, } if hasattr(record, 'request_id'): log_record['request_id'] = record.request_id if record.exc_info: log_record['exception'] = self.formatException(record.exc_info) if record.stack_info: log_record['stack_info'] = self.formatStack(record.stack_info) skip_keys = {'message', 'asctime', 'levelname', 'levelno', 'pathname', 'filename', 'module', 'funcName', 'lineno', 'created', 'msecs', 'relativeCreated', 'thread', 'threadName', 'process', 'processName', 'exc_info', 'exc_text', 'stack_info', 'request_id'} for key, value in record.__dict__.items(): if not key.startswith('_') and key not in log_record and key not in skip_keys: # Ensure value is JSON serializable try: json.dumps(value) log_record[key] = value except TypeError: log_record[key] = str(value) # Convert non-serializable types to string except Exception: log_record[key] = "[Unserializable Value]" return json.dumps(log_record) def setup_logging(): logger = logging.getLogger() if not logger.handlers: handler = logging.StreamHandler() formatter = JsonFormatter() handler.setFormatter(formatter) logger.addHandler(handler) logger.setLevel(logging.INFO) logging.getLogger("werkzeug").setLevel(logging.ERROR) logging.getLogger("llama_cpp").setLevel(logging.WARNING) return logger logger = setup_logging() MODEL_REPO = os.getenv("MODEL_REPO", "jnjj/vcvcvcv") MODEL_FILE = os.getenv("MODEL_FILE", "gemma-3-4b-it-q4_0.gguf") N_CTX_CONFIG = int(os.getenv("N_CTX", "2048")) N_BATCH = int(os.getenv("N_BATCH", "512")) N_GPU_LAYERS_CONFIG = int(os.getenv("N_GPU_LAYERS", "0")) MAX_CONTINUATIONS = int(os.getenv("MAX_CONTINUATIONS", "-1")) FIXED_REPEAT_PENALTY = float(os.getenv("FIXED_REPEAT_PENALTY", "1.1")) FIXED_SEED = int(os.getenv("FIXED_SEED", "-1")) DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "Eres un asistente conciso, directo y Ăștil.") CONTEXT_TRUNCATION_BUFFER_RATIO = float(os.getenv("CONTEXT_TRUNCATION_BUFFER_RATIO", "0.85")) RANDOM_PARAMS_CHOICES = [ {"temperature": 0.2, "top_p": 0.5, "top_k": 10}, {"temperature": 0.1, "top_p": 0.5, "top_k": 10}, {"temperature": 0.3, "top_p": 0.5, "top_k": 10}, {"temperature": 0.4, "top_p": 0.5, "top_k": 10}, {"temperature": 0.6, "top_p": 0.3, "top_k": 5}, {"temperature": 0.5, "top_p": 0.7, "top_k": 20}, ] llm: Optional[Llama] = None ACTUAL_N_CTX: int = N_CTX_CONFIG ACTUAL_N_BATCH: int = N_BATCH ACTUAL_N_GPU_LAYERS: int = N_GPU_LAYERS_CONFIG class ContextLimitException(Exception): pass class GenerationFailedException(Exception): pass def prepare_messages(data: Dict, format: Optional[str] = None, request_id: str = 'N/A') -> List[Dict[str, str]]: messages_list = data.get("messages") prompt_str = data.get("prompt") system_instruction = data.get("system_prompt", DEFAULT_SYSTEM_PROMPT) if not messages_list and not prompt_str: raise ValueError("Either 'messages' list or 'prompt' string is required.") if messages_list is not None and not isinstance(messages_list, list): raise ValueError("'messages' must be a list of dictionaries.") if prompt_str is not None and not isinstance(prompt_str, str): raise ValueError("'prompt' must be a string.") if system_instruction is not None and not isinstance(system_instruction, str): raise ValueError("'system_prompt' must be a string.") final_messages = [] content_format_instruction = "" if format == "markdown": content_format_instruction = " Format your response using Markdown." elif format is not None: logger.warning(f"Unsupported format '{format}' requested.", extra={'request_id': request_id, 'format': format}) effective_system_prompt_content = system_instruction.strip() + content_format_instruction.strip() if effective_system_prompt_content: final_messages.append({"role": "system", "content": effective_system_prompt_content}) if messages_list: has_user_message = False for i, msg in enumerate(messages_list): if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: raise ValueError(f"Message at index {i} is invalid: must be a dictionary with 'role' and 'content'.") role = msg.get("role") content = msg.get("content", "") if not isinstance(content, str): logger.warning(f"Message content at index {i} (role: {role}) is not a string (type: {type(content)}). Converting to string.", extra={'request_id': request_id, 'message_index': i, 'role': role, 'content_type': type(content)}) content = str(content) if role == "system": if i == 0 and final_messages and final_messages[0]["role"] == "system": logger.info("Replacing default system prompt with user-provided system message.", extra={'request_id': request_id}) final_messages[0]["content"] = content elif i == 0 and not final_messages: final_messages.append({"role": "system", "content": content}) else: logger.warning(f"Ignoring additional system message at index {i} as system prompt is already set or should be at the start.", extra={'request_id': request_id, 'message_index': i}) continue elif role == "user": has_user_message = True final_messages.append({"role": role, "content": content}) if not has_user_message and any(m["role"] != "system" for m in final_messages): logger.warning("The 'messages' list contains no user messages.", extra={'request_id': request_id}) elif prompt_str: final_messages.append({"role": "user", "content": prompt_str}) if not final_messages or all(m["role"] == "system" for m in final_messages): raise ValueError("No user or assistant messages found to generate a response.") return final_messages def estimate_token_count(messages: List[Dict[str, str]], request_id: str = 'N/A') -> int: if not llm or not hasattr(llm, 'tokenize') or not hasattr(llm, 'apply_chat_template'): logger.warning("LLM or tokenizer/template function not available for token estimation.", extra={'request_id': request_id}) return -1 try: chat_prompt_string = llm.apply_chat_template(messages, add_generation_prompt=True) tokens = llm.tokenize(chat_prompt_string.encode('utf-8', errors='ignore'), add_bos=True) return len(tokens) except Exception as e: logger.error(f"Could not estimate token count using apply_chat_template: {e}", exc_info=True, extra={'request_id': request_id}) char_count = sum(len(m.get('content', '')) for m in messages) estimated_tokens = char_count // 4 logger.warning(f"Falling back to character-based token estimation (~{estimated_tokens})", extra={'request_id': request_id, 'estimated_tokens': estimated_tokens, 'char_count': char_count}) return estimated_tokens def get_effective_n_ctx() -> int: return ACTUAL_N_CTX def truncate_messages_for_context(messages: List[Dict[str, str]], max_tokens: int, buffer_ratio: float, request_id: str = 'N/A') -> List[Dict[str, str]]: if not llm: return messages target_token_limit = int(max_tokens * buffer_ratio) truncated_messages: List[Dict[str, str]] = [] system_prompt: Optional[Dict[str, str]] = None if messages and messages[0].get("role") == "system": system_prompt = messages[0] truncated_messages.append(system_prompt) remaining_messages = messages[1:] else: remaining_messages = messages current_token_count = estimate_token_count(truncated_messages, request_id=request_id) if truncated_messages else 0 if current_token_count == -1: logger.warning("Could not estimate initial token count for truncation, proceeding cautiously with char estimate.", extra={'request_id': request_id}) current_token_count = sum(len(m.get('content', '')) for m in truncated_messages) // 4 messages_to_add = [] for msg in reversed(remaining_messages): potential_list = ([system_prompt] if system_prompt else []) + [msg] + messages_to_add next_token_count = estimate_token_count(potential_list, request_id=request_id) if next_token_count != -1 and next_token_count <= target_token_limit: messages_to_add.insert(0, msg) current_token_count = next_token_count elif next_token_count == -1: logger.warning(f"Token estimation failed while adding message: {msg}. Stopping truncation early.", extra={'request_id': request_id}) break else: logger.debug(f"Stopping truncation: Adding next message would exceed target limit ({next_token_count} > {target_token_limit}).", extra={'request_id': request_id}) break final_truncated_list = ([system_prompt] if system_prompt else []) + messages_to_add original_count = len(messages) final_count = len(final_truncated_list) if not final_truncated_list or all(m.get("role") == "system" for m in final_truncated_list): if any(m.get("role") == "user" for m in messages): last_user_message = next((m for m in reversed(messages) if m.get("role") == "user"), None) if last_user_message: logger.warning("Truncation resulted in empty or system-only messages, attempting to keep last user message.", extra={'request_id': request_id}) final_truncated_list = ([system_prompt] if system_prompt else []) + [last_user_message] final_count = len(final_truncated_list) current_token_count = estimate_token_count(final_truncated_list, request_id=request_id) if final_count < original_count: logger.warning(f"Context truncated: Kept {final_count}/{original_count} messages. Estimated tokens: ~{current_token_count}/{target_token_limit} (target).", extra={'request_id': request_id, 'kept': final_count, 'original': original_count, 'estimated_tokens': current_token_count, 'target_limit': target_token_limit}) else: logger.debug(f"Context truncation check complete. Kept all {final_count} messages. Estimated tokens: ~{current_token_count}.", extra={'request_id': request_id, 'kept': final_count, 'estimated_tokens': current_token_count}) if not final_truncated_list: logger.error("Context truncation resulted in an empty message list!", extra={'request_id': request_id}) return [] return final_truncated_list def get_property_or_method_value(obj: Any, prop_name: str, default: Any = None) -> Any: """Safely get property value or call method if callable.""" if hasattr(obj, prop_name): prop = getattr(obj, prop_name) if callable(prop): try: return prop() except Exception: logger.warning(f"Error calling method {prop_name} on {type(obj)}", exc_info=True) return default else: return prop return default def load_model(): global llm, ACTUAL_N_CTX, ACTUAL_N_BATCH, ACTUAL_N_GPU_LAYERS logger.info(f"Attempting to load model: {MODEL_REPO}/{MODEL_FILE}") logger.info(f"Configuration: N_CTX={N_CTX_CONFIG}, N_BATCH={N_BATCH}, N_GPU_LAYERS={N_GPU_LAYERS_CONFIG}") try: llm = Llama.from_pretrained( repo_id=MODEL_REPO, filename=MODEL_FILE, n_ctx=N_CTX_CONFIG, n_batch=N_BATCH, n_gpu_layers=N_GPU_LAYERS_CONFIG, verbose=False, use_mmap=True, use_mlock=True, ) logger.info("Model loaded successfully.") if llm: ACTUAL_N_CTX = get_property_or_method_value(llm, 'n_ctx', N_CTX_CONFIG) ACTUAL_N_BATCH = get_property_or_method_value(llm, 'n_batch', N_BATCH) ACTUAL_N_GPU_LAYERS = get_property_or_method_value(llm, 'n_gpu_layers', 0) if ACTUAL_N_CTX != N_CTX_CONFIG: logger.warning(f"Model's actual context size ({ACTUAL_N_CTX}) differs from config ({N_CTX_CONFIG}). Using actual.", extra={'actual_n_ctx': ACTUAL_N_CTX, 'configured_n_ctx': N_CTX_CONFIG}) if ACTUAL_N_GPU_LAYERS != N_GPU_LAYERS_CONFIG: logger.warning(f"Model loaded with {ACTUAL_N_GPU_LAYERS} GPU layers despite requesting {N_GPU_LAYERS_CONFIG}. Check llama.cpp build or environment.", extra={'actual_gpu_layers': ACTUAL_N_GPU_LAYERS, 'configured_gpu_layers': N_GPU_LAYERS_CONFIG}) logger.info(f"Actual Model Context Window (n_ctx): {ACTUAL_N_CTX}") logger.info(f"Actual Model Batch Size (n_batch): {ACTUAL_N_BATCH}") logger.info(f"Actual Model GPU Layers (n_gpu_layers): {ACTUAL_N_GPU_LAYERS}") try: test_tokens = llm.tokenize(b"Test sentence.") logger.info(f"Tokenizer test successful. 'Test sentence.' -> {len(test_tokens)} tokens.") except Exception as tokenize_e: logger.warning(f"Could not perform test tokenization: {tokenize_e}") except Exception as e: logger.error(f"Fatal error loading model: {e}", exc_info=True) llm = None logger.error("Model failed to load. Generation requests will not work.", extra={'error': str(e)}) app = Flask(__name__) @app.before_request def before_request_func(): g.request_id = str(uuid.uuid4()) logger.debug(f"Incoming request: {request.method} {request.path} from {request.remote_addr}", extra={'request_id': g.request_id, 'path': request.path, 'method': request.method}) load_model() html_code = """ LLM API Demo

LLM API Demonstration

Health Check

API Info


        

Generate Text (Automatic Continuation with Context Management)

Note: No artificial token limit. Generation continues until the model stops naturally, hits a stop sequence, or reaches the context window limit (N_CTX={{ ACTUAL_N_CTX }}). If the context limit is reached, the server will attempt to continue automatically by truncating older messages (unlimited continuations). Other parameters (Temperature, Top P, Top K, Repeat Penalty, Seed) are fixed/random per generation cycle.


        
""" @app.route("/") def index(): rendered_html = render_template_string( html_code, ACTUAL_N_CTX=ACTUAL_N_CTX, DEFAULT_SYSTEM_PROMPT=DEFAULT_SYSTEM_PROMPT ) return rendered_html @app.route("/health", methods=["GET"]) def health_check(): if llm: if hasattr(llm, 'tokenize') and hasattr(llm, 'apply_chat_template'): return jsonify(status="ok", message="Model is loaded and ready."), 200 else: logger.warning("Model loaded, but tokenizer or chat template functions might be missing.", extra={'request_id': getattr(g, 'request_id', 'N/A')}) return jsonify(status="warning", message="Model loaded, but critical functions (tokenize/apply_chat_template) might be missing."), 200 else: return jsonify(status="error", message="Model failed to load or is not available."), 503 @app.route("/info", methods=["GET"]) def model_info(): request_id = getattr(g, 'request_id', 'N/A') if not llm: logger.warning("Info request received but model is not loaded.", extra={'request_id': request_id}) return jsonify(error="Model not available."), 503 model_details: Union[Dict[str, Any], str] = "Model details unavailable" try: n_embd = get_property_or_method_value(get_property_or_method_value(llm, '_model'), 'n_embd', 'N/A') model_details = { "n_embd": n_embd, "n_ctx": ACTUAL_N_CTX, "n_batch": ACTUAL_N_BATCH, "n_gpu_layers": ACTUAL_N_GPU_LAYERS, "tokenizer_present": hasattr(llm, 'tokenize'), "chat_handler_present": hasattr(llm, 'apply_chat_template') and hasattr(llm, 'create_chat_completion'), } except Exception as e: logger.warning(f"Could not retrieve all model details: {e}", extra={'request_id': request_id}, exc_info=True) model_details = f"Error retrieving some model details: {e}" info = { "status": "ok", "message": "Model is loaded. Generation continues automatically with context truncation if context limit is hit.", "model_config": { "repo_id": MODEL_REPO, "filename": MODEL_FILE, "initial_load_config": { "n_ctx": N_CTX_CONFIG, "n_batch": N_BATCH, "n_gpu_layers": N_GPU_LAYERS_CONFIG, }, "loaded_model_details": model_details, }, "generation_parameters": { "note": f"No artificial 'max_tokens' limit. Generation proceeds until stop sequence, EOS, or context limit (N_CTX={ACTUAL_N_CTX}). Automatic continuation attempts by truncating context occur up to {MAX_CONTINUATIONS if MAX_CONTINUATIONS >= 0 else 'unlimited'} times if context limit is reached. Sampling parameters (temperature, top_p, top_k) are chosen randomly per request/continuation cycle from predefined sets. Repeat penalty and seed are fixed.", "fixed_max_tokens": None, "fixed_repeat_penalty": FIXED_REPEAT_PENALTY, "fixed_seed": FIXED_SEED, "max_automatic_continuations": MAX_CONTINUATIONS if MAX_CONTINUATIONS >= 0 else None, "context_truncation_buffer_ratio": CONTEXT_TRUNCATION_BUFFER_RATIO, "randomly_chosen_from": RANDOM_PARAMS_CHOICES, "default_system_prompt": DEFAULT_SYSTEM_PROMPT, "user_controllable": ["messages", "prompt", "stop", "stream", "format", "system_prompt"], }, } return jsonify(info), 200 def _generate_single_cycle(messages: List[Dict[str, str]], params: Dict, stream: bool, request_id: str) -> Union[Generator[Dict, None, None], Dict]: try: logger.debug(f"Starting llama.cpp chat completion call. Stream: {stream}. Messages: {len(messages)}. Params summary: temp={params.get('temperature')}, top_p={params.get('top_p')}, top_k={params.get('top_k')}, stop={params.get('stop')}", extra={'request_id': request_id, 'stream': stream, 'message_count': len(messages)}) result = llm.create_chat_completion( messages=messages, max_tokens=params["max_tokens"], temperature=params["temperature"], top_p=params["top_p"], top_k=params["top_k"], repeat_penalty=params["repeat_penalty"], stop=params["stop"], seed=params["seed"], stream=stream, ) return result except Exception as e: err_str = str(e).lower() if "context window is full" in err_str or \ "kv cache is full" in err_str or \ "llama_decode" in err_str or \ (hasattr(e, 'condition') and isinstance(e.condition, str) and ("context length" in e.condition.lower() or "failed to decode" in e.condition.lower())): logger.warning(f"Caught N_CTX limit or related exception: {e}", extra={'request_id': request_id}) raise ContextLimitException(str(e)) from e else: logger.error(f"Unhandled error during llama.cpp call: {e}", exc_info=True, extra={'request_id': request_id}) raise GenerationFailedException(f"Unhandled llama.cpp error: {str(e)}") from e @app.route("/generate", methods=["POST"]) def generate(): request_id = getattr(g, 'request_id', 'N/A') if not llm: logger.error("Generate request received but model is not loaded.", extra={'request_id': request_id}) return jsonify(error="Model is not available.", detail="The LLM model could not be loaded."), 503 if not request.is_json: logger.warning("Request received without Content-Type: application/json", extra={'request_id': request_id}) return jsonify(error="Invalid request header", detail="Content-Type must be application/json"), 415 data = request.get_json() is_streaming = data.get("stream", True) response_format = data.get("format") log_data_summary = {k: v for k, v in data.items() if k not in ('messages', 'prompt')} log_data_summary['messages_count_initial'] = len(data.get('messages', [])) if 'messages' in data else 0 log_data_summary['has_prompt_initial'] = 'prompt' in data log_data_summary['stream'] = is_streaming log_data_summary['format'] = response_format logger.info(f"Received generation request summary.", extra={'request_id': request_id, 'summary': log_data_summary}) try: initial_messages = prepare_messages(data, format=response_format, request_id=request_id) base_params: Dict[str, Any] = { "max_tokens": None, "repeat_penalty": FIXED_REPEAT_PENALTY, "seed": FIXED_SEED, } stop = data.get("stop") if stop is not None: if isinstance(stop, list) and all(isinstance(s, str) for s in stop): base_params["stop"] = stop elif isinstance(stop, str): base_params["stop"] = [stop] else: raise ValueError({"stop": "Stop must be a string or a list of strings"}) else: base_params["stop"] = None effective_n_ctx = get_effective_n_ctx() input_token_count = estimate_token_count(initial_messages, request_id=request_id) if input_token_count != -1 and input_token_count > effective_n_ctx * CONTEXT_TRUNCATION_BUFFER_RATIO: logger.warning(f"Initial input (~{input_token_count} tokens) likely exceeds safe context window ({int(effective_n_ctx * CONTEXT_TRUNCATION_BUFFER_RATIO)}). Attempting truncation.", extra={'request_id': request_id, 'initial_tokens': input_token_count, 'n_ctx': effective_n_ctx, 'buffer_ratio': CONTEXT_TRUNCATION_BUFFER_RATIO}) truncated_initial = truncate_messages_for_context(initial_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO, request_id=request_id) truncated_token_count = estimate_token_count(truncated_initial, request_id=request_id) if not truncated_initial or (truncated_token_count != -1 and truncated_token_count > effective_n_ctx): error_msg = f"Input exceeds context window ({effective_n_ctx}) even after attempting truncation. Input tokens (~{input_token_count}) / Truncated tokens (~{truncated_token_count}). Reduce initial message size." logger.error(error_msg, extra={'request_id': request_id, 'initial_tokens': input_token_count, 'truncated_tokens': truncated_token_count, 'n_ctx': effective_n_ctx}) return jsonify(error="Input exceeds context window", detail=error_msg), 400 else: logger.info(f"Initial input truncated from ~{input_token_count} to ~{truncated_token_count} tokens.", extra={'request_id': request_id, 'initial_tokens': input_token_count, 'truncated_tokens': truncated_token_count, 'n_ctx': effective_n_ctx}) initial_messages = truncated_initial input_token_count = truncated_token_count elif input_token_count != -1: logger.info(f"Initial input token count: ~{input_token_count}. Effective context window: {effective_n_ctx}. Context buffer target: {int(effective_n_ctx * CONTEXT_TRUNCATION_BUFFER_RATIO)}. Remaining: {effective_n_ctx - input_token_count}.", extra={'request_id': request_id, 'input_tokens': input_token_count, 'n_ctx': effective_n_ctx, 'buffer_target': int(effective_n_ctx * CONTEXT_TRUNCATION_BUFFER_RATIO), 'remaining_ctx': effective_n_ctx - input_token_count}) else: logger.warning("Could not estimate initial token count. Proceeding, may hit context limit.", extra={'request_id': request_id}) except ValueError as e: logger.error(f"Invalid input data or parameters: {e}", exc_info=True, extra={'request_id': request_id}) try: error_detail = json.loads(str(e)) except json.JSONDecodeError: error_detail = str(e) return jsonify(error="Invalid input", detail=error_detail), 400 except Exception as e: logger.error(f"Unexpected error preparing request: {e}", exc_info=True, extra={'request_id': request_id}) return jsonify(error="Internal server error", detail="An unexpected error occurred preparing the request."), 500 current_messages = list(initial_messages) continuations = 0 total_completion_tokens_generated = 0 final_finish_reason = "unknown" final_usage = {} full_generated_text_nonstream = "" effective_n_ctx = get_effective_n_ctx() def streaming_generator(req_id): nonlocal current_messages, continuations, total_completion_tokens_generated, final_finish_reason, final_usage while True: if MAX_CONTINUATIONS >= 0 and continuations > MAX_CONTINUATIONS: logger.info(f"Max continuations ({MAX_CONTINUATIONS}) reached. Stopping streaming.", extra={'request_id': req_id}) yield f"\n[INFO] Generation stopped: Max continuations reached ({MAX_CONTINUATIONS})." final_finish_reason = "max_continuations" break cycle_number = continuations + 1 logger.info(f"Starting streaming generation cycle {cycle_number}. Message count: {len(current_messages)}.", extra={'request_id': req_id, 'cycle': cycle_number, 'message_count': len(current_messages)}) chosen_params = random.choice(RANDOM_PARAMS_CHOICES) current_params = {**base_params, **chosen_params} generated_this_cycle_content = "" finish_reason = None usage_this_cycle = {} try: streamer = _generate_single_cycle(current_messages, current_params, stream=True, request_id=req_id) for chunk in streamer: choice = chunk.get("choices", [{}])[0] delta = choice.get("delta", {}) token_content = delta.get("content") chunk_finish_reason = choice.get("finish_reason") chunk_usage = chunk.get("usage", {}) if token_content: generated_this_cycle_content += token_content yield token_content if chunk_finish_reason: finish_reason = chunk_finish_reason usage_this_cycle = chunk_usage final_usage = usage_this_cycle break if not finish_reason and generated_this_cycle_content: finish_reason = "end_of_stream" logger.warning(f"Streaming cycle {cycle_number} ended without explicit finish reason.", extra={'request_id': req_id, 'cycle': cycle_number}) except ContextLimitException as e: logger.warning(f"Context limit caught during streaming cycle {cycle_number}.", extra={'request_id': req_id, 'cycle': cycle_number}) finish_reason = 'length' yield f"\n[INFO] Context limit approached in cycle {cycle_number}. Attempting continuation...\n" except GenerationFailedException as e: logger.error(f"Generation failed in streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': req_id, 'cycle': cycle_number}) yield f"\n[ERROR] Generation failed unexpectedly in cycle {cycle_number}: {e}" final_finish_reason = "error" break except Exception as e: logger.error(f"An unexpected error occurred in streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': req_id, 'cycle': cycle_number}) yield f"\n[ERROR] An unexpected error occurred in cycle {cycle_number}: {str(e)}" final_finish_reason = "error" break if generated_this_cycle_content: if not current_messages or current_messages[-1].get('role') != 'assistant': current_messages.append({"role": "assistant", "content": generated_this_cycle_content}) else: current_messages[-1]['content'] += generated_this_cycle_content total_completion_tokens_generated += usage_this_cycle.get("completion_tokens", 0) if finish_reason == 'stop' or finish_reason == 'end_of_stream': logger.info(f"Streaming generation stopped naturally in cycle {cycle_number}. Reason: {finish_reason}", extra={'request_id': req_id, 'cycle': cycle_number, 'finish_reason': finish_reason}) final_finish_reason = finish_reason if finish_reason != 'end_of_stream' else 'stop' yield f"\n[INFO] Generation finished." break elif finish_reason == 'length': continuations += 1 logger.warning(f"N_CTX limit reached in streaming cycle {cycle_number}. Attempting continuation {continuations}.", extra={'request_id': req_id, 'cycle': cycle_number, 'continuations': continuations}) current_messages = truncate_messages_for_context(current_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO, request_id=req_id) if not current_messages or (len(current_messages) == 1 and current_messages[0].get("role") == "system"): logger.error("Context truncation resulted in empty or system-only messages during streaming. Stopping.", extra={'request_id': req_id, 'cycle': cycle_number}) yield f"\n[ERROR] Generation failed: Context truncation error." final_finish_reason = "truncation_error" break yield f"\n[CONTINUING {continuations} - TRUNCATING CONTEXT...]\n" time.sleep(0.05) continue else: logger.warning(f"Streaming generation cycle {cycle_number} ended with unexpected reason '{finish_reason}'. Stopping generation.", extra={'request_id': req_id, 'cycle': cycle_number, 'finish_reason': finish_reason}) yield f"\n[INFO] Generation stopped: Reason: {finish_reason or 'Unknown'}" final_finish_reason = finish_reason or "unknown" break logger.info(f"Streaming generation stream closed. Total continuations: {continuations}. Final reason: {final_finish_reason}", extra={'request_id': req_id, 'continuations': continuations, 'final_reason': final_finish_reason}) if is_streaming: headers = { "Content-Type": "text/event-stream; charset=utf-8", "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", "X-Request-ID": request_id } return Response(stream_with_context(streaming_generator(request_id)), headers=headers) else: while True: if MAX_CONTINUATIONS >= 0 and continuations > MAX_CONTINUATIONS: logger.info(f"Max continuations ({MAX_CONTINUATIONS}) reached. Stopping non-streaming.", extra={'request_id': request_id}) if full_generated_text_nonstream: full_generated_text_nonstream += "\n\n" full_generated_text_nonstream += f"[INFO: Generation stopped: Max continuations reached ({MAX_CONTINUATIONS}).]" final_finish_reason = "max_continuations" break cycle_number = continuations + 1 logger.info(f"Starting non-streaming generation cycle {cycle_number}. Message count: {len(current_messages)}.", extra={'request_id': request_id, 'cycle': cycle_number, 'message_count': len(current_messages)}) chosen_params = random.choice(RANDOM_PARAMS_CHOICES) current_params = {**base_params, **chosen_params} logger.debug(f"Cycle {cycle_number} params: temp={current_params['temperature']}, top_p={current_params['top_p']}, top_k={current_params['top_k']}, stop={current_params['stop']}", extra={'request_id': request_id, 'cycle': cycle_number, 'params': current_params}) generated_this_cycle_content = "" finish_reason = None usage_this_cycle = {} try: result = _generate_single_cycle(current_messages, current_params, stream=False, request_id=request_id) if result and "choices" in result and result["choices"]: choice = result["choices"][0] generated_this_cycle_content = choice.get("message", {}).get("content", "") finish_reason = choice.get("finish_reason", "unknown") usage_this_cycle = result.get("usage", {}) final_usage = usage_this_cycle else: logger.error(f"Invalid response structure from llama_cpp in non-streaming cycle {cycle_number}: {result}", extra={'request_id': request_id, 'cycle': cycle_number, 'result': result}) if full_generated_text_nonstream: full_generated_text_nonstream += "\n\n" full_generated_text_nonstream += f"[ERROR: Invalid response structure from model in cycle {cycle_number}.]" final_finish_reason = "internal_error" break logger.info(f"Non-streaming cycle {cycle_number} finished. Reason: {finish_reason}. Usage: {usage_this_cycle}", extra={'request_id': request_id, 'cycle': cycle_number, 'usage': usage_this_cycle, 'finish_reason': finish_reason}) except ContextLimitException: logger.warning(f"Context limit caught during non-streaming cycle {cycle_number}.", extra={'request_id': request_id, 'cycle': cycle_number}) finish_reason = 'length' except GenerationFailedException as e: logger.error(f"Generation failed in non-streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': request_id, 'cycle': cycle_number}) if full_generated_text_nonstream: full_generated_text_nonstream += "\n\n" full_generated_text_nonstream += f"[ERROR: Generation failed unexpectedly in cycle {cycle_number}: {e}]" final_finish_reason = "error" break except Exception as e: logger.error(f"An unexpected error occurred in non-streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': request_id, 'cycle': cycle_number}) if full_generated_text_nonstream: full_generated_text_nonstream += "\n\n" full_generated_text_nonstream += f"[ERROR: An unexpected error occurred in cycle {cycle_number}: {str(e)}]" final_finish_reason = "error" break if generated_this_cycle_content: if continuations > 0 and full_generated_text_nonstream: full_generated_text_nonstream += f"\n\n[CONTINUATION {continuations} - TRUNCATED CONTEXT]\n\n" full_generated_text_nonstream += generated_this_cycle_content if not current_messages or current_messages[-1].get('role') != 'assistant': current_messages.append({"role": "assistant", "content": generated_this_cycle_content}) else: current_messages[-1]['content'] += generated_this_cycle_content tokens_generated_cycle = usage_this_cycle.get("completion_tokens", 0) total_completion_tokens_generated += tokens_generated_cycle elif finish_reason == 'length': logger.warning(f"Non-streaming N_CTX limit hit in cycle {cycle_number} but no completion tokens reported in usage.", extra={'request_id': request_id, 'cycle': cycle_number}) if continuations > 0 and full_generated_text_nonstream: full_generated_text_nonstream += f"\n\n[CONTINUATION {continuations} - TRUNCATED CONTEXT - NO OUTPUT THIS CYCLE]\n\n" if finish_reason == 'stop': logger.info(f"Non-streaming generation stopped naturally (reason: stop) in cycle {cycle_number}.", extra={'request_id': request_id, 'cycle': cycle_number, 'finish_reason': finish_reason}) final_finish_reason = 'stop' break elif finish_reason == 'length': continuations += 1 logger.warning(f"Non-streaming N_CTX limit reached in cycle {cycle_number}. Attempting continuation {continuations}.", extra={'request_id': request_id, 'cycle': cycle_number, 'continuations': continuations}) current_messages = truncate_messages_for_context(current_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO, request_id=request_id) if not current_messages or (len(current_messages) == 1 and current_messages[0].get("role") == "system"): logger.error("Context truncation resulted in empty or system-only messages during non-streaming. Stopping.", extra={'request_id': request_id, 'cycle': cycle_number}) if full_generated_text_nonstream: full_generated_text_nonstream += "\n\n" full_generated_text_nonstream += f"[ERROR: Generation failed: Context truncation error.]" final_finish_reason = "truncation_error" break continue else: logger.warning(f"Non-streaming cycle {cycle_number} ended with reason '{finish_reason}' or unexpectedly. Stopping generation.", extra={'request_id': request_id, 'cycle': cycle_number, 'finish_reason': finish_reason}) if full_generated_text_nonstream: full_generated_text_nonstream += "\n\n" full_generated_text_nonstream += f"[INFO: Generation stopped unexpectedly. Reason: {finish_reason or 'Unknown'}]" final_finish_reason = finish_reason or "unknown" break logger.info(f"Non-streaming generation finished after {continuations} continuations. Total completion tokens generated: {total_completion_tokens_generated}. Final reason: {final_finish_reason}", extra={'request_id': request_id, 'continuations': continuations, 'total_completion_tokens': total_completion_tokens_generated, 'final_reason': final_finish_reason}) response = Response(full_generated_text_nonstream, mimetype="text/plain; charset=utf-8") response.headers["X-Request-ID"] = request_id response.headers["X-Finish-Reason"] = final_finish_reason response.headers["X-Continuations"] = str(continuations) response.headers["X-Usage-Completion-Tokens"] = str(total_completion_tokens_generated) response.headers["X-Usage-Prompt-Tokens-Last-Cycle"] = str(final_usage.get("prompt_tokens", "N/A")) response.headers["X-Usage-Total-Tokens-Last-Cycle"] = str(final_usage.get("total_tokens", "N/A")) return response if __name__ == "__main__": host = os.getenv("HOST", "0.0.0.0") port = int(os.getenv("PORT", "7860")) is_debug = os.getenv("FLASK_DEBUG", "0") == "1" log_level = logging.DEBUG if is_debug else logging.INFO logger.setLevel(log_level) max_cont_desc = MAX_CONTINUATIONS if MAX_CONTINUATIONS >= 0 else 'UNLIMITED' logger.info(f"Starting Flask server on {host}:{port} (Debug mode: {is_debug})") logger.info(f"Model: {MODEL_REPO}/{MODEL_FILE}, N_CTX={ACTUAL_N_CTX}, Automatic Continuations: {max_cont_desc} (with context truncation)") if not llm: logger.critical("MODEL FAILED TO LOAD. SERVER WILL START BUT '/generate' WILL FAIL.") logger.info("Running with Flask development server.") app.run(host=host, port=port, threaded=True, debug=is_debug, use_reloader=False)