"""Mock implementations for testing OAuth providers and services.
This module provides configurable mock objects that can be used in tests to simulate
OAuth provider behavior without making actual HTTP requests.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from unittest.mock import AsyncMock
if TYPE_CHECKING:
from collections.abc import Mapping
@dataclass
class MockHTTPResponse:
"""Mock HTTP response for testing OAuth provider HTTP interactions.
This class simulates httpx.Response objects with configurable status codes,
headers, and response bodies.
Args:
status_code: HTTP status code (default: 200)
headers: Response headers (default: empty dict)
json_data: JSON response body (default: empty dict)
text_data: Text response body (default: empty string)
content: Raw bytes response body (default: empty bytes)
Example:
>>> response = MockHTTPResponse(status_code=200, json_data={"access_token": "mock_token"})
>>> assert response.json() == {"access_token": "mock_token"}
"""
status_code: int = 200
headers: dict[str, str] = field(default_factory=dict)
json_data: dict[str, Any] = field(default_factory=dict)
text_data: str = ""
content: bytes = b""
def json(self) -> dict[str, Any]:
"""Return JSON response data."""
return self.json_data
def text(self) -> str:
"""Return text response data."""
return self.text_data
def raise_for_status(self) -> None:
"""Raise exception if status code indicates error."""
if self.status_code >= 400:
msg = f"HTTP {self.status_code}"
raise Exception(msg)
[docs]
@dataclass
class MockOAuthProvider:
"""Configurable mock OAuth provider for testing.
This mock implements the OAuthProvider protocol and can be configured to return
specific responses for testing different OAuth flow scenarios.
Args:
provider_name: Provider identifier (default: "mock")
authorize_url: Authorization endpoint URL
token_url: Token exchange endpoint URL
user_info_url: User info endpoint URL
scope: OAuth scopes to request
configured: Whether provider has valid configuration
access_token: Mock access token to return
refresh_token: Mock refresh token to return
user_info: Mock user info to return
raise_on_exchange: If True, raise exception during code exchange
raise_on_refresh: If True, raise exception during token refresh
raise_on_user_info: If True, raise exception when fetching user info
Example:
>>> from litestar_oauth.types import OAuthUserInfo
>>> provider = MockOAuthProvider(
... provider_name="github",
... access_token="gho_mock_token",
... user_info=OAuthUserInfo(
... provider="github",
... oauth_id="12345",
... email="test@example.com",
... ),
... )
>>> url = provider.get_authorization_url("http://localhost/callback", "state123")
>>> assert "state=state123" in url
"""
provider_name: str = "mock"
authorize_url: str = "https://oauth.mock/authorize"
token_url: str = "https://oauth.mock/token"
user_info_url: str = "https://oauth.mock/userinfo"
scope: str = "user:email"
configured: bool = True
access_token: str = "mock_access_token"
mock_refresh_token: str | None = "mock_refresh_token"
user_info: Any = None # Should be OAuthUserInfo but avoiding import
raise_on_exchange: bool = False
raise_on_refresh: bool = False
raise_on_user_info: bool = False
_http_client: AsyncMock | None = None
[docs]
def get_authorization_url(
self,
redirect_uri: str,
state: str,
*,
scope: str | None = None,
extra_params: dict[str, str] | None = None,
) -> str:
"""Generate OAuth authorization URL.
Args:
redirect_uri: Callback URL after authorization
state: CSRF protection token
scope: Optional custom scopes (defaults to provider scope)
extra_params: Additional query parameters
Returns:
Full authorization URL with query parameters
"""
params = {
"client_id": "mock_client_id",
"redirect_uri": redirect_uri,
"state": state,
"scope": scope or self.scope,
"response_type": "code",
}
if extra_params:
params.update(extra_params)
query_string = "&".join(f"{k}={v}" for k, v in params.items())
return f"{self.authorize_url}?{query_string}"
[docs]
async def exchange_code(
self,
code: str,
redirect_uri: str,
) -> Any: # Should return OAuthToken
"""Exchange authorization code for access token.
Args:
code: Authorization code from OAuth callback
redirect_uri: Callback URL (must match initial request)
Returns:
OAuthToken with access token and optional refresh token
Raises:
Exception: If raise_on_exchange is True
"""
if self.raise_on_exchange:
msg = "Mock token exchange failure"
raise Exception(msg)
# Import here to avoid circular dependency
from datetime import datetime, timezone
# Dynamically import OAuthToken to avoid circular imports
# In real tests, this would be imported at module level
token_data = {
"access_token": self.access_token,
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": self.mock_refresh_token,
"scope": self.scope,
"raw_response": {
"access_token": self.access_token,
"token_type": "Bearer",
"expires_in": 3600,
"created_at": int(datetime.now(timezone.utc).timestamp()),
},
}
# Create a simple object with the expected attributes
class MockToken:
def __init__(self, data: dict[str, Any]) -> None:
self.access_token = data["access_token"]
self.token_type = data["token_type"]
self.expires_in = data["expires_in"]
self.refresh_token = data.get("refresh_token")
self.scope = data.get("scope")
self.id_token = data.get("id_token")
self.raw_response = data.get("raw_response", {})
return MockToken(token_data)
[docs]
async def refresh_token(
self,
refresh_token: str,
) -> Any: # Should return OAuthToken
"""Refresh an expired access token.
Args:
refresh_token: Refresh token from original token exchange
Returns:
OAuthToken with new access token
Raises:
Exception: If raise_on_refresh is True
"""
if self.raise_on_refresh:
msg = "Mock token refresh failure"
raise Exception(msg)
return await self.exchange_code("refresh_code", "http://mock/callback")
[docs]
async def get_user_info(
self,
access_token: str,
) -> Any: # Should return OAuthUserInfo
"""Fetch user profile information.
Args:
access_token: Valid OAuth access token
Returns:
OAuthUserInfo with user profile data
Raises:
Exception: If raise_on_user_info is True
"""
if self.raise_on_user_info:
msg = "Mock user info fetch failure"
raise Exception(msg)
if self.user_info:
return self.user_info
# Return default mock user info
class MockUserInfo:
def __init__(self) -> None:
self.provider = "mock"
self.oauth_id = "mock_user_123"
self.email = "mock@example.com"
self.email_verified = True
self.username = "mockuser"
self.first_name = "Mock"
self.last_name = "User"
self.avatar_url = "https://example.com/avatar.jpg"
self.profile_url = "https://example.com/mockuser"
self.raw_data = {}
return MockUserInfo()
[docs]
async def revoke_token(
self,
token: str,
*,
token_type_hint: str = "access_token",
) -> None:
"""Revoke an access or refresh token.
Args:
token: Token to revoke
token_type_hint: Type of token (access_token or refresh_token)
"""
# Mock implementation does nothing
class MockOAuthService:
"""Pre-configured mock OAuth service for testing.
This class provides a complete mock OAuth service with state management
and provider registration capabilities for integration testing.
Args:
providers: Optional mapping of provider names to MockOAuthProvider instances
Example:
>>> from litestar_oauth.testing.mocks import MockOAuthService
>>> service = MockOAuthService()
>>> await service.register_mock_provider("github")
>>> state = await service.create_state("github", "http://localhost/callback")
>>> assert state is not None
"""
def __init__(
self,
providers: Mapping[str, MockOAuthProvider] | None = None,
) -> None:
"""Initialize mock OAuth service."""
self._providers: dict[str, MockOAuthProvider] = dict(providers) if providers else {}
self._states: dict[str, dict[str, Any]] = {}
def register(self, provider: MockOAuthProvider) -> None:
"""Register a mock OAuth provider.
Args:
provider: MockOAuthProvider instance to register
"""
self._providers[provider.provider_name] = provider
def get_provider(self, provider_name: str) -> MockOAuthProvider | None:
"""Get a registered provider by name.
Args:
provider_name: Name of the provider to retrieve
Returns:
MockOAuthProvider instance or None if not found
"""
return self._providers.get(provider_name)
async def create_state(
self,
provider: str,
redirect_uri: str,
*,
next_url: str | None = None,
ttl: int = 600,
) -> str:
"""Create a new OAuth state token.
Args:
provider: Provider name for this state
redirect_uri: Callback URL
next_url: Optional URL to redirect to after OAuth flow
ttl: Time-to-live in seconds (ignored in mock)
Returns:
State token string
"""
import secrets
state = secrets.token_urlsafe(32)
self._states[state] = {
"provider": provider,
"redirect_uri": redirect_uri,
"next_url": next_url,
"created_at": "2024-01-01T00:00:00Z",
}
return state
async def validate_state(self, state: str) -> dict[str, Any] | None:
"""Validate and consume a state token.
Args:
state: State token to validate
Returns:
State data dict or None if invalid
"""
return self._states.pop(state, None)
async def exchange_code(
self,
provider_name: str,
code: str,
redirect_uri: str,
) -> Any: # Should return OAuthToken
"""Exchange authorization code for access token.
Args:
provider_name: Name of the OAuth provider
code: Authorization code from callback
redirect_uri: Callback URL (must match initial request)
Returns:
OAuthToken from provider
Raises:
ValueError: If provider not found
"""
provider = self.get_provider(provider_name)
if not provider:
msg = f"Provider {provider_name} not registered"
raise ValueError(msg)
return await provider.exchange_code(code, redirect_uri)
async def get_user_info(
self,
provider_name: str,
access_token: str,
) -> Any: # Should return OAuthUserInfo
"""Fetch user information from provider.
Args:
provider_name: Name of the OAuth provider
access_token: Valid access token
Returns:
OAuthUserInfo from provider
Raises:
ValueError: If provider not found
"""
provider = self.get_provider(provider_name)
if not provider:
msg = f"Provider {provider_name} not registered"
raise ValueError(msg)
return await provider.get_user_info(access_token)
async def register_mock_provider(
self,
provider_name: str,
**kwargs: Any,
) -> MockOAuthProvider:
"""Convenience method to create and register a mock provider.
Args:
provider_name: Name for the provider
**kwargs: Additional arguments for MockOAuthProvider
Returns:
Created and registered MockOAuthProvider instance
"""
provider = MockOAuthProvider(provider_name=provider_name, **kwargs)
self.register(provider)
return provider
__all__ = [
"MockHTTPResponse",
"MockOAuthProvider",
"MockOAuthService",
]