|
|
|
|
|
|
|
import os |
|
import json |
|
import logging |
|
import uuid |
|
from typing import List, Dict, Union, Optional, Generator, Any |
|
import random |
|
import time |
|
|
|
from flask import Flask, request, Response, stream_with_context, jsonify, g, render_template_string |
|
from llama_cpp import Llama |
|
|
|
class JsonFormatter(logging.Formatter): |
|
def format(self, record): |
|
log_record = { |
|
"timestamp": self.formatTime(record, self.datefmt), |
|
"level": record.levelname, |
|
"name": record.name, |
|
"message": record.getMessage(), |
|
"pathname": record.pathname, |
|
"lineno": record.lineno, |
|
} |
|
if hasattr(record, 'request_id'): |
|
log_record['request_id'] = record.request_id |
|
if record.exc_info: |
|
log_record['exception'] = self.formatException(record.exc_info) |
|
if record.stack_info: |
|
log_record['stack_info'] = self.formatStack(record.stack_info) |
|
skip_keys = {'message', 'asctime', 'levelname', 'levelno', 'pathname', 'filename', 'module', 'funcName', 'lineno', 'created', 'msecs', 'relativeCreated', 'thread', 'threadName', 'process', 'processName', 'exc_info', 'exc_text', 'stack_info', 'request_id'} |
|
for key, value in record.__dict__.items(): |
|
if not key.startswith('_') and key not in log_record and key not in skip_keys: |
|
log_record[key] = value |
|
return json.dumps(log_record) |
|
|
|
|
|
def setup_logging(): |
|
logger = logging.getLogger() |
|
if not logger.handlers: |
|
handler = logging.StreamHandler() |
|
formatter = JsonFormatter() |
|
handler.setFormatter(formatter) |
|
logger.addHandler(handler) |
|
logger.setLevel(logging.INFO) |
|
logging.getLogger("werkzeug").setLevel(logging.ERROR) |
|
logging.getLogger("llama_cpp").setLevel(logging.WARNING) |
|
return logger |
|
|
|
logger = setup_logging() |
|
|
|
|
|
MODEL_REPO = os.getenv("MODEL_REPO", "jnjj/vcvcvcv") |
|
MODEL_FILE = os.getenv("MODEL_FILE", "gemma-3-4b-it-q4_0.gguf") |
|
N_CTX = int(os.getenv("N_CTX", "2048")) |
|
N_BATCH = int(os.getenv("N_BATCH", "512")) |
|
N_GPU_LAYERS = 0 |
|
|
|
FIXED_REPEAT_PENALTY = float(os.getenv("FIXED_REPEAT_PENALTY", "1.1")) |
|
FIXED_SEED = int(os.getenv("FIXED_SEED", "-1")) |
|
DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "Eres un asistente conciso, directo y útil.") |
|
CONTEXT_TRUNCATION_BUFFER_RATIO = float(os.getenv("CONTEXT_TRUNCATION_BUFFER_RATIO", "0.85")) |
|
|
|
RANDOM_PARAMS_CHOICES = [ |
|
{"top_k": 10, "top_p": 0.5, "temperature": 0.2}, |
|
{"top_k": 10, "top_p": 0.5, "temperature": 0.1}, |
|
{"top_k": 10, "top_p": 0.5, "temperature": 0.3}, |
|
{"top_k": 10, "top_p": 0.5, "temperature": 0.4}, |
|
{"top_k": 5, "top_p": 0.3, "temperature": 0.6}, |
|
{"top_k": 20, "top_p": 0.7, "temperature": 0.5}, |
|
] |
|
|
|
llm: Optional[Llama] = None |
|
|
|
|
|
def parse_and_validate_params(data: Dict) -> Dict: |
|
request_id = getattr(g, 'request_id', 'N/A') |
|
params = {} |
|
errors = {} |
|
|
|
params["max_tokens"] = None |
|
|
|
chosen_params = random.choice(RANDOM_PARAMS_CHOICES) |
|
params["temperature"] = chosen_params["temperature"] |
|
params["top_p"] = chosen_params["top_p"] |
|
params["top_k"] = chosen_params["top_k"] |
|
|
|
params["repeat_penalty"] = FIXED_REPEAT_PENALTY |
|
params["seed"] = FIXED_SEED |
|
|
|
stop = data.get("stop") |
|
if stop is not None: |
|
if isinstance(stop, list) and all(isinstance(s, str) for s in stop): |
|
params["stop"] = stop |
|
elif isinstance(stop, str): |
|
params["stop"] = [stop] |
|
else: |
|
errors["stop"] = "Stop must be a string or a list of strings" |
|
else: |
|
params["stop"] = None |
|
|
|
if errors: |
|
logger.error(f"Parameter validation failed for allowed fields: {errors}", extra={'request_id': request_id}) |
|
raise ValueError(json.dumps(errors)) |
|
|
|
logger.debug(f"Using parameters: max_tokens={params['max_tokens']}, repeat_penalty={params['repeat_penalty']}, seed={params['seed']}, temperature={params['temperature']}, top_p={params['top_p']}, top_k={params['top_k']}", extra={'request_id': request_id}) |
|
|
|
return params |
|
|
|
|
|
def prepare_messages(data: Dict, format: Optional[str] = None) -> List[Dict[str, str]]: |
|
request_id = getattr(g, 'request_id', 'N/A') |
|
messages_list = data.get("messages") |
|
prompt_str = data.get("prompt") |
|
system_instruction = data.get("system_prompt", DEFAULT_SYSTEM_PROMPT) |
|
|
|
if not messages_list and not prompt_str: |
|
raise ValueError("Either 'messages' list or 'prompt' string is required.") |
|
|
|
if messages_list and not isinstance(messages_list, list): |
|
raise ValueError("'messages' must be a list of dictionaries.") |
|
if prompt_str and not isinstance(prompt_str, str): |
|
raise ValueError("'prompt' must be a string.") |
|
if system_instruction and not isinstance(system_instruction, str): |
|
raise ValueError("'system_prompt' must be a string.") |
|
|
|
final_messages = [] |
|
|
|
content_format_instruction = "" |
|
if format == "markdown": |
|
content_format_instruction = " Format your response using Markdown." |
|
elif format is not None: |
|
logger.warning(f"Unsupported format '{format}' requested.", extra={'request_id': request_id}) |
|
|
|
effective_system_prompt = system_instruction.strip() + content_format_instruction.strip() |
|
if effective_system_prompt: |
|
final_messages.append({"role": "system", "content": effective_system_prompt}) |
|
|
|
user_provided_system = False |
|
if messages_list: |
|
has_user_message = False |
|
for i, msg in enumerate(messages_list): |
|
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: |
|
raise ValueError(f"Message at index {i} is invalid: must be a dictionary with 'role' and 'content'.") |
|
|
|
role = msg.get("role") |
|
content = msg.get("content", "") |
|
|
|
if not isinstance(content, str): |
|
logger.warning(f"Message content at index {i} (role: {role}) is not a string (type: {type(content)}). Converting to string.", extra={'request_id': request_id}) |
|
content = str(content) |
|
|
|
if role == "system": |
|
if i == 0 and final_messages and final_messages[0]["role"] == "system": |
|
logger.info("Replacing default system prompt with user-provided system message.", extra={'request_id': request_id}) |
|
final_messages[0] = {"role": "system", "content": content} |
|
user_provided_system = True |
|
elif i == 0 and not final_messages: |
|
final_messages.append({"role": "system", "content": content}) |
|
user_provided_system = True |
|
else: |
|
logger.warning(f"Ignoring additional system message at index {i} as system prompt is already set or should be at the start.", extra={'request_id': request_id}) |
|
continue |
|
|
|
elif role == "user": |
|
has_user_message = True |
|
|
|
final_messages.append({"role": role, "content": content}) |
|
|
|
if not has_user_message and any(m["role"] != "system" for m in final_messages): |
|
logger.warning("The 'messages' list contains no user messages.", extra={'request_id': request_id}) |
|
|
|
elif prompt_str: |
|
final_messages.append({"role": "user", "content": prompt_str}) |
|
|
|
if not final_messages or all(m["role"] == "system" for m in final_messages): |
|
raise ValueError("No user or assistant messages found to generate a response.") |
|
|
|
return final_messages |
|
|
|
|
|
def estimate_token_count(messages: List[Dict[str, str]]) -> int: |
|
request_id = getattr(g, 'request_id', 'N/A') |
|
if not llm or not hasattr(llm, 'tokenize') or not hasattr(llm, 'apply_chat_template'): |
|
logger.warning("LLM or tokenizer/template function not available for token estimation.", extra={'request_id': request_id}) |
|
return -1 |
|
|
|
if not hasattr(llm, 'tokenize') or not hasattr(llm, 'apply_chat_template'): |
|
logger.warning("`tokenize` or `apply_chat_template` not found on LLM object. Cannot estimate tokens accurately.", extra={'request_id': request_id}) |
|
char_count = sum(len(m.get('content', '')) for m in messages) |
|
return char_count // 4 |
|
|
|
try: |
|
chat_prompt_string = llm.apply_chat_template(messages, add_generation_prompt=False) |
|
tokens = llm.tokenize(chat_prompt_string.encode('utf-8', errors='ignore'), add_bos=True) |
|
return len(tokens) |
|
except Exception as e: |
|
try: |
|
simple_text = "\n".join([f"{m.get('role', 'unknown')}: {m.get('content', '')}" for m in messages]) |
|
tokens = llm.tokenize(simple_text.encode('utf-8', errors='ignore'), add_bos=True) |
|
logger.warning(f"Chat template failed during token estimation, using simple join. Error: {e}", extra={'request_id': request_id}) |
|
return len(tokens) |
|
except Exception as e_inner: |
|
logger.error(f"Could not estimate token count using either method: {e_inner}", exc_info=True, extra={'request_id': request_id}) |
|
return -1 |
|
|
|
def get_effective_n_ctx() -> int: |
|
if llm and hasattr(llm, 'n_ctx') and callable(llm.n_ctx): |
|
try: |
|
return llm.n_ctx() |
|
except Exception: |
|
logger.warning("Failed to call llm.n_ctx(), falling back to N_CTX config value.") |
|
return N_CTX |
|
return N_CTX |
|
|
|
|
|
def truncate_messages_for_context(messages: List[Dict[str, str]], max_tokens: int, buffer_ratio: float) -> List[Dict[str, str]]: |
|
request_id = getattr(g, 'request_id', 'N/A') |
|
if not llm: return messages |
|
|
|
target_token_limit = int(max_tokens * buffer_ratio) |
|
truncated_messages: List[Dict[str, str]] = [] |
|
system_prompt: Optional[Dict[str, str]] = None |
|
|
|
if messages and messages[0].get("role") == "system": |
|
system_prompt = messages[0] |
|
remaining_messages = messages[1:] |
|
if system_prompt: |
|
truncated_messages.append(system_prompt) |
|
else: |
|
remaining_messages = messages |
|
|
|
current_token_count = estimate_token_count(truncated_messages) if truncated_messages else 0 |
|
if current_token_count == -1: |
|
logger.warning("Could not estimate initial token count for truncation, proceeding cautiously.", extra={'request_id': request_id}) |
|
current_token_count = 0 |
|
|
|
messages_to_add = [] |
|
for msg in reversed(remaining_messages): |
|
potential_list = [msg] + messages_to_add |
|
if system_prompt: |
|
potential_list_with_system = [system_prompt] + potential_list |
|
else: |
|
potential_list_with_system = potential_list |
|
|
|
next_token_count = estimate_token_count(potential_list_with_system) |
|
|
|
if next_token_count != -1 and next_token_count <= target_token_limit: |
|
messages_to_add.insert(0, msg) |
|
current_token_count = next_token_count |
|
elif next_token_count == -1: |
|
logger.warning(f"Token estimation failed while adding message: {msg}. Stopping truncation early.", extra={'request_id': request_id}) |
|
break |
|
else: |
|
logger.debug(f"Stopping truncation: Adding next message would exceed target limit ({next_token_count} > {target_token_limit}).", extra={'request_id': request_id}) |
|
break |
|
|
|
final_truncated_list = ([system_prompt] if system_prompt else []) + messages_to_add |
|
|
|
original_count = len(messages) |
|
final_count = len(final_truncated_list) |
|
if final_count < original_count: |
|
logger.warning(f"Context truncated: Kept {final_count}/{original_count} messages. Estimated tokens: ~{current_token_count}/{target_token_limit} (target).", |
|
extra={'request_id': request_id, 'kept': final_count, 'original': original_count, 'estimated_tokens': current_token_count, 'target_limit': target_token_limit}) |
|
else: |
|
logger.debug(f"Context truncation check complete. Kept all {final_count} messages. Estimated tokens: ~{current_token_count}.", |
|
extra={'request_id': request_id, 'kept': final_count, 'estimated_tokens': current_token_count}) |
|
|
|
if not final_truncated_list and messages: |
|
logger.error("Truncation resulted in an empty message list! Returning last message.", extra={'request_id': request_id}) |
|
return [messages[-1]] |
|
elif not final_truncated_list: |
|
logger.error("Truncation called with empty input, returning empty.", extra={'request_id': request_id}) |
|
return [] |
|
|
|
return final_truncated_list |
|
|
|
|
|
def load_model(): |
|
global llm, N_CTX |
|
logger.info(f"Attempting to load model: {MODEL_REPO}/{MODEL_FILE}") |
|
effective_n_gpu_layers = 0 |
|
logger.info(f"Configuration: N_CTX={N_CTX}, N_BATCH={N_BATCH}, N_GPU_LAYERS={effective_n_gpu_layers} (forced CPU)") |
|
try: |
|
llm = Llama.from_pretrained( |
|
repo_id=MODEL_REPO, |
|
filename=MODEL_FILE, |
|
n_ctx=N_CTX, |
|
n_batch=N_BATCH, |
|
n_gpu_layers=effective_n_gpu_layers, |
|
verbose=False, |
|
use_mmap=True, |
|
use_mlock=True, |
|
) |
|
logger.info("Model loaded successfully.") |
|
if llm: |
|
actual_n_ctx = get_effective_n_ctx() |
|
if actual_n_ctx != N_CTX: |
|
logger.warning(f"Model's actual context size ({actual_n_ctx}) differs from initial config ({N_CTX}). Using actual value: {actual_n_ctx}", extra={'actual_n_ctx': actual_n_ctx, 'configured_n_ctx': N_CTX}) |
|
N_CTX = actual_n_ctx |
|
|
|
actual_n_batch = llm.n_batch if hasattr(llm, 'n_batch') else N_BATCH |
|
actual_n_gpu_layers = llm.n_gpu_layers if hasattr(llm, 'n_gpu_layers') else 0 |
|
|
|
logger.info(f"Actual Model Context Window (n_ctx): {N_CTX}") |
|
logger.info(f"Actual Model Batch Size (n_batch): {actual_n_batch}") |
|
logger.info(f"Actual Model GPU Layers (n_gpu_layers): {actual_n_gpu_layers} (should be 0 for CPU)") |
|
|
|
if N_CTX < 1024 or actual_n_batch < 64: |
|
logger.warning("Model loaded with relatively small N_CTX or N_BATCH. Performance or max generation length might be impacted.", extra={'n_ctx': N_CTX, 'n_batch': actual_n_batch}) |
|
if actual_n_gpu_layers > 0: |
|
logger.warning(f"Model loaded with {actual_n_gpu_layers} GPU layers despite requesting 0. Check llama.cpp build or environment.", extra={'actual_gpu_layers': actual_n_gpu_layers}) |
|
|
|
try: |
|
test_tokens = llm.tokenize(b"Test sentence.") |
|
logger.info(f"Tokenizer test successful. 'Test sentence.' -> {len(test_tokens)} tokens.") |
|
except Exception as tokenize_e: |
|
logger.warning(f"Could not perform test tokenization: {tokenize_e}") |
|
|
|
except Exception as e: |
|
logger.error(f"Fatal error loading model: {e}", exc_info=True) |
|
llm = None |
|
logger.error("Model failed to load. Generation requests will not work.", extra={'error': str(e)}) |
|
|
|
app = Flask(__name__) |
|
|
|
@app.before_request |
|
def before_request_func(): |
|
g.request_id = str(uuid.uuid4()) |
|
logger.debug(f"Incoming request: {request.method} {request.path} from {request.remote_addr}", extra={'request_id': g.request_id, 'path': request.path, 'method': request.method}) |
|
|
|
|
|
load_model() |
|
|
|
html_code = """ |
|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>LLM API Demo</title> |
|
<style> |
|
body { font-family: sans-serif; margin: 20px; line-height: 1.6; background-color: #f4f4f4; color: #333; } |
|
.container { max-width: 800px; margin: auto; background: #fff; padding: 20px; border-radius: 8px; box-shadow: 0 2px 5px rgba(0,0,0,0.1); } |
|
h1, h2 { color: #0056b3; } |
|
.section { margin-bottom: 30px; padding: 20px; background-color: #e9e9e9; border-radius: 5px; } |
|
label { display: block; margin-bottom: 5px; font-weight: bold; } |
|
input[type="text"], input[type="number"], textarea, select { |
|
width: calc(100% - 22px); padding: 10px; margin-bottom: 10px; border: 1px solid #ccc; border-radius: 4px; |
|
} |
|
textarea { resize: vertical; min-height: 100px; } |
|
button { |
|
display: inline-block; background-color: #007bff; color: white; padding: 10px 15px; border: none; border-radius: 4px; cursor: pointer; font-size: 16px; |
|
margin-right: 10px; transition: background-color 0.3s ease; |
|
} |
|
button:hover { background-color: #0056b3; } |
|
button:disabled { background-color: #cccccc; cursor: not-allowed; } |
|
.output { |
|
background-color: #f9f9f9; border: 1px solid #ddd; padding: 15px; border-radius: 4px; white-space: pre-wrap; word-wrap: break-word; max-height: 400px; overflow-y: auto; font-family: monospace; |
|
} |
|
.error { color: red; font-weight: bold; } |
|
.info { color: green; } |
|
.warning { color: orange; } |
|
.param-fixed { font-style: italic; color: #555; margin-bottom: 10px; } |
|
.checkbox-container { display: flex; align-items: center; margin-bottom: 10px; } |
|
.checkbox-container input { margin-right: 5px; width: auto; } |
|
.continuation-info { font-weight: bold; } |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<h1>LLM API Demonstration</h1> |
|
<div class="section"> |
|
<h2>Health Check</h2> |
|
<button id="healthCheckBtn">Check Health</button> |
|
<p id="healthStatus"></p> |
|
</div> |
|
<div class="section"> |
|
<h2>API Info</h2> |
|
<button id="apiInfoBtn">Get Info</button> |
|
<pre id="apiInfoOutput" class="output"></pre> |
|
</div> |
|
<div class="section"> |
|
<h2>Generate Text (Automatic Continuation with Context Management)</h2> |
|
<label for="promptInput">Prompt / First User Message:</label> |
|
<textarea id="promptInput" placeholder="Enter your prompt here..."></textarea> |
|
|
|
<div class="param-fixed">Note: No artificial token limit. Generation continues until the model stops naturally, hits a stop sequence, or reaches the context window limit (N_CTX={{ N_CTX }}). If the context limit is reached, the server will attempt to continue automatically by <strong class="continuation-info">truncating older messages (unlimited continuations)</strong>. Other parameters (Temperature, Top P, Top K, Repeat Penalty, Seed) are fixed/random per generation cycle.</div> |
|
|
|
<div> |
|
<label for="stopInput">Stop Sequences (comma-separated):</label> |
|
<input type="text" id="stopInput" value=""> |
|
</div> |
|
<div> |
|
<label for="systemPromptInput">System Prompt (Optional Override - default: "{{ DEFAULT_SYSTEM_PROMPT | escape }}"):</label> |
|
<input type="text" id="systemPromptInput" placeholder="Leave empty to use default"> |
|
</div> |
|
<div> |
|
<label for="formatSelect">Format:</label> |
|
<select id="formatSelect"> |
|
<option value="">None</option> |
|
<option value="markdown">Markdown</option> |
|
</select> |
|
</div> |
|
<div class="checkbox-container"> |
|
<input type="checkbox" id="streamCheckbox" checked> |
|
<label for="streamCheckbox">Stream Output</label> |
|
</div> |
|
<button id="generateBtn">Generate</button> |
|
<p id="generationStatus"></p> |
|
<pre id="generationOutput" class="output"></pre> |
|
</div> |
|
</div> |
|
<script> |
|
const healthCheckBtn = document.getElementById('healthCheckBtn'); |
|
const healthStatus = document.getElementById('healthStatus'); |
|
const apiInfoBtn = document.getElementById('apiInfoBtn'); |
|
const apiInfoOutput = document.getElementById('apiInfoOutput'); |
|
const promptInput = document.getElementById('promptInput'); |
|
const stopInput = document.getElementById('stopInput'); |
|
const systemPromptInput = document.getElementById('systemPromptInput'); |
|
const formatSelect = document.getElementById('formatSelect'); |
|
const streamCheckbox = document.getElementById('streamCheckbox'); |
|
const generateBtn = document.getElementById('generateBtn'); |
|
const generationOutput = document.getElementById('generationOutput'); |
|
const generationStatus = document.getElementById('generationStatus'); |
|
const API_BASE_URL = window.location.origin; |
|
|
|
async function checkHealth() { |
|
healthStatus.textContent = 'Checking...'; |
|
healthStatus.className = ''; |
|
try { |
|
const response = await fetch(`${API_BASE_URL}/health`); |
|
const data = await response.json(); |
|
healthStatus.textContent = `Status: ${data.status}, Message: ${data.message}`; |
|
healthStatus.className = data.status === 'ok' ? 'info' : (data.status === 'warning' ? 'warning' : 'error'); |
|
} catch (error) { |
|
healthStatus.textContent = `Error fetching health: ${error}`; |
|
healthStatus.className = 'error'; |
|
} |
|
} |
|
|
|
async function getApiInfo() { |
|
apiInfoOutput.textContent = 'Loading...'; |
|
apiInfoOutput.className = 'output'; |
|
try { |
|
const response = await fetch(`${API_BASE_URL}/info`); |
|
if (!response.ok) { |
|
try { |
|
const errorData = await response.json(); |
|
throw new Error(`API Error ${response.status}: ${errorData.error || JSON.stringify(errorData)}`); |
|
} catch (e) { |
|
throw new Error(`API Error ${response.status}: ${response.statusText}`); |
|
} |
|
} |
|
const data = await response.json(); |
|
const nCtx = data?.model_config?.loaded_model_details?.n_ctx || '{{ N_CTX }}'; |
|
const maxContDesc = data?.generation_parameters?.max_automatic_continuations === null ? "unlimited continuations" : `up to ${data?.generation_parameters?.max_automatic_continuations} times`; |
|
|
|
const description = document.querySelector('.param-fixed'); |
|
if (description) { |
|
description.innerHTML = `Note: No artificial token limit. Generation continues until the model stops naturally, hits a stop sequence, or reaches the context window limit (N_CTX=${nCtx}). If the context limit is reached, the server will attempt to continue automatically by <strong class="continuation-info">truncating older messages (${maxContDesc})</strong>. Other parameters (Temperature, Top P, Top K, Repeat Penalty, Seed) are fixed/random per generation cycle.`; |
|
} |
|
|
|
apiInfoOutput.textContent = JSON.stringify(data, null, 2); |
|
} catch (error) { |
|
apiInfoOutput.textContent = `Error fetching info: ${error}`; |
|
apiInfoOutput.className = 'output error'; |
|
} |
|
} |
|
|
|
|
|
async function generateText() { |
|
generationOutput.textContent = ''; |
|
generationStatus.textContent = 'Preparing request...'; |
|
generationStatus.className = ''; |
|
generateBtn.disabled = true; |
|
|
|
const prompt = promptInput.value; |
|
if (!prompt.trim()) { |
|
generationStatus.textContent = 'Error: Prompt cannot be empty.'; |
|
generationStatus.className = 'error'; |
|
generateBtn.disabled = false; |
|
return; |
|
} |
|
|
|
const messages = [{"role": "user", "content": prompt}]; |
|
const stream = streamCheckbox.checked; |
|
const format = formatSelect.value || undefined; |
|
const stopSequences = stopInput.value.split(',').map(s => s.trim()).filter(s => s.length > 0); |
|
const stop = stopSequences.length > 0 ? stopSequences : undefined; |
|
const systemPrompt = systemPromptInput.value.trim() || undefined; |
|
|
|
const requestBody = { |
|
messages: messages, |
|
stop: stop, |
|
stream: stream, |
|
format: format, |
|
system_prompt: systemPrompt |
|
}; |
|
|
|
generationStatus.textContent = 'Generating... (may continue automatically with context truncation if needed)'; |
|
generationStatus.className = 'info'; |
|
|
|
try { |
|
const response = await fetch(`${API_BASE_URL}/generate`, { |
|
method: 'POST', |
|
headers: { 'Content-Type': 'application/json' }, |
|
body: JSON.stringify(requestBody), |
|
}); |
|
|
|
if (!response.ok) { |
|
const errorText = await response.text(); |
|
let errorMessage = `Error: ${response.status} ${response.statusText}`; |
|
try { |
|
const errorData = JSON.parse(errorText); |
|
errorMessage += ` - ${errorData.error || JSON.stringify(errorData.detail || errorData)}`; |
|
} catch (jsonParseError) { |
|
errorMessage += ` - ${errorText}`; |
|
} |
|
generationStatus.textContent = errorMessage; |
|
generationStatus.className = 'error'; |
|
generateBtn.disabled = false; |
|
return; |
|
} |
|
|
|
if (stream) { |
|
const reader = response.body.getReader(); |
|
const decoder = new TextDecoder('utf-8'); |
|
let finished = false; |
|
generationOutput.textContent = ''; |
|
let continuationCount = 0; |
|
|
|
while (!finished) { |
|
const { done, value } = await reader.read(); |
|
if (done) { |
|
finished = true; |
|
if (!generationStatus.textContent.includes("finished") && !generationStatus.textContent.includes("stopped") && !generationStatus.textContent.includes("Error")) { |
|
generationStatus.textContent = `Streaming finished. Continuations: ${continuationCount}.`; |
|
generationStatus.className = 'info'; |
|
} |
|
break; |
|
} |
|
const chunk = decoder.decode(value, { stream: true }); |
|
const continueMatch = chunk.match(/\n\[CONTINUING (\d+) - TRUNCATING CONTEXT\.\.\.\]\n/); |
|
if (continueMatch) { |
|
continuationCount = parseInt(continueMatch[1]); |
|
generationOutput.textContent += chunk; |
|
generationStatus.textContent = `Context limit reached, truncating history and continuing generation (Continuation #${continuationCount})...`; |
|
generationStatus.className = 'warning continuation-info'; |
|
} else if (chunk.startsWith("\\n[ERROR]")) { |
|
generationOutput.textContent += chunk; |
|
generationStatus.textContent = 'Error during generation (see output).'; |
|
generationStatus.className = 'error'; |
|
finished = true; |
|
} else if (chunk.startsWith("\\n[INFO] Generation stopped")) { |
|
generationOutput.textContent += chunk; |
|
generationStatus.textContent = `Generation stopped (see output for reason). Continuations: ${continuationCount}.`; |
|
generationStatus.className = 'info'; |
|
finished = true; |
|
} else { |
|
generationOutput.textContent += chunk; |
|
if (!generationStatus.className.includes('warning') && !generationStatus.className.includes('error')) { |
|
generationStatus.textContent = `Streaming... (Continuation #${continuationCount})`; |
|
generationStatus.className = 'info'; |
|
} |
|
} |
|
generationOutput.scrollTop = generationOutput.scrollHeight; |
|
} |
|
} else { |
|
const text = await response.text(); |
|
const finishReason = response.headers.get('X-Finish-Reason'); |
|
const continuations = response.headers.get('X-Continuations'); |
|
const usageTokens = response.headers.get('X-Usage-Completion-Tokens'); |
|
|
|
generationOutput.textContent = text; |
|
let statusText = `Generation finished. Reason: ${finishReason || 'unknown'}.`; |
|
if (continuations && parseInt(continuations) > 0) { |
|
statusText += ` Continuations: ${continuations} (context truncated).`; |
|
generationStatus.className = 'warning continuation-info'; |
|
} else { |
|
generationStatus.className = 'info'; |
|
} |
|
if (usageTokens) statusText += ` Tokens: ~${usageTokens}.`; |
|
generationStatus.textContent = statusText; |
|
} |
|
|
|
} catch (error) { |
|
generationStatus.textContent = `Network or processing error: ${error}`; |
|
generationStatus.className = 'error'; |
|
} finally { |
|
generateBtn.disabled = false; |
|
} |
|
} |
|
|
|
|
|
healthCheckBtn.addEventListener('click', checkHealth); |
|
apiInfoBtn.addEventListener('click', getApiInfo); |
|
generateBtn.addEventListener('click', generateText); |
|
|
|
checkHealth(); |
|
getApiInfo(); |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
|
|
@app.route("/") |
|
def index(): |
|
rendered_html = render_template_string( |
|
html_code, |
|
N_CTX=N_CTX, |
|
DEFAULT_SYSTEM_PROMPT=DEFAULT_SYSTEM_PROMPT |
|
) |
|
return rendered_html |
|
|
|
|
|
@app.route("/health", methods=["GET"]) |
|
def health_check(): |
|
if llm: |
|
if hasattr(llm, 'tokenize') and hasattr(llm, 'apply_chat_template'): |
|
return jsonify(status="ok", message="Model is loaded and ready."), 200 |
|
else: |
|
logger.warning("Model loaded, but tokenizer or chat template functions might be missing.") |
|
return jsonify(status="warning", message="Model loaded, but critical functions (tokenize/apply_chat_template) might be missing."), 200 |
|
else: |
|
return jsonify(status="error", message="Model failed to load or is not available."), 503 |
|
|
|
@app.route("/info", methods=["GET"]) |
|
def model_info(): |
|
request_id = getattr(g, 'request_id', 'N/A') |
|
if not llm: |
|
logger.warning("Info request received but model is not loaded.", extra={'request_id': request_id}) |
|
return jsonify(error="Model not available."), 503 |
|
|
|
model_details: Union[Dict[str, Any], str] = "Model details unavailable" |
|
actual_n_ctx = get_effective_n_ctx() |
|
actual_n_batch = N_BATCH |
|
actual_n_gpu_layers = N_GPU_LAYERS |
|
try: |
|
actual_n_batch = llm.n_batch if hasattr(llm, 'n_batch') else N_BATCH |
|
actual_n_gpu_layers = llm.n_gpu_layers if hasattr(llm, 'n_gpu_layers') else 0 |
|
n_embd = 'N/A' |
|
if hasattr(llm, '_model') and hasattr(llm._model, 'n_embd') and callable(llm._model.n_embd): |
|
try: |
|
n_embd = llm._model.n_embd() |
|
except Exception as embd_e: |
|
logger.warning(f"Could not get n_embd: {embd_e}", extra={'request_id': request_id}) |
|
|
|
|
|
model_details = { |
|
"n_embd": n_embd, |
|
"n_ctx": actual_n_ctx, |
|
"n_batch": actual_n_batch, |
|
"n_gpu_layers": actual_n_gpu_layers, |
|
"tokenizer_present": hasattr(llm, 'tokenize'), |
|
"chat_handler_present": hasattr(llm, 'apply_chat_template') and hasattr(llm, 'create_chat_completion'), |
|
} |
|
except Exception as e: |
|
logger.warning(f"Could not retrieve all model details: {e}", extra={'request_id': request_id}) |
|
model_details = f"Error retrieving some model details: {e}" |
|
|
|
info = { |
|
"status": "ok", |
|
"message": "Model is loaded. Generation continues automatically with context truncation if context limit is hit.", |
|
"model_config": { |
|
"repo_id": MODEL_REPO, |
|
"filename": MODEL_FILE, |
|
"initial_load_config": { |
|
"n_ctx": os.getenv("N_CTX", "2048"), |
|
"n_batch": N_BATCH, |
|
"n_gpu_layers": 0, |
|
}, |
|
"loaded_model_details": model_details, |
|
}, |
|
"generation_parameters": { |
|
"note": f"No artificial 'max_tokens' limit. Generation proceeds until stop sequence, EOS, or context limit (N_CTX={actual_n_ctx}). Automatic continuation attempts by truncating context **indefinitely** if context limit is reached. Sampling parameters (temperature, top_p, top_k) are chosen randomly per request/continuation cycle from predefined sets. Repeat penalty and seed are fixed.", |
|
"fixed_max_tokens": None, |
|
"fixed_repeat_penalty": FIXED_REPEAT_PENALTY, |
|
"fixed_seed": FIXED_SEED, |
|
"max_automatic_continuations": None, |
|
"context_truncation_buffer_ratio": CONTEXT_TRUNCATION_BUFFER_RATIO, |
|
"randomly_chosen_from": RANDOM_PARAMS_CHOICES, |
|
"default_system_prompt": DEFAULT_SYSTEM_PROMPT, |
|
"user_controllable": ["messages", "prompt", "stop", "stream", "format", "system_prompt"], |
|
}, |
|
} |
|
return jsonify(info), 200 |
|
|
|
|
|
@app.route("/generate", methods=["POST"]) |
|
def generate(): |
|
request_id = getattr(g, 'request_id', 'N/A') |
|
if not llm: |
|
logger.error("Generate request received but model is not loaded.", extra={'request_id': request_id}) |
|
return jsonify(error="Model is not available.", detail="The LLM model could not be loaded."), 503 |
|
|
|
if not request.is_json: |
|
logger.warning("Request received without Content-Type: application/json", extra={'request_id': request_id}) |
|
return jsonify(error="Invalid request header", detail="Content-Type must be application/json"), 415 |
|
|
|
data = request.get_json() |
|
is_streaming = data.get("stream", True) |
|
response_format = data.get("format") |
|
|
|
log_data_summary = {k: v for k, v in data.items() if k not in ('messages', 'prompt')} |
|
log_data_summary['messages_count_initial'] = len(data.get('messages', [])) if 'messages' in data else 0 |
|
log_data_summary['has_prompt_initial'] = 'prompt' in data |
|
log_data_summary['stream'] = is_streaming |
|
log_data_summary['format'] = response_format |
|
logger.info(f"Received generation request summary.", extra={'request_id': request_id, 'summary': log_data_summary}) |
|
|
|
try: |
|
initial_messages = prepare_messages(data, format=response_format) |
|
base_generation_params = parse_and_validate_params(data) |
|
|
|
effective_n_ctx = get_effective_n_ctx() |
|
input_token_count = estimate_token_count(initial_messages) |
|
|
|
if input_token_count != -1 and input_token_count >= effective_n_ctx: |
|
truncated_initial = truncate_messages_for_context(initial_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO) |
|
truncated_token_count = estimate_token_count(truncated_initial) |
|
|
|
if truncated_token_count != -1 and truncated_token_count >= effective_n_ctx: |
|
error_msg = f"Initial input exceeds context window ({effective_n_ctx}) even after attempting truncation. Input tokens (~{input_token_count}) / Truncated tokens (~{truncated_token_count}). Reduce initial message size significantly." |
|
logger.error(error_msg, extra={'request_id': request_id, 'initial_tokens': input_token_count, 'truncated_tokens': truncated_token_count, 'n_ctx': effective_n_ctx}) |
|
return jsonify(error="Input exceeds context window", detail=error_msg), 400 |
|
else: |
|
logger.warning(f"Initial input (~{input_token_count} tokens) exceeded context window ({effective_n_ctx}). Truncated to ~{truncated_token_count} tokens.", extra={'request_id': request_id, 'initial_tokens': input_token_count, 'truncated_tokens': truncated_token_count, 'n_ctx': effective_n_ctx}) |
|
initial_messages = truncated_initial |
|
input_token_count = truncated_token_count |
|
|
|
elif input_token_count != -1: |
|
logger.info(f"Initial input token count: ~{input_token_count}. Context window: {effective_n_ctx}. Remaining: {effective_n_ctx - input_token_count}.", extra={'request_id': request_id, 'input_tokens': input_token_count, 'n_ctx': effective_n_ctx, 'remaining_ctx': effective_n_ctx - input_token_count}) |
|
else: |
|
logger.warning("Could not estimate initial token count. Proceeding with generation, may hit context limit.", extra={'request_id': request_id}) |
|
|
|
|
|
logger.info(f"Processing request with {len(initial_messages)} initial messages. Stream={is_streaming}. Format={response_format}. max_tokens=None (dynamic). Unlimited Continuations.", extra={'request_id': request_id}) |
|
|
|
except ValueError as e: |
|
logger.error(f"Invalid input data: {e}", exc_info=True, extra={'request_id': request_id}) |
|
try: error_detail = json.loads(str(e)) |
|
except json.JSONDecodeError: error_detail = str(e) |
|
return jsonify(error="Invalid input", detail=error_detail), 400 |
|
except Exception as e: |
|
logger.error(f"Unexpected error preparing request: {e}", exc_info=True, extra={'request_id': request_id}) |
|
return jsonify(error="Internal server error", detail="An unexpected error occurred processing the request."), 500 |
|
|
|
|
|
if is_streaming: |
|
def generate_streaming_with_continuation(current_request_id: str) -> Generator[str, None, None]: |
|
current_messages = list(initial_messages) |
|
continuations = 0 |
|
total_tokens_generated_stream = 0 |
|
effective_n_ctx = get_effective_n_ctx() |
|
|
|
while True: |
|
cycle_number = continuations + 1 |
|
logger.info(f"Starting streaming generation cycle {cycle_number}. Message count: {len(current_messages)}.", extra={'request_id': current_request_id}) |
|
|
|
chosen_params = random.choice(RANDOM_PARAMS_CHOICES) |
|
current_dynamic_params = { |
|
"temperature": chosen_params["temperature"], |
|
"top_p": chosen_params["top_p"], |
|
"top_k": chosen_params["top_k"], |
|
} |
|
current_params = {**base_generation_params, **current_dynamic_params} |
|
logger.debug(f"Cycle {cycle_number} params: temp={current_params['temperature']}, top_p={current_params['top_p']}, top_k={current_params['top_k']}", extra={'request_id': current_request_id}) |
|
|
|
|
|
generated_this_cycle = "" |
|
finish_reason = None |
|
hit_context_limit_in_cycle = False |
|
|
|
try: |
|
streamer = llm.create_chat_completion( |
|
messages=current_messages, |
|
max_tokens=current_params["max_tokens"], |
|
temperature=current_params["temperature"], |
|
top_p=current_params["top_p"], |
|
top_k=current_params["top_k"], |
|
repeat_penalty=current_params["repeat_penalty"], |
|
stop=current_params["stop"], |
|
seed=current_params["seed"], |
|
stream=True, |
|
) |
|
|
|
for chunk in streamer: |
|
choice = chunk.get("choices", [{}])[0] |
|
delta = choice.get("delta", {}) |
|
token = delta.get("content") |
|
current_chunk_finish_reason = choice.get("finish_reason") |
|
|
|
if token: |
|
generated_this_cycle += token |
|
total_tokens_generated_stream += 1 |
|
yield token |
|
|
|
if current_chunk_finish_reason: |
|
finish_reason = current_chunk_finish_reason |
|
logger.info(f"Streaming chunk finished cycle {cycle_number}. Reason: {finish_reason}", extra={'request_id': current_request_id, 'finish_reason': finish_reason}) |
|
if finish_reason == 'length': |
|
hit_context_limit_in_cycle = True |
|
usage = chunk.get("usage") |
|
if usage: logger.debug(f"Usage reported in final chunk: {usage}", extra={'request_id': current_request_id, 'usage': usage}) |
|
break |
|
if not finish_reason: |
|
pass |
|
|
|
|
|
except Exception as e: |
|
err_str = str(e).lower() |
|
if "context window is full" in err_str or \ |
|
"kv cache is full" in err_str or \ |
|
"llama_decode" in err_str or \ |
|
(hasattr(e, 'condition') and ("context length" in str(e.condition).lower() or "failed to decode" in str(e.condition).lower())): |
|
logger.warning(f"N_CTX limit or related exception caught during streaming cycle {cycle_number}: {e}", extra={'request_id': current_request_id}) |
|
hit_context_limit_in_cycle = True |
|
finish_reason = 'length' |
|
else: |
|
logger.error(f"Unhandled error during streaming generation cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': current_request_id}) |
|
yield f"\n[ERROR] Generation failed unexpectedly in cycle {cycle_number}: {str(e)}" |
|
return |
|
|
|
if generated_this_cycle: |
|
if not current_messages or current_messages[-1].get('role') != 'assistant': |
|
current_messages.append({"role": "assistant", "content": generated_this_cycle}) |
|
else: |
|
current_messages[-1]['content'] += generated_this_cycle |
|
elif hit_context_limit_in_cycle: |
|
logger.warning(f"Context limit hit in streaming cycle {cycle_number} but no tokens were generated in this cycle. Check model behavior.", extra={'request_id': current_request_id}) |
|
elif not finish_reason: |
|
logger.warning(f"Stream cycle {cycle_number} ended without generating tokens or a definite finish reason. Stopping.", extra={'request_id': current_request_id}) |
|
yield f"\n[INFO] Generation stopped: Cycle ended unexpectedly." |
|
break |
|
|
|
|
|
if finish_reason == 'stop': |
|
logger.info(f"Generation stopped naturally (reason: stop) in streaming cycle {cycle_number}. Total stream tokens: ~{total_tokens_generated_stream}", extra={'request_id': current_request_id}) |
|
yield f"\n[INFO] Generation stopped: Stop sequence or EOS." |
|
break |
|
|
|
elif hit_context_limit_in_cycle: |
|
continuations += 1 |
|
logger.warning(f"N_CTX limit reached in streaming cycle {cycle_number}. Attempting continuation {continuations} (reinicio de contador).", extra={'request_id': current_request_id}) |
|
current_messages = truncate_messages_for_context(current_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO) |
|
if not current_messages: |
|
logger.error("Context truncation resulted in empty messages during streaming. Stopping.", extra={'request_id': current_request_id}) |
|
yield f"\n[ERROR] Generation failed: Context truncation error." |
|
break |
|
yield f"\n[CONTINUING {continuations} - TRUNCATING CONTEXT...]\n" |
|
time.sleep(0.1) |
|
continue |
|
else: |
|
logger.warning(f"Streaming generation cycle {cycle_number} ended with reason '{finish_reason}' or unexpectedly. Stopping generation.", extra={'request_id': current_request_id, 'finish_reason': finish_reason}) |
|
yield f"\n[INFO] Generation stopped: Reason: {finish_reason or 'Unknown'}" |
|
break |
|
|
|
logger.info(f"Streaming generation finished after {continuations} continuations. Total stream tokens generated: ~{total_tokens_generated_stream}", extra={'request_id': current_request_id, 'continuations': continuations, 'total_stream_tokens': total_tokens_generated_stream}) |
|
|
|
|
|
headers = { |
|
"Content-Type": "text/event-stream; charset=utf-8", |
|
"Cache-Control": "no-cache", |
|
"Connection": "keep-alive", |
|
"X-Accel-Buffering": "no", |
|
"X-Request-ID": request_id |
|
} |
|
return Response(stream_with_context(generate_streaming_with_continuation(request_id)), headers=headers) |
|
|
|
else: |
|
current_messages = list(initial_messages) |
|
continuations = 0 |
|
full_generated_text = "" |
|
total_tokens_generated_nonstream = 0 |
|
final_finish_reason = "unknown" |
|
final_usage = {} |
|
effective_n_ctx = get_effective_n_ctx() |
|
|
|
while True: |
|
cycle_number = continuations + 1 |
|
logger.info(f"Starting non-streaming generation cycle {cycle_number}. Message count: {len(current_messages)}.", extra={'request_id': request_id}) |
|
|
|
chosen_params = random.choice(RANDOM_PARAMS_CHOICES) |
|
current_dynamic_params = { |
|
"temperature": chosen_params["temperature"], |
|
"top_p": chosen_params["top_p"], |
|
"top_k": chosen_params["top_k"], |
|
} |
|
current_params = {**base_generation_params, **current_dynamic_params} |
|
logger.debug(f"Cycle {cycle_number} params: temp={current_params['temperature']}, top_p={current_params['top_p']}, top_k={current_params['top_k']}", extra={'request_id': request_id}) |
|
|
|
|
|
generated_this_cycle = "" |
|
finish_reason = None |
|
hit_context_limit_in_cycle = False |
|
usage_this_cycle = {} |
|
|
|
try: |
|
result = llm.create_chat_completion( |
|
messages=current_messages, |
|
max_tokens=current_params["max_tokens"], |
|
temperature=current_params["temperature"], |
|
top_p=current_params["top_p"], |
|
top_k=current_params["top_k"], |
|
repeat_penalty=current_params["repeat_penalty"], |
|
stop=current_params["stop"], |
|
seed=current_params["seed"], |
|
stream=False, |
|
) |
|
|
|
if result and "choices" in result and result["choices"]: |
|
choice = result["choices"][0] |
|
generated_this_cycle = choice.get("message", {}).get("content", "") |
|
finish_reason = choice.get("finish_reason", "unknown") |
|
else: |
|
logger.error(f"Invalid response structure from llama_cpp in non-streaming cycle {cycle_number}: {result}", extra={'request_id': request_id}) |
|
return jsonify(error="Generation failed", detail=f"Invalid response structure from model in cycle {cycle_number}."), 500 |
|
|
|
|
|
usage_this_cycle = result.get("usage", {}) |
|
final_finish_reason = finish_reason |
|
if usage_this_cycle: final_usage = usage_this_cycle |
|
|
|
logger.info(f"Non-streaming cycle {cycle_number} finished. Reason: {finish_reason}. Usage: {usage_this_cycle}", extra={'request_id': request_id, 'usage': usage_this_cycle, 'finish_reason': finish_reason}) |
|
|
|
if finish_reason == 'length': |
|
hit_context_limit_in_cycle = True |
|
|
|
except Exception as e: |
|
err_str = str(e).lower() |
|
if "context window is full" in err_str or \ |
|
"kv cache is full" in err_str or \ |
|
"llama_decode" in err_str or \ |
|
(hasattr(e, 'condition') and ("context length" in str(e.condition).lower() or "failed to decode" in str(e.condition).lower())): |
|
logger.warning(f"N_CTX limit or related exception caught during non-streaming cycle {cycle_number}: {e}", extra={'request_id': request_id}) |
|
hit_context_limit_in_cycle = True |
|
finish_reason = 'length' |
|
else: |
|
logger.error(f"Unhandled error during non-streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': request_id}) |
|
return jsonify(error="Generation failed", detail=f"Internal error in cycle {cycle_number}: {str(e)}"), 500 |
|
|
|
if generated_this_cycle: |
|
if continuations > 0 and full_generated_text: |
|
full_generated_text += f"\n\n[CONTINUATION {continuations} - TRUNCATED CONTEXT]\n\n" |
|
full_generated_text += generated_this_cycle |
|
tokens_generated_cycle = usage_this_cycle.get("completion_tokens", 0) |
|
total_tokens_generated_nonstream += tokens_generated_cycle |
|
|
|
if not current_messages or current_messages[-1].get('role') != 'assistant': |
|
current_messages.append({"role": "assistant", "content": generated_this_cycle}) |
|
else: |
|
current_messages[-1]['content'] += generated_this_cycle |
|
|
|
elif hit_context_limit_in_cycle: |
|
logger.warning(f"Non-streaming N_CTX limit hit in cycle {cycle_number} but no completion tokens reported.", extra={'request_id': request_id}) |
|
if continuations > 0 and full_generated_text: |
|
full_generated_text += f"\n\n[CONTINUATION {continuations} - TRUNCATED CONTEXT - NO OUTPUT THIS CYCLE]\n\n" |
|
|
|
elif not finish_reason: |
|
logger.warning(f"Non-streaming cycle {cycle_number} ended without generating tokens or a finish reason. Stopping.", extra={'request_id': request_id}) |
|
full_generated_text += f"\n[INFO: Generation stopped: Cycle {cycle_number} ended unexpectedly.]" |
|
break |
|
|
|
|
|
if finish_reason == 'stop': |
|
logger.info(f"Non-streaming generation stopped naturally (reason: stop) in cycle {cycle_number}.", extra={'request_id': request_id}) |
|
break |
|
|
|
elif hit_context_limit_in_cycle: |
|
continuations += 1 |
|
logger.warning(f"Non-streaming N_CTX limit reached in cycle {cycle_number}. Attempting continuation {continuations} (reinicio de contador).", extra={'request_id': request_id}) |
|
current_messages = truncate_messages_for_context(current_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO) |
|
if not current_messages: |
|
logger.error("Context truncation resulted in empty messages during non-streaming. Stopping.", extra={'request_id': request_id}) |
|
full_generated_text += f"\n[ERROR: Generation failed: Context truncation error.]" |
|
final_finish_reason = "truncation_error" |
|
break |
|
continue |
|
else: |
|
logger.warning(f"Non-streaming cycle {cycle_number} ended with reason '{finish_reason}' or unexpectedly. Stopping generation.", extra={'request_id': request_id, 'finish_reason': finish_reason}) |
|
full_generated_text += f"\n\n[INFO: Generation stopped unexpectedly. Reason: {finish_reason or 'Unknown'}]" |
|
break |
|
|
|
logger.info(f"Non-streaming generation finished after {continuations} continuations. Total completion tokens reported: {total_tokens_generated_nonstream}. Final reason: {final_finish_reason}", extra={'request_id': request_id, 'continuations': continuations, 'total_completion_tokens': total_tokens_generated_nonstream, 'final_reason': final_finish_reason}) |
|
|
|
response = Response(full_generated_text, mimetype="text/plain; charset=utf-8") |
|
response.headers["X-Request-ID"] = request_id |
|
response.headers["X-Finish-Reason"] = final_finish_reason |
|
response.headers["X-Continuations"] = str(continuations) |
|
total_prompt_tokens = final_usage.get("prompt_tokens", "N/A") |
|
response.headers["X-Usage-Completion-Tokens"] = str(total_tokens_generated_nonstream) |
|
response.headers["X-Usage-Prompt-Tokens-Last-Cycle"] = str(total_prompt_tokens) |
|
response.headers["X-Usage-Total-Tokens-Last-Cycle"] = str(final_usage.get("total_tokens", "N/A")) |
|
return response |
|
|
|
|
|
if __name__ == "__main__": |
|
host = os.getenv("HOST", "0.0.0.0") |
|
port = int(os.getenv("PORT", "7860")) |
|
is_debug = os.getenv("FLASK_DEBUG", "0") == "1" |
|
log_level = logging.DEBUG if is_debug else logging.INFO |
|
logger.setLevel(log_level) |
|
logger.info(f"Starting Flask server on {host}:{port} (Debug mode: {is_debug})") |
|
logger.info(f"Model: {MODEL_REPO}/{MODEL_FILE}, N_CTX={N_CTX}, Automatic Continuations: UNLIMITED (with context truncation)") |
|
|
|
if not llm: |
|
logger.critical("MODEL FAILED TO LOAD. SERVER WILL START BUT '/generate' WILL FAIL.") |
|
|
|
if not is_debug: |
|
try: |
|
from waitress import serve |
|
logger.info("Running with Waitress production server.") |
|
serve(app, host=host, port=port, threads=8) |
|
except ImportError: |
|
logger.warning("Waitress not found. Falling back to Flask development server. Install waitress for production.") |
|
app.run(host=host, port=port, threaded=True, debug=is_debug) |
|
else: |
|
logger.info("Running with Flask development server (Debug=True).") |
|
app.run(host=host, port=port, threaded=True, debug=is_debug, use_reloader=False) |