Spaces:
Sleeping
Sleeping
File size: 3,839 Bytes
23c88d4 0d01ec2 f74a35a 940ba34 f05802f a7925b2 7185b69 696277a baea38f 696277a a7925b2 696277a 1238274 9c77d88 f05802f baea38f a7925b2 2af2718 baea38f 2af2718 baea38f 2af2718 f74a35a 2af2718 f74a35a baea38f a7925b2 f74a35a fe5131f 4111b35 940ba34 4111b35 940ba34 4111b35 2af2718 4111b35 a8076e2 940ba34 2af2718 940ba34 a7925b2 4111b35 940ba34 2af2718 4111b35 e64a138 4db4956 f05802f e64a138 940ba34 e64a138 0d01ec2 f05802f 940ba34 0d01ec2 a5272ad 940ba34 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import gradio as gr
import yfinance as yf
from prophet import Prophet
from sklearn.linear_model import LinearRegression
from neuralprophet import NeuralProphet
import pandas as pd
from datetime import datetime
import plotly.graph_objects as go
def download_data(ticker, start_date='2010-01-01'):
"""
์ฃผ์ ๋ฐ์ดํฐ๋ฅผ ๋ค์ด๋ก๋ํ๊ณ ํฌ๋งท์ ์กฐ์ ํ๋ ํจ์
"""
data = yf.download(ticker, start=start_date)
if data.empty:
raise ValueError(f"No data returned for {ticker}")
data.reset_index(inplace=True)
if 'Adj Close' in data.columns:
data = data[['Date', 'Adj Close']]
data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
else:
raise ValueError("Expected 'Adj Close' in columns")
return data
def predict_future_prices(ticker, periods=1825):
data = download_data(ticker)
# Prophet ๋ชจ๋ธ ์์ฑ ๋ฐ ํ์ต
model_prophet = Prophet(daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=True)
model_prophet.fit(data)
# ๋ฏธ๋ ๋ฐ์ดํฐ ํ๋ ์ ์์ฑ ๋ฐ ์์ธก
future = model_prophet.make_future_dataframe(periods=periods, freq='D')
forecast_prophet = model_prophet.predict(future)
# Linear Regression ๋ชจ๋ธ ์์ฑ ๋ฐ ํ์ต
model_lr = LinearRegression()
X = pd.to_numeric(pd.Series(range(len(data))))
y = data['y'].values
model_lr.fit(X.values.reshape(-1, 1), y)
# ๋ฏธ๋ ๋ฐ์ดํฐ ํ๋ ์ ์์ฑ ๋ฐ ์์ธก
future_dates = pd.date_range(start=data['ds'].iloc[-1], periods=periods+1, freq='D')[1:]
future_lr = pd.DataFrame({'ds': future_dates})
future_lr['ds'] = future_lr['ds'].dt.strftime('%Y-%m-%d')
X_future = pd.to_numeric(pd.Series(range(len(data), len(data) + len(future_lr))))
future_lr['yhat'] = model_lr.predict(X_future.values.reshape(-1, 1))
# NeuralProphet ๋ชจ๋ธ ์์ฑ ๋ฐ ํ์ต
model_np = NeuralProphet()
metrics = model_np.fit(data, freq='D')
future_np = model_np.make_future_dataframe(data, periods=periods)
forecast_np = model_np.predict(future_np)
# ์์ธก ๊ฒฐ๊ณผ ๊ทธ๋ํ ์์ฑ
forecast_prophet['ds'] = forecast_prophet['ds'].dt.strftime('%Y-%m-%d')
forecast_np['ds'] = pd.to_datetime(forecast_np['ds']).dt.strftime('%Y-%m-%d')
fig = go.Figure()
fig.add_trace(go.Scatter(x=forecast_prophet['ds'], y=forecast_prophet['yhat'], mode='lines', name='Prophet Forecast (Blue)', line=dict(color='blue')))
fig.add_trace(go.Scatter(x=future_lr['ds'], y=future_lr['yhat'], mode='lines', name='Linear Regression Forecast (Red)', line=dict(color='red')))
fig.add_trace(go.Scatter(x=forecast_np['ds'], y=forecast_np['yhat1'], mode='lines', name='NeuralProphet Forecast (Green)', line=dict(color='green')))
fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))
return fig, forecast_prophet[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], future_lr[['ds', 'yhat']], forecast_np[['ds', 'yhat1']]
# Gradio ์ธํฐํ์ด์ค ์ค์ ๋ฐ ์คํ
with gr.Blocks() as app:
with gr.Row():
ticker_input = gr.Textbox(value="AAPL", label="Enter Stock Ticker for Forecast")
periods_input = gr.Number(value=1825, label="Forecast Period (days)")
forecast_button = gr.Button("Generate Forecast")
forecast_chart = gr.Plot(label="Forecast Chart")
forecast_data_prophet = gr.Dataframe(label="Prophet Forecast Data")
forecast_data_lr = gr.Dataframe(label="Linear Regression Forecast Data")
forecast_data_np = gr.Dataframe(label="NeuralProphet Forecast Data")
forecast_button.click(
fn=predict_future_prices,
inputs=[ticker_input, periods_input],
outputs=[forecast_chart, forecast_data_prophet, forecast_data_lr, forecast_data_np]
)
app.launch()
|