suryanshs16103's picture
Update app.py
689c2ed verified
import datetime
import requests
import matplotlib
import matplotlib.pyplot as plt
from mplfinance.original_flavor import candlestick_ohlc
import numpy as np
from sklearn.linear_model import LinearRegression
import os
from pathlib import Path
import streamlit as st
PLOT_DIR = Path("./Plots")
if not os.path.exists(PLOT_DIR):
os.mkdir(PLOT_DIR)
host = "https://api.gateio.ws"
prefix = "/api/v4"
headers = {'Accept': 'application/json', 'Content-Type': 'application/json'}
endpoint = '/spot/candlesticks'
url = host + prefix + endpoint
max_API_request_allowed = 900
def lin_reg(data, threshold_channel_len):
list_f = []
X = []
y = []
for i in range(0, len(data)):
X.append(data[i][0])
avg = (data[i][2] + data[i][3]) / 2
y.append(avg)
X = np.array(X).reshape(-1, 1)
y = np.array(y).reshape(-1, 1)
l = 0
j = threshold_channel_len
while l < j and j <= len(data):
score = []
list_pf = []
while j <= len(data):
reg = LinearRegression().fit(X[l:j], y[l:j])
temp_coeff = list(reg.coef_)
temp_intercept = list(reg.intercept_)
list_pf.append([temp_coeff[0][0], temp_intercept[0], l, j - 1])
score.append([reg.score(X[l:j], y[l:j]), j])
j = j + 1
req_score = float("-inf")
ind = -1
temp_ind = -1
for i in range(len(score)):
if req_score < score[i][0]:
ind = score[i][1]
req_score = score[i][0]
temp_ind = i
list_f.append(list_pf[temp_ind])
l = ind
j = ind + threshold_channel_len
return list_f
def binary_search(data, line_type, m, b, epsilon):
right = float("-inf")
left = float("inf")
get_y_intercept = lambda x, y: y - m * x
for i in range(len(data)):
d = data[i]
curr_y = d[2]
if line_type == "bottom":
curr_y = d[3]
curr = get_y_intercept(d[0], curr_y)
right = max(right, curr)
left = min(left, curr)
sign = -1
if line_type == "bottom":
left, right = right, left
sign = 1
ans = right
while left <= right:
mid = left + (right - left) // 2
intersection_count = 0
for i in range(len(data)):
d = data[i]
curr_y = m * d[0] + mid
candle_y = d[2]
if line_type == "bottom":
candle_y = d[3]
if line_type == "bottom" and (curr_y > candle_y and (curr_y - candle_y > epsilon)):
intersection_count += 1
if line_type == "top" and (curr_y < candle_y and (candle_y - curr_y > epsilon)):
intersection_count += 1
if intersection_count == 0:
right = mid + 1 * sign
ans = mid
else:
left = mid - 1 * sign
return ans
def plot_lines(lines, plt, converted_data):
for m, b, start, end in lines:
x_data = list(np.linspace(converted_data[start][0], converted_data[end][0], 10))
y_data = [m * x + b for x in x_data]
plt.plot(x_data, y_data)
def get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime):
curr_datetime = start_datetime
total_dates = 0
while curr_datetime <= end_datetime:
total_dates += 1
curr_datetime += interval_timedelta
data = []
for i in range(0, total_dates, max_API_request_allowed):
query_param = {
"currency_pair": "{}_USDT".format(currency),
"from": int((start_datetime + i * interval_timedelta).timestamp()),
"to": int((start_datetime + (i + max_API_request_allowed - 1) * interval_timedelta).timestamp()),
"interval": interval,
}
r = requests.get(url=url, headers=headers, params=query_param)
if r.status_code != 200:
st.error("Very Large Duration Selected. Please reduce Duration or increase Interval")
return []
data += r.json()
return data
def testcasecase(currency, interval, startdate, enddate, threshold_channel_len, testcasecase_id):
start_date_month, start_date_day, start_date_year = [int(x) for x in startdate.strip().split("/")]
end_date_month, end_date_day, end_date_year = [int(x) for x in enddate.strip().split("/")]
if interval == "1h":
interval_timedelta = datetime.timedelta(hours=1)
elif interval == "4h":
interval_timedelta = datetime.timedelta(hours=4)
elif interval == "1d":
interval_timedelta = datetime.timedelta(days=1)
else:
interval_timedelta = datetime.timedelta(weeks=1)
start_datetime = datetime.datetime(year=start_date_year, month=start_date_month, day=start_date_day)
end_datetime = datetime.datetime(year=end_date_year, month=end_date_month, day=end_date_day)
data = get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime)
if len(data) == 0:
return
converted_data = []
for d in data:
converted_data.append([matplotlib.dates.date2num(datetime.datetime.utcfromtimestamp(float(d[0]))), float(d[5]), float(d[3]), float(d[4]), float(d[2])])
fig, ax = plt.subplots()
candlestick_ohlc(ax, converted_data, width=0.4, colorup='#77d879', colordown='#db3f3f')
fitting_lines_data = lin_reg(converted_data, threshold_channel_len)
top_fitting_lines_data = []
bottom_fitting_lines_data = []
epsilon = 0
for i in range(len(fitting_lines_data)):
m, b, start, end = fitting_lines_data[i]
top_b = binary_search(converted_data[start:end + 1], "top", m, b, epsilon)
bottom_b = binary_search(converted_data[start:end + 1], "bottom", m, b, epsilon)
top_fitting_lines_data.append([m, top_b, start, end])
bottom_fitting_lines_data.append([m, bottom_b, start, end])
plot_lines(top_fitting_lines_data, plt, converted_data)
plot_lines(bottom_fitting_lines_data, plt, converted_data)
plt.title("{}_USDT".format(currency))
file_name = "figure_{}_{}_USDT.png".format(testcasecase_id, currency)
file_location = os.path.join(PLOT_DIR, file_name)
plt.savefig(file_location)
st.pyplot(fig)
def main():
st.title("Cryptocurrency Regression Analysis")
st.write("Enter details to generate regression lines on cryptocurrency candlesticks.")
currency = st.text_input("Currency", "BTC")
interval = st.selectbox("Interval", ["4h", "1d", "1w"])
startdate = st.text_input("Start Date (MM/DD/YYYY)", "01/01/2023")
enddate = st.text_input("End Date (MM/DD/YYYY)", "02/01/2023")
threshold_channel_len = st.number_input("Threshold Channel Length", min_value=1, max_value=1000, value=10)
if st.button("Generate Plot"):
testcasecase(currency, interval, startdate, enddate, threshold_channel_len, 1)
if __name__ == "__main__":
main()