Source code for pytest_routes.discovery.starlette
"""Starlette/FastAPI route extraction."""
from __future__ import annotations
import inspect
import re
from typing import Any, get_origin, get_type_hints
from pytest_routes.discovery.base import RouteExtractor, RouteInfo
[docs]
class StarletteExtractor(RouteExtractor):
"""Extract routes from Starlette and FastAPI applications.
This extractor provides comprehensive route discovery for Starlette-based
frameworks including vanilla Starlette and FastAPI applications. It handles
route mounts, path parameter parsing, and query parameter extraction.
The extractor supports:
- Recursive route collection through Mount instances
- Path parameter extraction with type conversion (int, float, path)
- Query parameter detection from endpoint signatures
- FastAPI-specific features (BaseModel bodies, dependency injection)
- Starlette request/response parameter filtering
Example:
>>> from starlette.applications import Starlette
>>> from starlette.routing import Route
>>>
>>> async def get_user(request):
... user_id = request.path_params["user_id"]
... return JSONResponse({"id": user_id})
>>>
>>> app = Starlette(routes=[Route("/users/{user_id:int}", get_user, methods=["GET"])])
>>> extractor = StarletteExtractor()
>>> routes = extractor.extract_routes(app)
>>> routes[0].path_params
{'user_id': <class 'int'>}
Note:
- HEAD methods are automatically filtered out
- Mount instances are recursively traversed with path prefix accumulation
- FastAPI BaseModel parameters are detected and skipped from query params
"""
[docs]
def supports(self, app: Any) -> bool:
"""Check if the application is a Starlette or FastAPI instance.
Args:
app: The ASGI application to check.
Returns:
True if the app is a Starlette or FastAPI instance, False otherwise.
Note:
Checks for both Starlette and FastAPI classes independently,
returning False if neither framework is installed. This allows
graceful degradation when frameworks are not available.
"""
try:
from starlette.applications import Starlette
if isinstance(app, Starlette):
return True
except ImportError:
pass
try:
from fastapi import FastAPI
if isinstance(app, FastAPI):
return True
except ImportError:
pass
return False
[docs]
def extract_routes(self, app: Any) -> list[RouteInfo]:
"""Extract all HTTP and WebSocket routes from a Starlette or FastAPI application.
This method traverses the application's route registry, recursively handling
Mount instances to collect all routes with their full path prefixes. It extracts
path parameters, query parameters, and route metadata for both HTTP and WebSocket routes.
Args:
app: A Starlette or FastAPI application instance.
Returns:
A list of RouteInfo objects containing route metadata:
path (full route path including mount prefixes), methods (HTTP methods or "WEBSOCKET"),
name (route name), handler (endpoint function), path_params (parameter
name to type mapping parsed from path), query_params (query parameter
mapping), body_type (always None for Starlette - use OpenAPI extractor),
is_websocket (True for WebSocket routes), websocket_metadata (WebSocket config).
Example:
>>> from fastapi import FastAPI, Query, WebSocket
>>> from pydantic import BaseModel
>>>
>>> app = FastAPI()
>>>
>>> class User(BaseModel):
... name: str
... email: str
>>>
>>> @app.get("/users/{user_id}")
>>> async def get_user(user_id: int, include_posts: bool = Query(False)):
... return {"id": user_id, "include_posts": include_posts}
>>>
>>> @app.post("/users")
>>> async def create_user(user: User):
... return {"name": user.name}
>>>
>>> @app.websocket("/ws/chat")
>>> async def websocket_endpoint(websocket: WebSocket):
... await websocket.accept()
... await websocket.send_json({"type": "welcome"})
>>>
>>> extractor = StarletteExtractor()
>>> routes = extractor.extract_routes(app)
>>> len(routes)
3
>>> routes[0].path_params
{'user_id': <class 'int'>}
>>> routes[0].query_params
{'include_posts': <class 'bool'>}
>>> routes[2].is_websocket
True
Note:
- Recursively processes Mount instances to handle sub-applications
- HEAD methods are automatically filtered out for HTTP routes
- Path prefixes from Mount instances are accumulated
- Query parameter extraction handles FastAPI Query/Body annotations
- WebSocket routes have auto_accept=False in their metadata (requires manual accept)
"""
routes: list[RouteInfo] = []
self._collect_routes(app.routes, "", routes)
return routes
def _collect_routes(self, route_list: list[Any], prefix: str, collected: list[RouteInfo]) -> None:
"""Recursively collect routes, handling mounts and WebSocket routes."""
from starlette.routing import Mount, Route, WebSocketRoute
for route in route_list:
if isinstance(route, Mount):
self._collect_routes(route.routes or [], prefix + route.path, collected)
elif isinstance(route, WebSocketRoute):
full_path = prefix + route.path
path_params = self._parse_path_params(full_path)
collected.append(self._build_websocket_route_info(route, full_path, path_params))
elif isinstance(route, Route):
for method in route.methods or ["GET"]:
if method == "HEAD":
continue
full_path = prefix + route.path
path_params = self._parse_path_params(full_path)
collected.append(
RouteInfo(
path=full_path,
methods=[method],
name=route.name,
handler=route.endpoint,
path_params=path_params,
query_params=self._extract_query_params(route.endpoint, path_params),
body_type=None,
)
)
def _parse_path_params(self, path: str) -> dict[str, type]:
"""Parse path parameters from a Starlette path pattern."""
params: dict[str, type] = {}
# Match patterns like {param}, {param:int}, {param:path}
pattern = r"\{([^}:]+)(?::([^}]+))?\}"
for match in re.finditer(pattern, path):
param_name = match.group(1)
param_type = match.group(2)
if param_type == "int":
params[param_name] = int
elif param_type == "float":
params[param_name] = float
else:
params[param_name] = str
return params
def _extract_query_params( # noqa: C901, PLR0912, PLR0915
self, endpoint: Any, path_params: dict[str, type]
) -> dict[str, type]:
"""Extract query parameters from endpoint signature.
Query parameters are function parameters that:
- Are not in path_params
- Are not the Request object (for Starlette)
- Are not request body parameters (for FastAPI)
- Have type annotations or default values
Args:
endpoint: The endpoint function
path_params: Already extracted path parameters
Returns:
Dictionary mapping query param names to their types
"""
if not callable(endpoint):
return {}
query_params: dict[str, type] = {}
try:
sig = inspect.signature(endpoint)
hints = get_type_hints(endpoint)
except (ValueError, TypeError, NameError):
return {}
path_param_names = set(path_params.keys())
for param_name, param in sig.parameters.items():
# Skip path parameters
if param_name in path_param_names:
continue
# Skip request body parameter (commonly named 'data')
if param_name == "data":
continue
# Skip common Starlette/FastAPI framework parameters by name
if param_name in ("request", "response", "websocket", "background_tasks"):
continue
# Skip Request parameter (Starlette)
param_type = hints.get(param_name, param.annotation)
if param_type != inspect.Parameter.empty:
try:
type_name = getattr(param_type, "__name__", str(param_type))
# Skip Request, WebSocket, and other ASGI types
if any(name in str(type_name) for name in ("Request", "WebSocket", "HTTPConnection", "Response")):
continue
# Skip Pydantic BaseModel subclasses (request bodies in FastAPI)
# Check if this type has BaseModel in its MRO
if hasattr(param_type, "__mro__"):
type_names = [t.__name__ for t in param_type.__mro__]
if "BaseModel" in type_names:
continue
except Exception: # noqa: S110
# Ignore errors from getattr or MRO inspection
pass
# For FastAPI, check if the parameter has a Body/Form annotation
# by checking the default value
if param.default != inspect.Parameter.empty:
try:
# FastAPI uses special classes for Body(), File(), Form()
default_class = type(param.default).__name__
if default_class in ("FieldInfo", "Body", "File", "Form"):
# This is a body parameter, not a query parameter
continue
except Exception: # noqa: S110
# Ignore errors from type inspection
pass
# If we have a type hint, use it
if param_type != inspect.Parameter.empty:
# Handle Optional types (Union[X, None])
origin = get_origin(param_type)
if origin is not None:
# For Union types, extract the non-None type
import types
if hasattr(types, "UnionType") and isinstance(param_type, types.UnionType):
# Python 3.10+ union syntax (X | None)
args = param_type.__args__
non_none_types = [t for t in args if t is not type(None)]
if non_none_types:
param_type = non_none_types[0]
elif hasattr(param_type, "__args__"):
# typing.Union syntax
args = param_type.__args__
non_none_types = [t for t in args if t is not type(None)]
if non_none_types:
param_type = non_none_types[0]
# Make sure we have a concrete type
if isinstance(param_type, type):
query_params[param_name] = param_type
else:
query_params[param_name] = str
else:
# No type hint, default to str
query_params[param_name] = str
return query_params
def _build_websocket_route_info(self, route: Any, full_path: str, path_params: dict[str, type]) -> RouteInfo:
"""Build RouteInfo for a Starlette/FastAPI WebSocket route.
Args:
route: A Starlette WebSocketRoute instance.
full_path: The full path including any mount prefixes.
path_params: Already extracted path parameters.
Returns:
RouteInfo configured for WebSocket with appropriate metadata.
"""
from pytest_routes.discovery.base import WebSocketMessageType, WebSocketMetadata
ws_metadata = WebSocketMetadata(
subprotocols=[],
accepted_message_types=[
WebSocketMessageType.TEXT,
WebSocketMessageType.BINARY,
WebSocketMessageType.JSON,
],
sends_message_types=[
WebSocketMessageType.TEXT,
WebSocketMessageType.BINARY,
WebSocketMessageType.JSON,
],
auto_accept=False,
)
return RouteInfo(
path=full_path,
methods=["WEBSOCKET"],
name=route.name,
handler=route.endpoint,
path_params=path_params,
query_params={},
body_type=None,
tags=[],
deprecated=False,
description=None,
is_websocket=True,
websocket_metadata=ws_metadata,
)