Spaces:
Running
Running
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
import os | |
import json | |
import logging | |
import uuid | |
from typing import List, Dict, Union, Optional, Generator, Any | |
import random | |
import time | |
from flask import Flask, request, Response, stream_with_context, jsonify, g, render_template_string | |
from llama_cpp import Llama | |
class JsonFormatter(logging.Formatter): | |
def format(self, record): | |
log_record = { | |
"timestamp": self.formatTime(record, self.datefmt), | |
"level": record.levelname, | |
"name": record.name, | |
"message": record.getMessage(), | |
"pathname": record.pathname, | |
"lineno": record.lineno, | |
} | |
if hasattr(record, 'request_id'): | |
log_record['request_id'] = record.request_id | |
if record.exc_info: | |
log_record['exception'] = self.formatException(record.exc_info) | |
if record.stack_info: | |
log_record['stack_info'] = self.formatStack(record.stack_info) | |
skip_keys = {'message', 'asctime', 'levelname', 'levelno', 'pathname', 'filename', 'module', 'funcName', 'lineno', 'created', 'msecs', 'relativeCreated', 'thread', 'threadName', 'process', 'processName', 'exc_info', 'exc_text', 'stack_info', 'request_id'} | |
for key, value in record.__dict__.items(): | |
if not key.startswith('_') and key not in log_record and key not in skip_keys: | |
# Ensure value is JSON serializable | |
try: | |
json.dumps(value) | |
log_record[key] = value | |
except TypeError: | |
log_record[key] = str(value) # Convert non-serializable types to string | |
except Exception: | |
log_record[key] = "[Unserializable Value]" | |
return json.dumps(log_record) | |
def setup_logging(): | |
logger = logging.getLogger() | |
if not logger.handlers: | |
handler = logging.StreamHandler() | |
formatter = JsonFormatter() | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
logger.setLevel(logging.INFO) | |
logging.getLogger("werkzeug").setLevel(logging.ERROR) | |
logging.getLogger("llama_cpp").setLevel(logging.WARNING) | |
return logger | |
logger = setup_logging() | |
MODEL_REPO = os.getenv("MODEL_REPO", "jnjj/vcvcvcv") | |
MODEL_FILE = os.getenv("MODEL_FILE", "gemma-3-4b-it-q4_0.gguf") | |
N_CTX_CONFIG = int(os.getenv("N_CTX", "2048")) | |
N_BATCH = int(os.getenv("N_BATCH", "512")) | |
N_GPU_LAYERS_CONFIG = int(os.getenv("N_GPU_LAYERS", "0")) | |
MAX_CONTINUATIONS = int(os.getenv("MAX_CONTINUATIONS", "-1")) | |
FIXED_REPEAT_PENALTY = float(os.getenv("FIXED_REPEAT_PENALTY", "1.1")) | |
FIXED_SEED = int(os.getenv("FIXED_SEED", "-1")) | |
DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "Eres un asistente conciso, directo y útil.") | |
CONTEXT_TRUNCATION_BUFFER_RATIO = float(os.getenv("CONTEXT_TRUNCATION_BUFFER_RATIO", "0.85")) | |
RANDOM_PARAMS_CHOICES = [ | |
{"temperature": 0.2, "top_p": 0.5, "top_k": 10}, | |
{"temperature": 0.1, "top_p": 0.5, "top_k": 10}, | |
{"temperature": 0.3, "top_p": 0.5, "top_k": 10}, | |
{"temperature": 0.4, "top_p": 0.5, "top_k": 10}, | |
{"temperature": 0.6, "top_p": 0.3, "top_k": 5}, | |
{"temperature": 0.5, "top_p": 0.7, "top_k": 20}, | |
] | |
llm: Optional[Llama] = None | |
ACTUAL_N_CTX: int = N_CTX_CONFIG | |
ACTUAL_N_BATCH: int = N_BATCH | |
ACTUAL_N_GPU_LAYERS: int = N_GPU_LAYERS_CONFIG | |
class ContextLimitException(Exception): | |
pass | |
class GenerationFailedException(Exception): | |
pass | |
def prepare_messages(data: Dict, format: Optional[str] = None, request_id: str = 'N/A') -> List[Dict[str, str]]: | |
messages_list = data.get("messages") | |
prompt_str = data.get("prompt") | |
system_instruction = data.get("system_prompt", DEFAULT_SYSTEM_PROMPT) | |
if not messages_list and not prompt_str: | |
raise ValueError("Either 'messages' list or 'prompt' string is required.") | |
if messages_list is not None and not isinstance(messages_list, list): | |
raise ValueError("'messages' must be a list of dictionaries.") | |
if prompt_str is not None and not isinstance(prompt_str, str): | |
raise ValueError("'prompt' must be a string.") | |
if system_instruction is not None and not isinstance(system_instruction, str): | |
raise ValueError("'system_prompt' must be a string.") | |
final_messages = [] | |
content_format_instruction = "" | |
if format == "markdown": | |
content_format_instruction = " Format your response using Markdown." | |
elif format is not None: | |
logger.warning(f"Unsupported format '{format}' requested.", extra={'request_id': request_id, 'format': format}) | |
effective_system_prompt_content = system_instruction.strip() + content_format_instruction.strip() | |
if effective_system_prompt_content: | |
final_messages.append({"role": "system", "content": effective_system_prompt_content}) | |
if messages_list: | |
has_user_message = False | |
for i, msg in enumerate(messages_list): | |
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: | |
raise ValueError(f"Message at index {i} is invalid: must be a dictionary with 'role' and 'content'.") | |
role = msg.get("role") | |
content = msg.get("content", "") | |
if not isinstance(content, str): | |
logger.warning(f"Message content at index {i} (role: {role}) is not a string (type: {type(content)}). Converting to string.", extra={'request_id': request_id, 'message_index': i, 'role': role, 'content_type': type(content)}) | |
content = str(content) | |
if role == "system": | |
if i == 0 and final_messages and final_messages[0]["role"] == "system": | |
logger.info("Replacing default system prompt with user-provided system message.", extra={'request_id': request_id}) | |
final_messages[0]["content"] = content | |
elif i == 0 and not final_messages: | |
final_messages.append({"role": "system", "content": content}) | |
else: | |
logger.warning(f"Ignoring additional system message at index {i} as system prompt is already set or should be at the start.", extra={'request_id': request_id, 'message_index': i}) | |
continue | |
elif role == "user": | |
has_user_message = True | |
final_messages.append({"role": role, "content": content}) | |
if not has_user_message and any(m["role"] != "system" for m in final_messages): | |
logger.warning("The 'messages' list contains no user messages.", extra={'request_id': request_id}) | |
elif prompt_str: | |
final_messages.append({"role": "user", "content": prompt_str}) | |
if not final_messages or all(m["role"] == "system" for m in final_messages): | |
raise ValueError("No user or assistant messages found to generate a response.") | |
return final_messages | |
def estimate_token_count(messages: List[Dict[str, str]], request_id: str = 'N/A') -> int: | |
if not llm or not hasattr(llm, 'tokenize') or not hasattr(llm, 'apply_chat_template'): | |
logger.warning("LLM or tokenizer/template function not available for token estimation.", extra={'request_id': request_id}) | |
return -1 | |
try: | |
chat_prompt_string = llm.apply_chat_template(messages, add_generation_prompt=True) | |
tokens = llm.tokenize(chat_prompt_string.encode('utf-8', errors='ignore'), add_bos=True) | |
return len(tokens) | |
except Exception as e: | |
logger.error(f"Could not estimate token count using apply_chat_template: {e}", exc_info=True, extra={'request_id': request_id}) | |
char_count = sum(len(m.get('content', '')) for m in messages) | |
estimated_tokens = char_count // 4 | |
logger.warning(f"Falling back to character-based token estimation (~{estimated_tokens})", extra={'request_id': request_id, 'estimated_tokens': estimated_tokens, 'char_count': char_count}) | |
return estimated_tokens | |
def get_effective_n_ctx() -> int: | |
return ACTUAL_N_CTX | |
def truncate_messages_for_context(messages: List[Dict[str, str]], max_tokens: int, buffer_ratio: float, request_id: str = 'N/A') -> List[Dict[str, str]]: | |
if not llm: return messages | |
target_token_limit = int(max_tokens * buffer_ratio) | |
truncated_messages: List[Dict[str, str]] = [] | |
system_prompt: Optional[Dict[str, str]] = None | |
if messages and messages[0].get("role") == "system": | |
system_prompt = messages[0] | |
truncated_messages.append(system_prompt) | |
remaining_messages = messages[1:] | |
else: | |
remaining_messages = messages | |
current_token_count = estimate_token_count(truncated_messages, request_id=request_id) if truncated_messages else 0 | |
if current_token_count == -1: | |
logger.warning("Could not estimate initial token count for truncation, proceeding cautiously with char estimate.", extra={'request_id': request_id}) | |
current_token_count = sum(len(m.get('content', '')) for m in truncated_messages) // 4 | |
messages_to_add = [] | |
for msg in reversed(remaining_messages): | |
potential_list = ([system_prompt] if system_prompt else []) + [msg] + messages_to_add | |
next_token_count = estimate_token_count(potential_list, request_id=request_id) | |
if next_token_count != -1 and next_token_count <= target_token_limit: | |
messages_to_add.insert(0, msg) | |
current_token_count = next_token_count | |
elif next_token_count == -1: | |
logger.warning(f"Token estimation failed while adding message: {msg}. Stopping truncation early.", extra={'request_id': request_id}) | |
break | |
else: | |
logger.debug(f"Stopping truncation: Adding next message would exceed target limit ({next_token_count} > {target_token_limit}).", extra={'request_id': request_id}) | |
break | |
final_truncated_list = ([system_prompt] if system_prompt else []) + messages_to_add | |
original_count = len(messages) | |
final_count = len(final_truncated_list) | |
if not final_truncated_list or all(m.get("role") == "system" for m in final_truncated_list): | |
if any(m.get("role") == "user" for m in messages): | |
last_user_message = next((m for m in reversed(messages) if m.get("role") == "user"), None) | |
if last_user_message: | |
logger.warning("Truncation resulted in empty or system-only messages, attempting to keep last user message.", extra={'request_id': request_id}) | |
final_truncated_list = ([system_prompt] if system_prompt else []) + [last_user_message] | |
final_count = len(final_truncated_list) | |
current_token_count = estimate_token_count(final_truncated_list, request_id=request_id) | |
if final_count < original_count: | |
logger.warning(f"Context truncated: Kept {final_count}/{original_count} messages. Estimated tokens: ~{current_token_count}/{target_token_limit} (target).", | |
extra={'request_id': request_id, 'kept': final_count, 'original': original_count, 'estimated_tokens': current_token_count, 'target_limit': target_token_limit}) | |
else: | |
logger.debug(f"Context truncation check complete. Kept all {final_count} messages. Estimated tokens: ~{current_token_count}.", | |
extra={'request_id': request_id, 'kept': final_count, 'estimated_tokens': current_token_count}) | |
if not final_truncated_list: | |
logger.error("Context truncation resulted in an empty message list!", extra={'request_id': request_id}) | |
return [] | |
return final_truncated_list | |
def get_property_or_method_value(obj: Any, prop_name: str, default: Any = None) -> Any: | |
"""Safely get property value or call method if callable.""" | |
if hasattr(obj, prop_name): | |
prop = getattr(obj, prop_name) | |
if callable(prop): | |
try: | |
return prop() | |
except Exception: | |
logger.warning(f"Error calling method {prop_name} on {type(obj)}", exc_info=True) | |
return default | |
else: | |
return prop | |
return default | |
def load_model(): | |
global llm, ACTUAL_N_CTX, ACTUAL_N_BATCH, ACTUAL_N_GPU_LAYERS | |
logger.info(f"Attempting to load model: {MODEL_REPO}/{MODEL_FILE}") | |
logger.info(f"Configuration: N_CTX={N_CTX_CONFIG}, N_BATCH={N_BATCH}, N_GPU_LAYERS={N_GPU_LAYERS_CONFIG}") | |
try: | |
llm = Llama.from_pretrained( | |
repo_id=MODEL_REPO, | |
filename=MODEL_FILE, | |
n_ctx=N_CTX_CONFIG, | |
n_batch=N_BATCH, | |
n_gpu_layers=N_GPU_LAYERS_CONFIG, | |
verbose=False, | |
use_mmap=True, | |
use_mlock=True, | |
) | |
logger.info("Model loaded successfully.") | |
if llm: | |
ACTUAL_N_CTX = get_property_or_method_value(llm, 'n_ctx', N_CTX_CONFIG) | |
ACTUAL_N_BATCH = get_property_or_method_value(llm, 'n_batch', N_BATCH) | |
ACTUAL_N_GPU_LAYERS = get_property_or_method_value(llm, 'n_gpu_layers', 0) | |
if ACTUAL_N_CTX != N_CTX_CONFIG: | |
logger.warning(f"Model's actual context size ({ACTUAL_N_CTX}) differs from config ({N_CTX_CONFIG}). Using actual.", extra={'actual_n_ctx': ACTUAL_N_CTX, 'configured_n_ctx': N_CTX_CONFIG}) | |
if ACTUAL_N_GPU_LAYERS != N_GPU_LAYERS_CONFIG: | |
logger.warning(f"Model loaded with {ACTUAL_N_GPU_LAYERS} GPU layers despite requesting {N_GPU_LAYERS_CONFIG}. Check llama.cpp build or environment.", extra={'actual_gpu_layers': ACTUAL_N_GPU_LAYERS, 'configured_gpu_layers': N_GPU_LAYERS_CONFIG}) | |
logger.info(f"Actual Model Context Window (n_ctx): {ACTUAL_N_CTX}") | |
logger.info(f"Actual Model Batch Size (n_batch): {ACTUAL_N_BATCH}") | |
logger.info(f"Actual Model GPU Layers (n_gpu_layers): {ACTUAL_N_GPU_LAYERS}") | |
try: | |
test_tokens = llm.tokenize(b"Test sentence.") | |
logger.info(f"Tokenizer test successful. 'Test sentence.' -> {len(test_tokens)} tokens.") | |
except Exception as tokenize_e: | |
logger.warning(f"Could not perform test tokenization: {tokenize_e}") | |
except Exception as e: | |
logger.error(f"Fatal error loading model: {e}", exc_info=True) | |
llm = None | |
logger.error("Model failed to load. Generation requests will not work.", extra={'error': str(e)}) | |
app = Flask(__name__) | |
def before_request_func(): | |
g.request_id = str(uuid.uuid4()) | |
logger.debug(f"Incoming request: {request.method} {request.path} from {request.remote_addr}", extra={'request_id': g.request_id, 'path': request.path, 'method': request.method}) | |
load_model() | |
html_code = """ | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>LLM API Demo</title> | |
<style> | |
body { font-family: sans-serif; margin: 20px; line-height: 1.6; background-color: #f4f4f4; color: #333; } | |
.container { max-width: 800px; margin: auto; background: #fff; padding: 20px; border-radius: 8px; box-shadow: 0 2px 5px rgba(0,0,0,0.1); } | |
h1, h2 { color: #0056b3; } | |
.section { margin-bottom: 30px; padding: 20px; background-color: #e9e9e9; border-radius: 5px; } | |
label { display: block; margin-bottom: 5px; font-weight: bold; } | |
input[type="text"], input[type="number"], textarea, select { | |
width: calc(100% - 22px); padding: 10px; margin-bottom: 10px; border: 1px solid #ccc; border-radius: 4px; | |
} | |
textarea { resize: vertical; min-height: 100px; } | |
button { | |
display: inline-block; background-color: #007bff; color: white; padding: 10px 15px; border: none; border-radius: 4px; cursor: pointer; font-size: 16px; | |
margin-right: 10px; transition: background-color 0.3s ease; | |
} | |
button:hover { background-color: #0056b3; } | |
button:disabled { background-color: #cccccc; cursor: not-allowed; } | |
.output { | |
background-color: #f9f9f9; border: 1px solid #ddd; padding: 15px; border-radius: 4px; white-space: pre-wrap; word-wrap: break-word; max-height: 400px; overflow-y: auto; font-family: monospace; | |
} | |
.error { color: red; font-weight: bold; } | |
.info { color: green; } | |
.warning { color: orange; } | |
.param-info { font-style: italic; color: #555; margin-bottom: 10px; } | |
.checkbox-container { display: flex; align-items: center; margin-bottom: 10px; } | |
.checkbox-container input { margin-right: 5px; width: auto; } | |
.continuation-info { font-weight: bold; } | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>LLM API Demonstration</h1> | |
<div class="section"> | |
<h2>Health Check</h2> | |
<button id="healthCheckBtn">Check Health</button> | |
<p id="healthStatus"></p> | |
</div> | |
<div class="section"> | |
<h2>API Info</h2> | |
<button id="apiInfoBtn">Get Info</button> | |
<pre id="apiInfoOutput" class="output"></pre> | |
</div> | |
<div class="section"> | |
<h2>Generate Text (Automatic Continuation with Context Management)</h2> | |
<label for="promptInput">Prompt / First User Message:</label> | |
<textarea id="promptInput" placeholder="Enter your prompt here..."></textarea> | |
<div class="param-info">Note: No artificial token limit. Generation continues until the model stops naturally, hits a stop sequence, or reaches the context window limit (N_CTX={{ ACTUAL_N_CTX }}). If the context limit is reached, the server will attempt to continue automatically by <strong class="continuation-info">truncating older messages (unlimited continuations)</strong>. Other parameters (Temperature, Top P, Top K, Repeat Penalty, Seed) are fixed/random per generation cycle.</div> | |
<div> | |
<label for="stopInput">Stop Sequences (comma-separated):</label> | |
<input type="text" id="stopInput" value=""> | |
</div> | |
<div> | |
<label for="systemPromptInput">System Prompt (Optional Override - default: "{{ DEFAULT_SYSTEM_PROMPT | escape }}"):</label> | |
<input type="text" id="systemPromptInput" placeholder="Leave empty to use default"> | |
</div> | |
<div> | |
<label for="formatSelect">Format:</label> | |
<select id="formatSelect"> | |
<option value="">None</option> | |
<option value="markdown">Markdown</option> | |
</select> | |
</div> | |
<div class="checkbox-container"> | |
<input type="checkbox" id="streamCheckbox" checked> | |
<label for="streamCheckbox">Stream Output</label> | |
</div> | |
<button id="generateBtn">Generate</button> | |
<p id="generationStatus"></p> | |
<pre id="generationOutput" class="output"></pre> | |
</div> | |
</div> | |
<script> | |
const healthCheckBtn = document.getElementById('healthCheckBtn'); | |
const healthStatus = document.getElementById('healthStatus'); | |
const apiInfoBtn = document.getElementById('apiInfoBtn'); | |
const apiInfoOutput = document.getElementById('apiInfoOutput'); | |
const promptInput = document.getElementById('promptInput'); | |
const stopInput = document.getElementById('stopInput'); | |
const systemPromptInput = document.getElementById('systemPromptInput'); | |
const formatSelect = document.getElementById('formatSelect'); | |
const streamCheckbox = document.getElementById('streamCheckbox'); | |
const generateBtn = document.getElementById('generateBtn'); | |
const generationOutput = document.getElementById('generationOutput'); | |
const generationStatus = document.getElementById('generationStatus'); | |
const API_BASE_URL = window.location.origin; | |
async function checkHealth() { | |
healthStatus.textContent = 'Checking...'; | |
healthStatus.className = ''; | |
try { | |
const response = await fetch(`${API_BASE_URL}/health`); | |
const data = await response.json(); | |
healthStatus.textContent = `Status: ${data.status}, Message: ${data.message}`; | |
healthStatus.className = data.status === 'ok' ? 'info' : (data.status === 'warning' ? 'warning' : 'error'); | |
} catch (error) { | |
healthStatus.textContent = `Error fetching health: ${error}`; | |
healthStatus.className = 'error'; | |
} | |
} | |
async function getApiInfo() { | |
apiInfoOutput.textContent = 'Loading...'; | |
apiInfoOutput.className = 'output'; | |
try { | |
const response = await fetch(`${API_BASE_URL}/info`); | |
if (!response.ok) { | |
try { | |
const errorData = await response.json(); | |
throw new Error(`API Error ${response.status}: ${errorData.error || JSON.stringify(errorData)}`); | |
} catch (e) { | |
throw new Error(`API Error ${response.status}: ${response.statusText}`); | |
} | |
} | |
const data = await response.json(); | |
const nCtx = data?.model_config?.loaded_model_details?.n_ctx || '{{ ACTUAL_N_CTX }}'; | |
const maxCont = data?.generation_parameters?.max_automatic_continuations; | |
const maxContDesc = maxCont === null ? "unlimited continuations" : `up to ${maxCont} times`; | |
const description = document.querySelector('.param-info'); | |
if (description) { | |
description.innerHTML = `Note: No artificial token limit. Generation continues until the model stops naturally, hits a stop sequence, or reaches the context window limit (N_CTX=${nCtx}). If the context limit is reached, the server will attempt to continue automatically by <strong class="continuation-info">truncating older messages (${maxContDesc})</strong>. Other parameters (Temperature, Top P, Top K, Repeat Penalty, Seed) are fixed/random per generation cycle.`; | |
} | |
apiInfoOutput.textContent = JSON.stringify(data, null, 2); | |
} catch (error) { | |
apiInfoOutput.textContent = `Error fetching info: ${error}`; | |
apiInfoOutput.className = 'output error'; | |
} | |
} | |
async function generateText() { | |
generationOutput.textContent = ''; | |
generationStatus.textContent = 'Preparing request...'; | |
generationStatus.className = ''; | |
generateBtn.disabled = true; | |
const prompt = promptInput.value; | |
if (!prompt.trim()) { | |
generationStatus.textContent = 'Error: Prompt cannot be empty.'; | |
generationStatus.className = 'error'; | |
generateBtn.disabled = false; | |
return; | |
} | |
const messages = [{"role": "user", "content": prompt}]; | |
const stream = streamCheckbox.checked; | |
const format = formatSelect.value || undefined; | |
const stopSequences = stopInput.value.split(',').map(s => s.trim()).filter(s => s.length > 0); | |
const stop = stopSequences.length > 0 ? stopSequences : undefined; | |
const systemPrompt = systemPromptInput.value.trim() || undefined; | |
const requestBody = { | |
messages: messages, | |
stop: stop, | |
stream: stream, | |
format: format, | |
system_prompt: systemPrompt | |
}; | |
generationStatus.textContent = 'Generating... (may continue automatically with context truncation if needed)'; | |
generationStatus.className = 'info'; | |
try { | |
const response = await fetch(`${API_BASE_URL}/generate`, { | |
method: 'POST', | |
headers: { 'Content-Type': 'application/json' }, | |
body: JSON.stringify(requestBody), | |
}); | |
if (!response.ok) { | |
const errorText = await response.text(); | |
let errorMessage = `Error: ${response.status} ${response.statusText}`; | |
try { | |
const errorData = JSON.parse(errorText); | |
errorMessage += ` - ${errorData.error || JSON.stringify(errorData.detail || errorData)}`; | |
} catch (jsonParseError) { | |
errorMessage += ` - ${errorText}`; | |
} | |
generationStatus.textContent = errorMessage; | |
generationStatus.className = 'error'; | |
generateBtn.disabled = false; | |
return; | |
} | |
if (stream) { | |
const reader = response.body.getReader(); | |
const decoder = new TextDecoder('utf-8'); | |
let finished = false; | |
generationOutput.textContent = ''; | |
let continuationCount = 0; | |
let lastStatusUpdate = Date.now(); | |
while (!finished) { | |
const { done, value } = await reader.read(); | |
if (done) { | |
finished = true; | |
if (!generationStatus.textContent.includes("finished") && !generationStatus.textContent.includes("stopped") && !generationStatus.textContent.includes("Error") && !generationStatus.textContent.includes("Max continuations")) { | |
generationStatus.textContent = `Streaming finished. Total continuations: ${continuationCount}.`; | |
generationStatus.className = 'info'; | |
} | |
break; | |
} | |
const chunk = decoder.decode(value, { stream: true }); | |
const continueMatch = chunk.match(/\n\[CONTINUING (\d+) - TRUNCATING CONTEXT\.\.\.\]\n/); | |
const errorMatch = chunk.match(/\n\[ERROR\](.*)/); | |
const infoMatch = chunk.match(/\n\[INFO\](.*)/); | |
if (continueMatch) { | |
continuationCount = parseInt(continueMatch[1]); | |
generationOutput.textContent += chunk; | |
generationStatus.textContent = `Context limit reached, truncating history and continuing generation (Continuation #${continuationCount})...`; | |
generationStatus.className = 'warning continuation-info'; | |
lastStatusUpdate = Date.now(); | |
} else if (errorMatch) { | |
generationOutput.textContent += chunk; | |
generationStatus.textContent = `Error during generation: ${errorMatch[1]}`; | |
generationStatus.className = 'error'; | |
finished = true; | |
} else if (infoMatch) { | |
generationOutput.textContent += chunk; | |
generationStatus.textContent = `Generation info: ${infoMatch[1]}. Total continuations: ${continuationCount}.`; | |
generationStatus.className = 'info'; | |
if (infoMatch[1].includes("stopped") || infoMatch[1].includes("finished") || infoMatch[1].includes("Max continuations")) { | |
finished = true; | |
} | |
lastStatusUpdate = Date.now(); | |
} | |
else { | |
generationOutput.textContent += chunk; | |
if (Date.now() - lastStatusUpdate > 1000 && !generationStatus.className.includes('warning') && !generationStatus.className.includes('error')) { | |
generationStatus.textContent = `Streaming... (Continuation #${continuationCount})`; | |
generationStatus.className = 'info'; | |
lastStatusUpdate = Date.now(); | |
} | |
} | |
generationOutput.scrollTop = generationOutput.scrollHeight; | |
} | |
} else { | |
const text = await response.text(); | |
const finishReason = response.headers.get('X-Finish-Reason'); | |
const continuations = response.headers.get('X-Continuations'); | |
const usageCompletionTokens = response.headers.get('X-Usage-Completion-Tokens'); | |
generationOutput.textContent = text; | |
let statusText = `Generation finished. Reason: ${finishReason || 'unknown'}.`; | |
const contCount = parseInt(continuations || '0'); | |
if (contCount > 0) { | |
statusText += ` Continuations: ${contCount} (context truncated).`; | |
generationStatus.className = 'warning continuation-info'; | |
} else { | |
generationStatus.className = 'info'; | |
} | |
if (usageCompletionTokens && usageCompletionTokens !== 'N/A') statusText += ` Completion Tokens: ~${usageCompletionTokens}.`; | |
if (text.includes("[ERROR]")) { | |
statusText = "Generation finished with errors. See output." | |
generationStatus.className = 'error'; | |
} | |
generationStatus.textContent = statusText; | |
} | |
} catch (error) { | |
generationStatus.textContent = `Network or processing error: ${error}`; | |
generationStatus.className = 'error'; | |
generationOutput.textContent += `\n\n[ERROR] Network or processing error: ${error}`; | |
} finally { | |
generateBtn.disabled = false; | |
} | |
} | |
healthCheckBtn.addEventListener('click', checkHealth); | |
apiInfoBtn.addEventListener('click', getApiInfo); | |
generateBtn.addEventListener('click', generateText); | |
checkHealth(); | |
getApiInfo(); | |
</script> | |
</body> | |
</html> | |
""" | |
def index(): | |
rendered_html = render_template_string( | |
html_code, | |
ACTUAL_N_CTX=ACTUAL_N_CTX, | |
DEFAULT_SYSTEM_PROMPT=DEFAULT_SYSTEM_PROMPT | |
) | |
return rendered_html | |
def health_check(): | |
if llm: | |
if hasattr(llm, 'tokenize') and hasattr(llm, 'apply_chat_template'): | |
return jsonify(status="ok", message="Model is loaded and ready."), 200 | |
else: | |
logger.warning("Model loaded, but tokenizer or chat template functions might be missing.", extra={'request_id': getattr(g, 'request_id', 'N/A')}) | |
return jsonify(status="warning", message="Model loaded, but critical functions (tokenize/apply_chat_template) might be missing."), 200 | |
else: | |
return jsonify(status="error", message="Model failed to load or is not available."), 503 | |
def model_info(): | |
request_id = getattr(g, 'request_id', 'N/A') | |
if not llm: | |
logger.warning("Info request received but model is not loaded.", extra={'request_id': request_id}) | |
return jsonify(error="Model not available."), 503 | |
model_details: Union[Dict[str, Any], str] = "Model details unavailable" | |
try: | |
n_embd = get_property_or_method_value(get_property_or_method_value(llm, '_model'), 'n_embd', 'N/A') | |
model_details = { | |
"n_embd": n_embd, | |
"n_ctx": ACTUAL_N_CTX, | |
"n_batch": ACTUAL_N_BATCH, | |
"n_gpu_layers": ACTUAL_N_GPU_LAYERS, | |
"tokenizer_present": hasattr(llm, 'tokenize'), | |
"chat_handler_present": hasattr(llm, 'apply_chat_template') and hasattr(llm, 'create_chat_completion'), | |
} | |
except Exception as e: | |
logger.warning(f"Could not retrieve all model details: {e}", extra={'request_id': request_id}, exc_info=True) | |
model_details = f"Error retrieving some model details: {e}" | |
info = { | |
"status": "ok", | |
"message": "Model is loaded. Generation continues automatically with context truncation if context limit is hit.", | |
"model_config": { | |
"repo_id": MODEL_REPO, | |
"filename": MODEL_FILE, | |
"initial_load_config": { | |
"n_ctx": N_CTX_CONFIG, | |
"n_batch": N_BATCH, | |
"n_gpu_layers": N_GPU_LAYERS_CONFIG, | |
}, | |
"loaded_model_details": model_details, | |
}, | |
"generation_parameters": { | |
"note": f"No artificial 'max_tokens' limit. Generation proceeds until stop sequence, EOS, or context limit (N_CTX={ACTUAL_N_CTX}). Automatic continuation attempts by truncating context occur up to {MAX_CONTINUATIONS if MAX_CONTINUATIONS >= 0 else 'unlimited'} times if context limit is reached. Sampling parameters (temperature, top_p, top_k) are chosen randomly per request/continuation cycle from predefined sets. Repeat penalty and seed are fixed.", | |
"fixed_max_tokens": None, | |
"fixed_repeat_penalty": FIXED_REPEAT_PENALTY, | |
"fixed_seed": FIXED_SEED, | |
"max_automatic_continuations": MAX_CONTINUATIONS if MAX_CONTINUATIONS >= 0 else None, | |
"context_truncation_buffer_ratio": CONTEXT_TRUNCATION_BUFFER_RATIO, | |
"randomly_chosen_from": RANDOM_PARAMS_CHOICES, | |
"default_system_prompt": DEFAULT_SYSTEM_PROMPT, | |
"user_controllable": ["messages", "prompt", "stop", "stream", "format", "system_prompt"], | |
}, | |
} | |
return jsonify(info), 200 | |
def _generate_single_cycle(messages: List[Dict[str, str]], params: Dict, stream: bool, request_id: str) -> Union[Generator[Dict, None, None], Dict]: | |
try: | |
logger.debug(f"Starting llama.cpp chat completion call. Stream: {stream}. Messages: {len(messages)}. Params summary: temp={params.get('temperature')}, top_p={params.get('top_p')}, top_k={params.get('top_k')}, stop={params.get('stop')}", extra={'request_id': request_id, 'stream': stream, 'message_count': len(messages)}) | |
result = llm.create_chat_completion( | |
messages=messages, | |
max_tokens=params["max_tokens"], | |
temperature=params["temperature"], | |
top_p=params["top_p"], | |
top_k=params["top_k"], | |
repeat_penalty=params["repeat_penalty"], | |
stop=params["stop"], | |
seed=params["seed"], | |
stream=stream, | |
) | |
return result | |
except Exception as e: | |
err_str = str(e).lower() | |
if "context window is full" in err_str or \ | |
"kv cache is full" in err_str or \ | |
"llama_decode" in err_str or \ | |
(hasattr(e, 'condition') and isinstance(e.condition, str) and ("context length" in e.condition.lower() or "failed to decode" in e.condition.lower())): | |
logger.warning(f"Caught N_CTX limit or related exception: {e}", extra={'request_id': request_id}) | |
raise ContextLimitException(str(e)) from e | |
else: | |
logger.error(f"Unhandled error during llama.cpp call: {e}", exc_info=True, extra={'request_id': request_id}) | |
raise GenerationFailedException(f"Unhandled llama.cpp error: {str(e)}") from e | |
def generate(): | |
request_id = getattr(g, 'request_id', 'N/A') | |
if not llm: | |
logger.error("Generate request received but model is not loaded.", extra={'request_id': request_id}) | |
return jsonify(error="Model is not available.", detail="The LLM model could not be loaded."), 503 | |
if not request.is_json: | |
logger.warning("Request received without Content-Type: application/json", extra={'request_id': request_id}) | |
return jsonify(error="Invalid request header", detail="Content-Type must be application/json"), 415 | |
data = request.get_json() | |
is_streaming = data.get("stream", True) | |
response_format = data.get("format") | |
log_data_summary = {k: v for k, v in data.items() if k not in ('messages', 'prompt')} | |
log_data_summary['messages_count_initial'] = len(data.get('messages', [])) if 'messages' in data else 0 | |
log_data_summary['has_prompt_initial'] = 'prompt' in data | |
log_data_summary['stream'] = is_streaming | |
log_data_summary['format'] = response_format | |
logger.info(f"Received generation request summary.", extra={'request_id': request_id, 'summary': log_data_summary}) | |
try: | |
initial_messages = prepare_messages(data, format=response_format, request_id=request_id) | |
base_params: Dict[str, Any] = { | |
"max_tokens": None, | |
"repeat_penalty": FIXED_REPEAT_PENALTY, | |
"seed": FIXED_SEED, | |
} | |
stop = data.get("stop") | |
if stop is not None: | |
if isinstance(stop, list) and all(isinstance(s, str) for s in stop): | |
base_params["stop"] = stop | |
elif isinstance(stop, str): | |
base_params["stop"] = [stop] | |
else: | |
raise ValueError({"stop": "Stop must be a string or a list of strings"}) | |
else: | |
base_params["stop"] = None | |
effective_n_ctx = get_effective_n_ctx() | |
input_token_count = estimate_token_count(initial_messages, request_id=request_id) | |
if input_token_count != -1 and input_token_count > effective_n_ctx * CONTEXT_TRUNCATION_BUFFER_RATIO: | |
logger.warning(f"Initial input (~{input_token_count} tokens) likely exceeds safe context window ({int(effective_n_ctx * CONTEXT_TRUNCATION_BUFFER_RATIO)}). Attempting truncation.", extra={'request_id': request_id, 'initial_tokens': input_token_count, 'n_ctx': effective_n_ctx, 'buffer_ratio': CONTEXT_TRUNCATION_BUFFER_RATIO}) | |
truncated_initial = truncate_messages_for_context(initial_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO, request_id=request_id) | |
truncated_token_count = estimate_token_count(truncated_initial, request_id=request_id) | |
if not truncated_initial or (truncated_token_count != -1 and truncated_token_count > effective_n_ctx): | |
error_msg = f"Input exceeds context window ({effective_n_ctx}) even after attempting truncation. Input tokens (~{input_token_count}) / Truncated tokens (~{truncated_token_count}). Reduce initial message size." | |
logger.error(error_msg, extra={'request_id': request_id, 'initial_tokens': input_token_count, 'truncated_tokens': truncated_token_count, 'n_ctx': effective_n_ctx}) | |
return jsonify(error="Input exceeds context window", detail=error_msg), 400 | |
else: | |
logger.info(f"Initial input truncated from ~{input_token_count} to ~{truncated_token_count} tokens.", extra={'request_id': request_id, 'initial_tokens': input_token_count, 'truncated_tokens': truncated_token_count, 'n_ctx': effective_n_ctx}) | |
initial_messages = truncated_initial | |
input_token_count = truncated_token_count | |
elif input_token_count != -1: | |
logger.info(f"Initial input token count: ~{input_token_count}. Effective context window: {effective_n_ctx}. Context buffer target: {int(effective_n_ctx * CONTEXT_TRUNCATION_BUFFER_RATIO)}. Remaining: {effective_n_ctx - input_token_count}.", extra={'request_id': request_id, 'input_tokens': input_token_count, 'n_ctx': effective_n_ctx, 'buffer_target': int(effective_n_ctx * CONTEXT_TRUNCATION_BUFFER_RATIO), 'remaining_ctx': effective_n_ctx - input_token_count}) | |
else: | |
logger.warning("Could not estimate initial token count. Proceeding, may hit context limit.", extra={'request_id': request_id}) | |
except ValueError as e: | |
logger.error(f"Invalid input data or parameters: {e}", exc_info=True, extra={'request_id': request_id}) | |
try: error_detail = json.loads(str(e)) | |
except json.JSONDecodeError: error_detail = str(e) | |
return jsonify(error="Invalid input", detail=error_detail), 400 | |
except Exception as e: | |
logger.error(f"Unexpected error preparing request: {e}", exc_info=True, extra={'request_id': request_id}) | |
return jsonify(error="Internal server error", detail="An unexpected error occurred preparing the request."), 500 | |
current_messages = list(initial_messages) | |
continuations = 0 | |
total_completion_tokens_generated = 0 | |
final_finish_reason = "unknown" | |
final_usage = {} | |
full_generated_text_nonstream = "" | |
effective_n_ctx = get_effective_n_ctx() | |
def streaming_generator(req_id): | |
nonlocal current_messages, continuations, total_completion_tokens_generated, final_finish_reason, final_usage | |
while True: | |
if MAX_CONTINUATIONS >= 0 and continuations > MAX_CONTINUATIONS: | |
logger.info(f"Max continuations ({MAX_CONTINUATIONS}) reached. Stopping streaming.", extra={'request_id': req_id}) | |
yield f"\n[INFO] Generation stopped: Max continuations reached ({MAX_CONTINUATIONS})." | |
final_finish_reason = "max_continuations" | |
break | |
cycle_number = continuations + 1 | |
logger.info(f"Starting streaming generation cycle {cycle_number}. Message count: {len(current_messages)}.", extra={'request_id': req_id, 'cycle': cycle_number, 'message_count': len(current_messages)}) | |
chosen_params = random.choice(RANDOM_PARAMS_CHOICES) | |
current_params = {**base_params, **chosen_params} | |
generated_this_cycle_content = "" | |
finish_reason = None | |
usage_this_cycle = {} | |
try: | |
streamer = _generate_single_cycle(current_messages, current_params, stream=True, request_id=req_id) | |
for chunk in streamer: | |
choice = chunk.get("choices", [{}])[0] | |
delta = choice.get("delta", {}) | |
token_content = delta.get("content") | |
chunk_finish_reason = choice.get("finish_reason") | |
chunk_usage = chunk.get("usage", {}) | |
if token_content: | |
generated_this_cycle_content += token_content | |
yield token_content | |
if chunk_finish_reason: | |
finish_reason = chunk_finish_reason | |
usage_this_cycle = chunk_usage | |
final_usage = usage_this_cycle | |
break | |
if not finish_reason and generated_this_cycle_content: | |
finish_reason = "end_of_stream" | |
logger.warning(f"Streaming cycle {cycle_number} ended without explicit finish reason.", extra={'request_id': req_id, 'cycle': cycle_number}) | |
except ContextLimitException as e: | |
logger.warning(f"Context limit caught during streaming cycle {cycle_number}.", extra={'request_id': req_id, 'cycle': cycle_number}) | |
finish_reason = 'length' | |
yield f"\n[INFO] Context limit approached in cycle {cycle_number}. Attempting continuation...\n" | |
except GenerationFailedException as e: | |
logger.error(f"Generation failed in streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': req_id, 'cycle': cycle_number}) | |
yield f"\n[ERROR] Generation failed unexpectedly in cycle {cycle_number}: {e}" | |
final_finish_reason = "error" | |
break | |
except Exception as e: | |
logger.error(f"An unexpected error occurred in streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': req_id, 'cycle': cycle_number}) | |
yield f"\n[ERROR] An unexpected error occurred in cycle {cycle_number}: {str(e)}" | |
final_finish_reason = "error" | |
break | |
if generated_this_cycle_content: | |
if not current_messages or current_messages[-1].get('role') != 'assistant': | |
current_messages.append({"role": "assistant", "content": generated_this_cycle_content}) | |
else: | |
current_messages[-1]['content'] += generated_this_cycle_content | |
total_completion_tokens_generated += usage_this_cycle.get("completion_tokens", 0) | |
if finish_reason == 'stop' or finish_reason == 'end_of_stream': | |
logger.info(f"Streaming generation stopped naturally in cycle {cycle_number}. Reason: {finish_reason}", extra={'request_id': req_id, 'cycle': cycle_number, 'finish_reason': finish_reason}) | |
final_finish_reason = finish_reason if finish_reason != 'end_of_stream' else 'stop' | |
yield f"\n[INFO] Generation finished." | |
break | |
elif finish_reason == 'length': | |
continuations += 1 | |
logger.warning(f"N_CTX limit reached in streaming cycle {cycle_number}. Attempting continuation {continuations}.", extra={'request_id': req_id, 'cycle': cycle_number, 'continuations': continuations}) | |
current_messages = truncate_messages_for_context(current_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO, request_id=req_id) | |
if not current_messages or (len(current_messages) == 1 and current_messages[0].get("role") == "system"): | |
logger.error("Context truncation resulted in empty or system-only messages during streaming. Stopping.", extra={'request_id': req_id, 'cycle': cycle_number}) | |
yield f"\n[ERROR] Generation failed: Context truncation error." | |
final_finish_reason = "truncation_error" | |
break | |
yield f"\n[CONTINUING {continuations} - TRUNCATING CONTEXT...]\n" | |
time.sleep(0.05) | |
continue | |
else: | |
logger.warning(f"Streaming generation cycle {cycle_number} ended with unexpected reason '{finish_reason}'. Stopping generation.", extra={'request_id': req_id, 'cycle': cycle_number, 'finish_reason': finish_reason}) | |
yield f"\n[INFO] Generation stopped: Reason: {finish_reason or 'Unknown'}" | |
final_finish_reason = finish_reason or "unknown" | |
break | |
logger.info(f"Streaming generation stream closed. Total continuations: {continuations}. Final reason: {final_finish_reason}", extra={'request_id': req_id, 'continuations': continuations, 'final_reason': final_finish_reason}) | |
if is_streaming: | |
headers = { | |
"Content-Type": "text/event-stream; charset=utf-8", | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
"X-Accel-Buffering": "no", | |
"X-Request-ID": request_id | |
} | |
return Response(stream_with_context(streaming_generator(request_id)), headers=headers) | |
else: | |
while True: | |
if MAX_CONTINUATIONS >= 0 and continuations > MAX_CONTINUATIONS: | |
logger.info(f"Max continuations ({MAX_CONTINUATIONS}) reached. Stopping non-streaming.", extra={'request_id': request_id}) | |
if full_generated_text_nonstream: | |
full_generated_text_nonstream += "\n\n" | |
full_generated_text_nonstream += f"[INFO: Generation stopped: Max continuations reached ({MAX_CONTINUATIONS}).]" | |
final_finish_reason = "max_continuations" | |
break | |
cycle_number = continuations + 1 | |
logger.info(f"Starting non-streaming generation cycle {cycle_number}. Message count: {len(current_messages)}.", extra={'request_id': request_id, 'cycle': cycle_number, 'message_count': len(current_messages)}) | |
chosen_params = random.choice(RANDOM_PARAMS_CHOICES) | |
current_params = {**base_params, **chosen_params} | |
logger.debug(f"Cycle {cycle_number} params: temp={current_params['temperature']}, top_p={current_params['top_p']}, top_k={current_params['top_k']}, stop={current_params['stop']}", extra={'request_id': request_id, 'cycle': cycle_number, 'params': current_params}) | |
generated_this_cycle_content = "" | |
finish_reason = None | |
usage_this_cycle = {} | |
try: | |
result = _generate_single_cycle(current_messages, current_params, stream=False, request_id=request_id) | |
if result and "choices" in result and result["choices"]: | |
choice = result["choices"][0] | |
generated_this_cycle_content = choice.get("message", {}).get("content", "") | |
finish_reason = choice.get("finish_reason", "unknown") | |
usage_this_cycle = result.get("usage", {}) | |
final_usage = usage_this_cycle | |
else: | |
logger.error(f"Invalid response structure from llama_cpp in non-streaming cycle {cycle_number}: {result}", extra={'request_id': request_id, 'cycle': cycle_number, 'result': result}) | |
if full_generated_text_nonstream: | |
full_generated_text_nonstream += "\n\n" | |
full_generated_text_nonstream += f"[ERROR: Invalid response structure from model in cycle {cycle_number}.]" | |
final_finish_reason = "internal_error" | |
break | |
logger.info(f"Non-streaming cycle {cycle_number} finished. Reason: {finish_reason}. Usage: {usage_this_cycle}", extra={'request_id': request_id, 'cycle': cycle_number, 'usage': usage_this_cycle, 'finish_reason': finish_reason}) | |
except ContextLimitException: | |
logger.warning(f"Context limit caught during non-streaming cycle {cycle_number}.", extra={'request_id': request_id, 'cycle': cycle_number}) | |
finish_reason = 'length' | |
except GenerationFailedException as e: | |
logger.error(f"Generation failed in non-streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': request_id, 'cycle': cycle_number}) | |
if full_generated_text_nonstream: | |
full_generated_text_nonstream += "\n\n" | |
full_generated_text_nonstream += f"[ERROR: Generation failed unexpectedly in cycle {cycle_number}: {e}]" | |
final_finish_reason = "error" | |
break | |
except Exception as e: | |
logger.error(f"An unexpected error occurred in non-streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': request_id, 'cycle': cycle_number}) | |
if full_generated_text_nonstream: | |
full_generated_text_nonstream += "\n\n" | |
full_generated_text_nonstream += f"[ERROR: An unexpected error occurred in cycle {cycle_number}: {str(e)}]" | |
final_finish_reason = "error" | |
break | |
if generated_this_cycle_content: | |
if continuations > 0 and full_generated_text_nonstream: | |
full_generated_text_nonstream += f"\n\n[CONTINUATION {continuations} - TRUNCATED CONTEXT]\n\n" | |
full_generated_text_nonstream += generated_this_cycle_content | |
if not current_messages or current_messages[-1].get('role') != 'assistant': | |
current_messages.append({"role": "assistant", "content": generated_this_cycle_content}) | |
else: | |
current_messages[-1]['content'] += generated_this_cycle_content | |
tokens_generated_cycle = usage_this_cycle.get("completion_tokens", 0) | |
total_completion_tokens_generated += tokens_generated_cycle | |
elif finish_reason == 'length': | |
logger.warning(f"Non-streaming N_CTX limit hit in cycle {cycle_number} but no completion tokens reported in usage.", extra={'request_id': request_id, 'cycle': cycle_number}) | |
if continuations > 0 and full_generated_text_nonstream: | |
full_generated_text_nonstream += f"\n\n[CONTINUATION {continuations} - TRUNCATED CONTEXT - NO OUTPUT THIS CYCLE]\n\n" | |
if finish_reason == 'stop': | |
logger.info(f"Non-streaming generation stopped naturally (reason: stop) in cycle {cycle_number}.", extra={'request_id': request_id, 'cycle': cycle_number, 'finish_reason': finish_reason}) | |
final_finish_reason = 'stop' | |
break | |
elif finish_reason == 'length': | |
continuations += 1 | |
logger.warning(f"Non-streaming N_CTX limit reached in cycle {cycle_number}. Attempting continuation {continuations}.", extra={'request_id': request_id, 'cycle': cycle_number, 'continuations': continuations}) | |
current_messages = truncate_messages_for_context(current_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO, request_id=request_id) | |
if not current_messages or (len(current_messages) == 1 and current_messages[0].get("role") == "system"): | |
logger.error("Context truncation resulted in empty or system-only messages during non-streaming. Stopping.", extra={'request_id': request_id, 'cycle': cycle_number}) | |
if full_generated_text_nonstream: | |
full_generated_text_nonstream += "\n\n" | |
full_generated_text_nonstream += f"[ERROR: Generation failed: Context truncation error.]" | |
final_finish_reason = "truncation_error" | |
break | |
continue | |
else: | |
logger.warning(f"Non-streaming cycle {cycle_number} ended with reason '{finish_reason}' or unexpectedly. Stopping generation.", extra={'request_id': request_id, 'cycle': cycle_number, 'finish_reason': finish_reason}) | |
if full_generated_text_nonstream: | |
full_generated_text_nonstream += "\n\n" | |
full_generated_text_nonstream += f"[INFO: Generation stopped unexpectedly. Reason: {finish_reason or 'Unknown'}]" | |
final_finish_reason = finish_reason or "unknown" | |
break | |
logger.info(f"Non-streaming generation finished after {continuations} continuations. Total completion tokens generated: {total_completion_tokens_generated}. Final reason: {final_finish_reason}", extra={'request_id': request_id, 'continuations': continuations, 'total_completion_tokens': total_completion_tokens_generated, 'final_reason': final_finish_reason}) | |
response = Response(full_generated_text_nonstream, mimetype="text/plain; charset=utf-8") | |
response.headers["X-Request-ID"] = request_id | |
response.headers["X-Finish-Reason"] = final_finish_reason | |
response.headers["X-Continuations"] = str(continuations) | |
response.headers["X-Usage-Completion-Tokens"] = str(total_completion_tokens_generated) | |
response.headers["X-Usage-Prompt-Tokens-Last-Cycle"] = str(final_usage.get("prompt_tokens", "N/A")) | |
response.headers["X-Usage-Total-Tokens-Last-Cycle"] = str(final_usage.get("total_tokens", "N/A")) | |
return response | |
if __name__ == "__main__": | |
host = os.getenv("HOST", "0.0.0.0") | |
port = int(os.getenv("PORT", "7860")) | |
is_debug = os.getenv("FLASK_DEBUG", "0") == "1" | |
log_level = logging.DEBUG if is_debug else logging.INFO | |
logger.setLevel(log_level) | |
max_cont_desc = MAX_CONTINUATIONS if MAX_CONTINUATIONS >= 0 else 'UNLIMITED' | |
logger.info(f"Starting Flask server on {host}:{port} (Debug mode: {is_debug})") | |
logger.info(f"Model: {MODEL_REPO}/{MODEL_FILE}, N_CTX={ACTUAL_N_CTX}, Automatic Continuations: {max_cont_desc} (with context truncation)") | |
if not llm: | |
logger.critical("MODEL FAILED TO LOAD. SERVER WILL START BUT '/generate' WILL FAIL.") | |
logger.info("Running with Flask development server.") | |
app.run(host=host, port=port, threaded=True, debug=is_debug, use_reloader=False) |