|
import datetime |
|
import requests |
|
import matplotlib |
|
import matplotlib.pyplot as plt |
|
from mplfinance.original_flavor import candlestick_ohlc |
|
import numpy as np |
|
from sklearn.linear_model import LinearRegression |
|
import os |
|
from pathlib import Path |
|
import streamlit as st |
|
|
|
PLOT_DIR = Path("./Plots") |
|
|
|
if not os.path.exists(PLOT_DIR): |
|
os.mkdir(PLOT_DIR) |
|
|
|
host = "https://api.gateio.ws" |
|
prefix = "/api/v4" |
|
headers = {'Accept': 'application/json', 'Content-Type': 'application/json'} |
|
endpoint = '/spot/candlesticks' |
|
url = host + prefix + endpoint |
|
max_API_request_allowed = 900 |
|
|
|
def lin_reg(data, threshold_channel_len): |
|
list_f = [] |
|
X = [] |
|
y = [] |
|
for i in range(0, len(data)): |
|
X.append(data[i][0]) |
|
avg = (data[i][2] + data[i][3]) / 2 |
|
y.append(avg) |
|
X = np.array(X).reshape(-1, 1) |
|
y = np.array(y).reshape(-1, 1) |
|
l = 0 |
|
j = threshold_channel_len |
|
while l < j and j <= len(data): |
|
score = [] |
|
list_pf = [] |
|
while j <= len(data): |
|
reg = LinearRegression().fit(X[l:j], y[l:j]) |
|
temp_coeff = list(reg.coef_) |
|
temp_intercept = list(reg.intercept_) |
|
list_pf.append([temp_coeff[0][0], temp_intercept[0], l, j - 1]) |
|
score.append([reg.score(X[l:j], y[l:j]), j]) |
|
j = j + 1 |
|
req_score = float("-inf") |
|
ind = -1 |
|
temp_ind = -1 |
|
for i in range(len(score)): |
|
if req_score < score[i][0]: |
|
ind = score[i][1] |
|
req_score = score[i][0] |
|
temp_ind = i |
|
list_f.append(list_pf[temp_ind]) |
|
l = ind |
|
j = ind + threshold_channel_len |
|
return list_f |
|
|
|
def binary_search(data, line_type, m, b, epsilon): |
|
right = float("-inf") |
|
left = float("inf") |
|
get_y_intercept = lambda x, y: y - m * x |
|
for i in range(len(data)): |
|
d = data[i] |
|
curr_y = d[2] |
|
if line_type == "bottom": |
|
curr_y = d[3] |
|
curr = get_y_intercept(d[0], curr_y) |
|
right = max(right, curr) |
|
left = min(left, curr) |
|
|
|
sign = -1 |
|
if line_type == "bottom": |
|
left, right = right, left |
|
sign = 1 |
|
ans = right |
|
while left <= right: |
|
mid = left + (right - left) // 2 |
|
intersection_count = 0 |
|
for i in range(len(data)): |
|
d = data[i] |
|
curr_y = m * d[0] + mid |
|
candle_y = d[2] |
|
if line_type == "bottom": |
|
candle_y = d[3] |
|
if line_type == "bottom" and (curr_y > candle_y and (curr_y - candle_y > epsilon)): |
|
intersection_count += 1 |
|
if line_type == "top" and (curr_y < candle_y and (candle_y - curr_y > epsilon)): |
|
intersection_count += 1 |
|
if intersection_count == 0: |
|
right = mid + 1 * sign |
|
ans = mid |
|
else: |
|
left = mid - 1 * sign |
|
return ans |
|
|
|
def plot_lines(lines, plt, converted_data): |
|
for m, b, start, end in lines: |
|
x_data = list(np.linspace(converted_data[start][0], converted_data[end][0], 10)) |
|
y_data = [m * x + b for x in x_data] |
|
plt.plot(x_data, y_data) |
|
|
|
def get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime): |
|
curr_datetime = start_datetime |
|
total_dates = 0 |
|
while curr_datetime <= end_datetime: |
|
total_dates += 1 |
|
curr_datetime += interval_timedelta |
|
data = [] |
|
for i in range(0, total_dates, max_API_request_allowed): |
|
query_param = { |
|
"currency_pair": "{}_USDT".format(currency), |
|
"from": int((start_datetime + i * interval_timedelta).timestamp()), |
|
"to": int((start_datetime + (i + max_API_request_allowed - 1) * interval_timedelta).timestamp()), |
|
"interval": interval, |
|
} |
|
r = requests.get(url=url, headers=headers, params=query_param) |
|
if r.status_code != 200: |
|
st.error("Very Large Duration Selected. Please reduce Duration or increase Interval") |
|
return [] |
|
data += r.json() |
|
return data |
|
|
|
def testcasecase(currency, interval, startdate, enddate, threshold_channel_len, testcasecase_id): |
|
start_date_month, start_date_day, start_date_year = [int(x) for x in startdate.strip().split("/")] |
|
end_date_month, end_date_day, end_date_year = [int(x) for x in enddate.strip().split("/")] |
|
|
|
if interval == "1h": |
|
interval_timedelta = datetime.timedelta(hours=1) |
|
elif interval == "4h": |
|
interval_timedelta = datetime.timedelta(hours=4) |
|
elif interval == "1d": |
|
interval_timedelta = datetime.timedelta(days=1) |
|
else: |
|
interval_timedelta = datetime.timedelta(weeks=1) |
|
|
|
start_datetime = datetime.datetime(year=start_date_year, month=start_date_month, day=start_date_day) |
|
end_datetime = datetime.datetime(year=end_date_year, month=end_date_month, day=end_date_day) |
|
|
|
data = get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime) |
|
if len(data) == 0: |
|
return |
|
converted_data = [] |
|
for d in data: |
|
converted_data.append([matplotlib.dates.date2num(datetime.datetime.utcfromtimestamp(float(d[0]))), float(d[5]), float(d[3]), float(d[4]), float(d[2])]) |
|
|
|
fig, ax = plt.subplots() |
|
candlestick_ohlc(ax, converted_data, width=0.4, colorup='#77d879', colordown='#db3f3f') |
|
|
|
fitting_lines_data = lin_reg(converted_data, threshold_channel_len) |
|
top_fitting_lines_data = [] |
|
bottom_fitting_lines_data = [] |
|
epsilon = 0 |
|
for i in range(len(fitting_lines_data)): |
|
m, b, start, end = fitting_lines_data[i] |
|
top_b = binary_search(converted_data[start:end + 1], "top", m, b, epsilon) |
|
bottom_b = binary_search(converted_data[start:end + 1], "bottom", m, b, epsilon) |
|
top_fitting_lines_data.append([m, top_b, start, end]) |
|
bottom_fitting_lines_data.append([m, bottom_b, start, end]) |
|
|
|
plot_lines(top_fitting_lines_data, plt, converted_data) |
|
plot_lines(bottom_fitting_lines_data, plt, converted_data) |
|
plt.title("{}_USDT".format(currency)) |
|
file_name = "figure_{}_{}_USDT.png".format(testcasecase_id, currency) |
|
file_location = os.path.join(PLOT_DIR, file_name) |
|
plt.savefig(file_location) |
|
st.pyplot(fig) |
|
|
|
def main(): |
|
st.title("Cryptocurrency Regression Analysis") |
|
st.write("Enter details to generate regression lines on cryptocurrency candlesticks.") |
|
|
|
currency = st.text_input("Currency", "BTC") |
|
interval = st.selectbox("Interval", ["4h", "1d", "1w"]) |
|
startdate = st.text_input("Start Date (MM/DD/YYYY)", "01/01/2023") |
|
enddate = st.text_input("End Date (MM/DD/YYYY)", "02/01/2023") |
|
threshold_channel_len = st.number_input("Threshold Channel Length", min_value=1, max_value=1000, value=10) |
|
|
|
if st.button("Generate Plot"): |
|
testcasecase(currency, interval, startdate, enddate, threshold_channel_len, 1) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|