Files
Groove_Coaster_2_Server/new_server_7003/api/file.py
UnitedAirforce c4ceefb282 fixes
2025-12-13 18:00:16 +08:00

117 lines
4.2 KiB
Python

from starlette.responses import Response, FileResponse
from starlette.requests import Request
from starlette.routing import Route
from sqlalchemy import select
import openpyxl
from io import BytesIO
import os
from api.database import player_database, devices, binds, batch_tokens, log_download, get_downloaded_bytes
from config import AUTHORIZATION_MODE, DAILY_DOWNLOAD_LIMIT
async def serve_file(request: Request):
auth_token = request.path_params['auth_token']
folder = request.path_params['folder']
filename = request.path_params['filename']
if folder not in ["audio", "stage", "pak"]:
return Response("Unauthorized", status_code=403)
if not filename.endswith(".zip") and not filename.endswith(".pak"):
return Response("Unauthorized", status_code=403)
existing_batch_token = select(batch_tokens).where((batch_tokens.c.batch_token == auth_token) & (batch_tokens.c.uses_left > -1))
batch_result = await player_database.fetch_one(existing_batch_token)
if batch_result:
pass
elif AUTHORIZATION_MODE == 0:
existing_device = select(devices).where(devices.c.device_id == auth_token)
result = await player_database.fetch_one(existing_device)
if not result:
return Response("Unauthorized", status_code=403)
else:
pass
else:
existing_device_query = select(devices).where((devices.c.bind_token == auth_token))
existing_device = await player_database.fetch_one(existing_device_query)
if not existing_device:
return Response("Unauthorized - device not found", status_code=403)
existing_bind_query = select(binds).where((binds.c.user_id == existing_device['user_id']) & (binds.c.is_verified == 1))
bind_result = await player_database.fetch_one(existing_bind_query)
if not bind_result:
return Response("Unauthorized - bind not verified", status_code=403)
daily_bytes = await get_downloaded_bytes(bind_result['user_id'], 24)
if daily_bytes >= DAILY_DOWNLOAD_LIMIT:
return Response("Daily download limit exceeded", status_code=403)
safe_filename = os.path.realpath(os.path.join(os.getcwd(), "files", "gc2", folder, filename))
base_directory = os.path.realpath(os.path.join(os.getcwd(), "files", "gc2", folder))
if not safe_filename.startswith(base_directory):
return Response("Unauthorized", status_code=403)
file_path = safe_filename
if os.path.isfile(file_path):
# get size of file
if AUTHORIZATION_MODE != 0 and not batch_result:
file_size = os.path.getsize(file_path)
await log_download(bind_result['user_id'], filename, file_size)
return FileResponse(file_path)
else:
return Response("File not found", status_code=404)
async def serve_public_file(request: Request):
path = request.path_params['path']
safe_filename = os.path.realpath(os.path.join(os.getcwd(), "files", path))
base_directory = os.path.realpath(os.path.join(os.getcwd(), "files"))
if not safe_filename.startswith(base_directory):
return Response("Unauthorized", status_code=403)
if os.path.isfile(safe_filename):
return FileResponse(safe_filename)
else:
return Response("File not found", status_code=404)
async def convert_user_export_data(data):
wb = openpyxl.Workbook()
# Remove default sheet
default_sheet = wb.active
wb.remove(default_sheet)
# Create sheet for each top-level key
for sheet_name, rows in data.items():
ws = wb.create_sheet(title=sheet_name)
# rows expected to be list[dict]
if isinstance(rows, list):
await write_dict_list_to_sheet(ws, rows)
stream = BytesIO()
wb.save(stream)
stream.seek(0)
return stream
async def write_dict_list_to_sheet(ws, rows):
if not rows:
return
# headers
headers = list(rows[0].keys())
ws.append(headers)
# rows
for row in rows:
ws.append([row.get(h, "") for h in headers])
routes = [
Route("/files/gc2/{auth_token}/{folder}/{filename}", serve_file, methods=["GET"]),
Route("/files/{path:path}", serve_public_file, methods=["GET"]),
]