-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtest_jwts.py
More file actions
262 lines (193 loc) · 8.4 KB
/
test_jwts.py
File metadata and controls
262 lines (193 loc) · 8.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
from typing import Any, Dict, Iterable
from unittest.mock import patch
import jwt
import pytest
from guardpost.jwks import JWKS, InMemoryKeysProvider, KeysProvider
from guardpost.jwks.caching import CachingKeysProvider
from guardpost.jwks.openid import AuthorityKeysProvider
from guardpost.jwks.urls import URLKeysProvider
from guardpost.jwts import InvalidAccessToken, JWTValidator
from .serverfixtures import * # noqa
from .serverfixtures import BASE_URL, get_file_path, get_test_jwks
@pytest.fixture(scope="session")
def default_keys_provider() -> KeysProvider:
return InMemoryKeysProvider(get_test_jwks())
class MockedKeysProvider(KeysProvider):
def __init__(self, mocked: Iterable[JWKS]) -> None:
self.mocked = iter(mocked)
async def get_keys(self) -> JWKS:
return next(self.mocked)
def get_access_token(
kid: str, payload: Dict[str, Any], include_headers: bool = True, fake_kid: str = ""
):
# loads the private key, use it to create an access token
# return the access token
with open(get_file_path(f"{kid}.pem"), "r") as key_file:
private_key = key_file.read()
return jwt.encode(
payload,
private_key,
algorithm="RS256",
headers={"kid": fake_kid or kid} if include_headers else None,
)
async def _valid_token_scenario(
kid: str, validator: JWTValidator, include_headers: bool = True
):
payload = {"aud": "a", "iss": "b"}
valid_token = get_access_token(kid, payload, include_headers=include_headers)
value = await validator.validate_jwt(valid_token)
assert value == payload
async def _valid_tokens_scenario(validator: JWTValidator, include_headers: bool = True):
for i in range(5):
await _valid_token_scenario(str(i), validator, include_headers)
def test_jwt_validator_raises_for_missing_key_source():
with pytest.raises(TypeError):
JWTValidator(valid_audiences=["a"], valid_issuers=["b"])
@pytest.mark.asyncio
async def test_jwt_validator_can_validate_valid_access_tokens(default_keys_provider):
validator = JWTValidator(
valid_audiences=["a"], valid_issuers=["b"], keys_provider=default_keys_provider
)
await _valid_tokens_scenario(validator)
@pytest.mark.asyncio
async def test_jwt_validator_cache_expiration(default_keys_provider):
with patch("guardpost.jwks.caching.time") as mock_time:
mock_time.time.return_value = 0
validator = JWTValidator(
valid_audiences=["a"],
valid_issuers=["b"],
keys_provider=default_keys_provider,
cache_time=10,
)
await _valid_tokens_scenario(validator)
# Simulate cache_time elapsed — keys must be re-fetched
mock_time.time.return_value = 11
await _valid_tokens_scenario(validator)
@pytest.mark.asyncio
async def test_jwt_validator_fetches_tokens_again_for_unknown_kid():
keys = get_test_jwks()
# configure a key provider that returns the given JWKS in sequence
keys_provider = MockedKeysProvider([JWKS(keys.keys[0:2]), JWKS(keys.keys[2:])])
with patch("guardpost.jwks.caching.time") as mock_time:
mock_time.time.return_value = 0
validator = JWTValidator(
valid_audiences=["a"],
valid_issuers=["b"],
keys_provider=keys_provider,
cache_time=10,
refresh_time=30,
)
await _valid_token_scenario("0", validator)
await _valid_token_scenario("1", validator)
# this must fail because refresh_time has not elapsed yet (t=1 < 30s)
mock_time.time.return_value = 1
with pytest.raises(InvalidAccessToken):
await _valid_token_scenario("2", validator)
# simulate refresh_time elapsed — provider should now fetch the new keys
mock_time.time.return_value = 31
await _valid_token_scenario("2", validator)
await _valid_token_scenario("3", validator)
await _valid_token_scenario("4", validator)
@pytest.mark.asyncio
async def test_jwt_validator_blocks_forged_access_tokens(default_keys_provider):
validator = JWTValidator(
valid_audiences=["a"], valid_issuers=["b"], keys_provider=default_keys_provider
)
payload = {"aud": "a", "iss": "b"}
forged_token = get_access_token("x", payload, fake_kid="1")
with pytest.raises(InvalidAccessToken):
await validator.validate_jwt(forged_token)
@pytest.mark.asyncio
async def test_jwt_validator_blocks_forged_access_tokens_no_kid(default_keys_provider):
validator = JWTValidator(
valid_audiences=["a"],
valid_issuers=["b"],
keys_provider=default_keys_provider,
require_kid=False,
)
payload = {"aud": "a", "iss": "b"}
forged_token = get_access_token("x", payload, fake_kid="1", include_headers=False)
with pytest.raises(InvalidAccessToken):
await validator.validate_jwt(forged_token)
@pytest.mark.asyncio
async def test_jwt_validator_blocks_invalid_kid(default_keys_provider):
validator = JWTValidator(
valid_audiences=["a"], valid_issuers=["b"], keys_provider=default_keys_provider
)
payload = {"aud": "a", "iss": "b"}
forged_token = get_access_token("x", payload)
with pytest.raises(InvalidAccessToken):
await validator.validate_jwt(forged_token)
@pytest.mark.asyncio
async def test_jwt_validator_can_validate_access_tokens_from_well_known_oidc_conf():
authority = BASE_URL + "/"
validator = JWTValidator(
valid_audiences=["a"], valid_issuers=["b"], authority=authority
)
keys_provider = validator._keys_provider
assert isinstance(keys_provider, CachingKeysProvider)
keys_provider = keys_provider.keys_provider
assert isinstance(keys_provider, AuthorityKeysProvider)
assert keys_provider.authority == authority
await _valid_tokens_scenario(validator)
@pytest.mark.asyncio
async def test_jwt_validator_can_validate_access_tokens_from_url():
url = BASE_URL + "/.well-known/jwks.json"
validator = JWTValidator(valid_audiences=["a"], valid_issuers=["b"], keys_url=url)
keys_provider = validator._keys_provider
assert isinstance(keys_provider, CachingKeysProvider)
keys_provider = keys_provider.keys_provider
assert isinstance(keys_provider, URLKeysProvider)
assert keys_provider.url == url
await _valid_tokens_scenario(validator)
@pytest.mark.asyncio
async def test_jwt_validator_raises_for_missing_key_id(default_keys_provider):
validator = JWTValidator(
valid_audiences=["a"], valid_issuers=["b"], keys_provider=default_keys_provider
)
payload = {"aud": "a", "iss": "b"}
valid_token = get_access_token("0", payload, include_headers=False)
with pytest.raises(InvalidAccessToken):
await validator.validate_jwt(valid_token)
@pytest.mark.asyncio
async def test_jwt_validator_supports_missing_key_id_by_configuration(
default_keys_provider,
):
validator = JWTValidator(
valid_audiences=["a"],
valid_issuers=["b"],
keys_provider=default_keys_provider,
require_kid=False,
)
await _valid_tokens_scenario(validator, include_headers=False)
@pytest.mark.asyncio
async def test_jwt_validator_raises_for_invalid_issuer(default_keys_provider):
validator = JWTValidator(
valid_audiences=["a"], valid_issuers=["b"], keys_provider=default_keys_provider
)
payload = {"aud": "a", "iss": "NO"}
valid_token = get_access_token("0", payload)
with pytest.raises(InvalidAccessToken):
await validator.validate_jwt(valid_token)
@pytest.mark.asyncio
async def test_jwt_validator_raises_for_invalid_audience(default_keys_provider):
validator = JWTValidator(
valid_audiences=["a"], valid_issuers=["b"], keys_provider=default_keys_provider
)
payload = {"aud": "NO", "iss": "b"}
valid_token = get_access_token("0", payload)
with pytest.raises(InvalidAccessToken):
await validator.validate_jwt(valid_token)
def test_authority_keys_provider_raises_for_missing_parameter():
with pytest.raises(TypeError):
AuthorityKeysProvider(None) # type: ignore
with pytest.raises(TypeError):
AuthorityKeysProvider("")
def test_url_keys_provider_raises_for_missing_parameter():
with pytest.raises(TypeError):
URLKeysProvider(None) # type: ignore
with pytest.raises(TypeError):
URLKeysProvider("")
def test_caching_keys_provider_raises_for_missing_parameter():
with pytest.raises(TypeError):
CachingKeysProvider(None, 1) # type: ignore