#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import json
import logging
import uuid
from typing import List, Dict, Union, Optional, Generator, Any
import random
import time
from flask import Flask, request, Response, stream_with_context, jsonify, g, render_template_string
from llama_cpp import Llama
class JsonFormatter(logging.Formatter):
def format(self, record):
log_record = {
"timestamp": self.formatTime(record, self.datefmt),
"level": record.levelname,
"name": record.name,
"message": record.getMessage(),
"pathname": record.pathname,
"lineno": record.lineno,
}
if hasattr(record, 'request_id'):
log_record['request_id'] = record.request_id
if record.exc_info:
log_record['exception'] = self.formatException(record.exc_info)
if record.stack_info:
log_record['stack_info'] = self.formatStack(record.stack_info)
skip_keys = {'message', 'asctime', 'levelname', 'levelno', 'pathname', 'filename', 'module', 'funcName', 'lineno', 'created', 'msecs', 'relativeCreated', 'thread', 'threadName', 'process', 'processName', 'exc_info', 'exc_text', 'stack_info', 'request_id'}
for key, value in record.__dict__.items():
if not key.startswith('_') and key not in log_record and key not in skip_keys:
# Ensure value is JSON serializable
try:
json.dumps(value)
log_record[key] = value
except TypeError:
log_record[key] = str(value) # Convert non-serializable types to string
except Exception:
log_record[key] = "[Unserializable Value]"
return json.dumps(log_record)
def setup_logging():
logger = logging.getLogger()
if not logger.handlers:
handler = logging.StreamHandler()
formatter = JsonFormatter()
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
logging.getLogger("werkzeug").setLevel(logging.ERROR)
logging.getLogger("llama_cpp").setLevel(logging.WARNING)
return logger
logger = setup_logging()
MODEL_REPO = os.getenv("MODEL_REPO", "jnjj/vcvcvcv")
MODEL_FILE = os.getenv("MODEL_FILE", "gemma-3-4b-it-q4_0.gguf")
N_CTX_CONFIG = int(os.getenv("N_CTX", "2048"))
N_BATCH = int(os.getenv("N_BATCH", "512"))
N_GPU_LAYERS_CONFIG = int(os.getenv("N_GPU_LAYERS", "0"))
MAX_CONTINUATIONS = int(os.getenv("MAX_CONTINUATIONS", "-1"))
FIXED_REPEAT_PENALTY = float(os.getenv("FIXED_REPEAT_PENALTY", "1.1"))
FIXED_SEED = int(os.getenv("FIXED_SEED", "-1"))
DEFAULT_SYSTEM_PROMPT = os.getenv("DEFAULT_SYSTEM_PROMPT", "Eres un asistente conciso, directo y Ăștil.")
CONTEXT_TRUNCATION_BUFFER_RATIO = float(os.getenv("CONTEXT_TRUNCATION_BUFFER_RATIO", "0.85"))
RANDOM_PARAMS_CHOICES = [
{"temperature": 0.2, "top_p": 0.5, "top_k": 10},
{"temperature": 0.1, "top_p": 0.5, "top_k": 10},
{"temperature": 0.3, "top_p": 0.5, "top_k": 10},
{"temperature": 0.4, "top_p": 0.5, "top_k": 10},
{"temperature": 0.6, "top_p": 0.3, "top_k": 5},
{"temperature": 0.5, "top_p": 0.7, "top_k": 20},
]
llm: Optional[Llama] = None
ACTUAL_N_CTX: int = N_CTX_CONFIG
ACTUAL_N_BATCH: int = N_BATCH
ACTUAL_N_GPU_LAYERS: int = N_GPU_LAYERS_CONFIG
class ContextLimitException(Exception):
pass
class GenerationFailedException(Exception):
pass
def prepare_messages(data: Dict, format: Optional[str] = None, request_id: str = 'N/A') -> List[Dict[str, str]]:
messages_list = data.get("messages")
prompt_str = data.get("prompt")
system_instruction = data.get("system_prompt", DEFAULT_SYSTEM_PROMPT)
if not messages_list and not prompt_str:
raise ValueError("Either 'messages' list or 'prompt' string is required.")
if messages_list is not None and not isinstance(messages_list, list):
raise ValueError("'messages' must be a list of dictionaries.")
if prompt_str is not None and not isinstance(prompt_str, str):
raise ValueError("'prompt' must be a string.")
if system_instruction is not None and not isinstance(system_instruction, str):
raise ValueError("'system_prompt' must be a string.")
final_messages = []
content_format_instruction = ""
if format == "markdown":
content_format_instruction = " Format your response using Markdown."
elif format is not None:
logger.warning(f"Unsupported format '{format}' requested.", extra={'request_id': request_id, 'format': format})
effective_system_prompt_content = system_instruction.strip() + content_format_instruction.strip()
if effective_system_prompt_content:
final_messages.append({"role": "system", "content": effective_system_prompt_content})
if messages_list:
has_user_message = False
for i, msg in enumerate(messages_list):
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
raise ValueError(f"Message at index {i} is invalid: must be a dictionary with 'role' and 'content'.")
role = msg.get("role")
content = msg.get("content", "")
if not isinstance(content, str):
logger.warning(f"Message content at index {i} (role: {role}) is not a string (type: {type(content)}). Converting to string.", extra={'request_id': request_id, 'message_index': i, 'role': role, 'content_type': type(content)})
content = str(content)
if role == "system":
if i == 0 and final_messages and final_messages[0]["role"] == "system":
logger.info("Replacing default system prompt with user-provided system message.", extra={'request_id': request_id})
final_messages[0]["content"] = content
elif i == 0 and not final_messages:
final_messages.append({"role": "system", "content": content})
else:
logger.warning(f"Ignoring additional system message at index {i} as system prompt is already set or should be at the start.", extra={'request_id': request_id, 'message_index': i})
continue
elif role == "user":
has_user_message = True
final_messages.append({"role": role, "content": content})
if not has_user_message and any(m["role"] != "system" for m in final_messages):
logger.warning("The 'messages' list contains no user messages.", extra={'request_id': request_id})
elif prompt_str:
final_messages.append({"role": "user", "content": prompt_str})
if not final_messages or all(m["role"] == "system" for m in final_messages):
raise ValueError("No user or assistant messages found to generate a response.")
return final_messages
def estimate_token_count(messages: List[Dict[str, str]], request_id: str = 'N/A') -> int:
if not llm or not hasattr(llm, 'tokenize') or not hasattr(llm, 'apply_chat_template'):
logger.warning("LLM or tokenizer/template function not available for token estimation.", extra={'request_id': request_id})
return -1
try:
chat_prompt_string = llm.apply_chat_template(messages, add_generation_prompt=True)
tokens = llm.tokenize(chat_prompt_string.encode('utf-8', errors='ignore'), add_bos=True)
return len(tokens)
except Exception as e:
logger.error(f"Could not estimate token count using apply_chat_template: {e}", exc_info=True, extra={'request_id': request_id})
char_count = sum(len(m.get('content', '')) for m in messages)
estimated_tokens = char_count // 4
logger.warning(f"Falling back to character-based token estimation (~{estimated_tokens})", extra={'request_id': request_id, 'estimated_tokens': estimated_tokens, 'char_count': char_count})
return estimated_tokens
def get_effective_n_ctx() -> int:
return ACTUAL_N_CTX
def truncate_messages_for_context(messages: List[Dict[str, str]], max_tokens: int, buffer_ratio: float, request_id: str = 'N/A') -> List[Dict[str, str]]:
if not llm: return messages
target_token_limit = int(max_tokens * buffer_ratio)
truncated_messages: List[Dict[str, str]] = []
system_prompt: Optional[Dict[str, str]] = None
if messages and messages[0].get("role") == "system":
system_prompt = messages[0]
truncated_messages.append(system_prompt)
remaining_messages = messages[1:]
else:
remaining_messages = messages
current_token_count = estimate_token_count(truncated_messages, request_id=request_id) if truncated_messages else 0
if current_token_count == -1:
logger.warning("Could not estimate initial token count for truncation, proceeding cautiously with char estimate.", extra={'request_id': request_id})
current_token_count = sum(len(m.get('content', '')) for m in truncated_messages) // 4
messages_to_add = []
for msg in reversed(remaining_messages):
potential_list = ([system_prompt] if system_prompt else []) + [msg] + messages_to_add
next_token_count = estimate_token_count(potential_list, request_id=request_id)
if next_token_count != -1 and next_token_count <= target_token_limit:
messages_to_add.insert(0, msg)
current_token_count = next_token_count
elif next_token_count == -1:
logger.warning(f"Token estimation failed while adding message: {msg}. Stopping truncation early.", extra={'request_id': request_id})
break
else:
logger.debug(f"Stopping truncation: Adding next message would exceed target limit ({next_token_count} > {target_token_limit}).", extra={'request_id': request_id})
break
final_truncated_list = ([system_prompt] if system_prompt else []) + messages_to_add
original_count = len(messages)
final_count = len(final_truncated_list)
if not final_truncated_list or all(m.get("role") == "system" for m in final_truncated_list):
if any(m.get("role") == "user" for m in messages):
last_user_message = next((m for m in reversed(messages) if m.get("role") == "user"), None)
if last_user_message:
logger.warning("Truncation resulted in empty or system-only messages, attempting to keep last user message.", extra={'request_id': request_id})
final_truncated_list = ([system_prompt] if system_prompt else []) + [last_user_message]
final_count = len(final_truncated_list)
current_token_count = estimate_token_count(final_truncated_list, request_id=request_id)
if final_count < original_count:
logger.warning(f"Context truncated: Kept {final_count}/{original_count} messages. Estimated tokens: ~{current_token_count}/{target_token_limit} (target).",
extra={'request_id': request_id, 'kept': final_count, 'original': original_count, 'estimated_tokens': current_token_count, 'target_limit': target_token_limit})
else:
logger.debug(f"Context truncation check complete. Kept all {final_count} messages. Estimated tokens: ~{current_token_count}.",
extra={'request_id': request_id, 'kept': final_count, 'estimated_tokens': current_token_count})
if not final_truncated_list:
logger.error("Context truncation resulted in an empty message list!", extra={'request_id': request_id})
return []
return final_truncated_list
def get_property_or_method_value(obj: Any, prop_name: str, default: Any = None) -> Any:
"""Safely get property value or call method if callable."""
if hasattr(obj, prop_name):
prop = getattr(obj, prop_name)
if callable(prop):
try:
return prop()
except Exception:
logger.warning(f"Error calling method {prop_name} on {type(obj)}", exc_info=True)
return default
else:
return prop
return default
def load_model():
global llm, ACTUAL_N_CTX, ACTUAL_N_BATCH, ACTUAL_N_GPU_LAYERS
logger.info(f"Attempting to load model: {MODEL_REPO}/{MODEL_FILE}")
logger.info(f"Configuration: N_CTX={N_CTX_CONFIG}, N_BATCH={N_BATCH}, N_GPU_LAYERS={N_GPU_LAYERS_CONFIG}")
try:
llm = Llama.from_pretrained(
repo_id=MODEL_REPO,
filename=MODEL_FILE,
n_ctx=N_CTX_CONFIG,
n_batch=N_BATCH,
n_gpu_layers=N_GPU_LAYERS_CONFIG,
verbose=False,
use_mmap=True,
use_mlock=True,
)
logger.info("Model loaded successfully.")
if llm:
ACTUAL_N_CTX = get_property_or_method_value(llm, 'n_ctx', N_CTX_CONFIG)
ACTUAL_N_BATCH = get_property_or_method_value(llm, 'n_batch', N_BATCH)
ACTUAL_N_GPU_LAYERS = get_property_or_method_value(llm, 'n_gpu_layers', 0)
if ACTUAL_N_CTX != N_CTX_CONFIG:
logger.warning(f"Model's actual context size ({ACTUAL_N_CTX}) differs from config ({N_CTX_CONFIG}). Using actual.", extra={'actual_n_ctx': ACTUAL_N_CTX, 'configured_n_ctx': N_CTX_CONFIG})
if ACTUAL_N_GPU_LAYERS != N_GPU_LAYERS_CONFIG:
logger.warning(f"Model loaded with {ACTUAL_N_GPU_LAYERS} GPU layers despite requesting {N_GPU_LAYERS_CONFIG}. Check llama.cpp build or environment.", extra={'actual_gpu_layers': ACTUAL_N_GPU_LAYERS, 'configured_gpu_layers': N_GPU_LAYERS_CONFIG})
logger.info(f"Actual Model Context Window (n_ctx): {ACTUAL_N_CTX}")
logger.info(f"Actual Model Batch Size (n_batch): {ACTUAL_N_BATCH}")
logger.info(f"Actual Model GPU Layers (n_gpu_layers): {ACTUAL_N_GPU_LAYERS}")
try:
test_tokens = llm.tokenize(b"Test sentence.")
logger.info(f"Tokenizer test successful. 'Test sentence.' -> {len(test_tokens)} tokens.")
except Exception as tokenize_e:
logger.warning(f"Could not perform test tokenization: {tokenize_e}")
except Exception as e:
logger.error(f"Fatal error loading model: {e}", exc_info=True)
llm = None
logger.error("Model failed to load. Generation requests will not work.", extra={'error': str(e)})
app = Flask(__name__)
@app.before_request
def before_request_func():
g.request_id = str(uuid.uuid4())
logger.debug(f"Incoming request: {request.method} {request.path} from {request.remote_addr}", extra={'request_id': g.request_id, 'path': request.path, 'method': request.method})
load_model()
html_code = """
LLM API Demo
LLM API Demonstration
Health Check
API Info
Generate Text (Automatic Continuation with Context Management)
Note: No artificial token limit. Generation continues until the model stops naturally, hits a stop sequence, or reaches the context window limit (N_CTX={{ ACTUAL_N_CTX }}). If the context limit is reached, the server will attempt to continue automatically by truncating older messages (unlimited continuations). Other parameters (Temperature, Top P, Top K, Repeat Penalty, Seed) are fixed/random per generation cycle.
"""
@app.route("/")
def index():
rendered_html = render_template_string(
html_code,
ACTUAL_N_CTX=ACTUAL_N_CTX,
DEFAULT_SYSTEM_PROMPT=DEFAULT_SYSTEM_PROMPT
)
return rendered_html
@app.route("/health", methods=["GET"])
def health_check():
if llm:
if hasattr(llm, 'tokenize') and hasattr(llm, 'apply_chat_template'):
return jsonify(status="ok", message="Model is loaded and ready."), 200
else:
logger.warning("Model loaded, but tokenizer or chat template functions might be missing.", extra={'request_id': getattr(g, 'request_id', 'N/A')})
return jsonify(status="warning", message="Model loaded, but critical functions (tokenize/apply_chat_template) might be missing."), 200
else:
return jsonify(status="error", message="Model failed to load or is not available."), 503
@app.route("/info", methods=["GET"])
def model_info():
request_id = getattr(g, 'request_id', 'N/A')
if not llm:
logger.warning("Info request received but model is not loaded.", extra={'request_id': request_id})
return jsonify(error="Model not available."), 503
model_details: Union[Dict[str, Any], str] = "Model details unavailable"
try:
n_embd = get_property_or_method_value(get_property_or_method_value(llm, '_model'), 'n_embd', 'N/A')
model_details = {
"n_embd": n_embd,
"n_ctx": ACTUAL_N_CTX,
"n_batch": ACTUAL_N_BATCH,
"n_gpu_layers": ACTUAL_N_GPU_LAYERS,
"tokenizer_present": hasattr(llm, 'tokenize'),
"chat_handler_present": hasattr(llm, 'apply_chat_template') and hasattr(llm, 'create_chat_completion'),
}
except Exception as e:
logger.warning(f"Could not retrieve all model details: {e}", extra={'request_id': request_id}, exc_info=True)
model_details = f"Error retrieving some model details: {e}"
info = {
"status": "ok",
"message": "Model is loaded. Generation continues automatically with context truncation if context limit is hit.",
"model_config": {
"repo_id": MODEL_REPO,
"filename": MODEL_FILE,
"initial_load_config": {
"n_ctx": N_CTX_CONFIG,
"n_batch": N_BATCH,
"n_gpu_layers": N_GPU_LAYERS_CONFIG,
},
"loaded_model_details": model_details,
},
"generation_parameters": {
"note": f"No artificial 'max_tokens' limit. Generation proceeds until stop sequence, EOS, or context limit (N_CTX={ACTUAL_N_CTX}). Automatic continuation attempts by truncating context occur up to {MAX_CONTINUATIONS if MAX_CONTINUATIONS >= 0 else 'unlimited'} times if context limit is reached. Sampling parameters (temperature, top_p, top_k) are chosen randomly per request/continuation cycle from predefined sets. Repeat penalty and seed are fixed.",
"fixed_max_tokens": None,
"fixed_repeat_penalty": FIXED_REPEAT_PENALTY,
"fixed_seed": FIXED_SEED,
"max_automatic_continuations": MAX_CONTINUATIONS if MAX_CONTINUATIONS >= 0 else None,
"context_truncation_buffer_ratio": CONTEXT_TRUNCATION_BUFFER_RATIO,
"randomly_chosen_from": RANDOM_PARAMS_CHOICES,
"default_system_prompt": DEFAULT_SYSTEM_PROMPT,
"user_controllable": ["messages", "prompt", "stop", "stream", "format", "system_prompt"],
},
}
return jsonify(info), 200
def _generate_single_cycle(messages: List[Dict[str, str]], params: Dict, stream: bool, request_id: str) -> Union[Generator[Dict, None, None], Dict]:
try:
logger.debug(f"Starting llama.cpp chat completion call. Stream: {stream}. Messages: {len(messages)}. Params summary: temp={params.get('temperature')}, top_p={params.get('top_p')}, top_k={params.get('top_k')}, stop={params.get('stop')}", extra={'request_id': request_id, 'stream': stream, 'message_count': len(messages)})
result = llm.create_chat_completion(
messages=messages,
max_tokens=params["max_tokens"],
temperature=params["temperature"],
top_p=params["top_p"],
top_k=params["top_k"],
repeat_penalty=params["repeat_penalty"],
stop=params["stop"],
seed=params["seed"],
stream=stream,
)
return result
except Exception as e:
err_str = str(e).lower()
if "context window is full" in err_str or \
"kv cache is full" in err_str or \
"llama_decode" in err_str or \
(hasattr(e, 'condition') and isinstance(e.condition, str) and ("context length" in e.condition.lower() or "failed to decode" in e.condition.lower())):
logger.warning(f"Caught N_CTX limit or related exception: {e}", extra={'request_id': request_id})
raise ContextLimitException(str(e)) from e
else:
logger.error(f"Unhandled error during llama.cpp call: {e}", exc_info=True, extra={'request_id': request_id})
raise GenerationFailedException(f"Unhandled llama.cpp error: {str(e)}") from e
@app.route("/generate", methods=["POST"])
def generate():
request_id = getattr(g, 'request_id', 'N/A')
if not llm:
logger.error("Generate request received but model is not loaded.", extra={'request_id': request_id})
return jsonify(error="Model is not available.", detail="The LLM model could not be loaded."), 503
if not request.is_json:
logger.warning("Request received without Content-Type: application/json", extra={'request_id': request_id})
return jsonify(error="Invalid request header", detail="Content-Type must be application/json"), 415
data = request.get_json()
is_streaming = data.get("stream", True)
response_format = data.get("format")
log_data_summary = {k: v for k, v in data.items() if k not in ('messages', 'prompt')}
log_data_summary['messages_count_initial'] = len(data.get('messages', [])) if 'messages' in data else 0
log_data_summary['has_prompt_initial'] = 'prompt' in data
log_data_summary['stream'] = is_streaming
log_data_summary['format'] = response_format
logger.info(f"Received generation request summary.", extra={'request_id': request_id, 'summary': log_data_summary})
try:
initial_messages = prepare_messages(data, format=response_format, request_id=request_id)
base_params: Dict[str, Any] = {
"max_tokens": None,
"repeat_penalty": FIXED_REPEAT_PENALTY,
"seed": FIXED_SEED,
}
stop = data.get("stop")
if stop is not None:
if isinstance(stop, list) and all(isinstance(s, str) for s in stop):
base_params["stop"] = stop
elif isinstance(stop, str):
base_params["stop"] = [stop]
else:
raise ValueError({"stop": "Stop must be a string or a list of strings"})
else:
base_params["stop"] = None
effective_n_ctx = get_effective_n_ctx()
input_token_count = estimate_token_count(initial_messages, request_id=request_id)
if input_token_count != -1 and input_token_count > effective_n_ctx * CONTEXT_TRUNCATION_BUFFER_RATIO:
logger.warning(f"Initial input (~{input_token_count} tokens) likely exceeds safe context window ({int(effective_n_ctx * CONTEXT_TRUNCATION_BUFFER_RATIO)}). Attempting truncation.", extra={'request_id': request_id, 'initial_tokens': input_token_count, 'n_ctx': effective_n_ctx, 'buffer_ratio': CONTEXT_TRUNCATION_BUFFER_RATIO})
truncated_initial = truncate_messages_for_context(initial_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO, request_id=request_id)
truncated_token_count = estimate_token_count(truncated_initial, request_id=request_id)
if not truncated_initial or (truncated_token_count != -1 and truncated_token_count > effective_n_ctx):
error_msg = f"Input exceeds context window ({effective_n_ctx}) even after attempting truncation. Input tokens (~{input_token_count}) / Truncated tokens (~{truncated_token_count}). Reduce initial message size."
logger.error(error_msg, extra={'request_id': request_id, 'initial_tokens': input_token_count, 'truncated_tokens': truncated_token_count, 'n_ctx': effective_n_ctx})
return jsonify(error="Input exceeds context window", detail=error_msg), 400
else:
logger.info(f"Initial input truncated from ~{input_token_count} to ~{truncated_token_count} tokens.", extra={'request_id': request_id, 'initial_tokens': input_token_count, 'truncated_tokens': truncated_token_count, 'n_ctx': effective_n_ctx})
initial_messages = truncated_initial
input_token_count = truncated_token_count
elif input_token_count != -1:
logger.info(f"Initial input token count: ~{input_token_count}. Effective context window: {effective_n_ctx}. Context buffer target: {int(effective_n_ctx * CONTEXT_TRUNCATION_BUFFER_RATIO)}. Remaining: {effective_n_ctx - input_token_count}.", extra={'request_id': request_id, 'input_tokens': input_token_count, 'n_ctx': effective_n_ctx, 'buffer_target': int(effective_n_ctx * CONTEXT_TRUNCATION_BUFFER_RATIO), 'remaining_ctx': effective_n_ctx - input_token_count})
else:
logger.warning("Could not estimate initial token count. Proceeding, may hit context limit.", extra={'request_id': request_id})
except ValueError as e:
logger.error(f"Invalid input data or parameters: {e}", exc_info=True, extra={'request_id': request_id})
try: error_detail = json.loads(str(e))
except json.JSONDecodeError: error_detail = str(e)
return jsonify(error="Invalid input", detail=error_detail), 400
except Exception as e:
logger.error(f"Unexpected error preparing request: {e}", exc_info=True, extra={'request_id': request_id})
return jsonify(error="Internal server error", detail="An unexpected error occurred preparing the request."), 500
current_messages = list(initial_messages)
continuations = 0
total_completion_tokens_generated = 0
final_finish_reason = "unknown"
final_usage = {}
full_generated_text_nonstream = ""
effective_n_ctx = get_effective_n_ctx()
def streaming_generator(req_id):
nonlocal current_messages, continuations, total_completion_tokens_generated, final_finish_reason, final_usage
while True:
if MAX_CONTINUATIONS >= 0 and continuations > MAX_CONTINUATIONS:
logger.info(f"Max continuations ({MAX_CONTINUATIONS}) reached. Stopping streaming.", extra={'request_id': req_id})
yield f"\n[INFO] Generation stopped: Max continuations reached ({MAX_CONTINUATIONS})."
final_finish_reason = "max_continuations"
break
cycle_number = continuations + 1
logger.info(f"Starting streaming generation cycle {cycle_number}. Message count: {len(current_messages)}.", extra={'request_id': req_id, 'cycle': cycle_number, 'message_count': len(current_messages)})
chosen_params = random.choice(RANDOM_PARAMS_CHOICES)
current_params = {**base_params, **chosen_params}
generated_this_cycle_content = ""
finish_reason = None
usage_this_cycle = {}
try:
streamer = _generate_single_cycle(current_messages, current_params, stream=True, request_id=req_id)
for chunk in streamer:
choice = chunk.get("choices", [{}])[0]
delta = choice.get("delta", {})
token_content = delta.get("content")
chunk_finish_reason = choice.get("finish_reason")
chunk_usage = chunk.get("usage", {})
if token_content:
generated_this_cycle_content += token_content
yield token_content
if chunk_finish_reason:
finish_reason = chunk_finish_reason
usage_this_cycle = chunk_usage
final_usage = usage_this_cycle
break
if not finish_reason and generated_this_cycle_content:
finish_reason = "end_of_stream"
logger.warning(f"Streaming cycle {cycle_number} ended without explicit finish reason.", extra={'request_id': req_id, 'cycle': cycle_number})
except ContextLimitException as e:
logger.warning(f"Context limit caught during streaming cycle {cycle_number}.", extra={'request_id': req_id, 'cycle': cycle_number})
finish_reason = 'length'
yield f"\n[INFO] Context limit approached in cycle {cycle_number}. Attempting continuation...\n"
except GenerationFailedException as e:
logger.error(f"Generation failed in streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': req_id, 'cycle': cycle_number})
yield f"\n[ERROR] Generation failed unexpectedly in cycle {cycle_number}: {e}"
final_finish_reason = "error"
break
except Exception as e:
logger.error(f"An unexpected error occurred in streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': req_id, 'cycle': cycle_number})
yield f"\n[ERROR] An unexpected error occurred in cycle {cycle_number}: {str(e)}"
final_finish_reason = "error"
break
if generated_this_cycle_content:
if not current_messages or current_messages[-1].get('role') != 'assistant':
current_messages.append({"role": "assistant", "content": generated_this_cycle_content})
else:
current_messages[-1]['content'] += generated_this_cycle_content
total_completion_tokens_generated += usage_this_cycle.get("completion_tokens", 0)
if finish_reason == 'stop' or finish_reason == 'end_of_stream':
logger.info(f"Streaming generation stopped naturally in cycle {cycle_number}. Reason: {finish_reason}", extra={'request_id': req_id, 'cycle': cycle_number, 'finish_reason': finish_reason})
final_finish_reason = finish_reason if finish_reason != 'end_of_stream' else 'stop'
yield f"\n[INFO] Generation finished."
break
elif finish_reason == 'length':
continuations += 1
logger.warning(f"N_CTX limit reached in streaming cycle {cycle_number}. Attempting continuation {continuations}.", extra={'request_id': req_id, 'cycle': cycle_number, 'continuations': continuations})
current_messages = truncate_messages_for_context(current_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO, request_id=req_id)
if not current_messages or (len(current_messages) == 1 and current_messages[0].get("role") == "system"):
logger.error("Context truncation resulted in empty or system-only messages during streaming. Stopping.", extra={'request_id': req_id, 'cycle': cycle_number})
yield f"\n[ERROR] Generation failed: Context truncation error."
final_finish_reason = "truncation_error"
break
yield f"\n[CONTINUING {continuations} - TRUNCATING CONTEXT...]\n"
time.sleep(0.05)
continue
else:
logger.warning(f"Streaming generation cycle {cycle_number} ended with unexpected reason '{finish_reason}'. Stopping generation.", extra={'request_id': req_id, 'cycle': cycle_number, 'finish_reason': finish_reason})
yield f"\n[INFO] Generation stopped: Reason: {finish_reason or 'Unknown'}"
final_finish_reason = finish_reason or "unknown"
break
logger.info(f"Streaming generation stream closed. Total continuations: {continuations}. Final reason: {final_finish_reason}", extra={'request_id': req_id, 'continuations': continuations, 'final_reason': final_finish_reason})
if is_streaming:
headers = {
"Content-Type": "text/event-stream; charset=utf-8",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"X-Request-ID": request_id
}
return Response(stream_with_context(streaming_generator(request_id)), headers=headers)
else:
while True:
if MAX_CONTINUATIONS >= 0 and continuations > MAX_CONTINUATIONS:
logger.info(f"Max continuations ({MAX_CONTINUATIONS}) reached. Stopping non-streaming.", extra={'request_id': request_id})
if full_generated_text_nonstream:
full_generated_text_nonstream += "\n\n"
full_generated_text_nonstream += f"[INFO: Generation stopped: Max continuations reached ({MAX_CONTINUATIONS}).]"
final_finish_reason = "max_continuations"
break
cycle_number = continuations + 1
logger.info(f"Starting non-streaming generation cycle {cycle_number}. Message count: {len(current_messages)}.", extra={'request_id': request_id, 'cycle': cycle_number, 'message_count': len(current_messages)})
chosen_params = random.choice(RANDOM_PARAMS_CHOICES)
current_params = {**base_params, **chosen_params}
logger.debug(f"Cycle {cycle_number} params: temp={current_params['temperature']}, top_p={current_params['top_p']}, top_k={current_params['top_k']}, stop={current_params['stop']}", extra={'request_id': request_id, 'cycle': cycle_number, 'params': current_params})
generated_this_cycle_content = ""
finish_reason = None
usage_this_cycle = {}
try:
result = _generate_single_cycle(current_messages, current_params, stream=False, request_id=request_id)
if result and "choices" in result and result["choices"]:
choice = result["choices"][0]
generated_this_cycle_content = choice.get("message", {}).get("content", "")
finish_reason = choice.get("finish_reason", "unknown")
usage_this_cycle = result.get("usage", {})
final_usage = usage_this_cycle
else:
logger.error(f"Invalid response structure from llama_cpp in non-streaming cycle {cycle_number}: {result}", extra={'request_id': request_id, 'cycle': cycle_number, 'result': result})
if full_generated_text_nonstream:
full_generated_text_nonstream += "\n\n"
full_generated_text_nonstream += f"[ERROR: Invalid response structure from model in cycle {cycle_number}.]"
final_finish_reason = "internal_error"
break
logger.info(f"Non-streaming cycle {cycle_number} finished. Reason: {finish_reason}. Usage: {usage_this_cycle}", extra={'request_id': request_id, 'cycle': cycle_number, 'usage': usage_this_cycle, 'finish_reason': finish_reason})
except ContextLimitException:
logger.warning(f"Context limit caught during non-streaming cycle {cycle_number}.", extra={'request_id': request_id, 'cycle': cycle_number})
finish_reason = 'length'
except GenerationFailedException as e:
logger.error(f"Generation failed in non-streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': request_id, 'cycle': cycle_number})
if full_generated_text_nonstream:
full_generated_text_nonstream += "\n\n"
full_generated_text_nonstream += f"[ERROR: Generation failed unexpectedly in cycle {cycle_number}: {e}]"
final_finish_reason = "error"
break
except Exception as e:
logger.error(f"An unexpected error occurred in non-streaming cycle {cycle_number}: {e}", exc_info=True, extra={'request_id': request_id, 'cycle': cycle_number})
if full_generated_text_nonstream:
full_generated_text_nonstream += "\n\n"
full_generated_text_nonstream += f"[ERROR: An unexpected error occurred in cycle {cycle_number}: {str(e)}]"
final_finish_reason = "error"
break
if generated_this_cycle_content:
if continuations > 0 and full_generated_text_nonstream:
full_generated_text_nonstream += f"\n\n[CONTINUATION {continuations} - TRUNCATED CONTEXT]\n\n"
full_generated_text_nonstream += generated_this_cycle_content
if not current_messages or current_messages[-1].get('role') != 'assistant':
current_messages.append({"role": "assistant", "content": generated_this_cycle_content})
else:
current_messages[-1]['content'] += generated_this_cycle_content
tokens_generated_cycle = usage_this_cycle.get("completion_tokens", 0)
total_completion_tokens_generated += tokens_generated_cycle
elif finish_reason == 'length':
logger.warning(f"Non-streaming N_CTX limit hit in cycle {cycle_number} but no completion tokens reported in usage.", extra={'request_id': request_id, 'cycle': cycle_number})
if continuations > 0 and full_generated_text_nonstream:
full_generated_text_nonstream += f"\n\n[CONTINUATION {continuations} - TRUNCATED CONTEXT - NO OUTPUT THIS CYCLE]\n\n"
if finish_reason == 'stop':
logger.info(f"Non-streaming generation stopped naturally (reason: stop) in cycle {cycle_number}.", extra={'request_id': request_id, 'cycle': cycle_number, 'finish_reason': finish_reason})
final_finish_reason = 'stop'
break
elif finish_reason == 'length':
continuations += 1
logger.warning(f"Non-streaming N_CTX limit reached in cycle {cycle_number}. Attempting continuation {continuations}.", extra={'request_id': request_id, 'cycle': cycle_number, 'continuations': continuations})
current_messages = truncate_messages_for_context(current_messages, effective_n_ctx, CONTEXT_TRUNCATION_BUFFER_RATIO, request_id=request_id)
if not current_messages or (len(current_messages) == 1 and current_messages[0].get("role") == "system"):
logger.error("Context truncation resulted in empty or system-only messages during non-streaming. Stopping.", extra={'request_id': request_id, 'cycle': cycle_number})
if full_generated_text_nonstream:
full_generated_text_nonstream += "\n\n"
full_generated_text_nonstream += f"[ERROR: Generation failed: Context truncation error.]"
final_finish_reason = "truncation_error"
break
continue
else:
logger.warning(f"Non-streaming cycle {cycle_number} ended with reason '{finish_reason}' or unexpectedly. Stopping generation.", extra={'request_id': request_id, 'cycle': cycle_number, 'finish_reason': finish_reason})
if full_generated_text_nonstream:
full_generated_text_nonstream += "\n\n"
full_generated_text_nonstream += f"[INFO: Generation stopped unexpectedly. Reason: {finish_reason or 'Unknown'}]"
final_finish_reason = finish_reason or "unknown"
break
logger.info(f"Non-streaming generation finished after {continuations} continuations. Total completion tokens generated: {total_completion_tokens_generated}. Final reason: {final_finish_reason}", extra={'request_id': request_id, 'continuations': continuations, 'total_completion_tokens': total_completion_tokens_generated, 'final_reason': final_finish_reason})
response = Response(full_generated_text_nonstream, mimetype="text/plain; charset=utf-8")
response.headers["X-Request-ID"] = request_id
response.headers["X-Finish-Reason"] = final_finish_reason
response.headers["X-Continuations"] = str(continuations)
response.headers["X-Usage-Completion-Tokens"] = str(total_completion_tokens_generated)
response.headers["X-Usage-Prompt-Tokens-Last-Cycle"] = str(final_usage.get("prompt_tokens", "N/A"))
response.headers["X-Usage-Total-Tokens-Last-Cycle"] = str(final_usage.get("total_tokens", "N/A"))
return response
if __name__ == "__main__":
host = os.getenv("HOST", "0.0.0.0")
port = int(os.getenv("PORT", "7860"))
is_debug = os.getenv("FLASK_DEBUG", "0") == "1"
log_level = logging.DEBUG if is_debug else logging.INFO
logger.setLevel(log_level)
max_cont_desc = MAX_CONTINUATIONS if MAX_CONTINUATIONS >= 0 else 'UNLIMITED'
logger.info(f"Starting Flask server on {host}:{port} (Debug mode: {is_debug})")
logger.info(f"Model: {MODEL_REPO}/{MODEL_FILE}, N_CTX={ACTUAL_N_CTX}, Automatic Continuations: {max_cont_desc} (with context truncation)")
if not llm:
logger.critical("MODEL FAILED TO LOAD. SERVER WILL START BUT '/generate' WILL FAIL.")
logger.info("Running with Flask development server.")
app.run(host=host, port=port, threaded=True, debug=is_debug, use_reloader=False)