Skip to content
Merged

Dev #12

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ certs/
src/**/__pycache__
src/specs/*
src/.streamlit/secrets.toml
src/logs/
4 changes: 4 additions & 0 deletions src/endpoints/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from starlette.responses import RedirectResponse

from services.db import get_db
from utils.logger import logger

async def create_session(request: Request):
"""Create a session from an auth token.
Expand All @@ -19,6 +20,7 @@ async def create_session(request: Request):

token = request.query_params.get("token")
if not token:
logger.warn("Auth endpoint: missing token")
return RedirectResponse("/")

db = get_db()
Expand All @@ -28,11 +30,13 @@ async def create_session(request: Request):
auth_token is None
or auth_token.expires_at < datetime.now(timezone.utc)
):
logger.warn("Auth endpoint: token expired or invalid")
return RedirectResponse("/")

db.remove(auth_token)

session = db.create_session(auth_token.user_id)
logger.info(f"Session created for user_id: {auth_token.user_id}")

response = RedirectResponse("/")
response.set_cookie(
Expand Down
8 changes: 7 additions & 1 deletion src/services/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from models.user import User
from services.db import get_db
from services.ldap import authenticate
from utils.logger import logger

db = get_db()

Expand All @@ -27,6 +28,7 @@ def logout(session: Session) -> None:
"""

db.remove(session)
logger.info(f"User logged out: {session.user.ldap_uid}")


def login(uid: str, password: str) -> User|None:
Expand All @@ -43,12 +45,15 @@ def login(uid: str, password: str) -> User|None:
user_infos = authenticate(uid, password)

if user_infos is None or user_infos["uid"] is None:
logger.warn(f"Login failed for user: {uid}")
return None

user = db.get_user(user_infos["uid"])

create_session(user)

logger.info(f"User logged in: {uid}")

return user

def validate_session() -> Session|None:
Expand All @@ -70,6 +75,7 @@ def validate_session() -> Session|None:
return None

if session.expires_at < datetime.now(timezone.utc):
logger.info(f"Session expired, removed for user: {session.user.ldap_uid}")
db.remove(session)
return None

Expand Down
7 changes: 3 additions & 4 deletions src/services/ldap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

from config import LDAP_BASE_DN, LDAP_HOST, LDAP_PASSWORD, LDAP_PORT, LDAP_USER
from ldap3 import Entry, Server, Connection, ALL
from utils.logger import logger

def _get_server():
return Server(LDAP_HOST, port=int(LDAP_PORT), use_ssl=True, get_info=ALL)
Expand All @@ -27,8 +28,7 @@ def _search_user(uid: str, attributes: list[str]) -> Entry | None:
conn.unbind()
return entry
except Exception as e:
print(e)
print("search bind error")
logger.error(f"LDAP search error: {e}")
return None


Expand Down Expand Up @@ -69,8 +69,7 @@ def authenticate(uid: str, password: str) -> dict[str, str|None] | None:
return user_info

except Exception as e:
print(e)
print("auth bind error")
logger.error(f"LDAP auth error for {uid}: {e}")
return None


Expand Down
12 changes: 8 additions & 4 deletions src/services/mail.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from jinja2 import Environment, FileSystemLoader

from config import SMTP_PASSWORD, SMTP_SERVER, SMTP_PORT, SMTP_USERNAME
from utils.logger import logger
from models.project import Project
from models.user import User
from services.db import Db, get_db
Expand Down Expand Up @@ -307,8 +308,11 @@ def _send(self, mail: Mail) -> None:
context = ssl.create_default_context()
context.minimum_version = ssl.TLSVersion.TLSv1_3

with smtplib.SMTP(self._server, self._port, timeout=10) as smtp:
smtp.starttls(context=context)
smtp.login(self._username, self._password)
try:
with smtplib.SMTP(self._server, self._port, timeout=10) as smtp:
smtp.starttls(context=context)
smtp.login(self._username, self._password)

smtp.sendmail(self._sender, all_recipents, msg.as_string())
smtp.sendmail(self._sender, all_recipents, msg.as_string())
except Exception as e:
logger.error(f"Failed to send email to {all_recipents}: {e}")
7 changes: 7 additions & 0 deletions src/tests/unit/test_assignment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from unittest.mock import patch, MagicMock
import pytest
import numpy as np
from utils.assignment import (
assignment_algorithm,
Expand All @@ -7,6 +8,12 @@
)


@pytest.fixture(autouse=True)
def mock_logger():
with patch("utils.assignment.logger") as mock:
yield mock


class TestAssignmentAlgorithm:
"""Tests for the assignment_algorithm function."""

Expand Down
7 changes: 7 additions & 0 deletions src/tests/unit/test_ldap.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from unittest.mock import patch, MagicMock
import pytest
from services.ldap import (
_get_server,
_search_user,
Expand All @@ -7,6 +8,12 @@
)


@pytest.fixture(autouse=True)
def mock_logger():
with patch("services.ldap.logger") as mock:
yield mock


class TestSearchUser:
"""Tests for the _search_user function."""

Expand Down
97 changes: 97 additions & 0 deletions src/tests/unit/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import sys
from unittest.mock import MagicMock

sys.modules["streamlit"] = MagicMock()

import pytest
from unittest.mock import patch, MagicMock
from importlib import reload

from utils.logger import Logger


class TestLoggerSingleton:
"""Tests for the Logger singleton pattern."""

@pytest.fixture(autouse=True)
def reset_singleton(self):
"""Reset singleton instance before each test."""
Logger._instance = None
Logger._initialized = False
yield
Logger._instance = None
Logger._initialized = False

def test_singleton_returns_same_instance(self):
"""Multiple Logger() calls return the same instance."""
l1 = Logger()
l2 = Logger()
assert l1 is l2

def test_is_initialized_returns_true(self):
"""is_initialized returns True after init."""
l = Logger()
assert l.is_initialized() is True


class TestLoggerMethods:
"""Tests for Logger public methods."""

@pytest.fixture(autouse=True)
def reset_singleton(self):
"""Reset singleton instance before each test."""
Logger._instance = None
Logger._initialized = False
yield

@patch("utils.logger.open", new_callable=MagicMock)
def test_info_writes_to_file(self, mock_open):
"""info() writes a log line to file."""
l = Logger()
l.init()
l.info("test message")

mock_open.assert_called()
mock_open.return_value.__enter__.return_value.write.assert_called()

@patch("utils.logger.open", new_callable=MagicMock)
def test_warn_writes_to_file(self, mock_open):
"""warn() writes a log line to file."""
l = Logger()
l.init()
l.warn("warning message")

mock_open.assert_called()

@patch("utils.logger.open", new_callable=MagicMock)
def test_error_writes_to_file(self, mock_open):
"""error() writes a log line to file."""
l = Logger()
l.init()
l.error("error message")

mock_open.assert_called()


class TestLoggerModuleExport:
"""Tests for the module-level logger export."""

@pytest.fixture(autouse=True)
def reset_singleton(self):
"""Reset singleton instance before each test."""
import utils.logger as logger_module
Logger._instance = None
Logger._initialized = False
logger_module._logger = None
yield
Logger._instance = None
Logger._initialized = False
logger_module._logger = None

def test_module_level_logger_is_singleton(self):
"""The module-level logger export is the singleton instance."""
import utils.logger as logger_module
reload(logger_module)

assert logger_module.logger is logger_module._logger
assert logger_module.logger is logger_module.get_logger()
6 changes: 6 additions & 0 deletions src/utils/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from scipy.optimize import linear_sum_assignment

from services.mail import Mailer
from utils.logger import logger

def assignment_algorithm(project_ratings: Sequence[ProjectRating], student_ids: list[int], project_ids: list[int]) -> tuple[list[int], list[int]] :
"""Assignment algorithm.
Expand Down Expand Up @@ -75,6 +76,9 @@ def assign_projects(program_id: int, project_ratings: Sequence[ProjectRating], d

db.assign_project(project_id, student_id)

if len(project_ratings) > 0:
logger.info(f"Assignment complete for program {project_ratings[0].project.program.name}, emails sent")

mailer.project_assignment(program_id)

def remind_students(students: Sequence[User], n_projects: int, mailer: Mailer):
Expand All @@ -94,7 +98,9 @@ def remind_students(students: Sequence[User], n_projects: int, mailer: Mailer):
if len(student.project_ratings) != n_projects:
students_to_remind.append(student)

if students_to_remind:
mailer.students_reminder(students_to_remind, urgent=True)
logger.info(f"Sent reminders to {len(students_to_remind)} students")


def start_assignment(program_id: int):
Expand Down
Loading
Loading