Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/proxy/handlers/chat_completions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{
proxy::{
AppState,
hooks::{HOOK_FILTER_ALL, HOOK_MANAGER, HookContext, ResponseData, TokenUsage},
middlewares::RequestModel,
hooks2::{RequestContext, authorization},
},
utils::future::maybe_timeout,
};
Expand All @@ -32,9 +32,11 @@ pub async fn chat_completions(
State(_state): State<AppState>,
Extension(span_ctx): Extension<SpanContext>,
mut hook_ctx: HookContext,
mut request_ctx: RequestContext,
Json(mut request_data): Json<ChatCompletionRequest>,
) -> Result<Response, ChatCompletionError> {
hook_ctx.insert(RequestModel(request_data.model));
authorization::check(&mut request_ctx, request_data.model.clone()).await?;

let mut request = Request::new(Body::empty()); //TODO
HOOK_MANAGER
.pre_call(&mut hook_ctx, &mut request, HOOK_FILTER_ALL)
Expand Down
8 changes: 7 additions & 1 deletion src/proxy/handlers/chat_completions/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::time::error::Elapsed;

use crate::{providers::ProviderError, proxy::hooks::HookError};
use crate::{
providers::ProviderError,
proxy::{hooks::HookError, hooks2::authorization::AuthorizationError},
};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
Expand Down Expand Up @@ -143,6 +146,8 @@ pub struct ChatCompletionChunk {

#[derive(Debug, Error)]
pub enum ChatCompletionError {
#[error("Authorization error: {0}")]
AuthorizationError(#[from] AuthorizationError),
#[error("Provider error: {0}")]
ProviderError(#[from] ProviderError),
#[error("Request timed out")]
Expand All @@ -154,6 +159,7 @@ pub enum ChatCompletionError {
impl IntoResponse for ChatCompletionError {
fn into_response(self) -> axum::response::Response {
match self {
ChatCompletionError::AuthorizationError(err) => err.into_response(),
ChatCompletionError::ProviderError(err) => (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({
Expand Down
4 changes: 4 additions & 0 deletions src/proxy/handlers/embeddings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::{
proxy::{
AppState,
hooks::{HOOK_FILTER_ALL, HOOK_MANAGER, HookContext, ResponseData},
hooks2::{RequestContext, authorization},
middlewares::RequestModel,
},
utils::future::maybe_timeout,
Expand All @@ -25,8 +26,11 @@ use crate::{
pub async fn embeddings(
State(_state): State<AppState>,
mut hook_ctx: HookContext,
mut request_ctx: RequestContext,
Json(mut request_data): Json<EmbeddingRequest>,
) -> Result<Response, EmbeddingError> {
authorization::check(&mut request_ctx, request_data.model.clone()).await?;

// PRE CALL HOOKS START
hook_ctx.insert(RequestModel(request_data.model));

Expand Down
8 changes: 7 additions & 1 deletion src/proxy/handlers/embeddings/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::time::error::Elapsed;

use crate::{providers::ProviderError, proxy::hooks::HookError};
use crate::{
providers::ProviderError,
proxy::{hooks::HookError, hooks2::authorization::AuthorizationError},
};

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
Expand Down Expand Up @@ -48,6 +51,8 @@ pub struct EmbeddingResponse {

#[derive(Debug, Error)]
pub enum EmbeddingError {
#[error("Authorization error: {0}")]
AuthorizationError(#[from] AuthorizationError),
#[error("Provider error: {0}")]
ProviderError(#[from] ProviderError),
#[error("Request timed out")]
Expand All @@ -59,6 +64,7 @@ pub enum EmbeddingError {
impl IntoResponse for EmbeddingError {
fn into_response(self) -> axum::response::Response {
match self {
EmbeddingError::AuthorizationError(err) => err.into_response(),
EmbeddingError::ProviderError(err) => (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({
Expand Down
4 changes: 3 additions & 1 deletion src/proxy/handlers/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::{
proxy::{
AppState,
hooks::{HOOK_FILTER_NONE, HOOK_MANAGER, HookContext, HookError},
hooks2::RequestContext,
},
};

Expand Down Expand Up @@ -56,13 +57,14 @@ impl IntoResponse for ModelError {
pub async fn list_models(
State(state): State<AppState>,
mut hook_ctx: HookContext,
request_ctx: RequestContext,
mut request: Request,
) -> Result<Response, ModelError> {
HOOK_MANAGER
.pre_call(&mut hook_ctx, &mut request, HOOK_FILTER_NONE)
.await?;

let api_key = hook_ctx
let api_key = request_ctx
.get::<ResourceEntry<ApiKey>>()
.cloned()
.expect("apikey should exist in context");
Expand Down
2 changes: 0 additions & 2 deletions src/proxy/hooks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
mod metric;
mod rate_limit;
mod validate_model;

use std::{
any::Any,
Expand Down Expand Up @@ -352,7 +351,6 @@ impl HookManager {
pub static HOOK_MANAGER: LazyLock<HookManager> = LazyLock::new(|| {
let mut manager = HookManager::new();
manager
.register(Box::new(validate_model::ValidateModelHook))
.register(Box::new(rate_limit::RateLimitHook))
.register(Box::new(metric::MetricHook));
manager
Expand Down
115 changes: 0 additions & 115 deletions src/proxy/hooks/validate_model.rs

This file was deleted.

86 changes: 86 additions & 0 deletions src/proxy/hooks2/authorization/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use anyhow::Result;
Comment thread
bzp2010 marked this conversation as resolved.
use axum::{Json, response::IntoResponse};
use http::StatusCode;
use log::error;
use serde_json::json;
use thiserror::Error;

use crate::{
config::entities::{ApiKey, ResourceEntry},
proxy::hooks2::RequestContext,
};

#[derive(Clone)]
pub struct RequestModel(#[allow(unused)] pub String);

#[derive(Debug, Clone, Error, PartialEq, Eq, Hash)]
pub enum AuthorizationError {
#[error("Model '{0}' not found")]
ModelNotFound(String),
#[error("Access to model '{0}' is forbidden")]
AccessForbidden(String),

// INTERNAL ERROR
#[error("Apikey not found in context")]
MissingApiKeyInContext,
}

impl IntoResponse for AuthorizationError {
fn into_response(self) -> axum::response::Response {
match self {
AuthorizationError::ModelNotFound(_) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": self.to_string(),
"type": "invalid_request_error",
"code": "model_not_found"
}
})),
)
.into_response(),
AuthorizationError::AccessForbidden(_) => (
StatusCode::FORBIDDEN,
Json(json!({
"error": {
"message": self.to_string(),
"type": "invalid_request_error",
"code": "model_access_forbidden"
}
})),
)
.into_response(),
AuthorizationError::MissingApiKeyInContext => {
(StatusCode::INTERNAL_SERVER_ERROR).into_response()
}
}
}
}

#[fastrace::trace]
pub async fn check(ctx: &mut RequestContext, model_name: String) -> Result<(), AuthorizationError> {
let model = match ctx.app_state().resources().models.get_by_name(&model_name) {
Some(model) => model,
None => {
return Err(AuthorizationError::ModelNotFound(model_name.clone()));
}
};

let api_key = match ctx.get::<ResourceEntry<ApiKey>>().cloned() {
Some(api_key) => api_key,
None => {
error!("API key not found in context");
return Err(AuthorizationError::MissingApiKeyInContext);
}
};

// Check if API key has access to this model
if !api_key.allowed_models.contains(&model_name) {
return Err(AuthorizationError::AccessForbidden(model_name.clone()));
}

ctx.insert(model);
ctx.insert(RequestModel(model_name));

Ok(())
}
53 changes: 53 additions & 0 deletions src/proxy/hooks2/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use std::ops::{Deref, DerefMut};

use axum::extract::FromRequestParts;
use http::request::Parts;

use crate::{
config::entities::{ApiKey, ResourceEntry},
proxy::AppState,
};

pub struct RequestContext {
#[allow(unused)]
app_state: AppState,
extensions: http::Extensions,
}

impl FromRequestParts<AppState> for RequestContext {
type Rejection = ();

async fn from_request_parts(
parts: &mut Parts,
state: &AppState,
) -> Result<Self, Self::Rejection> {
let mut ctx = http::Extensions::new();
ctx.insert(parts.extensions.remove::<ResourceEntry<ApiKey>>().expect(
"Authentication middleware should have inserted ApiKey into request extensions",
));
Ok(Self {
app_state: state.clone(),
extensions: ctx,
})
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

impl Deref for RequestContext {
type Target = http::Extensions;

fn deref(&self) -> &Self::Target {
&self.extensions
}
}

impl DerefMut for RequestContext {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.extensions
}
}

impl RequestContext {
pub fn app_state(&self) -> &AppState {
&self.app_state
}
Comment thread
bzp2010 marked this conversation as resolved.
}
Loading
Loading