File size: 5,677 Bytes
007d70f
358d2a9
007d70f
c7ebd48
 
 
 
 
 
 
358d2a9
 
 
 
c7ebd48
 
 
 
 
 
358d2a9
fb2a718
358d2a9
 
c7ebd48
 
358d2a9
c7ebd48
 
 
 
 
358d2a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7ebd48
358d2a9
 
 
 
 
 
c7ebd48
 
 
358d2a9
 
 
 
 
 
c7ebd48
358d2a9
 
 
 
c7ebd48
358d2a9
c7ebd48
 
358d2a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7ebd48
 
 
 
358d2a9
 
 
c7ebd48
358d2a9
 
 
 
c7ebd48
 
 
 
 
 
 
 
358d2a9
c7ebd48
 
 
358d2a9
c7ebd48
358d2a9
 
 
c7ebd48
358d2a9
 
c7ebd48
358d2a9
c7ebd48
 
358d2a9
 
 
 
 
c7ebd48
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache'  # Set local writable cache

from flask import Flask, request, jsonify
from flask_cors import CORS
from flask_sock import Sock
import uuid
import time
import requests
from transformers import pipeline
from Crypto.Cipher import AES
from Crypto.Hash import SHA256
import base64
import threading

# Initialize Flask app and WebSocket
app = Flask(__name__)
CORS(app)
sock = Sock(app)

# AI classification pipeline using lightweight model
classifier = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-1")

# Active chat sessions: {session_id: {password, created_at, messages, flagged, flagged_messages}}
SESSIONS = {}

# Labels for AI to detect
SENSITIVE_LABELS = ["terrorism", "blackmail", "national security threat"]

# Storage space API to log flagged chats
STORAGE_API = "https://mike23415-storage.hf.space/api/flag"

def decrypt_message(encrypted_b64, password):
    try:
        # Derive key from password
        pw_hash = SHA256.new(password.encode()).digest()
        # Decode base64
        encrypted = base64.b64decode(encrypted_b64)
        # Extract IV and ciphertext (assuming AES-GCM with 12-byte IV)
        iv = encrypted[:12]
        ciphertext = encrypted[12:]
        # Decrypt
        cipher = AES.new(pw_hash, AES.MODE_GCM, nonce=iv)
        plaintext = cipher.decrypt(ciphertext).decode()
        return plaintext
    except Exception as e:
        print(f"Decryption failed: {e}")
        return None

def flag_if_sensitive(decrypted_text, ip, session_id, role, encrypted_msg):
    if not decrypted_text:
        return
    # AI checks for sensitive labels
    result = classifier(decrypted_text, SENSITIVE_LABELS)
    scores = dict(zip(result["labels"], result["scores"]))
    for label, score in scores.items():
        if score > 0.8:
            print(f"⚠️ FLAGGED: {label} with score {score}")
            # Mark session as flagged
            SESSIONS[session_id]["flagged"] = True
            # Store flagged message
            flagged_entry = {
                "encrypted_msg": encrypted_msg,
                "decrypted_msg": decrypted_text,
                "label": label,
                "score": score,
                "role": role,
                "ip": ip,
                "timestamp": time.time()
            }
            SESSIONS[session_id]["flagged_messages"].append(flagged_entry)
            break

def log_flagged_session(session_id):
    if session_id not in SESSIONS or not SESSIONS[session_id]["flagged"]:
        return
    session = SESSIONS[session_id]
    payload = {
        "session_id": session_id,
        "created_at": session["created_at"],
        "messages": session["messages"],
        "unique_ips": list(set(msg["ip"] for msg in session["messages"])),
        "flagged_messages": session["flagged_messages"]
    }
    try:
        requests.post(STORAGE_API, json=payload, timeout=3)
        print(f"Logged flagged session {session_id}")
    except Exception as e:
        print(f"Failed to log session {session_id}: {e}")

def cleanup_session(session_id):
    if session_id in SESSIONS:
        log_flagged_session(session_id)
        del SESSIONS[session_id]
        print(f"Deleted session {session_id}")

@app.route("/api/create_chat", methods=["POST"])
def create_chat():
    data = request.get_json()
    password = data.get("password", "default")  # Default password if none provided
    session_id = str(uuid.uuid4())
    SESSIONS[session_id] = {
        "password": password,
        "created_at": time.time(),
        "messages": [],
        "flagged": False,
        "flagged_messages": [],
        "connections": []
    }
    # Schedule cleanup after 15 minutes
    threading.Timer(900, cleanup_session, args=[session_id]).start()
    short_id = session_id[:8]  # Simplified short ID
    short_url = f"https://{request.host}/s/{short_id}"
    return jsonify({"session_id": session_id, "short_id": short_id, "short_url": short_url})

@sock.route('/ws/<session_id>')
def chat(ws, session_id):
    ip = request.remote_addr or "unknown"
    if session_id not in SESSIONS:
        ws.send('{"type": "error", "message": "Session not found"}')
        ws.close()
        return

    # Assign role based on join order
    join_index = sum(1 for msg in SESSIONS[session_id]["messages"] if msg["role"].startswith("Receiver")) + 1
    role = "Sender" if len(SESSIONS[session_id]["messages"]) == 0 else f"Receiver {join_index}"
    SESSIONS[session_id]["connections"].append(ws)

    try:
        while True:
            msg = ws.receive()
            if msg is None:
                break
            entry = {
                "role": role,
                "encrypted_msg": msg,
                "ip": ip,
                "timestamp": time.time()
            }
            SESSIONS[session_id]["messages"].append(entry)

            # Decrypt for AI analysis
            decrypted_text = decrypt_message(msg, SESSIONS[session_id]["password"])
            flag_if_sensitive(decrypted_text, ip, session_id, role, msg)

            # Broadcast encrypted message to all participants
            for conn in SESSIONS[session_id]["connections"]:
                try:
                    conn.send(f'{{"role": "{role}", "encrypted_msg": "{msg}"}}')
                except:
                    continue
    except Exception as e:
        print(f"WebSocket error: {e}")
    finally:
        if ws in SESSIONS[session_id]["connections"]:
            SESSIONS[session_id]["connections"].remove(ws)

@app.route("/")
def root():
    return "Real-time AI chat backend is running."

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)