kontextflux / app.py
1v1's picture
Update app.py
623802b verified
import os
import json
import time
import uuid
import base64
import hashlib
import asyncio
import websocket
from io import BytesIO
from typing import Dict, List, Optional, Union, Any, Tuple
import requests
import httpx
from fastapi import FastAPI, HTTPException, Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from itertools import cycle
class KontextFluxEncryptor:
"""Replicates the encryption logic from a.js to generate the 'xtx' header hash."""
def __init__(self, config_data):
self.kis = config_data["kis"]
self.ra1 = config_data["ra1"]
self.ra2 = config_data["ra2"]
self.random = config_data["random"]
def _aes_decrypt(self, key, iv, ciphertext_b64):
"""Decrypts AES-CBC base64 encoded data."""
cipher = AES.new(key.encode("utf-8"), AES.MODE_CBC, iv.encode("utf-8"))
decoded_ciphertext = base64.b64decode(ciphertext_b64)
decrypted_padded = cipher.decrypt(decoded_ciphertext)
return unpad(decrypted_padded, AES.block_size).decode("utf-8")
def _aes_encrypt(self, key, iv, plaintext):
"""Encrypts plaintext with AES-CBC and returns a base64 encoded string."""
cipher = AES.new(key.encode("utf-8"), AES.MODE_CBC, iv.encode("utf-8"))
padded_data = pad(plaintext.encode("utf-8"), AES.block_size)
encrypted_data = cipher.encrypt(padded_data)
return base64.b64encode(encrypted_data).decode("utf-8")
def get_xtx_hash(self, payload):
"""Generates the final MD5 hash for the 'xtx' header."""
sorted_keys = sorted(payload.keys())
serialized_parts = []
for key in sorted_keys:
value = payload[key]
stringified_value = json.dumps(value, separators=(",", ":"), ensure_ascii=False)
safe_value = stringified_value.replace("<", "").replace(">", "")
encoded_value = base64.b64encode(safe_value.encode("utf-8")).decode("utf-8")
serialized_parts.append(f"{key}={encoded_value}")
serialized_payload = "".join(serialized_parts)
decoded_kis = base64.b64decode(self.kis).split(b"=sj+Ow2R/v")
random_str = str(self.random)
y = int(random_str[0])
b = int(random_str[-1])
k = int(random_str[2 : 2 + y])
s_idx = int(random_str[4 + y : 4 + y + b])
intermediate_key = decoded_kis[k].decode("utf-8")
intermediate_iv = decoded_kis[s_idx].decode("utf-8")
main_key = self._aes_decrypt(intermediate_key, intermediate_iv, self.ra1)
main_iv = self._aes_decrypt(intermediate_key, intermediate_iv, self.ra2)
encrypted_payload = self._aes_encrypt(main_key, main_iv, serialized_payload)
final_hash = hashlib.md5(encrypted_payload.encode("utf-8")).hexdigest()
return final_hash
def get_config():
"""Get kontextflux configuration."""
url = "https://api.kontextflux.com/client/common/getConfig"
payload = {"token": get_token(), "referrer": ""}
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36 Edg/137.0.0.0",
"Accept-Encoding": "gzip, deflate, br, zstd",
"Content-Type": "application/json",
"Origin": "https://kontextflux.com",
"Referer": "https://kontextflux.com/",
}
response = requests.post(url, data=json.dumps(payload), headers=headers)
response.raise_for_status()
return response.json()["data"]
async def upload_file(config, image_bytes: bytes, filename: str = "image.png"):
"""Upload image file to kontextflux."""
url = "https://api.kontextflux.com/client/resource/uploadFile"
files = [("file", (filename, BytesIO(image_bytes), "null"))]
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36 Edg/137.0.0.0",
"Accept-Encoding": "gzip, deflate, br, zstd",
"Authorization": config["token"],
"xtx": KontextFluxEncryptor(config).get_xtx_hash({}),
}
async with httpx.AsyncClient() as client:
response = await client.post(url, files=files, headers=headers)
response.raise_for_status()
return response.json()["data"]
def create_draw_task(config, prompt: str, keys: List[str] = [], size: str = "auto"):
"""Create image generation task."""
url = "https://api.kontextflux.com/client/styleAI/draw"
payload = {"keys": keys, "prompt": prompt, "size": size}
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36 Edg/137.0.0.0",
"Accept-Encoding": "gzip, deflate, br, zstd",
"Content-Type": "application/json",
"Authorization": config["token"],
"xtx": KontextFluxEncryptor(config).get_xtx_hash(payload),
}
response = requests.post(url, data=json.dumps(payload), headers=headers)
response.raise_for_status()
return response.json()["data"]["id"]
async def process_image_url(image_url: str) -> Optional[bytes]:
"""Process image URL (base64 or http/https) and return image bytes."""
try:
if image_url.startswith('data:image/'):
# Handle base64 encoded image
_, encoded = image_url.split(",", 1)
return base64.b64decode(encoded)
elif image_url.startswith(('http://', 'https://')):
# Handle HTTP(S) URL
async with httpx.AsyncClient() as client:
response = await client.get(image_url, timeout=60, follow_redirects=True)
response.raise_for_status()
return response.content
else:
print(f"Unsupported image URL format: {image_url[:30]}...")
return None
except Exception as e:
print(f"Error processing image: {e}")
return None
# Pydantic Models
class ChatMessage(BaseModel):
role: str
content: Union[str, List[Dict[str, Any]]]
reasoning_content: Optional[str] = None
class ChatRequest(BaseModel):
model: str
messages: List[ChatMessage]
stream: bool = True
size: str = "auto" # KontextFlux size parameter: auto, 2:3, 3:2, 1:1
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_p: Optional[float] = None
class ModelInfo(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "kontextflux"
class ModelList(BaseModel):
object: str = "list"
data: List[ModelInfo]
class StreamChoice(BaseModel):
delta: Dict[str, Any] = Field(default_factory=dict)
index: int = 0
finish_reason: Optional[str] = None
class StreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[StreamChoice]
# FastAPI App
app = FastAPI(title="KontextFlux OpenAI API Adapter")
security = HTTPBearer(auto_error=False)
# Global variables
_TOKENS = os.getenv("TOKENS", "")
TOKENS = _TOKENS.split(",")
iterator = cycle(TOKENS)
API_KEY = os.getenv("API_KEY", "linux.do")
VALID_SIZES = ["auto", "2:3", "3:2", "1:1"] # Supported KontextFlux size options
def get_token():
return next(iterator)
def parse_prompt_parameters(prompt: str) -> Tuple[str, str]:
"""
Parse parameters from user prompt and return cleaned prompt with extracted size.
Supports formats like:
- "A beautiful landscape --size 2:3"
- "Generate an image --size=1:1 of a cat"
- "Create art with size:3:2"
Returns:
tuple: (cleaned_prompt, size)
"""
import re
# Default values
size = "auto"
cleaned_prompt = prompt
# Pattern to match size parameters in various formats
size_patterns = [
r'--size[=\s]+([^\s]+)', # --size 2:3 or --size=2:3
r'size[:\s]+([^\s]+)', # size:2:3 or size 2:3
r'\bsize[=\s]+([^\s]+)', # size=2:3 or size 2:3
]
for pattern in size_patterns:
match = re.search(pattern, prompt, re.IGNORECASE)
if match:
extracted_size = match.group(1).strip()
if extracted_size in VALID_SIZES:
size = extracted_size
# Remove the parameter from prompt
cleaned_prompt = re.sub(pattern, '', prompt, flags=re.IGNORECASE).strip()
# Clean up extra spaces
cleaned_prompt = re.sub(r'\s+', ' ', cleaned_prompt).strip()
break
return cleaned_prompt, size
def validate_size_parameter(size: str) -> str:
"""Validate and return the size parameter."""
if size not in VALID_SIZES:
raise HTTPException(
status_code=400,
detail=f"Invalid size '{size}'. Valid options: {', '.join(VALID_SIZES)}"
)
return size
async def authenticate_client(auth: Optional[HTTPAuthorizationCredentials] = Depends(security)):
"""Authenticate client based on API key in Authorization header"""
if not API_KEY:
raise HTTPException(status_code=503, detail="Service unavailable: Client API keys not configured.")
if not auth or not auth.credentials:
raise HTTPException(status_code=401, detail="API key required in Authorization header.")
if auth.credentials not in API_KEY:
raise HTTPException(status_code=403, detail="Invalid client API key.")
@app.on_event("startup")
async def startup():
"""Application startup initialization"""
print("Starting KontextFlux OpenAI API Adapter server...")
print("Server initialization completed.")
@app.get("/v1/models", response_model=ModelList)
async def list_models(_: None = Depends(authenticate_client)):
"""List available models"""
return ModelList(data=[ModelInfo(id="kontext-flux")])
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatRequest, _: None = Depends(authenticate_client)):
"""Create chat completion using KontextFlux backend"""
if request.model != "kontext-flux":
raise HTTPException(status_code=404, detail=f"Model '{request.model}' not found.")
if not request.messages:
raise HTTPException(status_code=400, detail="No messages provided in the request.")
# Extract prompt and images from messages
prompt_parts = []
image_urls = []
for message in request.messages:
if isinstance(message.content, str):
prompt_parts.append(message.content)
elif isinstance(message.content, list):
for part in message.content:
if part.get("type") == "text" and part.get("text"):
prompt_parts.append(part["text"])
elif part.get("type") == "image_url" and part.get("image_url", {}).get("url"):
image_urls.append(part["image_url"]["url"])
raw_prompt = " ".join(filter(None, prompt_parts))
if not raw_prompt and not image_urls:
raise HTTPException(status_code=400, detail="Request must contain text prompt or at least one image.")
# Parse parameters from prompt and validate
prompt, parsed_size = parse_prompt_parameters(raw_prompt)
# Use parsed size if found, otherwise use request parameter, fallback to default
final_size = parsed_size if parsed_size != "auto" else request.size
final_size = validate_size_parameter(final_size)
try:
# Get kontextflux config
config = get_config()
# Process and upload images
uploaded_keys = []
if image_urls:
for i, image_url in enumerate(image_urls):
image_bytes = await process_image_url(image_url)
if image_bytes:
upload_result = await upload_file(config, image_bytes, f"image_{i}.png")
uploaded_keys.append(upload_result["key"])
# Create draw task with dynamic size parameter
draw_id = create_draw_task(config, prompt, uploaded_keys, final_size)
if request.stream:
return StreamingResponse(
kontextflux_stream_generator(config, draw_id, request.model),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
else:
# Non-streaming response (wait for completion)
final_url = await wait_for_completion(config, draw_id)
return {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": int(time.time()),
"model": request.model,
"choices": [{
"message": {
"role": "assistant",
"content": f"![image]({final_url})"
},
"index": 0,
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
}
except Exception as e:
print(f"Error processing request: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
async def kontextflux_stream_generator(config, draw_id: str, model: str):
"""Generate streaming response with progress updates"""
stream_id = f"chatcmpl-{uuid.uuid4().hex}"
created_time = int(time.time())
# Send initial role delta
yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={'role': 'assistant'})]).model_dump_json()}\n\n"
try:
# Setup WebSocket connection
e = {"token": config["token"], "id": draw_id}
xtx = KontextFluxEncryptor(config).get_xtx_hash(e)
url = f"wss://api.kontextflux.com/client/styleAI/checkWs?xtx={xtx}"
ws = websocket.create_connection(url)
ws.send(json.dumps(e))
while True:
await asyncio.sleep(0.1) # Small delay to prevent blocking
try:
msg = ws.recv()
data = json.loads(msg)
if data["content"]["photo"]:
# Generation completed
final_url = data["content"]["photo"]["url"]
# Send final content
yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={'content': f'![image]({final_url})'})]).model_dump_json()}\n\n"
# Send completion signal
yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={}, finish_reason='stop')]).model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
break
else:
# Send progress update as reasoning content
progress = data["content"]["progress"]
# 选择适合当前进度的表情符号
if progress < 20:
emoji = "🚀" # 开始阶段
elif progress < 40:
emoji = "⚙️" # 处理中
elif progress < 60:
emoji = "✨" # 半程
elif progress < 80:
emoji = "🔍" # 细节生成
elif progress < 100:
emoji = "🎨" # 最终润色
else:
emoji = "✅" # 完成
# 创建进度条
bar_length = 20
filled_length = int(bar_length * progress / 100)
bar = "█" * filled_length + "░" * (bar_length - filled_length)
# 格式化美观的进度信息
reasoning_text = f"{emoji} 图像生成进度 |{bar}| {progress}%\n"
yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={'reasoning_content': reasoning_text})]).model_dump_json()}\n\n"
except websocket.WebSocketTimeoutException:
continue
except Exception as e:
print(f"WebSocket error: {e}")
break
ws.close()
except Exception as e:
print(f"Stream processing error: {e}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
yield "data: [DONE]\n\n"
async def wait_for_completion(config, draw_id: str) -> str:
"""Wait for image generation completion (non-streaming)"""
e = {"token": config["token"], "id": draw_id}
xtx = KontextFluxEncryptor(config).get_xtx_hash(e)
url = f"wss://api.kontextflux.com/client/styleAI/checkWs?xtx={xtx}"
ws = websocket.create_connection(url)
ws.send(json.dumps(e))
try:
while True:
await asyncio.sleep(1)
msg = ws.recv()
data = json.loads(msg)
if data["content"]["photo"]:
return data["content"]["photo"]["url"]
finally:
ws.close()
if __name__ == "__main__":
import uvicorn
print("\n--- KontextFlux OpenAI API Adapter ---")
print("Endpoints:")
print(" GET /v1/models")
print(" POST /v1/chat/completions")
print(f"TOKENS: {len(TOKENS)}")
print("------------------------------------")
uvicorn.run(app, host="0.0.0.0", port=8000)