Skip to content

Commit 15cde20

Browse files
authored
fix: create new client if token decoding fails (#1495)
fixes #1401
1 parent b9fd05f commit 15cde20

4 files changed

Lines changed: 122 additions & 65 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ argon2 = "0.5.0"
6262
base64 = "0.22.0"
6363
cookie = "0.18.1"
6464
hex = "0.4"
65-
openid = { version = "0.15.0", default-features = false, features = ["rustls"] }
65+
openid = { version = "0.18.3", default-features = false, features = ["rustls"] }
6666
rustls = "0.22.4"
6767
rustls-pemfile = "2.1.2"
6868
sha2 = "0.10.8"

src/handlers/http/middleware.rs

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,16 @@ use actix_web::{
2424
dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready},
2525
error::{ErrorBadRequest, ErrorForbidden, ErrorUnauthorized},
2626
http::header::{self, HeaderName},
27-
web::Data,
2827
};
2928
use chrono::{Duration, Utc};
3029
use futures_util::future::LocalBoxFuture;
3130

3231
use crate::{
3332
handlers::{
3433
AUTHORIZATION_KEY, KINESIS_COMMON_ATTRIBUTES_KEY, LOG_SOURCE_KEY, LOG_SOURCE_KINESIS,
35-
STREAM_NAME_HEADER_KEY, http::rbac::RBACError,
34+
STREAM_NAME_HEADER_KEY,
35+
http::{modal::OIDC_CLIENT, rbac::RBACError},
3636
},
37-
oidc::DiscoveredClient,
3837
option::Mode,
3938
parseable::PARSEABLE,
4039
rbac::{
@@ -145,7 +144,7 @@ where
145144
when request is made from Kinesis Firehose.
146145
For requests made from other clients, no change.
147146
148-
## Section start */
147+
## Section start */
149148
if let Some(kinesis_common_attributes) =
150149
req.request().headers().get(KINESIS_COMMON_ATTRIBUTES_KEY)
151150
{
@@ -183,12 +182,13 @@ where
183182

184183
// if session is expired, refresh token
185184
if sessions().is_session_expired(&key) {
186-
let oidc_client = match http_req.app_data::<Data<Option<DiscoveredClient>>>() {
187-
Some(client) => {
188-
let c = client.clone().into_inner();
189-
c.as_ref().clone()
190-
}
191-
None => None,
185+
let oidc_client = if let Some(client) = OIDC_CLIENT.get()
186+
&& let Some(client) = client
187+
{
188+
let guard = client.read().await;
189+
Some(guard.client().clone())
190+
} else {
191+
None
192192
};
193193

194194
if let Some(client) = oidc_client
@@ -208,13 +208,19 @@ where
208208
};
209209

210210
if let Some(oauth_data) = bearer_to_refresh {
211-
let Ok(refreshed_token) = client
211+
let refreshed_token = match client
212212
.refresh_token(&oauth_data, Some(PARSEABLE.options.scope.as_str()))
213213
.await
214-
else {
215-
return Err(ErrorUnauthorized(
216-
"Your session has expired or is no longer valid. Please re-authenticate to access this resource.",
217-
));
214+
{
215+
Ok(bearer) => bearer,
216+
Err(e) => {
217+
tracing::error!("client refresh_token call failed- {e}");
218+
// remove user session
219+
Users.remove_session(&key);
220+
return Err(ErrorUnauthorized(
221+
"Your session has expired or is no longer valid. Please re-authenticate to access this resource.",
222+
));
223+
}
218224
};
219225

220226
let expires_in =

src/handlers/http/modal/mod.rs

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,20 @@
1818

1919
use std::{fmt, path::Path, sync::Arc};
2020

21-
use actix_web::{
22-
App, HttpServer,
23-
middleware::from_fn,
24-
web::{self, ServiceConfig},
25-
};
21+
use actix_web::{App, HttpServer, middleware::from_fn, web::ServiceConfig};
2622
use actix_web_prometheus::PrometheusMetrics;
2723
use anyhow::Context;
2824
use async_trait::async_trait;
2925
use base64::{Engine, prelude::BASE64_STANDARD};
3026
use bytes::Bytes;
3127
use futures::future;
28+
use once_cell::sync::OnceCell;
3229
use openid::Discovered;
3330
use relative_path::RelativePathBuf;
3431
use serde::{Deserialize, Serialize};
3532
use serde_json::{Map, Value};
3633
use ssl_acceptor::get_ssl_acceptor;
37-
use tokio::sync::oneshot;
34+
use tokio::sync::{RwLock, oneshot};
3835
use tracing::{error, info, warn};
3936

4037
use crate::{
@@ -43,7 +40,7 @@ use crate::{
4340
correlation::CORRELATIONS,
4441
hottier::{HotTierManager, StreamHotTier},
4542
metastore::metastore_traits::MetastoreObject,
46-
oidc::Claims,
43+
oidc::{Claims, DiscoveredClient},
4744
option::Mode,
4845
parseable::PARSEABLE,
4946
storage::{ObjectStorageProvider, PARSEABLE_ROOT_DIRECTORY},
@@ -63,6 +60,27 @@ pub mod utils;
6360

6461
pub type OpenIdClient = Arc<openid::Client<Discovered, Claims>>;
6562

63+
pub static OIDC_CLIENT: OnceCell<Option<Arc<RwLock<GlobalClient>>>> = OnceCell::new();
64+
65+
#[derive(Debug)]
66+
pub struct GlobalClient {
67+
client: DiscoveredClient,
68+
}
69+
70+
impl GlobalClient {
71+
pub fn set(&mut self, client: DiscoveredClient) {
72+
self.client = client;
73+
}
74+
75+
pub fn client(&self) -> &DiscoveredClient {
76+
&self.client
77+
}
78+
79+
pub fn new(client: DiscoveredClient) -> Self {
80+
Self { client }
81+
}
82+
}
83+
6684
// to be decided on what the Default version should be
6785
pub const DEFAULT_VERSION: &str = "v4";
6886

@@ -95,16 +113,14 @@ pub trait ParseableServer {
95113
where
96114
Self: Sized,
97115
{
98-
let oidc_client = match oidc_client {
99-
Some(config) => {
100-
let client = config
101-
.connect(&format!("{API_BASE_PATH}/{API_VERSION}/o/code"))
102-
.await?;
103-
Some(client)
104-
}
105-
106-
None => None,
107-
};
116+
if let Some(config) = oidc_client {
117+
let client = config
118+
.connect(&format!("{API_BASE_PATH}/{API_VERSION}/o/code"))
119+
.await?;
120+
OIDC_CLIENT.get_or_init(|| Some(Arc::new(RwLock::new(GlobalClient::new(client)))));
121+
} else {
122+
OIDC_CLIENT.get_or_init(|| None);
123+
}
108124

109125
// get the ssl stuff
110126
let ssl = get_ssl_acceptor(
@@ -120,7 +136,6 @@ pub trait ParseableServer {
120136
// fn that creates the app
121137
let create_app_fn = move || {
122138
App::new()
123-
.app_data(web::Data::new(oidc_client.clone()))
124139
.wrap(prometheus.clone())
125140
.configure(|config| Self::configure_routes(config))
126141
.wrap(from_fn(health_check::check_shutdown_middleware))

src/handlers/http/oidc.rs

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,31 @@
1616
*
1717
*/
1818

19-
use std::collections::HashSet;
19+
use std::{collections::HashSet, sync::Arc};
2020

2121
use actix_web::{
2222
HttpRequest, HttpResponse,
2323
cookie::{Cookie, SameSite, time},
2424
http::header::{self, ContentType},
25-
web::{self, Data},
25+
web,
2626
};
2727
use chrono::{Duration, TimeDelta};
2828
use http::StatusCode;
2929
use openid::{Bearer, Options, Token, Userinfo};
3030
use regex::Regex;
3131
use serde::Deserialize;
32+
use tokio::sync::RwLock;
3233
use ulid::Ulid;
3334
use url::Url;
3435

3536
use crate::{
36-
handlers::{COOKIE_AGE_DAYS, SESSION_COOKIE_NAME, USER_COOKIE_NAME, USER_ID_COOKIE_NAME},
37+
handlers::{
38+
COOKIE_AGE_DAYS, SESSION_COOKIE_NAME, USER_COOKIE_NAME, USER_ID_COOKIE_NAME,
39+
http::{
40+
API_BASE_PATH, API_VERSION,
41+
modal::{GlobalClient, OIDC_CLIENT},
42+
},
43+
},
3744
oidc::{Claims, DiscoveredClient},
3845
parseable::PARSEABLE,
3946
rbac::{
@@ -73,20 +80,18 @@ pub async fn login(
7380
));
7481
}
7582

76-
let oidc_client = match req.app_data::<Data<Option<DiscoveredClient>>>() {
77-
Some(client) => {
78-
let c = client.clone().into_inner();
79-
c.as_ref().clone()
80-
}
83+
let oidc_client = match OIDC_CLIENT.get() {
84+
Some(c) => c.as_ref().cloned(),
8185
None => None,
8286
};
87+
8388
let session_key = extract_session_key_from_req(&req).ok();
8489
let (session_key, oidc_client) = match (session_key, oidc_client) {
8590
(None, None) => return Ok(redirect_no_oauth_setup(query.redirect.clone())),
8691
(None, Some(client)) => {
8792
return Ok(redirect_to_oidc(
8893
query,
89-
&client,
94+
client.read().await.client(),
9095
PARSEABLE.options.scope.to_string().as_str(),
9196
));
9297
}
@@ -131,7 +136,7 @@ pub async fn login(
131136
if let Some(oidc_client) = oidc_client {
132137
redirect_to_oidc(
133138
query,
134-
&oidc_client,
139+
oidc_client.read().await.client(),
135140
PARSEABLE.options.scope.to_string().as_str(),
136141
)
137142
} else {
@@ -144,13 +149,11 @@ pub async fn login(
144149
}
145150

146151
pub async fn logout(req: HttpRequest, query: web::Query<RedirectAfterLogin>) -> HttpResponse {
147-
let oidc_client = match req.app_data::<Data<Option<DiscoveredClient>>>() {
148-
Some(client) => {
149-
let c = client.clone().into_inner();
150-
c.as_ref().clone()
151-
}
152+
let oidc_client = match OIDC_CLIENT.get() {
153+
Some(c) => Some(c.as_ref().unwrap().read().await.client().clone()),
152154
None => None,
153155
};
156+
154157
let Some(session) = extract_session_key_from_req(&req).ok() else {
155158
return redirect_to_client(query.redirect.as_str(), None);
156159
};
@@ -170,16 +173,21 @@ pub async fn logout(req: HttpRequest, query: web::Query<RedirectAfterLogin>) ->
170173

171174
/// Handler for code callback
172175
/// User should be redirected to page they were trying to access with cookie
173-
pub async fn reply_login(
174-
req: HttpRequest,
175-
login_query: web::Query<Login>,
176-
) -> Result<HttpResponse, OIDCError> {
177-
let oidc_client = req.app_data::<Data<Option<DiscoveredClient>>>().unwrap();
178-
let oidc_client = oidc_client.clone().into_inner().as_ref().clone().unwrap();
179-
let Ok((mut claims, user_info, bearer)): Result<(Claims, Userinfo, Bearer), anyhow::Error> =
180-
request_token(oidc_client, &login_query).await
181-
else {
182-
return Ok(HttpResponse::Unauthorized().finish());
176+
pub async fn reply_login(login_query: web::Query<Login>) -> Result<HttpResponse, OIDCError> {
177+
let oidc_client = if let Some(oidc_client) = OIDC_CLIENT.get()
178+
&& let Some(oidc_client) = oidc_client
179+
{
180+
oidc_client
181+
} else {
182+
return Err(OIDCError::Unauthorized);
183+
};
184+
185+
let (mut claims, user_info, bearer) = match request_token(oidc_client, &login_query).await {
186+
Ok(v) => v,
187+
Err(e) => {
188+
tracing::error!("reply_login call failed- {e}");
189+
return Ok(HttpResponse::Unauthorized().finish());
190+
}
183191
};
184192
let username = user_info
185193
.name
@@ -351,6 +359,7 @@ pub fn redirect_to_client(
351359
response.cookie(cookie);
352360
}
353361
response.insert_header((header::CACHE_CONTROL, "no-store"));
362+
354363
response.finish()
355364
}
356365

@@ -387,19 +396,46 @@ pub fn cookie_userid(user_id: &str) -> Cookie<'static> {
387396
}
388397

389398
pub async fn request_token(
390-
oidc_client: DiscoveredClient,
399+
oidc_client: &Arc<RwLock<GlobalClient>>,
391400
login_query: &Login,
392401
) -> anyhow::Result<(Claims, Userinfo, Bearer)> {
393-
let mut token: Token<Claims> = oidc_client.request_token(&login_query.code).await?.into();
394-
let Some(id_token) = token.id_token.as_mut() else {
402+
let old_client = oidc_client.read().await.client().clone();
403+
let mut token: Token<Claims> = old_client.request_token(&login_query.code).await?.into();
404+
405+
let id_token = if let Some(token) = token.id_token.as_mut() {
406+
token
407+
} else {
395408
return Err(anyhow::anyhow!("No id_token provided"));
396409
};
397410

398-
oidc_client.decode_token(id_token)?;
399-
oidc_client.validate_token(id_token, None, None)?;
411+
if let Err(e) = old_client.decode_token(id_token) {
412+
tracing::error!("error while decoding the id_token- {e}");
413+
let new_client = PARSEABLE
414+
.options
415+
.openid()
416+
.unwrap()
417+
.connect(&format!("{API_BASE_PATH}/{API_VERSION}/o/code"))
418+
.await?;
419+
420+
// Reuse the already-obtained token, just decode with new client's JWKS
421+
new_client.decode_token(id_token)?;
422+
new_client.validate_token(id_token, None, None)?;
423+
let claims = id_token.payload().expect("payload is decoded").clone();
424+
425+
let userinfo = new_client.request_userinfo(&token).await?;
426+
let bearer = token.bearer;
427+
428+
// replace old client with new one
429+
drop(old_client);
430+
431+
oidc_client.write().await.set(new_client);
432+
return Ok((claims, userinfo, bearer));
433+
}
434+
435+
old_client.validate_token(id_token, None, None)?;
400436
let claims = id_token.payload().expect("payload is decoded").clone();
401437

402-
let userinfo = oidc_client.request_userinfo(&token).await?;
438+
let userinfo = old_client.request_userinfo(&token).await?;
403439
let bearer = token.bearer;
404440
Ok((claims, userinfo, bearer))
405441
}

0 commit comments

Comments
 (0)