diff --git a/7002.py b/7002.py index c759e83..50b8de7 100644 --- a/7002.py +++ b/7002.py @@ -15,7 +15,7 @@ from api.ranking import routes as rank_routes from api.shop import routes as shop_routes from api.play import routes as play_routes -from config import HOST, PORT, DEBUG, SSL_CERT, SSL_KEY, ROOT_FOLDER, ACTUAL_HOST, ACTUAL_PORT +from config import DEBUG, SSL_CERT, SSL_KEY, ROOT_FOLDER, ACTUAL_HOST, ACTUAL_PORT if (os.path.isfile('./files/dlc_4max.html')): get_4max_version_string() @@ -42,6 +42,7 @@ app = Starlette(debug=DEBUG, routes=routes) @app.on_event("startup") async def startup(): + global redis await database.connect() await init_db() diff --git a/api/database.py b/api/database.py index 3533ea7..c0d9681 100644 --- a/api/database.py +++ b/api/database.py @@ -1,14 +1,17 @@ -from starlette.responses import JSONResponse, Response -from starlette.requests import Request - import sqlalchemy -from sqlalchemy import Table, Column, Integer, String, DateTime, ForeignKey +from sqlalchemy import Table, Column, Integer, String, DateTime from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy import select, update +from config import REDIS_ADDRESS, USE_REDIS_CACHE + import os import databases import datetime +if USE_REDIS_CACHE: + import redis.asyncio as aioredis + +redis = None DB_NAME = "player.db" DB_PATH = os.path.join(os.getcwd(), DB_NAME) @@ -81,7 +84,7 @@ blacklist = Table( ) async def init_db(): - + global redis if not os.path.exists(DB_PATH): print("[DB] Creating new database:", DB_PATH) @@ -93,6 +96,9 @@ async def init_db(): await engine.dispose() print("[DB] Database initialized successfully.") await ensure_user_columns() + if USE_REDIS_CACHE: + print("[DB] Connecting to Redis at", REDIS_ADDRESS) + redis = await aioredis.from_url("redis://" + REDIS_ADDRESS) async def get_user_data(uid, data_field): query = select(user.c[data_field]).where(user.c.device_id == uid[b'vid'][0].decode()) diff --git a/api/ranking.py b/api/ranking.py index 476a833..cc0af7d 100644 --- a/api/ranking.py +++ b/api/ranking.py @@ -2,11 +2,12 @@ from starlette.responses import HTMLResponse from starlette.requests import Request from starlette.routing import Route import os +import json from sqlalchemy import select, update -from config import AUTHORIZATION_NEEDED +from config import AUTHORIZATION_NEEDED, USE_REDIS_CACHE -from api.database import database, user, result, daily_reward, check_blacklist, check_whitelist +import api.database from api.crypt import decrypt_fields, encryptAES from api.templates import EXP_UNLOCKED_SONGS, TITLE_LISTS, SONG_LIST from api.misc import inform_page @@ -18,7 +19,7 @@ async def ranking(request: Request): should_serve = True if AUTHORIZATION_NEEDED: - should_serve = await check_whitelist(decrypted_fields) and not await check_blacklist(decrypted_fields) + should_serve = await api.database.check_whitelist(decrypted_fields) and not await api.database.check_blacklist(decrypted_fields) if should_serve: device_id = decrypted_fields[b'vid'][0].decode() @@ -60,7 +61,7 @@ async def ranking_detail(request: Request): should_serve = True if AUTHORIZATION_NEEDED: - should_serve = await check_whitelist(decrypted_fields) and not await check_blacklist(decrypted_fields) + should_serve = await api.database.check_whitelist(decrypted_fields) and not await api.database.check_blacklist(decrypted_fields) if should_serve: device_id = decrypted_fields[b'vid'][0].decode() @@ -108,6 +109,11 @@ async def ranking_detail(request: Request): play_results = None user_result = None device_result = None + if USE_REDIS_CACHE: + cache_key = f"{song_id}-{mode}" + cached = await api.database.redis.get(cache_key) + else: + cached = False if (song_id == -1): # Filter out the mobile/AC modes @@ -118,50 +124,57 @@ async def ranking_detail(request: Request): else: exclude = [1, 2, 3] - query = select(result.c.vid, result.c.sid, result.c.mode, result.c.avatar, result.c.score) - play_results = await database.fetch_all(query) - - query = select(daily_reward.c.device_id, daily_reward.c.title, daily_reward.c.avatar) - device_results_raw = await database.fetch_all(query) - device_results = {row["device_id"]: {"title": row["title"], "avatar": row["avatar"]} for row in device_results_raw} - - query = select(user.c.id, user.c.username, user.c.device_id) - user_results_raw = await database.fetch_all(query) - user_results = {row["id"]: {"username": row["username"], "device_id": row["device_id"]} for row in user_results_raw} - query = select(user).where(user.c.device_id == device_id) - cur_user = await database.fetch_one(query) + if cached and USE_REDIS_CACHE: + sorted_players = json.loads(cached) - player_scores = {} + else: + query = select(api.database.result.c.vid, api.database.result.c.sid, api.database.result.c.mode, api.database.result.c.avatar, api.database.result.c.score) + play_results = await api.database.database.fetch_all(query) - filtered_play_results = [play for play in play_results if int(play[2]) not in exclude] + query = select(api.database.daily_reward.c.device_id, api.database.daily_reward.c.title, api.database.daily_reward.c.avatar) + device_results_raw = await api.database.database.fetch_all(query) + device_results = {row["device_id"]: {"title": row["title"], "avatar": row["avatar"]} for row in device_results_raw} - for play in filtered_play_results: - did = play[0] - sid = play[1] - avatar = play[3] - score = play[4] - username, title = None, None + query = select(api.database.user.c.id, api.database.user.c.username, api.database.user.c.device_id) + user_results_raw = await api.database.database.fetch_all(query) + user_results = {row["id"]: {"username": row["username"], "device_id": row["device_id"]} for row in user_results_raw} + + query = select(api.database.user).where(api.database.user.c.device_id == device_id) + cur_user = await api.database.database.fetch_one(query) - if sid: - sid = int(sid) - if sid in user_results: - username = user_results[sid]["username"] - did = user_results[sid]["device_id"] - else: # Guest - username = f"Guest({did[-6:]})" + player_scores = {} - # title is device-specific - title = device_results.get(did, {}).get("title", "1") + filtered_play_results = [play for play in play_results if int(play[2]) not in exclude] - if username in player_scores: - player_scores[username]["score"] += int(score) - player_scores[username]["avatar"] = avatar # But avatar is based on latest play submission - player_scores[username]["title"] = title - else: - player_scores[username] = {"score": int(score), "avatar": avatar, "title": title} + for play in filtered_play_results: + did = play[0] + sid = play[1] + avatar = play[3] + score = play[4] + username, title = None, None - sorted_players = sorted(player_scores.items(), key=lambda x: x[1]["score"], reverse=True) + if sid: + sid = int(sid) + if sid in user_results: + username = user_results[sid]["username"] + did = user_results[sid]["device_id"] + else: # Guest + username = f"Guest({did[-6:]})" + + # title is device-specific + title = device_results.get(did, {}).get("title", "1") + + if username in player_scores: + player_scores[username]["score"] += int(score) + player_scores[username]["avatar"] = avatar # But avatar is based on latest play submission + player_scores[username]["title"] = title + else: + player_scores[username] = {"score": int(score), "avatar": avatar, "title": title} + + sorted_players = sorted(player_scores.items(), key=lambda x: x[1]["score"], reverse=True) + if USE_REDIS_CACHE: + await api.database.redis.set(cache_key, json.dumps(sorted_players), ex=300) username = cur_user[1] if cur_user else f"Guest({device_id[-6:]})" @@ -214,15 +227,21 @@ async def ranking_detail(request: Request): """ else: - query = select(result).where((result.c.id == song_id) & (result.c.mode == mode)) - play_results = await database.fetch_all(query) - play_results = sorted(play_results, key=lambda x: int(x[8]), reverse=True) + if cached and USE_REDIS_CACHE: + play_results = json.loads(cached) - query = select(user).where(user.c.device_id == device_id) - user_result = await database.fetch_one(query) + else: + query = select(api.database.result).where((api.database.result.c.id == song_id) & (api.database.result.c.mode == mode)) + play_results = await api.database.database.fetch_all(query) + play_results = sorted(play_results, key=lambda x: int(x[8]), reverse=True) + if USE_REDIS_CACHE: + await api.database.redis.set(cache_key, json.dumps(play_results), ex=300) - query = select(daily_reward).where(daily_reward.c.device_id == device_id) - device_result = await database.fetch_one(query) + query = select(api.database.user).where(api.database.user.c.device_id == device_id) + user_result = await api.database.database.fetch_one(query) + + query = select(api.database.daily_reward).where(api.database.daily_reward.c.device_id == device_id) + device_result = await api.database.database.fetch_one(query) user_id = user_result[0] if user_result else None username = user_result[1] if user_result else f"Guest({device_id[-6:]})" @@ -265,13 +284,13 @@ async def ranking_detail(request: Request): username = f"Guest({record[1][-6:]})" device_info = None if record[3]: - query = select(user.c.username).where(user.c.id == record[3]) - user_data = await database.fetch_one(query) + query = select(api.database.user.c.username).where(api.database.user.c.id == record[3]) + user_data = await api.database.database.fetch_one(query) if user_data: username = user_data["username"] - query = select(daily_reward.c.title).where(daily_reward.c.device_id == record[1]) - device_title = await database.fetch_one(query) + query = select(api.database.daily_reward.c.title).where(api.database.daily_reward.c.device_id == record[1]) + device_title = await api.database.database.fetch_one(query) if device_title: device_info = device_title["title"] else: @@ -321,7 +340,7 @@ async def status(request: Request): should_serve = True if AUTHORIZATION_NEEDED: - should_serve = await check_whitelist(decrypted_fields) and not await check_blacklist(decrypted_fields) + should_serve = await api.database.check_whitelist(decrypted_fields) and not await api.database.check_blacklist(decrypted_fields) if should_serve: device_id = decrypted_fields[b'vid'][0].decode() @@ -330,19 +349,19 @@ async def status(request: Request): if set_title: update_query = ( - update(daily_reward) - .where(daily_reward.c.device_id == device_id) + update(api.database.daily_reward) + .where(api.database.daily_reward.c.device_id == device_id) .values(title=set_title) ) - await database.execute(update_query) + await api.database.execute(update_query) - query = select(daily_reward).where(daily_reward.c.device_id == device_id) - user_data = await database.fetch_one(query) + query = select(api.database.daily_reward).where(api.database.daily_reward.c.device_id == device_id) + user_data = await api.database.database.fetch_one(query) user_name = f"Guest({device_id[-6:]})" if user_data: - query = select(user.c.username).where(user.c.device_id == device_id) - user_result = await database.fetch_one(query) + query = select(api.database.user.c.username).where(api.database.user.c.device_id == device_id) + user_result = await api.database.database.fetch_one(query) if user_result: user_name = user_result["username"] @@ -416,7 +435,7 @@ async def set_title(request: Request): should_serve = True if AUTHORIZATION_NEEDED: - should_serve = await check_whitelist(decrypted_fields) and not await check_blacklist(decrypted_fields) + should_serve = await api.database.check_whitelist(decrypted_fields) and not await api.database.check_blacklist(decrypted_fields) if should_serve: device_id = decrypted_fields[b'vid'][0].decode() @@ -424,8 +443,8 @@ async def set_title(request: Request): title_id = decrypted_fields[b'title_id'][0].decode() current_title = 1 - query = select(daily_reward.c.title).where(daily_reward.c.device_id == device_id) - row = await database.fetch_one(query) + query = select(api.database.daily_reward.c.title).where(api.database.daily_reward.c.device_id == device_id) + row = await api.database.database.fetch_one(query) if row: current_title = row["title"] @@ -458,7 +477,7 @@ async def mission(request: Request): should_serve = True if AUTHORIZATION_NEEDED: - should_serve = await check_whitelist(decrypted_fields) and not await check_blacklist(decrypted_fields) + should_serve = await api.database.check_whitelist(decrypted_fields) and not await api.database.check_blacklist(decrypted_fields) if should_serve: html = f"""
101001010