hii / app.py
SuryaaNaik's picture
Update app.py
9922bf4 verified
import streamlit as st
import numpy as np
import plotly.graph_objs as go
import sympy as sp
st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide")
st.markdown("""
<style>
html, body, [class*="css"] {
font-family: 'Segoe UI', sans-serif;
background-color: #000000;
color: #00ff00;
}
h1 {
font-size: 36px;
font-weight: 700;
margin-bottom: 0.5em;
color: #00ff00;
}
.stTextInput > div > input {
border: 2px solid #00ff00;
border-radius: 8px;
padding: 0.5em;
font-size: 16px;
background-color: #111;
color: #00ff00;
}
.stButton > button {
background-color: #00ff00;
color: black;
font-weight: 600;
border-radius: 8px;
padding: 0.6em 1.2em;
font-size: 16px;
}
.stButton > button:hover {
background-color: #00cc00;
transition: 0.3s;
}
.stMarkdown {
font-size: 18px;
font-weight: 500;
color: #00ff00;
}
.element-container:has(.stButton) {
margin-top: 1em;
margin-bottom: 1em;
}
.stColumns {
gap: 0.5rem !important;
}
.st-c1 {
font-weight: bold;
color: #00ff00;
}
.stSuccess {
font-size: 18px;
font-weight: 600;
color: black;
background-color: #00ff00;
border-radius: 6px;
padding: 0.4em 0.8em;
}
</style>
""", unsafe_allow_html=True)
st.title("Interactive Gradient Descent Visualizer")
x = sp.Symbol('x')
defaults = {
"step": 0,
"points": [],
"gradient_func": None,
"func": None,
"parsed_func": None,
"func_input": "x^2+x"
}
for key, val in defaults.items():
if key not in st.session_state:
st.session_state[key] = val
st.session_state.func_input = st.text_input("Function", st.session_state.func_input)
st.markdown("Try these functions:")
c1, c2, c3, c4, c5 = st.columns(5)
if c1.button("x²"):
st.session_state.func_input = "x^2"
if c2.button("x³"):
st.session_state.func_input = "x^3"
if c3.button("sin(x)"):
st.session_state.func_input = "sin(x)"
if c4.button("1/x"):
st.session_state.func_input = "1/x"
if c5.button("Polynomial"):
st.session_state.func_input = "x**4 - 3*x**3 + 2"
start_point = st.text_input("Starting Point", "5")
setup = st.button("Set Up")
if setup:
try:
st.session_state.step = 0
st.session_state.points = []
expr = st.session_state.func_input.replace("^", "**")
parsed = sp.sympify(expr)
st.session_state.parsed_func = parsed
st.session_state.func = sp.lambdify(x, parsed, "numpy")
grad = sp.diff(parsed, x)
st.session_state.gradient_func = sp.lambdify(x, grad, "numpy")
st.session_state.points.append(float(start_point))
except Exception as e:
st.error(f"Error parsing function: {e}")
learning_rate = st.text_input("Learning Rate", "0.01")
if st.button("Next Iteration"):
if st.session_state.func is None or st.session_state.gradient_func is None or len(st.session_state.points) == 0:
st.warning("Please set up the function first.")
else:
try:
lr = float(learning_rate)
x_curr = st.session_state.points[-1]
grad_val = st.session_state.gradient_func(x_curr)
x_next = x_curr - lr * grad_val
st.session_state.points.append(x_next)
st.session_state.step += 1
except Exception as e:
st.error(f"Iteration error: {e}")
if st.session_state.func is not None and len(st.session_state.points) > 0:
try:
min_x = min(min(st.session_state.points), -6)
max_x = max(max(st.session_state.points), 6)
margin = (max_x - min_x) * 0.2 if max_x > min_x else 1
x_vals = np.linspace(min_x - margin, max_x + margin, 400)
y_vals = st.session_state.func(x_vals)
iter_points = np.array(st.session_state.points)
iter_y = st.session_state.func(iter_points)
trace1 = go.Scatter(x=x_vals, y=y_vals, mode='lines', name='Function', line=dict(color='cyan'))
trace2 = go.Scatter(x=iter_points, y=iter_y, mode='markers+lines', name='Gradient Descent Path', marker=dict(color='red'))
trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',
marker=dict(color='lime', size=12),
text=[f"{iter_points[-1]:.6f}"],
textposition="top center",
name="Current Point")
layout = go.Layout(
title=f"Iteration {st.session_state.step}",
xaxis=dict(title="x - axis", color="lime", showgrid=True, gridcolor="gray"),
yaxis=dict(title="y - axis", color="lime", showgrid=True, gridcolor="gray"),
plot_bgcolor="black",
paper_bgcolor="black",
font=dict(color="lime"),
)
fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
st.plotly_chart(fig, use_container_width=True)
st.success(f"Current Point = {iter_points[-1]}")
except Exception as e:
st.error(f"Plot error: {e}")