"""Base route extractor protocol and types."""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
class WebSocketMessageType(Enum):
"""Supported WebSocket message types for testing."""
TEXT = "text"
BINARY = "binary"
JSON = "json"
@dataclass
class WebSocketMetadata:
"""Metadata specific to WebSocket routes.
This dataclass captures WebSocket-specific configuration that differs from
standard HTTP routes. It includes protocol negotiation details, message
format specifications, and connection behavior expectations.
Attributes:
subprotocols: List of supported WebSocket subprotocols (e.g., ["graphql-ws"]).
The client may request one of these during the handshake.
accepted_message_types: List of message types the endpoint accepts.
Used for generating appropriate test messages.
sends_message_types: List of message types the endpoint may send.
Used for response validation.
auto_accept: Whether the server auto-accepts connections (Litestar style)
or requires manual accept() call (Starlette/FastAPI style).
ping_interval: Expected ping interval in seconds, if applicable.
max_message_size: Maximum message size in bytes, if known.
close_codes: Expected close codes for normal shutdown.
"""
subprotocols: list[str] = field(default_factory=list)
accepted_message_types: list[WebSocketMessageType] = field(
default_factory=lambda: [WebSocketMessageType.TEXT, WebSocketMessageType.JSON]
)
sends_message_types: list[WebSocketMessageType] = field(
default_factory=lambda: [WebSocketMessageType.TEXT, WebSocketMessageType.JSON]
)
auto_accept: bool = True
ping_interval: float | None = None
max_message_size: int | None = None
close_codes: list[int] = field(default_factory=lambda: [1000, 1001])
[docs]
@dataclass
class RouteInfo:
"""Normalized route information.
This dataclass represents a discovered route from an ASGI application,
containing all metadata needed for property-based testing. It supports
both HTTP routes and WebSocket endpoints through the is_websocket flag
and optional websocket_metadata field.
Attributes:
path: The route path pattern (e.g., "/users/{user_id}").
methods: HTTP methods for this route (e.g., ["GET", "POST"]).
For WebSocket routes, this is typically ["WEBSOCKET"] or empty.
name: Optional route name from the framework.
handler: Reference to the handler function/coroutine.
path_params: Mapping of path parameter names to their types.
query_params: Mapping of query parameter names to their types.
body_type: Type annotation for the request body, if applicable.
tags: Framework-assigned tags for grouping/categorization.
deprecated: Whether the route is marked as deprecated.
description: Documentation string for the route.
is_websocket: True if this is a WebSocket endpoint.
websocket_metadata: Additional WebSocket-specific configuration.
"""
path: str
methods: list[str]
name: str | None = None
handler: Callable[..., Any] | None = None
path_params: dict[str, type] = field(default_factory=dict)
query_params: dict[str, type] = field(default_factory=dict)
body_type: type | None = None
tags: list[str] = field(default_factory=list)
deprecated: bool = False
description: str | None = None
is_websocket: bool = False
websocket_metadata: WebSocketMetadata | None = None
def __repr__(self) -> str:
if self.is_websocket:
return f"RouteInfo(WS {self.path})"
methods_str = ",".join(self.methods)
return f"RouteInfo({methods_str} {self.path})"
@property
def is_http(self) -> bool:
"""Check if this is an HTTP route (not WebSocket)."""
return not self.is_websocket