Spaces:
Sleeping
Sleeping
import os | |
import json | |
import time | |
import uuid | |
import base64 | |
import hashlib | |
import asyncio | |
import websocket | |
from io import BytesIO | |
from typing import Dict, List, Optional, Union, Any, Tuple | |
import requests | |
import httpx | |
from fastapi import FastAPI, HTTPException, Depends | |
from fastapi.responses import StreamingResponse | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from pydantic import BaseModel, Field | |
from Crypto.Cipher import AES | |
from Crypto.Util.Padding import pad, unpad | |
from itertools import cycle | |
class KontextFluxEncryptor: | |
"""Replicates the encryption logic from a.js to generate the 'xtx' header hash.""" | |
def __init__(self, config_data): | |
self.kis = config_data["kis"] | |
self.ra1 = config_data["ra1"] | |
self.ra2 = config_data["ra2"] | |
self.random = config_data["random"] | |
def _aes_decrypt(self, key, iv, ciphertext_b64): | |
"""Decrypts AES-CBC base64 encoded data.""" | |
cipher = AES.new(key.encode("utf-8"), AES.MODE_CBC, iv.encode("utf-8")) | |
decoded_ciphertext = base64.b64decode(ciphertext_b64) | |
decrypted_padded = cipher.decrypt(decoded_ciphertext) | |
return unpad(decrypted_padded, AES.block_size).decode("utf-8") | |
def _aes_encrypt(self, key, iv, plaintext): | |
"""Encrypts plaintext with AES-CBC and returns a base64 encoded string.""" | |
cipher = AES.new(key.encode("utf-8"), AES.MODE_CBC, iv.encode("utf-8")) | |
padded_data = pad(plaintext.encode("utf-8"), AES.block_size) | |
encrypted_data = cipher.encrypt(padded_data) | |
return base64.b64encode(encrypted_data).decode("utf-8") | |
def get_xtx_hash(self, payload): | |
"""Generates the final MD5 hash for the 'xtx' header.""" | |
sorted_keys = sorted(payload.keys()) | |
serialized_parts = [] | |
for key in sorted_keys: | |
value = payload[key] | |
stringified_value = json.dumps(value, separators=(",", ":"), ensure_ascii=False) | |
safe_value = stringified_value.replace("<", "").replace(">", "") | |
encoded_value = base64.b64encode(safe_value.encode("utf-8")).decode("utf-8") | |
serialized_parts.append(f"{key}={encoded_value}") | |
serialized_payload = "".join(serialized_parts) | |
decoded_kis = base64.b64decode(self.kis).split(b"=sj+Ow2R/v") | |
random_str = str(self.random) | |
y = int(random_str[0]) | |
b = int(random_str[-1]) | |
k = int(random_str[2 : 2 + y]) | |
s_idx = int(random_str[4 + y : 4 + y + b]) | |
intermediate_key = decoded_kis[k].decode("utf-8") | |
intermediate_iv = decoded_kis[s_idx].decode("utf-8") | |
main_key = self._aes_decrypt(intermediate_key, intermediate_iv, self.ra1) | |
main_iv = self._aes_decrypt(intermediate_key, intermediate_iv, self.ra2) | |
encrypted_payload = self._aes_encrypt(main_key, main_iv, serialized_payload) | |
final_hash = hashlib.md5(encrypted_payload.encode("utf-8")).hexdigest() | |
return final_hash | |
def get_config(): | |
"""Get kontextflux configuration.""" | |
url = "https://api.kontextflux.com/client/common/getConfig" | |
payload = {"token": get_token(), "referrer": ""} | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36 Edg/137.0.0.0", | |
"Accept-Encoding": "gzip, deflate, br, zstd", | |
"Content-Type": "application/json", | |
"Origin": "https://kontextflux.com", | |
"Referer": "https://kontextflux.com/", | |
} | |
response = requests.post(url, data=json.dumps(payload), headers=headers) | |
response.raise_for_status() | |
return response.json()["data"] | |
async def upload_file(config, image_bytes: bytes, filename: str = "image.png"): | |
"""Upload image file to kontextflux.""" | |
url = "https://api.kontextflux.com/client/resource/uploadFile" | |
files = [("file", (filename, BytesIO(image_bytes), "null"))] | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36 Edg/137.0.0.0", | |
"Accept-Encoding": "gzip, deflate, br, zstd", | |
"Authorization": config["token"], | |
"xtx": KontextFluxEncryptor(config).get_xtx_hash({}), | |
} | |
async with httpx.AsyncClient() as client: | |
response = await client.post(url, files=files, headers=headers) | |
response.raise_for_status() | |
return response.json()["data"] | |
def create_draw_task(config, prompt: str, keys: List[str] = [], size: str = "auto"): | |
"""Create image generation task.""" | |
url = "https://api.kontextflux.com/client/styleAI/draw" | |
payload = {"keys": keys, "prompt": prompt, "size": size} | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/137.0.0.0 Safari/537.36 Edg/137.0.0.0", | |
"Accept-Encoding": "gzip, deflate, br, zstd", | |
"Content-Type": "application/json", | |
"Authorization": config["token"], | |
"xtx": KontextFluxEncryptor(config).get_xtx_hash(payload), | |
} | |
response = requests.post(url, data=json.dumps(payload), headers=headers) | |
response.raise_for_status() | |
return response.json()["data"]["id"] | |
async def process_image_url(image_url: str) -> Optional[bytes]: | |
"""Process image URL (base64 or http/https) and return image bytes.""" | |
try: | |
if image_url.startswith('data:image/'): | |
# Handle base64 encoded image | |
_, encoded = image_url.split(",", 1) | |
return base64.b64decode(encoded) | |
elif image_url.startswith(('http://', 'https://')): | |
# Handle HTTP(S) URL | |
async with httpx.AsyncClient() as client: | |
response = await client.get(image_url, timeout=60, follow_redirects=True) | |
response.raise_for_status() | |
return response.content | |
else: | |
print(f"Unsupported image URL format: {image_url[:30]}...") | |
return None | |
except Exception as e: | |
print(f"Error processing image: {e}") | |
return None | |
# Pydantic Models | |
class ChatMessage(BaseModel): | |
role: str | |
content: Union[str, List[Dict[str, Any]]] | |
reasoning_content: Optional[str] = None | |
class ChatRequest(BaseModel): | |
model: str | |
messages: List[ChatMessage] | |
stream: bool = True | |
size: str = "auto" # KontextFlux size parameter: auto, 2:3, 3:2, 1:1 | |
temperature: Optional[float] = None | |
max_tokens: Optional[int] = None | |
top_p: Optional[float] = None | |
class ModelInfo(BaseModel): | |
id: str | |
object: str = "model" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
owned_by: str = "kontextflux" | |
class ModelList(BaseModel): | |
object: str = "list" | |
data: List[ModelInfo] | |
class StreamChoice(BaseModel): | |
delta: Dict[str, Any] = Field(default_factory=dict) | |
index: int = 0 | |
finish_reason: Optional[str] = None | |
class StreamResponse(BaseModel): | |
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}") | |
object: str = "chat.completion.chunk" | |
created: int = Field(default_factory=lambda: int(time.time())) | |
model: str | |
choices: List[StreamChoice] | |
# FastAPI App | |
app = FastAPI(title="KontextFlux OpenAI API Adapter") | |
security = HTTPBearer(auto_error=False) | |
# Global variables | |
_TOKENS = os.getenv("TOKENS", "") | |
TOKENS = _TOKENS.split(",") | |
iterator = cycle(TOKENS) | |
API_KEY = os.getenv("API_KEY", "linux.do") | |
VALID_SIZES = ["auto", "2:3", "3:2", "1:1"] # Supported KontextFlux size options | |
def get_token(): | |
return next(iterator) | |
def parse_prompt_parameters(prompt: str) -> Tuple[str, str]: | |
""" | |
Parse parameters from user prompt and return cleaned prompt with extracted size. | |
Supports formats like: | |
- "A beautiful landscape --size 2:3" | |
- "Generate an image --size=1:1 of a cat" | |
- "Create art with size:3:2" | |
Returns: | |
tuple: (cleaned_prompt, size) | |
""" | |
import re | |
# Default values | |
size = "auto" | |
cleaned_prompt = prompt | |
# Pattern to match size parameters in various formats | |
size_patterns = [ | |
r'--size[=\s]+([^\s]+)', # --size 2:3 or --size=2:3 | |
r'size[:\s]+([^\s]+)', # size:2:3 or size 2:3 | |
r'\bsize[=\s]+([^\s]+)', # size=2:3 or size 2:3 | |
] | |
for pattern in size_patterns: | |
match = re.search(pattern, prompt, re.IGNORECASE) | |
if match: | |
extracted_size = match.group(1).strip() | |
if extracted_size in VALID_SIZES: | |
size = extracted_size | |
# Remove the parameter from prompt | |
cleaned_prompt = re.sub(pattern, '', prompt, flags=re.IGNORECASE).strip() | |
# Clean up extra spaces | |
cleaned_prompt = re.sub(r'\s+', ' ', cleaned_prompt).strip() | |
break | |
return cleaned_prompt, size | |
def validate_size_parameter(size: str) -> str: | |
"""Validate and return the size parameter.""" | |
if size not in VALID_SIZES: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Invalid size '{size}'. Valid options: {', '.join(VALID_SIZES)}" | |
) | |
return size | |
async def authenticate_client(auth: Optional[HTTPAuthorizationCredentials] = Depends(security)): | |
"""Authenticate client based on API key in Authorization header""" | |
if not API_KEY: | |
raise HTTPException(status_code=503, detail="Service unavailable: Client API keys not configured.") | |
if not auth or not auth.credentials: | |
raise HTTPException(status_code=401, detail="API key required in Authorization header.") | |
if auth.credentials not in API_KEY: | |
raise HTTPException(status_code=403, detail="Invalid client API key.") | |
async def startup(): | |
"""Application startup initialization""" | |
print("Starting KontextFlux OpenAI API Adapter server...") | |
print("Server initialization completed.") | |
async def list_models(_: None = Depends(authenticate_client)): | |
"""List available models""" | |
return ModelList(data=[ModelInfo(id="kontext-flux")]) | |
async def chat_completions(request: ChatRequest, _: None = Depends(authenticate_client)): | |
"""Create chat completion using KontextFlux backend""" | |
if request.model != "kontext-flux": | |
raise HTTPException(status_code=404, detail=f"Model '{request.model}' not found.") | |
if not request.messages: | |
raise HTTPException(status_code=400, detail="No messages provided in the request.") | |
# Extract prompt and images from messages | |
prompt_parts = [] | |
image_urls = [] | |
for message in request.messages: | |
if isinstance(message.content, str): | |
prompt_parts.append(message.content) | |
elif isinstance(message.content, list): | |
for part in message.content: | |
if part.get("type") == "text" and part.get("text"): | |
prompt_parts.append(part["text"]) | |
elif part.get("type") == "image_url" and part.get("image_url", {}).get("url"): | |
image_urls.append(part["image_url"]["url"]) | |
raw_prompt = " ".join(filter(None, prompt_parts)) | |
if not raw_prompt and not image_urls: | |
raise HTTPException(status_code=400, detail="Request must contain text prompt or at least one image.") | |
# Parse parameters from prompt and validate | |
prompt, parsed_size = parse_prompt_parameters(raw_prompt) | |
# Use parsed size if found, otherwise use request parameter, fallback to default | |
final_size = parsed_size if parsed_size != "auto" else request.size | |
final_size = validate_size_parameter(final_size) | |
try: | |
# Get kontextflux config | |
config = get_config() | |
# Process and upload images | |
uploaded_keys = [] | |
if image_urls: | |
for i, image_url in enumerate(image_urls): | |
image_bytes = await process_image_url(image_url) | |
if image_bytes: | |
upload_result = await upload_file(config, image_bytes, f"image_{i}.png") | |
uploaded_keys.append(upload_result["key"]) | |
# Create draw task with dynamic size parameter | |
draw_id = create_draw_task(config, prompt, uploaded_keys, final_size) | |
if request.stream: | |
return StreamingResponse( | |
kontextflux_stream_generator(config, draw_id, request.model), | |
media_type="text/event-stream", | |
headers={ | |
"Cache-Control": "no-cache", | |
"Connection": "keep-alive", | |
"X-Accel-Buffering": "no", | |
}, | |
) | |
else: | |
# Non-streaming response (wait for completion) | |
final_url = await wait_for_completion(config, draw_id) | |
return { | |
"id": f"chatcmpl-{uuid.uuid4().hex}", | |
"object": "chat.completion", | |
"created": int(time.time()), | |
"model": request.model, | |
"choices": [{ | |
"message": { | |
"role": "assistant", | |
"content": f"" | |
}, | |
"index": 0, | |
"finish_reason": "stop" | |
}], | |
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} | |
} | |
except Exception as e: | |
print(f"Error processing request: {e}") | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
async def kontextflux_stream_generator(config, draw_id: str, model: str): | |
"""Generate streaming response with progress updates""" | |
stream_id = f"chatcmpl-{uuid.uuid4().hex}" | |
created_time = int(time.time()) | |
# Send initial role delta | |
yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={'role': 'assistant'})]).model_dump_json()}\n\n" | |
try: | |
# Setup WebSocket connection | |
e = {"token": config["token"], "id": draw_id} | |
xtx = KontextFluxEncryptor(config).get_xtx_hash(e) | |
url = f"wss://api.kontextflux.com/client/styleAI/checkWs?xtx={xtx}" | |
ws = websocket.create_connection(url) | |
ws.send(json.dumps(e)) | |
while True: | |
await asyncio.sleep(0.1) # Small delay to prevent blocking | |
try: | |
msg = ws.recv() | |
data = json.loads(msg) | |
if data["content"]["photo"]: | |
# Generation completed | |
final_url = data["content"]["photo"]["url"] | |
# Send final content | |
yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={'content': f''})]).model_dump_json()}\n\n" | |
# Send completion signal | |
yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={}, finish_reason='stop')]).model_dump_json()}\n\n" | |
yield "data: [DONE]\n\n" | |
break | |
else: | |
# Send progress update as reasoning content | |
progress = data["content"]["progress"] | |
# 选择适合当前进度的表情符号 | |
if progress < 20: | |
emoji = "🚀" # 开始阶段 | |
elif progress < 40: | |
emoji = "⚙️" # 处理中 | |
elif progress < 60: | |
emoji = "✨" # 半程 | |
elif progress < 80: | |
emoji = "🔍" # 细节生成 | |
elif progress < 100: | |
emoji = "🎨" # 最终润色 | |
else: | |
emoji = "✅" # 完成 | |
# 创建进度条 | |
bar_length = 20 | |
filled_length = int(bar_length * progress / 100) | |
bar = "█" * filled_length + "░" * (bar_length - filled_length) | |
# 格式化美观的进度信息 | |
reasoning_text = f"{emoji} 图像生成进度 |{bar}| {progress}%\n" | |
yield f"data: {StreamResponse(id=stream_id, created=created_time, model=model, choices=[StreamChoice(delta={'reasoning_content': reasoning_text})]).model_dump_json()}\n\n" | |
except websocket.WebSocketTimeoutException: | |
continue | |
except Exception as e: | |
print(f"WebSocket error: {e}") | |
break | |
ws.close() | |
except Exception as e: | |
print(f"Stream processing error: {e}") | |
yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
yield "data: [DONE]\n\n" | |
async def wait_for_completion(config, draw_id: str) -> str: | |
"""Wait for image generation completion (non-streaming)""" | |
e = {"token": config["token"], "id": draw_id} | |
xtx = KontextFluxEncryptor(config).get_xtx_hash(e) | |
url = f"wss://api.kontextflux.com/client/styleAI/checkWs?xtx={xtx}" | |
ws = websocket.create_connection(url) | |
ws.send(json.dumps(e)) | |
try: | |
while True: | |
await asyncio.sleep(1) | |
msg = ws.recv() | |
data = json.loads(msg) | |
if data["content"]["photo"]: | |
return data["content"]["photo"]["url"] | |
finally: | |
ws.close() | |
if __name__ == "__main__": | |
import uvicorn | |
print("\n--- KontextFlux OpenAI API Adapter ---") | |
print("Endpoints:") | |
print(" GET /v1/models") | |
print(" POST /v1/chat/completions") | |
print(f"TOKENS: {len(TOKENS)}") | |
print("------------------------------------") | |
uvicorn.run(app, host="0.0.0.0", port=8000) |