Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions openfga_sdk/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import urllib3

from openfga_sdk.configuration import Configuration
from openfga_sdk.constants import USER_AGENT
from openfga_sdk.constants import (
TOKEN_EXPIRY_JITTER_IN_SEC,
TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC,
USER_AGENT,
)
from openfga_sdk.credentials import Credentials
from openfga_sdk.exceptions import AuthenticationError
from openfga_sdk.telemetry.attributes import TelemetryAttributes
Expand All @@ -36,6 +40,8 @@ def __init__(self, credentials: Credentials, configuration=None):
self._credentials = credentials
self._access_token = None
self._access_expiry_time = None
self._access_token_expiry_buffer = 0
self._lock = asyncio.Lock()
self._telemetry = Telemetry()

if configuration is None:
Expand All @@ -45,13 +51,12 @@ def __init__(self, credentials: Credentials, configuration=None):

def _token_valid(self):
"""
Return whether token is valid
Return whether token is valid (with proactive expiry buffer to avoid using near-expired tokens)
"""
if self._access_token is None or self._access_expiry_time is None:
return False
if self._access_expiry_time < datetime.now():
return False
return True
remaining = (self._access_expiry_time - datetime.now()).total_seconds()
return remaining > self._access_token_expiry_buffer

async def _obtain_token(self, client):
"""
Expand All @@ -76,7 +81,9 @@ async def _obtain_token(self, client):
# Add scope parameter if scopes are configured
if configuration.scopes is not None:
if isinstance(configuration.scopes, list):
scope_str = " ".join(s.strip() for s in configuration.scopes if s and s.strip())
scope_str = " ".join(
s.strip() for s in configuration.scopes if s and s.strip()
)
else:
scope_str = (
configuration.scopes.strip()
Expand Down Expand Up @@ -140,6 +147,10 @@ async def _obtain_token(self, client):
seconds=int(api_response.get("expires_in"))
)
self._access_token = api_response.get("access_token")
self._access_token_expiry_buffer = (
TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC
+ random.random() * TOKEN_EXPIRY_JITTER_IN_SEC
)
Comment thread
SoulPancake marked this conversation as resolved.
Outdated
self._telemetry.metrics.credentialsRequest(
attributes={
TelemetryAttributes.fga_client_request_client_id: configuration.client_id
Expand All @@ -154,8 +165,8 @@ async def get_authentication_header(self, client):
"""
If configured, return the header for authentication
"""
# check to see token is valid
if not self._token_valid():
# In this case, the token is not valid, we need to get the refresh the token
await self._obtain_token(client)
async with self._lock:
if not self._token_valid():
await self._obtain_token(client)
return {"Authorization": f"Bearer {self._access_token}"}
30 changes: 21 additions & 9 deletions openfga_sdk/sync/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
import math
import random
import sys
import threading
import time

from datetime import datetime, timedelta

import urllib3

from openfga_sdk.configuration import Configuration
from openfga_sdk.constants import USER_AGENT
from openfga_sdk.constants import (
TOKEN_EXPIRY_JITTER_IN_SEC,
TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC,
USER_AGENT,
)
from openfga_sdk.credentials import Credentials
from openfga_sdk.exceptions import AuthenticationError
from openfga_sdk.telemetry.attributes import TelemetryAttributes
Expand All @@ -36,6 +41,8 @@ def __init__(self, credentials: Credentials, configuration=None):
self._credentials = credentials
self._access_token = None
self._access_expiry_time = None
self._access_token_expiry_buffer = 0
self._lock = threading.Lock()
self._telemetry = Telemetry()

if configuration is None:
Expand All @@ -45,13 +52,12 @@ def __init__(self, credentials: Credentials, configuration=None):

def _token_valid(self):
"""
Return whether token is valid
Return whether token is valid (with proactive expiry buffer to avoid using near-expired tokens)
"""
if self._access_token is None or self._access_expiry_time is None:
return False
if self._access_expiry_time < datetime.now():
return False
return True
remaining = (self._access_expiry_time - datetime.now()).total_seconds()
return remaining > self._access_token_expiry_buffer

def _obtain_token(self, client):
"""
Expand All @@ -76,7 +82,9 @@ def _obtain_token(self, client):
# Add scope parameter if scopes are configured
if configuration.scopes is not None:
if isinstance(configuration.scopes, list):
scope_str = " ".join(s.strip() for s in configuration.scopes if s and s.strip())
scope_str = " ".join(
s.strip() for s in configuration.scopes if s and s.strip()
)
else:
scope_str = (
configuration.scopes.strip()
Expand Down Expand Up @@ -140,6 +148,10 @@ def _obtain_token(self, client):
seconds=int(api_response.get("expires_in"))
)
self._access_token = api_response.get("access_token")
self._access_token_expiry_buffer = (
Comment thread
SoulPancake marked this conversation as resolved.
Outdated
TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC
+ random.random() * TOKEN_EXPIRY_JITTER_IN_SEC
)
self._telemetry.metrics.credentialsRequest(
attributes={
TelemetryAttributes.fga_client_request_client_id: configuration.client_id
Expand All @@ -154,8 +166,8 @@ def get_authentication_header(self, client):
"""
If configured, return the header for authentication
"""
# check to see token is valid
if not self._token_valid():
# In this case, the token is not valid, we need to get the refresh the token
self._obtain_token(client)
with self._lock:
if not self._token_valid():
self._obtain_token(client)
return {"Authorization": f"Bearer {self._access_token}"}
85 changes: 83 additions & 2 deletions test/oauth2_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

from datetime import datetime, timedelta
from unittest import IsolatedAsyncioTestCase
from unittest.mock import patch
Expand All @@ -6,7 +8,7 @@

from openfga_sdk import rest
from openfga_sdk.configuration import Configuration
from openfga_sdk.constants import USER_AGENT
from openfga_sdk.constants import TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC, USER_AGENT
from openfga_sdk.credentials import CredentialConfiguration, Credentials
from openfga_sdk.exceptions import AuthenticationError
from openfga_sdk.oauth2 import OAuth2Client
Expand Down Expand Up @@ -34,7 +36,7 @@ async def test_get_authentication_valid_client_credentials(self):
"""
client = OAuth2Client(None)
client._access_token = "XYZ123"
client._access_expiry_time = datetime.now() + timedelta(seconds=60)
client._access_expiry_time = datetime.now() + timedelta(seconds=3600)
auth_header = await client.get_authentication_header(None)
self.assertEqual(auth_header, {"Authorization": "Bearer XYZ123"})

Expand Down Expand Up @@ -651,6 +653,85 @@ async def test_get_authentication_without_audience(self, mock_request):
)
await rest_client.close()

@patch.object(rest.RESTClientObject, "request")
@patch("openfga_sdk.oauth2.random")
async def test_get_authentication_refreshes_near_expiry_token(
self, mock_random, mock_request
):
"""
Token close to expiry (within buffer window) should trigger a proactive refresh
"""
mock_random.random.return_value = 0
short_lived_secs = max(1, TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC - 1)

Comment thread
SoulPancake marked this conversation as resolved.
mock_request.side_effect = [
mock_response(
f'{{"expires_in": {short_lived_secs}, "access_token": "short-lived-token"}}',
200,
),
mock_response(
'{"expires_in": 3600, "access_token": "refreshed-token"}',
200,
),
]

credentials = Credentials(
method="client_credentials",
configuration=CredentialConfiguration(
client_id="myclientid",
client_secret="mysecret",
api_issuer="issuer.fga.example",
api_audience="myaudience",
),
)
rest_client = rest.RESTClientObject(Configuration())
client = OAuth2Client(credentials)

header1 = await client.get_authentication_header(rest_client)
header2 = await client.get_authentication_header(rest_client)

self.assertEqual(header1, {"Authorization": "Bearer short-lived-token"})
self.assertEqual(header2, {"Authorization": "Bearer refreshed-token"})
self.assertEqual(mock_request.call_count, 2)

await rest_client.close()

async def test_concurrent_requests_only_fetch_token_once(self):
"""
Multiple concurrent requests while the token is invalid should result in
only one token fetch — subsequent coroutines wait on the lock and reuse
the token obtained by the first.
"""
obtain_calls = []

credentials = Credentials(
method="client_credentials",
configuration=CredentialConfiguration(
client_id="myclientid",
client_secret="mysecret",
api_issuer="issuer.fga.example",
api_audience="myaudience",
),
)
oauth_client = OAuth2Client(credentials)

async def mock_obtain_token(client):
obtain_calls.append(1)
await asyncio.sleep(0) # yield so other coroutines reach the lock
oauth_client._access_token = "concurrent-token"
oauth_client._access_expiry_time = datetime.now() + timedelta(seconds=3600)
oauth_client._access_token_expiry_buffer = 300

with patch.object(oauth_client, "_obtain_token", side_effect=mock_obtain_token):
results = await asyncio.gather(
*[oauth_client.get_authentication_header(None) for _ in range(5)]
)

self.assertEqual(len(obtain_calls), 1)
self.assertTrue(
all(r == {"Authorization": "Bearer concurrent-token"} for r in results)
)

@patch.object(rest.RESTClientObject, "request")
async def test_get_authentication_with_scopes_no_audience(self, mock_request):
"""
Expand Down
94 changes: 91 additions & 3 deletions test/sync/oauth2_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import threading
import time

from datetime import datetime, timedelta
from unittest import IsolatedAsyncioTestCase
from unittest.mock import patch

import urllib3

from openfga_sdk.configuration import Configuration
from openfga_sdk.constants import USER_AGENT
from openfga_sdk.constants import TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC, USER_AGENT
from openfga_sdk.credentials import CredentialConfiguration, Credentials
from openfga_sdk.exceptions import AuthenticationError
from openfga_sdk.sync import rest
Expand Down Expand Up @@ -34,7 +37,7 @@ def test_get_authentication_valid_client_credentials(self):
"""
client = OAuth2Client(None)
client._access_token = "XYZ123"
client._access_expiry_time = datetime.now() + timedelta(seconds=60)
client._access_expiry_time = datetime.now() + timedelta(seconds=3600)
auth_header = client.get_authentication_header(None)
self.assertEqual(auth_header, {"Authorization": "Bearer XYZ123"})

Expand Down Expand Up @@ -427,6 +430,92 @@ def test_get_authentication_without_audience(self, mock_request):
)
rest_client.close()

@patch.object(rest.RESTClientObject, "request")
@patch("openfga_sdk.sync.oauth2.random")
def test_get_authentication_refreshes_near_expiry_token(
self, mock_random, mock_request
):
"""
Token close to expiry (within buffer window) should trigger a proactive refresh
"""
mock_random.random.return_value = 0
short_lived_secs = max(1, TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC - 1)

Comment thread
SoulPancake marked this conversation as resolved.
mock_request.side_effect = [
mock_response(
f'{{"expires_in": {short_lived_secs}, "access_token": "short-lived-token"}}',
200,
),
mock_response(
'{"expires_in": 3600, "access_token": "refreshed-token"}',
200,
),
]

credentials = Credentials(
method="client_credentials",
configuration=CredentialConfiguration(
client_id="myclientid",
client_secret="mysecret",
api_issuer="issuer.fga.example",
api_audience="myaudience",
),
)
rest_client = rest.RESTClientObject(Configuration())
client = OAuth2Client(credentials)

header1 = client.get_authentication_header(rest_client)
header2 = client.get_authentication_header(rest_client)

self.assertEqual(header1, {"Authorization": "Bearer short-lived-token"})
self.assertEqual(header2, {"Authorization": "Bearer refreshed-token"})
self.assertEqual(mock_request.call_count, 2)

rest_client.close()

def test_concurrent_requests_only_fetch_token_once(self):
"""
Multiple concurrent threads while the token is invalid should result in
only one token fetch — subsequent threads wait on the lock and reuse
the token obtained by the first.
"""
obtain_calls = []

credentials = Credentials(
method="client_credentials",
configuration=CredentialConfiguration(
client_id="myclientid",
client_secret="mysecret",
api_issuer="issuer.fga.example",
api_audience="myaudience",
),
)
oauth_client = OAuth2Client(credentials)

def mock_obtain_token(client):
obtain_calls.append(1)
time.sleep(0.05) # hold the lock briefly so other threads queue up
oauth_client._access_token = "concurrent-token"
oauth_client._access_expiry_time = datetime.now() + timedelta(seconds=3600)
oauth_client._access_token_expiry_buffer = 300

results = []

def call():
results.append(oauth_client.get_authentication_header(None))

with patch.object(oauth_client, "_obtain_token", side_effect=mock_obtain_token):
threads = [threading.Thread(target=call) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()

self.assertEqual(len(obtain_calls), 1)
self.assertTrue(
all(r == {"Authorization": "Bearer concurrent-token"} for r in results)
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

@patch.object(rest.RESTClientObject, "request")
def test_get_authentication_with_scopes_no_audience(self, mock_request):
"""
Expand Down Expand Up @@ -477,4 +566,3 @@ def test_get_authentication_with_scopes_no_audience(self, mock_request):
},
)
rest_client.close()

Loading