340 lines
11 KiB
Python
340 lines
11 KiB
Python
import io
|
||
import os
|
||
import uuid
|
||
import urllib.parse
|
||
from fastapi import Depends, HTTPException, status, APIRouter, File, UploadFile, Form
|
||
from fastapi.responses import StreamingResponse
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy import select, func
|
||
from google.oauth2.credentials import Credentials
|
||
import asyncio
|
||
from googleapiclient.discovery import build
|
||
from googleapiclient.http import MediaIoBaseUpload, MediaIoBaseDownload
|
||
|
||
from app.core.security import get_current_user
|
||
from app.db import models
|
||
from app.core.config import config
|
||
|
||
mediaRouter = APIRouter(
|
||
prefix='/media',
|
||
tags=['media'],
|
||
)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Инициализация клиента Google Drive
|
||
# ---------------------------------------------------------------------------
|
||
def _get_drive_service():
|
||
try:
|
||
credentials = Credentials(
|
||
token=None,
|
||
refresh_token=config.GOOGLE_REFRESH_TOKEN,
|
||
token_uri='https://oauth2.googleapis.com/token',
|
||
client_id=config.GOOGLE_CLIENT_ID,
|
||
client_secret=config.GOOGLE_CLIENT_SECRET,
|
||
scopes=['https://www.googleapis.com/auth/drive']
|
||
)
|
||
return build('drive', 'v3', credentials=credentials)
|
||
except Exception as e:
|
||
raise HTTPException(
|
||
status_code=500,
|
||
detail=f"Failed to initialize Google Drive service: {str(e)}"
|
||
)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Контроль квоты пользователя (с удалением старых файлов)
|
||
# ---------------------------------------------------------------------------
|
||
async def _cleanup_google_drive_quota(db: AsyncSession, owner_id: int | None, new_file_size: int):
|
||
if owner_id is None:
|
||
return
|
||
|
||
user_quota = getattr(config, "HOME_USER_QUOTA_BYTES", 10737418240) # 10 ГБ по умолчанию
|
||
|
||
user_res = await db.execute(select(models.User).where(models.User.id == owner_id))
|
||
user = user_res.scalars().first()
|
||
active_avatar_id = user.avatar_file_id if user else None
|
||
|
||
sum_stmt = select(func.sum(models.MediaItem.size_bytes)).where(
|
||
models.MediaItem.owner_id == owner_id
|
||
)
|
||
if active_avatar_id:
|
||
sum_stmt = sum_stmt.where(models.MediaItem.file_id != active_avatar_id)
|
||
|
||
sum_res = await db.execute(sum_stmt)
|
||
total_used = sum_res.scalar() or 0
|
||
total_used = int(total_used)
|
||
|
||
if total_used + new_file_size <= user_quota:
|
||
return
|
||
|
||
files_res = await db.execute(
|
||
select(models.MediaItem)
|
||
.where(models.MediaItem.owner_id == owner_id)
|
||
.order_by(models.MediaItem.created_at.asc())
|
||
)
|
||
files = files_res.scalars().all()
|
||
service = _get_drive_service()
|
||
|
||
for file_record in files:
|
||
if total_used + new_file_size <= user_quota:
|
||
break
|
||
|
||
try:
|
||
drive_id = file_record.storage_file_id
|
||
# Обернуто в to_thread, так как сетевой вызов Google SDK синхронный
|
||
await asyncio.to_thread(service.files().delete(fileId=drive_id).execute)
|
||
except Exception:
|
||
pass
|
||
|
||
total_used -= file_record.size_bytes
|
||
await db.delete(file_record)
|
||
|
||
await db.commit()
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Эндпоинты: Загрузка (Upload)
|
||
# ---------------------------------------------------------------------------
|
||
class SeekableFastAPIStream(io.RawIOBase):
|
||
def __init__(self, raw_file):
|
||
self.raw_file = raw_file
|
||
try:
|
||
self._position = self.raw_file.tell()
|
||
except Exception:
|
||
self._position = 0
|
||
|
||
def readinto(self, b):
|
||
chunk = self.raw_file.read(len(b))
|
||
if not chunk:
|
||
return 0
|
||
n = len(chunk)
|
||
b[:n] = chunk
|
||
|
||
try:
|
||
self._position = self.raw_file.tell()
|
||
except Exception:
|
||
self._position += n
|
||
return n
|
||
|
||
def seek(self, offset, whence=io.SEEK_SET):
|
||
try:
|
||
self.raw_file.seek(offset, whence)
|
||
self._position = self.raw_file.tell()
|
||
except Exception:
|
||
if whence == io.SEEK_SET:
|
||
self._position = offset
|
||
elif whence == io.SEEK_CUR:
|
||
self._position += offset
|
||
|
||
return self._position
|
||
|
||
def tell(self):
|
||
try:
|
||
return self.raw_file.tell()
|
||
except Exception:
|
||
return self._position
|
||
|
||
def seekable(self):
|
||
return True
|
||
|
||
def readable(self):
|
||
return True
|
||
|
||
async def get_db():
|
||
async with models.AsyncSessionLocal() as db:
|
||
try:
|
||
yield db
|
||
finally:
|
||
await db.close()
|
||
|
||
@mediaRouter.post('/upload')
|
||
@mediaRouter.post('/v2/upload')
|
||
async def upload_file(
|
||
file: UploadFile = File(...),
|
||
current_user: models.User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
if not file.filename:
|
||
raise HTTPException(status_code=400, detail='No selected file')
|
||
|
||
max_upload_size = getattr(config, "MEDIA_UPLOAD_MAX_BYTES", 52428800)
|
||
file_size = file.size
|
||
|
||
if file_size and file_size > max_upload_size:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f'File too large (max {max_upload_size} bytes)'
|
||
)
|
||
|
||
try:
|
||
if file_size:
|
||
await _cleanup_google_drive_quota(db, current_user.id, file_size)
|
||
|
||
service = _get_drive_service()
|
||
file_id = uuid.uuid4().hex
|
||
file_metadata = {
|
||
'name': f"{file_id}.enc",
|
||
'parents': [config.GOOGLE_DRIVE_FOLDER_ID]
|
||
}
|
||
|
||
stream = SeekableFastAPIStream(file.file)
|
||
chunk_size = 5 * 1024 * 1024
|
||
|
||
media = MediaIoBaseUpload(
|
||
stream,
|
||
mimetype=file.content_type or 'application/octet-stream',
|
||
chunksize=chunk_size,
|
||
resumable=True
|
||
)
|
||
if file_size:
|
||
media._size = file_size
|
||
|
||
def _execute_resumable_upload():
|
||
request = service.files().create(
|
||
body=file_metadata,
|
||
media_body=media,
|
||
fields='id,size',
|
||
supportsAllDrives=True
|
||
)
|
||
response = None
|
||
retries = 0
|
||
max_retries = 3
|
||
|
||
while response is None:
|
||
try:
|
||
status, response = request.next_chunk()
|
||
if status:
|
||
print(f"Uploaded {int(status.progress() * 100)}%...")
|
||
retries = 0
|
||
except Exception as e:
|
||
retries += 1
|
||
print(f"Ошибка при загрузке чанка: {e}. Попытка {retries} из {max_retries}")
|
||
if retries >= max_retries:
|
||
raise e
|
||
import time
|
||
time.sleep(1)
|
||
return response
|
||
|
||
drive_file = await asyncio.to_thread(_execute_resumable_upload)
|
||
drive_id = drive_file.get('id')
|
||
final_size = int(drive_file.get('size', 0)) if file_size is None else file_size
|
||
|
||
media_item = models.MediaItem(
|
||
file_id=file_id,
|
||
owner_id=current_user.id,
|
||
original_filename=file.filename,
|
||
content_type=file.content_type or 'application/octet-stream',
|
||
storage_file_id=drive_id,
|
||
size_bytes=final_size,
|
||
)
|
||
db.add(media_item)
|
||
await db.commit()
|
||
|
||
except Exception as e:
|
||
print(f"Upload operation failed: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=500, detail=f"Upload operation failed: {str(e)}"
|
||
)
|
||
finally:
|
||
await file.close()
|
||
|
||
return {'status': 'ok', 'file_id': file_id}
|
||
|
||
@mediaRouter.get('/size/{file_id}')
|
||
async def get_file_size(file_id: str, db: AsyncSession = Depends(get_db)):
|
||
res = await db.execute(select(models.MediaItem).where(models.MediaItem.file_id == file_id))
|
||
db_file = res.scalars().first()
|
||
|
||
if not db_file:
|
||
raise HTTPException(status_code=404, detail='File not found')
|
||
|
||
encoded_filename = urllib.parse.quote(db_file.original_filename)
|
||
return {
|
||
"file_id": file_id,
|
||
"size": db_file.size_bytes,
|
||
"file_name": encoded_filename,
|
||
"content_type": db_file.content_type
|
||
}
|
||
|
||
@mediaRouter.get('/{file_id}')
|
||
async def get_file(file_id: str, db: AsyncSession = Depends(get_db)):
|
||
res = await db.execute(select(models.MediaItem).where(models.MediaItem.file_id == file_id))
|
||
db_file = res.scalars().first()
|
||
|
||
if not db_file:
|
||
raise HTTPException(status_code=404, detail='File not found')
|
||
|
||
drive_id = db_file.storage_file_id
|
||
|
||
try:
|
||
service = _get_drive_service()
|
||
request = service.files().get_media(fileId=drive_id)
|
||
|
||
async def _async_stream_drive_file():
|
||
fh = io.BytesIO()
|
||
downloader = MediaIoBaseDownload(fh, request, chunksize=1024 * 1024)
|
||
done = False
|
||
last_position = 0
|
||
|
||
while not done:
|
||
status, done = await asyncio.to_thread(downloader.next_chunk)
|
||
fh.seek(last_position)
|
||
chunk = fh.read()
|
||
if chunk:
|
||
yield chunk
|
||
last_position = fh.tell()
|
||
|
||
encoded_filename = urllib.parse.quote(db_file.original_filename)
|
||
headers = {
|
||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"
|
||
}
|
||
return StreamingResponse(
|
||
_async_stream_drive_file(),
|
||
media_type=db_file.content_type,
|
||
headers=headers
|
||
)
|
||
except Exception as e:
|
||
raise HTTPException(
|
||
status_code=502,
|
||
detail=f"Error fetching file from Google Drive: {str(e)}"
|
||
)
|
||
|
||
@mediaRouter.post('/copy')
|
||
async def copy(
|
||
file_id: str = Form(...),
|
||
current_user: models.User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db)
|
||
):
|
||
res = await db.execute(select(models.MediaItem).where(models.MediaItem.file_id == file_id))
|
||
old_record = res.scalars().first()
|
||
|
||
if not old_record:
|
||
raise HTTPException(status_code=404, detail='Source file not found')
|
||
|
||
try:
|
||
await _cleanup_google_drive_quota(db, current_user.id, old_record.size_bytes)
|
||
new_file_id = uuid.uuid4().hex
|
||
service = _get_drive_service()
|
||
|
||
def _execute_copy():
|
||
return service.files().copy(
|
||
fileId=old_record.storage_file_id,
|
||
body={'name': f"{new_file_id}.enc"},
|
||
fields='id'
|
||
).execute()
|
||
|
||
drive_file = await asyncio.to_thread(_execute_copy)
|
||
new_drive_id = drive_file.get('id')
|
||
|
||
new_record = models.MediaItem(
|
||
file_id=new_file_id,
|
||
owner_id=current_user.id,
|
||
original_filename=old_record.original_filename,
|
||
content_type=old_record.content_type,
|
||
storage_file_id=new_drive_id,
|
||
size_bytes=old_record.size_bytes,
|
||
)
|
||
db.add(new_record)
|
||
await db.commit()
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Copy operation failed: {str(e)}")
|
||
|
||
return {"status": "ok", "new_file_id": new_file_id} |