Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import plotly.graph_objs as go | |
import sympy as sp | |
st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide") | |
st.markdown(""" | |
<style> | |
html, body, [class*="css"] { | |
font-family: 'Segoe UI', sans-serif; | |
background-color: #000000; | |
color: #00ff00; | |
} | |
h1 { | |
font-size: 36px; | |
font-weight: 700; | |
margin-bottom: 0.5em; | |
color: #00ff00; | |
} | |
.stTextInput > div > input { | |
border: 2px solid #00ff00; | |
border-radius: 8px; | |
padding: 0.5em; | |
font-size: 16px; | |
background-color: #111; | |
color: #00ff00; | |
} | |
.stButton > button { | |
background-color: #00ff00; | |
color: black; | |
font-weight: 600; | |
border-radius: 8px; | |
padding: 0.6em 1.2em; | |
font-size: 16px; | |
} | |
.stButton > button:hover { | |
background-color: #00cc00; | |
transition: 0.3s; | |
} | |
.stMarkdown { | |
font-size: 18px; | |
font-weight: 500; | |
color: #00ff00; | |
} | |
.element-container:has(.stButton) { | |
margin-top: 1em; | |
margin-bottom: 1em; | |
} | |
.stColumns { | |
gap: 0.5rem !important; | |
} | |
.st-c1 { | |
font-weight: bold; | |
color: #00ff00; | |
} | |
.stSuccess { | |
font-size: 18px; | |
font-weight: 600; | |
color: black; | |
background-color: #00ff00; | |
border-radius: 6px; | |
padding: 0.4em 0.8em; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
st.title("Interactive Gradient Descent Visualizer") | |
x = sp.Symbol('x') | |
defaults = { | |
"step": 0, | |
"points": [], | |
"gradient_func": None, | |
"func": None, | |
"parsed_func": None, | |
"func_input": "x^2+x" | |
} | |
for key, val in defaults.items(): | |
if key not in st.session_state: | |
st.session_state[key] = val | |
st.session_state.func_input = st.text_input("Function", st.session_state.func_input) | |
st.markdown("Try these functions:") | |
c1, c2, c3, c4, c5 = st.columns(5) | |
if c1.button("x²"): | |
st.session_state.func_input = "x^2" | |
if c2.button("x³"): | |
st.session_state.func_input = "x^3" | |
if c3.button("sin(x)"): | |
st.session_state.func_input = "sin(x)" | |
if c4.button("1/x"): | |
st.session_state.func_input = "1/x" | |
if c5.button("Polynomial"): | |
st.session_state.func_input = "x**4 - 3*x**3 + 2" | |
start_point = st.text_input("Starting Point", "5") | |
setup = st.button("Set Up") | |
if setup: | |
try: | |
st.session_state.step = 0 | |
st.session_state.points = [] | |
expr = st.session_state.func_input.replace("^", "**") | |
parsed = sp.sympify(expr) | |
st.session_state.parsed_func = parsed | |
st.session_state.func = sp.lambdify(x, parsed, "numpy") | |
grad = sp.diff(parsed, x) | |
st.session_state.gradient_func = sp.lambdify(x, grad, "numpy") | |
st.session_state.points.append(float(start_point)) | |
except Exception as e: | |
st.error(f"Error parsing function: {e}") | |
learning_rate = st.text_input("Learning Rate", "0.01") | |
if st.button("Next Iteration"): | |
if st.session_state.func is None or st.session_state.gradient_func is None or len(st.session_state.points) == 0: | |
st.warning("Please set up the function first.") | |
else: | |
try: | |
lr = float(learning_rate) | |
x_curr = st.session_state.points[-1] | |
grad_val = st.session_state.gradient_func(x_curr) | |
x_next = x_curr - lr * grad_val | |
st.session_state.points.append(x_next) | |
st.session_state.step += 1 | |
except Exception as e: | |
st.error(f"Iteration error: {e}") | |
if st.session_state.func is not None and len(st.session_state.points) > 0: | |
try: | |
min_x = min(min(st.session_state.points), -6) | |
max_x = max(max(st.session_state.points), 6) | |
margin = (max_x - min_x) * 0.2 if max_x > min_x else 1 | |
x_vals = np.linspace(min_x - margin, max_x + margin, 400) | |
y_vals = st.session_state.func(x_vals) | |
iter_points = np.array(st.session_state.points) | |
iter_y = st.session_state.func(iter_points) | |
trace1 = go.Scatter(x=x_vals, y=y_vals, mode='lines', name='Function', line=dict(color='cyan')) | |
trace2 = go.Scatter(x=iter_points, y=iter_y, mode='markers+lines', name='Gradient Descent Path', marker=dict(color='red')) | |
trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text', | |
marker=dict(color='lime', size=12), | |
text=[f"{iter_points[-1]:.6f}"], | |
textposition="top center", | |
name="Current Point") | |
layout = go.Layout( | |
title=f"Iteration {st.session_state.step}", | |
xaxis=dict(title="x - axis", color="lime", showgrid=True, gridcolor="gray"), | |
yaxis=dict(title="y - axis", color="lime", showgrid=True, gridcolor="gray"), | |
plot_bgcolor="black", | |
paper_bgcolor="black", | |
font=dict(color="lime"), | |
) | |
fig = go.Figure(data=[trace1, trace2, trace3], layout=layout) | |
st.plotly_chart(fig, use_container_width=True) | |
st.success(f"Current Point = {iter_points[-1]}") | |
except Exception as e: | |
st.error(f"Plot error: {e}") | |