diff --git a/sqladmin/application.py b/sqladmin/application.py index 7f582690..6215b088 100644 --- a/sqladmin/application.py +++ b/sqladmin/application.py @@ -1,6 +1,7 @@ import inspect import io import logging +from pathlib import Path from types import MethodType from typing import ( TYPE_CHECKING, @@ -27,7 +28,12 @@ from starlette.exceptions import HTTPException from starlette.middleware import Middleware from starlette.requests import Request -from starlette.responses import JSONResponse, RedirectResponse, Response +from starlette.responses import ( + FileResponse, + JSONResponse, + RedirectResponse, + Response, +) from starlette.routing import Mount, Route from starlette.staticfiles import StaticFiles @@ -37,9 +43,11 @@ from sqladmin.authentication import AuthenticationBackend, login_required from sqladmin.forms import WTFORMS_ATTRS, WTFORMS_ATTRS_REVERSED from sqladmin.helpers import ( + get_filename_from_path, get_object_identifier, is_async_session_maker, slugify_action_name, + value_is_filepath, ) from sqladmin.models import BaseView, ModelView from sqladmin.templating import Jinja2Templates @@ -116,6 +124,8 @@ def init_templating_engine(self) -> Jinja2Templates: templates.env.globals["admin"] = self templates.env.globals["is_list"] = lambda x: isinstance(x, list) templates.env.globals["get_object_identifier"] = get_object_identifier + templates.env.globals["value_is_filepath"] = value_is_filepath + templates.env.globals["get_filename_from_path"] = get_filename_from_path return templates @@ -311,6 +321,21 @@ async def _export(self, request: Request) -> None: if request.path_params["export_type"] not in model_view.export_types: raise HTTPException(status_code=404) + async def _get_file(self, request: Request) -> Path: + """Get file path""" + + identity = request.path_params["identity"] + identifier = request.path_params["pk"] + column_name = request.path_params["column_name"] + + model_view = self._find_model_view(identity) + file_path = await model_view.get_object_filepath(identifier, column_name) + + request_path = Path(file_path) + if not request_path.is_file(): + raise HTTPException(status_code=404) + return request_path + class Admin(BaseAdminView): """Main entrypoint to admin interface. @@ -417,6 +442,18 @@ async def http_exception( ), Route("/login", endpoint=self.login, name="login", methods=["GET", "POST"]), Route("/logout", endpoint=self.logout, name="logout", methods=["GET"]), + Route( + "/{identity}/{pk:path}/{column_name}/download/", + endpoint=self.download_file, + name="file_download", + methods=["GET"], + ), + Route( + "/{identity}/{pk:path}/{column_name}/read/", + endpoint=self.reed_file, + name="file_read", + methods=["GET"], + ), ] self.admin.router.routes = routes @@ -490,7 +527,6 @@ async def delete(self, request: Request) -> Response: @login_required async def create(self, request: Request) -> Response: """Create model endpoint.""" - await self._create(request) identity = request.path_params["identity"] @@ -626,6 +662,18 @@ async def logout(self, request: Request) -> Response: await self.authentication_backend.logout(request) return RedirectResponse(request.url_for("admin:index"), status_code=302) + async def download_file(self, request: Request) -> Response: + """Download file endpoint.""" + request_path = await self._get_file(request) + return FileResponse(request_path, filename=request_path.name) + + async def reed_file(self, request: Request) -> Response: + """Read file endpoint.""" + request_path = await self._get_file(request) + return FileResponse( + request_path, filename=request_path.name, content_disposition_type="inline" + ) + async def ajax_lookup(self, request: Request) -> Response: """Ajax lookup route.""" diff --git a/sqladmin/helpers.py b/sqladmin/helpers.py index 0d8cc518..26731b37 100644 --- a/sqladmin/helpers.py +++ b/sqladmin/helpers.py @@ -132,6 +132,16 @@ def secure_filename(filename: str) -> str: return filename +def value_is_filepath(value: Any) -> bool: + """Check if a value is a filepath.""" + return isinstance(value, str) and os.path.isfile(value) + + +def get_filename_from_path(path: str) -> str: + """Get filename from path.""" + return os.path.basename(path) + + class Writer(ABC): """https://docs.python.org/3/library/csv.html#writer-objects""" diff --git a/sqladmin/models.py b/sqladmin/models.py index 896010f5..4927fece 100644 --- a/sqladmin/models.py +++ b/sqladmin/models.py @@ -818,6 +818,32 @@ async def get_object_for_delete(self, value: Any) -> Any: stmt = self._stmt_by_identifier(value) return await self._get_object_by_pk(stmt) + async def get_object_filepath(self, identifier: str, column_name: str) -> Any: + stmt = self._stmt_by_identifier(identifier) + obj = await self._get_object_by_pk(stmt) + column_value = getattr(obj, column_name) + return column_value + + async def _get_file_download_link( + self, request: Request, obj: Any, pk: int, column_name: str + ) -> Any: + return request.url_for( + "admin:file_download", + identity=slugify_class_name(obj.__class__.__name__), + pk=pk, + column_name=column_name, + ) + + async def _get_file_link( + self, request: Request, obj: Any, pk: int, column_name: str + ) -> Any: + return request.url_for( + "admin:file_read", + identity=slugify_class_name(obj.__class__.__name__), + pk=pk, + column_name=column_name, + ) + def _stmt_by_identifier(self, identifier: str) -> Select: stmt = select(self.model) pks = get_primary_keys(self.model) diff --git a/sqladmin/templates/details.html b/sqladmin/templates/details.html index 34ae2c36..a9acdf64 100644 --- a/sqladmin/templates/details.html +++ b/sqladmin/templates/details.html @@ -36,8 +36,19 @@