forked from FrancescoDatascientest/mc_sec_api
-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy path3_exemple_oauth.py
More file actions
140 lines (119 loc) · 4.86 KB
/
3_exemple_oauth.py
File metadata and controls
140 lines (119 loc) · 4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import joblib
import hashlib
from fastapi import FastAPI, HTTPException, Depends, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel
from hashlib import sha256
from sklearn.preprocessing import OneHotEncoder
import numpy as np
import joblib
import json
import os
from datetime import datetime, timedelta
import jwt
from cryptography.fernet import Fernet
app = FastAPI()
# Constants
JSON_FILE_PATH = os.path.expanduser("./users/users.json")
SECRET_KEY = Fernet.generate_key()
ALGORITHM = "HS256"
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token")
# User Models
class User(BaseModel):
username: str
first_name: str
last_name: str
password: str
class UserOut(BaseModel):
username: str
first_name: str
last_name: str
class UserInDB(User):
password: str
class Config:
orm_mode = True
class UserPred(BaseModel):
age: int
sex: str
favorite_color: str
favorite_food: str
class Token(BaseModel):
access_token: str
token_type: str
# Model and Encoder Initialization
model = joblib.load("./models/model_fin2.pkl")
allowed_favorite_colors = ['Red', 'Blue', 'Green', 'Yellow', 'Purple']
allowed_favorite_foods = ['Pizza', 'Pasta', 'Burger', 'Sushi', 'Salad', 'Ice Cream']
allowed_sex = ['Male', 'Female']
encoder = OneHotEncoder(categories=[allowed_sex, allowed_favorite_colors, allowed_favorite_foods], sparse_output=False)
dummy_data = np.array([['Male', 'Red', 'Pizza']])
encoder.fit(dummy_data)
# Helper Functions
def verify_password(plain_password, hashed_password):
return sha256(plain_password.encode()).hexdigest() == hashed_password
def load_users():
if os.path.exists(JSON_FILE_PATH):
with open(JSON_FILE_PATH, "r") as file:
users_data = json.load(file)
return [UserInDB(**user) for user in users_data]
return []
def save_user(user: UserInDB):
users = load_users()
users.append(user)
with open(JSON_FILE_PATH, "w") as file:
json.dump([user.dict() for user in users], file)
def create_access_token(data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def get_current_user(token: str = Depends(oauth2_scheme)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return {"username": payload["sub"]}
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
def validate_user_input(user: UserPred):
if user.favorite_color not in allowed_favorite_colors:
raise HTTPException(status_code=400, detail="Invalid favorite color")
if user.favorite_food not in allowed_favorite_foods:
raise HTTPException(status_code=400, detail="Invalid favorite food")
if user.sex not in allowed_sex:
raise HTTPException(status_code=400, detail="Invalid sex")
def preprocess_user_data(user: UserPred):
categorical_features = np.array([[user.sex, user.favorite_color, user.favorite_food]])
encoded_features = encoder.transform(categorical_features).flatten()
features = np.concatenate(([user.age], encoded_features))
return features.reshape(1, -1)
# Endpoints
@app.post("/register", response_model=UserOut)
async def register(user: User):
hashed_password = sha256(user.password.encode()).hexdigest()
user_data = user.dict(exclude={"password"})
user_in_db = UserInDB(**user_data, password=hashed_password)
save_user(user_in_db)
return UserOut(**user_data)
@app.post("/token", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
username = form_data.username
password = form_data.password
users = load_users()
for user in users:
if user.username == username and verify_password(password, user.password):
token_data = {"sub": username}
access_token = create_access_token(token_data, expires_delta=timedelta(minutes=30))
return {"access_token": access_token, "token_type": "bearer"}
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
@app.post("/predict/")
async def predict_sign(user: UserPred, current_user: UserOut = Depends(get_current_user)):
validate_user_input(user)
features = preprocess_user_data(user)
prediction = model.predict(features)[0]
return {"astrological_sign": prediction}
if __name__ == "__main__":
import uvicorn
uvicorn.run("3_exemple_oauth:app", host="0.0.0.0", port=8000, reload=True)