reorganization
This commit is contained in:
188
blueprints/users/__init__.py
Normal file
188
blueprints/users/__init__.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from quart import Blueprint, jsonify, request
|
||||
from quart_jwt_extended import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
jwt_refresh_token_required,
|
||||
get_jwt_identity,
|
||||
)
|
||||
from .models import User
|
||||
from .oidc_service import OIDCUserService
|
||||
from config.oidc_config import oidc_config
|
||||
import secrets
|
||||
import httpx
|
||||
from urllib.parse import urlencode
|
||||
import hashlib
|
||||
import base64
|
||||
|
||||
|
||||
user_blueprint = Blueprint("user_api", __name__, url_prefix="/api/user")
|
||||
|
||||
# In-memory storage for OIDC state/PKCE (production: use Redis or database)
|
||||
# Format: {state: {"pkce_verifier": str, "redirect_after_login": str}}
|
||||
_oidc_sessions = {}
|
||||
|
||||
|
||||
@user_blueprint.route("/oidc/login", methods=["GET"])
|
||||
async def oidc_login():
|
||||
"""
|
||||
Initiate OIDC login flow
|
||||
Generates PKCE parameters and redirects to Authelia
|
||||
"""
|
||||
if not oidc_config.validate_config():
|
||||
return jsonify({"error": "OIDC not configured"}), 500
|
||||
|
||||
try:
|
||||
# Generate PKCE parameters
|
||||
code_verifier = secrets.token_urlsafe(64)
|
||||
|
||||
# For PKCE, we need code_challenge = BASE64URL(SHA256(code_verifier))
|
||||
code_challenge = (
|
||||
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
|
||||
.decode()
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
# Generate state for CSRF protection
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
# Store PKCE verifier and state for callback validation
|
||||
_oidc_sessions[state] = {
|
||||
"pkce_verifier": code_verifier,
|
||||
"redirect_after_login": request.args.get("redirect", "/"),
|
||||
}
|
||||
|
||||
# Get authorization endpoint from discovery
|
||||
discovery = await oidc_config.get_discovery_document()
|
||||
auth_endpoint = discovery.get("authorization_endpoint")
|
||||
|
||||
# Build authorization URL
|
||||
params = {
|
||||
"client_id": oidc_config.client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": oidc_config.redirect_uri,
|
||||
"scope": "openid email profile groups",
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
auth_url = f"{auth_endpoint}?{urlencode(params)}"
|
||||
|
||||
return jsonify({"auth_url": auth_url})
|
||||
except Exception as e:
|
||||
return jsonify({"error": f"OIDC login failed: {str(e)}"}), 500
|
||||
|
||||
|
||||
@user_blueprint.route("/oidc/callback", methods=["GET"])
|
||||
async def oidc_callback():
|
||||
"""
|
||||
Handle OIDC callback from Authelia
|
||||
Exchanges authorization code for tokens, verifies ID token, and creates/updates user
|
||||
"""
|
||||
# Get authorization code and state from callback
|
||||
code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
|
||||
if error:
|
||||
return jsonify({"error": f"OIDC error: {error}"}), 400
|
||||
|
||||
if not code or not state:
|
||||
return jsonify({"error": "Missing code or state"}), 400
|
||||
|
||||
# Validate state and retrieve PKCE verifier
|
||||
session = _oidc_sessions.pop(state, None)
|
||||
if not session:
|
||||
return jsonify({"error": "Invalid or expired state"}), 400
|
||||
|
||||
pkce_verifier = session["pkce_verifier"]
|
||||
|
||||
# Exchange authorization code for tokens
|
||||
discovery = await oidc_config.get_discovery_document()
|
||||
token_endpoint = discovery.get("token_endpoint")
|
||||
|
||||
token_data = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": oidc_config.redirect_uri,
|
||||
"client_id": oidc_config.client_id,
|
||||
"client_secret": oidc_config.client_secret,
|
||||
"code_verifier": pkce_verifier,
|
||||
}
|
||||
|
||||
# Use client_secret_post method (credentials in POST body)
|
||||
async with httpx.AsyncClient() as client:
|
||||
token_response = await client.post(token_endpoint, data=token_data)
|
||||
|
||||
if token_response.status_code != 200:
|
||||
return jsonify(
|
||||
{"error": f"Failed to exchange code for token: {token_response.text}"}
|
||||
), 400
|
||||
|
||||
tokens = token_response.json()
|
||||
|
||||
id_token = tokens.get("id_token")
|
||||
if not id_token:
|
||||
return jsonify({"error": "No ID token received"}), 400
|
||||
|
||||
# Verify ID token
|
||||
try:
|
||||
claims = await oidc_config.verify_id_token(id_token)
|
||||
except Exception as e:
|
||||
return jsonify({"error": f"ID token verification failed: {str(e)}"}), 400
|
||||
|
||||
# Get or create user from OIDC claims
|
||||
user = await OIDCUserService.get_or_create_user_from_oidc(claims)
|
||||
|
||||
# Issue backend JWT tokens
|
||||
access_token = create_access_token(identity=str(user.id))
|
||||
refresh_token = create_refresh_token(identity=str(user.id))
|
||||
|
||||
# Return tokens to frontend
|
||||
# Frontend will handle storing these and redirecting
|
||||
return jsonify(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user={
|
||||
"id": str(user.id),
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"groups": user.ldap_groups,
|
||||
"is_admin": user.is_admin(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@user_blueprint.route("/refresh", methods=["POST"])
|
||||
@jwt_refresh_token_required
|
||||
async def refresh():
|
||||
"""Refresh access token (unchanged from original)"""
|
||||
user_id = get_jwt_identity()
|
||||
new_token = create_access_token(identity=user_id)
|
||||
return jsonify(access_token=new_token)
|
||||
|
||||
|
||||
# Legacy username/password login - kept for backward compatibility during migration
|
||||
@user_blueprint.route("/login", methods=["POST"])
|
||||
async def login():
|
||||
"""
|
||||
Legacy username/password login
|
||||
This can be removed after full OIDC migration is complete
|
||||
"""
|
||||
data = await request.get_json()
|
||||
username = data.get("username")
|
||||
password = data.get("password")
|
||||
|
||||
user = await User.filter(username=username).first()
|
||||
|
||||
if not user or not user.verify_password(password):
|
||||
return jsonify({"msg": "Invalid credentials"}), 401
|
||||
|
||||
access_token = create_access_token(identity=str(user.id))
|
||||
refresh_token = create_refresh_token(identity=str(user.id))
|
||||
|
||||
return jsonify(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user={"id": str(user.id), "username": user.username},
|
||||
)
|
||||
26
blueprints/users/decorators.py
Normal file
26
blueprints/users/decorators.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Authentication decorators for role-based access control.
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from quart import jsonify
|
||||
from quart_jwt_extended import jwt_refresh_token_required, get_jwt_identity
|
||||
from .models import User
|
||||
|
||||
|
||||
def admin_required(fn):
|
||||
"""
|
||||
Decorator that requires the user to be an admin (member of lldap_admin group).
|
||||
Must be used on async route handlers.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
@jwt_refresh_token_required
|
||||
async def wrapper(*args, **kwargs):
|
||||
user_id = get_jwt_identity()
|
||||
user = await User.get_or_none(id=user_id)
|
||||
if not user or not user.is_admin():
|
||||
return jsonify({"error": "Admin access required"}), 403
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
46
blueprints/users/models.py
Normal file
46
blueprints/users/models.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from tortoise.models import Model
|
||||
from tortoise import fields
|
||||
|
||||
|
||||
import bcrypt
|
||||
|
||||
|
||||
class User(Model):
|
||||
id = fields.UUIDField(primary_key=True)
|
||||
username = fields.CharField(max_length=255)
|
||||
password = fields.BinaryField(null=True) # Hashed - nullable for OIDC users
|
||||
email = fields.CharField(max_length=100, unique=True)
|
||||
|
||||
# OIDC fields
|
||||
oidc_subject = fields.CharField(
|
||||
max_length=255, unique=True, null=True, index=True
|
||||
) # "sub" claim from OIDC
|
||||
auth_provider = fields.CharField(
|
||||
max_length=50, default="local"
|
||||
) # "local" or "oidc"
|
||||
ldap_groups = fields.JSONField(default=[]) # LDAP groups from OIDC claims
|
||||
|
||||
created_at = fields.DatetimeField(auto_now_add=True)
|
||||
updated_at = fields.DatetimeField(auto_now=True)
|
||||
|
||||
class Meta:
|
||||
table = "users"
|
||||
|
||||
def has_group(self, group: str) -> bool:
|
||||
"""Check if user belongs to a specific LDAP group."""
|
||||
return group in (self.ldap_groups or [])
|
||||
|
||||
def is_admin(self) -> bool:
|
||||
"""Check if user is an admin (member of lldap_admin group)."""
|
||||
return self.has_group("lldap_admin")
|
||||
|
||||
def set_password(self, plain_password: str):
|
||||
self.password = bcrypt.hashpw(
|
||||
plain_password.encode("utf-8"),
|
||||
bcrypt.gensalt(),
|
||||
)
|
||||
|
||||
def verify_password(self, plain_password: str):
|
||||
if not self.password:
|
||||
return False
|
||||
return bcrypt.checkpw(plain_password.encode("utf-8"), self.password)
|
||||
81
blueprints/users/oidc_service.py
Normal file
81
blueprints/users/oidc_service.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
OIDC User Management Service
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from uuid import uuid4
|
||||
from .models import User
|
||||
|
||||
|
||||
class OIDCUserService:
|
||||
"""Service for managing OIDC user authentication and provisioning"""
|
||||
|
||||
@staticmethod
|
||||
async def get_or_create_user_from_oidc(claims: Dict[str, Any]) -> User:
|
||||
"""
|
||||
Get existing user by OIDC subject, or create new user from OIDC claims
|
||||
|
||||
Args:
|
||||
claims: Decoded OIDC ID token claims
|
||||
|
||||
Returns:
|
||||
User object (existing or newly created)
|
||||
"""
|
||||
oidc_subject = claims.get("sub")
|
||||
if not oidc_subject:
|
||||
raise ValueError("No 'sub' claim in ID token")
|
||||
|
||||
# Try to find existing user by OIDC subject
|
||||
user = await User.filter(oidc_subject=oidc_subject).first()
|
||||
|
||||
if user:
|
||||
# Update user info from latest claims (optional)
|
||||
user.email = claims.get("email", user.email)
|
||||
user.username = (
|
||||
claims.get("preferred_username") or claims.get("name") or user.username
|
||||
)
|
||||
# Update LDAP groups from claims
|
||||
user.ldap_groups = claims.get("groups", [])
|
||||
await user.save()
|
||||
return user
|
||||
|
||||
# Check if user exists by email (migration case)
|
||||
email = claims.get("email")
|
||||
if email:
|
||||
user = await User.filter(email=email, auth_provider="local").first()
|
||||
if user:
|
||||
# Migrate existing local user to OIDC
|
||||
user.oidc_subject = oidc_subject
|
||||
user.auth_provider = "oidc"
|
||||
user.password = None # Clear password
|
||||
user.ldap_groups = claims.get("groups", [])
|
||||
await user.save()
|
||||
return user
|
||||
|
||||
# Create new user from OIDC claims
|
||||
username = (
|
||||
claims.get("preferred_username")
|
||||
or claims.get("name")
|
||||
or claims.get("email", "").split("@")[0]
|
||||
or f"user_{oidc_subject[:8]}"
|
||||
)
|
||||
|
||||
# Extract LDAP groups from claims
|
||||
groups = claims.get("groups", [])
|
||||
|
||||
user = await User.create(
|
||||
id=uuid4(),
|
||||
username=username,
|
||||
email=email or f"{oidc_subject}@oidc.local", # Fallback if no email claim
|
||||
oidc_subject=oidc_subject,
|
||||
auth_provider="oidc",
|
||||
password=None,
|
||||
ldap_groups=groups,
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def find_user_by_oidc_subject(oidc_subject: str) -> Optional[User]:
|
||||
"""Find user by OIDC subject ID"""
|
||||
return await User.filter(oidc_subject=oidc_subject).first()
|
||||
Reference in New Issue
Block a user