Files
Groove_Coaster_2_Server/new_server_7003/api/database.py
2025-12-03 10:06:08 +08:00

536 lines
19 KiB
Python

import sqlalchemy
from sqlalchemy import Table, Column, Integer, String, DateTime, JSON, ForeignKey, Index
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import select, update
import base64
import aiofiles
import json
from config import START_COIN, SIMULTANEOUS_LOGINS
from api.template import START_AVATARS, START_STAGES
import os
import databases
from datetime import datetime, timedelta
DB_NAME = "player.db"
DB_PATH = os.path.join(os.getcwd(), DB_NAME)
DATABASE_URL = f"sqlite+aiosqlite:///{DB_PATH}"
CACHE_DB_NAME = "cache.db"
CACHE_DB_PATH = os.path.join(os.getcwd(), CACHE_DB_NAME)
CACHE_DATABASE_URL = f"sqlite+aiosqlite:///{CACHE_DB_PATH}"
cache_database = databases.Database(CACHE_DATABASE_URL)
cache_metadata = sqlalchemy.MetaData()
player_database = databases.Database(DATABASE_URL)
player_metadata = sqlalchemy.MetaData()
#----------------------- Table definitions -----------------------#
accounts = Table(
"accounts",
player_metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("username", String(20), unique=True, nullable=False),
Column("password_hash", String(255), nullable=False),
Column("save_crc", String(24), nullable=True),
Column("save_timestamp", DateTime, nullable=True),
Column("save_id", String(24), nullable=True),
Column("coin_mp", Integer, default=0),
Column("title", Integer, default=1),
Column("avatar", Integer, default=1),
Column("mobile_delta", Integer, default=0),
Column("arcade_delta", Integer, default=0),
Column("total_delta", Integer, default=0),
Column("created_at", DateTime, default=datetime.utcnow),
Column("updated_at", DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
)
devices = Table(
"devices",
player_metadata,
Column("device_id", String(64), primary_key=True),
Column("user_id", Integer, ForeignKey("accounts.id")),
Column("my_stage", JSON, default=[]),
Column("my_avatar", JSON, default=[]),
Column("item", JSON, default=[]),
Column("daily_day", Integer, default=0),
Column("daily_timestamp", DateTime, default=datetime.min),
Column("coin", Integer, default=START_COIN),
Column("lvl", Integer, default=1),
Column("title", Integer, default=1),
Column("avatar", Integer, default=1),
Column("created_at", DateTime, default=datetime.utcnow),
Column("updated_at", DateTime, default=datetime.utcnow, onupdate=datetime.utcnow),
Column("last_login_at", DateTime, default=None),
Column("bind_token", String(64), unique=True, nullable=True)
)
results = Table(
"results",
player_metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("device_id", String(64), ForeignKey("devices.device_id")),
Column("user_id", Integer, ForeignKey("accounts.id")),
Column("stts", JSON, nullable=False),
Column("song_id", Integer, nullable=False),
Column("mode", Integer, nullable=False),
Column("avatar", Integer, nullable=False),
Column("score", Integer, nullable=False),
Column("high_score", JSON, nullable=False),
Column("play_rslt", JSON, nullable=False),
Column("item", Integer, nullable=False),
Column("os", String(8), nullable=False),
Column("os_ver", String(16), nullable=False),
Column("ver", String(8), nullable=False),
Column("created_at", DateTime, default=datetime.utcnow)
)
Index(
"idx_results_song_mode_score",
results.c.song_id,
results.c.mode,
results.c.score.desc(),
)
webs = Table(
"webs",
player_metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("user_id", Integer, ForeignKey("accounts.id")),
Column("permission", Integer, default=1),
Column("web_token", String(128), unique=True, nullable=False),
Column("last_save_export", Integer, nullable=True),
Column("created_at", DateTime, default=datetime.utcnow),
Column("updated_at", DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
)
batch_tokens = Table(
"batch_tokens",
player_metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("batch_token", String(128), unique=True, nullable=False),
Column("expire_at", DateTime, nullable=False),
Column("uses_left", Integer, default=1),
Column("auth_id", String(64), nullable=False),
Column("created_at", DateTime, default=datetime.utcnow),
Column("updated_at", DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
)
whitelists = Table(
"whitelists",
player_metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("device_id", String(64), ForeignKey("devices.device_id")),
)
blacklists = Table(
"blacklists",
player_metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("ban_terms", String(128), unique=True, nullable=False),
Column("reason", String(255), nullable=True)
)
binds = Table(
"binds",
player_metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("user_id", Integer, ForeignKey("accounts.id")),
Column("bind_account", String(128), unique=True, nullable=False),
Column("bind_code", String(6), nullable=False),
Column("is_verified", Integer, default=0),
Column("bind_date", DateTime, default=datetime.utcnow)
)
logs = Table(
"logs",
player_metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("user_id", Integer, ForeignKey("accounts.id")),
Column("filename", String(255), nullable=False),
Column("filesize", Integer, nullable=False),
Column("timestamp", DateTime, default=datetime.utcnow)
)
ranking_cache = Table(
"ranking_cache",
cache_metadata,
Column("id", Integer, primary_key=True, autoincrement=True),
Column("key", String(16), nullable=False),
Column("value", JSON, nullable=False),
Column("expire_at", Integer)
)
#----------------------- End of Table definitions -----------------------#
async def init_db():
if not os.path.exists(DB_PATH):
print("[DB] Creating new database:", DB_PATH)
if not os.path.exists(CACHE_DB_PATH):
print("[DB] Creating new cache database:", CACHE_DB_PATH)
cache_engine = create_async_engine(CACHE_DATABASE_URL, echo=False)
async with cache_engine.begin() as conn:
await conn.run_sync(cache_metadata.create_all)
await cache_engine.dispose()
engine = create_async_engine(DATABASE_URL, echo=False)
async with engine.begin() as conn:
await conn.run_sync(player_metadata.create_all)
await engine.dispose()
print("[DB] Database initialized successfully.")
await ensure_user_columns()
async def ensure_user_columns():
import aiosqlite
async with aiosqlite.connect(DB_PATH) as db:
async with db.execute("PRAGMA table_info(devices);") as cursor:
columns = [row[1] async for row in cursor]
alter_needed = False
if "bind_token" not in columns:
await db.execute("ALTER TABLE devices ADD COLUMN bind_token TEXT;")
alter_needed = True
#if "coin_mp" not in columns:
# await db.execute("ALTER TABLE user ADD COLUMN coin_mp INTEGER DEFAULT 1;")
# alter_needed = True
if alter_needed:
await db.commit()
print("[DB] Added missing columns to user table.")
async def get_bind(user_id):
if not user_id:
return None
query = binds.select().where(binds.c.user_id == user_id)
result = await player_database.fetch_one(query)
return dict(result) if result else None
async def refresh_bind(user_id, device_id):
existing_bind = await get_bind(user_id)
if existing_bind and existing_bind['is_verified'] == 1:
new_auth_token = base64.urlsafe_b64encode(os.urandom(64)).decode("utf-8")
update_query = update(devices).where(devices.c.device_id == device_id).values(
bind_token=new_auth_token
)
await player_database.execute(update_query)
return new_auth_token
return ""
async def log_download(user_id, filename, filesize):
query = logs.insert().values(
user_id=user_id,
filename=filename,
filesize=filesize,
timestamp=datetime.utcnow()
)
await player_database.execute(query)
async def get_downloaded_bytes(user_id, hours):
query = select(sqlalchemy.func.sum(logs.c.filesize)).where(
(logs.c.user_id == user_id) &
(logs.c.timestamp >= datetime.utcnow() - timedelta(hours=hours))
)
result = await player_database.fetch_one(query)
return result[0] if result[0] is not None else 0
async def verify_user_code(code, user_id):
existing_bind = await get_bind(user_id)
if existing_bind and existing_bind['is_verified'] == 1:
return "This account is already bound to an account."
query = binds.select().where(
(binds.c.bind_code == code) &
(binds.c.user_id == user_id) &
(binds.c.is_verified == 0) &
(binds.c.bind_date >= datetime.utcnow() - timedelta(minutes=10))
)
result = await player_database.fetch_one(query)
if not result:
return "Invalid or expired verification code."
update_query = update(binds).where(binds.c.id == result['id']).values(
is_verified=1,
bind_date=datetime.utcnow()
)
await player_database.execute(update_query)
return "Verified and account successfully bound."
async def decrypt_fields_to_user_info(decrypted_fields):
device_id = decrypted_fields[b'vid'][0].decode()
query = devices.select().where(devices.c.device_id == device_id)
device_record = await player_database.fetch_one(query)
if device_record:
device_record = dict(device_record)
user_query = accounts.select().where(accounts.c.id == device_record['user_id'])
user_record = await player_database.fetch_one(user_query)
if user_record:
user_record = dict(user_record)
return user_record, device_record
return None, device_record
return None, None
async def get_device_info(device_id):
query = devices.select().where(devices.c.device_id == device_id)
device_record = await player_database.fetch_one(query)
device_record = dict(device_record) if device_record else None
return device_record
async def user_id_to_user_info(user_id):
user_query = accounts.select().where(accounts.c.id == user_id)
user_record = await player_database.fetch_one(user_query)
user_record = dict(user_record) if user_record else None
if user_record:
user_record = dict(user_record)
device_query = devices.select().where(devices.c.user_id == user_id)
device_record = await player_database.fetch_all(device_query)
device_record = [dict(d) for d in device_record]
return user_record, device_record
return None, None
async def user_id_to_user_info_simple(user_id):
user_query = accounts.select().where(accounts.c.id == user_id)
user_record = await player_database.fetch_one(user_query)
user_record = dict(user_record) if user_record else None
return user_record
async def user_name_to_user_info(username):
user_query = accounts.select().where(accounts.c.username == username)
user_record = await player_database.fetch_one(user_query)
user_record = dict(user_record) if user_record else None
return user_record
async def check_whitelist(decrypted_fields):
device_id = decrypted_fields[b'vid'][0].decode()
user_info, _ = await decrypt_fields_to_user_info(decrypted_fields)
query = select(whitelists.c.device_id).where((whitelists.c.device_id == device_id) | (whitelists.c.device_id == user_info['username']))
result = await player_database.fetch_one(query)
return result is not None
async def check_blacklist(decrypted_fields):
device_id = decrypted_fields[b'vid'][0].decode()
user_info, _ = await decrypt_fields_to_user_info(decrypted_fields)
query = select(blacklists.c.ban_terms).where((blacklists.c.ban_terms == device_id) | (blacklists.c.ban_terms == user_info['username']))
result = await player_database.fetch_one(query)
return result is None
async def get_user_entitlement_from_devices(user_id):
devices_query = select(devices.c.my_stage, devices.c.my_avatar).where(devices.c.user_id == user_id)
devices_list = await player_database.fetch_all(devices_query)
devices_list = [dict(dev) for dev in devices_list] if devices_list else []
stage_set = set()
avatar_set = set()
for dev in devices_list:
my_stages = dev['my_stage'] if dev['my_stage'] else []
my_avatars = dev['my_avatar'] if dev['my_avatar'] else []
stage_set.update(my_stages)
avatar_set.update(my_avatars)
return list(stage_set), list(avatar_set)
async def set_user_data_using_decrypted_fields(decrypted_fields, data_fields):
data_fields['updated_at'] = datetime.utcnow()
device_id = decrypted_fields[b'vid'][0].decode()
device_query = devices.select().where(devices.c.device_id == device_id)
device_result = await player_database.fetch_one(device_query)
if device_result:
user_id = device_result['user_id']
query = (
update(accounts)
.where(accounts.c.id == user_id)
.values(**data_fields)
)
await player_database.execute(query)
async def set_device_data_using_decrypted_fields(decrypted_fields, data_fields):
data_fields['updated_at'] = datetime.utcnow()
device_id = decrypted_fields[b'vid'][0].decode()
query = (
update(devices)
.where(devices.c.device_id == device_id)
.values(**data_fields)
)
await player_database.execute(query)
async def get_user_from_save_id(save_id):
query = accounts.select().where(accounts.c.save_id == save_id)
result = await player_database.fetch_one(query)
result = dict(result) if result else None
return result
async def create_user(username, password_hash, device_id):
insert_query = accounts.insert().values(
username=username,
password_hash=password_hash,
save_crc=None,
save_timestamp=None,
save_id=None,
coin_mp=1,
title=1,
avatar=1,
mobile_delta=0,
arcade_delta=0,
total_delta=0,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
user_id = await player_database.execute(insert_query)
await login_user(user_id, device_id)
async def logout_user(device_id):
query = (
update(devices)
.where(devices.c.device_id == device_id)
.values(user_id=None)
)
await player_database.execute(query)
async def login_user(user_id, device_id):
query = (
update(devices)
.where(devices.c.device_id == device_id)
.values(user_id=user_id, last_login_at=datetime.utcnow())
)
await player_database.execute(query)
_, device_list = await user_id_to_user_info(user_id)
if len(device_list) > SIMULTANEOUS_LOGINS:
sorted_devices = sorted(device_list, key=lambda d: d['last_login_at'] or datetime.min)
devices_to_logout = sorted_devices[:-SIMULTANEOUS_LOGINS]
for device in devices_to_logout:
await logout_user(device['device_id'])
async def create_device(device_id, current_time):
insert_query = devices.insert().values(
device_id=device_id,
user_id=None,
my_avatar=START_AVATARS,
my_stage=START_STAGES,
item=[],
daily_day=1,
daily_timestamp=current_time,
coin=START_COIN,
lvl=1,
title=1,
avatar=1,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
last_login_at=None
)
await player_database.execute(insert_query)
async def is_admin(token):
if not token:
return False
query = webs.select().where(webs.c.web_token == token)
web_data = await player_database.fetch_one(query)
if not web_data:
return False
if web_data['permission'] != 2:
return False
return web_data['user_id']
async def results_query(query_params):
query = select(results.c.id, results.c.user_id, results.c.score, results.c.avatar)
for key, value in query_params.items():
query = query.where(getattr(results.c, key) == value)
query = query.order_by(results.c.score.desc())
records = await player_database.fetch_all(query)
return [dict(record) for record in records]
async def clear_rank_cache(key):
delete_query = ranking_cache.delete().where(ranking_cache.c.key == key)
await cache_database.execute(delete_query)
async def write_rank_cache(key, value, expire_seconds=None):
if expire_seconds:
expire_time = datetime.utcnow() + timedelta(seconds=expire_seconds)
else:
expire_time = None
insert_query = ranking_cache.insert().values(
key=key,
value=value,
expire_at=expire_time
)
await cache_database.execute(insert_query)
async def get_rank_cache(key):
query = ranking_cache.select().where(ranking_cache.c.key == key)
result = await cache_database.fetch_one(query)
if result:
expire_at = datetime.fromisoformat(result['expire_at']) if result['expire_at'] else None
if expire_at and expire_at < datetime.utcnow():
await clear_rank_cache(key)
return None
return dict(result)['value']
return None
async def get_user_export_data(user_id):
user_data = {}
user_info, device_list = await user_id_to_user_info(user_id)
user_save = await read_user_save_file(user_id)
user_info['save_data'] = user_save
user_data['account'] = [user_info]
user_data['devices'] = device_list
all_results_query = results.select().where(results.c.user_id == user_id)
all_results = await player_database.fetch_all(all_results_query)
user_data['results'] = [dict(result) for result in all_results]
user_binds_query = binds.select().where(binds.c.user_id == user_id)
user_binds = await player_database.fetch_all(user_binds_query)
user_data['binds'] = [dict(bind) for bind in user_binds]
# Convert JSON fields to strings
for key, value in user_data.items():
if isinstance(value, list):
for item in value:
if isinstance(item, dict):
for field, field_value in item.items():
if isinstance(field_value, (dict, list)):
item[field] = json.dumps(field_value)
return user_data
async def read_user_save_file(user_id):
if user_id is None:
return ""
elif type(user_id) != int:
return ""
else:
try:
async with aiofiles.open(f"./save/{user_id}.dat", "rb") as file:
result = await file.read()
result = result.decode("utf-8")
return result
except FileNotFoundError:
return ""
async def write_user_save_file(user_id, data):
if user_id is None:
return
elif type(user_id) != int:
return
else:
try:
async with aiofiles.open(f"./save/{user_id}.dat", "wb") as file:
await file.write(data.encode("utf-8"))
except Exception as e:
print(f"An error occurred while writing the file: {e}")