sdsdsdsd / app.py
jnjj's picture
Update app.py
9cc851e verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import json
import logging
import uuid
from typing import List, Dict, Union, Optional, Generator, Any
import random
import time
from flask import Flask, request, Response, stream_with_context, jsonify, g, render_template_string
from llama_cpp import Llama
class JsonFormatter(logging.Formatter):
def format(self, record):
log_record = {
"timestamp": self.formatTime(record, self.datefmt),
"level": record.levelname,
"name": record.name,
"message": record.getMessage(),
"pathname": record.pathname,
"lineno": record.lineno,
}
if hasattr(record, 'request_id'):
log_record['request_id'] = record.request_id
if record.exc_info:
log_record['exception'] = self.formatException(record.exc_info)
if record.stack_info:
log_record['stack_info'] = self.formatStack(record.stack_info)
skip_keys = {'message', 'asctime', 'levelname', 'levelno', 'pathname', 'filename', 'module', 'funcName', 'lineno', 'created', 'msecs', 'relativeCreated', 'thread', 'threadName', 'process', 'processName', 'exc_info', 'exc_text', 'stack_info', 'request_id'}
for key, value in record.__dict__.items():
if not key.startswith('_') and key not in log_record and key not in skip_keys:
log_record[key] = value
return json.dumps(log_record)
def setup_logging():
logger = logging.getLogger()
if not logger.handlers:
handler = logging.StreamHandler()
formatter = JsonFormatter()
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logging.getLogger("werkzeug").setLevel(logging.ERROR)
logging.getLogger("llama_cpp").setLevel(logging.WARNING)
return logger
logger = setup_logging()
MODEL_REPO = os.getenv("MODEL_REPO", "jnjj/vcvcvcv")
MODEL_FILE = os.getenv("MODEL_FILE", "gemma-3-4b-it-q4_0.gguf")
N_CTX = int(os.getenv("N_CTX", "2048"))
N_BATCH = int(os.getenv("N_BATCH", "512"))
N_GPU_LAYERS = 0
FIXED_REPEAT_PENALTY = float(os.getenv("FIXED_REPEAT_PENALTY", "1.1"))
FIXED_SEED = int(os.getenv("FIXED_SEED", "-1"))
DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "Eres un asistente conciso, directo y útil.")
CONTEXT_TRUNCATION_BUFFER_RATIO = float(os.getenv("CONTEXT_TRUNCATION_BUFFER_RATIO", "0.85"))
RANDOM_PARAMS_CHOICES = [
{"top_k": 10, "top_p": 0.5, "temperature": 0.2},
{"top_k": 10, "top_p": 0.5, "temperature": 0.1},
{"top_k": 10, "top_p": 0.5, "temperature": 0.3},
{"top_k": 10, "top_p": 0.5, "temperature": 0.4},
{"top_k": 5, "top_p": 0.3, "temperature": 0.6},
{"top_k": 20, "top_p": 0.7, "temperature": 0.5},
]
llm: Optional[Llama] = None
def parse_and_validate_params(data: Dict) -> Dict:
request_id = getattr(g, 'request_id', 'N/A')
params = {}
errors = {}
params["max_tokens"] = None
chosen_params = random.choice(RANDOM_PARAMS_CHOICES)
params["temperature"] = chosen_params["temperature"]
params["top_p"] = chosen_params["top_p"]
params["top_k"] = chosen_params["top_k"]
params["repeat_penalty"] = FIXED_REPEAT_PENALTY
params["seed"] = FIXED_SEED
stop = data.get("stop")
if stop is not None:
if isinstance(stop, list) and all(isinstance(s, str) for s in stop):
params["stop"] = stop
elif isinstance(stop, str):
params["stop"] = [stop]
else:
errors["stop"] = "Stop must be a string or a list of strings"
else:
params["stop"] = None
if errors:
logger.error(f"Parameter validation failed for allowed fields: {errors}", extra={'request_id': request_id})
raise ValueError(json.dumps(errors))
logger.debug(f"Using parameters: max_tokens={params['max_tokens']}, repeat_penalty={params['repeat_penalty']}, seed={params['seed']}, temperature={params['temperature']}, top_p={params['top_p']}, top_k={params['top_k']}", extra={'request_id': request_id})
return params
def prepare_messages(data: Dict, format: Optional[str] = None) -> List[Dict[str, str]]:
request_id = getattr(g, 'request_id', 'N/A')
messages_list = data.get("messages")
prompt_str = data.get("prompt")
system_instruction = data.get("system_prompt", DEFAULT_SYSTEM_PROMPT)
if not messages_list and not prompt_str:
raise ValueError("Either 'messages' list or 'prompt' string is required.")
if messages_list and not isinstance(messages_list, list):
raise ValueError("'messages' must be a list of dictionaries.")
if prompt_str and not isinstance(prompt_str, str):
raise ValueError("'prompt' must be a string.")
if system_instruction and not isinstance(system_instruction, str):
raise ValueError("'system_prompt' must be a string.")
final_messages = []
content_format_instruction = ""
if format == "markdown":
content_format_instruction = " Format your response using Markdown."
elif format is not None:
logger.warning(f"Unsupported format '{format}' requested.", extra={'request_id': request_id})
effective_system_prompt = system_instruction.strip() + content_format_instruction.strip()
if effective_system_prompt:
final_messages.append({"role": "system", "content": effective_system_prompt})
user_provided_system = False
if messages_list:
has_user_message = False
for i, msg in enumerate(messages_list):
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
raise ValueError(f"Message at index {i} is invalid: must be a dictionary with 'role' and 'content'.")
role = msg.get("role")
content = msg.get("content", "")
if not isinstance(content, str):
logger.warning(f"Message content at index {i} (role: {role}) is not a string (type: {type(content)}). Converting to string.", extra={'request_id': request_id})
content = str(content)
if role == "system":
if i == 0 and final_messages and final_messages[0]["role"] == "system":
logger.info("Replacing default system prompt with user-provided system message.", extra={'request_id': request_id})
final_messages[0] = {"role": "system", "content": content}
user_provided_system = True
elif i == 0 and not final_messages:
final_messages.append({"role": "system", "content": content})
user_provided_system = True
else:
logger.warning(f"Ignoring additional system message at index {i} as system prompt is already set or should be at the start.", extra={'request_id': request_id})
continue
elif role == "user":
has_user_message = True
final_messages.append({"role": role, "content": content})
if not has_user_message and any(m["role"] != "system" for m in final_messages):
logger.warning("The 'messages' list contains no user messages.", extra={'request_id': request_id})
elif prompt_str:
final_messages.append({"role": "user", "content": prompt_str})
if not final_messages or all(m["role"] == "system" for m in final_messages):
raise ValueError("No user or assistant messages found to generate a response.")
return final_messages
def estimate_token_count(messages: List[Dict[str, str]]) -> int:
request_id = getattr(g, 'request_id', 'N/A')
if not llm or not hasattr(llm, 'tokenize') or not hasattr(llm, 'apply_chat_template'):
logger.warning("LLM or tokenizer/template function not available for token estimation.", extra={'request_id': request_id})
return -1
if not hasattr(llm, 'tokenize') or not hasattr(llm, 'apply_chat_template'):
logger.warning("`tokenize` or `apply_chat_template` not found on LLM object. Cannot estimate tokens accurately.", extra={'request_id': request_id})
char_count = sum(len(m.get('content', '')) for m in messages)
return char_count // 4
try:
chat_prompt_string = llm.apply_chat_template(messages, add_generation_prompt=False)
tokens = llm.tokenize(chat_prompt_string.encode('utf-8', errors='ignore'), add_bos=True)
return len(tokens)
except Exception as e:
try:
simple_text = "\n".join([f"{m.get('role', 'unknown')}: {m.get('content', '')}" for m in messages])
tokens = llm.tokenize(simple_text.encode('utf-8', errors='ignore'), add_bos=True)
logger.warning(f"Chat template failed during token estimation, using simple join. Error: {e}", extra={'request_id': request_id})
return len(tokens)
except Exception as e_inner:
logger.error(f"Could not estimate token count using either method: {e_inner}", exc_info=True, extra={'request_id': request_id})
return -1
def get_effective_n_ctx() -> int:
if llm and hasattr(llm, 'n_ctx') and callable(llm.n_ctx):
try:
return llm.n_ctx()
except Exception:
logger.warning("Failed to call llm.n_ctx(), falling back to N_CTX config value.")
return N_CTX
return N_CTX
def truncate_messages_for_context(messages: List[Dict[str, str]], max_tokens: int, buffer_ratio: float) -> List[Dict[str, str]]:
request_id = getattr(g, 'request_id', 'N/A')
if not llm: return messages
target_token_limit = int(max_tokens * buffer_ratio)
truncated_messages: List[Dict[str, str]] = []
system_prompt: Optional[Dict[str, str]] = None
if messages and messages[0].get("role") == "system":
system_prompt = messages[0]
remaining_messages = messages[1:]
if system_prompt:
truncated_messages.append(system_prompt)
else:
remaining_messages = messages
current_token_count = estimate_token_count(truncated_messages) if truncated_messages else 0
if current_token_count == -1:
logger.warning("Could not estimate initial token count for truncation, proceeding cautiously.", extra={'request_id': request_id})
current_token_count = 0
messages_to_add = []
for msg in reversed(remaining_messages):
potential_list = [msg] + messages_to_add
if system_prompt:
potential_list_with_system = [system_prompt] + potential_list
else:
potential_list_with_system = potential_list
next_token_count = estimate_token_count(potential_list_with_system)
if next_token_count != -1 and next_token_count <= target_token_limit:
messages_to_add.insert(0, msg)
current_token_count = next_token_count
elif next_token_count == -1:
logger.warning(f"Token estimation failed while adding message: {msg}. Stopping truncation early.", extra={'request_id': request_id})
break
else:
logger.debug(f"Stopping truncation: Adding next message would exceed target limit ({next_token_count} > {target_token_limit}).", extra={'request_id': request_id})
break
final_truncated_list = ([system_prompt] if system_prompt else []) + messages_to_add
original_count = len(messages)
final_count = len(final_truncated_list)
if final_count < original_count:
logger.warning(f"Context truncated: Kept {final_count}/{original_count} messages. Estimated tokens: ~{current_token_count}/{target_token_limit} (target).",
extra={'request_id': request_id, 'kept': final_count, 'original': original_count, 'estimated_tokens': current_token_count, 'target_limit': target_token_limit})
else:
logger.debug(f"Context truncation check complete. Kept all {final_count} messages. Estimated tokens: ~{current_token_count}.",
extra={'request_id': request_id, 'kept': final_count, 'estimated_tokens': current_token_count})
if not final_truncated_list and messages:
logger.error("Truncation resulted in an empty message list! Returning last message.", extra={'request_id': request_id})
return [messages[-1]]
elif not final_truncated_list:
logger.error("Truncation called with empty input, returning empty.", extra={'request_id': request_id})
return []
return final_truncated_list
def load_model():
global llm, N_CTX
logger.info(f"Attempting to load model: {MODEL_REPO}/{MODEL_FILE}")
effective_n_gpu_layers = 0
logger.info(f"Configuration: N_CTX={N_CTX}, N_BATCH={N_BATCH}, N_GPU_LAYERS={effective_n_gpu_layers} (forced CPU)")
try:
llm = Llama.from_pretrained(
repo_id=MODEL_REPO,
filename=MODEL_FILE,
n_ctx=N_CTX,
n_batch=N_BATCH,
n_gpu_layers=effective_n_gpu_layers,
verbose=False,
use_mmap=True,
use_mlock=True,
)
logger.info("Model loaded successfully.")
if llm:
actual_n_ctx = get_effective_n_ctx()
if actual_n_ctx != N_CTX:
logger.warning(f"Model's actual context size ({actual_n_ctx}) differs from initial config ({N_CTX}). Using actual value: {actual_n_ctx}", extra={'actual_n_ctx': actual_n_ctx, 'configured_n_ctx': N_CTX})
N_CTX = actual_n_ctx
actual_n_batch = llm.n_batch if hasattr(llm, 'n_batch') else N_BATCH
actual_n_gpu_layers = llm.n_gpu_layers if hasattr(llm, 'n_gpu_layers') else 0
logger.info(f"Actual Model Context Window (n_ctx): {N_CTX}")
logger.info(f"Actual Model Batch Size (n_batch): {actual_n_batch}")
logger.info(f"Actual Model GPU Layers (n_gpu_layers): {actual_n_gpu_layers} (should be 0 for CPU)")
if N_CTX < 1024 or actual_n_batch < 64:
logger.warning("Model loaded with relatively small N_CTX or N_BATCH. Performance or max generation length might be impacted.", extra={'n_ctx': N_CTX, 'n_batch': actual_n_batch})
if actual_n_gpu_layers > 0:
logger.warning(f"Model loaded with {actual_n_gpu_layers} GPU layers despite requesting 0. Check llama.cpp build or environment.", extra={'actual_gpu_layers': actual_n_gpu_layers})
try:
test_tokens = llm.tokenize(b"Test sentence.")
logger.info(f"Tokenizer test successful. 'Test sentence.' -> {len(test_tokens)} tokens.")
except Exception as tokenize_e:
logger.warning(f"Could not perform test tokenization: {tokenize_e}")
except Exception as e:
logger.error(f"Fatal error loading model: {e}", exc_info=True)
llm = None
logger.error("Model failed to load. Generation requests will not work.", extra={'error': str(e)})
app = Flask(__name__)
@app.before_request
def before_request_func():
g.request_id = str(uuid.uuid4())
logger.debug(f"Incoming request: {request.method} {request.path} from {request.remote_addr}", extra={'request_id': g.request_id, 'path': request.path, 'method': request.method})
load_model()
html_code = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>LLM API Demo</title>
<style>
body { font-family: sans-serif; margin: 20px; line-height: 1.6; background-color: #f4f4f4; color: #333; }
.container { max-width: 800px; margin: auto; background: #fff; padding: 20px; border-radius: 8px; box-shadow: 0 2px 5px rgba(0,0,0,0.1); }
h1, h2 { color: #0056b3; }
.section { margin-bottom: 30px; padding: 20px; background-color: #e9e9e9; border-radius: 5px; }
label { display: block; margin-bottom: 5px; font-weight: bold; }
input[type="text"], input[type="number"], textarea, select {
width: calc(100% - 22px); padding: 10px; margin-bottom: 10px; border: 1px solid #ccc; border-radius: 4px;
}
textarea { resize: vertical; min-height: 100px; }
button {
display: inline-block; background-color: #007bff; color: white; padding: 10px 15px; border: none; border-radius: 4px; cursor: pointer; font-size: 16px;
margin-right: 10px; transition: background-color 0.3s ease;
}
button:hover { background-color: #0056b3; }
button:disabled { background-color: #cccccc; cursor: not-allowed; }
.output {
background-color: #f9f9f9; border: 1px solid #ddd; padding: 15px; border-radius: 4px; white-space: pre-wrap; word-wrap: break-word; max-height: 400px; overflow-y: auto; font-family: monospace;
}
.error { color: red; font-weight: bold; }
.info { color: green; }
.warning { color: orange; }
.param-fixed { font-style: italic; color: #555; margin-bottom: 10px; }
.checkbox-container { display: flex; align-items: center; margin-bottom: 10px; }
.checkbox-container input { margin-right: 5px; width: auto; }
.continuation-info { font-weight: bold; }
</style>
</head>
<body>
<div class="container">
<h1>LLM API Demonstration</h1>
<div class="section">
<h2>Health Check</h2>
<button id="healthCheckBtn">Check Health</button>
<p id="healthStatus"></p>
</div>
<div class="section">
<h2>API Info</h2>
<button id="apiInfoBtn">Get Info</button>
<pre id="apiInfoOutput" class="output"></pre>
</div>
<div class="section">
<h2>Generate Text (Automatic Continuation with Context Management)</h2>
<label for="promptInput">Prompt / First User Message:</label>
<textarea id="promptInput" placeholder="Enter your prompt here..."></textarea>
<div class="param-fixed">Note: No artificial token limit. Generation continues until the model stops naturally, hits a stop sequence, or reaches the context window limit (N_CTX={{ N_CTX }}). If the context limit is reached, the server will attempt to continue automatically by <strong class="continuation-info">truncating older messages (unlimited continuations)</strong>. Other parameters (Temperature, Top P, Top K, Repeat Penalty, Seed) are fixed/random per generation cycle.</div>
<div>
<label for="stopInput">Stop Sequences (comma-separated):</label>
<input type="text" id="stopInput" value="">
</div>
<div>
<label for="systemPromptInput">System Prompt (Optional Override - default: "{{ DEFAULT_SYSTEM_PROMPT | escape }}"):</label>
<input type="text" id="systemPromptInput" placeholder="Leave empty to use default">
</div>
<div>
<label for="formatSelect">Format:</label>
<select id="formatSelect">
<option value="">None</option>
<option value="markdown">Markdown</option>
</select>
</div>
<div class="checkbox-container">
<input type="checkbox" id="streamCheckbox" checked>
<label for="streamCheckbox">Stream Output</label>
</div>
<button id="generateBtn">Generate</button>
<p id="generationStatus"></p>
<pre id="generationOutput" class="output"></pre>
</div>
</div>
<script>
const healthCheckBtn = document.getElementById('healthCheckBtn');
const healthStatus = document.getElementById('healthStatus');
const apiInfoBtn = document.getElementById('apiInfoBtn');
const apiInfoOutput = document.getElementById('apiInfoOutput');
const promptInput = document.getElementById('promptInput');
const stopInput = document.getElementById('stopInput');
const systemPromptInput = document.getElementById('systemPromptInput');
const formatSelect = document.getElementById('formatSelect');
const streamCheckbox = document.getElementById('streamCheckbox');
const generateBtn = document.getElementById('generateBtn');
const generationOutput = document.getElementById('generationOutput');
const generationStatus = document.getElementById('generationStatus');
const API_BASE_URL = window.location.origin;
async function checkHealth() {
healthStatus.textContent = 'Checking...';
healthStatus.className = '';
try {
const response = await fetch(`${API_BASE_URL}/health`);
const data = await response.json();
healthStatus.textContent = `Status: ${data.status}, Message: ${data.message}`;
healthStatus.className = data.status === 'ok' ? 'info' : (data.status === 'warning' ? 'warning' : 'error');
} catch (error) {
healthStatus.textContent = `Error fetching health: ${error}`;
healthStatus.className = 'error';
}
}
async function getApiInfo() {
apiInfoOutput.textContent = 'Loading...';
apiInfoOutput.className = 'output';
try {
const response = await fetch(`${API_BASE_URL}/info`);
if (!response.ok) {
try {
const errorData = await response.json();
throw new Error(`API Error ${response.status}: ${errorData.error || JSON.stringify(errorData)}`);
} catch (e) {
throw new Error(`API Error ${response.status}: ${response.statusText}`);
}
}
const data = await response.json();
const nCtx = data?.model_config?.loaded_model_details?.n_ctx || '{{ N_CTX }}';
const maxContDesc = data?.generation_parameters?.max_automatic_continuations === null ? "unlimited continuations" : `up to ${data?.generation_parameters?.max_automatic_continuations} times`;
const description = document.querySelector('.param-fixed');
if (description) {
description.innerHTML = `Note: No artificial token limit. Generation continues until the model stops naturally, hits a stop sequence, or reaches the context window limit (N_CTX=${nCtx}). If the context limit is reached, the server will attempt to continue automatically by <strong class="continuation-info">truncating older messages (${maxContDesc})</strong>. Other parameters (Temperature, Top P, Top K, Repeat Penalty, Seed) are fixed/random per generation cycle.`;
}
apiInfoOutput.textContent = JSON.stringify(data, null, 2);
} catch (error) {
apiInfoOutput.textContent = `Error fetching info: ${error}`;
apiInfoOutput.className = 'output error';
}
}
async function generateText() {
generationOutput.textContent = '';
generationStatus.textContent = 'Preparing request...';
generationStatus.className = '';
generateBtn.disabled = true;
const prompt = promptInput.value;
if (!prompt.trim()) {
generationStatus.textContent = 'Error: Prompt cannot be empty.';
generationStatus.className = 'error';
generateBtn.disabled = false;
return;
}
const messages = [{"role": "user", "content": prompt}];
const stream = streamCheckbox.checked;
const format = formatSelect.value || undefined;
const stopSequences = stopInput.value.split(',').map(s => s.trim()).filter(s => s.length > 0);
const stop = stopSequences.length > 0 ? stopSequences : undefined;
const systemPrompt = systemPromptInput.value.trim() || undefined;
const requestBody = {
messages: messages,
stop: stop,
stream: stream,
format: format,
system_prompt: systemPrompt
};
generationStatus.textContent = 'Generating... (may continue automatically with context truncation if needed)';
generationStatus.className = 'info';
try {
const response = await fetch(`${API_BASE_URL}/generate`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(requestBody),
});
if (!response.ok) {
const errorText = await response.text();
let errorMessage = `Error: ${response.status} ${response.statusText}`;
try {
const errorData = JSON.parse(errorText);
errorMessage += ` - ${errorData.error || JSON.stringify(errorData.detail || errorData)}`;
} catch (jsonParseError) {
errorMessage += ` - ${errorText}`;
}
generationStatus.textContent = errorMessage;
generationStatus.className = 'error';
generateBtn.disabled = false;
return;
}
if (stream) {
const reader = response.body.getReader();
const decoder = new TextDecoder('utf-8');
let finished = false;
generationOutput.textContent = '';
let continuationCount = 0;
while (!finished) {
const { done, value } = await reader.read();
if (done) {
finished = true;
if (!generationStatus.textContent.includes("finished") && !generationStatus.textContent.includes("stopped") && !generationStatus.textContent.includes("Error")) {
generationStatus.textContent = `Streaming finished. Continuations: ${continuationCount}.`;
generationStatus.className = 'info';
}
break;
}
const chunk = decoder.decode(value, { stream: true });
const continueMatch = chunk.match(/\n\[CONTINUING (\d+) - TRUNCATING CONTEXT\.\.\.\]\n/);
if (continueMatch) {
continuationCount = parseInt(continueMatch[1]);
generationOutput.textContent += chunk;
generationStatus.textContent = `Context limit reached, truncating history and continuing generation (Continuation #${continuationCount})...`;
generationStatus.className = 'warning continuation-info';
} else if (chunk.startsWith("\\n[ERROR]")) {
generationOutput.textContent += chunk;
generationStatus.textContent = 'Error during generation (see output).';
generationStatus.className = 'error';
finished = true;
} else if (chunk.startsWith("\\n[INFO] Generation stopped")) {
generationOutput.textContent += chunk;
generationStatus.textContent = `Generation stopped (see output for reason). Continuations: ${continuationCount}.`;
generationStatus.className = 'info';
finished = true;
} else {
generationOutput.textContent += chunk;
if (!generationStatus.className.includes('warning') && !generationStatus.className.includes('error')) {
generationStatus.textContent = `Streaming... (Continuation #${continuationCount})`;
generationStatus.className = 'info';
}
}
generationOutput.scrollTop = generationOutput.scrollHeight;
}
} else {
const text = await response.text();
const finishReason = response.headers.get('X-Finish-Reason');
const continuations = response.headers.get('X-Continuations');
const usageTokens = response.headers.get('X-Usage-Completion-Tokens');
generationOutput.textContent = text;
let statusText = `Generation finished. Reason: ${finishReason || 'unknown'}.`;
if (continuations && parseInt(continuations) > 0) {
statusText += ` Continuations: ${continuations} (context truncated).`;
generationStatus.className = 'warning continuation-info';
} else {
generationStatus.className = 'info';
}
if (usageTokens) statusText += ` Tokens: ~${usageTokens}.`;
generationStatus.textContent = statusText;
}
} catch (error) {
generationStatus.textContent = `Network or processing error: ${error}`;
generationStatus.className = 'error';
} finally {
generateBtn.disabled = false;
}
}
healthCheckBtn.addEventListener('click', checkHealth);
apiInfoBtn.addEventListener('click', getApiInfo);
generateBtn.addEventListener('click', generateText);
checkHealth();
getApiInfo();
</script>
</body>
</html>
"""
@app.route("/")
def index():
rendered_html = render_template_string(
html_code,
N_CTX=N_CTX,
DEFAULT_SYSTEM_PROMPT=DEFAULT_SYSTEM_PROMPT
)
return rendered_html
@app.route("/health", methods=["GET"])
def health_check():
if llm:
if hasattr(llm, 'tokenize') and hasattr(llm, 'apply_chat_template'):
return jsonify(status="ok", message="Model is loaded and ready."), 200
else:
logger.warning("Model loaded, but tokenizer or chat template functions might be missing.")
return jsonify(status="warning", message="Model loaded, but critical functions (tokenize/apply_chat_template) might be missing."), 200
else:
return jsonify(status="error", message="Model failed to load or is not available."), 503
@app.route("/info", methods=["GET"])
def model_info():
request_id = getattr(g, 'request_id', 'N/A')
if not llm:
logger.warning("Info request received but model is not loaded.", extra={'request_id': request_id})
return jsonify(error="Model not available."), 503
model_details: Union[Dict[str, Any], str] = "Model details unavailable"
actual_n_ctx = get_effective_n_ctx()
actual_n_batch = N_BATCH
actual_n_gpu_layers = N_GPU_LAYERS
try:
actual_n_batch = llm.n_batch if hasattr(llm, 'n_batch') else N_BATCH
actual_n_gpu_layers = llm.n_gpu_layers if hasattr(llm, 'n_gpu_layers') else 0
n_embd = 'N/A'
if hasattr(llm, '_model') and hasattr(llm._model, 'n_embd') and callable(llm._model.n_embd):
try:
n_embd = llm._model.n_embd()
except Exception as embd_e:
logger.warning(f"Could not get n_embd: {embd_e}", extra={'request_id': request_id})
model_details = {
"n_embd": n_embd,
"n_ctx": actual_n_ctx,
"n_batch": actual_n_batch,
"n_gpu_layers": actual_n_gpu_layers,
"tokenizer_present": hasattr(llm, 'tokenize'),
"chat_handler_present": hasattr(llm, 'apply_chat_template') and hasattr(llm, 'create_chat_completion'),
}
except Exception as e:
logger.warning(f"Could not retrieve all model details: {e}", extra={'request_id': request_id})
model_details = f"Error retrieving some model details: {e}"
info = {
"status": "ok",
"message": "Model is loaded. Generation continues automatically with context truncation if context limit is hit.",
"model_config": {
"repo_id": MODEL_REPO,
"filename": MODEL_FILE,
"initial_load_config": {
"n_ctx": os.getenv("N_CTX", "2048"),
"n_batch": N_BATCH,
"n_gpu_layers": 0,
},
"loaded_model_details": model_details,
},
"generation_parameters": {
"note": f"No artificial 'max_tokens' limit. Generation proceeds until stop sequence, EOS, or context limit (N_CTX={actual_n_ctx}). Automatic continuation attempts by truncating context **indefinitely** if context limit is reached. Sampling parameters (temperature, top_p, top_k) are chosen randomly per request/continuation cycle from predefined sets. Repeat penalty and seed are fixed.",
"fixed_max_tokens": None,
"fixed_repeat_penalty": FIXED_REPEAT_PENALTY,
"fixed_seed": FIXED_SEED,
"max_automatic_continuations": None,
"context_truncation_buffer_ratio": CONTEXT_TRUNCATION_BUFFER_RATIO,
"randomly_chosen_from": RANDOM_PARAMS_CHOICES,
"default_system_prompt": DEFAULT_SYSTEM_PROMPT,
"user_controllable": ["messages", "prompt", "stop", "stream", "format", "system_prompt"],
},
}
return jsonify(info), 200
@app.route("/generate", methods=["POST"])
def generate():
request_id = getattr(g, 'request_id', 'N/A')
if not llm:
logger.error("Generate request received but model is not loaded.", extra={'request_id': request_id})
return jsonify(error="Model is not available.", detail="The LLM model could not be loaded."), 503
if not request.is_json:
logger.warning("Request received without Content-Type: application/json", extra={'request_id': request_id})
return jsonify(error="Invalid request header", detail="Content-Type must be application/json"), 415
data = request.get_json()
is_streaming = data.get("stream", True)
response_format = data.get("format")
log_data_summary = {k: v for k, v in data.items() if k not in ('messages', 'prompt')}
log_data_summary['messages_count_initial'] = len(data.get('messages', [])) if 'messages' in data else 0
log_data_summary['has_prompt_initial'] = 'prompt' in data
log_data_summary['stream'] = is_streaming
log_data_summary['format'] = response_format
logger.info(f"Received generation request summary.", extra={'request_id': request_id, 'summary': log_data_summary})
try:
initial_messages = prepare_messages(data, format=response_format)
base_generation_params = parse_and_validate_params(data)
effective_n_ctx = get_effective_n_ctx()
input_token_count = estimate_token_count(initial_messages)
if input_token_count != -1 and input_token_count >= effective_n_ctx:
truncated_initial = truncate_messages_for_context(initial_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO)
truncated_token_count = estimate_token_count(truncated_initial)
if truncated_token_count != -1 and truncated_token_count >= effective_n_ctx:
error_msg = f"Initial input exceeds context window ({effective_n_ctx}) even after attempting truncation. Input tokens (~{input_token_count}) / Truncated tokens (~{truncated_token_count}). Reduce initial message size significantly."
logger.error(error_msg, extra={'request_id': request_id, 'initial_tokens': input_token_count, 'truncated_tokens': truncated_token_count, 'n_ctx': effective_n_ctx})
return jsonify(error="Input exceeds context window", detail=error_msg), 400
else:
logger.warning(f"Initial input (~{input_token_count} tokens) exceeded context window ({effective_n_ctx}). Truncated to ~{truncated_token_count} tokens.", extra={'request_id': request_id, 'initial_tokens': input_token_count, 'truncated_tokens': truncated_token_count, 'n_ctx': effective_n_ctx})
initial_messages = truncated_initial
input_token_count = truncated_token_count
elif input_token_count != -1:
logger.info(f"Initial input token count: ~{input_token_count}. Context window: {effective_n_ctx}. Remaining: {effective_n_ctx - input_token_count}.", extra={'request_id': request_id, 'input_tokens': input_token_count, 'n_ctx': effective_n_ctx, 'remaining_ctx': effective_n_ctx - input_token_count})
else:
logger.warning("Could not estimate initial token count. Proceeding with generation, may hit context limit.", extra={'request_id': request_id})
logger.info(f"Processing request with {len(initial_messages)} initial messages. Stream={is_streaming}. Format={response_format}. max_tokens=None (dynamic). Unlimited Continuations.", extra={'request_id': request_id})
except ValueError as e:
logger.error(f"Invalid input data: {e}", exc_info=True, extra={'request_id': request_id})
try: error_detail = json.loads(str(e))
except json.JSONDecodeError: error_detail = str(e)
return jsonify(error="Invalid input", detail=error_detail), 400
except Exception as e:
logger.error(f"Unexpected error preparing request: {e}", exc_info=True, extra={'request_id': request_id})
return jsonify(error="Internal server error", detail="An unexpected error occurred processing the request."), 500
if is_streaming:
def generate_streaming_with_continuation(current_request_id: str) -> Generator[str, None, None]:
current_messages = list(initial_messages)
continuations = 0
total_tokens_generated_stream = 0
effective_n_ctx = get_effective_n_ctx()
while True:
cycle_number = continuations + 1
logger.info(f"Starting streaming generation cycle {cycle_number}. Message count: {len(current_messages)}.", extra={'request_id': current_request_id})
chosen_params = random.choice(RANDOM_PARAMS_CHOICES)
current_dynamic_params = {
"temperature": chosen_params["temperature"],
"top_p": chosen_params["top_p"],
"top_k": chosen_params["top_k"],
}
current_params = {**base_generation_params, **current_dynamic_params}
logger.debug(f"Cycle {cycle_number} params: temp={current_params['temperature']}, top_p={current_params['top_p']}, top_k={current_params['top_k']}", extra={'request_id': current_request_id})
generated_this_cycle = ""
finish_reason = None
hit_context_limit_in_cycle = False
try:
streamer = llm.create_chat_completion(
messages=current_messages,
max_tokens=current_params["max_tokens"],
temperature=current_params["temperature"],
top_p=current_params["top_p"],
top_k=current_params["top_k"],
repeat_penalty=current_params["repeat_penalty"],
stop=current_params["stop"],
seed=current_params["seed"],
stream=True,
)
for chunk in streamer:
choice = chunk.get("choices", [{}])[0]
delta = choice.get("delta", {})
token = delta.get("content")
current_chunk_finish_reason = choice.get("finish_reason")
if token:
generated_this_cycle += token
total_tokens_generated_stream += 1
yield token
if current_chunk_finish_reason:
finish_reason = current_chunk_finish_reason
logger.info(f"Streaming chunk finished cycle {cycle_number}. Reason: {finish_reason}", extra={'request_id': current_request_id, 'finish_reason': finish_reason})
if finish_reason == 'length':
hit_context_limit_in_cycle = True
usage = chunk.get("usage")
if usage: logger.debug(f"Usage reported in final chunk: {usage}", extra={'request_id': current_request_id, 'usage': usage})
break
if not finish_reason:
pass
except Exception as e:
err_str = str(e).lower()
if "context window is full" in err_str or \
"kv cache is full" in err_str or \
"llama_decode" in err_str or \
(hasattr(e, 'condition') and ("context length" in str(e.condition).lower() or "failed to decode" in str(e.condition).lower())):
logger.warning(f"N_CTX limit or related exception caught during streaming cycle {cycle_number}: {e}", extra={'request_id': current_request_id})
hit_context_limit_in_cycle = True
finish_reason = 'length'
else:
logger.error(f"Unhandled error during streaming generation cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': current_request_id})
yield f"\n[ERROR] Generation failed unexpectedly in cycle {cycle_number}: {str(e)}"
return
if generated_this_cycle:
if not current_messages or current_messages[-1].get('role') != 'assistant':
current_messages.append({"role": "assistant", "content": generated_this_cycle})
else:
current_messages[-1]['content'] += generated_this_cycle
elif hit_context_limit_in_cycle:
logger.warning(f"Context limit hit in streaming cycle {cycle_number} but no tokens were generated in this cycle. Check model behavior.", extra={'request_id': current_request_id})
elif not finish_reason:
logger.warning(f"Stream cycle {cycle_number} ended without generating tokens or a definite finish reason. Stopping.", extra={'request_id': current_request_id})
yield f"\n[INFO] Generation stopped: Cycle ended unexpectedly."
break
if finish_reason == 'stop':
logger.info(f"Generation stopped naturally (reason: stop) in streaming cycle {cycle_number}. Total stream tokens: ~{total_tokens_generated_stream}", extra={'request_id': current_request_id})
yield f"\n[INFO] Generation stopped: Stop sequence or EOS."
break
elif hit_context_limit_in_cycle:
continuations += 1
logger.warning(f"N_CTX limit reached in streaming cycle {cycle_number}. Attempting continuation {continuations} (reinicio de contador).", extra={'request_id': current_request_id})
current_messages = truncate_messages_for_context(current_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO)
if not current_messages:
logger.error("Context truncation resulted in empty messages during streaming. Stopping.", extra={'request_id': current_request_id})
yield f"\n[ERROR] Generation failed: Context truncation error."
break
yield f"\n[CONTINUING {continuations} - TRUNCATING CONTEXT...]\n"
time.sleep(0.1)
continue
else:
logger.warning(f"Streaming generation cycle {cycle_number} ended with reason '{finish_reason}' or unexpectedly. Stopping generation.", extra={'request_id': current_request_id, 'finish_reason': finish_reason})
yield f"\n[INFO] Generation stopped: Reason: {finish_reason or 'Unknown'}"
break
logger.info(f"Streaming generation finished after {continuations} continuations. Total stream tokens generated: ~{total_tokens_generated_stream}", extra={'request_id': current_request_id, 'continuations': continuations, 'total_stream_tokens': total_tokens_generated_stream})
headers = {
"Content-Type": "text/event-stream; charset=utf-8",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"X-Request-ID": request_id
}
return Response(stream_with_context(generate_streaming_with_continuation(request_id)), headers=headers)
else:
current_messages = list(initial_messages)
continuations = 0
full_generated_text = ""
total_tokens_generated_nonstream = 0
final_finish_reason = "unknown"
final_usage = {}
effective_n_ctx = get_effective_n_ctx()
while True:
cycle_number = continuations + 1
logger.info(f"Starting non-streaming generation cycle {cycle_number}. Message count: {len(current_messages)}.", extra={'request_id': request_id})
chosen_params = random.choice(RANDOM_PARAMS_CHOICES)
current_dynamic_params = {
"temperature": chosen_params["temperature"],
"top_p": chosen_params["top_p"],
"top_k": chosen_params["top_k"],
}
current_params = {**base_generation_params, **current_dynamic_params}
logger.debug(f"Cycle {cycle_number} params: temp={current_params['temperature']}, top_p={current_params['top_p']}, top_k={current_params['top_k']}", extra={'request_id': request_id})
generated_this_cycle = ""
finish_reason = None
hit_context_limit_in_cycle = False
usage_this_cycle = {}
try:
result = llm.create_chat_completion(
messages=current_messages,
max_tokens=current_params["max_tokens"],
temperature=current_params["temperature"],
top_p=current_params["top_p"],
top_k=current_params["top_k"],
repeat_penalty=current_params["repeat_penalty"],
stop=current_params["stop"],
seed=current_params["seed"],
stream=False,
)
if result and "choices" in result and result["choices"]:
choice = result["choices"][0]
generated_this_cycle = choice.get("message", {}).get("content", "")
finish_reason = choice.get("finish_reason", "unknown")
else:
logger.error(f"Invalid response structure from llama_cpp in non-streaming cycle {cycle_number}: {result}", extra={'request_id': request_id})
return jsonify(error="Generation failed", detail=f"Invalid response structure from model in cycle {cycle_number}."), 500
usage_this_cycle = result.get("usage", {})
final_finish_reason = finish_reason
if usage_this_cycle: final_usage = usage_this_cycle
logger.info(f"Non-streaming cycle {cycle_number} finished. Reason: {finish_reason}. Usage: {usage_this_cycle}", extra={'request_id': request_id, 'usage': usage_this_cycle, 'finish_reason': finish_reason})
if finish_reason == 'length':
hit_context_limit_in_cycle = True
except Exception as e:
err_str = str(e).lower()
if "context window is full" in err_str or \
"kv cache is full" in err_str or \
"llama_decode" in err_str or \
(hasattr(e, 'condition') and ("context length" in str(e.condition).lower() or "failed to decode" in str(e.condition).lower())):
logger.warning(f"N_CTX limit or related exception caught during non-streaming cycle {cycle_number}: {e}", extra={'request_id': request_id})
hit_context_limit_in_cycle = True
finish_reason = 'length'
else:
logger.error(f"Unhandled error during non-streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': request_id})
return jsonify(error="Generation failed", detail=f"Internal error in cycle {cycle_number}: {str(e)}"), 500
if generated_this_cycle:
if continuations > 0 and full_generated_text:
full_generated_text += f"\n\n[CONTINUATION {continuations} - TRUNCATED CONTEXT]\n\n"
full_generated_text += generated_this_cycle
tokens_generated_cycle = usage_this_cycle.get("completion_tokens", 0)
total_tokens_generated_nonstream += tokens_generated_cycle
if not current_messages or current_messages[-1].get('role') != 'assistant':
current_messages.append({"role": "assistant", "content": generated_this_cycle})
else:
current_messages[-1]['content'] += generated_this_cycle
elif hit_context_limit_in_cycle:
logger.warning(f"Non-streaming N_CTX limit hit in cycle {cycle_number} but no completion tokens reported.", extra={'request_id': request_id})
if continuations > 0 and full_generated_text:
full_generated_text += f"\n\n[CONTINUATION {continuations} - TRUNCATED CONTEXT - NO OUTPUT THIS CYCLE]\n\n"
elif not finish_reason:
logger.warning(f"Non-streaming cycle {cycle_number} ended without generating tokens or a finish reason. Stopping.", extra={'request_id': request_id})
full_generated_text += f"\n[INFO: Generation stopped: Cycle {cycle_number} ended unexpectedly.]"
break
if finish_reason == 'stop':
logger.info(f"Non-streaming generation stopped naturally (reason: stop) in cycle {cycle_number}.", extra={'request_id': request_id})
break
elif hit_context_limit_in_cycle:
continuations += 1
logger.warning(f"Non-streaming N_CTX limit reached in cycle {cycle_number}. Attempting continuation {continuations} (reinicio de contador).", extra={'request_id': request_id})
current_messages = truncate_messages_for_context(current_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO)
if not current_messages:
logger.error("Context truncation resulted in empty messages during non-streaming. Stopping.", extra={'request_id': request_id})
full_generated_text += f"\n[ERROR: Generation failed: Context truncation error.]"
final_finish_reason = "truncation_error"
break
continue
else:
logger.warning(f"Non-streaming cycle {cycle_number} ended with reason '{finish_reason}' or unexpectedly. Stopping generation.", extra={'request_id': request_id, 'finish_reason': finish_reason})
full_generated_text += f"\n\n[INFO: Generation stopped unexpectedly. Reason: {finish_reason or 'Unknown'}]"
break
logger.info(f"Non-streaming generation finished after {continuations} continuations. Total completion tokens reported: {total_tokens_generated_nonstream}. Final reason: {final_finish_reason}", extra={'request_id': request_id, 'continuations': continuations, 'total_completion_tokens': total_tokens_generated_nonstream, 'final_reason': final_finish_reason})
response = Response(full_generated_text, mimetype="text/plain; charset=utf-8")
response.headers["X-Request-ID"] = request_id
response.headers["X-Finish-Reason"] = final_finish_reason
response.headers["X-Continuations"] = str(continuations)
total_prompt_tokens = final_usage.get("prompt_tokens", "N/A")
response.headers["X-Usage-Completion-Tokens"] = str(total_tokens_generated_nonstream)
response.headers["X-Usage-Prompt-Tokens-Last-Cycle"] = str(total_prompt_tokens)
response.headers["X-Usage-Total-Tokens-Last-Cycle"] = str(final_usage.get("total_tokens", "N/A"))
return response
if __name__ == "__main__":
host = os.getenv("HOST", "0.0.0.0")
port = int(os.getenv("PORT", "7860"))
is_debug = os.getenv("FLASK_DEBUG", "0") == "1"
log_level = logging.DEBUG if is_debug else logging.INFO
logger.setLevel(log_level)
logger.info(f"Starting Flask server on {host}:{port} (Debug mode: {is_debug})")
logger.info(f"Model: {MODEL_REPO}/{MODEL_FILE}, N_CTX={N_CTX}, Automatic Continuations: UNLIMITED (with context truncation)")
if not llm:
logger.critical("MODEL FAILED TO LOAD. SERVER WILL START BUT '/generate' WILL FAIL.")
if not is_debug:
try:
from waitress import serve
logger.info("Running with Waitress production server.")
serve(app, host=host, port=port, threads=8)
except ImportError:
logger.warning("Waitress not found. Falling back to Flask development server. Install waitress for production.")
app.run(host=host, port=port, threaded=True, debug=is_debug)
else:
logger.info("Running with Flask development server (Debug=True).")
app.run(host=host, port=port, threaded=True, debug=is_debug, use_reloader=False)