Skip to content

feat: add asyncio.Lock to prevent concurrent refreshes #234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions fastapi_azure_auth/openid_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from asyncio import Lock
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Dict, List, Optional

Expand All @@ -11,6 +12,8 @@

log = logging.getLogger('fastapi_azure_auth')

refresh_lock: Lock = Lock()


class OpenIdConfig:
def __init__(
Expand All @@ -35,24 +38,25 @@ async def load_config(self) -> None:
"""
Loads config from the Intility openid-config endpoint if it's over 24 hours old (or don't exist)
"""
refresh_time = datetime.now() - timedelta(hours=24)
if not self._config_timestamp or self._config_timestamp < refresh_time:
try:
log.debug('Loading Azure Entra ID OpenID configuration.')
await self._load_openid_config()
self._config_timestamp = datetime.now()
except Exception as error:
log.exception('Unable to fetch OpenID configuration from Azure Entra ID. Error: %s', error)
# We can't fetch an up to date openid-config, so authentication will not work.
if self._config_timestamp:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Connection to Azure Entra ID is down. Unable to fetch provider configuration',
headers={'WWW-Authenticate': 'Bearer'},
) from error

else:
raise RuntimeError(f'Unable to fetch provider information. {error}') from error
async with refresh_lock:
refresh_time = datetime.now() - timedelta(hours=24)
if not self._config_timestamp or self._config_timestamp < refresh_time:
try:
log.debug('Loading Azure Entra ID OpenID configuration.')
await self._load_openid_config()
self._config_timestamp = datetime.now()
except Exception as error:
log.exception('Unable to fetch OpenID configuration from Azure Entra ID. Error: %s', error)
# We can't fetch an up to date openid-config, so authentication will not work.
if self._config_timestamp:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Connection to Azure Entra ID is down. Unable to fetch provider configuration',
headers={'WWW-Authenticate': 'Bearer'},
) from error

else:
raise RuntimeError(f'Unable to fetch provider information. {error}') from error

log.info('fastapi-azure-auth loaded settings from Azure Entra ID.')
log.info('authorization endpoint: %s', self.authorization_endpoint)
Expand Down
30 changes: 29 additions & 1 deletion tests/test_provider_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import asyncio
from datetime import datetime, timedelta

import httpx
import pytest
import respx
from asgi_lifespan import LifespanManager
from demo_project.api.dependencies import azure_scheme
from demo_project.main import app
from httpx import AsyncClient
from tests.utils import build_access_token, build_openid_keys, openid_configuration
from tests.utils import build_access_token, build_openid_keys, keys_url, openid_config_url, openid_configuration

from fastapi_azure_auth.openid_config import OpenIdConfig

Expand Down Expand Up @@ -64,3 +67,28 @@ async def test_custom_config_id(respx_mock):
)
await openid_config.load_config()
assert len(openid_config.signing_keys) == 2


async def test_concurrent_refresh_requests():
"""Test that concurrent refreshes are handled correctly"""
with respx.mock(assert_all_called=True) as mock:

async def slow_config_response(*args, **kwargs):
await asyncio.sleep(0.2)
return httpx.Response(200, json=openid_configuration())

async def slow_keys_response(*args, **kwargs):
await asyncio.sleep(0.2)
return httpx.Response(200, json=build_openid_keys())

config_route = mock.get(openid_config_url()).mock(side_effect=slow_config_response)
keys_route = mock.get(keys_url()).mock(side_effect=slow_keys_response)

azure_scheme.openid_config._config_timestamp = None

tasks = [azure_scheme.openid_config.load_config() for _ in range(5)]
await asyncio.gather(*tasks)

assert len(config_route.calls) == 1, "Config endpoint called multiple times"
assert len(keys_route.calls) == 1, "Keys endpoint called multiple times"
assert len(azure_scheme.openid_config.signing_keys) == 2
Loading