kemuriririn commited on
Commit
2249567
·
1 Parent(s): a2a66cd

modify elo caculate

Browse files
Files changed (2) hide show
  1. app.py +105 -56
  2. models.py +68 -30
app.py CHANGED
@@ -662,62 +662,111 @@ def generate_tts():
662
  )
663
  # --- End Cache Check ---
664
 
665
- # --- Cache Miss: Generate on the fly ---
666
- app.logger.warning(f"TTS Cache MISS for: '{text[:50]}...'. Generating on the fly.")
667
- available_models = Model.query.filter_by(
668
- model_type=ModelType.TTS, is_active=True
669
- ).all()
670
- if len(available_models) < 2:
671
- return jsonify({"error": "Not enough TTS models available"}), 500
672
-
673
- # 新增:a和b模型都需通过缓存和静音检测
674
- candidate_models = available_models.copy()
675
- random.shuffle(candidate_models)
676
- valid_pairs = []
677
- # 枚举所有模型对,找到第一个都通过的组合
678
- for i in range(len(candidate_models)):
679
- model_a = candidate_models[i]
680
- audio_a_path = find_cached_audio(str(model_a.id), text, reference_audio_path)
681
- app.logger.warning(f"checking {audio_a_path}")
682
- if not audio_a_path or has_long_silence(audio_a_path):
683
- continue
684
- # 检测到a模型音频有效,继续检测b模型
685
- for j in range(i + 1, len(candidate_models)):
686
- model_b = candidate_models[j]
687
- audio_b_path = find_cached_audio(str(model_b.id), text, reference_audio_path)
688
- app.logger.warning(f"checking {audio_b_path}")
689
- if not audio_b_path or has_long_silence(audio_b_path):
690
- continue
691
- valid_pairs.append((model_a, audio_a_path, model_b, audio_b_path))
692
- app.logger.warning(f"Found valid model pair: {model_a.name} and {model_b.name} for text '{text[:50]}...'")
693
- break
694
- if not valid_pairs:
695
- return jsonify({"error": "所有模型均未通过持久化缓存和静音检测,无法生成音频"}), 500
696
-
697
- # 随机选一个合格组合
698
- model_a, audio_a_path, model_b, audio_b_path = random.choice(valid_pairs)
699
- session_id = str(uuid.uuid4())
700
- app.tts_sessions[session_id] = {
701
- "model_a": model_a.id,
702
- "model_b": model_b.id,
703
- "audio_a": audio_a_path,
704
- "audio_b": audio_b_path,
705
- "text": text,
706
- "created_at": datetime.utcnow(),
707
- "expires_at": datetime.utcnow() + timedelta(minutes=30),
708
- "voted": False,
709
- }
710
- # 清理临时参考音频文件
711
- if reference_audio_path and os.path.exists(reference_audio_path):
712
- os.remove(reference_audio_path)
713
- return jsonify({
714
- "session_id": session_id,
715
- "audio_a": f"/api/tts/audio/{session_id}/a",
716
- "audio_b": f"/api/tts/audio/{session_id}/b",
717
- "expires_in": 1800,
718
- "cache_hit": True,
719
- })
720
- # --- End Cache Miss ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
721
 
722
 
723
  @app.route("/api/tts/audio/<session_id>/<model_key>")
 
662
  )
663
  # --- End Cache Check ---
664
 
665
+ # --- Cache Miss: Local File Cache ---
666
+ # 对于预置文本和预置prompt,检查本地缓存
667
+ if text in predefined_texts and prompt_md5 in predefined_prompts.values():
668
+ app.logger.warning(f"TTS Cache MISS for: '{text[:50]}...'. Generating on the fly.")
669
+ available_models = Model.query.filter_by(
670
+ model_type=ModelType.TTS, is_active=True
671
+ ).all()
672
+ if len(available_models) < 2:
673
+ return jsonify({"error": "Not enough TTS models available"}), 500
674
+
675
+ # 新增:a和b模型都需通过缓存检测
676
+ candidate_models = available_models.copy()
677
+ valid_models = []
678
+ invalid_models = []
679
+ for model in candidate_models:
680
+ audio_path = find_cached_audio(model.name, text, prompt_audio_path=reference_audio_path)
681
+ if audio_path and os.path.exists(audio_path):
682
+ valid_models.append(model)
683
+ else:
684
+ invalid_models.append(model)
685
+
686
+ if len(valid_models) < 2:
687
+ return jsonify({"error": "Not enough valid TTS model results available"}), 500
688
+
689
+ apply_filter_penalty_and_redistribute(invalid_models, valid_models, penalty_amount=1.0)
690
+
691
+ # 从有结果的模型中随机选择两个
692
+ model_a,model_b = random.sample(valid_models, 2)
693
+ audio_a_path = find_cached_audio(model_a.name, text, prompt_audio_path=reference_audio_path)
694
+ audio_b_path = find_cached_audio(model_b.name, text, prompt_audio_path=reference_audio_path)
695
+
696
+ session_id = str(uuid.uuid4())
697
+ app.tts_sessions[session_id] = {
698
+ "model_a": model_a.id,
699
+ "model_b": model_b.id,
700
+ "audio_a": audio_a_path,
701
+ "audio_b": audio_b_path,
702
+ "text": text,
703
+ "created_at": datetime.utcnow(),
704
+ "expires_at": datetime.utcnow() + timedelta(minutes=30),
705
+ "voted": False,
706
+ }
707
+ # 清理临时参考音频文件
708
+ if reference_audio_path and os.path.exists(reference_audio_path):
709
+ os.remove(reference_audio_path)
710
+ return jsonify({
711
+ "session_id": session_id,
712
+ "audio_a": f"/api/tts/audio/{session_id}/a",
713
+ "audio_b": f"/api/tts/audio/{session_id}/b",
714
+ "expires_in": 1800,
715
+ "cache_hit": True,
716
+ })
717
+ # --- End Cache Miss ---
718
+ else:
719
+ app.logger.warning(f"TTS Cache MISS for: '{text[:50]}...'. Generating on the fly.")
720
+ available_models = Model.query.filter_by(
721
+ model_type=ModelType.TTS, is_active=True
722
+ ).all()
723
+ if len(available_models) < 2:
724
+ return jsonify({"error": "Not enough TTS models available"}), 500
725
+
726
+ # Get two random models with weighted selection
727
+ models = get_weighted_random_models(available_models, 2, ModelType.TTS)
728
+
729
+ # Generate audio concurrently using a local executor for clarity within the request
730
+ with ThreadPoolExecutor(max_workers=2, thread_name_prefix='AudioGen') as audio_executor:
731
+ future_a = audio_executor.submit(generate_and_save_tts, text, models[0].id, RUNTIME_CACHE_DIR,
732
+ prompt_audio_path=reference_audio_path)
733
+ future_b = audio_executor.submit(generate_and_save_tts, text, models[1].id, RUNTIME_CACHE_DIR,
734
+ prompt_audio_path=reference_audio_path)
735
+
736
+ timeout_seconds = 120
737
+ audio_a_path, ref_a = future_a.result(timeout=timeout_seconds)
738
+ audio_b_path, ref_b = future_b.result(timeout=timeout_seconds)
739
+
740
+ if not audio_a_path or not audio_b_path:
741
+ return jsonify({"error": "Failed to generate TTS audio"}), 500
742
+
743
+ session_id = str(uuid.uuid4())
744
+ app.tts_sessions[session_id] = {
745
+ "model_a": models[0].id,
746
+ "model_b": models[1].id,
747
+ "audio_a": audio_a_path,
748
+ "audio_b": audio_b_path,
749
+ "text": text,
750
+ "created_at": datetime.utcnow(),
751
+ "expires_at": datetime.utcnow() + timedelta(minutes=30),
752
+ "voted": False,
753
+ }
754
+
755
+ # Clean up temporary reference audio file if it was provided
756
+ if reference_audio_path and os.path.exists(reference_audio_path):
757
+ os.remove(reference_audio_path)
758
+
759
+ # Return response with session ID and audio URLs
760
+ return jsonify(
761
+ {
762
+ "session_id": session_id,
763
+ "audio_a": f"/api/tts/audio/{session_id}/a",
764
+ "audio_b": f"/api/tts/audio/{session_id}/b",
765
+ "expires_in": 1800, # 30 minutes in seconds
766
+ "cache_hit": False,
767
+ }
768
+ )
769
+
770
 
771
 
772
  @app.route("/api/tts/audio/<session_id>/<model_key>")
models.py CHANGED
@@ -84,6 +84,7 @@ class EloHistory(db.Model):
84
  model_id = db.Column(db.String(100), db.ForeignKey("model.id"), nullable=False)
85
  timestamp = db.Column(db.DateTime, default=datetime.utcnow)
86
  elo_score = db.Column(db.Float, nullable=False)
 
87
  vote_id = db.Column(db.Integer, db.ForeignKey("vote.id"), nullable=True)
88
  model_type = db.Column(db.String(20), nullable=False) # 'tts' or 'conversational'
89
 
@@ -130,18 +131,18 @@ def record_vote(user_id, text, chosen_model_id, rejected_model_id, model_type):
130
  db.session.rollback()
131
  return None, "One or both models not found for the specified model type"
132
 
133
- k_factor_winner = get_dynamic_k_factor(chosen_model.match_count)
134
- k_factor_loser = get_dynamic_k_factor(rejected_model.match_count)
 
 
 
 
 
135
 
136
- # Calculate new Elo ratings
137
- new_chosen_elo, new_rejected_elo = calculate_elo_change_dynamic_k(
138
- chosen_model.current_elo, rejected_model.current_elo, k_factor_winner, k_factor_loser
139
  )
140
 
141
- # new_chosen_elo, new_rejected_elo = calculate_elo_change(
142
- # chosen_model.current_elo, rejected_model.current_elo
143
- # )
144
-
145
  # Update model stats
146
  chosen_model.current_elo = new_chosen_elo
147
  chosen_model.win_count += 1
@@ -535,32 +536,69 @@ def toggle_user_leaderboard_visibility(user_id):
535
  return user.show_in_leaderboard
536
 
537
 
538
- def get_dynamic_k_factor(match_count):
539
- """
540
- 使用连续衰减函数动态计算K因子。
541
- K因子会从一个最大值平滑地衰减到一个最小值。
542
-
543
- Args:
544
- match_count (int): 模型的总比赛次数。
545
-
546
- Returns:
547
- float: 计算出的K因子。
548
- """
549
  k_max = 40 # 新模型的最大K因子
550
  k_min = 10 # 成熟模型的最小K因子
551
- decay_speed = 50.0 # 衰减速度,数值越大,K因子下降越慢
 
 
 
 
552
 
553
- # 指数衰减公式: K = K_min + (K_max - K_min) * e^(-match_count / decay_speed)
554
- k_factor = k_min + (k_max - k_min) * math.exp(-match_count / decay_speed)
 
 
 
 
555
 
556
  return k_factor
557
 
558
- def calculate_elo_change_dynamic_k(winner_elo, loser_elo, k_factor_winner, k_factor_loser):
559
- """根据双方不同的K因子计算Elo等级分变化。"""
560
- expected_winner = 1 / (1 + math.pow(10, (loser_elo - winner_elo) / 400))
561
- expected_loser = 1 / (1 + math.pow(10, (winner_elo - loser_elo) / 400))
562
 
563
- winner_new_elo = winner_elo + k_factor_winner * (1 - expected_winner)
564
- loser_new_elo = loser_elo + k_factor_loser * (0 - expected_loser)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
 
566
- return winner_new_elo, loser_new_elo
 
 
84
  model_id = db.Column(db.String(100), db.ForeignKey("model.id"), nullable=False)
85
  timestamp = db.Column(db.DateTime, default=datetime.utcnow)
86
  elo_score = db.Column(db.Float, nullable=False)
87
+ by_system = db.Column(db.Boolean, default=False) # Whether this is a penalty or reward change
88
  vote_id = db.Column(db.Integer, db.ForeignKey("vote.id"), nullable=True)
89
  model_type = db.Column(db.String(20), nullable=False) # 'tts' or 'conversational'
90
 
 
131
  db.session.rollback()
132
  return None, "One or both models not found for the specified model type"
133
 
134
+ # --- ELO 计算逻辑与 test_elo.py 保持一致 ---
135
+ # a. 计算双方的基础动态K因子
136
+ max_match = max(chosen_model.match_count, rejected_model.match_count, 10)
137
+ k_winner_base = get_dynamic_k_factor(chosen_model.match_count, max_match)
138
+ k_loser_base = get_dynamic_k_factor(rejected_model.match_count, max_match)
139
+ # b. 取平均K因子
140
+ base_k = (k_winner_base + k_loser_base) / 2.0
141
 
142
+ new_chosen_elo, new_rejected_elo = calculate_elo_change(
143
+ chosen_model.current_elo, rejected_model.current_elo, k_factor=base_k
 
144
  )
145
 
 
 
 
 
146
  # Update model stats
147
  chosen_model.current_elo = new_chosen_elo
148
  chosen_model.win_count += 1
 
536
  return user.show_in_leaderboard
537
 
538
 
539
+ def get_dynamic_k_factor(match_count, max_match_count):
 
 
 
 
 
 
 
 
 
 
540
  k_max = 40 # 新模型的最大K因子
541
  k_min = 10 # 成熟模型的最小K因子
542
+ decay_factor = 5.0 # 衰减因子,控制K因子下降的速度
543
+
544
+ # 防止除以零
545
+ if max_match_count == 0:
546
+ return k_max
547
 
548
+ # 计算相对比赛进度 (0到1之间)
549
+ relative_progress = match_count / max_match_count
550
+
551
+ # 使用指数衰减公式,但基于相对进度
552
+ # K = K_min + (K_max - K_min) * e^(-decay_factor * relative_progress)
553
+ k_factor = k_min + (k_max - k_min) * math.exp(-decay_factor * relative_progress)
554
 
555
  return k_factor
556
 
557
+ def apply_filter_penalty_and_redistribute(unavailable_models, available_models, penalty_amount=1.0):
558
+ """
559
+ 对不可用的模型施加惩罚,并将扣除的分数平均重新分配给可用的模型。
560
+ 这确保了系统的ELO总分保持不变(零和)。
561
 
562
+ Args:
563
+ unavailable_models (list[Model]): 因被过滤而不可用的模型对象列表。
564
+ available_models (list[Model]): 当前可用的模型对象列表。
565
+ penalty_amount (float): 每个不可用模型被扣除的ELO分数。
566
+ """
567
+ if not unavailable_models or not available_models:
568
+ # 如果没有不可用模型或没有可用的模型来接收分数,则不执行任何操作
569
+ return
570
+
571
+ # 1. 计算总惩罚分数
572
+ total_penalty = len(unavailable_models) * penalty_amount
573
+ reward_per_model = total_penalty / len(available_models)
574
+
575
+ # 2. 从不可用模型中扣除分数并记录历史
576
+ for model in unavailable_models:
577
+ new_elo = model.current_elo - penalty_amount
578
+ model.current_elo = new_elo
579
+ # 为惩罚创建一条历史记录 (没有 vote_id)
580
+ penalty_history = EloHistory(
581
+ model_id=model.id,
582
+ elo_score=new_elo,
583
+ vote_id=None,
584
+ by_system=True,
585
+ model_type=model.model_type,
586
+ )
587
+ db.session.add(penalty_history)
588
+
589
+ # 3. 将分数奖励给可用模型并记录历史
590
+ for model in available_models:
591
+ new_elo = model.current_elo + reward_per_model
592
+ model.current_elo = new_elo
593
+ # 为奖励创建一条历史记录 (没有 vote_id)
594
+ reward_history = EloHistory(
595
+ model_id=model.id,
596
+ elo_score=new_elo,
597
+ vote_id=None,
598
+ by_system=True,
599
+ model_type=model.model_type,
600
+ )
601
+ db.session.add(reward_history)
602
 
603
+ # 4. 提交所有更改到数据库
604
+ db.session.commit()