"""OAuth2 service for managing providers and state.
This module provides the central OAuthService class that coordinates OAuth
providers, manages state for security, and provides a high-level API for
OAuth operations.
"""
from __future__ import annotations
import secrets
from collections.abc import Mapping
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any
from litestar_oauth.exceptions import ExpiredStateError, InvalidStateError, ProviderNotConfiguredError
from litestar_oauth.types import OAuthState
if TYPE_CHECKING:
from litestar_oauth.base import OAuthProvider
__all__ = (
"OAuthService",
"OAuthStateManager",
)
[docs]
class OAuthStateManager:
"""In-memory state manager for OAuth2 CSRF protection.
This manager generates, stores, and validates state tokens used in OAuth
flows to prevent CSRF attacks. States are stored in memory with a TTL.
Attributes:
default_ttl: Default time-to-live for state tokens in seconds.
"""
[docs]
def __init__(self, default_ttl: int = 600) -> None:
"""Initialize the state manager.
Args:
default_ttl: Default lifetime for state tokens in seconds. Defaults to 600 (10 minutes).
"""
self.default_ttl = default_ttl
self._states: dict[str, OAuthState] = {}
[docs]
def generate_state(
self,
provider: str,
redirect_uri: str,
next_url: str | None = None,
extra_data: dict[str, Any] | None = None,
ttl: int | None = None,
) -> OAuthState:
"""Generate a new OAuth state token.
Args:
provider: Name of the OAuth provider.
redirect_uri: URI for OAuth callback.
next_url: Optional URL to redirect to after authentication.
extra_data: Optional additional data to store with state.
ttl: Optional custom TTL for this state in seconds.
Returns:
Generated OAuth state object.
"""
state_string = secrets.token_urlsafe(32)
oauth_state = OAuthState(
state=state_string,
provider=provider,
redirect_uri=redirect_uri,
next_url=next_url,
extra_data=extra_data or {},
)
self._states[state_string] = oauth_state
# Schedule cleanup for expired states
if ttl is None:
ttl = self.default_ttl
# Note: In production, consider using a background task or cache with TTL
# This is a simple in-memory implementation
return oauth_state
[docs]
def validate_state(self, state: str, provider: str | None = None) -> OAuthState:
"""Validate and retrieve an OAuth state.
Args:
state: State string to validate.
provider: Optional provider name to verify against.
Returns:
The validated OAuth state object.
Raises:
InvalidStateError: If state is not found or provider doesn't match.
ExpiredStateError: If state has exceeded its TTL.
"""
oauth_state = self._states.get(state)
if oauth_state is None:
raise InvalidStateError(f"State not found: {state}")
# Check expiration
age = datetime.now(timezone.utc) - oauth_state.created_at
if age.total_seconds() > self.default_ttl:
self._states.pop(state, None)
raise ExpiredStateError(f"State has expired: {state}")
# Verify provider if specified
if provider is not None and oauth_state.provider != provider:
raise InvalidStateError(f"Provider mismatch: expected {provider}, got {oauth_state.provider}")
return oauth_state
[docs]
def consume_state(self, state: str, provider: str | None = None) -> OAuthState:
"""Validate and remove an OAuth state (one-time use).
Args:
state: State string to consume.
provider: Optional provider name to verify against.
Returns:
The validated OAuth state object.
Raises:
InvalidStateError: If state is not found or provider doesn't match.
ExpiredStateError: If state has exceeded its TTL.
"""
oauth_state = self.validate_state(state, provider)
self._states.pop(state, None)
return oauth_state
[docs]
def cleanup_expired(self) -> int:
"""Remove all expired states from storage.
Returns:
Number of expired states removed.
"""
now = datetime.now(timezone.utc)
expired_states = [
state_str
for state_str, oauth_state in self._states.items()
if (now - oauth_state.created_at).total_seconds() > self.default_ttl
]
for state_str in expired_states:
self._states.pop(state_str, None)
return len(expired_states)
[docs]
def clear(self) -> None:
"""Remove all states from storage.
This is primarily useful for testing or application shutdown.
"""
self._states.clear()
[docs]
class OAuthService:
"""Central service for OAuth2 provider management and operations.
This service manages multiple OAuth providers, handles state management for
security, and provides a unified API for OAuth operations.
Attributes:
providers: Registry of configured OAuth providers.
state_manager: Manager for OAuth state tokens.
"""
[docs]
def __init__(
self,
providers: Mapping[str, OAuthProvider] | None = None,
state_manager: OAuthStateManager | None = None,
) -> None:
"""Initialize the OAuth service.
Args:
providers: Optional mapping of provider names to provider instances.
state_manager: Optional custom state manager. If not provided, uses default.
"""
self.providers: dict[str, OAuthProvider] = dict(providers) if providers else {}
self.state_manager = state_manager or OAuthStateManager()
[docs]
def register(self, provider: OAuthProvider) -> None:
"""Register an OAuth provider with the service.
Args:
provider: Provider instance to register.
Raises:
ValueError: If a provider with the same name is already registered.
"""
if provider.provider_name in self.providers:
raise ValueError(f"Provider '{provider.provider_name}' is already registered")
self.providers[provider.provider_name] = provider
[docs]
def get_provider(self, provider_name: str) -> OAuthProvider:
"""Retrieve a registered OAuth provider.
Args:
provider_name: Name of the provider to retrieve.
Returns:
The requested OAuth provider instance.
Raises:
ProviderNotConfiguredError: If the provider is not registered.
"""
provider = self.providers.get(provider_name)
if provider is None:
raise ProviderNotConfiguredError(
f"Provider '{provider_name}' is not configured. Available providers: {', '.join(self.list_providers())}"
)
if not provider.is_configured():
raise ProviderNotConfiguredError(f"Provider '{provider_name}' is registered but not properly configured")
return provider
[docs]
def list_providers(self) -> list[str]:
"""Get names of all registered providers.
Returns:
List of provider names currently registered.
"""
return list(self.providers.keys())
[docs]
async def get_authorization_url(
self,
provider_name: str,
redirect_uri: str,
next_url: str | None = None,
extra_data: dict[str, Any] | None = None,
**kwargs: Any,
) -> str:
"""Generate an authorization URL for a provider.
This is the first step in the OAuth flow. The generated URL includes a
secure state parameter for CSRF protection.
Args:
provider_name: Name of the OAuth provider to use.
redirect_uri: URI where the provider should redirect after authorization.
next_url: Optional URL to redirect to after successful authentication.
extra_data: Optional additional data to preserve across the OAuth flow.
**kwargs: Additional provider-specific parameters.
Returns:
Complete authorization URL to redirect the user to.
Raises:
ProviderNotConfiguredError: If the provider is not available.
"""
provider = self.get_provider(provider_name)
# Generate and store state
oauth_state = self.state_manager.generate_state(
provider=provider_name,
redirect_uri=redirect_uri,
next_url=next_url,
extra_data=extra_data,
)
# Get authorization URL from provider
return await provider.get_authorization_url(
redirect_uri=redirect_uri,
state=oauth_state.state,
**kwargs,
)
[docs]
@classmethod
def from_config(
cls,
providers_config: Mapping[str, Mapping[str, Any]],
provider_classes: Mapping[str, type[OAuthProvider]],
state_ttl: int = 600,
) -> OAuthService:
"""Create an OAuthService from configuration dictionaries.
This factory method simplifies service setup when configuration comes
from files, environment variables, or other sources.
Args:
providers_config: Mapping of provider names to their configuration.
Each config should include 'client_id', 'client_secret', and optional 'scope'.
provider_classes: Mapping of provider names to their implementation classes.
state_ttl: Time-to-live for state tokens in seconds.
Returns:
Configured OAuthService instance.
Example::
from litestar_oauth.providers import GoogleProvider, GitHubProvider
config = {
"google": {
"client_id": "your-client-id",
"client_secret": "your-secret",
"scope": ["openid", "email", "profile"],
},
"github": {
"client_id": "your-client-id",
"client_secret": "your-secret",
},
}
classes = {
"google": GoogleProvider,
"github": GitHubProvider,
}
service = OAuthService.from_config(config, classes)
"""
providers: dict[str, OAuthProvider] = {}
for provider_name, provider_config in providers_config.items():
if provider_name not in provider_classes:
continue
provider_class = provider_classes[provider_name]
providers[provider_name] = provider_class(**provider_config)
state_manager = OAuthStateManager(default_ttl=state_ttl)
return cls(providers=providers, state_manager=state_manager)