Source code for flask_limiter.wrappers

from __future__ import annotations

import dataclasses
import typing
import weakref
from typing import Callable, Iterator, List, Optional, Tuple, Union

from flask import request
from flask.wrappers import Response
from limits import RateLimitItem, parse_many
from limits.strategies import RateLimiter

if typing.TYPE_CHECKING:
    from .extension import Limiter


[docs] class RequestLimit: """ Provides details of a rate limit within the context of a request """ #: The instance of the rate limit limit: RateLimitItem #: The full key for the request against which the rate limit is tested key: str #: Whether the limit was breached within the context of this request breached: bool #: Whether the limit is a shared limit shared: bool def __init__( self, extension: Limiter, limit: RateLimitItem, request_args: List[str], breached: bool, shared: bool, ) -> None: self.extension: weakref.ProxyType[Limiter] = weakref.proxy(extension) self.limit = limit self.request_args = request_args self.key = limit.key_for(*request_args) self.breached = breached self.shared = shared self._window: Optional[Tuple[int, int]] = None @property def limiter(self) -> RateLimiter: return typing.cast(RateLimiter, self.extension.limiter) @property def window(self) -> Tuple[int, int]: if not self._window: self._window = self.limiter.get_window_stats(self.limit, *self.request_args) return self._window @property def reset_at(self) -> int: """Timestamp at which the rate limit will be reset""" return int(self.window[0] + 1) @property def remaining(self) -> int: """Quantity remaining for this rate limit""" return self.window[1]
@dataclasses.dataclass(eq=True, unsafe_hash=True) class Limit: """ simple wrapper to encapsulate limits and their context """ limit: RateLimitItem key_func: Callable[[], str] _scope: Optional[Union[str, Callable[[str], str]]] per_method: bool = False methods: Optional[Tuple[str, ...]] = None error_message: Optional[str] = None exempt_when: Optional[Callable[[], bool]] = None override_defaults: Optional[bool] = False deduct_when: Optional[Callable[[Response], bool]] = None on_breach: Optional[Callable[[RequestLimit], Optional[Response]]] = None _cost: Union[Callable[[], int], int] = 1 shared: bool = False def __post_init__(self) -> None: if self.methods: self.methods = tuple([k.lower() for k in self.methods]) @property def is_exempt(self) -> bool: """Check if the limit is exempt.""" if self.exempt_when: return self.exempt_when() return False @property def scope(self) -> Optional[str]: return ( self._scope(request.endpoint or "") if callable(self._scope) else self._scope ) @property def cost(self) -> int: if isinstance(self._cost, int): return self._cost return self._cost() @property def method_exempt(self) -> bool: """Check if the limit is not applicable for this method""" return self.methods is not None and request.method.lower() not in self.methods def scope_for(self, endpoint: str, method: Optional[str]) -> str: """ Derive final bucket (scope) for this limit given the endpoint and request method. If the limit is shared between multiple routes, the scope does not include the endpoint. """ limit_scope = self.scope if limit_scope: if self.shared: scope = limit_scope else: scope = f"{endpoint}:{limit_scope}" else: scope = endpoint if self.per_method: assert method scope += f":{method.upper()}" return scope @dataclasses.dataclass(eq=True, unsafe_hash=True) class LimitGroup: """ represents a group of related limits either from a string or a callable that returns one """ limit_provider: Union[Callable[[], str], str] key_function: Callable[[], str] scope: Optional[Union[str, Callable[[str], str]]] = None methods: Optional[Tuple[str, ...]] = None error_message: Optional[str] = None exempt_when: Optional[Callable[[], bool]] = None override_defaults: Optional[bool] = False deduct_when: Optional[Callable[[Response], bool]] = None on_breach: Optional[Callable[[RequestLimit], Optional[Response]]] = None per_method: bool = False cost: Optional[Union[Callable[[], int], int]] = None shared: bool = False def __iter__(self) -> Iterator[Limit]: limit_str = ( self.limit_provider() if callable(self.limit_provider) else self.limit_provider ) limit_items = parse_many(limit_str) if limit_str else [] for limit in limit_items: yield Limit( limit, self.key_function, self.scope, self.per_method, self.methods, self.error_message, self.exempt_when, self.override_defaults, self.deduct_when, self.on_breach, self.cost or 1, self.shared, )