ai2 / main.py
Starchik1's picture
Update main.py
728186c verified
from flask import Flask, render_template_string, request, redirect, url_for, jsonify
from huggingface_hub import snapshot_download
from transformers import pipeline
import os
import threading
app = Flask(__name__)
# Доступные модели
AVAILABLE_MODELS = {
"GPT-OSS-20B": "togethercomputer/GPT-NeoXT-Chat-Base-20B",
"GPT-Neo-1.3B": "EleutherAI/gpt-neo-1.3B",
"LLaMA-7B": "meta-llama/Llama-2-7b-hf",
"Mistral-7B": "mistralai/Mistral-7B-v0.1"
}
MODEL_DIR = "models"
os.makedirs(MODEL_DIR, exist_ok=True)
current_model = None
pipe = None
# Прогресс загрузки
download_status = {"model": None, "status": "idle"} # idle | downloading | done
def download_model(model_name, repo_id):
global download_status
download_status = {"model": model_name, "status": "downloading"}
try:
snapshot_download(repo_id, local_dir=os.path.join(MODEL_DIR, model_name))
download_status = {"model": model_name, "status": "done"}
except Exception as e:
download_status = {"model": model_name, "status": f"error: {e}"}
@app.route("/", methods=["GET", "POST"])
def index():
global current_model, pipe
if request.method == "POST":
if "download" in request.form:
model_name = request.form["model"]
repo_id = AVAILABLE_MODELS[model_name]
# Запускаем скачивание в отдельном потоке
thread = threading.Thread(target=download_model, args=(model_name, repo_id))
thread.start()
return redirect(url_for("index"))
elif "use" in request.form:
model_name = request.form["model"]
repo_id = AVAILABLE_MODELS[model_name]
current_model = model_name
pipe = pipeline("text-generation", model=repo_id, device_map="auto")
return redirect(url_for("index"))
elif "ask" in request.form:
user_input = request.form["user_input"]
if pipe:
result = pipe(user_input, max_new_tokens=200, do_sample=True, temperature=0.7)
answer = result[0]["generated_text"]
else:
answer = "Модель не выбрана."
return render_template_string(TEMPLATE, models=AVAILABLE_MODELS, current=current_model, answer=answer)
return render_template_string(TEMPLATE, models=AVAILABLE_MODELS, current=current_model, answer=None)
@app.route("/progress")
def progress():
return jsonify(download_status)
TEMPLATE = """
<!DOCTYPE html>
<html>
<head>
<title>AI Model Manager</title>
<style>
body { font-family: Arial; margin: 30px; background: #f4f4f4; }
.card { background: white; padding: 20px; margin-bottom: 20px; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
select, input[type=text], button { padding: 10px; margin-top: 10px; width: 100%; }
button { background: #007BFF; color: white; border: none; cursor: pointer; }
button:hover { background: #0056b3; }
#progress-box { margin-top: 10px; padding: 10px; background: #eee; border-radius: 5px; display: none; }
</style>
</head>
<body>
<h2>AI Model Manager</h2>
<div class="card">
<form method="post">
<label for="model">Choose model:</label>
<select name="model">
{% for name, repo in models.items() %}
<option value="{{name}}" {% if current == name %}selected{% endif %}>{{name}} ({{repo}})</option>
{% endfor %}
</select>
<button type="submit" name="download">Download Model</button>
<button type="submit" name="use">Use Model</button>
</form>
<p><b>Current model:</b> {{ current if current else "None" }}</p>
<div id="progress-box">Checking status...</div>
</div>
<div class="card">
<form method="post">
<label>Ask AI:</label>
<input type="text" name="user_input" placeholder="Enter your question..." required>
<button type="submit" name="ask">Ask</button>
</form>
{% if answer %}
<div style="margin-top:15px; padding:10px; background:#e8f0fe; border-radius:5px;">
<b>Answer:</b> {{ answer }}
</div>
{% endif %}
</div>
<script>
function checkProgress() {
fetch("/progress")
.then(response => response.json())
.then(data => {
let box = document.getElementById("progress-box");
if (data.status === "idle") {
box.style.display = "none";
} else {
box.style.display = "block";
if (data.status === "downloading") {
box.innerHTML = "Downloading model: " + data.model + "...";
} else if (data.status === "done") {
box.innerHTML = "✅ Model " + data.model + " downloaded successfully!";
} else {
box.innerHTML = "⚠️ Error: " + data.status;
}
}
});
}
setInterval(checkProgress, 2000);
</script>
</body>
</html>
"""
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=5000)