ritual/grpc-balancer.py

324 lines
13 KiB
Python

import threading
from collections import deque
from datetime import datetime, timedelta, timezone
import random
import logging
import time
import json
import socket
import requests
from grist_api import GristDocAPI
from flask import Flask, request, Response
logging.basicConfig(level=logging.INFO)
app = Flask(__name__)
BACKEND_SERVERS = []
SERVER_STATS = {}
STATS_LOCK = threading.Lock()
ADDRESS_STATS = {}
ADDRESS_STATS_LOCK = threading.Lock()
STATISTICS_THRESHOLD = 10
STATISTICS_WINDOW = timedelta(minutes=10)
MAX_WORKERS = 500
MAX_ERROR_RATE = 0.7
PORT = 5000
HOP_BY_HOP_HEADERS = {
'connection',
'keep-alive',
'proxy-authenticate',
'proxy-authorization',
'te',
'trailers',
'transfer-encoding',
'upgrade',
}
class GRIST:
def __init__(self, server, doc_id, api_key, logger):
self.server = server
self.doc_id = doc_id
self.api_key = api_key
self.logger = logger
self.grist = GristDocAPI(doc_id, server=server, api_key=api_key)
def table_name_convert(self, table_name):
return table_name.replace(" ", "_")
def to_timestamp(self, dtime: datetime) -> int:
if dtime.tzinfo is None:
dtime = dtime.replace(tzinfo=timezone(timedelta(hours=3)))
return int(dtime.timestamp())
def insert_row(self, data, table):
data = {key.replace(" ", "_"): value for key, value in data.items()}
row_id = self.grist.add_records(self.table_name_convert(table), [data])
return row_id
def update_column(self, row_id, column_name, value, table):
if isinstance(value, datetime):
value = self.to_timestamp(value)
column_name = column_name.replace(" ", "_")
self.grist.update_records(self.table_name_convert(table), [{ "id": row_id, column_name: value }])
def delete_row(self, row_id, table):
self.grist.delete_records(self.table_name_convert(table), [row_id])
def update(self, row_id, updates, table):
for column_name, value in updates.items():
if isinstance(value, datetime):
updates[column_name] = self.to_timestamp(value)
updates = {column_name.replace(" ", "_"): value for column_name, value in updates.items()}
self.grist.update_records(self.table_name_convert(table), [{"id": row_id, **updates}])
def fetch_table(self, table):
return self.grist.fetch_table(self.table_name_convert(table))
def find_record(self, id=None, state=None, name=None, table=None):
if table is None:
raise ValueError("Table is not specified")
table_data = self.grist.fetch_table(self.table_name_convert(table))
if id is not None:
record = [row for row in table_data if row.id == id]
return record
if state is not None and name is not None:
record = [row for row in table_data if row.State == state and row.name == name]
return record
if state is not None:
record = [row for row in table_data if row.State == state]
return record
if name is not None:
record = [row for row in table_data if row.Name == name]
return record
def find_settings(self, key, table):
table = self.fetch_table(self.table_name_convert(table))
for record in table:
if record.Setting == key:
if record.Value is None or record.Value == "":
raise ValueError(f"Setting {key} blank")
return record.Value
raise ValueError(f"Setting {key} not found")
@app.route('/', methods=['POST'])
def proxy():
data = request.get_data()
headers = dict(request.headers)
headers.pop('Accept-Encoding', None)
try:
data_json = json.loads(data.decode('utf-8'))
except json.JSONDecodeError:
logging.warning(f'Invalid JSON from {request.remote_addr}: {data}')
return Response('Invalid JSON', status=400)
# Функция для обновления статистики запросов по адресу
def update_address_stats(from_address):
now = datetime.now(timezone.utc)
with ADDRESS_STATS_LOCK:
if from_address not in ADDRESS_STATS:
ADDRESS_STATS[from_address] = deque()
ADDRESS_STATS[from_address].append(now)
# Удаление запросов, вышедших за пределы окна
while ADDRESS_STATS[from_address] and ADDRESS_STATS[from_address][0] < now - STATISTICS_WINDOW:
ADDRESS_STATS[from_address].popleft()
# Функция для извлечения 'from' адреса из запроса
def extract_from_address(req):
params = req.get("params", [])
if isinstance(params, list) and len(params) > 0 and isinstance(params[0], dict):
return params[0].get("from")
return None
# Проверка, является ли запрос массивом (батч-запрос)
if isinstance(data_json, list):
for req in data_json:
from_address = extract_from_address(req)
if from_address:
update_address_stats(from_address)
elif isinstance(data_json, dict):
from_address = extract_from_address(data_json)
if from_address:
update_address_stats(from_address)
if data_json.get("method") == "eth_chainId":
response_json = {
"jsonrpc": "2.0",
"id": data_json.get("id"),
"result": "0x2105" #base
}
response_str = json.dumps(response_json)
return Response(response_str, status=200, mimetype='application/json')
selected_servers = select_servers()
for server in selected_servers:
server_url = server['url']
server_id = server['id']
try:
headers['Host'] = server_url.split('//')[-1].split('/')[0]
#logging.info(f'Proxying request to {server_url}: {data}')
response = requests.post(server_url, data=data, headers=headers, timeout=5)
if response.status_code == 200:
print(".", end="", flush=True)
#MAX_DATA_LENGTH = 20
#data_str = data.decode('utf-8')
#data_json = json.loads(data_str)
#if "jsonrpc" in data_json: data_json.pop("jsonrpc")
#if 'params' in data_json and isinstance(data_json['params'], list):
# for idx, param in enumerate(data_json['params']):
# if isinstance(param, dict) and 'data' in param:
# original_data = param['data']
# if isinstance(original_data, str) and len(original_data) > MAX_DATA_LENGTH:
# truncated_data = original_data[:MAX_DATA_LENGTH - len("....SKIPPED")] + "....SKIPPED"
# data_json['params'][idx]['data'] = truncated_data
#truncated_data_str = json.dumps(data_json)
#response_str = response.content.decode('utf-8')
#response_json = json.loads(response_str)
#if "jsonrpc" in response_json: response_json.pop("jsonrpc")
#if 'result' in response_json:
# original_result = response_json['result']
# if isinstance(original_result, str) and len(original_result) > MAX_DATA_LENGTH:
# truncated_result = original_result[:MAX_DATA_LENGTH - len("....SKIPPED")] + "....SKIPPED"
# response_json['result'] = truncated_result
#truncated_response_str = json.dumps(response_json)
#logging.info(f'OK: {request.remote_addr}: {truncated_data_str} -> {server_url}: {response.status_code}/{truncated_response_str}')
with STATS_LOCK:
SERVER_STATS[server_id].append((datetime.now(timezone.utc), True))
filtered_headers = {
k: v for k, v in response.headers.items()
if k.lower() not in HOP_BY_HOP_HEADERS
}
filtered_headers.pop('Content-Encoding', None)
connection_header = response.headers.get('Connection', '')
connection_tokens = [token.strip().lower() for token in connection_header.split(',')]
for token in connection_tokens:
filtered_headers.pop(token, None)
return Response(response.content, status=response.status_code, headers=filtered_headers)
else:
logging.warning(f'Failed to proxy request to {server_url}: {response.status_code}/{response.content}')
with STATS_LOCK:
SERVER_STATS[server_id].append((datetime.now(timezone.utc), False))
continue
except requests.exceptions.RequestException as e:
logging.error(f'Exception while proxying to {server_url}: {e}')
with STATS_LOCK:
SERVER_STATS[server_id].append((datetime.now(timezone.utc), False))
continue
return Response('All backend servers are unavailable', status=503)
def select_servers():
now = datetime.now(timezone.utc)
with STATS_LOCK:
for server in BACKEND_SERVERS:
server_id = server['id']
stats = SERVER_STATS[server_id]
while stats and stats[0][0] < now - STATISTICS_WINDOW:
stats.popleft()
total_requests = sum(len(SERVER_STATS[server['id']]) for server in BACKEND_SERVERS)
if total_requests < STATISTICS_THRESHOLD:
servers = BACKEND_SERVERS.copy()
random.shuffle(servers)
#logging.info("Total requests below threshold. Shuffled servers: %s", servers)
return servers
server_scores = []
with STATS_LOCK:
for server in BACKEND_SERVERS:
server_id = server['id']
stats = SERVER_STATS[server_id]
failures = sum(1 for t, success in stats if not success)
successes = len(stats) - failures
total = successes + failures
error_rate = failures / total if total > 0 else 0
server_scores.append({
'server': server,
'failures': failures,
'successes': successes,
'error_rate': error_rate
})
#logging.info(f"Server {server_id}: Failures={failures}, Successes={successes}, Error Rate={error_rate:.2f}")
healthy_servers = [s for s in server_scores if s['error_rate'] <= MAX_ERROR_RATE]
if not healthy_servers:
logging.warning("No healthy servers available.")
return BACKEND_SERVERS.copy()
healthy_servers.sort(key=lambda x: x['error_rate'])
total_weight = sum(1 - s['error_rate'] for s in healthy_servers)
if total_weight == 0:
weights = [1 for _ in healthy_servers]
else:
weights = [(1 - s['error_rate']) / total_weight for s in healthy_servers]
selected_server = random.choices( [s['server'] for s in healthy_servers], weights=weights, k=1 )[0]
selected_servers = [selected_server] + [s['server'] for s in healthy_servers if s['server'] != selected_server]
return selected_servers
def upload_stats_to_grist(update_row):
while True:
try:
total_stats = {
'failures': 0,
'successes': 0,
'rps': 0
}
with STATS_LOCK:
for server in BACKEND_SERVERS:
server_id = server['id']
server_stats = SERVER_STATS[server_id]
failures = sum(1 for t, success in server_stats if not success)
successes = len(server_stats) - failures
total_stats['failures'] += failures
total_stats['successes'] += successes
total_stats['rps'] += len(server_stats)/STATISTICS_WINDOW.total_seconds()
health = f"{total_stats['successes']}/{total_stats['failures']}/{total_stats['rps']:.2f}"
update_row({"Health": health})
except Exception as e:
logging.error(f"Failed to upload stats to Grist: {str(e)}")
time.sleep(30)
if __name__ == '__main__':
GRIST_ROW_NAME = socket.gethostname()
NODES_TABLE = "Nodes"
RPC_TABLE = "RPC_list"
with open('/root/node/grist.json', 'r', encoding='utf-8') as f:
grist_data = json.loads(f.read())
host = grist_data.get('grist_server')
doc_id = grist_data.get('grist_doc_id')
api_key = grist_data.get('grist_api_key')
grist = GRIST(host, doc_id, api_key, logging)
current_vm = grist.find_record(name=GRIST_ROW_NAME, table=NODES_TABLE)[0]
def grist_callback(msg): grist.update(current_vm.id, msg, NODES_TABLE)
BACKEND_SERVERS = []
SERVER_STATS = {}
table = grist.fetch_table(RPC_TABLE)
for row in table:
if row.URL:
server_info = {'id': row.id, 'url': row.URL}
BACKEND_SERVERS.append(server_info)
SERVER_STATS[row.id] = deque()
upload_thread = threading.Thread(target=upload_stats_to_grist, daemon=True, args=(grist_callback,))
upload_thread.start()
from waitress import serve
logging.info(f"Starting server on port {PORT}")
serve(app, host='0.0.0.0', port=PORT, threads=MAX_WORKERS, connection_limit=1000)