#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import json import logging import uuid from typing import List, Dict, Union, Optional, Generator, Any import random import time from flask import Flask, request, Response, stream_with_context, jsonify, g, render_template_string from llama_cpp import Llama class JsonFormatter(logging.Formatter): def format(self, record): log_record = { "timestamp": self.formatTime(record, self.datefmt), "level": record.levelname, "name": record.name, "message": record.getMessage(), "pathname": record.pathname, "lineno": record.lineno, } if hasattr(record, 'request_id'): log_record['request_id'] = record.request_id if record.exc_info: log_record['exception'] = self.formatException(record.exc_info) if record.stack_info: log_record['stack_info'] = self.formatStack(record.stack_info) skip_keys = {'message', 'asctime', 'levelname', 'levelno', 'pathname', 'filename', 'module', 'funcName', 'lineno', 'created', 'msecs', 'relativeCreated', 'thread', 'threadName', 'process', 'processName', 'exc_info', 'exc_text', 'stack_info', 'request_id'} for key, value in record.__dict__.items(): if not key.startswith('_') and key not in log_record and key not in skip_keys: log_record[key] = value return json.dumps(log_record) def setup_logging(): logger = logging.getLogger() if not logger.handlers: handler = logging.StreamHandler() formatter = JsonFormatter() handler.setFormatter(formatter) logger.addHandler(handler) logger.setLevel(logging.INFO) logging.getLogger("werkzeug").setLevel(logging.ERROR) logging.getLogger("llama_cpp").setLevel(logging.WARNING) return logger logger = setup_logging() MODEL_REPO = os.getenv("MODEL_REPO", "jnjj/vcvcvcv") MODEL_FILE = os.getenv("MODEL_FILE", "gemma-3-4b-it-q4_0.gguf") N_CTX = int(os.getenv("N_CTX", "2048")) N_BATCH = int(os.getenv("N_BATCH", "512")) N_GPU_LAYERS = 0 FIXED_REPEAT_PENALTY = float(os.getenv("FIXED_REPEAT_PENALTY", "1.1")) FIXED_SEED = int(os.getenv("FIXED_SEED", "-1")) DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "Eres un asistente conciso, directo y Ăștil.") CONTEXT_TRUNCATION_BUFFER_RATIO = float(os.getenv("CONTEXT_TRUNCATION_BUFFER_RATIO", "0.85")) RANDOM_PARAMS_CHOICES = [ {"top_k": 10, "top_p": 0.5, "temperature": 0.2}, {"top_k": 10, "top_p": 0.5, "temperature": 0.1}, {"top_k": 10, "top_p": 0.5, "temperature": 0.3}, {"top_k": 10, "top_p": 0.5, "temperature": 0.4}, {"top_k": 5, "top_p": 0.3, "temperature": 0.6}, {"top_k": 20, "top_p": 0.7, "temperature": 0.5}, ] llm: Optional[Llama] = None def parse_and_validate_params(data: Dict) -> Dict: request_id = getattr(g, 'request_id', 'N/A') params = {} errors = {} params["max_tokens"] = None chosen_params = random.choice(RANDOM_PARAMS_CHOICES) params["temperature"] = chosen_params["temperature"] params["top_p"] = chosen_params["top_p"] params["top_k"] = chosen_params["top_k"] params["repeat_penalty"] = FIXED_REPEAT_PENALTY params["seed"] = FIXED_SEED stop = data.get("stop") if stop is not None: if isinstance(stop, list) and all(isinstance(s, str) for s in stop): params["stop"] = stop elif isinstance(stop, str): params["stop"] = [stop] else: errors["stop"] = "Stop must be a string or a list of strings" else: params["stop"] = None if errors: logger.error(f"Parameter validation failed for allowed fields: {errors}", extra={'request_id': request_id}) raise ValueError(json.dumps(errors)) logger.debug(f"Using parameters: max_tokens={params['max_tokens']}, repeat_penalty={params['repeat_penalty']}, seed={params['seed']}, temperature={params['temperature']}, top_p={params['top_p']}, top_k={params['top_k']}", extra={'request_id': request_id}) return params def prepare_messages(data: Dict, format: Optional[str] = None) -> List[Dict[str, str]]: request_id = getattr(g, 'request_id', 'N/A') messages_list = data.get("messages") prompt_str = data.get("prompt") system_instruction = data.get("system_prompt", DEFAULT_SYSTEM_PROMPT) if not messages_list and not prompt_str: raise ValueError("Either 'messages' list or 'prompt' string is required.") if messages_list and not isinstance(messages_list, list): raise ValueError("'messages' must be a list of dictionaries.") if prompt_str and not isinstance(prompt_str, str): raise ValueError("'prompt' must be a string.") if system_instruction and not isinstance(system_instruction, str): raise ValueError("'system_prompt' must be a string.") final_messages = [] content_format_instruction = "" if format == "markdown": content_format_instruction = " Format your response using Markdown." elif format is not None: logger.warning(f"Unsupported format '{format}' requested.", extra={'request_id': request_id}) effective_system_prompt = system_instruction.strip() + content_format_instruction.strip() if effective_system_prompt: final_messages.append({"role": "system", "content": effective_system_prompt}) user_provided_system = False if messages_list: has_user_message = False for i, msg in enumerate(messages_list): if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: raise ValueError(f"Message at index {i} is invalid: must be a dictionary with 'role' and 'content'.") role = msg.get("role") content = msg.get("content", "") if not isinstance(content, str): logger.warning(f"Message content at index {i} (role: {role}) is not a string (type: {type(content)}). Converting to string.", extra={'request_id': request_id}) content = str(content) if role == "system": if i == 0 and final_messages and final_messages[0]["role"] == "system": logger.info("Replacing default system prompt with user-provided system message.", extra={'request_id': request_id}) final_messages[0] = {"role": "system", "content": content} user_provided_system = True elif i == 0 and not final_messages: final_messages.append({"role": "system", "content": content}) user_provided_system = True else: logger.warning(f"Ignoring additional system message at index {i} as system prompt is already set or should be at the start.", extra={'request_id': request_id}) continue elif role == "user": has_user_message = True final_messages.append({"role": role, "content": content}) if not has_user_message and any(m["role"] != "system" for m in final_messages): logger.warning("The 'messages' list contains no user messages.", extra={'request_id': request_id}) elif prompt_str: final_messages.append({"role": "user", "content": prompt_str}) if not final_messages or all(m["role"] == "system" for m in final_messages): raise ValueError("No user or assistant messages found to generate a response.") return final_messages def estimate_token_count(messages: List[Dict[str, str]]) -> int: request_id = getattr(g, 'request_id', 'N/A') if not llm or not hasattr(llm, 'tokenize') or not hasattr(llm, 'apply_chat_template'): logger.warning("LLM or tokenizer/template function not available for token estimation.", extra={'request_id': request_id}) return -1 if not hasattr(llm, 'tokenize') or not hasattr(llm, 'apply_chat_template'): logger.warning("`tokenize` or `apply_chat_template` not found on LLM object. Cannot estimate tokens accurately.", extra={'request_id': request_id}) char_count = sum(len(m.get('content', '')) for m in messages) return char_count // 4 try: chat_prompt_string = llm.apply_chat_template(messages, add_generation_prompt=False) tokens = llm.tokenize(chat_prompt_string.encode('utf-8', errors='ignore'), add_bos=True) return len(tokens) except Exception as e: try: simple_text = "\n".join([f"{m.get('role', 'unknown')}: {m.get('content', '')}" for m in messages]) tokens = llm.tokenize(simple_text.encode('utf-8', errors='ignore'), add_bos=True) logger.warning(f"Chat template failed during token estimation, using simple join. Error: {e}", extra={'request_id': request_id}) return len(tokens) except Exception as e_inner: logger.error(f"Could not estimate token count using either method: {e_inner}", exc_info=True, extra={'request_id': request_id}) return -1 def get_effective_n_ctx() -> int: if llm and hasattr(llm, 'n_ctx') and callable(llm.n_ctx): try: return llm.n_ctx() except Exception: logger.warning("Failed to call llm.n_ctx(), falling back to N_CTX config value.") return N_CTX return N_CTX def truncate_messages_for_context(messages: List[Dict[str, str]], max_tokens: int, buffer_ratio: float) -> List[Dict[str, str]]: request_id = getattr(g, 'request_id', 'N/A') if not llm: return messages target_token_limit = int(max_tokens * buffer_ratio) truncated_messages: List[Dict[str, str]] = [] system_prompt: Optional[Dict[str, str]] = None if messages and messages[0].get("role") == "system": system_prompt = messages[0] remaining_messages = messages[1:] if system_prompt: truncated_messages.append(system_prompt) else: remaining_messages = messages current_token_count = estimate_token_count(truncated_messages) if truncated_messages else 0 if current_token_count == -1: logger.warning("Could not estimate initial token count for truncation, proceeding cautiously.", extra={'request_id': request_id}) current_token_count = 0 messages_to_add = [] for msg in reversed(remaining_messages): potential_list = [msg] + messages_to_add if system_prompt: potential_list_with_system = [system_prompt] + potential_list else: potential_list_with_system = potential_list next_token_count = estimate_token_count(potential_list_with_system) if next_token_count != -1 and next_token_count <= target_token_limit: messages_to_add.insert(0, msg) current_token_count = next_token_count elif next_token_count == -1: logger.warning(f"Token estimation failed while adding message: {msg}. Stopping truncation early.", extra={'request_id': request_id}) break else: logger.debug(f"Stopping truncation: Adding next message would exceed target limit ({next_token_count} > {target_token_limit}).", extra={'request_id': request_id}) break final_truncated_list = ([system_prompt] if system_prompt else []) + messages_to_add original_count = len(messages) final_count = len(final_truncated_list) if final_count < original_count: logger.warning(f"Context truncated: Kept {final_count}/{original_count} messages. Estimated tokens: ~{current_token_count}/{target_token_limit} (target).", extra={'request_id': request_id, 'kept': final_count, 'original': original_count, 'estimated_tokens': current_token_count, 'target_limit': target_token_limit}) else: logger.debug(f"Context truncation check complete. Kept all {final_count} messages. Estimated tokens: ~{current_token_count}.", extra={'request_id': request_id, 'kept': final_count, 'estimated_tokens': current_token_count}) if not final_truncated_list and messages: logger.error("Truncation resulted in an empty message list! Returning last message.", extra={'request_id': request_id}) return [messages[-1]] elif not final_truncated_list: logger.error("Truncation called with empty input, returning empty.", extra={'request_id': request_id}) return [] return final_truncated_list def load_model(): global llm, N_CTX logger.info(f"Attempting to load model: {MODEL_REPO}/{MODEL_FILE}") effective_n_gpu_layers = 0 logger.info(f"Configuration: N_CTX={N_CTX}, N_BATCH={N_BATCH}, N_GPU_LAYERS={effective_n_gpu_layers} (forced CPU)") try: llm = Llama.from_pretrained( repo_id=MODEL_REPO, filename=MODEL_FILE, n_ctx=N_CTX, n_batch=N_BATCH, n_gpu_layers=effective_n_gpu_layers, verbose=False, use_mmap=True, use_mlock=True, ) logger.info("Model loaded successfully.") if llm: actual_n_ctx = get_effective_n_ctx() if actual_n_ctx != N_CTX: logger.warning(f"Model's actual context size ({actual_n_ctx}) differs from initial config ({N_CTX}). Using actual value: {actual_n_ctx}", extra={'actual_n_ctx': actual_n_ctx, 'configured_n_ctx': N_CTX}) N_CTX = actual_n_ctx actual_n_batch = llm.n_batch if hasattr(llm, 'n_batch') else N_BATCH actual_n_gpu_layers = llm.n_gpu_layers if hasattr(llm, 'n_gpu_layers') else 0 logger.info(f"Actual Model Context Window (n_ctx): {N_CTX}") logger.info(f"Actual Model Batch Size (n_batch): {actual_n_batch}") logger.info(f"Actual Model GPU Layers (n_gpu_layers): {actual_n_gpu_layers} (should be 0 for CPU)") if N_CTX < 1024 or actual_n_batch < 64: logger.warning("Model loaded with relatively small N_CTX or N_BATCH. Performance or max generation length might be impacted.", extra={'n_ctx': N_CTX, 'n_batch': actual_n_batch}) if actual_n_gpu_layers > 0: logger.warning(f"Model loaded with {actual_n_gpu_layers} GPU layers despite requesting 0. Check llama.cpp build or environment.", extra={'actual_gpu_layers': actual_n_gpu_layers}) try: test_tokens = llm.tokenize(b"Test sentence.") logger.info(f"Tokenizer test successful. 'Test sentence.' -> {len(test_tokens)} tokens.") except Exception as tokenize_e: logger.warning(f"Could not perform test tokenization: {tokenize_e}") except Exception as e: logger.error(f"Fatal error loading model: {e}", exc_info=True) llm = None logger.error("Model failed to load. Generation requests will not work.", extra={'error': str(e)}) app = Flask(__name__) @app.before_request def before_request_func(): g.request_id = str(uuid.uuid4()) logger.debug(f"Incoming request: {request.method} {request.path} from {request.remote_addr}", extra={'request_id': g.request_id, 'path': request.path, 'method': request.method}) load_model() html_code = """ LLM API Demo

LLM API Demonstration

Health Check

API Info


        

Generate Text (Automatic Continuation with Context Management)

Note: No artificial token limit. Generation continues until the model stops naturally, hits a stop sequence, or reaches the context window limit (N_CTX={{ N_CTX }}). If the context limit is reached, the server will attempt to continue automatically by truncating older messages (unlimited continuations). Other parameters (Temperature, Top P, Top K, Repeat Penalty, Seed) are fixed/random per generation cycle.


        
""" @app.route("/") def index(): rendered_html = render_template_string( html_code, N_CTX=N_CTX, DEFAULT_SYSTEM_PROMPT=DEFAULT_SYSTEM_PROMPT ) return rendered_html @app.route("/health", methods=["GET"]) def health_check(): if llm: if hasattr(llm, 'tokenize') and hasattr(llm, 'apply_chat_template'): return jsonify(status="ok", message="Model is loaded and ready."), 200 else: logger.warning("Model loaded, but tokenizer or chat template functions might be missing.") return jsonify(status="warning", message="Model loaded, but critical functions (tokenize/apply_chat_template) might be missing."), 200 else: return jsonify(status="error", message="Model failed to load or is not available."), 503 @app.route("/info", methods=["GET"]) def model_info(): request_id = getattr(g, 'request_id', 'N/A') if not llm: logger.warning("Info request received but model is not loaded.", extra={'request_id': request_id}) return jsonify(error="Model not available."), 503 model_details: Union[Dict[str, Any], str] = "Model details unavailable" actual_n_ctx = get_effective_n_ctx() actual_n_batch = N_BATCH actual_n_gpu_layers = N_GPU_LAYERS try: actual_n_batch = llm.n_batch if hasattr(llm, 'n_batch') else N_BATCH actual_n_gpu_layers = llm.n_gpu_layers if hasattr(llm, 'n_gpu_layers') else 0 n_embd = 'N/A' if hasattr(llm, '_model') and hasattr(llm._model, 'n_embd') and callable(llm._model.n_embd): try: n_embd = llm._model.n_embd() except Exception as embd_e: logger.warning(f"Could not get n_embd: {embd_e}", extra={'request_id': request_id}) model_details = { "n_embd": n_embd, "n_ctx": actual_n_ctx, "n_batch": actual_n_batch, "n_gpu_layers": actual_n_gpu_layers, "tokenizer_present": hasattr(llm, 'tokenize'), "chat_handler_present": hasattr(llm, 'apply_chat_template') and hasattr(llm, 'create_chat_completion'), } except Exception as e: logger.warning(f"Could not retrieve all model details: {e}", extra={'request_id': request_id}) model_details = f"Error retrieving some model details: {e}" info = { "status": "ok", "message": "Model is loaded. Generation continues automatically with context truncation if context limit is hit.", "model_config": { "repo_id": MODEL_REPO, "filename": MODEL_FILE, "initial_load_config": { "n_ctx": os.getenv("N_CTX", "2048"), "n_batch": N_BATCH, "n_gpu_layers": 0, }, "loaded_model_details": model_details, }, "generation_parameters": { "note": f"No artificial 'max_tokens' limit. Generation proceeds until stop sequence, EOS, or context limit (N_CTX={actual_n_ctx}). Automatic continuation attempts by truncating context **indefinitely** if context limit is reached. Sampling parameters (temperature, top_p, top_k) are chosen randomly per request/continuation cycle from predefined sets. Repeat penalty and seed are fixed.", "fixed_max_tokens": None, "fixed_repeat_penalty": FIXED_REPEAT_PENALTY, "fixed_seed": FIXED_SEED, "max_automatic_continuations": None, "context_truncation_buffer_ratio": CONTEXT_TRUNCATION_BUFFER_RATIO, "randomly_chosen_from": RANDOM_PARAMS_CHOICES, "default_system_prompt": DEFAULT_SYSTEM_PROMPT, "user_controllable": ["messages", "prompt", "stop", "stream", "format", "system_prompt"], }, } return jsonify(info), 200 @app.route("/generate", methods=["POST"]) def generate(): request_id = getattr(g, 'request_id', 'N/A') if not llm: logger.error("Generate request received but model is not loaded.", extra={'request_id': request_id}) return jsonify(error="Model is not available.", detail="The LLM model could not be loaded."), 503 if not request.is_json: logger.warning("Request received without Content-Type: application/json", extra={'request_id': request_id}) return jsonify(error="Invalid request header", detail="Content-Type must be application/json"), 415 data = request.get_json() is_streaming = data.get("stream", True) response_format = data.get("format") log_data_summary = {k: v for k, v in data.items() if k not in ('messages', 'prompt')} log_data_summary['messages_count_initial'] = len(data.get('messages', [])) if 'messages' in data else 0 log_data_summary['has_prompt_initial'] = 'prompt' in data log_data_summary['stream'] = is_streaming log_data_summary['format'] = response_format logger.info(f"Received generation request summary.", extra={'request_id': request_id, 'summary': log_data_summary}) try: initial_messages = prepare_messages(data, format=response_format) base_generation_params = parse_and_validate_params(data) effective_n_ctx = get_effective_n_ctx() input_token_count = estimate_token_count(initial_messages) if input_token_count != -1 and input_token_count >= effective_n_ctx: truncated_initial = truncate_messages_for_context(initial_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO) truncated_token_count = estimate_token_count(truncated_initial) if truncated_token_count != -1 and truncated_token_count >= effective_n_ctx: error_msg = f"Initial input exceeds context window ({effective_n_ctx}) even after attempting truncation. Input tokens (~{input_token_count}) / Truncated tokens (~{truncated_token_count}). Reduce initial message size significantly." logger.error(error_msg, extra={'request_id': request_id, 'initial_tokens': input_token_count, 'truncated_tokens': truncated_token_count, 'n_ctx': effective_n_ctx}) return jsonify(error="Input exceeds context window", detail=error_msg), 400 else: logger.warning(f"Initial input (~{input_token_count} tokens) exceeded context window ({effective_n_ctx}). Truncated to ~{truncated_token_count} tokens.", extra={'request_id': request_id, 'initial_tokens': input_token_count, 'truncated_tokens': truncated_token_count, 'n_ctx': effective_n_ctx}) initial_messages = truncated_initial input_token_count = truncated_token_count elif input_token_count != -1: logger.info(f"Initial input token count: ~{input_token_count}. Context window: {effective_n_ctx}. Remaining: {effective_n_ctx - input_token_count}.", extra={'request_id': request_id, 'input_tokens': input_token_count, 'n_ctx': effective_n_ctx, 'remaining_ctx': effective_n_ctx - input_token_count}) else: logger.warning("Could not estimate initial token count. Proceeding with generation, may hit context limit.", extra={'request_id': request_id}) logger.info(f"Processing request with {len(initial_messages)} initial messages. Stream={is_streaming}. Format={response_format}. max_tokens=None (dynamic). Unlimited Continuations.", extra={'request_id': request_id}) except ValueError as e: logger.error(f"Invalid input data: {e}", exc_info=True, extra={'request_id': request_id}) try: error_detail = json.loads(str(e)) except json.JSONDecodeError: error_detail = str(e) return jsonify(error="Invalid input", detail=error_detail), 400 except Exception as e: logger.error(f"Unexpected error preparing request: {e}", exc_info=True, extra={'request_id': request_id}) return jsonify(error="Internal server error", detail="An unexpected error occurred processing the request."), 500 if is_streaming: def generate_streaming_with_continuation(current_request_id: str) -> Generator[str, None, None]: current_messages = list(initial_messages) continuations = 0 total_tokens_generated_stream = 0 effective_n_ctx = get_effective_n_ctx() while True: cycle_number = continuations + 1 logger.info(f"Starting streaming generation cycle {cycle_number}. Message count: {len(current_messages)}.", extra={'request_id': current_request_id}) chosen_params = random.choice(RANDOM_PARAMS_CHOICES) current_dynamic_params = { "temperature": chosen_params["temperature"], "top_p": chosen_params["top_p"], "top_k": chosen_params["top_k"], } current_params = {**base_generation_params, **current_dynamic_params} logger.debug(f"Cycle {cycle_number} params: temp={current_params['temperature']}, top_p={current_params['top_p']}, top_k={current_params['top_k']}", extra={'request_id': current_request_id}) generated_this_cycle = "" finish_reason = None hit_context_limit_in_cycle = False try: streamer = llm.create_chat_completion( messages=current_messages, max_tokens=current_params["max_tokens"], temperature=current_params["temperature"], top_p=current_params["top_p"], top_k=current_params["top_k"], repeat_penalty=current_params["repeat_penalty"], stop=current_params["stop"], seed=current_params["seed"], stream=True, ) for chunk in streamer: choice = chunk.get("choices", [{}])[0] delta = choice.get("delta", {}) token = delta.get("content") current_chunk_finish_reason = choice.get("finish_reason") if token: generated_this_cycle += token total_tokens_generated_stream += 1 yield token if current_chunk_finish_reason: finish_reason = current_chunk_finish_reason logger.info(f"Streaming chunk finished cycle {cycle_number}. Reason: {finish_reason}", extra={'request_id': current_request_id, 'finish_reason': finish_reason}) if finish_reason == 'length': hit_context_limit_in_cycle = True usage = chunk.get("usage") if usage: logger.debug(f"Usage reported in final chunk: {usage}", extra={'request_id': current_request_id, 'usage': usage}) break if not finish_reason: pass except Exception as e: err_str = str(e).lower() if "context window is full" in err_str or \ "kv cache is full" in err_str or \ "llama_decode" in err_str or \ (hasattr(e, 'condition') and ("context length" in str(e.condition).lower() or "failed to decode" in str(e.condition).lower())): logger.warning(f"N_CTX limit or related exception caught during streaming cycle {cycle_number}: {e}", extra={'request_id': current_request_id}) hit_context_limit_in_cycle = True finish_reason = 'length' else: logger.error(f"Unhandled error during streaming generation cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': current_request_id}) yield f"\n[ERROR] Generation failed unexpectedly in cycle {cycle_number}: {str(e)}" return if generated_this_cycle: if not current_messages or current_messages[-1].get('role') != 'assistant': current_messages.append({"role": "assistant", "content": generated_this_cycle}) else: current_messages[-1]['content'] += generated_this_cycle elif hit_context_limit_in_cycle: logger.warning(f"Context limit hit in streaming cycle {cycle_number} but no tokens were generated in this cycle. Check model behavior.", extra={'request_id': current_request_id}) elif not finish_reason: logger.warning(f"Stream cycle {cycle_number} ended without generating tokens or a definite finish reason. Stopping.", extra={'request_id': current_request_id}) yield f"\n[INFO] Generation stopped: Cycle ended unexpectedly." break if finish_reason == 'stop': logger.info(f"Generation stopped naturally (reason: stop) in streaming cycle {cycle_number}. Total stream tokens: ~{total_tokens_generated_stream}", extra={'request_id': current_request_id}) yield f"\n[INFO] Generation stopped: Stop sequence or EOS." break elif hit_context_limit_in_cycle: continuations += 1 logger.warning(f"N_CTX limit reached in streaming cycle {cycle_number}. Attempting continuation {continuations} (reinicio de contador).", extra={'request_id': current_request_id}) current_messages = truncate_messages_for_context(current_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO) if not current_messages: logger.error("Context truncation resulted in empty messages during streaming. Stopping.", extra={'request_id': current_request_id}) yield f"\n[ERROR] Generation failed: Context truncation error." break yield f"\n[CONTINUING {continuations} - TRUNCATING CONTEXT...]\n" time.sleep(0.1) continue else: logger.warning(f"Streaming generation cycle {cycle_number} ended with reason '{finish_reason}' or unexpectedly. Stopping generation.", extra={'request_id': current_request_id, 'finish_reason': finish_reason}) yield f"\n[INFO] Generation stopped: Reason: {finish_reason or 'Unknown'}" break logger.info(f"Streaming generation finished after {continuations} continuations. Total stream tokens generated: ~{total_tokens_generated_stream}", extra={'request_id': current_request_id, 'continuations': continuations, 'total_stream_tokens': total_tokens_generated_stream}) headers = { "Content-Type": "text/event-stream; charset=utf-8", "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", "X-Request-ID": request_id } return Response(stream_with_context(generate_streaming_with_continuation(request_id)), headers=headers) else: current_messages = list(initial_messages) continuations = 0 full_generated_text = "" total_tokens_generated_nonstream = 0 final_finish_reason = "unknown" final_usage = {} effective_n_ctx = get_effective_n_ctx() while True: cycle_number = continuations + 1 logger.info(f"Starting non-streaming generation cycle {cycle_number}. Message count: {len(current_messages)}.", extra={'request_id': request_id}) chosen_params = random.choice(RANDOM_PARAMS_CHOICES) current_dynamic_params = { "temperature": chosen_params["temperature"], "top_p": chosen_params["top_p"], "top_k": chosen_params["top_k"], } current_params = {**base_generation_params, **current_dynamic_params} logger.debug(f"Cycle {cycle_number} params: temp={current_params['temperature']}, top_p={current_params['top_p']}, top_k={current_params['top_k']}", extra={'request_id': request_id}) generated_this_cycle = "" finish_reason = None hit_context_limit_in_cycle = False usage_this_cycle = {} try: result = llm.create_chat_completion( messages=current_messages, max_tokens=current_params["max_tokens"], temperature=current_params["temperature"], top_p=current_params["top_p"], top_k=current_params["top_k"], repeat_penalty=current_params["repeat_penalty"], stop=current_params["stop"], seed=current_params["seed"], stream=False, ) if result and "choices" in result and result["choices"]: choice = result["choices"][0] generated_this_cycle = choice.get("message", {}).get("content", "") finish_reason = choice.get("finish_reason", "unknown") else: logger.error(f"Invalid response structure from llama_cpp in non-streaming cycle {cycle_number}: {result}", extra={'request_id': request_id}) return jsonify(error="Generation failed", detail=f"Invalid response structure from model in cycle {cycle_number}."), 500 usage_this_cycle = result.get("usage", {}) final_finish_reason = finish_reason if usage_this_cycle: final_usage = usage_this_cycle logger.info(f"Non-streaming cycle {cycle_number} finished. Reason: {finish_reason}. Usage: {usage_this_cycle}", extra={'request_id': request_id, 'usage': usage_this_cycle, 'finish_reason': finish_reason}) if finish_reason == 'length': hit_context_limit_in_cycle = True except Exception as e: err_str = str(e).lower() if "context window is full" in err_str or \ "kv cache is full" in err_str or \ "llama_decode" in err_str or \ (hasattr(e, 'condition') and ("context length" in str(e.condition).lower() or "failed to decode" in str(e.condition).lower())): logger.warning(f"N_CTX limit or related exception caught during non-streaming cycle {cycle_number}: {e}", extra={'request_id': request_id}) hit_context_limit_in_cycle = True finish_reason = 'length' else: logger.error(f"Unhandled error during non-streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': request_id}) return jsonify(error="Generation failed", detail=f"Internal error in cycle {cycle_number}: {str(e)}"), 500 if generated_this_cycle: if continuations > 0 and full_generated_text: full_generated_text += f"\n\n[CONTINUATION {continuations} - TRUNCATED CONTEXT]\n\n" full_generated_text += generated_this_cycle tokens_generated_cycle = usage_this_cycle.get("completion_tokens", 0) total_tokens_generated_nonstream += tokens_generated_cycle if not current_messages or current_messages[-1].get('role') != 'assistant': current_messages.append({"role": "assistant", "content": generated_this_cycle}) else: current_messages[-1]['content'] += generated_this_cycle elif hit_context_limit_in_cycle: logger.warning(f"Non-streaming N_CTX limit hit in cycle {cycle_number} but no completion tokens reported.", extra={'request_id': request_id}) if continuations > 0 and full_generated_text: full_generated_text += f"\n\n[CONTINUATION {continuations} - TRUNCATED CONTEXT - NO OUTPUT THIS CYCLE]\n\n" elif not finish_reason: logger.warning(f"Non-streaming cycle {cycle_number} ended without generating tokens or a finish reason. Stopping.", extra={'request_id': request_id}) full_generated_text += f"\n[INFO: Generation stopped: Cycle {cycle_number} ended unexpectedly.]" break if finish_reason == 'stop': logger.info(f"Non-streaming generation stopped naturally (reason: stop) in cycle {cycle_number}.", extra={'request_id': request_id}) break elif hit_context_limit_in_cycle: continuations += 1 logger.warning(f"Non-streaming N_CTX limit reached in cycle {cycle_number}. Attempting continuation {continuations} (reinicio de contador).", extra={'request_id': request_id}) current_messages = truncate_messages_for_context(current_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO) if not current_messages: logger.error("Context truncation resulted in empty messages during non-streaming. Stopping.", extra={'request_id': request_id}) full_generated_text += f"\n[ERROR: Generation failed: Context truncation error.]" final_finish_reason = "truncation_error" break continue else: logger.warning(f"Non-streaming cycle {cycle_number} ended with reason '{finish_reason}' or unexpectedly. Stopping generation.", extra={'request_id': request_id, 'finish_reason': finish_reason}) full_generated_text += f"\n\n[INFO: Generation stopped unexpectedly. Reason: {finish_reason or 'Unknown'}]" break logger.info(f"Non-streaming generation finished after {continuations} continuations. Total completion tokens reported: {total_tokens_generated_nonstream}. Final reason: {final_finish_reason}", extra={'request_id': request_id, 'continuations': continuations, 'total_completion_tokens': total_tokens_generated_nonstream, 'final_reason': final_finish_reason}) response = Response(full_generated_text, mimetype="text/plain; charset=utf-8") response.headers["X-Request-ID"] = request_id response.headers["X-Finish-Reason"] = final_finish_reason response.headers["X-Continuations"] = str(continuations) total_prompt_tokens = final_usage.get("prompt_tokens", "N/A") response.headers["X-Usage-Completion-Tokens"] = str(total_tokens_generated_nonstream) response.headers["X-Usage-Prompt-Tokens-Last-Cycle"] = str(total_prompt_tokens) response.headers["X-Usage-Total-Tokens-Last-Cycle"] = str(final_usage.get("total_tokens", "N/A")) return response if __name__ == "__main__": host = os.getenv("HOST", "0.0.0.0") port = int(os.getenv("PORT", "7860")) is_debug = os.getenv("FLASK_DEBUG", "0") == "1" log_level = logging.DEBUG if is_debug else logging.INFO logger.setLevel(log_level) logger.info(f"Starting Flask server on {host}:{port} (Debug mode: {is_debug})") logger.info(f"Model: {MODEL_REPO}/{MODEL_FILE}, N_CTX={N_CTX}, Automatic Continuations: UNLIMITED (with context truncation)") if not llm: logger.critical("MODEL FAILED TO LOAD. SERVER WILL START BUT '/generate' WILL FAIL.") if not is_debug: try: from waitress import serve logger.info("Running with Waitress production server.") serve(app, host=host, port=port, threads=8) except ImportError: logger.warning("Waitress not found. Falling back to Flask development server. Install waitress for production.") app.run(host=host, port=port, threaded=True, debug=is_debug) else: logger.info("Running with Flask development server (Debug=True).") app.run(host=host, port=port, threaded=True, debug=is_debug, use_reloader=False)