import datetime import requests import matplotlib import matplotlib.pyplot as plt from mplfinance.original_flavor import candlestick_ohlc import numpy as np from sklearn.linear_model import LinearRegression import os from pathlib import Path import streamlit as st PLOT_DIR = Path("./Plots") if not os.path.exists(PLOT_DIR): os.mkdir(PLOT_DIR) host = "https://api.gateio.ws" prefix = "/api/v4" headers = {'Accept': 'application/json', 'Content-Type': 'application/json'} endpoint = '/spot/candlesticks' url = host + prefix + endpoint max_API_request_allowed = 900 def lin_reg(data, threshold_channel_len): list_f = [] X = [] y = [] for i in range(0, len(data)): X.append(data[i][0]) avg = (data[i][2] + data[i][3]) / 2 y.append(avg) X = np.array(X).reshape(-1, 1) y = np.array(y).reshape(-1, 1) l = 0 j = threshold_channel_len while l < j and j <= len(data): score = [] list_pf = [] while j <= len(data): reg = LinearRegression().fit(X[l:j], y[l:j]) temp_coeff = list(reg.coef_) temp_intercept = list(reg.intercept_) list_pf.append([temp_coeff[0][0], temp_intercept[0], l, j - 1]) score.append([reg.score(X[l:j], y[l:j]), j]) j = j + 1 req_score = float("-inf") ind = -1 temp_ind = -1 for i in range(len(score)): if req_score < score[i][0]: ind = score[i][1] req_score = score[i][0] temp_ind = i list_f.append(list_pf[temp_ind]) l = ind j = ind + threshold_channel_len return list_f def binary_search(data, line_type, m, b, epsilon): right = float("-inf") left = float("inf") get_y_intercept = lambda x, y: y - m * x for i in range(len(data)): d = data[i] curr_y = d[2] if line_type == "bottom": curr_y = d[3] curr = get_y_intercept(d[0], curr_y) right = max(right, curr) left = min(left, curr) sign = -1 if line_type == "bottom": left, right = right, left sign = 1 ans = right while left <= right: mid = left + (right - left) // 2 intersection_count = 0 for i in range(len(data)): d = data[i] curr_y = m * d[0] + mid candle_y = d[2] if line_type == "bottom": candle_y = d[3] if line_type == "bottom" and (curr_y > candle_y and (curr_y - candle_y > epsilon)): intersection_count += 1 if line_type == "top" and (curr_y < candle_y and (candle_y - curr_y > epsilon)): intersection_count += 1 if intersection_count == 0: right = mid + 1 * sign ans = mid else: left = mid - 1 * sign return ans def plot_lines(lines, plt, converted_data): for m, b, start, end in lines: x_data = list(np.linspace(converted_data[start][0], converted_data[end][0], 10)) y_data = [m * x + b for x in x_data] plt.plot(x_data, y_data) def get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime): curr_datetime = start_datetime total_dates = 0 while curr_datetime <= end_datetime: total_dates += 1 curr_datetime += interval_timedelta data = [] for i in range(0, total_dates, max_API_request_allowed): query_param = { "currency_pair": "{}_USDT".format(currency), "from": int((start_datetime + i * interval_timedelta).timestamp()), "to": int((start_datetime + (i + max_API_request_allowed - 1) * interval_timedelta).timestamp()), "interval": interval, } r = requests.get(url=url, headers=headers, params=query_param) if r.status_code != 200: st.error("Very Large Duration Selected. Please reduce Duration or increase Interval") return [] data += r.json() return data def testcasecase(currency, interval, startdate, enddate, threshold_channel_len, testcasecase_id): start_date_month, start_date_day, start_date_year = [int(x) for x in startdate.strip().split("/")] end_date_month, end_date_day, end_date_year = [int(x) for x in enddate.strip().split("/")] if interval == "1h": interval_timedelta = datetime.timedelta(hours=1) elif interval == "4h": interval_timedelta = datetime.timedelta(hours=4) elif interval == "1d": interval_timedelta = datetime.timedelta(days=1) else: interval_timedelta = datetime.timedelta(weeks=1) start_datetime = datetime.datetime(year=start_date_year, month=start_date_month, day=start_date_day) end_datetime = datetime.datetime(year=end_date_year, month=end_date_month, day=end_date_day) data = get_API_data(currency, interval_timedelta, interval, start_datetime, end_datetime) if len(data) == 0: return converted_data = [] for d in data: converted_data.append([matplotlib.dates.date2num(datetime.datetime.utcfromtimestamp(float(d[0]))), float(d[5]), float(d[3]), float(d[4]), float(d[2])]) fig, ax = plt.subplots() candlestick_ohlc(ax, converted_data, width=0.4, colorup='#77d879', colordown='#db3f3f') fitting_lines_data = lin_reg(converted_data, threshold_channel_len) top_fitting_lines_data = [] bottom_fitting_lines_data = [] epsilon = 0 for i in range(len(fitting_lines_data)): m, b, start, end = fitting_lines_data[i] top_b = binary_search(converted_data[start:end + 1], "top", m, b, epsilon) bottom_b = binary_search(converted_data[start:end + 1], "bottom", m, b, epsilon) top_fitting_lines_data.append([m, top_b, start, end]) bottom_fitting_lines_data.append([m, bottom_b, start, end]) plot_lines(top_fitting_lines_data, plt, converted_data) plot_lines(bottom_fitting_lines_data, plt, converted_data) plt.title("{}_USDT".format(currency)) file_name = "figure_{}_{}_USDT.png".format(testcasecase_id, currency) file_location = os.path.join(PLOT_DIR, file_name) plt.savefig(file_location) st.pyplot(fig) def main(): st.title("Cryptocurrency Regression Analysis") st.write("Enter details to generate regression lines on cryptocurrency candlesticks.") currency = st.text_input("Currency", "BTC") interval = st.selectbox("Interval", ["4h", "1d", "1w"]) startdate = st.text_input("Start Date (MM/DD/YYYY)", "01/01/2023") enddate = st.text_input("End Date (MM/DD/YYYY)", "02/01/2023") threshold_channel_len = st.number_input("Threshold Channel Length", min_value=1, max_value=1000, value=10) if st.button("Generate Plot"): testcasecase(currency, interval, startdate, enddate, threshold_channel_len, 1) if __name__ == "__main__": main()