Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/proxy/handlers/chat_completions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::{
proxy::{
AppState,
hooks::{HOOK_FILTER_ALL, HOOK_MANAGER, HookContext, ResponseData, TokenUsage},
hooks2::{RequestContext, authorization},
middlewares::RequestModel,
},
utils::future::maybe_timeout,
Expand All @@ -31,10 +32,17 @@ use crate::{
pub async fn chat_completions(
State(_state): State<AppState>,
Extension(span_ctx): Extension<SpanContext>,
mut request_ctx: RequestContext,
mut hook_ctx: HookContext,
Json(mut request_data): Json<ChatCompletionRequest>,
) -> Result<Response, ChatCompletionError> {
authorization::check(&mut request_ctx, request_data.model.clone()).await?;

// TODO: remove
let _model = request_ctx.get::<ResourceEntry<Model>>().unwrap().clone();
hook_ctx.insert(_model);
hook_ctx.insert(RequestModel(request_data.model));

let mut request = Request::new(Body::empty()); //TODO
HOOK_MANAGER
.pre_call(&mut hook_ctx, &mut request, HOOK_FILTER_ALL)
Expand Down
8 changes: 7 additions & 1 deletion src/proxy/handlers/chat_completions/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::time::error::Elapsed;

use crate::{providers::ProviderError, proxy::hooks::HookError};
use crate::{
providers::ProviderError,
proxy::{hooks::HookError, hooks2::authorization::AuthorizationError},
};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
Expand Down Expand Up @@ -143,6 +146,8 @@ pub struct ChatCompletionChunk {

#[derive(Debug, Error)]
pub enum ChatCompletionError {
#[error("Authorization error: {0}")]
AuthorizationError(#[from] AuthorizationError),
#[error("Provider error: {0}")]
ProviderError(#[from] ProviderError),
#[error("Request timed out")]
Expand All @@ -154,6 +159,7 @@ pub enum ChatCompletionError {
impl IntoResponse for ChatCompletionError {
fn into_response(self) -> axum::response::Response {
match self {
ChatCompletionError::AuthorizationError(err) => err.into_response(),
ChatCompletionError::ProviderError(err) => (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({
Expand Down
7 changes: 7 additions & 0 deletions src/proxy/handlers/embeddings/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,24 @@ use crate::{
proxy::{
AppState,
hooks::{HOOK_FILTER_ALL, HOOK_MANAGER, HookContext, ResponseData},
hooks2::{RequestContext, authorization},
middlewares::RequestModel,
},
utils::future::maybe_timeout,
};

pub async fn embeddings(
State(_state): State<AppState>,
mut request_ctx: RequestContext,
mut hook_ctx: HookContext,
Json(mut request_data): Json<EmbeddingRequest>,
) -> Result<Response, EmbeddingError> {
authorization::check(&mut request_ctx, request_data.model.clone()).await?;

// PRE CALL HOOKS START
// TODO: remove
let _model = request_ctx.get::<ResourceEntry<Model>>().unwrap().clone();
hook_ctx.insert(_model);
hook_ctx.insert(RequestModel(request_data.model));

let mut request = Request::new(Body::empty()); //TODO
Expand Down
8 changes: 7 additions & 1 deletion src/proxy/handlers/embeddings/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::time::error::Elapsed;

use crate::{providers::ProviderError, proxy::hooks::HookError};
use crate::{
providers::ProviderError,
proxy::{hooks::HookError, hooks2::authorization::AuthorizationError},
};

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
Expand Down Expand Up @@ -48,6 +51,8 @@ pub struct EmbeddingResponse {

#[derive(Debug, Error)]
pub enum EmbeddingError {
#[error("Authorization error: {0}")]
AuthorizationError(#[from] AuthorizationError),
#[error("Provider error: {0}")]
ProviderError(#[from] ProviderError),
#[error("Request timed out")]
Expand All @@ -59,6 +64,7 @@ pub enum EmbeddingError {
impl IntoResponse for EmbeddingError {
fn into_response(self) -> axum::response::Response {
match self {
EmbeddingError::AuthorizationError(err) => err.into_response(),
EmbeddingError::ProviderError(err) => (
StatusCode::BAD_GATEWAY,
Json(serde_json::json!({
Expand Down
4 changes: 3 additions & 1 deletion src/proxy/handlers/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::{
proxy::{
AppState,
hooks::{HOOK_FILTER_NONE, HOOK_MANAGER, HookContext, HookError},
hooks2::RequestContext,
},
};

Expand Down Expand Up @@ -55,14 +56,15 @@ impl IntoResponse for ModelError {
#[fastrace::trace]
pub async fn list_models(
State(state): State<AppState>,
request_ctx: RequestContext,
mut hook_ctx: HookContext,
mut request: Request,
) -> Result<Response, ModelError> {
HOOK_MANAGER
.pre_call(&mut hook_ctx, &mut request, HOOK_FILTER_NONE)
.await?;

let api_key = hook_ctx
let api_key = request_ctx
.get::<ResourceEntry<ApiKey>>()
.cloned()
.expect("apikey should exist in context");
Expand Down
2 changes: 0 additions & 2 deletions src/proxy/hooks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
mod metric;
mod rate_limit;
mod validate_model;

use std::{
any::Any,
Expand Down Expand Up @@ -352,7 +351,6 @@ impl HookManager {
pub static HOOK_MANAGER: LazyLock<HookManager> = LazyLock::new(|| {
let mut manager = HookManager::new();
manager
.register(Box::new(validate_model::ValidateModelHook))
.register(Box::new(rate_limit::RateLimitHook))
.register(Box::new(metric::MetricHook));
manager
Expand Down
115 changes: 0 additions & 115 deletions src/proxy/hooks/validate_model.rs

This file was deleted.

86 changes: 86 additions & 0 deletions src/proxy/hooks2/authorization/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use anyhow::Result;
Comment thread
bzp2010 marked this conversation as resolved.
use axum::{Json, response::IntoResponse};
use http::StatusCode;
use log::error;
use serde_json::json;
use thiserror::Error;

use crate::{
config::entities::{ApiKey, ResourceEntry},
proxy::hooks2::RequestContext,
};

#[derive(Clone)]
pub struct RequestModel(#[allow(unused)] pub String);

#[derive(Debug, Clone, Error, PartialEq, Eq, Hash)]
pub enum AuthorizationError {
#[error("Model '{0}' not found")]
ModelNotFound(String),
#[error("Access to model '{0}' is forbidden")]
AccessForbidden(String),

// INTERNAL ERROR
#[error("Apikey not found in context")]
MissingApiKeyInContext,
}

impl IntoResponse for AuthorizationError {
fn into_response(self) -> axum::response::Response {
match self {
AuthorizationError::ModelNotFound(_) => (
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": self.to_string(),
"type": "invalid_request_error",
"code": "model_not_found"
}
})),
)
.into_response(),
AuthorizationError::AccessForbidden(_) => (
StatusCode::FORBIDDEN,
Json(json!({
"error": {
"message": self.to_string(),
"type": "invalid_request_error",
"code": "model_access_forbidden"
}
})),
)
.into_response(),
AuthorizationError::MissingApiKeyInContext => {
(StatusCode::INTERNAL_SERVER_ERROR).into_response()
}
}
}
}

#[fastrace::trace]
pub async fn check(ctx: &mut RequestContext, model_name: String) -> Result<(), AuthorizationError> {
let model = match ctx.app_state().resources().models.get_by_name(&model_name) {
Some(model) => model,
None => {
return Err(AuthorizationError::ModelNotFound(model_name.clone()));
}
};

let api_key = match ctx.get::<ResourceEntry<ApiKey>>().cloned() {
Some(api_key) => api_key,
None => {
error!("API key not found in context");
return Err(AuthorizationError::MissingApiKeyInContext);
}
};

// Check if API key has access to this model
if !api_key.allowed_models.contains(&model_name) {
return Err(AuthorizationError::AccessForbidden(model_name.clone()));
}

ctx.insert(model);
ctx.insert(RequestModel(model_name));

Ok(())
}
Loading
Loading