Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keycloak SSO #5711

Open
wants to merge 14 commits into
base: refactor/argilla-server/better-oauth2-integration
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions argilla-frontend/components/features/login/components/KeycloakLogo.vue
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
<template>
<!--https://github.com/keycloak/keycloak-misc/blob/main/logo/icon.svg-->
<svg
width="256"
height="256"
viewBox="0 0 44.216 39.861"
fill="none"
xmlns="http://www.w3.org/2000/svg"
>
<path
d="m88.61 138.456 5.716-9.865 23.018-.004 5.686 9.965.007 19.932-5.691 9.957-23.012.008-5.782-9.965z"
style="
display: inline;
fill: #4d4d4d;
fill-opacity: 1;
stroke-width: 0.264583;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="M88.552 158.481h10.375l-5.699-10.041 4.634-9.982-9.252-.002-5.795 10.065"
style="
fill: #ededed;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="M102.073 158.481h7.582l6.706-9.773-6.589-10.156h-8.921l-5.373 9.814z"
style="
fill: #e0e0e0;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="m82.815 148.52 5.738 9.964h10.374l-5.636-9.93z"
style="
fill: #acacac;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="m95.589 148.522 6.484 9.963h7.582l6.601-9.959z"
style="
fill: #9e9e9e;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="m98.157 148.529-1.958.569-1.877-.572 7.667-13.288 1.918 3.316"
style="
fill: #00b8e3;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="m103.9 158.482-1.909 3.332-5.093-5.487-2.58-7.797v-.004h3.838"
style="
fill: #33c6e9;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="M94.322 148.526h-.003v.003l-1.918 3.322-1.925-3.307 1.952-3.386 5.728-9.92h3.834"
style="
fill: #008aaa;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="M115.42 158.481h11.611l-.007-19.93h-11.605z"
style="
fill: #d4d4d4;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="M115.42 148.554v9.93h11.59v-9.93z"
style="
fill: #919191;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="M101.992 161.817h-3.836l-5.755-9.966 1.918-3.321z"
style="
fill: #00b8e3;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="m117.333 148.526-7.669 13.289c-.705-1.036-1.913-3.331-1.913-3.331l5.753-9.959z"
style="
fill: #008aaa;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="m113.495 161.815-3.831-.001 7.67-13.288 1.917-3.317 1.921 3.34m-3.839-.023h-3.828l-5.755-9.973 1.905-3.314 4.658 5.922z"
style="
fill: #00b8e3;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="M119.25 145.205v.003l-1.917 3.318-7.677-13.286 3.841.002z"
style="
fill: #33c6e9;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
</svg>
</template>
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
<template>
<BaseButton class="sign-in-button" @click="$emit('click')">
<KeycloakLogo v-if="provider === 'keycloak'" />
{{ signinText }}
</BaseButton>
</template>

<script>
import KeycloakLogo from "./KeycloakLogo.vue";

export default {
name: "OAuthLoginButton",
components: {
KeycloakLogo,
},
props: {
provider: {
type: String,
Expand Down
2 changes: 1 addition & 1 deletion argilla-frontend/translation/de.js
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ export default {
button: {
ignore_and_continue: "Ignorieren und fortfahren",
login: "Anmelden",
signin_with_provider: "Anmeldung bei {provider} starten",
signin_with_provider: "Mit {provider} anmelden",
"hf-login": "Mit Hugging Face anmelden",
sign_in_with_username: "Mit Benutzername anmelden",
cancel: "Abbrechen",
Expand Down
44 changes: 38 additions & 6 deletions argilla-server/src/argilla_server/api/handlers/v1/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.errors.future import NotFoundError
from argilla_server.models import User
from argilla_server.models import User, UserRole, Workspace, WorkspaceUser
from argilla_server.security.authentication.oauth2 import OAuth2ClientProvider
from argilla_server.security.authentication.userinfo import UserInfo
from argilla_server.security.settings import settings
Expand Down Expand Up @@ -61,14 +61,46 @@ async def get_access_token(
if not userinfo.username:
raise RuntimeError("OAuth error: Missing username")

user = await User.get_by(db, username=userinfo.username)
if user is None:
user = await accounts.create_user_with_random_password(
user_w_workspace = await accounts.get_user_by_username(db, username=userinfo.username)
if user_w_workspace is None:
exs_workspaces = await accounts.list_workspaces(db)
exs_workspaces = [w.name for w in exs_workspaces]
default_available_workspaces = [workspace.name for workspace in settings.oauth.allowed_workspaces]
workspaces = userinfo.available_workspaces or default_available_workspaces
# Check first if workspaces exist
workspaces = [w for w in workspaces if w in exs_workspaces]

user_w_workspace = await accounts.create_user_with_random_password(
db,
username=userinfo.username,
first_name=userinfo.first_name,
last_name=userinfo.last_name,
role=userinfo.role,
workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces],
workspaces=workspaces,
)
else:
# With existing user update the role if needed
if user_w_workspace.role != userinfo.role:
user_w_workspace = await user_w_workspace.update(db, role=userinfo.role)
# With existing user update the workspaces if needed
if user_w_workspace.role != UserRole.owner and set(user_w_workspace.workspaces) != set(
userinfo.available_workspaces
):
for workspace_name in userinfo.available_workspaces:
workspace = await Workspace.get_by(db, name=workspace_name)
if not workspace:
continue

await WorkspaceUser.create(
db,
workspace_id=workspace.id,
user_id=user_w_workspace.id,
autocommit=False,
)
for workspace in user_w_workspace.workspaces:
if workspace.name not in userinfo.available_workspaces:
ws_user = await WorkspaceUser.get_by(db, workspace_id=workspace.id, user_id=user_w_workspace.id)
await ws_user.delete(db, autocommit=False)
await db.commit()

return Token(access_token=accounts.generate_user_token(user))
return Token(access_token=accounts.generate_user_token(user_w_workspace))
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.

import os
from typing import Type, Dict, Any
from typing import Type, Dict, Any, Optional, List

from social_core.backends.oauth import BaseOAuth2
from social_core.backends.open_id_connect import OpenIdConnectAuth
from social_core.backends.utils import load_backends
from social_core.strategy import BaseStrategy

from argilla_server.errors.future import NotFoundError
from argilla_server.models import UserRole


class Strategy(BaseStrategy):
Expand Down Expand Up @@ -48,6 +49,61 @@ class HuggingfaceOpenId(OpenIdConnectAuth):
DEFAULT_SCOPE = ["openid", "profile"]


class KeycloakOpenId(OpenIdConnectAuth):
"""Huggingface OpenID Connect authentication backend."""

name = "keycloak"

def oidc_endpoint(self) -> str:
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
value = super().oidc_endpoint()

if value is None:
from social_core.utils import setting_name

name = setting_name("OIDC_ENDPOINT")
raise ValueError(
"oidc_endpoint needs to be set in the Keycloak configuration. "
f"Please set the {name} environment variable."
)

return value

def get_user_details(self, response: Dict[str, Any]) -> Dict[str, Any]:
user = super().get_user_details(response)

if role := self._extract_role(response):
user["role"] = role

if available_workspaces := self._extract_available_workspaces(response):
user["available_workspaces"] = available_workspaces

return user

def _extract_role(self, response: Dict[str, Any]) -> Optional[str]:
roles = self._read_realm_roles(response)
role_to_value = {UserRole.owner: 3, UserRole.admin: 2, UserRole.annotator: 1}
role_list = [role.split(":")[1] for role in roles if role.startswith("argilla_role:")]
if role_list:
max_role = max(role_list, key=lambda s: role_to_value.get(s, 0))
return max_role

def _extract_available_workspaces(self, response: Dict[str, Any]) -> List[str]:
roles = self._read_realm_roles(response)

workspaces = []
for role in roles:
if role.startswith("argilla_workspace:"):
workspace = role.split(":")[1]
workspaces.append(workspace)

return workspaces

@classmethod
def _read_realm_roles(cls, response) -> List[str]:
realm_access = response.get("realm_access") or {}
return realm_access.get("roles") or []


_SUPPORTED_BACKENDS = {}


Expand All @@ -56,6 +112,7 @@ def load_supported_backends(extra_backends: list = None) -> Dict[str, Type[BaseO

backends = [
"argilla_server.security.authentication.oauth2._backends.HuggingfaceOpenId",
"argilla_server.security.authentication.oauth2._backends.KeycloakOpenId",
"social_core.backends.github.GithubOAuth2",
"social_core.backends.google.GoogleOAuth2",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,19 @@ def username(self) -> str:
def first_name(self) -> str:
return self.get("first_name") or self.username

@property
def last_name(self) -> str:
return self.get("last_name") or ""

@property
def role(self) -> UserRole:
role = self.get("role") or self._parse_role_from_environment()
return UserRole(role)

@property
def available_workspaces(self) -> Optional[list]:
return self.get("available_workspaces")

def _parse_role_from_environment(self) -> Optional[UserRole]:
"""This is a temporal solution, and it will be replaced by a proper Sign up process"""
if self["username"] == os.getenv("USERNAME"):
Expand Down
Loading