import streamlit as st import numpy as np import plotly.graph_objs as go import sympy as sp st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide") st.markdown(""" """, unsafe_allow_html=True) st.title("Interactive Gradient Descent Visualizer") x = sp.Symbol('x') defaults = { "step": 0, "points": [], "gradient_func": None, "func": None, "parsed_func": None, "func_input": "x^2+x" } for key, val in defaults.items(): if key not in st.session_state: st.session_state[key] = val st.session_state.func_input = st.text_input("Function", st.session_state.func_input) st.markdown("Try these functions:") c1, c2, c3, c4, c5 = st.columns(5) if c1.button("x²"): st.session_state.func_input = "x^2" if c2.button("x³"): st.session_state.func_input = "x^3" if c3.button("sin(x)"): st.session_state.func_input = "sin(x)" if c4.button("1/x"): st.session_state.func_input = "1/x" if c5.button("Polynomial"): st.session_state.func_input = "x**4 - 3*x**3 + 2" start_point = st.text_input("Starting Point", "5") setup = st.button("Set Up") if setup: try: st.session_state.step = 0 st.session_state.points = [] expr = st.session_state.func_input.replace("^", "**") parsed = sp.sympify(expr) st.session_state.parsed_func = parsed st.session_state.func = sp.lambdify(x, parsed, "numpy") grad = sp.diff(parsed, x) st.session_state.gradient_func = sp.lambdify(x, grad, "numpy") st.session_state.points.append(float(start_point)) except Exception as e: st.error(f"Error parsing function: {e}") learning_rate = st.text_input("Learning Rate", "0.01") if st.button("Next Iteration"): if st.session_state.func is None or st.session_state.gradient_func is None or len(st.session_state.points) == 0: st.warning("Please set up the function first.") else: try: lr = float(learning_rate) x_curr = st.session_state.points[-1] grad_val = st.session_state.gradient_func(x_curr) x_next = x_curr - lr * grad_val st.session_state.points.append(x_next) st.session_state.step += 1 except Exception as e: st.error(f"Iteration error: {e}") if st.session_state.func is not None and len(st.session_state.points) > 0: try: min_x = min(min(st.session_state.points), -6) max_x = max(max(st.session_state.points), 6) margin = (max_x - min_x) * 0.2 if max_x > min_x else 1 x_vals = np.linspace(min_x - margin, max_x + margin, 400) y_vals = st.session_state.func(x_vals) iter_points = np.array(st.session_state.points) iter_y = st.session_state.func(iter_points) trace1 = go.Scatter(x=x_vals, y=y_vals, mode='lines', name='Function', line=dict(color='cyan')) trace2 = go.Scatter(x=iter_points, y=iter_y, mode='markers+lines', name='Gradient Descent Path', marker=dict(color='red')) trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text', marker=dict(color='lime', size=12), text=[f"{iter_points[-1]:.6f}"], textposition="top center", name="Current Point") layout = go.Layout( title=f"Iteration {st.session_state.step}", xaxis=dict(title="x - axis", color="lime", showgrid=True, gridcolor="gray"), yaxis=dict(title="y - axis", color="lime", showgrid=True, gridcolor="gray"), plot_bgcolor="black", paper_bgcolor="black", font=dict(color="lime"), ) fig = go.Figure(data=[trace1, trace2, trace3], layout=layout) st.plotly_chart(fig, use_container_width=True) st.success(f"Current Point = {iter_points[-1]}") except Exception as e: st.error(f"Plot error: {e}")