File size: 3,839 Bytes
23c88d4
 
0d01ec2
f74a35a
940ba34
f05802f
a7925b2
7185b69
 
696277a
baea38f
 
 
696277a
 
 
 
 
a7925b2
696277a
 
 
 
1238274
9c77d88
f05802f
baea38f
a7925b2
2af2718
 
baea38f
2af2718
 
 
baea38f
2af2718
 
f74a35a
2af2718
f74a35a
baea38f
a7925b2
f74a35a
 
 
 
 
fe5131f
4111b35
940ba34
4111b35
940ba34
 
4111b35
 
2af2718
4111b35
a8076e2
940ba34
2af2718
940ba34
a7925b2
4111b35
940ba34
2af2718
4111b35
e64a138
 
4db4956
f05802f
 
 
e64a138
 
 
 
940ba34
e64a138
0d01ec2
 
f05802f
940ba34
0d01ec2
a5272ad
940ba34
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
import yfinance as yf
from prophet import Prophet
from sklearn.linear_model import LinearRegression
from neuralprophet import NeuralProphet
import pandas as pd
from datetime import datetime
import plotly.graph_objects as go

def download_data(ticker, start_date='2010-01-01'):
    """
    ์ฃผ์‹ ๋ฐ์ดํ„ฐ๋ฅผ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ  ํฌ๋งท์„ ์กฐ์ •ํ•˜๋Š” ํ•จ์ˆ˜
    """
    data = yf.download(ticker, start=start_date)
    if data.empty:
        raise ValueError(f"No data returned for {ticker}")
    data.reset_index(inplace=True)
    if 'Adj Close' in data.columns:
        data = data[['Date', 'Adj Close']]
        data.rename(columns={'Date': 'ds', 'Adj Close': 'y'}, inplace=True)
    else:
        raise ValueError("Expected 'Adj Close' in columns")
    return data

def predict_future_prices(ticker, periods=1825):
    data = download_data(ticker)
    
    # Prophet ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
    model_prophet = Prophet(daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=True)
    model_prophet.fit(data)
    
    # ๋ฏธ๋ž˜ ๋ฐ์ดํ„ฐ ํ”„๋ ˆ์ž„ ์ƒ์„ฑ ๋ฐ ์˜ˆ์ธก
    future = model_prophet.make_future_dataframe(periods=periods, freq='D')
    forecast_prophet = model_prophet.predict(future)
    
    # Linear Regression ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
    model_lr = LinearRegression()
    X = pd.to_numeric(pd.Series(range(len(data))))
    y = data['y'].values
    model_lr.fit(X.values.reshape(-1, 1), y)
    
    # ๋ฏธ๋ž˜ ๋ฐ์ดํ„ฐ ํ”„๋ ˆ์ž„ ์ƒ์„ฑ ๋ฐ ์˜ˆ์ธก
    future_dates = pd.date_range(start=data['ds'].iloc[-1], periods=periods+1, freq='D')[1:]
    future_lr = pd.DataFrame({'ds': future_dates})
    future_lr['ds'] = future_lr['ds'].dt.strftime('%Y-%m-%d')
    X_future = pd.to_numeric(pd.Series(range(len(data), len(data) + len(future_lr))))
    future_lr['yhat'] = model_lr.predict(X_future.values.reshape(-1, 1))

# NeuralProphet ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ ํ•™์Šต
    model_np = NeuralProphet()
    metrics = model_np.fit(data, freq='D')
    future_np = model_np.make_future_dataframe(data, periods=periods)
    forecast_np = model_np.predict(future_np)

# ์˜ˆ์ธก ๊ฒฐ๊ณผ ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
    forecast_prophet['ds'] = forecast_prophet['ds'].dt.strftime('%Y-%m-%d')
    forecast_np['ds'] = pd.to_datetime(forecast_np['ds']).dt.strftime('%Y-%m-%d')
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=forecast_prophet['ds'], y=forecast_prophet['yhat'], mode='lines', name='Prophet Forecast (Blue)', line=dict(color='blue')))
    fig.add_trace(go.Scatter(x=future_lr['ds'], y=future_lr['yhat'], mode='lines', name='Linear Regression Forecast (Red)', line=dict(color='red')))
    fig.add_trace(go.Scatter(x=forecast_np['ds'], y=forecast_np['yhat1'], mode='lines', name='NeuralProphet Forecast (Green)', line=dict(color='green')))
    fig.add_trace(go.Scatter(x=data['ds'], y=data['y'], mode='lines', name='Actual (Black)', line=dict(color='black')))

    return fig, forecast_prophet[['ds', 'yhat', 'yhat_lower', 'yhat_upper']], future_lr[['ds', 'yhat']], forecast_np[['ds', 'yhat1']]


# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ๋ฐ ์‹คํ–‰
with gr.Blocks() as app:
    with gr.Row():
        ticker_input = gr.Textbox(value="AAPL", label="Enter Stock Ticker for Forecast")
        periods_input = gr.Number(value=1825, label="Forecast Period (days)")
        forecast_button = gr.Button("Generate Forecast")
    
    forecast_chart = gr.Plot(label="Forecast Chart")
    forecast_data_prophet = gr.Dataframe(label="Prophet Forecast Data")
    forecast_data_lr = gr.Dataframe(label="Linear Regression Forecast Data")
    forecast_data_np = gr.Dataframe(label="NeuralProphet Forecast Data")
    
    forecast_button.click(
        fn=predict_future_prices,
        inputs=[ticker_input, periods_input],
        outputs=[forecast_chart, forecast_data_prophet, forecast_data_lr, forecast_data_np]
    )

app.launch()