|
import os |
|
import io |
|
import json |
|
import base64 |
|
import struct |
|
import logging |
|
|
|
import gradio as gr |
|
from PIL import Image |
|
import numpy as np |
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM |
|
from cryptography.hazmat.primitives import hashes |
|
from cryptography.hazmat.primitives import serialization |
|
from cryptography.hazmat.primitives.asymmetric import padding, rsa |
|
from cryptography.exceptions import InvalidTag |
|
|
|
HEADER_BITS = 32 |
|
AES_GCM_NONCE_SIZE = 12 |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
KEYLOCK_PRIV_KEY_PEM = os.environ.get('KEYLOCK_PRIV_KEY') |
|
PRIVATE_KEY_OBJECT = None |
|
PUBLIC_KEY_PEM_STRING = "" |
|
KEYLOCK_STATUS_MESSAGE = "" |
|
|
|
if not KEYLOCK_PRIV_KEY_PEM: |
|
dev_key_path = os.path.join(os.path.dirname(__file__), '..', 'keys', 'DEMO_ONLY_THIS _IS_SECRET_keylock_priv_key.pem') |
|
try: |
|
with open(dev_key_path, "r") as f: |
|
KEYLOCK_PRIV_KEY_PEM = f.read() |
|
logger.warning(f"Loaded private key from dev path. This is for local testing only.") |
|
KEYLOCK_STATUS_MESSAGE = f"⚠️ Loaded from development key file. This is for local testing but insecure for production." |
|
except FileNotFoundError: |
|
logger.error(f"FATAL: Private key not found at '{dev_key_path}' and 'KEYLOCK_PRIV_KEY' secret is not set.") |
|
KEYLOCK_STATUS_MESSAGE = "❌ NOT FOUND. The API is non-functional. Set the `KEYLOCK_PRIV_KEY` secret or provide the demo key file." |
|
else: |
|
logger.info("Successfully loaded private key from environment variable 'KEYLOCK_PRIV_KEY'.") |
|
KEYLOCK_STATUS_MESSAGE = "✅ Loaded successfully from secrets/environment variable. Recommended secure configuration." |
|
|
|
if KEYLOCK_PRIV_KEY_PEM: |
|
try: |
|
PRIVATE_KEY_OBJECT = serialization.load_pem_private_key(KEYLOCK_PRIV_KEY_PEM.encode(), password=None) |
|
public_key = PRIVATE_KEY_OBJECT.public_key() |
|
PUBLIC_KEY_PEM_STRING = public_key.public_bytes( |
|
encoding=serialization.Encoding.PEM, |
|
format=serialization.PublicFormat.SubjectPublicKeyInfo |
|
).decode('utf-8') |
|
KEYLOCK_STATUS_MESSAGE += "\n✅ Public key derived successfully." |
|
except Exception as e: |
|
logger.error(f"Failed to load private key or derive public key: {e}", exc_info=True) |
|
PRIVATE_KEY_OBJECT = None |
|
PUBLIC_KEY_PEM_STRING = f"Error: Could not process the configured private key. Details: {e}" |
|
KEYLOCK_STATUS_MESSAGE += f"\n❌ Failed to parse key: {e}" |
|
|
|
MOCK_USER_DATABASE = { |
|
"sk-12345-abcde": {"user": "demo-user", "permissions": "read"}, |
|
"sk-67890-fghij": {"user": "admin-user", "permissions": "read,write,delete"} |
|
} |
|
|
|
def example_authenticate(api_key: str, user_id: str) -> bool: |
|
if not api_key or not user_id: |
|
return False |
|
db_entry = MOCK_USER_DATABASE.get(api_key) |
|
if db_entry and db_entry.get("user") == user_id: |
|
logger.info(f"Authentication successful for user '{user_id}'.") |
|
return True |
|
else: |
|
logger.warning(f"Authentication failed for user '{user_id}' with key '{api_key[:8]}...'.") |
|
return False |
|
|
|
def get_public_key(): |
|
if not PUBLIC_KEY_PEM_STRING or "Error" in PUBLIC_KEY_PEM_STRING: |
|
raise gr.Error("Server key is not configured correctly.") |
|
return PUBLIC_KEY_PEM_STRING |
|
|
|
def get_server_info(): |
|
return { |
|
"name": "KeyLock Auth Server", |
|
"version": "1.3", |
|
"documentation": "This server decrypts data hidden in KeyLock images and performs a mock authentication. Use /keylock-pub to get the public key, and POST to /keylock-server with a base64 image string to decrypt and authenticate.", |
|
"endpoints": { |
|
"/keylock-pub": "GET - Returns the server's public key.", |
|
"/keylock-info": "GET - Returns this information object.", |
|
"/keylock-server": "POST - Decrypts a KeyLock image and attempts authentication." |
|
} |
|
} |
|
|
|
def decode_data(image_base64_string: str) -> dict: |
|
if not PRIVATE_KEY_OBJECT: |
|
error_msg = "Server Error: The API is not configured with a private key." |
|
logger.error(error_msg) |
|
raise gr.Error(error_msg) |
|
try: |
|
image_buffer = base64.b64decode(image_base64_string) |
|
img = Image.open(io.BytesIO(image_buffer)).convert("RGB") |
|
pixel_data = np.array(img).ravel() |
|
header_binary_string = "".join(str(pixel & 1) for pixel in pixel_data[:HEADER_BITS]) |
|
data_length = int(header_binary_string, 2) |
|
if data_length == 0: raise ValueError("No data found in image.") |
|
data_bits_count = data_length * 8 |
|
end_offset = HEADER_BITS + data_bits_count |
|
if pixel_data.size < end_offset: raise ValueError("Image data corrupt or truncated.") |
|
data_binary_string = "".join(str(pixel & 1) for pixel in pixel_data[HEADER_BITS:end_offset]) |
|
crypto_payload = int(data_binary_string, 2).to_bytes(data_length, byteorder='big') |
|
offset = 4 |
|
encrypted_aes_key_len = struct.unpack('>I', crypto_payload[:offset])[0] |
|
encrypted_aes_key = crypto_payload[offset : offset + encrypted_aes_key_len] |
|
offset += encrypted_aes_key_len |
|
nonce = crypto_payload[offset : offset + AES_GCM_NONCE_SIZE] |
|
offset += AES_GCM_NONCE_SIZE |
|
ciphertext_with_tag = crypto_payload[offset:] |
|
recovered_aes_key = PRIVATE_KEY_OBJECT.decrypt(encrypted_aes_key, padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None)) |
|
decrypted_bytes = AESGCM(recovered_aes_key).decrypt(nonce, ciphertext_with_tag, None) |
|
decrypted_payload = json.loads(decrypted_bytes.decode('utf-8')) |
|
logger.info(f"Successfully decoded payload: {decrypted_payload}") |
|
api_key = decrypted_payload.get('API_KEY') |
|
user_id = decrypted_payload.get('USER') |
|
is_authenticated = example_authenticate(api_key=api_key, user_id=user_id) |
|
if is_authenticated: |
|
return { |
|
"authentication_status": "Success", |
|
"message": f"User '{user_id}' successfully authenticated.", |
|
"granted_permissions": MOCK_USER_DATABASE[api_key]['permissions'], |
|
"decoded_payload": decrypted_payload |
|
} |
|
else: |
|
return { |
|
"authentication_status": "Failed", |
|
"message": "Authentication failed. Invalid credentials provided in the image.", |
|
"decoded_payload": decrypted_payload |
|
} |
|
except (ValueError, InvalidTag, TypeError, struct.error) as e: |
|
logger.warning(f"Decryption failed: {e}") |
|
raise gr.Error(f"Decryption failed. Image may be corrupt or used the wrong public key. Details: {e}") |
|
except Exception as e: |
|
logger.error(f"An unexpected server error occurred during decryption: {e}", exc_info=True) |
|
raise gr.Error(f"An unexpected server error occurred. Details: {e}") |
|
|
|
def generate_rsa_keys(): |
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) |
|
private_pem = private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption()).decode('utf-8') |
|
public_pem = private_key.public_key().public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo).decode('utf-8') |
|
return private_pem, public_pem |