Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New feature: File download support #702

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
52 changes: 50 additions & 2 deletions sqladmin/application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import io
import logging
from pathlib import Path
from types import MethodType
from typing import (
TYPE_CHECKING,
Expand All @@ -27,7 +28,12 @@
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse, Response
from starlette.responses import (
FileResponse,
JSONResponse,
RedirectResponse,
Response,
)
from starlette.routing import Mount, Route
from starlette.staticfiles import StaticFiles

Expand All @@ -37,9 +43,11 @@
from sqladmin.authentication import AuthenticationBackend, login_required
from sqladmin.forms import WTFORMS_ATTRS, WTFORMS_ATTRS_REVERSED
from sqladmin.helpers import (
get_filename_from_path,
get_object_identifier,
is_async_session_maker,
slugify_action_name,
value_is_filepath,
)
from sqladmin.models import BaseView, ModelView
from sqladmin.templating import Jinja2Templates
Expand Down Expand Up @@ -116,6 +124,8 @@
templates.env.globals["admin"] = self
templates.env.globals["is_list"] = lambda x: isinstance(x, list)
templates.env.globals["get_object_identifier"] = get_object_identifier
templates.env.globals["value_is_filepath"] = value_is_filepath
templates.env.globals["get_filename_from_path"] = get_filename_from_path

return templates

Expand Down Expand Up @@ -311,6 +321,21 @@
if request.path_params["export_type"] not in model_view.export_types:
raise HTTPException(status_code=404)

async def _get_file(self, request: Request) -> Path:
"""Get file path"""

identity = request.path_params["identity"]
identifier = request.path_params["pk"]
column_name = request.path_params["column_name"]

model_view = self._find_model_view(identity)
file_path = await model_view.get_object_filepath(identifier, column_name)

request_path = Path(file_path)
if not request_path.is_file():
raise HTTPException(status_code=404)

Check warning on line 336 in sqladmin/application.py

View check run for this annotation

Codecov / codecov/patch

sqladmin/application.py#L336

Added line #L336 was not covered by tests
return request_path


class Admin(BaseAdminView):
"""Main entrypoint to admin interface.
Expand Down Expand Up @@ -417,6 +442,18 @@
),
Route("/login", endpoint=self.login, name="login", methods=["GET", "POST"]),
Route("/logout", endpoint=self.logout, name="logout", methods=["GET"]),
Route(
"/{identity}/{pk:path}/{column_name}/download/",
endpoint=self.download_file,
name="file_download",
methods=["GET"],
),
Route(
"/{identity}/{pk:path}/{column_name}/read/",
endpoint=self.reed_file,
name="file_read",
methods=["GET"],
),
]

self.admin.router.routes = routes
Expand Down Expand Up @@ -490,7 +527,6 @@
@login_required
async def create(self, request: Request) -> Response:
"""Create model endpoint."""

await self._create(request)

identity = request.path_params["identity"]
Expand Down Expand Up @@ -626,6 +662,18 @@
await self.authentication_backend.logout(request)
return RedirectResponse(request.url_for("admin:index"), status_code=302)

async def download_file(self, request: Request) -> Response:
"""Download file endpoint."""
request_path = await self._get_file(request)
return FileResponse(request_path, filename=request_path.name)

async def reed_file(self, request: Request) -> Response:
"""Read file endpoint."""
request_path = await self._get_file(request)
return FileResponse(
request_path, filename=request_path.name, content_disposition_type="inline"
)

async def ajax_lookup(self, request: Request) -> Response:
"""Ajax lookup route."""

Expand Down
10 changes: 10 additions & 0 deletions sqladmin/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ def secure_filename(filename: str) -> str:
return filename


def value_is_filepath(value: Any) -> bool:
"""Check if a value is a filepath."""
return isinstance(value, str) and os.path.isfile(value)


def get_filename_from_path(path: str) -> str:
"""Get filename from path."""
return os.path.basename(path)


class Writer(ABC):
"""https://docs.python.org/3/library/csv.html#writer-objects"""

Expand Down
26 changes: 26 additions & 0 deletions sqladmin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,32 @@ async def get_object_for_delete(self, value: Any) -> Any:
stmt = self._stmt_by_identifier(value)
return await self._get_object_by_pk(stmt)

async def get_object_filepath(self, identifier: str, column_name: str) -> Any:
stmt = self._stmt_by_identifier(identifier)
obj = await self._get_object_by_pk(stmt)
column_value = getattr(obj, column_name)
return column_value

async def _get_file_download_link(
self, request: Request, obj: Any, pk: int, column_name: str
) -> Any:
return request.url_for(
"admin:file_download",
identity=slugify_class_name(obj.__class__.__name__),
pk=pk,
column_name=column_name,
)

async def _get_file_link(
self, request: Request, obj: Any, pk: int, column_name: str
) -> Any:
return request.url_for(
"admin:file_read",
identity=slugify_class_name(obj.__class__.__name__),
pk=pk,
column_name=column_name,
)

def _stmt_by_identifier(self, identifier: str) -> Select:
stmt = select(self.model)
pks = get_primary_keys(self.model)
Expand Down
11 changes: 11 additions & 0 deletions sqladmin/templates/details.html
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,19 @@ <h3 class="card-title">
</td>
{% endif %}
{% else %}
{% if value_is_filepath(value) %}
<td>
<a href="{{ model_view._get_file_link(request, model, get_object_identifier(model), name) }}">
{{ get_filename_from_path(value) }}
</a>
<a href="{{ model_view._get_file_download_link(request, model, get_object_identifier(model), name) }}">
<span class="me-1"><i class="fa-solid fa-download"></i></span>
</a>
</td>
{% else %}
<td>{{ formatted_value }}</td>
{% endif %}
{% endif %}
</tr>
{% endfor %}
</tbody>
Expand Down
11 changes: 11 additions & 0 deletions sqladmin/templates/list.html
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,19 @@ <h3 class="card-title">{{ model_view.name_plural }}</h3>
<td><a href="{{ model_view._url_for_details_with_prop(request, row, name) }}">{{ formatted_value }}</a></td>
{% endif %}
{% else %}
{% if value_is_filepath(value) %}
<td>
<a href="{{ model_view._get_file_link(request, row, get_object_identifier(row), name) }}">
{{ get_filename_from_path(value) }}
</a>
<a href="{{ model_view._get_file_download_link(request, row, get_object_identifier(row), name) }}">
<span class="me-1"><i class="fa-solid fa-download"></i></span>
</a>
</td>
{% else %}
<td>{{ formatted_value }}</td>
{% endif %}
{% endif %}
{% endfor %}
</tr>
{% endfor %}
Expand Down
151 changes: 151 additions & 0 deletions tests/test_views/test_file_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import io
import re
from typing import Any, AsyncGenerator

import pytest
from fastapi_storages import FileSystemStorage, StorageFile
from fastapi_storages.integrations.sqlalchemy import FileType
from httpx import AsyncClient
from sqlalchemy import Column, Integer, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import declarative_base, sessionmaker
from starlette.applications import Starlette
from starlette.datastructures import UploadFile

from sqladmin import Admin, ModelView
from tests.common import async_engine as engine

pytestmark = pytest.mark.anyio

Base = declarative_base() # type: Any
session_maker = sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)

app = Starlette()
admin = Admin(app=app, engine=engine)

storage = FileSystemStorage(path=".uploads")


class User(Base):
__tablename__ = "users"

id = Column(Integer, primary_key=True)
file = Column(FileType(FileSystemStorage(".uploads")))


@pytest.fixture
async def prepare_database() -> AsyncGenerator[None, None]:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)

await engine.dispose()


@pytest.fixture
async def client(prepare_database: Any) -> AsyncGenerator[AsyncClient, None]:
async with AsyncClient(app=app, base_url="http://testserver") as c:
yield c


class UserAdmin(ModelView, model=User):
column_list = [User.id, User.file]


admin.add_view(UserAdmin)


async def _query_user() -> Any:
stmt = select(User).limit(1)
async with session_maker() as s:
result = await s.execute(stmt)
return result.scalar_one()


async def test_detail_view(client: AsyncClient) -> None:
async with session_maker() as session:
user = User(file=UploadFile(filename="upload.txt", file=io.BytesIO(b"abc")))
session.add(user)
await session.commit()

response = await client.get("/admin/user/details/1")

user = await _query_user()

assert response.status_code == 200
assert isinstance(user.file, StorageFile) is True
assert user.file.name == "upload.txt"
assert user.file.path == ".uploads/upload.txt"
assert user.file.open().read() == b"abc"

assert (
'<span class="me-1"><i class="fa-solid fa-download"></i></span>'
in response.text
)
assert '<a href="http://testserver/admin/user/1/file/read/">' in response.text
assert '<a href="http://testserver/admin/user/1/file/download/">' in response.text


async def test_list_view(client: AsyncClient) -> None:
async with session_maker() as session:
for i in range(10):
user = User(file=UploadFile(filename="upload.txt", file=io.BytesIO(b"abc")))
session.add(user)
await session.commit()

response = await client.get("/admin/user/list")

user = await _query_user()

assert response.status_code == 200
assert isinstance(user.file, StorageFile) is True
assert user.file.name == "upload.txt"
assert user.file.path == ".uploads/upload.txt"
assert user.file.open().read() == b"abc"

pattern_span = re.compile(
r'<span class="me-1"><i class="fa-solid fa-download"></i></span>'
)
pattern_a_read = re.compile(
r'<a href="http://testserver/admin/user/\d+/file/read/">'
)
pattern_a_download = re.compile(
r'<a href="http://testserver/admin/user/\d+/file/download/">'
)

count_span = len(pattern_span.findall(response.text))
count_a_read = len(pattern_a_read.findall(response.text))
count_a_download = len(pattern_a_download.findall(response.text))

assert count_span == count_a_read == count_a_download == 10


async def test_file_download(client: AsyncClient) -> None:
async with session_maker() as session:
for i in range(10):
user = User(file=UploadFile(filename="upload.txt", file=io.BytesIO(b"abc")))
session.add(user)
await session.commit()

response = await client.get("/admin/user/1/file/download/")

assert response.status_code == 200

with open(".uploads/download.txt", "wb") as local_file:
local_file.write(response.content)

assert open(".uploads/download.txt", "rb").read() == b"abc"


async def test_file_read(client: AsyncClient) -> None:
async with session_maker() as session:
for i in range(10):
user = User(file=UploadFile(filename="upload.txt", file=io.BytesIO(b"abc")))
session.add(user)
await session.commit()

response = await client.get("/admin/user/1/file/read/")
assert response.status_code == 200
assert response.text == "abc"
Loading