Spaces:
Running
Running
Commit
·
2249567
1
Parent(s):
a2a66cd
modify elo caculate
Browse files
app.py
CHANGED
@@ -662,62 +662,111 @@ def generate_tts():
|
|
662 |
)
|
663 |
# --- End Cache Check ---
|
664 |
|
665 |
-
# --- Cache Miss:
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
721 |
|
722 |
|
723 |
@app.route("/api/tts/audio/<session_id>/<model_key>")
|
|
|
662 |
)
|
663 |
# --- End Cache Check ---
|
664 |
|
665 |
+
# --- Cache Miss: Local File Cache ---
|
666 |
+
# 对于预置文本和预置prompt,检查本地缓存
|
667 |
+
if text in predefined_texts and prompt_md5 in predefined_prompts.values():
|
668 |
+
app.logger.warning(f"TTS Cache MISS for: '{text[:50]}...'. Generating on the fly.")
|
669 |
+
available_models = Model.query.filter_by(
|
670 |
+
model_type=ModelType.TTS, is_active=True
|
671 |
+
).all()
|
672 |
+
if len(available_models) < 2:
|
673 |
+
return jsonify({"error": "Not enough TTS models available"}), 500
|
674 |
+
|
675 |
+
# 新增:a和b模型都需通过缓存检测
|
676 |
+
candidate_models = available_models.copy()
|
677 |
+
valid_models = []
|
678 |
+
invalid_models = []
|
679 |
+
for model in candidate_models:
|
680 |
+
audio_path = find_cached_audio(model.name, text, prompt_audio_path=reference_audio_path)
|
681 |
+
if audio_path and os.path.exists(audio_path):
|
682 |
+
valid_models.append(model)
|
683 |
+
else:
|
684 |
+
invalid_models.append(model)
|
685 |
+
|
686 |
+
if len(valid_models) < 2:
|
687 |
+
return jsonify({"error": "Not enough valid TTS model results available"}), 500
|
688 |
+
|
689 |
+
apply_filter_penalty_and_redistribute(invalid_models, valid_models, penalty_amount=1.0)
|
690 |
+
|
691 |
+
# 从有结果的模型中随机选择两个
|
692 |
+
model_a,model_b = random.sample(valid_models, 2)
|
693 |
+
audio_a_path = find_cached_audio(model_a.name, text, prompt_audio_path=reference_audio_path)
|
694 |
+
audio_b_path = find_cached_audio(model_b.name, text, prompt_audio_path=reference_audio_path)
|
695 |
+
|
696 |
+
session_id = str(uuid.uuid4())
|
697 |
+
app.tts_sessions[session_id] = {
|
698 |
+
"model_a": model_a.id,
|
699 |
+
"model_b": model_b.id,
|
700 |
+
"audio_a": audio_a_path,
|
701 |
+
"audio_b": audio_b_path,
|
702 |
+
"text": text,
|
703 |
+
"created_at": datetime.utcnow(),
|
704 |
+
"expires_at": datetime.utcnow() + timedelta(minutes=30),
|
705 |
+
"voted": False,
|
706 |
+
}
|
707 |
+
# 清理临时参考音频文件
|
708 |
+
if reference_audio_path and os.path.exists(reference_audio_path):
|
709 |
+
os.remove(reference_audio_path)
|
710 |
+
return jsonify({
|
711 |
+
"session_id": session_id,
|
712 |
+
"audio_a": f"/api/tts/audio/{session_id}/a",
|
713 |
+
"audio_b": f"/api/tts/audio/{session_id}/b",
|
714 |
+
"expires_in": 1800,
|
715 |
+
"cache_hit": True,
|
716 |
+
})
|
717 |
+
# --- End Cache Miss ---
|
718 |
+
else:
|
719 |
+
app.logger.warning(f"TTS Cache MISS for: '{text[:50]}...'. Generating on the fly.")
|
720 |
+
available_models = Model.query.filter_by(
|
721 |
+
model_type=ModelType.TTS, is_active=True
|
722 |
+
).all()
|
723 |
+
if len(available_models) < 2:
|
724 |
+
return jsonify({"error": "Not enough TTS models available"}), 500
|
725 |
+
|
726 |
+
# Get two random models with weighted selection
|
727 |
+
models = get_weighted_random_models(available_models, 2, ModelType.TTS)
|
728 |
+
|
729 |
+
# Generate audio concurrently using a local executor for clarity within the request
|
730 |
+
with ThreadPoolExecutor(max_workers=2, thread_name_prefix='AudioGen') as audio_executor:
|
731 |
+
future_a = audio_executor.submit(generate_and_save_tts, text, models[0].id, RUNTIME_CACHE_DIR,
|
732 |
+
prompt_audio_path=reference_audio_path)
|
733 |
+
future_b = audio_executor.submit(generate_and_save_tts, text, models[1].id, RUNTIME_CACHE_DIR,
|
734 |
+
prompt_audio_path=reference_audio_path)
|
735 |
+
|
736 |
+
timeout_seconds = 120
|
737 |
+
audio_a_path, ref_a = future_a.result(timeout=timeout_seconds)
|
738 |
+
audio_b_path, ref_b = future_b.result(timeout=timeout_seconds)
|
739 |
+
|
740 |
+
if not audio_a_path or not audio_b_path:
|
741 |
+
return jsonify({"error": "Failed to generate TTS audio"}), 500
|
742 |
+
|
743 |
+
session_id = str(uuid.uuid4())
|
744 |
+
app.tts_sessions[session_id] = {
|
745 |
+
"model_a": models[0].id,
|
746 |
+
"model_b": models[1].id,
|
747 |
+
"audio_a": audio_a_path,
|
748 |
+
"audio_b": audio_b_path,
|
749 |
+
"text": text,
|
750 |
+
"created_at": datetime.utcnow(),
|
751 |
+
"expires_at": datetime.utcnow() + timedelta(minutes=30),
|
752 |
+
"voted": False,
|
753 |
+
}
|
754 |
+
|
755 |
+
# Clean up temporary reference audio file if it was provided
|
756 |
+
if reference_audio_path and os.path.exists(reference_audio_path):
|
757 |
+
os.remove(reference_audio_path)
|
758 |
+
|
759 |
+
# Return response with session ID and audio URLs
|
760 |
+
return jsonify(
|
761 |
+
{
|
762 |
+
"session_id": session_id,
|
763 |
+
"audio_a": f"/api/tts/audio/{session_id}/a",
|
764 |
+
"audio_b": f"/api/tts/audio/{session_id}/b",
|
765 |
+
"expires_in": 1800, # 30 minutes in seconds
|
766 |
+
"cache_hit": False,
|
767 |
+
}
|
768 |
+
)
|
769 |
+
|
770 |
|
771 |
|
772 |
@app.route("/api/tts/audio/<session_id>/<model_key>")
|
models.py
CHANGED
@@ -84,6 +84,7 @@ class EloHistory(db.Model):
|
|
84 |
model_id = db.Column(db.String(100), db.ForeignKey("model.id"), nullable=False)
|
85 |
timestamp = db.Column(db.DateTime, default=datetime.utcnow)
|
86 |
elo_score = db.Column(db.Float, nullable=False)
|
|
|
87 |
vote_id = db.Column(db.Integer, db.ForeignKey("vote.id"), nullable=True)
|
88 |
model_type = db.Column(db.String(20), nullable=False) # 'tts' or 'conversational'
|
89 |
|
@@ -130,18 +131,18 @@ def record_vote(user_id, text, chosen_model_id, rejected_model_id, model_type):
|
|
130 |
db.session.rollback()
|
131 |
return None, "One or both models not found for the specified model type"
|
132 |
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
chosen_model.current_elo, rejected_model.current_elo, k_factor_winner, k_factor_loser
|
139 |
)
|
140 |
|
141 |
-
# new_chosen_elo, new_rejected_elo = calculate_elo_change(
|
142 |
-
# chosen_model.current_elo, rejected_model.current_elo
|
143 |
-
# )
|
144 |
-
|
145 |
# Update model stats
|
146 |
chosen_model.current_elo = new_chosen_elo
|
147 |
chosen_model.win_count += 1
|
@@ -535,32 +536,69 @@ def toggle_user_leaderboard_visibility(user_id):
|
|
535 |
return user.show_in_leaderboard
|
536 |
|
537 |
|
538 |
-
def get_dynamic_k_factor(match_count):
|
539 |
-
"""
|
540 |
-
使用连续衰减函数动态计算K因子。
|
541 |
-
K因子会从一个最大值平滑地衰减到一个最小值。
|
542 |
-
|
543 |
-
Args:
|
544 |
-
match_count (int): 模型的总比赛次数。
|
545 |
-
|
546 |
-
Returns:
|
547 |
-
float: 计算出的K因子。
|
548 |
-
"""
|
549 |
k_max = 40 # 新模型的最大K因子
|
550 |
k_min = 10 # 成熟模型的最小K因子
|
551 |
-
|
|
|
|
|
|
|
|
|
552 |
|
553 |
-
#
|
554 |
-
|
|
|
|
|
|
|
|
|
555 |
|
556 |
return k_factor
|
557 |
|
558 |
-
def
|
559 |
-
"""
|
560 |
-
|
561 |
-
|
562 |
|
563 |
-
|
564 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
565 |
|
566 |
-
|
|
|
|
84 |
model_id = db.Column(db.String(100), db.ForeignKey("model.id"), nullable=False)
|
85 |
timestamp = db.Column(db.DateTime, default=datetime.utcnow)
|
86 |
elo_score = db.Column(db.Float, nullable=False)
|
87 |
+
by_system = db.Column(db.Boolean, default=False) # Whether this is a penalty or reward change
|
88 |
vote_id = db.Column(db.Integer, db.ForeignKey("vote.id"), nullable=True)
|
89 |
model_type = db.Column(db.String(20), nullable=False) # 'tts' or 'conversational'
|
90 |
|
|
|
131 |
db.session.rollback()
|
132 |
return None, "One or both models not found for the specified model type"
|
133 |
|
134 |
+
# --- ELO 计算逻辑与 test_elo.py 保持一致 ---
|
135 |
+
# a. 计算双方的基础动态K因子
|
136 |
+
max_match = max(chosen_model.match_count, rejected_model.match_count, 10)
|
137 |
+
k_winner_base = get_dynamic_k_factor(chosen_model.match_count, max_match)
|
138 |
+
k_loser_base = get_dynamic_k_factor(rejected_model.match_count, max_match)
|
139 |
+
# b. 取平均K因子
|
140 |
+
base_k = (k_winner_base + k_loser_base) / 2.0
|
141 |
|
142 |
+
new_chosen_elo, new_rejected_elo = calculate_elo_change(
|
143 |
+
chosen_model.current_elo, rejected_model.current_elo, k_factor=base_k
|
|
|
144 |
)
|
145 |
|
|
|
|
|
|
|
|
|
146 |
# Update model stats
|
147 |
chosen_model.current_elo = new_chosen_elo
|
148 |
chosen_model.win_count += 1
|
|
|
536 |
return user.show_in_leaderboard
|
537 |
|
538 |
|
539 |
+
def get_dynamic_k_factor(match_count, max_match_count):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
540 |
k_max = 40 # 新模型的最大K因子
|
541 |
k_min = 10 # 成熟模型的最小K因子
|
542 |
+
decay_factor = 5.0 # 衰减因子,控制K因子下降的速度
|
543 |
+
|
544 |
+
# 防止除以零
|
545 |
+
if max_match_count == 0:
|
546 |
+
return k_max
|
547 |
|
548 |
+
# 计算相对比赛进度 (0到1之间)
|
549 |
+
relative_progress = match_count / max_match_count
|
550 |
+
|
551 |
+
# 使用指数衰减公式,但基于相对进度
|
552 |
+
# K = K_min + (K_max - K_min) * e^(-decay_factor * relative_progress)
|
553 |
+
k_factor = k_min + (k_max - k_min) * math.exp(-decay_factor * relative_progress)
|
554 |
|
555 |
return k_factor
|
556 |
|
557 |
+
def apply_filter_penalty_and_redistribute(unavailable_models, available_models, penalty_amount=1.0):
|
558 |
+
"""
|
559 |
+
对不可用的模型施加惩罚,并将扣除的分数平均重新分配给可用的模型。
|
560 |
+
这确保了系统的ELO总分保持不变(零和)。
|
561 |
|
562 |
+
Args:
|
563 |
+
unavailable_models (list[Model]): 因被过滤而不可用的模型对象列表。
|
564 |
+
available_models (list[Model]): 当前可用的模型对象列表。
|
565 |
+
penalty_amount (float): 每个不可用模型被扣除的ELO分数。
|
566 |
+
"""
|
567 |
+
if not unavailable_models or not available_models:
|
568 |
+
# 如果没有不可用模型或没有可用的模型来接收分数,则不执行任何操作
|
569 |
+
return
|
570 |
+
|
571 |
+
# 1. 计算总惩罚分数
|
572 |
+
total_penalty = len(unavailable_models) * penalty_amount
|
573 |
+
reward_per_model = total_penalty / len(available_models)
|
574 |
+
|
575 |
+
# 2. 从不可用模型中扣除分数并记录历史
|
576 |
+
for model in unavailable_models:
|
577 |
+
new_elo = model.current_elo - penalty_amount
|
578 |
+
model.current_elo = new_elo
|
579 |
+
# 为惩罚创建一条历史记录 (没有 vote_id)
|
580 |
+
penalty_history = EloHistory(
|
581 |
+
model_id=model.id,
|
582 |
+
elo_score=new_elo,
|
583 |
+
vote_id=None,
|
584 |
+
by_system=True,
|
585 |
+
model_type=model.model_type,
|
586 |
+
)
|
587 |
+
db.session.add(penalty_history)
|
588 |
+
|
589 |
+
# 3. 将分数奖励给可用模型并记录历史
|
590 |
+
for model in available_models:
|
591 |
+
new_elo = model.current_elo + reward_per_model
|
592 |
+
model.current_elo = new_elo
|
593 |
+
# 为奖励创建一条历史记录 (没有 vote_id)
|
594 |
+
reward_history = EloHistory(
|
595 |
+
model_id=model.id,
|
596 |
+
elo_score=new_elo,
|
597 |
+
vote_id=None,
|
598 |
+
by_system=True,
|
599 |
+
model_type=model.model_type,
|
600 |
+
)
|
601 |
+
db.session.add(reward_history)
|
602 |
|
603 |
+
# 4. 提交所有更改到数据库
|
604 |
+
db.session.commit()
|