Spaces:
Runtime error
Runtime error
Michael Natanael
Pengguna mengunggah file audio non-lirik atau rusak. Muncul pesan error bahwa file yang diunggah
8da3a98
| from flask import ( | |
| Flask, | |
| render_template, | |
| request, | |
| url_for, | |
| redirect, | |
| flash, | |
| get_flashed_messages, | |
| ) | |
| from flask_login import ( | |
| LoginManager, | |
| login_user, | |
| logout_user, | |
| login_required, | |
| current_user, | |
| ) | |
| from flask_sqlalchemy import SQLAlchemy | |
| from flask_login import UserMixin | |
| from werkzeug.security import generate_password_hash, check_password_hash | |
| from faster_whisper import WhisperModel | |
| from groq import Groq | |
| import tempfile | |
| import os | |
| import datetime | |
| import time | |
| import torch | |
| import numpy as np | |
| import requests | |
| from tqdm import tqdm | |
| from transformers import BertTokenizer | |
| from model.multi_class_model import MultiClassModel | |
| # from model.database import db, User | |
| from sqlalchemy.exc import OperationalError | |
| from sqlalchemy import inspect | |
| app = Flask(__name__) | |
| # === CONFIG === | |
| # CHECKPOINT_URL = "https://github.com/michael2002porto/bert_classification_indonesian_song_lyrics/releases/download/finetuned_checkpoints/original_split_synthesized.ckpt" | |
| CHECKPOINT_URL = "https://huggingface.co/nenafem/original_split_synthesized/resolve/main/original_split_synthesized.ckpt?download=true" | |
| CHECKPOINT_PATH = "final_checkpoint/original_split_synthesized.ckpt" | |
| AGE_LABELS = ["semua usia", "anak", "remaja", "dewasa"] | |
| DATABASE_URI = "postgresql://postgres.tcqmmongiztvqkxxebnc:I1Nnj0H72Z3mXWcp@aws-0-ap-southeast-1.pooler.supabase.com:6543/postgres" | |
| # === CONNECT DATABASE === | |
| app.config["SQLALCHEMY_DATABASE_URI"] = DATABASE_URI | |
| app.config["SECRET_KEY"] = "I1Nnj0H72Z3mXWcp" | |
| # init extensions | |
| db = SQLAlchemy(app) | |
| login_manager = LoginManager(app) | |
| login_manager.login_view = "login" | |
| try: | |
| db.session.execute("SELECT 1") | |
| print("✅ Database connected successfully.") | |
| except OperationalError as e: | |
| print(f"❌ Database connection failed: {e}") | |
| def show_schema_info(): | |
| inspector = inspect(db.engine) | |
| # Get current schema (by default it's 'public' unless set explicitly) | |
| current_schema = db.engine.url.database | |
| all_schemas = inspector.get_schema_names() | |
| public_tables = inspector.get_table_names(schema="public") | |
| return { | |
| "current_schema": current_schema, | |
| "available_schemas": all_schemas, | |
| "public_tables": public_tables, | |
| } | |
| class User(db.Model, UserMixin): | |
| __tablename__ = "user" | |
| id = db.Column(db.Integer, primary_key=True) | |
| email = db.Column(db.String(255), nullable=False) | |
| password = db.Column(db.String(255)) | |
| created_date = db.Column(db.DateTime, default=datetime.datetime.now()) | |
| history = db.relationship("History", backref="user", lazy=True) | |
| class History(db.Model): | |
| __tablename__ = "history" | |
| id = db.Column(db.Integer, primary_key=True) | |
| lyric = db.Column(db.Text, nullable=False) | |
| predicted_label = db.Column(db.String(255), nullable=False) | |
| children_prob = db.Column(db.Float) | |
| adolescents_prob = db.Column(db.Float) | |
| adults_prob = db.Column(db.Float) | |
| all_ages_prob = db.Column(db.Float) | |
| processing_time = db.Column(db.Float) # store duration in seconds | |
| created_date = db.Column(db.DateTime, default=datetime.datetime.now) | |
| speech_to_text = db.Column(db.Boolean) | |
| user_id = db.Column(db.Integer, db.ForeignKey("user.id")) | |
| # Load user for Flask-Login | |
| def load_user(user_id): | |
| return User.query.get(int(user_id)) | |
| # === FUNCTION TO DOWNLOAD CKPT IF NEEDED === | |
| def download_checkpoint_if_needed(url, save_path): | |
| if not os.path.exists(save_path): | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| print(f"📥 Downloading model checkpoint from {url}...") | |
| response = requests.get(url, stream=True, timeout=10) | |
| if response.status_code == 200: | |
| total = int(response.headers.get("content-length", 0)) | |
| with open(save_path, "wb") as f, tqdm( | |
| total=total, unit="B", unit_scale=True, desc="Downloading" | |
| ) as pbar: | |
| for chunk in response.iter_content(1024): | |
| f.write(chunk) | |
| pbar.update(len(chunk)) | |
| print("✅ Checkpoint downloaded!") | |
| else: | |
| raise Exception(f"❌ Failed to download: {response.status_code}") | |
| # === INITIAL SETUP: Download & Load Model === | |
| print(show_schema_info()) | |
| download_checkpoint_if_needed(CHECKPOINT_URL, CHECKPOINT_PATH) | |
| # Load groq | |
| client = Groq(api_key="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") | |
| # Load tokenizer | |
| tokenizer = BertTokenizer.from_pretrained("indolem/indobert-base-uncased") | |
| # Load model from checkpoint | |
| model = MultiClassModel.load_from_checkpoint( | |
| CHECKPOINT_PATH, n_out=4, dropout=0.3, lr=1e-5 | |
| ) | |
| model.eval() | |
| # === INITIAL SETUP: Faster Whisper === | |
| # https://github.com/SYSTRAN/faster-whisper | |
| # faster_whisper_model_size = "large-v3" | |
| faster_whisper_model_size = "turbo" | |
| # Run on GPU with FP16 | |
| # model = WhisperModel(model_size, device="cuda", compute_type="float16") | |
| # or run on GPU with INT8 | |
| # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16") | |
| # or run on CPU with INT8 | |
| faster_whisper_model = WhisperModel( | |
| faster_whisper_model_size, device="cpu", compute_type="int8" | |
| ) | |
| def faster_whisper(temp_audio_path): | |
| segments, info = faster_whisper_model.transcribe( | |
| temp_audio_path, | |
| language="id", | |
| beam_size=1, # Lower beam_size, faster but may miss words | |
| ) | |
| print( | |
| "Detected language '%s' with probability %f" | |
| % (info.language, info.language_probability) | |
| ) | |
| # for segment in segments: | |
| # print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text)) | |
| return " ".join(segment.text for segment in segments) | |
| def bert_predict(input_lyric): | |
| encoding = tokenizer.encode_plus( | |
| input_lyric, | |
| add_special_tokens=True, | |
| max_length=512, | |
| truncation=True, # Ensures input ≤512 tokens | |
| return_token_type_ids=True, | |
| padding="max_length", | |
| return_attention_mask=True, | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| prediction = model( | |
| encoding["input_ids"], | |
| encoding["attention_mask"], | |
| encoding["token_type_ids"], | |
| ) | |
| logits = prediction | |
| probabilities = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten() | |
| predicted_class = np.argmax(probabilities) | |
| predicted_label = AGE_LABELS[predicted_class] | |
| prob_results = [ | |
| (label, f"{prob:.4f}") for label, prob in zip(AGE_LABELS, probabilities) | |
| ] | |
| return predicted_label, prob_results | |
| # === ROUTES === | |
| def index(): | |
| return render_template("index.html") | |
| def transcribe(): | |
| try: | |
| # Load Whisper with Indonesian language support (large / turbo) | |
| # https://github.com/openai/whisper | |
| # whisper_model = whisper.load_model("large") | |
| # Start measuring time | |
| start_time = time.time() | |
| audio_file = request.files["file"] | |
| if audio_file: | |
| # Save uploaded audio to temp file | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: | |
| temp_audio.write(audio_file.read()) | |
| temp_audio_path = temp_audio.name | |
| # Step 1: Transcribe | |
| # transcribed_text = faster_whisper(temp_audio_path).strip() | |
| with open(temp_audio_path, "rb") as file: | |
| transcription = client.audio.transcriptions.create( | |
| file=(temp_audio_path, file.read()), | |
| model="whisper-large-v3", | |
| prompt="Transkripsikan hanya bagian lirik lagu saja", | |
| language="id", | |
| response_format="verbose_json", | |
| temperature=0, | |
| ) | |
| transcribed_text = transcription.text.strip() | |
| os.remove(temp_audio_path) | |
| # Step 2: BERT Prediction | |
| predicted_label, prob_results = bert_predict(transcribed_text) | |
| # Stop timer | |
| end_time = time.time() | |
| total_time = end_time - start_time | |
| formatted_time = f"{total_time:.2f} seconds" | |
| # Insert log prediction | |
| new_prediction_history = History( | |
| lyric=transcribed_text, | |
| predicted_label=predicted_label, | |
| children_prob=prob_results[AGE_LABELS.index("anak")][1], | |
| adolescents_prob=prob_results[AGE_LABELS.index("remaja")][1], | |
| adults_prob=prob_results[AGE_LABELS.index("dewasa")][1], | |
| all_ages_prob=prob_results[AGE_LABELS.index("semua usia")][1], | |
| processing_time=round(total_time, 2), | |
| speech_to_text=True, | |
| user_id=current_user.id if current_user.is_authenticated else None, | |
| ) | |
| db.session.add(new_prediction_history) | |
| db.session.commit() | |
| return render_template( | |
| "transcribe.html", | |
| task=transcribed_text, | |
| prediction=predicted_label, | |
| probabilities=prob_results, | |
| total_time=formatted_time, | |
| ) | |
| except Exception as e: | |
| print("Error:", e) | |
| return str(e) | |
| def predict_text(): | |
| try: | |
| user_lyrics = request.form.get("lyrics", "").strip() | |
| if not user_lyrics: | |
| return "No lyrics provided.", 400 | |
| # Start timer | |
| start_time = time.time() | |
| # Step 1: BERT Prediction | |
| predicted_label, prob_results = bert_predict(user_lyrics) | |
| # End timer | |
| end_time = time.time() | |
| total_time = end_time - start_time | |
| formatted_time = f"{total_time:.2f} seconds" | |
| # Insert log prediction | |
| new_prediction_history = History( | |
| lyric=user_lyrics, | |
| predicted_label=predicted_label, | |
| children_prob=prob_results[AGE_LABELS.index("anak")][1], | |
| adolescents_prob=prob_results[AGE_LABELS.index("remaja")][1], | |
| adults_prob=prob_results[AGE_LABELS.index("dewasa")][1], | |
| all_ages_prob=prob_results[AGE_LABELS.index("semua usia")][1], | |
| processing_time=round(total_time, 2), | |
| user_id=current_user.id if current_user.is_authenticated else None, | |
| ) | |
| db.session.add(new_prediction_history) | |
| db.session.commit() | |
| return render_template( | |
| "transcribe.html", | |
| task=user_lyrics, | |
| prediction=predicted_label, | |
| probabilities=prob_results, | |
| total_time=formatted_time, | |
| ) | |
| except Exception as e: | |
| print("❌ Error in predict-text:", e) | |
| return str(e), 500 | |
| def register(): | |
| if request.method == "POST": | |
| email = request.form.get("email") | |
| password = request.form.get("password") | |
| confirm_password = request.form.get("confirm-password") | |
| if User.query.filter_by(email=email).first(): | |
| return render_template( | |
| "register.html", | |
| error="Email already taken!", | |
| email=email, | |
| password=password, | |
| confirm_password=confirm_password, | |
| ) | |
| if password != confirm_password: | |
| return render_template( | |
| "register.html", | |
| error="Password does not match!", | |
| email=email, | |
| password=password, | |
| confirm_password=confirm_password, | |
| ) | |
| hashed_password = generate_password_hash(password, method="pbkdf2:sha256") | |
| new_user = User(email=email, password=hashed_password) | |
| db.session.add(new_user) | |
| db.session.commit() | |
| flash( | |
| "Sign up successful! Please log in.", "success" | |
| ) # Flash the success message | |
| return redirect(url_for("login")) | |
| return render_template("register.html") | |
| def login(): | |
| if request.method == "POST": | |
| email = request.form.get("email") | |
| password = request.form.get("password") | |
| user = User.query.filter_by(email=email).first() | |
| if user and check_password_hash(user.password, password): | |
| login_user(user) | |
| return dashboard(login_alert=True) | |
| else: | |
| return render_template("login.html", error="Invalid email or password") | |
| return render_template("login.html") | |
| def dashboard(login_alert=False): | |
| if login_alert: | |
| flash(current_user.email, "success") | |
| return redirect(url_for("index")) | |
| def logout(): | |
| logout_user() | |
| return redirect(url_for("login")) | |
| def history(): | |
| data_history = ( | |
| History.query.filter_by(user_id=current_user.id) | |
| .order_by(History.created_date.desc()) | |
| .all() | |
| ) | |
| for item in data_history: | |
| item.probabilities = [ | |
| ("anak", f"{item.children_prob:.4f}"), | |
| ("remaja", f"{item.adolescents_prob:.4f}"), | |
| ("dewasa", f"{item.adults_prob:.4f}"), | |
| ("semua usia", f"{item.all_ages_prob:.4f}"), | |
| ] | |
| return render_template("history.html", data_history=data_history) | |
| if __name__ == "__main__": | |
| app.run(debug=True) | |