Skip to content

Commit 3844504

Browse files
author
Jaap Roes
committed
Add cors decorator
1 parent 75cc53c commit 3844504

File tree

5 files changed

+234
-14
lines changed

5 files changed

+234
-14
lines changed

src/corsheaders/decorators.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import functools
5+
from typing import Any
6+
from typing import Callable
7+
from typing import cast
8+
from typing import Optional
9+
from typing import TypeVar
10+
11+
from django.http import HttpRequest
12+
from django.http import HttpResponseBase
13+
14+
from corsheaders.conf import conf as _conf
15+
from corsheaders.conf import Settings
16+
from corsheaders.middleware import CorsMiddleware
17+
18+
F = TypeVar("F", bound=Callable[..., HttpResponseBase])
19+
20+
21+
def cors(func: Optional[F] = None, *, conf: Settings = _conf) -> F | Callable[[F], F]:
22+
if func is None:
23+
return cast(Callable[[F], F], functools.partial(cors, conf=conf))
24+
25+
assert callable(func)
26+
27+
if asyncio.iscoroutinefunction(func):
28+
29+
async def inner(
30+
_request: HttpRequest, *args: Any, **kwargs: Any
31+
) -> HttpResponseBase:
32+
async def get_response(request: HttpRequest) -> HttpResponseBase:
33+
return await func(request, *args, **kwargs)
34+
35+
return await CorsMiddleware(get_response, conf=conf)(_request)
36+
37+
else:
38+
39+
def inner(_request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponseBase:
40+
def get_response(request: HttpRequest) -> HttpResponseBase:
41+
return func(request, *args, **kwargs)
42+
43+
return CorsMiddleware(get_response, conf=conf)(_request)
44+
45+
wrapper = functools.wraps(func)(inner)
46+
wrapper._skip_cors_middleware = True # type: ignore [attr-defined]
47+
return cast(F, wrapper)

src/corsheaders/middleware.py

+34-14
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import re
5+
from typing import Any
56
from typing import Awaitable
67
from typing import Callable
78
from urllib.parse import SplitResult
@@ -54,22 +55,40 @@ def __call__(
5455
) -> HttpResponseBase | Awaitable[HttpResponseBase]:
5556
if self._is_coroutine:
5657
return self.__acall__(request)
57-
response: HttpResponseBase | None = self.check_preflight(request)
58-
if response is None:
59-
result = self.get_response(request)
60-
assert isinstance(result, HttpResponseBase)
61-
response = result
62-
self.add_response_headers(request, response)
63-
return response
58+
result = self.get_response(request)
59+
assert isinstance(result, HttpResponseBase)
60+
response = result
61+
if getattr(request, "_cors_preflight_done", False):
62+
return response
63+
else:
64+
# Request wasn't processed (e.g. because of a 404)
65+
return self.add_response_headers(
66+
request, self.check_preflight(request) or response
67+
)
6468

6569
async def __acall__(self, request: HttpRequest) -> HttpResponseBase:
66-
response = self.check_preflight(request)
67-
if response is None:
68-
result = self.get_response(request)
69-
assert not isinstance(result, HttpResponseBase)
70-
response = await result
71-
self.add_response_headers(request, response)
72-
return response
70+
result = self.get_response(request)
71+
assert not isinstance(result, HttpResponseBase)
72+
response = await result
73+
if getattr(response, "_cors_processing_done", False):
74+
return response
75+
else:
76+
# View wasn't processed (e.g. because of a 404)
77+
return self.add_response_headers(
78+
request, self.check_preflight(request) or response
79+
)
80+
81+
def process_view(
82+
self,
83+
request: HttpRequest,
84+
callback: Callable[[HttpRequest], HttpResponseBase],
85+
callback_args: Any,
86+
callback_kwargs: Any,
87+
) -> HttpResponseBase | None:
88+
if getattr(callback, "_skip_cors_middleware", False):
89+
# View is decorated and will add CORS headers itself
90+
return None
91+
return self.check_preflight(request)
7392

7493
def check_preflight(self, request: HttpRequest) -> HttpResponseBase | None:
7594
"""
@@ -90,6 +109,7 @@ def add_response_headers(
90109
"""
91110
Add the respective CORS headers
92111
"""
112+
response._cors_processing_done = True
93113
enabled = getattr(request, "_cors_enabled", None)
94114
if enabled is None:
95115
enabled = self.is_enabled(request)

tests/test_decorators.py

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from __future__ import annotations
2+
3+
from django.test import TestCase
4+
from django.test.utils import modify_settings
5+
from django.test.utils import override_settings
6+
7+
from corsheaders.middleware import ACCESS_CONTROL_ALLOW_ORIGIN
8+
9+
10+
@modify_settings(
11+
MIDDLEWARE={
12+
"remove": "corsheaders.middleware.CorsMiddleware",
13+
}
14+
)
15+
@override_settings(CORS_ALLOWED_ORIGINS=["https://example.com"])
16+
class CorsDecoratorsTestCase(TestCase):
17+
def test_get_no_origin(self):
18+
resp = self.client.get("/decorated/hello/")
19+
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp
20+
assert resp.content == b"Decorated: hello"
21+
22+
def test_get_not_in_allowed_origins(self):
23+
resp = self.client.get(
24+
"/decorated/hello/",
25+
HTTP_ORIGIN="https://example.net",
26+
)
27+
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp
28+
assert resp.content == b"Decorated: hello"
29+
30+
def test_get_in_allowed_origins_preflight(self):
31+
resp = self.client.options(
32+
"/decorated/hello/",
33+
HTTP_ORIGIN="https://example.com",
34+
HTTP_ACCESS_CONTROL_REQUEST_METHOD="GET",
35+
)
36+
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
37+
assert resp.content == b""
38+
39+
def test_get_in_allowed_origins(self):
40+
resp = self.client.get(
41+
"/decorated/hello/",
42+
HTTP_ORIGIN="https://example.com",
43+
)
44+
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
45+
assert resp.content == b"Decorated: hello"
46+
47+
async def test_async_get_not_in_allowed_origins(self):
48+
resp = await self.async_client.get(
49+
"/async-decorated/hello/",
50+
origin="https://example.org",
51+
)
52+
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp
53+
assert resp.content == b"Async Decorated: hello"
54+
55+
async def test_async_get_in_allowed_origins_preflight(self):
56+
resp = await self.async_client.options(
57+
"/async-decorated/hello/",
58+
origin="https://example.com",
59+
access_control_request_method="GET",
60+
)
61+
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
62+
assert resp.content == b""
63+
64+
async def test_async_get_in_allowed_origins(self):
65+
resp = await self.async_client.get(
66+
"/async-decorated/hello/",
67+
origin="https://example.com",
68+
)
69+
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
70+
assert resp.content == b"Async Decorated: hello"
71+
72+
73+
class CorsDecoratorsWithConfTestCase(TestCase):
74+
def test_get_no_origin(self):
75+
resp = self.client.get("/decorated-with-conf/hello/")
76+
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp
77+
assert resp.content == b"Decorated (with conf): hello"
78+
79+
def test_get_not_in_allowed_origins(self):
80+
resp = self.client.get(
81+
"/decorated-with-conf/hello/", HTTP_ORIGIN="https://example.net"
82+
)
83+
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp
84+
assert resp.content == b"Decorated (with conf): hello"
85+
86+
def test_get_in_allowed_origins_preflight(self):
87+
resp = self.client.options(
88+
"/decorated-with-conf/hello/",
89+
HTTP_ORIGIN="https://example.com",
90+
HTTP_ACCESS_CONTROL_REQUEST_METHOD="GET",
91+
)
92+
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
93+
assert resp.content == b"Decorated (with conf): hello"
94+
95+
def test_get_in_allowed_origins(self):
96+
resp = self.client.get(
97+
"/decorated-with-conf/hello/",
98+
HTTP_ORIGIN="https://example.com",
99+
)
100+
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
101+
assert resp.content == b"Decorated (with conf): hello"
102+
103+
async def test_async_get_not_in_allowed_origins(self):
104+
resp = await self.async_client.get(
105+
"/async-decorated-with-conf/hello/",
106+
origin="https://example.org",
107+
)
108+
assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp
109+
assert resp.content == b"Async Decorated (with conf): hello"
110+
111+
async def test_async_get_in_allowed_origins_preflight(self):
112+
resp = await self.async_client.options(
113+
"/async-decorated-with-conf/hello/",
114+
origin="https://example.com",
115+
access_control_request_method="GET",
116+
)
117+
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
118+
assert resp.content == b""
119+
120+
async def test_async_get_in_allowed_origins(self):
121+
resp = await self.async_client.get(
122+
"/async-decorated-with-conf/hello/",
123+
origin="https://example.com",
124+
)
125+
assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com"
126+
assert resp.content == b"Async Decorated (with conf): hello"

tests/urls.py

+4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
urlpatterns = [
88
path("", views.index),
99
path("async/", views.async_),
10+
path("decorated/<slug:slug>/", views.decorated),
11+
path("decorated-with-conf/<slug:slug>/", views.decorated_with_conf),
12+
path("async-decorated/<slug:slug>/", views.async_decorated),
13+
path("async-decorated-with-conf/<slug:slug>/", views.async_decorated_with_conf),
1014
path("unauthorized/", views.unauthorized),
1115
path("delete-enabled/", views.delete_enabled_attribute),
1216
]

tests/views.py

+23
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from django.http import HttpResponse
66
from django.views.decorators.http import require_GET
77

8+
from corsheaders.decorators import cors
9+
from corsheaders.conf import Settings
10+
811

912
@require_GET
1013
def index(request):
@@ -15,6 +18,26 @@ async def async_(request):
1518
return HttpResponse("Asynchronous")
1619

1720

21+
@cors
22+
def decorated(request, slug):
23+
return HttpResponse(f"Decorated: {slug}")
24+
25+
26+
@cors(conf=Settings(CORS_ALLOWED_ORIGINS=["https://example.com"]))
27+
def decorated_with_conf(request, slug):
28+
return HttpResponse(f"Decorated (with conf): {slug}")
29+
30+
31+
@cors
32+
async def async_decorated(request, slug):
33+
return HttpResponse(f"Async Decorated: {slug}")
34+
35+
36+
@cors(conf=Settings(CORS_ALLOWED_ORIGINS=["https://example.com"]))
37+
async def async_decorated_with_conf(request, slug):
38+
return HttpResponse(f"Async Decorated (with conf): {slug}")
39+
40+
1841
def unauthorized(request):
1942
return HttpResponse("Unauthorized", status=HTTPStatus.UNAUTHORIZED)
2043

0 commit comments

Comments
 (0)