diff --git a/scim2_client/client.py b/scim2_client/client.py index 10d1103..f4615f5 100644 --- a/scim2_client/client.py +++ b/scim2_client/client.py @@ -205,10 +205,13 @@ def get_resource_model(self, name: str) -> Optional[type[Resource]]: def _check_resource_model( self, resource_model: type[Resource], payload=None ) -> None: - if ( - resource_model not in self.resource_models - and resource_model not in CONFIG_RESOURCES - ): + schema_to_check = resource_model.model_fields["schemas"].default[0] + for element in self.resource_models: + schema = element.model_fields["schemas"].default[0] + if schema_to_check == schema: + return + + if resource_model not in CONFIG_RESOURCES: raise SCIMRequestError( f"Unknown resource type: '{resource_model}'", source=payload ) @@ -640,13 +643,13 @@ def build_resource_models( for schema, resource_type in resource_types_by_schema.items(): schema_obj = schema_objs_by_schema[schema] model = Resource.from_schema(schema_obj) - extensions = [] + extensions: tuple[type[Extension], ...] = () for ext_schema in resource_type.schema_extensions or []: schema_obj = schema_objs_by_schema[ext_schema.schema_] extension = Extension.from_schema(schema_obj) - extensions.append(extension) + extensions = extensions + (extension,) if extensions: - model = model[Union[tuple(extensions)]] + model = model[Union[extensions]] resource_models.append(model) return tuple(resource_models) diff --git a/tests/test_delete.py b/tests/test_delete.py index 4d17a22..6d6d4e2 100644 --- a/tests/test_delete.py +++ b/tests/test_delete.py @@ -1,12 +1,16 @@ import pytest from scim2_models import Error -from scim2_models import Group +from scim2_models import Resource from scim2_models import User from scim2_client import RequestNetworkError from scim2_client import SCIMRequestError +class UnregisteredResource(Resource): + schemas: list[str] = ["urn:test:schemas:UnregisteredResource"] + + def test_delete_user(httpserver, sync_client): """Nominal case for a User deletion.""" httpserver.expect_request( @@ -45,7 +49,7 @@ def test_errors(httpserver, code, sync_client): def test_invalid_resource_model(httpserver, sync_client): """Test that resource_models passed to the method must be part of SCIMClient.resource_models.""" with pytest.raises(SCIMRequestError, match=r"Unknown resource type"): - sync_client.delete(Group(display_name="foobar"), id="foobar") + sync_client.delete(UnregisteredResource, id="foobar") def test_dont_check_response_payload(httpserver, sync_client): diff --git a/tests/test_discovery.py b/tests/test_discovery.py new file mode 100644 index 0000000..9006ac1 --- /dev/null +++ b/tests/test_discovery.py @@ -0,0 +1,96 @@ +import threading +import wsgiref.simple_server +from typing import Annotated +from typing import Union + +import portpicker +import pytest +from httpx import Client +from scim2_models import EnterpriseUser +from scim2_models import Extension +from scim2_models import Group +from scim2_models import Meta +from scim2_models import Required +from scim2_models import ResourceType +from scim2_models import User + +from scim2_client.engines.httpx import SyncSCIMClient + +scim2_server = pytest.importorskip("scim2_server") +from scim2_server.backend import InMemoryBackend # noqa: E402 +from scim2_server.provider import SCIMProvider # noqa: E402 + + +class OtherExtension(Extension): + schemas: Annotated[list[str], Required.true] = [ + "urn:ietf:params:scim:schemas:extension:Other:1.0:User" + ] + + test: str | None = None + test2: list[str] | None = None + + +def get_schemas(): + schemas = [ + User.to_schema(), + Group.to_schema(), + OtherExtension.to_schema(), + EnterpriseUser.to_schema(), + ] + + # SCIMProvider register_schema requires meta object to be set + for schema in schemas: + schema.meta = Meta(resource_type="Schema") + + return schemas + + +def get_resource_types(): + resource_types = [ + ResourceType.from_resource(User[Union[EnterpriseUser, OtherExtension]]), + ResourceType.from_resource(Group), + ] + + # SCIMProvider register_resource_type requires meta object to be set + for resource_type in resource_types: + resource_type.meta = Meta(resource_type="ResourceType") + + return resource_types + + +@pytest.fixture(scope="session") +def server(): + backend = InMemoryBackend() + provider = SCIMProvider(backend) + for schema in get_schemas(): + provider.register_schema(schema) + for resource_type in get_resource_types(): + provider.register_resource_type(resource_type) + + host = "localhost" + port = portpicker.pick_unused_port() + httpd = wsgiref.simple_server.make_server(host, port, provider) + + server_thread = threading.Thread(target=httpd.serve_forever) + server_thread.start() + try: + yield host, port + finally: + httpd.shutdown() + server_thread.join() + + +def test_discovery_resource_types_multiple_extensions(server): + host, port = server + client = Client(base_url=f"http://{host}:{port}") + scim_client = SyncSCIMClient(client) + + scim_client.discover() + assert scim_client.get_resource_model("User") + assert scim_client.get_resource_model("Group") + + # Try to create a user to see if discover filled everything correctly + user_request = User[Union[EnterpriseUser, OtherExtension]]( + user_name="bjensen@example.com" + ) + scim_client.create(user_request)