import time
from typing import Any, List, Union
from urllib.parse import urlparse
import pymemcache
from freiner.errors import FreinerConfigurationError
from freiner.types import Host
MemcachedClient = Union[pymemcache.Client, pymemcache.PooledClient, pymemcache.HashClient]
[docs]class MemcachedStorage:
"""
Rate limit storage with memcached as backend.
Depends on the `pymemcache` library.
"""
MAX_CAS_RETRIES = 10
def __init__(self, client: MemcachedClient):
self._client: MemcachedClient = client
[docs] @classmethod
def from_uri(cls, uri: str, **options: Any) -> "MemcachedStorage":
"""
:param uri: URI of the form `memcached://host:port,host:port`or `memcached:///run/path/to/sock`.
:param options: All remaining keyword arguments are passed directly to the constructor
of :class:`pymemcache.client.base.Client`.
:raises FreinerConfigurationError: When no hosts could be parsed from the supplied URI.
"""
parsed_uri = urlparse(uri)
hosts: List[Union[Host, str]] = []
for loc in parsed_uri.netloc.strip().split(","):
if not loc:
continue
host, port = loc.split(":")
hosts.append((host, int(port)))
else:
# filesystem path to UDS
if parsed_uri.path and not parsed_uri.netloc and not parsed_uri.port:
hosts = [parsed_uri.path]
if not hosts:
raise FreinerConfigurationError(f"No Memcached hosts parsed from URI: {uri}")
if len(hosts) > 1:
client = pymemcache.HashClient(hosts, **options)
else:
client = pymemcache.Client(*hosts, **options)
return cls(client)
[docs] def get(self, key: str) -> int:
"""
Retrieve the current request count for the given rate limit key.
:param key: The key to get the counter value for.
"""
return int(self._client.get(key) or 0)
[docs] def clear(self, key: str) -> None:
"""
Resets the rate limit for the given key.
:param key: The key to clear rate limits for.
"""
self._client.delete(key)
[docs] def incr(self, key: str, expiry: int, elastic_expiry: bool = False) -> int:
"""
Increments the counter for the given rate limit key.
:param key: The key to increment.
:param expiry: Amount in seconds for the key to expire in.
:param elastic_expiry: Whether to keep extending the rate limit window every hit.
:return: The number of hits currently on the rate limit for the given key.
"""
if self._client.add(key, 1, expiry, noreply=False):
self._set_expiry(key, expiry)
return 1
if not elastic_expiry:
return self._client.incr(key, 1) or 1
# TODO: There is a timing issue here.
# This code makes the assumption that because client.add() failed, the key must exist.
# That isn't necessarily true. It can expire between us calling client.add() and us
# calling client.gets(). If that happens, 'cas' will be None. If we pass cas=None to
# client.cas(), it gets very unhappy.
# This issue shows up occasionally in the test suite, both locally and on Github Actions.
# If it shows up in testing, it absolutely will show up in the real world.
# I believe the solution will be to "restart" the logic flow if 'cas is None'. However,
# that will require rewriting the method so that that can be achieved without recursion,
# and without the code looking like a nightmare.
value, cas = self._client.gets(key)
retry = 0
while (
not self._client.cas(key, int(value or 0) + 1, cas, expiry)
and retry < self.MAX_CAS_RETRIES
):
value, cas = self._client.gets(key)
retry += 1
self._set_expiry(key, expiry)
return int(value or 0) + 1
def _set_expiry(self, key: str, expiry: int):
self._client.set(key + "/expires", expiry + time.time(), expire=expiry, noreply=False)
[docs] def get_expiry(self, key: str) -> float:
"""
Retrieve the expected expiry time for the given rate limit key.
:param key: The key to get the expiry time for.
:return: The time at which the current rate limit for the given key ends.
"""
return float(self._client.get(key + "/expires") or time.time())
[docs] def check(self) -> bool:
"""
Check if the connection to the storage backend is healthy.
"""
try:
self._client.get("freiner-check")
return True
except: # noqa
return False
__all__ = [
"MemcachedClient",
"MemcachedStorage",
]