Spaces:
Runtime error
Runtime error
File size: 4,545 Bytes
5fc6e5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from typing import Optional
from loguru import logger
from mlflow.tracking import MlflowClient
def get_best_model_by_tag(
language: str,
tag_key: str = "best_model",
metric: str = "f1_score"
) -> Optional[dict]:
"""
Retrieve the best model for a specific language using MLflow tags.
Args:
language: Programming language (java, python, pharo)
tag_key: Tag key to search for (default: "best_model")
metric: Metric to use for ordering (default: "f1_score")
Returns:
Dict with run_id and artifact_name of the best model or None if not found
"""
client = MlflowClient()
experiments = client.search_experiments()
if not experiments:
logger.error("No experiments found in MLflow")
return None
try:
runs = client.search_runs(
experiment_ids=[exp.experiment_id for exp in experiments],
filter_string=f"tags.{tag_key} = 'true' and tags.Language = '{language}'",
order_by=[f"metrics.{metric} DESC"],
max_results=1
)
if not runs:
logger.warning(f"No runs found with tag '{tag_key}' for language '{language}'")
return None
best_run = runs[0]
run_id = best_run.info.run_id
exp_name = client.get_experiment(best_run.info.experiment_id).name
run_name = best_run.info.run_name
artifact_name = best_run.data.tags.get("model_name")
model_id = best_run.data.tags.get("model_id")
logger.info(f"Found best model for {language}: {exp_name}/{run_name} ({run_id}), artifact={artifact_name}")
return {
"run_id": run_id,
"artifact": artifact_name,
"model_id": model_id
}
except Exception as e:
logger.error(f"Error searching for best model: {e}")
return None
def get_best_model_info(
language: str,
fallback_registry: dict = None
) -> dict:
"""
Retrieve the best model information for a language.
First searches by tag, then falls back to hardcoded registry.
Args:
language: Programming language
fallback_registry: Fallback registry with run_id and artifact
Returns:
Dict with run_id and artifact of the model
"""
model_info = get_best_model_by_tag(language, "best_model")
if model_info:
logger.info(f"Using tagged best model for {language}")
return model_info
if fallback_registry and language in fallback_registry:
logger.warning(f"No tagged model found for {language}, using fallback registry")
return fallback_registry[language]
model_info = get_best_model_by_metric(language)
if model_info:
logger.warning(f"Using best model by metric for {language}")
return model_info
raise ValueError(f"No model found for language {language}")
def get_best_model_by_metric(
language: str,
metric: str = "f1_score"
) -> Optional[dict]:
"""
Find the model with the best metric for a language.
Args:
language: Programming language
metric: Metric to use for ordering
Returns:
Dict with run_id and artifact of the model or None
"""
client = MlflowClient()
experiments = client.search_experiments()
if not experiments:
logger.error("No experiments found in MLflow")
return None
try:
runs = client.search_runs(
experiment_ids=[exp.experiment_id for exp in experiments],
filter_string=f"tags.Language = '{language}'",
order_by=[f"metrics.{metric} DESC"],
max_results=1
)
if not runs:
logger.warning(f"No runs found for language '{language}'")
return None
best_run = runs[0]
run_id = best_run.info.run_id
exp_name = client.get_experiment(best_run.info.experiment_id).name
run_name = best_run.info.run_name
artifact_name = best_run.data.tags.get("model_name")
model_id = best_run.data.tags.get("model_id")
logger.info(f"Found best model for {language}: {exp_name}/{run_name} ({run_id}), artifact={artifact_name}")
return {
"run_id": run_id,
"artifact": artifact_name,
"model_id": model_id
}
except Exception as e:
logger.error(f"Error finding best model by metric: {e}")
return None
|