import threading
import time
from collections import Counter
from typing import Counter as CounterType
from . import MovingWindow
class _LockableEntry:
__slots__ = ("atime", "expiry", "_lock")
def __init__(self, expiry: float):
self.atime: float = time.time()
self.expiry: float = self.atime + expiry
self._lock = threading.RLock()
def acquire(self) -> None:
self._lock.acquire()
def release(self) -> None:
self._lock.release()
def __enter__(self):
self.acquire()
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()
def __repr__(self) -> str:
return f"MemoryLockableEntry<atime={self.atime}, expiry={self.expiry}>" # pragma: no cover
[docs]class MemoryStorage:
"""
rate limit storage using :py:class:`collections.Counter`
as an in memory storage for fixed and elastic window strategies,
and a simple list to implement moving window strategy.
"""
def __init__(self):
self.storage: CounterType[str] = Counter()
self.expirations = {}
self.events = {}
self.timer = threading.Timer(0.01, self.__expire_events)
self.timer.start()
def __expire_events(self) -> None:
for key in list(self.events.keys()):
for event in list(self.events[key]):
with event:
if event.expiry <= time.time() and event in self.events[key]:
self.events[key].remove(event)
for key in list(self.expirations.keys()):
if self.expirations[key] <= time.time():
self.storage.pop(key, None)
self.expirations.pop(key, None)
def __schedule_expiry(self) -> None:
if not self.timer.is_alive():
self.timer = threading.Timer(0.01, self.__expire_events)
self.timer.start()
[docs] def incr(self, key: str, expiry: int, elastic_expiry: bool = False) -> int:
"""
Increments the counter for the given rate limit key.
:param key: The key to increment.
:param expiry: Amount in seconds for the key to expire in.
:param elastic_expiry: Whether to keep extending the rate limit window every hit.
:return: The number of hits currently on the rate limit for the given key.
"""
self.get(key)
self.__schedule_expiry()
self.storage[key] += 1
if elastic_expiry or self.storage[key] == 1:
self.expirations[key] = time.time() + expiry
return self.storage.get(key, 0)
[docs] def get(self, key: str) -> int:
"""
Retrieve the current request count for the given rate limit key.
:param key: The key to get the counter value for.
"""
if self.expirations.get(key, 0) <= time.time():
self.storage.pop(key, None)
self.expirations.pop(key, None)
return self.storage.get(key, 0)
[docs] def clear(self, key: str) -> None:
"""
Resets the rate limit for the given key.
:param key: The key to clear rate limits for.
"""
self.storage.pop(key, None)
self.expirations.pop(key, None)
self.events.pop(key, None)
[docs] def acquire_entry(self, key: str, limit: int, expiry: int, no_add: bool = False) -> bool:
"""
:param key: The rate limit key to acquire an entry in.
:param limit: The total amount of entries allowed before hitting the rate limit.
:param expiry: Amount in seconds for the acquired entry to expire in.
:param no_add: If False, an entry is not actually acquired but instead serves as a 'check'.
"""
self.events.setdefault(key, [])
self.__schedule_expiry()
timestamp = time.time()
try:
entry = self.events[key][limit - 1]
except IndexError:
entry = None
if entry and entry.atime > timestamp - expiry:
return False
else:
if not no_add:
self.events[key].insert(0, _LockableEntry(expiry))
return True
[docs] def get_expiry(self, key: str) -> float:
"""
Retrieve the expected expiry time for the given rate limit key.
:param key: The key to get the expiry time for.
:return: The time at which the current rate limit for the given key ends.
"""
return self.expirations.get(key, -1)
[docs] def get_num_acquired(self, key: str, expiry: int) -> int:
"""
returns the number of entries already acquired
:param key: rate limit key to acquire an entry in
:param expiry: expiry of the entry
"""
timestamp = time.time()
if self.events.get(key):
return len([k for k in self.events[key] if k.atime > timestamp - expiry])
else:
return 0
[docs] def get_moving_window(self, key: str, limit: int, expiry: int) -> MovingWindow:
"""
Retrieves the starting point and the number of entries in the moving window.
:param key: The rate limit key to retrieve statistics about.
:param limit: The total amount of entries allowed before hitting the rate limit.
:param expiry: Amount in seconds for the acquired entry to expire in.
:return: (start of window, number of acquired entries)
"""
timestamp = time.time()
acquired = self.get_num_acquired(key, expiry)
for item in self.events.get(key, []):
if item.atime > timestamp - expiry:
return MovingWindow(item.atime, acquired)
return MovingWindow(timestamp, acquired)
[docs] def check(self) -> bool:
"""
Check if the connection to the storage backend is healthy.
"""
return True
def reset(self) -> None:
self.storage.clear()
self.expirations.clear()
self.events.clear()
__all__ = [
"MemoryStorage",
]