Skip to content

Commit

Permalink
Merge branch 'refactor/argilla-server/better-oauth2-integration' into…
Browse files Browse the repository at this point in the history
… feature/better-oauth2-integration-keycloak
  • Loading branch information
frascuchon authored Dec 3, 2024
2 parents 7f319d2 + b3e2cbe commit 0da655e
Show file tree
Hide file tree
Showing 9 changed files with 378 additions and 7 deletions.
1 change: 1 addition & 0 deletions argilla-server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ These are the section headers that we use:

### Changed

- API endpoint added to the User router to allow updates to User objects ([#5615](https://github.com/argilla-io/argilla/pull/5615))
- Changed default python version to 3.13. ([#5649](https://github.com/argilla-io/argilla/pull/5649))
- Changed Pydantic version to v2. ([#5666](https://github.com/argilla-io/argilla/pull/5666))

Expand Down
17 changes: 16 additions & 1 deletion argilla-server/src/argilla_server/api/handlers/v1/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from argilla_server.api.policies.v1 import UserPolicy, authorize
from argilla_server.api.schemas.v1.users import User as UserSchema
from argilla_server.api.schemas.v1.users import UserCreate, Users
from argilla_server.api.schemas.v1.users import UserCreate, Users, UserUpdate
from argilla_server.api.schemas.v1.workspaces import Workspaces
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
Expand Down Expand Up @@ -89,6 +89,21 @@ async def delete_user(
return await accounts.delete_user(db, user)


@router.patch("/users/{user_id}", status_code=status.HTTP_200_OK, response_model=UserSchema)
async def update_user(
*,
db: AsyncSession = Depends(get_async_db),
user_id: UUID,
user_update: UserUpdate,
current_user: User = Security(auth.get_current_user),
):
user = await User.get_or_raise(db, user_id)

await authorize(current_user, UserPolicy.update)

return await accounts.update_user(db, user, user_update.model_dump(exclude_unset=True))


@router.get("/users/{user_id}/workspaces", response_model=Workspaces)
async def list_user_workspaces(
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ async def list(cls, actor: User) -> bool:
async def create(cls, actor: User) -> bool:
return actor.is_owner

@classmethod
async def update(cls, actor: User) -> bool:
return actor.is_owner

@classmethod
async def delete(cls, actor: User) -> bool:
return actor.is_owner
Expand Down
39 changes: 34 additions & 5 deletions argilla-server/src/argilla_server/api/schemas/v1/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,35 @@
# limitations under the License.

from datetime import datetime
from typing import List, Optional
from typing import Annotated, List, Optional
from uuid import UUID

from pydantic import BaseModel, Field, constr, ConfigDict

from argilla_server.api.schemas.v1.commons import UpdateSchema
from argilla_server.enums import UserRole

USER_PASSWORD_MIN_LENGTH = 8
USER_PASSWORD_MAX_LENGTH = 100

UserFirstName = Annotated[
constr(min_length=1, strip_whitespace=True), Field(..., description="The first name for the user")
]
UserLastName = Annotated[
constr(min_length=1, strip_whitespace=True), Field(..., description="The last name for the user")
]
UserUsername = Annotated[str, Field(..., min_length=1, description="The username for the user")]

UserPassword = Annotated[
str,
Field(
...,
min_length=USER_PASSWORD_MIN_LENGTH,
max_length=USER_PASSWORD_MAX_LENGTH,
description="The password for the user",
),
]


class User(BaseModel):
id: UUID
Expand All @@ -40,11 +59,21 @@ class User(BaseModel):


class UserCreate(BaseModel):
username: str = Field(..., min_length=1)
password: str = Field(min_length=USER_PASSWORD_MIN_LENGTH, max_length=USER_PASSWORD_MAX_LENGTH)
first_name: constr(min_length=1, strip_whitespace=True)
last_name: Optional[constr(min_length=1, strip_whitespace=True)] = None
first_name: UserFirstName
last_name: Optional[UserLastName] = None
username: UserUsername
role: Optional[UserRole] = None
password: UserPassword


class UserUpdate(UpdateSchema):
__non_nullable_fields__ = {"first_name", "username", "role", "password"}

first_name: Optional[UserFirstName] = None
last_name: Optional[UserLastName] = None
username: Optional[UserUsername] = None
role: Optional[UserRole] = None
password: Optional[UserPassword] = None


class Users(BaseModel):
Expand Down
12 changes: 12 additions & 0 deletions argilla-server/src/argilla_server/contexts/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,18 @@ async def create_user_with_random_password(
return await create_user(db, user_attrs, workspaces)


async def update_user(db: AsyncSession, user: User, user_attrs: dict) -> User:
username = user_attrs.get("username")
if username is not None and username != user.username:
if await get_user_by_username(db, username):
raise UnprocessableEntityError(f"Username {username!r} already exists")

if "password" in user_attrs:
user_attrs["password_hash"] = hash_password(user_attrs.pop("password"))

return await user.update(db, **user_attrs)


async def delete_user(db: AsyncSession, user: User) -> User:
return await user.delete(db)

Expand Down
Loading

0 comments on commit 0da655e

Please sign in to comment.