File size: 4,545 Bytes
5fc6e5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from typing import Optional

from loguru import logger
from mlflow.tracking import MlflowClient


def get_best_model_by_tag(
    language: str,
    tag_key: str = "best_model",
    metric: str = "f1_score"
) -> Optional[dict]:
    """
    Retrieve the best model for a specific language using MLflow tags.
    
    Args:
        language: Programming language (java, python, pharo)
        tag_key: Tag key to search for (default: "best_model")
        metric: Metric to use for ordering (default: "f1_score")
    
    Returns:
        Dict with run_id and artifact_name of the best model or None if not found
    """

    client = MlflowClient()
    experiments = client.search_experiments()
    if not experiments:
        logger.error("No experiments found in MLflow")
        return None
    
    try:
        runs = client.search_runs(
            experiment_ids=[exp.experiment_id for exp in experiments],
            filter_string=f"tags.{tag_key} = 'true' and tags.Language = '{language}'",
            order_by=[f"metrics.{metric} DESC"],
            max_results=1
        )
        
        if not runs:
            logger.warning(f"No runs found with tag '{tag_key}' for language '{language}'")
            return None
        
        best_run = runs[0]
        run_id = best_run.info.run_id
        exp_name = client.get_experiment(best_run.info.experiment_id).name
        run_name = best_run.info.run_name
        artifact_name = best_run.data.tags.get("model_name")
        model_id = best_run.data.tags.get("model_id")
        logger.info(f"Found best model for {language}: {exp_name}/{run_name} ({run_id}), artifact={artifact_name}")
        
        return {
            "run_id": run_id,
            "artifact": artifact_name, 
            "model_id": model_id
        }
    
    except Exception as e:
        logger.error(f"Error searching for best model: {e}")
        return None


def get_best_model_info(
    language: str,
    fallback_registry: dict = None
) -> dict:
    """
    Retrieve the best model information for a language.
    First searches by tag, then falls back to hardcoded registry.
    
    Args:
        language: Programming language
        fallback_registry: Fallback registry with run_id and artifact
        
    Returns:
        Dict with run_id and artifact of the model
    """
    
    model_info = get_best_model_by_tag(language, "best_model")
    
    if model_info:
        logger.info(f"Using tagged best model for {language}")
        return model_info
    
    if fallback_registry and language in fallback_registry:
        logger.warning(f"No tagged model found for {language}, using fallback registry")
        return fallback_registry[language]
    
    model_info = get_best_model_by_metric(language)
    
    if model_info:
        logger.warning(f"Using best model by metric for {language}")
        return model_info
    
    raise ValueError(f"No model found for language {language}")


def get_best_model_by_metric(
    language: str,
    metric: str = "f1_score"
) -> Optional[dict]:
    """
    Find the model with the best metric for a language.
    
    Args:
        language: Programming language
        metric: Metric to use for ordering
        
    Returns:
        Dict with run_id and artifact of the model or None
    """

    client = MlflowClient()
    experiments = client.search_experiments()
    if not experiments:
        logger.error("No experiments found in MLflow")
        return None
    
    try:
        runs = client.search_runs(
            experiment_ids=[exp.experiment_id for exp in experiments],
            filter_string=f"tags.Language = '{language}'",
            order_by=[f"metrics.{metric} DESC"],
            max_results=1
        )
        
        if not runs:
            logger.warning(f"No runs found for language '{language}'")
            return None
        
        best_run = runs[0]
        run_id = best_run.info.run_id
        exp_name = client.get_experiment(best_run.info.experiment_id).name
        run_name = best_run.info.run_name
        artifact_name = best_run.data.tags.get("model_name")
        model_id = best_run.data.tags.get("model_id")
        logger.info(f"Found best model for {language}: {exp_name}/{run_name} ({run_id}), artifact={artifact_name}")
        
        return {
            "run_id": run_id,
            "artifact": artifact_name,
            "model_id": model_id
        }
    
    except Exception as e:
        logger.error(f"Error finding best model by metric: {e}")
        return None