Skip to content

Commit

Permalink
Merge pull request #1530 from weaviate/rbac/refactor-has-permission
Browse files Browse the repository at this point in the history
Make changes to user flow of has permissions:
  • Loading branch information
tsmith023 authored Feb 3, 2025
2 parents 62985c4 + 0cf0af2 commit 40c547f
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 57 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ env:
WEAVIATE_125: 1.25.29
WEAVIATE_126: 1.26.13
WEAVIATE_127: 1.27.9
WEAVIATE_128: 1.28.2-2c00437
WEAVIATE_128: 1.28.4-6553adc

jobs:
lint-and-format:
Expand Down
12 changes: 8 additions & 4 deletions integration/test_rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,15 @@ def test_multiple_permissions(client_factory: ClientFactory) -> None:
assert role.data_permissions[0].action == Actions.Data.CREATE
assert role.data_permissions[1].action == Actions.Data.UPDATE

assert client.roles.has_permission(
permission=role.collections_permissions[0], role=role_name
assert client.roles.has_permissions(
permissions=role.collections_permissions[0], role=role_name
)
assert client.roles.has_permission(
permission=required_permissions[1][0], role=role_name
assert client.roles.has_permissions(permissions=role.data_permissions, role=role_name)
assert client.roles.has_permissions(
permissions=required_permissions[1][0], role=role_name
)
assert client.roles.has_permissions(permissions=required_permissions[0], role=role_name)
assert client.roles.has_permissions(permissions=required_permissions[1], role=role_name)
assert client.roles.has_permissions(permissions=required_permissions, role=role_name)
finally:
client.roles.delete(role_name)
93 changes: 57 additions & 36 deletions weaviate/rbac/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ def values() -> List[str]:
return [action.value for action in BackupsAction]


class _Permission(BaseModel):
class _InputPermission(BaseModel):
@abstractmethod
def _to_weaviate(self) -> WeaviatePermission:
raise NotImplementedError()


class _CollectionsPermission(_Permission):
class _CollectionsPermission(_InputPermission):
collection: str
tenant: str
action: CollectionsAction
Expand All @@ -160,7 +160,7 @@ def _to_weaviate(self) -> WeaviatePermission:
}


class TenantsPermission(_Permission):
class _TenantsPermission(_InputPermission):
collection: str
action: TenantsAction

Expand All @@ -174,7 +174,7 @@ def _to_weaviate(self) -> WeaviatePermission:
}


class _NodesPermission(_Permission):
class _NodesPermission(_InputPermission):
verbosity: Verbosity
collection: str
action: NodesAction
Expand All @@ -189,7 +189,7 @@ def _to_weaviate(self) -> WeaviatePermission:
}


class _RolesPermission(_Permission):
class _RolesPermission(_InputPermission):
role: str
action: RolesAction

Expand All @@ -202,14 +202,14 @@ def _to_weaviate(self) -> WeaviatePermission:
}


class _UsersPermission(_Permission):
class _UsersPermission(_InputPermission):
action: UsersAction

def _to_weaviate(self) -> WeaviatePermission:
return {"action": self.action}


class _BackupsPermission(_Permission):
class _BackupsPermission(_InputPermission):
collection: str
action: BackupsAction

Expand All @@ -222,7 +222,7 @@ def _to_weaviate(self) -> WeaviatePermission:
}


class _ClusterPermission(_Permission):
class _ClusterPermission(_InputPermission):
action: ClusterAction

def _to_weaviate(self) -> WeaviatePermission:
Expand All @@ -231,7 +231,7 @@ def _to_weaviate(self) -> WeaviatePermission:
}


class _DataPermission(_Permission):
class _DataPermission(_InputPermission):
collection: str
tenant: str
object_: str
Expand All @@ -248,8 +248,14 @@ def _to_weaviate(self) -> WeaviatePermission:
}


class _OutputPermission:
@abstractmethod
def _to_weaviate(self) -> WeaviatePermission:
raise NotImplementedError()


@dataclass
class CollectionsPermission:
class CollectionsPermission(_OutputPermission):
collection: str
action: CollectionsAction

Expand All @@ -264,7 +270,7 @@ def _to_weaviate(self) -> WeaviatePermission:


@dataclass
class DataPermission:
class DataPermission(_OutputPermission):
collection: str
action: DataAction

Expand All @@ -280,7 +286,7 @@ def _to_weaviate(self) -> WeaviatePermission:


@dataclass
class RolesPermission:
class RolesPermission(_OutputPermission):
role: str
action: RolesAction

Expand All @@ -294,23 +300,23 @@ def _to_weaviate(self) -> WeaviatePermission:


@dataclass
class UsersPermission:
class UsersPermission(_OutputPermission):
action: UsersAction

def _to_weaviate(self) -> WeaviatePermission:
return {"action": self.action}


@dataclass
class ClusterPermission:
class ClusterPermission(_OutputPermission):
action: ClusterAction

def _to_weaviate(self) -> WeaviatePermission:
return {"action": self.action}


@dataclass
class BackupsPermission:
class BackupsPermission(_OutputPermission):
collection: str
action: BackupsAction

Expand All @@ -324,7 +330,7 @@ def _to_weaviate(self) -> WeaviatePermission:


@dataclass
class NodesPermission:
class NodesPermission(_OutputPermission):
collection: Optional[str]
verbosity: Verbosity
action: NodesAction
Expand All @@ -339,6 +345,21 @@ def _to_weaviate(self) -> WeaviatePermission:
}


@dataclass
class TenantsPermission(_OutputPermission):
collection: str
action: TenantsAction

def _to_weaviate(self) -> WeaviatePermission:
return {
"action": self.action,
"tenants": {
"collection": _capitalize_first_letter(self.collection),
"tenant": "*",
},
}


PermissionsOutputType = Union[
ClusterPermission,
CollectionsPermission,
Expand Down Expand Up @@ -474,12 +495,12 @@ class User:


PermissionsInputType = Union[
_Permission,
Sequence[_Permission],
Sequence[Sequence[_Permission]],
Sequence[Union[_Permission, Sequence[_Permission]]],
_InputPermission,
Sequence[_InputPermission],
Sequence[Sequence[_InputPermission]],
Sequence[Union[_InputPermission, Sequence[_InputPermission]]],
]
PermissionsCreateType = List[_Permission]
PermissionsCreateType = List[_InputPermission]


class _DataFactory:
Expand Down Expand Up @@ -536,20 +557,20 @@ def delete(*, collection: Optional[str] = None) -> _CollectionsPermission:

class _TenantsFactory:
@staticmethod
def create(*, collection: Optional[str] = None) -> TenantsPermission:
return TenantsPermission(collection=collection or "*", action=TenantsAction.CREATE)
def create(*, collection: Optional[str] = None) -> _TenantsPermission:
return _TenantsPermission(collection=collection or "*", action=TenantsAction.CREATE)

@staticmethod
def read(*, collection: Optional[str] = None) -> TenantsPermission:
return TenantsPermission(collection=collection or "*", action=TenantsAction.READ)
def read(*, collection: Optional[str] = None) -> _TenantsPermission:
return _TenantsPermission(collection=collection or "*", action=TenantsAction.READ)

@staticmethod
def update(*, collection: Optional[str] = None) -> TenantsPermission:
return TenantsPermission(collection=collection or "*", action=TenantsAction.UPDATE)
def update(*, collection: Optional[str] = None) -> _TenantsPermission:
return _TenantsPermission(collection=collection or "*", action=TenantsAction.UPDATE)

@staticmethod
def delete(*, collection: Optional[str] = None) -> TenantsPermission:
return TenantsPermission(collection=collection or "*", action=TenantsAction.DELETE)
def delete(*, collection: Optional[str] = None) -> _TenantsPermission:
return _TenantsPermission(collection=collection or "*", action=TenantsAction.DELETE)


class _RolesFactory:
Expand Down Expand Up @@ -611,7 +632,7 @@ def data(
update: bool = False,
delete: bool = False,
) -> PermissionsCreateType:
permissions: List[_Permission] = []
permissions: List[_InputPermission] = []
if isinstance(collection, str):
collection = [collection]
for c in collection:
Expand All @@ -634,7 +655,7 @@ def collections(
update_config: bool = False,
delete_collection: bool = False,
) -> PermissionsCreateType:
permissions: List[_Permission] = []
permissions: List[_InputPermission] = []
if isinstance(collection, str):
collection = [collection]
for c in collection:
Expand All @@ -657,7 +678,7 @@ def tenants(
update: bool = False,
delete: bool = False,
) -> PermissionsCreateType:
permissions: List[_Permission] = []
permissions: List[_InputPermission] = []
if isinstance(collection, str):
collection = [collection]
for c in collection:
Expand All @@ -675,7 +696,7 @@ def tenants(
def roles(
*, role: Union[str, Sequence[str]], read: bool = False, manage: bool = False
) -> PermissionsCreateType:
permissions: List[_Permission] = []
permissions: List[_InputPermission] = []
if isinstance(role, str):
role = [role]
for r in role:
Expand All @@ -689,7 +710,7 @@ def roles(
def backup(
*, collection: Union[str, Sequence[str]], manage: bool = False
) -> PermissionsCreateType:
permissions: List[_Permission] = []
permissions: List[_InputPermission] = []
if isinstance(collection, str):
collection = [collection]
for c in collection:
Expand All @@ -704,7 +725,7 @@ def nodes(
verbosity: Verbosity = "minimal",
read: bool = False,
) -> PermissionsCreateType:
permissions: List[_Permission] = []
permissions: List[_InputPermission] = []
if isinstance(collection, str):
collection = [collection]
for c in collection:
Expand All @@ -714,7 +735,7 @@ def nodes(

@staticmethod
def cluster(*, read: bool = False) -> PermissionsCreateType:
permissions: List[_Permission] = []
permissions: List[_InputPermission] = []
if read:
permissions.append(_ClusterFactory.read())
return permissions
42 changes: 30 additions & 12 deletions weaviate/rbac/roles.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import json
from typing import Dict, List, Optional, Union, cast
from typing import Dict, List, Optional, Sequence, Union, cast

from weaviate.connect import ConnectionV4
from weaviate.connect.v4 import _ExpectedStatusCodes
from weaviate.rbac.models import (
_Permission,
_OutputPermission,
_InputPermission,
PermissionsOutputType,
PermissionsInputType,
Role,
Expand Down Expand Up @@ -274,7 +276,7 @@ async def add_permissions(self, *, permissions: PermissionsInputType, role_name:
permissions: The permissions to add to the role.
role_name: The name of the role to add the permissions to.
"""
if isinstance(permissions, _Permission):
if isinstance(permissions, _InputPermission):
permissions = [permissions]
await self._add_permissions(
[permission._to_weaviate() for permission in _flatten_permissions(permissions)],
Expand All @@ -292,17 +294,22 @@ async def remove_permissions(
permissions: The permissions to remove from the role.
role_name: The name of the role to remove the permissions from.
"""
if isinstance(permissions, _Permission):
if isinstance(permissions, _InputPermission):
permissions = [permissions]
await self._remove_permissions(
[permission._to_weaviate() for permission in _flatten_permissions(permissions)],
role_name,
)

async def has_permission(
self, *, permission: Union[_Permission, PermissionsOutputType], role: str
async def has_permissions(
self,
*,
permissions: Union[
PermissionsInputType, PermissionsOutputType, Sequence[PermissionsOutputType]
],
role: str,
) -> bool:
"""Check if a role has a specific permission.
"""Check if a role has a specific set of permission.
Args:
permission: The permission to check.
Expand All @@ -311,15 +318,26 @@ async def has_permission(
Returns:
True if the role has the permission, False otherwise.
"""
return await self._has_permission(permission._to_weaviate(), role)
return all(
await asyncio.gather(
*[
self._has_permission(permission._to_weaviate(), role)
for permission in _flatten_permissions(permissions)
]
)
)


def _flatten_permissions(permissions: PermissionsInputType) -> List[_Permission]:
if isinstance(permissions, _Permission):
def _flatten_permissions(
permissions: Union[PermissionsInputType, PermissionsOutputType, Sequence[PermissionsOutputType]]
) -> List[Union[_InputPermission, _OutputPermission]]:
if isinstance(permissions, _InputPermission) or isinstance(permissions, _OutputPermission):
return [permissions]
flattened_permissions: List[_Permission] = []
flattened_permissions: List[Union[_InputPermission, _OutputPermission]] = []
for permission in permissions:
if isinstance(permission, _Permission):
if isinstance(permission, _InputPermission):
flattened_permissions.append(permission)
elif isinstance(permission, _OutputPermission):
flattened_permissions.append(permission)
else:
flattened_permissions.extend(permission)
Expand Down
Loading

0 comments on commit 40c547f

Please sign in to comment.