This commit is contained in:
Ryan Chen
2025-12-25 07:36:26 -08:00
parent f5e2d68cd2
commit 913875188a
18 changed files with 799 additions and 219 deletions

View File

@@ -6,13 +6,161 @@ from quart_jwt_extended import (
get_jwt_identity,
)
from .models import User
from .oidc_service import OIDCUserService
from oidc_config import oidc_config
import secrets
import httpx
from urllib.parse import urlencode
import hashlib
import base64
user_blueprint = Blueprint("user_api", __name__, url_prefix="/api/user")
# In-memory storage for OIDC state/PKCE (production: use Redis or database)
# Format: {state: {"pkce_verifier": str, "redirect_after_login": str}}
_oidc_sessions = {}
@user_blueprint.route("/oidc/login", methods=["GET"])
async def oidc_login():
"""
Initiate OIDC login flow
Generates PKCE parameters and redirects to Authelia
"""
if not oidc_config.validate_config():
return jsonify({"error": "OIDC not configured"}), 500
try:
# Generate PKCE parameters
code_verifier = secrets.token_urlsafe(64)
# For PKCE, we need code_challenge = BASE64URL(SHA256(code_verifier))
code_challenge = (
base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
.decode()
.rstrip("=")
)
# Generate state for CSRF protection
state = secrets.token_urlsafe(32)
# Store PKCE verifier and state for callback validation
_oidc_sessions[state] = {
"pkce_verifier": code_verifier,
"redirect_after_login": request.args.get("redirect", "/"),
}
# Get authorization endpoint from discovery
discovery = await oidc_config.get_discovery_document()
auth_endpoint = discovery.get("authorization_endpoint")
# Build authorization URL
params = {
"client_id": oidc_config.client_id,
"response_type": "code",
"redirect_uri": oidc_config.redirect_uri,
"scope": "openid email profile",
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}
auth_url = f"{auth_endpoint}?{urlencode(params)}"
return jsonify({"auth_url": auth_url})
except Exception as e:
return jsonify({"error": f"OIDC login failed: {str(e)}"}), 500
@user_blueprint.route("/oidc/callback", methods=["GET"])
async def oidc_callback():
"""
Handle OIDC callback from Authelia
Exchanges authorization code for tokens, verifies ID token, and creates/updates user
"""
# Get authorization code and state from callback
code = request.args.get("code")
state = request.args.get("state")
error = request.args.get("error")
if error:
return jsonify({"error": f"OIDC error: {error}"}), 400
if not code or not state:
return jsonify({"error": "Missing code or state"}), 400
# Validate state and retrieve PKCE verifier
session = _oidc_sessions.pop(state, None)
if not session:
return jsonify({"error": "Invalid or expired state"}), 400
pkce_verifier = session["pkce_verifier"]
# Exchange authorization code for tokens
discovery = await oidc_config.get_discovery_document()
token_endpoint = discovery.get("token_endpoint")
token_data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": oidc_config.redirect_uri,
"client_id": oidc_config.client_id,
"client_secret": oidc_config.client_secret,
"code_verifier": pkce_verifier,
}
# Use client_secret_post method (credentials in POST body)
async with httpx.AsyncClient() as client:
token_response = await client.post(token_endpoint, data=token_data)
if token_response.status_code != 200:
return jsonify({"error": f"Failed to exchange code for token: {token_response.text}"}), 400
tokens = token_response.json()
id_token = tokens.get("id_token")
if not id_token:
return jsonify({"error": "No ID token received"}), 400
# Verify ID token
try:
claims = await oidc_config.verify_id_token(id_token)
except Exception as e:
return jsonify({"error": f"ID token verification failed: {str(e)}"}), 400
# Get or create user from OIDC claims
user = await OIDCUserService.get_or_create_user_from_oidc(claims)
# Issue backend JWT tokens
access_token = create_access_token(identity=str(user.id))
refresh_token = create_refresh_token(identity=str(user.id))
# Return tokens to frontend
# Frontend will handle storing these and redirecting
return jsonify(
access_token=access_token,
refresh_token=refresh_token,
user={"id": str(user.id), "username": user.username, "email": user.email},
)
@user_blueprint.route("/refresh", methods=["POST"])
@jwt_refresh_token_required
async def refresh():
"""Refresh access token (unchanged from original)"""
user_id = get_jwt_identity()
new_token = create_access_token(identity=user_id)
return jsonify(access_token=new_token)
# Legacy username/password login - kept for backward compatibility during migration
@user_blueprint.route("/login", methods=["POST"])
async def login():
"""
Legacy username/password login
This can be removed after full OIDC migration is complete
"""
data = await request.get_json()
username = data.get("username")
password = data.get("password")
@@ -28,13 +176,5 @@ async def login():
return jsonify(
access_token=access_token,
refresh_token=refresh_token,
user={"id": user.id, "username": user.username},
user={"id": str(user.id), "username": user.username},
)
@user_blueprint.route("/refresh", methods=["POST"])
@jwt_refresh_token_required
async def refresh():
user_id = get_jwt_identity()
new_token = create_access_token(identity=user_id)
return jsonify(access_token=new_token)

View File

@@ -8,8 +8,13 @@ import bcrypt
class User(Model):
id = fields.UUIDField(primary_key=True)
username = fields.CharField(max_length=255)
password = fields.BinaryField() # Hashed
password = fields.BinaryField(null=True) # Hashed - nullable for OIDC users
email = fields.CharField(max_length=100, unique=True)
# OIDC fields
oidc_subject = fields.CharField(max_length=255, unique=True, null=True, index=True) # "sub" claim from OIDC
auth_provider = fields.CharField(max_length=50, default="local") # "local" or "oidc"
created_at = fields.DatetimeField(auto_now_add=True)
updated_at = fields.DatetimeField(auto_now=True)
@@ -23,4 +28,6 @@ class User(Model):
)
def verify_password(self, plain_password: str):
if not self.password:
return False
return bcrypt.checkpw(plain_password.encode("utf-8"), self.password)

View File

@@ -0,0 +1,76 @@
"""
OIDC User Management Service
"""
from typing import Dict, Any, Optional
from uuid import uuid4
from .models import User
class OIDCUserService:
"""Service for managing OIDC user authentication and provisioning"""
@staticmethod
async def get_or_create_user_from_oidc(claims: Dict[str, Any]) -> User:
"""
Get existing user by OIDC subject, or create new user from OIDC claims
Args:
claims: Decoded OIDC ID token claims
Returns:
User object (existing or newly created)
"""
oidc_subject = claims.get("sub")
if not oidc_subject:
raise ValueError("No 'sub' claim in ID token")
# Try to find existing user by OIDC subject
user = await User.filter(oidc_subject=oidc_subject).first()
if user:
# Update user info from latest claims (optional)
user.email = claims.get("email", user.email)
user.username = (
claims.get("preferred_username")
or claims.get("name")
or user.username
)
await user.save()
return user
# Check if user exists by email (migration case)
email = claims.get("email")
if email:
user = await User.filter(email=email, auth_provider="local").first()
if user:
# Migrate existing local user to OIDC
user.oidc_subject = oidc_subject
user.auth_provider = "oidc"
user.password = None # Clear password
await user.save()
return user
# Create new user from OIDC claims
username = (
claims.get("preferred_username")
or claims.get("name")
or claims.get("email", "").split("@")[0]
or f"user_{oidc_subject[:8]}"
)
user = await User.create(
id=uuid4(),
username=username,
email=email
or f"{oidc_subject}@oidc.local", # Fallback if no email claim
oidc_subject=oidc_subject,
auth_provider="oidc",
password=None,
)
return user
@staticmethod
async def find_user_by_oidc_subject(oidc_subject: str) -> Optional[User]:
"""Find user by OIDC subject ID"""
return await User.filter(oidc_subject=oidc_subject).first()