diff --git a/sqladmin/application.py b/sqladmin/application.py index 7f582690..6215b088 100644 --- a/sqladmin/application.py +++ b/sqladmin/application.py @@ -1,6 +1,7 @@ import inspect import io import logging +from pathlib import Path from types import MethodType from typing import ( TYPE_CHECKING, @@ -27,7 +28,12 @@ from starlette.exceptions import HTTPException from starlette.middleware import Middleware from starlette.requests import Request -from starlette.responses import JSONResponse, RedirectResponse, Response +from starlette.responses import ( + FileResponse, + JSONResponse, + RedirectResponse, + Response, +) from starlette.routing import Mount, Route from starlette.staticfiles import StaticFiles @@ -37,9 +43,11 @@ from sqladmin.authentication import AuthenticationBackend, login_required from sqladmin.forms import WTFORMS_ATTRS, WTFORMS_ATTRS_REVERSED from sqladmin.helpers import ( + get_filename_from_path, get_object_identifier, is_async_session_maker, slugify_action_name, + value_is_filepath, ) from sqladmin.models import BaseView, ModelView from sqladmin.templating import Jinja2Templates @@ -116,6 +124,8 @@ def init_templating_engine(self) -> Jinja2Templates: templates.env.globals["admin"] = self templates.env.globals["is_list"] = lambda x: isinstance(x, list) templates.env.globals["get_object_identifier"] = get_object_identifier + templates.env.globals["value_is_filepath"] = value_is_filepath + templates.env.globals["get_filename_from_path"] = get_filename_from_path return templates @@ -311,6 +321,21 @@ async def _export(self, request: Request) -> None: if request.path_params["export_type"] not in model_view.export_types: raise HTTPException(status_code=404) + async def _get_file(self, request: Request) -> Path: + """Get file path""" + + identity = request.path_params["identity"] + identifier = request.path_params["pk"] + column_name = request.path_params["column_name"] + + model_view = self._find_model_view(identity) + file_path = await model_view.get_object_filepath(identifier, column_name) + + request_path = Path(file_path) + if not request_path.is_file(): + raise HTTPException(status_code=404) + return request_path + class Admin(BaseAdminView): """Main entrypoint to admin interface. @@ -417,6 +442,18 @@ async def http_exception( ), Route("/login", endpoint=self.login, name="login", methods=["GET", "POST"]), Route("/logout", endpoint=self.logout, name="logout", methods=["GET"]), + Route( + "/{identity}/{pk:path}/{column_name}/download/", + endpoint=self.download_file, + name="file_download", + methods=["GET"], + ), + Route( + "/{identity}/{pk:path}/{column_name}/read/", + endpoint=self.reed_file, + name="file_read", + methods=["GET"], + ), ] self.admin.router.routes = routes @@ -490,7 +527,6 @@ async def delete(self, request: Request) -> Response: @login_required async def create(self, request: Request) -> Response: """Create model endpoint.""" - await self._create(request) identity = request.path_params["identity"] @@ -626,6 +662,18 @@ async def logout(self, request: Request) -> Response: await self.authentication_backend.logout(request) return RedirectResponse(request.url_for("admin:index"), status_code=302) + async def download_file(self, request: Request) -> Response: + """Download file endpoint.""" + request_path = await self._get_file(request) + return FileResponse(request_path, filename=request_path.name) + + async def reed_file(self, request: Request) -> Response: + """Read file endpoint.""" + request_path = await self._get_file(request) + return FileResponse( + request_path, filename=request_path.name, content_disposition_type="inline" + ) + async def ajax_lookup(self, request: Request) -> Response: """Ajax lookup route.""" diff --git a/sqladmin/helpers.py b/sqladmin/helpers.py index 0d8cc518..26731b37 100644 --- a/sqladmin/helpers.py +++ b/sqladmin/helpers.py @@ -132,6 +132,16 @@ def secure_filename(filename: str) -> str: return filename +def value_is_filepath(value: Any) -> bool: + """Check if a value is a filepath.""" + return isinstance(value, str) and os.path.isfile(value) + + +def get_filename_from_path(path: str) -> str: + """Get filename from path.""" + return os.path.basename(path) + + class Writer(ABC): """https://docs.python.org/3/library/csv.html#writer-objects""" diff --git a/sqladmin/models.py b/sqladmin/models.py index 896010f5..4927fece 100644 --- a/sqladmin/models.py +++ b/sqladmin/models.py @@ -818,6 +818,32 @@ async def get_object_for_delete(self, value: Any) -> Any: stmt = self._stmt_by_identifier(value) return await self._get_object_by_pk(stmt) + async def get_object_filepath(self, identifier: str, column_name: str) -> Any: + stmt = self._stmt_by_identifier(identifier) + obj = await self._get_object_by_pk(stmt) + column_value = getattr(obj, column_name) + return column_value + + async def _get_file_download_link( + self, request: Request, obj: Any, pk: int, column_name: str + ) -> Any: + return request.url_for( + "admin:file_download", + identity=slugify_class_name(obj.__class__.__name__), + pk=pk, + column_name=column_name, + ) + + async def _get_file_link( + self, request: Request, obj: Any, pk: int, column_name: str + ) -> Any: + return request.url_for( + "admin:file_read", + identity=slugify_class_name(obj.__class__.__name__), + pk=pk, + column_name=column_name, + ) + def _stmt_by_identifier(self, identifier: str) -> Select: stmt = select(self.model) pks = get_primary_keys(self.model) diff --git a/sqladmin/templates/details.html b/sqladmin/templates/details.html index 34ae2c36..a9acdf64 100644 --- a/sqladmin/templates/details.html +++ b/sqladmin/templates/details.html @@ -36,8 +36,19 @@

{% endif %} {% else %} + {% if value_is_filepath(value) %} + + + {{ get_filename_from_path(value) }} + + + + + + {% else %} {{ formatted_value }} {% endif %} + {% endif %} {% endfor %} diff --git a/sqladmin/templates/list.html b/sqladmin/templates/list.html index e574026b..e094f487 100644 --- a/sqladmin/templates/list.html +++ b/sqladmin/templates/list.html @@ -151,8 +151,19 @@

{{ model_view.name_plural }}

{{ formatted_value }} {% endif %} {% else %} + {% if value_is_filepath(value) %} + + + {{ get_filename_from_path(value) }} + + + + + + {% else %} {{ formatted_value }} {% endif %} + {% endif %} {% endfor %} {% endfor %} diff --git a/tests/test_views/test_file_view.py b/tests/test_views/test_file_view.py new file mode 100644 index 00000000..d2bd7544 --- /dev/null +++ b/tests/test_views/test_file_view.py @@ -0,0 +1,151 @@ +import io +import re +from typing import Any, AsyncGenerator + +import pytest +from fastapi_storages import FileSystemStorage, StorageFile +from fastapi_storages.integrations.sqlalchemy import FileType +from httpx import AsyncClient +from sqlalchemy import Column, Integer, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import declarative_base, sessionmaker +from starlette.applications import Starlette +from starlette.datastructures import UploadFile + +from sqladmin import Admin, ModelView +from tests.common import async_engine as engine + +pytestmark = pytest.mark.anyio + +Base = declarative_base() # type: Any +session_maker = sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False) + +app = Starlette() +admin = Admin(app=app, engine=engine) + +storage = FileSystemStorage(path=".uploads") + + +class User(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True) + file = Column(FileType(FileSystemStorage(".uploads"))) + + +@pytest.fixture +async def prepare_database() -> AsyncGenerator[None, None]: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + await engine.dispose() + + +@pytest.fixture +async def client(prepare_database: Any) -> AsyncGenerator[AsyncClient, None]: + async with AsyncClient(app=app, base_url="http://testserver") as c: + yield c + + +class UserAdmin(ModelView, model=User): + column_list = [User.id, User.file] + + +admin.add_view(UserAdmin) + + +async def _query_user() -> Any: + stmt = select(User).limit(1) + async with session_maker() as s: + result = await s.execute(stmt) + return result.scalar_one() + + +async def test_detail_view(client: AsyncClient) -> None: + async with session_maker() as session: + user = User(file=UploadFile(filename="upload.txt", file=io.BytesIO(b"abc"))) + session.add(user) + await session.commit() + + response = await client.get("/admin/user/details/1") + + user = await _query_user() + + assert response.status_code == 200 + assert isinstance(user.file, StorageFile) is True + assert user.file.name == "upload.txt" + assert user.file.path == ".uploads/upload.txt" + assert user.file.open().read() == b"abc" + + assert ( + '' + in response.text + ) + assert '' in response.text + assert '' in response.text + + +async def test_list_view(client: AsyncClient) -> None: + async with session_maker() as session: + for i in range(10): + user = User(file=UploadFile(filename="upload.txt", file=io.BytesIO(b"abc"))) + session.add(user) + await session.commit() + + response = await client.get("/admin/user/list") + + user = await _query_user() + + assert response.status_code == 200 + assert isinstance(user.file, StorageFile) is True + assert user.file.name == "upload.txt" + assert user.file.path == ".uploads/upload.txt" + assert user.file.open().read() == b"abc" + + pattern_span = re.compile( + r'' + ) + pattern_a_read = re.compile( + r'' + ) + pattern_a_download = re.compile( + r'' + ) + + count_span = len(pattern_span.findall(response.text)) + count_a_read = len(pattern_a_read.findall(response.text)) + count_a_download = len(pattern_a_download.findall(response.text)) + + assert count_span == count_a_read == count_a_download == 10 + + +async def test_file_download(client: AsyncClient) -> None: + async with session_maker() as session: + for i in range(10): + user = User(file=UploadFile(filename="upload.txt", file=io.BytesIO(b"abc"))) + session.add(user) + await session.commit() + + response = await client.get("/admin/user/1/file/download/") + + assert response.status_code == 200 + + with open(".uploads/download.txt", "wb") as local_file: + local_file.write(response.content) + + assert open(".uploads/download.txt", "rb").read() == b"abc" + + +async def test_file_read(client: AsyncClient) -> None: + async with session_maker() as session: + for i in range(10): + user = User(file=UploadFile(filename="upload.txt", file=io.BytesIO(b"abc"))) + session.add(user) + await session.commit() + + response = await client.get("/admin/user/1/file/read/") + assert response.status_code == 200 + assert response.text == "abc"