From 14e8c749626616f185e5d5971517b2b2e2c60873 Mon Sep 17 00:00:00 2001 From: vvzvlad Date: Tue, 3 Sep 2024 04:24:43 +0300 Subject: [PATCH] new model --- app.py | 44 ++++++++++++++++++++++++++++-------- model.py | 59 ++++++++++++++++++++++++++++-------------------- requirements.txt | 11 ++++++++- 3 files changed, 80 insertions(+), 34 deletions(-) diff --git a/app.py b/app.py index 4f94268..2232b3d 100644 --- a/app.py +++ b/app.py @@ -4,7 +4,7 @@ import pandas as pd import numpy as np from datetime import datetime from flask import Flask, jsonify, Response -from model import download_data, format_data, train_model +from model import download_data, format_data, train_model, training_price_data_path from config import model_file_path app = Flask(__name__) @@ -19,14 +19,36 @@ def update_data(): def get_eth_inference(): """Load model and predict current price.""" - with open(model_file_path, "rb") as f: - loaded_model = pickle.load(f) + try: + with open(model_file_path, "rb") as f: + loaded_model = pickle.load(f) - now_timestamp = pd.Timestamp(datetime.now()).timestamp() - X_new = np.array([now_timestamp]).reshape(-1, 1) - current_price_pred = loaded_model.predict(X_new) + # Загружаем последние данные из файла + price_data = pd.read_csv(training_price_data_path) - return current_price_pred[0] + # Используем последние значения признаков для предсказания + X_new = ( + price_data[ + [ + "timestamp", + "price_diff", + "volatility", + "volume", + "moving_avg_7", + "moving_avg_30", + ] + ] + .iloc[-1] + .values.reshape(1, -1) + ) + + # Делаем предсказание + current_price_pred = loaded_model.predict(X_new) + + return current_price_pred[0] + except Exception as e: + print(f"Error during inference: {str(e)}") + raise @app.route("/inference/") @@ -34,13 +56,17 @@ def generate_inference(token): """Generate inference for given token.""" if not token or token != "ETH": error_msg = "Token is required" if not token else "Token not supported" - return Response(json.dumps({"error": error_msg}), status=400, mimetype='application/json') + return Response( + json.dumps({"error": error_msg}), status=400, mimetype="application/json" + ) try: inference = get_eth_inference() return Response(str(inference), status=200) except Exception as e: - return Response(json.dumps({"error": str(e)}), status=500, mimetype='application/json') + return Response( + json.dumps({"error": str(e)}), status=500, mimetype="application/json" + ) @app.route("/update") diff --git a/model.py b/model.py index 7311661..11ab48d 100644 --- a/model.py +++ b/model.py @@ -1,15 +1,14 @@ import os import pickle +import numpy as np +from xgboost import XGBRegressor from zipfile import ZipFile from datetime import datetime import pandas as pd -import numpy as np from sklearn.model_selection import train_test_split -from sklearn import linear_model from updater import download_binance_monthly_data, download_binance_daily_data from config import data_base_path, model_file_path - binance_data_path = os.path.join(data_base_path, "binance/futures-klines") training_price_data_path = os.path.join(data_base_path, "eth_price_data.csv") @@ -35,19 +34,14 @@ def download_data(): def format_data(): - files = sorted([x for x in os.listdir(binance_data_path)]) + files = sorted([x for x in os.listdir(binance_data_path) if x.endswith(".zip")]) - # No files to process if len(files) == 0: return price_df = pd.DataFrame() for file in files: zip_file_path = os.path.join(binance_data_path, file) - - if not zip_file_path.endswith(".zip"): - continue - myzip = ZipFile(zip_file_path) with myzip.open(myzip.filelist[0]) as f: line = f.readline() @@ -70,30 +64,43 @@ def format_data(): df.index.name = "date" price_df = pd.concat([price_df, df]) + price_df["timestamp"] = price_df.index.map(pd.Timestamp.timestamp) + price_df["price_diff"] = price_df["close"].diff() + price_df["volatility"] = (price_df["high"] - price_df["low"]) / price_df["open"] + price_df["volume"] = price_df["volume"] + price_df["moving_avg_7"] = price_df["close"].rolling(window=7).mean() + price_df["moving_avg_30"] = price_df["close"].rolling(window=30).mean() + + # Удаляем строки с NaN значениями + price_df.dropna(inplace=True) + + # Сохраняем данные price_df.sort_index().to_csv(training_price_data_path) def train_model(): - # Load the eth price data price_data = pd.read_csv(training_price_data_path) - df = pd.DataFrame() - # Convert 'date' to a numerical value (timestamp) we can use for regression - df["date"] = pd.to_datetime(price_data["date"]) - df["date"] = df["date"].map(pd.Timestamp.timestamp) + # Используем дополнительные признаки + x = price_data[ + [ + "timestamp", + "price_diff", + "volatility", + "volume", + "moving_avg_7", + "moving_avg_30", + ] + ] + y = price_data["close"] - df["price"] = price_data[["open", "close", "high", "low"]].mean(axis=1) - - # Reshape the data to the shape expected by sklearn - x = df["date"].values.reshape(-1, 1) - y = df["price"].values.reshape(-1, 1) - - # Split the data into training set and test set - x_train, _, y_train, _ = train_test_split(x, y, test_size=0.2, random_state=0) + x_train, x_test, y_train, y_test = train_test_split( + x, y, test_size=0.2, random_state=0 + ) # Train the model print("Training model...") - model = linear_model.Lasso(alpha=0.1) + model = XGBRegressor() model.fit(x_train, y_train) print("Model trained.") @@ -104,4 +111,8 @@ def train_model(): with open(model_file_path, "wb") as f: pickle.dump(model, f) - print(f"Trained model saved to {model_file_path}") \ No newline at end of file + print(f"Trained model saved to {model_file_path}") + + # Optional: Оценка модели + y_pred = model.predict(x_test) + print(f"Mean Absolute Error: {np.mean(np.abs(y_test - y_pred))}") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index feb7a96..f9a3cd7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,13 @@ numpy==1.26.2 pandas==2.1.3 Requests==2.32.0 scikit_learn==1.3.2 -werkzeug>=3.0.3 # not directly required, pinned by Snyk to avoid a vulnerability \ No newline at end of file +werkzeug>=3.0.3 # not directly required, pinned by Snyk to avoid a vulnerability +itsdangerous +Jinja2 +MarkupSafe +python-dateutil +pytz +scipy +six +sklearn +xgboost \ No newline at end of file