Skip to content

Commit c47a019

Browse files
authored
Merge pull request #644 from szmadd/feature/oidc-enhancements
Feature/OIDC enhancements
2 parents c2e7558 + 392549f commit c47a019

2 files changed

Lines changed: 290 additions & 29 deletions

File tree

backend/package/yuxi/services/oidc_service.py

Lines changed: 261 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ class OIDCConfig(BaseModel):
5151
username_claim: str = Field(default="preferred_username", description="用户名映射字段")
5252
email_claim: str = Field(default="email", description="邮箱映射字段")
5353
name_claim: str = Field(default="name", description="姓名映射字段")
54+
use_raw_username: bool = Field(default=False, description="是否使用原始用户名(不带oidc前缀)")
55+
fetch_department_info: bool = Field(default=False, description="是否从OIDC中获取部门信息")
56+
department_claim: str = Field(default="department", description="部门信息映射字段")
57+
force_prompt_login: bool = Field(default=False, description="是否强制用户重新登录(添加prompt=login参数)")
5458

5559
@classmethod
5660
def from_env(cls) -> "OIDCConfig":
@@ -82,6 +86,10 @@ def _env(name: str, default: str = "") -> str:
8286
username_claim=_env("OIDC_USERNAME_CLAIM", "preferred_username"),
8387
email_claim=_env("OIDC_EMAIL_CLAIM", "email"),
8488
name_claim=_env("OIDC_NAME_CLAIM", "name"),
89+
use_raw_username=os.environ.get("OIDC_USE_RAW_USERNAME", "false").lower() == "true",
90+
fetch_department_info=os.environ.get("OIDC_FETCH_DEPARTMENT_INFO", "false").lower() == "true",
91+
department_claim=_env("OIDC_DEPARTMENT_CLAIM", "department"),
92+
force_prompt_login=os.environ.get("OIDC_FORCE_PROMPT_LOGIN", "true").lower() == "true",
8593
)
8694

8795
def is_configured(self) -> bool:
@@ -277,6 +285,10 @@ async def build_authorization_url(cls, redirect_path: str = "/") -> str | None:
277285
"nonce": nonce,
278286
}
279287

288+
# 如果配置强制登录,添加 prompt=login 参数
289+
if oidc_config.force_prompt_login:
290+
params["prompt"] = "login"
291+
280292
query_string = urllib.parse.urlencode(params)
281293
return f"{metadata.authorization_endpoint}?{query_string}"
282294

@@ -374,54 +386,129 @@ def extract_user_info(cls, userinfo: dict[str, Any]) -> dict[str, Any]:
374386
if not name:
375387
name = username
376388

389+
department_name = None
390+
department_description = None
391+
if oidc_config.fetch_department_info:
392+
department_name = userinfo.get(oidc_config.department_claim)
393+
if not department_name:
394+
department_name = userinfo.get("department")
395+
396+
# 获取部门描述
397+
department_description = userinfo.get("department_description")
398+
if not department_description:
399+
department_description = userinfo.get("department_desc")
400+
377401
return {
378402
"sub": sub,
379403
"username": username,
380404
"email": email,
381405
"name": name,
406+
"department_name": department_name,
407+
"department_description": department_description,
382408
"raw": userinfo,
383409
}
384410

385411

386-
async def get_or_create_oidc_department(db) -> Department | None:
387-
"""获取或创建 OIDC 用户的默认部门"""
388-
dept_name = oidc_config.default_department
389-
390-
result = await db.execute(select(Department).filter(Department.name == dept_name))
412+
async def get_or_create_oidc_department(db, dept_name_from_oidc: str | None = None, dept_desc_from_oidc: str | None = None) -> Department | None:
413+
"""获取或创建 OIDC 用户的部门"""
414+
# 清理并验证从 OIDC 获取的部门名称
415+
processed_dept_name = None
416+
processed_dept_desc = None
417+
418+
if dept_name_from_oidc:
419+
# 去除首尾空格
420+
processed_dept_name = dept_name_from_oidc.strip()
421+
# 截断到 50 字符(匹配数据库限制)
422+
if len(processed_dept_name) > 50:
423+
processed_dept_name = processed_dept_name[:50]
424+
# 如果处理后为空,放弃使用
425+
if not processed_dept_name:
426+
processed_dept_name = None
427+
428+
# 清理并验证从 OIDC 获取的部门描述
429+
if dept_desc_from_oidc:
430+
processed_dept_desc = dept_desc_from_oidc.strip()
431+
# 截断到 255 字符(匹配数据库限制)
432+
if len(processed_dept_desc) > 255:
433+
processed_dept_desc = processed_dept_desc[:255]
434+
if not processed_dept_desc:
435+
processed_dept_desc = None
436+
437+
# 最终确定部门名称:优先使用处理后的OIDC部门名称,否则使用默认部门名称
438+
final_dept_name = processed_dept_name or oidc_config.default_department
439+
# 最终确定部门描述:优先使用处理后的OIDC部门描述,否则使用默认描述
440+
final_dept_desc = processed_dept_desc or f"{final_dept_name}部门"
441+
442+
result = await db.execute(select(Department).filter(Department.name == final_dept_name))
391443
dept = result.scalar_one_or_none()
392444

393-
if not dept:
394-
dept = Department(
395-
name=dept_name,
396-
description=f"{dept_name}部门",
397-
)
398-
db.add(dept)
399-
try:
400-
await db.commit()
401-
await db.refresh(dept)
402-
logger.info(f"Created OIDC department: {dept_name}")
403-
except IntegrityError:
404-
await db.rollback()
405-
result = await db.execute(select(Department).filter(Department.name == dept_name))
406-
dept = result.scalar_one_or_none()
445+
if dept:
446+
# 部门已存在,直接返回
447+
logger.info(f"Using existing department: {final_dept_name}")
448+
return dept
449+
450+
# 部门不存在,创建新部门
451+
dept = Department(
452+
name=final_dept_name,
453+
description=final_dept_desc,
454+
)
455+
db.add(dept)
456+
try:
457+
await db.commit()
458+
await db.refresh(dept)
459+
logger.info(f"Created OIDC department: {final_dept_name}")
460+
except IntegrityError:
461+
# 并发创建时部门可能已存在,再次查询
462+
await db.rollback()
463+
result = await db.execute(select(Department).filter(Department.name == final_dept_name))
464+
dept = result.scalar_one_or_none()
407465

408466
return dept
409467

410468

411469
async def find_user_by_oidc_sub(db, sub: str) -> User | None:
412470
"""通过 OIDC sub 查找用户"""
413-
oidc_user_id = f"oidc:{sub}"
414-
415-
result = await db.execute(select(User).filter(User.user_id == oidc_user_id, User.is_deleted == 0))
471+
# 方法1: 检查是否有用户的 user_id 直接等于 "oidc:{sub}"(标准 OIDC 用户)
472+
standard_oidc_user_id = f"oidc:{sub}"
473+
# 占位绑定记录会被标记为 is_deleted=1,但我们仍需要查询它们来获取绑定关系
474+
result = await db.execute(select(User).filter(
475+
User.user_id == standard_oidc_user_id,
476+
User.is_deleted == 0
477+
))
416478
user = result.scalar_one_or_none()
417479
if user:
418480
return user
419481

482+
# 方法2: 检查是否有绑定占位用户格式: "oidc:{sub}:{target_user_id}"(use_raw_username 绑定记录)
483+
# 绑定占位用户被标记为 is_deleted=1,需要包括deleted来查询
420484
legacy_result = await db.execute(
421-
select(User).filter(User.user_id.like(f"{oidc_user_id}:%"), User.is_deleted == 0).order_by(User.id.asc())
485+
select(User).filter(
486+
User.user_id.like(f"{standard_oidc_user_id}:%"),
487+
User.is_deleted.in_([0, 1])
488+
).order_by(User.id.asc())
422489
)
423490
legacy_users = list(legacy_result.scalars().all())
424491
if legacy_users:
492+
# 对于绑定占位用户,user_id 格式为 oidc:{sub}:{target_user_id},解析出 target_user_id 并返回真实用户
493+
for placeholder in legacy_users:
494+
if placeholder.is_deleted != 1:
495+
# 非deleted占位,直接返回
496+
return placeholder
497+
parts = placeholder.user_id.split(":")
498+
if len(parts) >= 3:
499+
try:
500+
target_user_id = int(parts[2])
501+
result = await db.execute(select(User).filter(User.id == target_user_id, User.is_deleted == 0))
502+
target_user = result.scalar_one_or_none()
503+
if target_user:
504+
logger.debug(f"Resolved OIDC binding placeholder {placeholder.user_id} to user {target_user_id}")
505+
return target_user
506+
except ValueError:
507+
continue
508+
# 如果没有解析出有效的目标用户,返回第一个非deleted legacy用户(向后兼容)
509+
for candidate in legacy_users:
510+
if candidate.is_deleted == 0:
511+
return candidate
425512
if len(legacy_users) > 1:
426513
logger.warning(f"Multiple legacy OIDC users matched for sub={sub}, use earliest id={legacy_users[0].id}")
427514
return legacy_users[0]
@@ -438,10 +525,77 @@ async def find_deleted_oidc_user_by_sub(db, sub: str) -> User | None:
438525
if deleted_user:
439526
return deleted_user
440527

528+
# 检查绑定占位格式 oidc:{sub}:{target_user_id}(占位本身是deleted,需要查询目标用户)
441529
legacy_result = await db.execute(
442-
select(User).filter(User.user_id.like(f"{oidc_user_id}:%"), User.is_deleted == 1).order_by(User.id.asc())
530+
select(User).filter(
531+
User.user_id.like(f"{oidc_user_id}:%"),
532+
User.is_deleted == 1
533+
).order_by(User.id.asc())
534+
)
535+
legacy_users = list(legacy_result.scalars().all())
536+
if legacy_users:
537+
for placeholder in legacy_users:
538+
parts = placeholder.user_id.split(":")
539+
if len(parts) >= 3:
540+
try:
541+
target_user_id = int(parts[2])
542+
result = await db.execute(select(User).filter(User.id == target_user_id, User.is_deleted == 1))
543+
target_user = result.scalar_one_or_none()
544+
if target_user:
545+
return target_user
546+
except ValueError:
547+
continue
548+
return legacy_users[0]
549+
return None
550+
551+
552+
async def _create_oidc_binding_placeholder(db, sub: str, target_user: User) -> None:
553+
"""创建 OIDC sub 绑定占位用户(仅用于记录绑定关系,不用于登录)
554+
555+
在 use_raw_username 模式下,我们创建一个占位用户格式: oidc:{sub}:{target_user_id},
556+
占位用户标记为 is_deleted=1(不参与实际登录),仅用于存储绑定关系,
557+
find_user_by_oidc_sub 查询时会读取该占位记录并解析出绑定的真实用户,
558+
这样就能在不修改User表结构的前提下,保持绑定关系可验证,防止账号冒用。
559+
560+
使用传入的同一个 db session,避免跨session一致性问题。
561+
"""
562+
# 占位用户格式: oidc:{sub}:{target_user_id},这样find_user_by_oidc_sub可以解析出目标用户ID
563+
oidc_placeholder_id = f"oidc:{sub}:{target_user.id}"
564+
# 占位用户标记为 deleted,查询时需要特别包括deleted才能找到
565+
result = await db.execute(select(User).filter(User.user_id == oidc_placeholder_id, User.is_deleted.in_([0, 1])))
566+
if result.scalar_one_or_none():
567+
# 占位用户已存在,无需重复创建
568+
return
569+
570+
# 创建占位用户:使用随机密码,标记为deleted,不用于实际登录,仅存储绑定关系
571+
random_password = secrets.token_urlsafe(32)
572+
password_hash = AuthUtils.hash_password(random_password)
573+
574+
# username 使用 oidc-binding-{sub_hash} 避免冲突,sub_hash 基于完整 sub 生成
575+
import hashlib
576+
sub_hash = hashlib.sha256(sub.encode()).hexdigest()[:8]
577+
username = f"oidc-binding-{sub_hash}"
578+
579+
placeholder_user = User(
580+
username=username,
581+
user_id=oidc_placeholder_id,
582+
phone_number=None,
583+
avatar=None,
584+
password_hash=password_hash,
585+
role=target_user.role,
586+
department_id=target_user.department_id,
587+
is_deleted=1, # 标记为deleted,不参与实际登录
588+
last_login=utc_now_naive(),
443589
)
444-
return legacy_result.scalar_one_or_none()
590+
591+
try:
592+
db.add(placeholder_user)
593+
await db.commit()
594+
logger.info(f"Created OIDC binding placeholder (deleted) for sub {sub} -> user {target_user.id} ({target_user.user_id})")
595+
except IntegrityError:
596+
# 并发创建冲突,回滚后忽略
597+
await db.rollback()
598+
logger.info(f"OIDC binding placeholder already exists for sub {sub}")
445599

446600

447601
async def build_unique_oidc_username(db, preferred_username: str, sub: str) -> str:
@@ -478,7 +632,35 @@ async def create_oidc_user(db, user_info: dict, department_id: int | None = None
478632

479633
sub = user_info["sub"]
480634
preferred_username = user_info["name"] or user_info["username"]
481-
user_id = f"oidc:{sub}"
635+
636+
# 根据配置决定用户ID是否带oidc前缀
637+
if oidc_config.use_raw_username:
638+
user_id = user_info["username"]
639+
# 检查用户名是否已存在
640+
result = await db.execute(select(User).filter(User.user_id == user_id, User.is_deleted == 0))
641+
existing_user = result.scalar_one_or_none()
642+
if existing_user:
643+
# 用户已存在,必须验证当前sub是否已经绑定到这个用户
644+
# 如果sub未绑定该用户,不能直接复用,存在账号冒用风险
645+
user_by_sub = await find_user_by_oidc_sub(db, sub)
646+
if user_by_sub and user_by_sub.id == existing_user.id:
647+
# sub 已经正确绑定到该用户,允许返回
648+
logger.info(f"User with raw username {user_id} already exists and bound to sub {sub}, returning existing user")
649+
return existing_user
650+
elif user_by_sub is None:
651+
# sub 尚未绑定任何用户,可以将sub绑定到这个现有用户
652+
logger.info(f"Binding new OIDC sub {sub} to existing user with raw username {user_id}")
653+
await _create_oidc_binding_placeholder(db, sub, existing_user)
654+
return existing_user
655+
else:
656+
# sub 已经绑定到另一个用户,冲突,拒绝创建
657+
logger.warning(f"Cannot create OIDC user with raw username {user_id}: sub {sub} is already bound to another user {user_by_sub.id}, conflict")
658+
raise HTTPException(
659+
status_code=status.HTTP_409_CONFLICT,
660+
detail=f"用户名 {user_id} 已存在且OIDC标识 {sub} 已绑定到其他账号,请联系管理员处理冲突",
661+
)
662+
else:
663+
user_id = f"oidc:{sub}"
482664

483665
random_password = secrets.token_urlsafe(32)
484666
password_hash = AuthUtils.hash_password(random_password)
@@ -500,6 +682,11 @@ async def create_oidc_user(db, user_info: dict, department_id: int | None = None
500682
}
501683
)
502684
logger.info(f"Created OIDC user: {new_user.username} ({user_id})")
685+
686+
# use_raw_username 模式下,创建占位用户记录绑定关系
687+
if oidc_config.use_raw_username:
688+
await _create_oidc_binding_placeholder(db, sub, new_user)
689+
503690
return new_user
504691
except IntegrityError:
505692
existing_user = await find_user_by_oidc_sub(db, sub)
@@ -590,7 +777,50 @@ async def oidc_callback_handler(code: str, state: str, db, request: Request | No
590777
if not sub:
591778
return _redirect_to_login_with_error("无法获取用户标识,请返回登录页重试")
592779

593-
user = await find_user_by_oidc_sub(db, sub)
780+
# 查找用户:总是先通过 sub 查找,保证绑定关系可验证
781+
user_by_sub = await find_user_by_oidc_sub(db, sub)
782+
783+
if oidc_config.use_raw_username:
784+
# 使用原始用户名模式
785+
username = extracted_info["username"]
786+
user = None
787+
if username:
788+
result = await db.execute(select(User).filter(User.user_id == username, User.is_deleted == 0))
789+
user_by_name = result.scalar_one_or_none()
790+
791+
if user_by_sub:
792+
# sub 已经绑定到一个用户
793+
if user_by_name and user_by_sub.id == user_by_name.id:
794+
# sub 绑定的用户就是找到的用户名用户 -> 验证通过
795+
user = user_by_name
796+
logger.info(f"OIDC user logged in with raw username: {username} (sub: {sub})")
797+
else:
798+
# sub 已经绑定到另一个用户,存在冲突,拒绝登录
799+
conflict_name = user_by_sub.username if not user_by_name else user_by_name.username
800+
logger.warning(f"OIDC sub {sub} is already bound to a different user, login rejected to prevent account hijacking (conflict: {conflict_name})")
801+
return _redirect_to_login_with_error("OIDC标识已绑定到其他账号,请联系管理员处理绑定冲突")
802+
else:
803+
# sub 尚未绑定到任何用户
804+
if user_by_name:
805+
# 用户名存在,且 sub 没有绑定 -> 允许登录,并创建绑定记录
806+
# 在不修改表结构的情况下,我们创建一个占位用户 oidc:{sub} 来记录绑定关系
807+
# 这个占位用户不会被用来登录,仅用于存储sub -> 用户的绑定关系
808+
user = user_by_name
809+
logger.info(f"Binding new OIDC sub {sub} to existing user with raw username: {username}")
810+
# 创建绑定占位用户(后台静默创建,不影响现有用户)
811+
await _create_oidc_binding_placeholder(db, sub, user_by_name)
812+
else:
813+
# 用户名不存在,需要创建新用户
814+
if oidc_config.auto_create_user:
815+
user = None # 让后续逻辑创建
816+
else:
817+
return _redirect_to_login_with_error("用户不存在,请联系管理员开通账号")
818+
else:
819+
# 没有获取到 username,回退到按sub查找
820+
user = user_by_sub
821+
else:
822+
# 标准 OIDC 模式,通过 sub 查找
823+
user = user_by_sub
594824

595825
if user:
596826
await update_oidc_user_login(db, user)
@@ -601,7 +831,10 @@ async def oidc_callback_handler(code: str, state: str, db, request: Request | No
601831
user = await restore_deleted_oidc_user(db, deleted_user, extracted_info)
602832
logger.info(f"OIDC deleted user restored and logged in: {user.username}")
603833
else:
604-
dept = await get_or_create_oidc_department(db)
834+
# 从用户信息中获取部门信息
835+
dept_name = extracted_info.get("department_name")
836+
dept_desc = extracted_info.get("department_description")
837+
dept = await get_or_create_oidc_department(db, dept_name, dept_desc)
605838
department_id = dept.id if dept else None
606839
user = await create_oidc_user(db, extracted_info, department_id)
607840
else:

0 commit comments

Comments
 (0)