File size: 2,330 Bytes
2013214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc7b249
41e8fd4
2013214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.tsa.statespace.sarimax import SARIMAX
import os

script_dir = os.path.dirname(os.path.abspath(__file__))
csv_path = os.path.join(script_dir, 'sample_data_extended.csv')
df = pd.read_csv(csv_path)

def preprocess_data(df):
    df = df.dropna()
    # Set the time column as the index
    df['Time'] = pd.to_datetime(df['Time'])
    df = df.set_index('Time')
    # convert string to float and remove commas
    df['Total expense (VND)'] = df['Total expense (VND)'].str.replace(',', '').astype(float)
    df['Income (VND)'] = df['Income (VND)'].str.replace(',', '').astype(float)
    return df

def plot_data(df):
    plt.figure(figsize=(10, 6))
    plt.plot(df['Total expense (VND)'], label='Total expense (VND)')
    plt.plot(df['Income (VND)'], label='Income (VND)')
    plt.legend(loc='best')
    plt.show()
    
def fit_model(df):
    # define the output and exxogenous variables
    y = df['Total expense (VND)']
    exog = df[['Income (VND)', 'Interest rate (%)', 'Inflation rate (%)', 'Holidays']]
    # define orders for the model
    order = (1, 1, 1)
    seasonal_order = (1, 1, 1, 12)
    # fit the model
    model = SARIMAX(y, exog=exog, order=order, seasonal_order=seasonal_order)
    model_fit = model.fit()
    return model_fit

def get_input_data(income, interest_rate, inflation_rate, holidays):
    income = float(income)
    interest_rate = float(interest_rate)
    inflation_rate = float(inflation_rate)
    holidays = int(holidays)
    input_data = pd.DataFrame({
        'Income (VND)': [income],
        'Interest rate (%)': [interest_rate],
        'Inflation rate (%)': [inflation_rate],
        'Holidays': [holidays]
    })
    return input_data

def forecast_expense(model_fit, input_data, df):
    if (input_data['Income (VND)'].iloc[0] < 5000000):
        return input_data['Income (VND)'].iloc[0] * 0.78492
    forecast = model_fit.predict(start=len(df), end=len(df), exog=input_data)
    return forecast.iloc[0]

df = preprocess_data(df)
model_fit = fit_model(df)

# def main():
#     global df
#     df = preprocess_data(df)
#     model_fit = fit_model(df)
#     input_data = get_input_data(10000000, 5, 3, 0)
#     forecast = forecast_expense(model_fit, input_data)
#     print(forecast)

# if __name__ == '__main__':
#     main()