diff --git a/pyramid_openapi3/__init__.py b/pyramid_openapi3/__init__.py index 13ed88b..b721f94 100644 --- a/pyramid_openapi3/__init__.py +++ b/pyramid_openapi3/__init__.py @@ -66,25 +66,20 @@ def includeme(config: Configurator) -> None: def openapi_validated(request: Request) -> dict: """Get validated parameters.""" - # Validate request and attach all findings for view to introspect - validate_request = asbool( - request.registry.settings.get( - "pyramid_openapi3.enable_request_validation", True - ) - ) - validate_response = asbool( - request.registry.settings.get( - "pyramid_openapi3.enable_response_validation", True + + # we need this here in case someone calls request.openapi_validated on + # a view marked with openapi=False + if not request.environ.get("pyramid_openapi3.enabled"): + raise AttributeError( + "Cannot do openapi request validation on a view marked with openapi=False" ) - ) - request.environ["pyramid_openapi3.validate_response"] = validate_response + gsettings = settings = request.registry.settings["pyramid_openapi3"] route_settings = gsettings.get("routes") if route_settings and request.matched_route.name in route_settings: settings = request.registry.settings[route_settings[request.matched_route.name]] - if validate_request: # pragma: no branch - request.environ["pyramid_openapi3.validate_request"] = True + if request.environ.get("pyramid_openapi3.validate_request"): openapi_request = PyramidOpenAPIRequestFactory.create(request) validated = settings["request_validator"].validate(openapi_request) return validated @@ -97,7 +92,7 @@ def openapi_validated(request: Request) -> dict: def openapi_view(view: View, info: ViewDeriverInfo) -> View: - """View deriver that takes care of request/response validation. + """View deriver that takes care of request validation. If `openapi=True` is passed to `@view_config`, this decorator will: @@ -109,12 +104,29 @@ def openapi_view(view: View, info: ViewDeriverInfo) -> View: if info.options.get("openapi"): def wrapper_view(context: Context, request: Request) -> Response: - validate_request = asbool( + + # We need this to be able to raise AttributeError if view code + # accesses request.openapi_validated on a view that is marked + # with openapi=False + request.environ["pyramid_openapi3.enabled"] = True + + # If view is marked with openapi=True (i.e. we are in this + # function) and registry settings are not set to disable + # validation, then do request/response validation + request.environ["pyramid_openapi3.validate_request"] = asbool( request.registry.settings.get( "pyramid_openapi3.enable_request_validation", True ) ) - if validate_request and request.openapi_validated.errors: + request.environ["pyramid_openapi3.validate_response"] = asbool( + request.registry.settings.get( + "pyramid_openapi3.enable_response_validation", True + ) + ) + + # request validation can happen already here, but response validation + # needs to happen later in a tween + if request.openapi_validated and request.openapi_validated.errors: raise RequestValidationError(errors=request.openapi_validated.errors) # Do the view diff --git a/pyramid_openapi3/tests/test_path_parameters.py b/pyramid_openapi3/tests/test_path_parameters.py index a4ceb9d..5ca8beb 100644 --- a/pyramid_openapi3/tests/test_path_parameters.py +++ b/pyramid_openapi3/tests/test_path_parameters.py @@ -6,14 +6,8 @@ from webtest.app import TestApp -class _FooResource: - def __init__(self, request: Request) -> None: - self.request = request - self.foo_id = request.openapi_validated.parameters["path"]["foo_id"] - - -def _foo_view(context: _FooResource, request: Request) -> int: - return context.foo_id +def _foo_view(request: Request) -> int: + return request.openapi_validated.parameters["path"]["foo_id"] def test_path_parameter_validation() -> None: @@ -46,8 +40,12 @@ def test_path_parameter_validation() -> None: config.include("pyramid_openapi3") config.pyramid_openapi3_spec(tempdoc.name) config.pyramid_openapi3_register_routes() - config.add_route("foo_route", "/foo/{foo_id}", factory=_FooResource) - config.add_view(_foo_view, route_name="foo_route", renderer="json") + # config.add_route("foo_route", "/foo/{foo_id}", factory=_FooResource) + config.add_route("foo_route", "/foo/{foo_id}") + config.add_view( + openapi=True, view=_foo_view, route_name="foo_route", renderer="json" + ) + app = config.make_wsgi_app() test_app = TestApp(app) resp = test_app.get("/foo/1") diff --git a/pyramid_openapi3/tests/test_validation.py b/pyramid_openapi3/tests/test_validation.py index b3d5fa6..519ba89 100644 --- a/pyramid_openapi3/tests/test_validation.py +++ b/pyramid_openapi3/tests/test_validation.py @@ -267,6 +267,31 @@ def test_nonapi_view(self) -> None: self.assertEqual(start_response.status, "200 OK") self.assertIn(b"foo", b"".join(response)) + def test_nonapi_view_raises_AttributeError(self) -> None: + """Test non-openapi view that accesses request.openapi_validated.""" + + def should_raise_error(request): + request.openapi_validated + + self._add_view(openapi=False, view_func=should_raise_error) + # run request through router + router = Router(self.config.registry) + environ = { + "wsgi.url_scheme": "http", + "SERVER_NAME": "localhost", + "SERVER_PORT": "8080", + "REQUEST_METHOD": "GET", + "PATH_INFO": "/foo", + } + start_response = DummyStartResponse() + with self.assertRaises(AttributeError) as cm: + router(environ, start_response) + + self.assertEqual( + str(cm.exception), + "Cannot do openapi request validation on a view marked with openapi=False", + ) + def test_request_validation_disabled(self) -> None: """Test View with request validation disabled.""" self.config.registry.settings[ diff --git a/pyramid_openapi3/tween.py b/pyramid_openapi3/tween.py index 0ef26d3..9cdd626 100644 --- a/pyramid_openapi3/tween.py +++ b/pyramid_openapi3/tween.py @@ -27,8 +27,8 @@ def excview_tween(request: Request) -> Response: try: response = handler(request) if not request.environ.get("pyramid_openapi3.validate_response"): - # not an openapi view or response validation not requested return response + # validate response openapi_request = PyramidOpenAPIRequestFactory.create(request) openapi_response = PyramidOpenAPIResponseFactory.create(response)