| import db |
| import numpy as np |
| import plotly.graph_objects as go |
|
|
|
|
| def plot_estimates_distribution(): |
| """Affiche une gaussienne par prompt (Plotly) + lignes verticales pointillées sur les moyennes.""" |
| estimates = db.load("estimates") |
| prompts = db.load("prompts") |
| if estimates.empty or prompts.empty: |
| fig = go.Figure() |
| fig.add_annotation( |
| text="Aucune estimation disponible", x=0.5, y=0.5, showarrow=False |
| ) |
| return fig |
| x = np.linspace( |
| estimates["mu"].min() - 3 * estimates["sigma"].max(), |
| estimates["mu"].max() + 3 * estimates["sigma"].max(), |
| 500, |
| ) |
| fig = go.Figure() |
| shapes = [] |
| |
| for _, row in estimates.iterrows(): |
| mu = row["mu"] |
| sigma = row["sigma"] |
| prompt_id = row["prompt_id"] if "prompt_id" in row else row["id"] |
| |
| name = str(prompt_id) |
| if "name" in prompts.columns: |
| match = prompts[prompts["id"] == prompt_id] |
| if not match.empty: |
| name = match.iloc[0]["name"] |
| y = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-0.5 * ((x - mu) / sigma) ** 2) |
| fig.add_trace( |
| go.Scatter( |
| x=x, |
| y=y, |
| mode="lines", |
| name=f"{name}", |
| hovertemplate=f"<b>{name}</b><br>Score (mu): {mu:.2f}<br>Sigma: {sigma:.2f}<extra></extra>", |
| ) |
| ) |
| |
| shapes.append( |
| dict( |
| type="line", |
| x0=mu, |
| x1=mu, |
| y0=0, |
| y1=max(y), |
| line=dict( |
| color="gray", |
| width=2, |
| dash="dot", |
| ), |
| xref="x", |
| yref="y", |
| ) |
| ) |
| fig.update_layout( |
| title="Distribution gaussienne de chaque prompt", |
| xaxis_title="Score (mu)", |
| yaxis_title="Densité", |
| template="plotly_white", |
| shapes=shapes, |
| ) |
| return fig |
|
|