import time
import pytz
import functools
from threading import RLock
from collections import defaultdict, deque
from datetime import datetime, timedelta


class LimitBucket:
    """
    限制某功能运行中某段在一定速率下
    """

    def __init__(self, capacity, fill_rate, is_lock: bool = False) -> None:
        """
        :param capacity: 容量总数
        :param fill_rate: 重新装填速率(单位:秒)
        """
        self._capacity = float(capacity)
        self._tokens = float(capacity)
        self._fill_rate = float(fill_rate)
        self._last_time = time()
        self._is_lock = is_lock
        self._lock = RLock()

    def _get_cur_tokens(self):
        if self._tokens < self._capacity:
            now = time()
            delta = self._fill_rate * (now - self._last_time)
            self._tokens = min(self._capacity, self._tokens + delta)
            self._last_time = now
        return self._tokens

    def get_cur_tokens(self):
        if self._is_lock:
            with self._lock:
                return self._get_cur_tokens()
        else:
            return self._get_cur_tokens()

    def _consume(self, tokens) -> bool:
        if tokens <= self.get_cur_tokens():
            self._tokens -= tokens
            return True
        return False

    def consume(self, tokens):
        if self._is_lock:
            with self._lock:
                return self._consume(tokens)
        else:
            return self._consume(tokens)


class RateLimiting:
    """
    限制该功能全体速率
    """

    def __init__(self, max_calls, period=1.0):
        if period <= 0:
            raise ValueError("Rate limiting period should be > 0")
        if max_calls <= 0:
            raise ValueError("Rate limiting number of calls should be > 0")

        self.calls = deque()

        self.period = period
        self.max_calls = max_calls

    def __call__(self, func):
        @functools.wraps(func)
        def wrapped(*args, **kwargs):
            with self:
                return func(*args, **kwargs)

        return wrapped

    def __enter__(self):
        if len(self.calls) >= self.max_calls:
            time.sleep(self.period - self._timespan)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.calls.append(time.time())
        while self._timespan >= self.period:
            self.calls.popleft()

    @property
    def _timespan(self):
        return self.calls[-1] - self.calls[0]


class FreqLimiter:
    """
    Copy from: https://github.com/Ice-Cirno/HoshinoBot/blob/master/hoshino/util/__init__.py
    """

    def __init__(self, default_cd_seconds):
        self.next_time = defaultdict(float)
        self.default_cd = default_cd_seconds

    def check(self, key) -> bool:
        return bool(time.time() >= self.next_time[key])

    def start_cd(self, key, cd_time=0):
        self.next_time[key] = time.time() + (
            cd_time if cd_time > 0 else self.default_cd
        )

    def left_time(self, key) -> float:
        return self.next_time[key] - time.time()


class DailyLimiter:
    """
    Copy from: https://github.com/Ice-Cirno/HoshinoBot/blob/master/hoshino/util/__init__.py
    """

    tz = pytz.timezone("Asia/Shanghai")

    def __init__(self, max_num):
        self.today = -1
        self.count = defaultdict(int)
        self.max = max_num

    def check(self, key) -> bool:
        now = datetime.now(self.tz)
        day = (now - timedelta(hours=6)).day
        if day != self.today:
            self.today = day
            self.count.clear()
        return bool(self.count[key] < self.max)

    def get_num(self, key):
        return self.count[key]

    def increase(self, key, num=1):
        self.count[key] += num

    def reset(self, key):
        self.count[key] = 0