Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image, ImageDraw, ImageFont, ImageOps | |
import base64 | |
import io | |
import json | |
import logging | |
import os | |
import requests | |
import struct | |
import tempfile | |
import numpy as np | |
from cryptography.hazmat.primitives import serialization | |
from cryptography.hazmat.primitives.asymmetric import rsa, padding | |
from cryptography.hazmat.primitives.ciphers.aead import AESGCM | |
from cryptography.hazmat.primitives import hashes | |
from cryptography.exceptions import InvalidTag | |
from gradio_client import Client | |
from huggingface_hub import InferenceClient | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- Configuration for Endpoints --- | |
LOCAL_ENDPOINTS_FILE = "endpoints.json" | |
REMOTE_ENDPOINTS_URL = "https://huggingface.co/spaces/Space-Share/bucket/raw/main/endpoints.json" | |
# --- Constants and Key Management --- | |
HEADER_BITS = 32 | |
AES_GCM_NONCE_SIZE = 12 | |
KEYLOCK_PRIV_KEY_PEM = os.environ.get('KEYLOCK_PRIV_KEY') | |
PRIVATE_KEY_OBJECT = None | |
PUBLIC_KEY_PEM_STRING = "" | |
KEYLOCK_STATUS_MESSAGE = "" | |
MOCK_USER_DATABASE = {"sk-12345-abcde": {"user": "demo-user", "permissions": "read"}, "sk-67890-fghij": {"user": "admin-user", "permissions": "read,write,delete"}} | |
if not KEYLOCK_PRIV_KEY_PEM: | |
logger.warning("No KEYLOCK_PRIV_KEY secret found. Generating a temporary key pair for this session.") | |
temp_priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) | |
KEYLOCK_PRIV_KEY_PEM = temp_priv_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption()).decode('utf-8') | |
KEYLOCK_STATUS_MESSAGE = "β οΈ No secret found. Using a temporary key for this session. Keys will be lost on restart." | |
else: | |
logger.info("Successfully loaded private key from environment variable 'KEYLOCK_PRIV_KEY'.") | |
KEYLOCK_STATUS_MESSAGE = "β Loaded successfully from secrets/environment variable." | |
try: | |
PRIVATE_KEY_OBJECT = serialization.load_pem_private_key(KEYLOCK_PRIV_KEY_PEM.encode(), password=None) | |
PUBLIC_KEY_PEM_STRING = PRIVATE_KEY_OBJECT.public_key().public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo).decode('utf-8') | |
KEYLOCK_STATUS_MESSAGE += "\nβ Public key derived successfully." | |
except Exception as e: | |
PRIVATE_KEY_OBJECT = None | |
PUBLIC_KEY_PEM_STRING = "Error: Key could not be processed." | |
KEYLOCK_STATUS_MESSAGE += f"\nβ Failed to parse key: {e}" | |
# --- Core Helper Functions --- | |
def _parse_secret_data(secret_data_str: str) -> dict: | |
stripped_input = secret_data_str.strip() | |
try: | |
data_dict = json.loads(stripped_input) | |
if isinstance(data_dict, dict): return data_dict | |
except json.JSONDecodeError: | |
pass | |
data_dict = {} | |
for line in stripped_input.splitlines(): | |
line = line.strip() | |
if not line or line.startswith('#'): continue | |
separator = ':' if ':' in line else '=' | |
if separator not in line: continue | |
parts = line.split(separator, 1) | |
if len(parts) == 2: | |
key = parts[0].strip().strip("'\"") | |
value = parts[1].strip().strip(",").strip().strip("'\"") | |
if key: data_dict[key] = value | |
return data_dict | |
def prepare_base_image(uploaded_image: Image.Image | None, progress) -> Image.Image: | |
size = 600 | |
if uploaded_image: | |
progress(0, desc="β Using uploaded image...") | |
return ImageOps.fit(uploaded_image, (size, size), Image.Resampling.LANCZOS) | |
try: | |
progress(0, desc="β³ Fetching default background...") | |
response = requests.get("https://images.unsplash.com/photo-1506318137071-a8e063b4bec0?q=80&w=1200&auto=format=fit=crop", timeout=10) | |
response.raise_for_status() | |
img = Image.open(io.BytesIO(response.content)).convert("RGB") | |
return ImageOps.fit(img, (size, size), Image.Resampling.LANCZOS) | |
except Exception as e: | |
logger.warning(f"Default image fetch failed: {e}. Falling back to AI.") | |
try: | |
progress(0, desc="β³ Generating image with SDXL-Lightning...") | |
client = InferenceClient() | |
image_bytes = client.text_to_image("A stunning view of a distant galaxy, nebulae, and constellations, digital art, vibrant colors", model="sd-community/sdxl-lightning") | |
return ImageOps.fit(Image.open(io.BytesIO(image_bytes)).convert("RGB"), (size, size), Image.Resampling.LANCZOS) | |
except Exception as e: | |
raise gr.Error(f"All image sources failed. AI error: {e}") | |
def create_encrypted_image(payload_dict: dict, public_key_pem: str, base_image: Image.Image, overlay_option: str, server_url: str) -> Image.Image: | |
json_bytes = json.dumps(payload_dict).encode('utf-8') | |
public_key = serialization.load_pem_public_key(public_key_pem.encode('utf-8')) | |
aes_key, nonce = os.urandom(32), os.urandom(12) | |
ciphertext = AESGCM(aes_key).encrypt(nonce, json_bytes, None) | |
rsa_encrypted_key = public_key.encrypt(aes_key, padding.OAEP(mgf=padding.MGF1(hashes.SHA256()), algorithm=hashes.SHA256(), label=None)) | |
encrypted_payload = struct.pack('>I', len(rsa_encrypted_key)) + rsa_encrypted_key + nonce + ciphertext | |
img = base_image.copy().convert("RGB") | |
width, height = img.size | |
draw = ImageDraw.Draw(img, "RGBA") | |
try: | |
font_bold = ImageFont.truetype("DejaVuSans-Bold.ttf", 30); font_regular = ImageFont.truetype("DejaVuSans.ttf", 15); font_small = ImageFont.truetype("DejaVuSans.ttf", 12) | |
except IOError: | |
font_bold = ImageFont.load_default(size=28); font_regular = ImageFont.load_default(size=14); font_small = ImageFont.load_default(size=12) | |
overlay_color, title_color, key_color, value_color = (15, 23, 42, 190), (226, 232, 240), (148, 163, 184), (241, 245, 249) | |
draw.rectangle([0, 20, width, 100], fill=overlay_color) | |
draw.text((width / 2, 45), "KeyLock Secure Data", fill=title_color, font=font_bold, anchor="ms") | |
draw.text((width / 2, 75), server_url.replace("https://", ""), fill=key_color, font=font_small, anchor="ms") | |
if overlay_option != "None": | |
lines = list(payload_dict.keys()) if overlay_option == "Keys Only" else [f"{k}: {v}" for k, v in payload_dict.items()] | |
line_heights = [draw.textbbox((0, 0), line, font=font_regular)[3] for line in lines] | |
box_y0 = height - (sum(line_heights) + (len(lines) - 1) * 6 + 30) - 20 | |
draw.rectangle([20, box_y0, width - 20, height - 20], fill=overlay_color) | |
current_y = box_y0 + 15 | |
for i, (key, value) in enumerate(payload_dict.items()): | |
if overlay_option == "Keys Only": | |
draw.text((35, current_y), key, fill=key_color, font=font_regular) | |
else: | |
key_text = f"{key}:"; draw.text((35, current_y), key_text, fill=key_color, font=font_regular) | |
key_bbox = draw.textbbox((35, current_y), key_text, font=font_regular) | |
draw.text((key_bbox[2] + 8, current_y), str(value), fill=value_color, font=font_regular) | |
current_y += line_heights[i] + 6 | |
pixel_data = np.array(img).ravel() | |
binary_payload = ''.join(format(b, '08b') for b in struct.pack('>I', len(encrypted_payload)) + encrypted_payload) | |
if len(binary_payload) > pixel_data.size: raise ValueError("Payload is too large for the image.") | |
pixel_data[:len(binary_payload)] = (pixel_data[:len(binary_payload)] & 0xFE) | np.array(list(binary_payload), dtype=np.uint8) | |
return Image.fromarray(pixel_data.reshape((height, width, 3)), 'RGB') | |
# --- API Functions --- | |
def api_get_info(): | |
return {"name": "Embedded KeyLock Server", "version": "2.1", "documentation": "This server can generate and authenticate KeyLock images.", "required_payload_keys": [{"key_name": "API_KEY", "description": "Your unique API Key.", "example": "sk-12345-abcde"}, {"key_name": "USER", "description": "The user ID for the key.", "example": "demo-user"}]} | |
def api_get_public_key(): | |
return PUBLIC_KEY_PEM_STRING | |
def api_decode_and_auth(image_base64_string: str) -> dict: | |
if not PRIVATE_KEY_OBJECT: raise gr.Error("Server is not configured with a private key.") | |
try: | |
pixel_data = np.array(Image.open(io.BytesIO(base64.b64decode(image_base64_string))).convert("RGB")).ravel() | |
header_binary_string = "".join(str(p & 1) for p in pixel_data[:HEADER_BITS]) | |
data_length = int(header_binary_string, 2) | |
data_binary_string = "".join(str(p & 1) for p in pixel_data[HEADER_BITS:HEADER_BITS + data_length * 8]) | |
crypto_payload = int(data_binary_string, 2).to_bytes(data_length, byteorder='big') | |
offset = 4; encrypted_aes_key_len = struct.unpack('>I', crypto_payload[:offset])[0] | |
encrypted_aes_key = crypto_payload[offset:offset + encrypted_aes_key_len]; offset += encrypted_aes_key_len | |
nonce = crypto_payload[offset:offset + AES_GCM_NONCE_SIZE]; offset += AES_GCM_NONCE_SIZE | |
ciphertext_with_tag = crypto_payload[offset:] | |
recovered_aes_key = PRIVATE_KEY_OBJECT.decrypt(encrypted_aes_key, padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None)) | |
decrypted_payload = json.loads(AESGCM(recovered_aes_key).decrypt(nonce, ciphertext_with_tag, None).decode('utf-8')) | |
db_entry = MOCK_USER_DATABASE.get(decrypted_payload.get('API_KEY')) | |
if db_entry and db_entry.get("user") == decrypted_payload.get('USER'): | |
return {"status": "Success", "message": f"User '{decrypted_payload.get('USER')}' authenticated.", "details": decrypted_payload} | |
else: | |
return {"status": "Failed", "message": "Invalid credentials.", "details": decrypted_payload} | |
except Exception as e: | |
return {"status": "Error", "message": f"Decryption/Processing Failed: {e}", "details": {}} | |
# --- Endpoints File Management --- | |
def load_endpoints(): | |
"""Tries to load endpoints from a local file, falling back to a remote URL.""" | |
if os.path.exists(LOCAL_ENDPOINTS_FILE): | |
try: | |
with open(LOCAL_ENDPOINTS_FILE, 'r') as f: | |
data = json.load(f) | |
if data: | |
logger.info(f"Loaded endpoints from local file: {LOCAL_ENDPOINTS_FILE}") | |
return data | |
except (json.JSONDecodeError, IOError) as e: | |
logger.warning(f"Local file '{LOCAL_ENDPOINTS_FILE}' is invalid: {e}. Trying remote.") | |
logger.info(f"Fetching endpoints from remote URL: {REMOTE_ENDPOINTS_URL}") | |
try: | |
response = requests.get(REMOTE_ENDPOINTS_URL, timeout=5) | |
response.raise_for_status() | |
data = response.json() | |
logger.info("Successfully fetched remote endpoints. Caching locally.") | |
save_endpoints(data) | |
return data | |
except (requests.RequestException, json.JSONDecodeError) as e: | |
logger.error(f"Failed to fetch or parse remote endpoints: {e}. Starting with an empty list.") | |
return [] | |
def save_endpoints(endpoints_list): | |
"""Saves the given list of endpoints to the local JSON file.""" | |
try: | |
with open(LOCAL_ENDPOINTS_FILE, 'w') as f: | |
json.dump(endpoints_list, f, indent=4) | |
logger.info(f"Successfully saved endpoints to {LOCAL_ENDPOINTS_FILE}") | |
except IOError as e: | |
logger.error(f"Error saving endpoints to {LOCAL_ENDPOINTS_FILE}: {e}") | |
# --- Gradio UI Definition --- | |
theme = gr.themes.Soft(primary_hue="sky", secondary_hue="blue", neutral_hue="slate") | |
with gr.Blocks(theme=theme, title="KeyLock Showcase") as demo: | |
all_servers_state = gr.State(value=load_endpoints()) | |
active_server_state = gr.State({}) | |
gr.Markdown("# π KeyLock Showcase") | |
gr.Markdown("A comprehensive toolkit for generating and testing KeyLock authentication images against live servers.") | |
with gr.Tabs(): | |
with gr.TabItem("Client Operations"): | |
gr.Markdown("### 1. Connect to a Target Server") | |
with gr.Row(): | |
saved_servers_dropdown = gr.Dropdown(label="Load Saved Server", interactive=True) | |
with gr.Column(): | |
server_url_input = gr.Textbox(label="Or Add New Server by URL", placeholder="https://your-server.hf.space") | |
connect_button = gr.Button("Connect New Server", variant="primary") | |
client_status_display = gr.Markdown("**Status:** Not Connected") | |
with gr.Accordion("2. Create an Encrypted Image for the Connected Server", open=False) as client_generate_accordion: | |
with gr.Row(): | |
with gr.Column(scale=2): | |
client_payload_input = gr.Textbox(label="Secret Data (key:value or JSON)", lines=5) | |
client_overlay_radio = gr.Radio(label="Show Labels on Image", choices=["Keys and Values", "Keys Only", "None"], value="Keys and Values") | |
client_base_image_input = gr.Image(label="Optional Base Image", type="pil", height=200) | |
client_generate_button = gr.Button("Create Image", variant="secondary") | |
with gr.Column(scale=3): | |
client_generated_image_preview = gr.Image(label="Generated Image Preview", interactive=False) | |
client_generated_file_output = gr.File(label="Download Uncorrupted PNG", interactive=False, file_count="single") | |
with gr.Accordion("3. Test an Existing Image with the Connected Server", open=False) as client_test_accordion: | |
client_test_image_input = gr.Image(type="filepath", label="Upload Encrypted Image") | |
client_auth_status_display = gr.Markdown(visible=False) | |
client_auth_result_output = gr.JSON(label="Server Authentication Response") | |
client_test_image_input.change(lambda: (gr.update(visible=False), None), outputs=[client_auth_status_display, client_auth_result_output]) | |
with gr.TabItem("Server Showcase & Admin"): | |
gr.Markdown("## Embedded Server Details") | |
gr.Markdown("This Gradio app is also running its own KeyLock server. You can use its details to test the client.") | |
gr.Textbox(label="Embedded Server Status", value=KEYLOCK_STATUS_MESSAGE, interactive=False, lines=3) | |
gr.Code(label="Embedded Server Public Key", value=PUBLIC_KEY_PEM_STRING, language="python") | |
gr.JSON(label="Embedded Server Required Payload", value={k['key_name']: k['example'] for k in api_get_info()["required_payload_keys"]}) | |
with gr.Accordion("Generate Image with Embedded Server", open=False): | |
server_payload_input = gr.JSON(label="Payload to Encrypt", value={k['key_name']: k['example'] for k in api_get_info()["required_payload_keys"]}) | |
server_generate_button = gr.Button("Generate Image", variant="secondary") | |
server_generated_file_output = gr.File(label="Download Uncorrupted PNG", interactive=False, file_count="single") | |
with gr.Accordion("Admin: Manage Saved Servers (endpoints.json)", open=False): | |
endpoints_editor_textbox = gr.Textbox(label="Edit endpoints.json", lines=15, interactive=True, placeholder="Enter a list of servers in JSON format.") | |
save_endpoints_button = gr.Button("Save Changes to local endpoints.json", variant="primary") | |
with gr.Accordion("Admin: Generate New Key Pair", open=False): | |
gen_keys_button = gr.Button("βοΈ Generate New 2048-bit Key Pair") | |
with gr.Row(): | |
output_private_key = gr.Textbox(lines=8, label="Generated Private Key", interactive=False, show_copy_button=True) | |
output_public_key = gr.Textbox(lines=8, label="Generated Public Key", interactive=False, show_copy_button=True) | |
# --- UI Logic and Event Handlers --- | |
def initialize_ui(all_servers_list): | |
dropdown_choices = [s['name'] for s in all_servers_list] if all_servers_list else [] | |
json_text_content = json.dumps(all_servers_list, indent=4) if all_servers_list else "[]" | |
return { | |
saved_servers_dropdown: gr.update(choices=dropdown_choices), | |
endpoints_editor_textbox: json_text_content | |
} | |
def process_server_connection(server_data, all_servers): | |
placeholder = "\n".join([f"{k['key_name']}: {k['example']}" for k in server_data['info'].get('required_payload_keys', [])]) | |
status_md = f"**Status:** β Connected to **{server_data['name']}**" | |
return {active_server_state: server_data, client_status_display: status_md, client_payload_input: gr.update(placeholder=placeholder), all_servers_state: all_servers} | |
def load_server_from_dropdown(server_name, all_servers): | |
server_from_list = next((s for s in all_servers if s['name'] == server_name), None) | |
if server_from_list: | |
active_server_data = {'name': server_from_list['name'], 'url': server_from_list['link'], 'pubkey': server_from_list['public_key'], 'info': server_from_list.get('info', {})} | |
return process_server_connection(active_server_data, all_servers) | |
return {} | |
def add_new_server_from_url(url, all_servers): | |
if not url: raise gr.Error("Please provide a server URL.") | |
url = url.strip().rstrip('/') | |
if any(s['link'] == url for s in all_servers): | |
existing_server = next(s for s in all_servers if s['link'] == url) | |
gr.Info(f"Server '{existing_server['name']}' already exists. Loading it.") | |
updates = load_server_from_dropdown(existing_server['name'], all_servers) | |
updates[saved_servers_dropdown] = gr.update(value=existing_server['name']) | |
return updates | |
try: | |
client = Client(url, verbose=False) | |
info = client.predict(api_name="/keylock-info") | |
pubkey = client.predict(api_name="/keylock-pub") | |
new_server = {'name': info.get('name', url), 'link': url, 'public_key': pubkey, 'info': info} | |
all_servers.append(new_server) | |
save_endpoints(all_servers) | |
gr.Info(f"Successfully added and saved '{new_server['name']}'!") | |
server_for_state = {'name': new_server['name'], 'url': url, 'pubkey': pubkey, 'info': info} | |
updates = process_server_connection(server_for_state, all_servers) | |
updates[saved_servers_dropdown] = gr.update(choices=[s['name'] for s in all_servers], value=server_for_state['name']) | |
updates[endpoints_editor_textbox] = json.dumps(all_servers, indent=4) | |
return updates | |
except Exception as e: | |
raise gr.Error(f"Connection Failed: {e}") | |
def update_and_save_endpoints(json_string): | |
try: | |
new_endpoints = json.loads(json_string) | |
if not isinstance(new_endpoints, list): raise TypeError("The root JSON element must be a list of objects.") | |
save_endpoints(new_endpoints) | |
gr.Info("Endpoints saved successfully to endpoints.json!") | |
new_choices = [s.get('name', 'Unnamed Server') for s in new_endpoints] | |
return { | |
all_servers_state: new_endpoints, | |
saved_servers_dropdown: gr.update(choices=new_choices, value=None), | |
client_status_display: gr.update(value="**Status:** Not Connected (Endpoints Reloaded)"), | |
active_server_state: {} | |
} | |
except (json.JSONDecodeError, TypeError) as e: | |
raise gr.Error(f"Invalid format. Please check your JSON. Error: {e}") | |
def client_generate_image_wrapper(active_server, payload_str, overlay, base_img, progress=gr.Progress(track_tqdm=True)): | |
if not active_server: raise gr.Error("Not connected to a server.") | |
payload_dict = _parse_secret_data(payload_str) | |
if not payload_dict: raise gr.Error("Invalid payload format. Please provide key:value pairs or a valid JSON object.") | |
base_image = prepare_base_image(base_img, progress) | |
img = create_encrypted_image(payload_dict, active_server['pubkey'], base_image, overlay, active_server['url']) | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: | |
img.save(f.name, "PNG", compress_level=1) | |
return f.name, f.name | |
def client_authenticate_wrapper(active_server, image_path): | |
if not active_server: raise gr.Error("Not connected to a server.") | |
if not image_path: raise gr.Error("Please upload an image.") | |
try: | |
with open(image_path, "rb") as f: b64_img = base64.b64encode(f.read()).decode('utf-8') | |
client = Client(active_server['url']) | |
response = client.predict(b64_img, api_name="/keylock-auth") | |
status_md = "### β Authentication Successful" if response.get("status") == "Success" else "### β Authentication Failed" if response.get("status") == "Failed" else f"### β οΈ Server Error: {response.get('message')}" | |
return gr.update(value=status_md, visible=True), response | |
except Exception as e: | |
gr.Error(f"Authentication request failed: {e}") | |
return gr.update(value=f"### β οΈ Request Error: {e}", visible=True), None | |
def server_generate_image_wrapper(payload): | |
img = create_encrypted_image(payload, PUBLIC_KEY_PEM_STRING, prepare_base_image(None, gr.Progress(track_tqdm=True)), "Keys and Values", "Embedded Server") | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: | |
img.save(f.name, "PNG", compress_level=1) | |
return f.name | |
def generate_pem_keys(): | |
pk = rsa.generate_private_key(public_exponent=65537, key_size=2048) | |
priv = pk.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption()).decode() | |
pub = pk.public_key().public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo).decode() | |
return priv, pub | |
# --- Wire UI Events --- | |
demo.load(initialize_ui, inputs=[all_servers_state], outputs=[saved_servers_dropdown, endpoints_editor_textbox]) | |
saved_servers_dropdown.change(load_server_from_dropdown, inputs=[saved_servers_dropdown, all_servers_state], outputs=[active_server_state, client_status_display, client_payload_input, all_servers_state]) | |
connect_button.click(add_new_server_from_url, inputs=[server_url_input, all_servers_state], outputs=[active_server_state, client_status_display, client_payload_input, all_servers_state, saved_servers_dropdown, endpoints_editor_textbox]) | |
save_endpoints_button.click(update_and_save_endpoints, inputs=[endpoints_editor_textbox], outputs=[all_servers_state, saved_servers_dropdown, client_status_display, active_server_state]) | |
client_generate_button.click(client_generate_image_wrapper, inputs=[active_server_state, client_payload_input, client_overlay_radio, client_base_image_input], outputs=[client_generated_image_preview, client_generated_file_output]) | |
client_test_image_input.upload(client_authenticate_wrapper, inputs=[active_server_state, client_test_image_input], outputs=[client_auth_status_display, client_auth_result_output]) | |
server_generate_button.click(server_generate_image_wrapper, inputs=[server_payload_input], outputs=[server_generated_file_output]) | |
gen_keys_button.click(generate_pem_keys, outputs=[output_private_key, output_public_key]) | |
# --- Define API Endpoints --- | |
with gr.Row(visible=False): | |
gr.Interface(fn=api_get_info, inputs=None, outputs=gr.JSON(), api_name="keylock-info") | |
gr.Interface(fn=api_get_public_key, inputs=None, outputs=gr.Textbox(), api_name="keylock-pub") | |
gr.Interface(fn=api_decode_and_auth, inputs=gr.Textbox(), outputs=gr.JSON(), api_name="keylock-auth") | |
if __name__ == "__main__": | |
demo.launch() |