|
from flask import Flask, render_template_string, request, redirect, url_for, jsonify |
|
from huggingface_hub import snapshot_download |
|
from transformers import pipeline |
|
import os |
|
import threading |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
AVAILABLE_MODELS = { |
|
"GPT-OSS-20B": "togethercomputer/GPT-NeoXT-Chat-Base-20B", |
|
"GPT-Neo-1.3B": "EleutherAI/gpt-neo-1.3B", |
|
"LLaMA-7B": "meta-llama/Llama-2-7b-hf", |
|
"Mistral-7B": "mistralai/Mistral-7B-v0.1" |
|
} |
|
|
|
MODEL_DIR = "models" |
|
os.makedirs(MODEL_DIR, exist_ok=True) |
|
|
|
current_model = None |
|
pipe = None |
|
|
|
|
|
download_status = {"model": None, "status": "idle"} |
|
|
|
|
|
def download_model(model_name, repo_id): |
|
global download_status |
|
download_status = {"model": model_name, "status": "downloading"} |
|
|
|
try: |
|
snapshot_download(repo_id, local_dir=os.path.join(MODEL_DIR, model_name)) |
|
download_status = {"model": model_name, "status": "done"} |
|
except Exception as e: |
|
download_status = {"model": model_name, "status": f"error: {e}"} |
|
|
|
|
|
@app.route("/", methods=["GET", "POST"]) |
|
def index(): |
|
global current_model, pipe |
|
|
|
if request.method == "POST": |
|
if "download" in request.form: |
|
model_name = request.form["model"] |
|
repo_id = AVAILABLE_MODELS[model_name] |
|
|
|
|
|
thread = threading.Thread(target=download_model, args=(model_name, repo_id)) |
|
thread.start() |
|
return redirect(url_for("index")) |
|
|
|
elif "use" in request.form: |
|
model_name = request.form["model"] |
|
repo_id = AVAILABLE_MODELS[model_name] |
|
|
|
current_model = model_name |
|
pipe = pipeline("text-generation", model=repo_id, device_map="auto") |
|
return redirect(url_for("index")) |
|
|
|
elif "ask" in request.form: |
|
user_input = request.form["user_input"] |
|
if pipe: |
|
result = pipe(user_input, max_new_tokens=200, do_sample=True, temperature=0.7) |
|
answer = result[0]["generated_text"] |
|
else: |
|
answer = "Модель не выбрана." |
|
return render_template_string(TEMPLATE, models=AVAILABLE_MODELS, current=current_model, answer=answer) |
|
|
|
return render_template_string(TEMPLATE, models=AVAILABLE_MODELS, current=current_model, answer=None) |
|
|
|
|
|
@app.route("/progress") |
|
def progress(): |
|
return jsonify(download_status) |
|
|
|
|
|
TEMPLATE = """ |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<title>AI Model Manager</title> |
|
<style> |
|
body { font-family: Arial; margin: 30px; background: #f4f4f4; } |
|
.card { background: white; padding: 20px; margin-bottom: 20px; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);} |
|
select, input[type=text], button { padding: 10px; margin-top: 10px; width: 100%; } |
|
button { background: #007BFF; color: white; border: none; cursor: pointer; } |
|
button:hover { background: #0056b3; } |
|
#progress-box { margin-top: 10px; padding: 10px; background: #eee; border-radius: 5px; display: none; } |
|
</style> |
|
</head> |
|
<body> |
|
<h2>AI Model Manager</h2> |
|
|
|
<div class="card"> |
|
<form method="post"> |
|
<label for="model">Choose model:</label> |
|
<select name="model"> |
|
{% for name, repo in models.items() %} |
|
<option value="{{name}}" {% if current == name %}selected{% endif %}>{{name}} ({{repo}})</option> |
|
{% endfor %} |
|
</select> |
|
<button type="submit" name="download">Download Model</button> |
|
<button type="submit" name="use">Use Model</button> |
|
</form> |
|
<p><b>Current model:</b> {{ current if current else "None" }}</p> |
|
<div id="progress-box">Checking status...</div> |
|
</div> |
|
|
|
<div class="card"> |
|
<form method="post"> |
|
<label>Ask AI:</label> |
|
<input type="text" name="user_input" placeholder="Enter your question..." required> |
|
<button type="submit" name="ask">Ask</button> |
|
</form> |
|
{% if answer %} |
|
<div style="margin-top:15px; padding:10px; background:#e8f0fe; border-radius:5px;"> |
|
<b>Answer:</b> {{ answer }} |
|
</div> |
|
{% endif %} |
|
</div> |
|
|
|
<script> |
|
function checkProgress() { |
|
fetch("/progress") |
|
.then(response => response.json()) |
|
.then(data => { |
|
let box = document.getElementById("progress-box"); |
|
if (data.status === "idle") { |
|
box.style.display = "none"; |
|
} else { |
|
box.style.display = "block"; |
|
if (data.status === "downloading") { |
|
box.innerHTML = "Downloading model: " + data.model + "..."; |
|
} else if (data.status === "done") { |
|
box.innerHTML = "✅ Model " + data.model + " downloaded successfully!"; |
|
} else { |
|
box.innerHTML = "⚠️ Error: " + data.status; |
|
} |
|
} |
|
}); |
|
} |
|
|
|
setInterval(checkProgress, 2000); |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
|
|
if __name__ == "__main__": |
|
app.run(debug=True, host="0.0.0.0", port=5000) |
|
|