From b47f9296d4e3d25c05f91c8ad61a0ee107ff0370 Mon Sep 17 00:00:00 2001 From: UnitedAirforce Date: Sun, 22 Jun 2025 08:48:48 +0800 Subject: [PATCH] fixes --- 7002.py | 3 +- api/database.py | 36 +++++++--- api/play.py | 170 +++++++++++++++++++++++++++++++++++++++++++++++ api/ranking.py | 6 +- api/user.py | 163 ++------------------------------------------- requirements.txt | 10 ++- 6 files changed, 212 insertions(+), 176 deletions(-) create mode 100644 api/play.py diff --git a/7002.py b/7002.py index 6e47f69..c759e83 100644 --- a/7002.py +++ b/7002.py @@ -13,6 +13,7 @@ from api.misc import get_4max_version_string from api.user import routes as user_routes from api.ranking import routes as rank_routes from api.shop import routes as shop_routes +from api.play import routes as play_routes from config import HOST, PORT, DEBUG, SSL_CERT, SSL_KEY, ROOT_FOLDER, ACTUAL_HOST, ACTUAL_PORT @@ -33,7 +34,7 @@ async def serve_file(request): routes = [] -routes = routes + user_routes + rank_routes + shop_routes +routes = routes + user_routes + rank_routes + shop_routes + play_routes routes.append(Route("/{path:path}", serve_file)) diff --git a/api/database.py b/api/database.py index a8e4e25..3533ea7 100644 --- a/api/database.py +++ b/api/database.py @@ -53,20 +53,19 @@ result = Table( Column("rid", Integer, primary_key=True, autoincrement=True), Column("vid", String(512), nullable=False), Column("tid", String(512), nullable=False), - Column("sid", String(512), nullable=False), + Column("sid", Integer, nullable=False), Column("stts", String(64)), - Column("id", String(8)), - Column("mode", String(4)), - Column("avatar", String(4)), - Column("score", String(16)), + Column("id", Integer), + Column("mode", Integer), + Column("avatar", Integer), + Column("score", Integer), Column("high_score", String(128)), Column("play_rslt", String(128)), - Column("item", String(16)), + Column("item", Integer), Column("os", String(16)), Column("os_ver", String(16)), Column("ver", String(16)), - Column("mike", String(8)), - + Column("mike", Integer), ) whitelist = Table( @@ -93,6 +92,7 @@ async def init_db(): await engine.dispose() print("[DB] Database initialized successfully.") + await ensure_user_columns() async def get_user_data(uid, data_field): query = select(user.c[data_field]).where(user.c.device_id == uid[b'vid'][0].decode()) @@ -125,4 +125,22 @@ async def check_blacklist(uid): ) async with database.transaction(): result = await database.fetch_one(query) - return result is None \ No newline at end of file + return result is None + +async def ensure_user_columns(): + import aiosqlite + + async with aiosqlite.connect(DB_PATH) as db: + async with db.execute("PRAGMA table_info(user);") as cursor: + columns = [row[1] async for row in cursor] + + alter_needed = False + if "save_id" not in columns: + await db.execute("ALTER TABLE user ADD COLUMN save_id TEXT;") + alter_needed = True + if "coin_mp" not in columns: + await db.execute("ALTER TABLE user ADD COLUMN coin_mp INTEGER DEFAULT 1;") + alter_needed = True + if alter_needed: + await db.commit() + print("[DB] Added missing columns to user table.") diff --git a/api/play.py b/api/play.py new file mode 100644 index 0000000..439f90e --- /dev/null +++ b/api/play.py @@ -0,0 +1,170 @@ +from starlette.responses import Response +from starlette.requests import Request +from starlette.routing import Route +import os +import json +from sqlalchemy import select, update, insert +import xml.etree.ElementTree as ET + +from config import ROOT_FOLDER, START_COIN, COIN_REWARD, AUTHORIZATION_NEEDED + +from api.database import database, user, daily_reward, result, result, check_blacklist, check_whitelist +from api.crypt import decrypt_fields +from api.templates import START_STAGES, EXP_UNLOCKED_SONGS + +async def result_request(request: Request): + decrypted_fields, _ = await decrypt_fields(request) + if not decrypted_fields: + return Response("""10Invalid request data.""", media_type="application/xml") + + should_serve = True + if AUTHORIZATION_NEEDED: + should_serve = await check_whitelist(decrypted_fields) and not await check_blacklist(decrypted_fields) + + if not should_serve: + return Response("""403Access denied.""", media_type="application/xml") + + device_id = decrypted_fields[b'vid'][0].decode() + file_path = os.path.join(ROOT_FOLDER, "files/result.xml") + try: + tree = ET.parse(file_path) + root = tree.getroot() + except Exception as e: + return Response(f"""500Error parsing XML: {str(e)}""", media_type="application/xml") + + vid = decrypted_fields[b'vid'][0].decode() + stts = decrypted_fields[b'stts'][0].decode() + track_id = decrypted_fields[b'id'][0].decode() + mode = decrypted_fields[b'mode'][0].decode() + avatar = decrypted_fields[b'avatar'][0].decode() + score = int(decrypted_fields[b'score'][0].decode()) + high_score = decrypted_fields[b'high_score'][0].decode() + play_rslt = decrypted_fields[b'play_rslt'][0].decode() + item = decrypted_fields[b'item'][0].decode() + device_os = decrypted_fields[b'os'][0].decode() + os_ver = decrypted_fields[b'os_ver'][0].decode() + tid = decrypted_fields[b'tid'][0].decode() + ver = decrypted_fields[b'ver'][0].decode() + mike = decrypted_fields[b'mike'][0].decode() + + if int(track_id) not in range(616, 1024) or int(mode) not in range(0, 4): + query = select(daily_reward.c.coin).where(daily_reward.c.device_id == device_id) + row = await database.fetch_one(query) + query = select(user.c.coin_mp).where(user.c.device_id == device_id) + coin_mp_row = await database.fetch_one(query) + if coin_mp_row is None: + coin_mp_row = {"coin_mp": 1} + current_coin = row["coin"] if row and row["coin"] else START_COIN + updated_coin = current_coin + COIN_REWARD * coin_mp_row["coin_mp"] + + update_query = ( + update(daily_reward) + .where(daily_reward.c.device_id == device_id) + .values(coin=updated_coin) + ) + await database.execute(update_query) + + query = select(user.c.id).where(user.c.device_id == vid) + user_row = await database.fetch_one(query) + sid = user_row["id"] if user_row else "" + + do_insert = False + do_update_sid = False + do_update_vid = False + last_row_id = 0 + + if sid: + query = select(result.c.rid, result.c.score).where( + (result.c.id == track_id) & + (result.c.mode == mode) & + (result.c.sid == sid) + ).order_by(result.c.score.desc()) + records = await database.fetch_all(query) + if records: + last_row_id = records[0]["rid"] + if score > int(records[0]["score"]): + do_update_sid = True + else: + do_insert = True + else: + query = select(result.c.rid, result.c.score).where( + (result.c.id == track_id) & + (result.c.mode == mode) & + (result.c.sid == "") & + (result.c.vid == vid) + ).order_by(result.c.score.desc()) + records = await database.fetch_all(query) + if records: + last_row_id = records[0]["rid"] + if score > records[0]["score"]: + do_update_vid = True + else: + do_insert = True + + if do_insert: + insert_query = insert(result).values( + vid=vid, stts=stts, id=track_id, mode=mode, avatar=avatar, + score=score, high_score=high_score, play_rslt=play_rslt, item=item, + os=device_os, os_ver=os_ver, tid=tid, sid=sid, ver=ver, mike=mike + ) + result_obj = await database.execute(insert_query) + last_row_id = result_obj + elif do_update_sid: + update_query = ( + update(result) + .where((result.c.sid == sid) & (result.c.id == track_id) & (result.c.mode == mode)) + .values( + stts=stts, avatar=avatar, score=score, high_score=high_score, + play_rslt=play_rslt, item=item, os=device_os, os_ver=os_ver, + tid=tid, ver=ver, mike=mike, vid=vid + ) + ) + await database.execute(update_query) + elif do_update_vid: + update_query = ( + update(result) + .where((result.c.vid == vid) & (result.c.id == track_id) & (result.c.mode == mode)) + .values( + stts=stts, avatar=avatar, score=score, high_score=high_score, + play_rslt=play_rslt, item=item, os=device_os, os_ver=os_ver, + sid=sid, ver=ver, mike=mike + ) + ) + await database.execute(update_query) + + query = select(daily_reward.c.my_stage).where(daily_reward.c.device_id == device_id) + row = await database.fetch_one(query) + my_stage = set(json.loads(row["my_stage"])) if row and row["my_stage"] else set(START_STAGES) + + current_exp = int(stts.split(",")[0]) + for song in EXP_UNLOCKED_SONGS: + if song["lvl"] <= current_exp: + my_stage.add(song["id"]) + + my_stage = sorted(my_stage) + update_query = ( + update(daily_reward) + .where(daily_reward.c.device_id == device_id) + .values(lvl=current_exp, avatar=int(avatar), my_stage=json.dumps(my_stage)) + ) + await database.execute(update_query) + + query = select(result.c.rid, result.c.score).where( + (result.c.id == track_id) & (result.c.mode == mode) + ).order_by(result.c.score.desc()) + records = await database.fetch_all(query) + + rank = None + for idx, record in enumerate(records, start=1): + if record["rid"] == last_row_id: + rank = idx + break + + after_element = root.find('.//after') + after_element.text = str(rank) + xml_response = ET.tostring(tree.getroot(), encoding='unicode') + return Response(xml_response, media_type="application/xml") + +routes = [ + Route('/result.php', result_request, methods=['GET']) +] \ No newline at end of file diff --git a/api/ranking.py b/api/ranking.py index fac94ca..476a833 100644 --- a/api/ranking.py +++ b/api/ranking.py @@ -81,7 +81,6 @@ async def ranking_detail(request: Request): html = f"""
{song_name}
""" button_modes = [1, 2, 3] - print(len(difficulty_levels)) if (len(difficulty_levels) == 6): button_labels.extend(["AC-Easy", "AC-Normal", "AC-Hard"]) @@ -215,10 +214,9 @@ async def ranking_detail(request: Request): """ else: - query = select(result).where( - (result.c.id == song_id) & (result.c.mode == mode) - ).order_by(result.c.score.desc()) + query = select(result).where((result.c.id == song_id) & (result.c.mode == mode)) play_results = await database.fetch_all(query) + play_results = sorted(play_results, key=lambda x: int(x[8]), reverse=True) query = select(user).where(user.c.device_id == device_id) user_result = await database.fetch_one(query) diff --git a/api/user.py b/api/user.py index 6dc6bdc..2a2198c 100644 --- a/api/user.py +++ b/api/user.py @@ -8,12 +8,12 @@ import secrets from sqlalchemy import select, update, insert import xml.etree.ElementTree as ET -from config import ROOT_FOLDER, START_COIN, COIN_REWARD, AUTHORIZATION_NEEDED, HOST, PORT +from config import ROOT_FOLDER, START_COIN, AUTHORIZATION_NEEDED, HOST, PORT from api.misc import is_alphanumeric, inform_page, verify_password, hash_password, crc32_decimal, get_model_pak, get_tune_pak, get_skin_pak, get_m4a_path, get_stage_path, get_stage_zero -from api.database import database, user, daily_reward, result, get_user_data, set_user_data, check_blacklist, check_whitelist +from api.database import database, user, daily_reward, get_user_data, set_user_data, check_blacklist, check_whitelist from api.crypt import decrypt_fields -from api.templates import START_AVATARS, START_STAGES, EXP_UNLOCKED_SONGS +from api.templates import START_AVATARS, START_STAGES async def info(request: Request): file_path = os.path.join(ROOT_FOLDER, "files/history.html") @@ -652,160 +652,6 @@ async def bonus(request: Request): return Response(xml_response, media_type="application/xml") -async def result_request(request: Request): - decrypted_fields, _ = await decrypt_fields(request) - if not decrypted_fields: - return Response("""10Invalid request data.""", media_type="application/xml") - - should_serve = True - if AUTHORIZATION_NEEDED: - should_serve = await check_whitelist(decrypted_fields) and not await check_blacklist(decrypted_fields) - - if not should_serve: - return Response("""403Access denied.""", media_type="application/xml") - - device_id = decrypted_fields[b'vid'][0].decode() - file_path = os.path.join(ROOT_FOLDER, "files/result.xml") - try: - tree = ET.parse(file_path) - root = tree.getroot() - except Exception as e: - return Response(f"""500Error parsing XML: {str(e)}""", media_type="application/xml") - - vid = decrypted_fields[b'vid'][0].decode() - stts = decrypted_fields[b'stts'][0].decode() - track_id = decrypted_fields[b'id'][0].decode() - mode = decrypted_fields[b'mode'][0].decode() - avatar = decrypted_fields[b'avatar'][0].decode() - score = int(decrypted_fields[b'score'][0].decode()) - high_score = decrypted_fields[b'high_score'][0].decode() - play_rslt = decrypted_fields[b'play_rslt'][0].decode() - item = decrypted_fields[b'item'][0].decode() - device_os = decrypted_fields[b'os'][0].decode() - os_ver = decrypted_fields[b'os_ver'][0].decode() - tid = decrypted_fields[b'tid'][0].decode() - ver = decrypted_fields[b'ver'][0].decode() - mike = decrypted_fields[b'mike'][0].decode() - - if int(track_id) not in range(616, 1024) or int(mode) not in range(0, 4): - query = select(daily_reward.c.coin).where(daily_reward.c.device_id == device_id) - row = await database.fetch_one(query) - query = select(user.c.coin_mp).where(user.c.device_id == device_id) - coin_mp_row = await database.fetch_one(query) - if coin_mp_row is None: - coin_mp_row = {"coin_mp": 1} - current_coin = row["coin"] if row and row["coin"] else START_COIN - updated_coin = current_coin + COIN_REWARD * coin_mp_row["coin_mp"] - - update_query = ( - update(daily_reward) - .where(daily_reward.c.device_id == device_id) - .values(coin=updated_coin) - ) - await database.execute(update_query) - - query = select(user.c.id).where(user.c.device_id == vid) - user_row = await database.fetch_one(query) - sid = user_row["id"] if user_row else "" - - do_insert = False - do_update_sid = False - do_update_vid = False - last_row_id = 0 - - if sid: - query = select(result.c.rid, result.c.score).where( - (result.c.id == track_id) & - (result.c.mode == mode) & - (result.c.sid == sid) - ).order_by(result.c.score.desc()) - records = await database.fetch_all(query) - if records: - last_row_id = records[0]["rid"] - if score > int(records[0]["score"]): - do_update_sid = True - else: - do_insert = True - else: - query = select(result.c.rid, result.c.score).where( - (result.c.id == track_id) & - (result.c.mode == mode) & - (result.c.sid == "") & - (result.c.vid == vid) - ).order_by(result.c.score.desc()) - records = await database.fetch_all(query) - if records: - last_row_id = records[0]["rid"] - if score > records[0]["score"]: - do_update_vid = True - else: - do_insert = True - - if do_insert: - insert_query = insert(result).values( - vid=vid, stts=stts, id=track_id, mode=mode, avatar=avatar, - score=score, high_score=high_score, play_rslt=play_rslt, item=item, - os=device_os, os_ver=os_ver, tid=tid, sid=sid, ver=ver, mike=mike - ) - result_obj = await database.execute(insert_query) - last_row_id = result_obj - elif do_update_sid: - update_query = ( - update(result) - .where((result.c.sid == sid) & (result.c.id == track_id) & (result.c.mode == mode)) - .values( - stts=stts, avatar=avatar, score=score, high_score=high_score, - play_rslt=play_rslt, item=item, os=device_os, os_ver=os_ver, - tid=tid, ver=ver, mike=mike, vid=vid - ) - ) - await database.execute(update_query) - elif do_update_vid: - update_query = ( - update(result) - .where((result.c.vid == vid) & (result.c.id == track_id) & (result.c.mode == mode)) - .values( - stts=stts, avatar=avatar, score=score, high_score=high_score, - play_rslt=play_rslt, item=item, os=device_os, os_ver=os_ver, - sid=sid, ver=ver, mike=mike - ) - ) - await database.execute(update_query) - - query = select(daily_reward.c.my_stage).where(daily_reward.c.device_id == device_id) - row = await database.fetch_one(query) - my_stage = set(json.loads(row["my_stage"])) if row and row["my_stage"] else set(START_STAGES) - - current_exp = int(stts.split(",")[0]) - for song in EXP_UNLOCKED_SONGS: - if song["lvl"] <= current_exp: - my_stage.add(song["id"]) - - my_stage = sorted(my_stage) - update_query = ( - update(daily_reward) - .where(daily_reward.c.device_id == device_id) - .values(lvl=current_exp, avatar=int(avatar), my_stage=json.dumps(my_stage)) - ) - await database.execute(update_query) - - query = select(result.c.rid, result.c.score).where( - (result.c.id == track_id) & (result.c.mode == mode) - ).order_by(result.c.score.desc()) - records = await database.fetch_all(query) - - rank = None - for idx, record in enumerate(records, start=1): - if record["rid"] == last_row_id: - rank = idx - break - - after_element = root.find('.//after') - after_element.text = str(rank) - xml_response = ET.tostring(tree.getroot(), encoding='unicode') - return Response(xml_response, media_type="application/xml") - - routes = [ Route('/info.php', info, methods=['GET']), Route('/history.php', history, methods=['GET']), @@ -824,6 +670,5 @@ routes = [ Route('/start.php', start, methods=['GET']), Route('/sync.php', sync, methods=['GET', 'POST']), Route('/ttag.php', ttag, methods=['GET']), - Route('/login_bonus.php', bonus, methods=['GET']), - Route('/result.php', result_request, methods=['GET']), + Route('/login_bonus.php', bonus, methods=['GET']) ] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9fb40bc..97822e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,8 @@ +cryptography +databases +multipart +requests +sqlalchemy starlette -bcrypt -pycryptodome -requests \ No newline at end of file +urllib +uvicorn \ No newline at end of file