diff --git a/apisix/plugins/limit-count/sliding-window/sliding-window.lua b/apisix/plugins/limit-count/sliding-window/sliding-window.lua index 89d33a2731a8..f21e5b8a4825 100644 --- a/apisix/plugins/limit-count/sliding-window/sliding-window.lua +++ b/apisix/plugins/limit-count/sliding-window/sliding-window.lua @@ -40,12 +40,8 @@ end local function get_counter_key(self, key, time) local wid = get_window_id(self, time) - -- Prefix with plugin_name (set only for the Redis-backed stores) so that two - -- plugins reusing this module on the same resource with identical config - -- (and therefore the same gen_limit_key) cannot share a Redis counter, the - -- way the fixed-window Redis path already isolates them. The local store is - -- already namespaced by its per-plugin shared dict, so it passes no name and - -- keeps the original key format. + -- plugin_name (Redis stores only) keeps plugins that share a key apart, + -- like the fixed-window Redis path already does. if self.plugin_name then return string_format("%s:%s.%s.counter", self.plugin_name, key, wid) end @@ -53,30 +49,6 @@ local function get_counter_key(self, key, time) end -local function get_last_rate(self, sample, now_ms, red_cli) - local a_window_ago_from_now = now_ms - self.window_size - local last_counter_key = get_counter_key(self, sample, a_window_ago_from_now) - - local last_count, err = self.store:get(last_counter_key, red_cli) - if err then - return nil, err - end - if not last_count then - last_count = 0 - end - if last_count > self.limit then - -- in incoming we also reactively check for exceeding limit - -- after icnrementing the counter. So even though counter can be higher - -- than the limit as a result of racy behaviour we would still throttle - -- anyway. That is way it is important to correct the last count here - -- to avoid over-punishment. - last_count = self.limit - end - - return last_count / self.window_size -end - - function _M.new(store, limit, window_size, red_cli) if not store then return nil, "'store' parameter is missing" @@ -87,6 +59,9 @@ function _M.new(store, limit, window_size, red_cli) if not store.get then return nil, "'store' has to implement 'get' function" end + if not store.check_and_incr then + return nil, "'store' has to implement 'check_and_incr' function" + end return setmetatable({ store = store, @@ -107,6 +82,9 @@ function _M.new_with_red_cli_factory(store, limit, window_size, red_cli_factory, if not store.get then return nil, "'store' has to implement 'get' function" end + if not store.check_and_incr then + return nil, "'store' has to implement 'check_and_incr' function" + end return setmetatable({ store = store, @@ -136,6 +114,7 @@ end function _M.incoming(self, key, cost) local now = ngx_now() local counter_key = get_counter_key(self, key, now) + local last_counter_key = get_counter_key(self, key, now - self.window_size) local remaining_time = self.window_size - now % self.window_size local red_cli, err @@ -146,63 +125,37 @@ function _M.incoming(self, key, cost) end end - local count, err = self.store:get(counter_key, self.red_cli or red_cli) - if err then - return nil, err - end - if not count then - count = 0 - end - log.debug("count: ", count, ", limit: ", self.limit) - if count >= self.limit then - return nil, "rejected", round_off_decimal_places(remaining_time, 2) - end - - local last_rate - last_rate, err = get_last_rate(self, key, now, self.red_cli or red_cli) - if err then - return nil, err, 0 - end - - local estimated_last_window_count = last_rate * remaining_time - local estimated_final_count = estimated_last_window_count + count - log.debug("estimated_final_count: ", estimated_final_count, ", limit: ", self.limit) - if estimated_final_count >= self.limit then - local desired_delay = - get_desired_delay(self, remaining_time, last_rate, count) - return nil, "rejected", round_off_decimal_places(desired_delay, 2) - end - + -- One atomic step decides accept/reject and increments only on accept, so + -- concurrent requests can't all pass the check before any increment lands. local expiry = self.window_size * 2 - local new_count - new_count, err = self.store:incr(counter_key, cost, expiry, self.red_cli or red_cli) - if err then - return nil, err, 0 - end + local res + res, err = self.store:check_and_incr(counter_key, last_counter_key, cost, + self.limit, self.window_size, remaining_time, expiry, + self.red_cli or red_cli) if red_cli then red_cli:set_keepalive(10000, 100) end - -- The below limit checking is only to cope with a racy behaviour where - -- counter for the given sample is incremented at the same time by multiple - -- sliding_window instances. That is we re-adjust the new count by ignoring - -- the current occurrence of the sample. Otherwise the limit would - -- unncessarily be exceeding. - local new_adjusted_count = new_count - cost - log.debug("new_adjusted_count: ", new_adjusted_count, ", limit: ", self.limit) - - if new_adjusted_count >= self.limit then - -- incr above might take long enough to make difference, so - -- we recalculate time-dependant variables. - remaining_time = self.window_size - ngx_now() % self.window_size - return nil, "rejected", round_off_decimal_places(remaining_time, 2) + if not res then + return nil, err, 0 end - local remaining = self.limit - new_count - estimated_last_window_count - local rounded_remaining = math_floor(remaining) + local accepted, count, last_count = res[1], res[2], res[3] + local last_rate = last_count / self.window_size + local estimated_last_window_count = last_rate * remaining_time + log.debug("accepted: ", accepted, ", count: ", count, ", limit: ", self.limit) + + if accepted == 0 then + if count >= self.limit then + return nil, "rejected", round_off_decimal_places(remaining_time, 2) + end + local desired_delay = get_desired_delay(self, remaining_time, last_rate, count) + return nil, "rejected", round_off_decimal_places(desired_delay, 2) + end - return 0, rounded_remaining, round_off_decimal_places(remaining_time, 2) + local remaining = self.limit - count - estimated_last_window_count + return 0, math_floor(remaining), round_off_decimal_places(remaining_time, 2) end diff --git a/apisix/plugins/limit-count/sliding-window/store/redis.lua b/apisix/plugins/limit-count/sliding-window/store/redis.lua index 6ff77d6033ad..f5eade46ba89 100644 --- a/apisix/plugins/limit-count/sliding-window/store/redis.lua +++ b/apisix/plugins/limit-count/sliding-window/store/redis.lua @@ -33,6 +33,47 @@ local incr_script = core.string.compress_script([=[ local incr_script_sha = to_hex(ngx.sha1_bin(incr_script)) +-- Decide accept/reject and increment (only on accept) in one atomic step, so +-- concurrent requests can't all pass the check before an increment lands. +-- KEYS[1] is the current window counter; the previous window's count comes via +-- ARGV because it is frozen (in the past, never written concurrently) and +-- redis-cluster only allows single-key EVAL. Returns {accepted, count, last}: +-- count is the post-incr value on accept, else the current count; last is the +-- previous window count, capped at the limit. +local check_incr_script = core.string.compress_script([=[ + local cost = tonumber(ARGV[1]) + local limit = tonumber(ARGV[2]) + local window_size = tonumber(ARGV[3]) + local remaining_time = tonumber(ARGV[4]) + local expiry = ARGV[5] + local last = tonumber(ARGV[6]) + if last > limit then + last = limit + end + + local cur_ttl = redis.call('pttl', KEYS[1]) + local cur = 0 + if cur_ttl >= 0 then + cur = tonumber(redis.call('get', KEYS[1]) or 0) + end + + local estimated = last / window_size * remaining_time + cur + if cur >= limit or estimated >= limit then + return {0, cur, last} + end + + local new + if cur_ttl < 0 then + redis.call('set', KEYS[1], cost, 'EX', expiry) + new = cost + else + new = redis.call('incrby', KEYS[1], cost) + end + return {1, new, last} +]=]) +local check_incr_script_sha = to_hex(ngx.sha1_bin(check_incr_script)) + + -- TODO: keepalive or close function _M.incr(self, key, delta, expiry, red) -- nk key1 argv1 argv2 @@ -54,6 +95,38 @@ function _M.incr(self, key, delta, expiry, red) end +function _M.check_and_incr(self, current_key, last_key, cost, limit, + window_size, remaining_time, expiry, red) + -- previous window is frozen, so a single-key GET is safe and keeps the + -- atomic EVAL to one key, which redis-cluster requires + local last, err = red:get(last_key) + if err then + return nil, err + end + if not last or last == ngx_null then + last = 0 + end + + local res + res, err = red:evalsha(check_incr_script_sha, 1, current_key, + cost, limit, window_size, remaining_time, expiry, last) + if err and core.string.has_prefix(err, "NOSCRIPT") then + core.log.warn("redis evalsha failed: ", err, ". Falling back to eval") + res, err = red:eval(check_incr_script, 1, current_key, + cost, limit, window_size, remaining_time, expiry, last) + end + if err then + return nil, err + end + + if not res then + return nil, "malformed redis response while calling check_and_incr" + end + + return res +end + + -- TODO: keepalive or close function _M.get(self, key, red) local value, err = red:get(key) diff --git a/apisix/plugins/limit-count/sliding-window/store/shared-dict.lua b/apisix/plugins/limit-count/sliding-window/store/shared-dict.lua index 0401bf0b4b1e..72a6626e68f7 100644 --- a/apisix/plugins/limit-count/sliding-window/store/shared-dict.lua +++ b/apisix/plugins/limit-count/sliding-window/store/shared-dict.lua @@ -52,6 +52,36 @@ function _M.incr(self, key, delta, expiry) return new_value end +-- Counterpart of the redis store's atomic check. Shared dict ops don't yield, +-- so get/decide/incr can't interleave within a worker. They aren't atomic +-- across workers though, so a concurrent burst may admit a few extra requests +-- at a window boundary. Best-effort by design; the redis store is exact. +function _M.check_and_incr(self, current_key, last_key, cost, limit, + window_size, remaining_time, expiry) + local dict = self.dict + local last = dict:get(last_key) or 0 + if last > limit then + last = limit + end + + local cur = dict:get(current_key) or 0 + local estimated = last / window_size * remaining_time + cur + if cur >= limit or estimated >= limit then + return {0, cur, last} + end + + local new, err, forcible = dict:incr(current_key, cost, 0, expiry) + if err then + return nil, err + end + + if forcible then + log.warn("shared dictionary is full, removed valid key(s) to store the new one") + end + + return {1, new, last} +end + function _M.get(self, key) local value, err = self.dict:get(key) if not value then diff --git a/t/plugin/limit-count-sliding.t b/t/plugin/limit-count-sliding.t index 2286eb1b09df..e6c069a38f7c 100644 --- a/t/plugin/limit-count-sliding.t +++ b/t/plugin/limit-count-sliding.t @@ -292,3 +292,47 @@ commit delay: 0, remaining: -3 --- response_body a over limit: rejected b independent: 0, remaining: 1 + + + +=== TEST 8: check_and_incr decides and increments atomically, never on reject +# the accept/reject decision and the increment happen in one atomic step, so +# concurrent requests cannot all pass the check before any increment lands. an +# over-limit request must reject and leave the counter untouched. +--- config + location /t { + content_by_lua_block { + local redis_store = + require("apisix.plugins.limit-count.sliding-window.store.redis") + local redis_cli = require("apisix.plugins.limit-count.util").redis_cli + local conf = { + redis_host = "127.0.0.1", + redis_port = 6379, + redis_database = 1, + } + local red = redis_cli(conf) + local limit, window, remaining_time, expiry = 2, 5, 5, 10 + local cur = "ut-atomic-cur-" .. ngx.now() + local last = "ut-atomic-last-" .. ngx.now() + + local function call(cost) + return redis_store.check_and_incr(redis_store, cur, last, cost, + limit, window, remaining_time, expiry, red) + end + + local r1 = call(1) + ngx.say("accept ", r1[1], " count ", r1[2]) + local r2 = call(1) + ngx.say("accept ", r2[1], " count ", r2[2]) + -- over the limit now: must reject and not increment + local r3 = call(1) + ngx.say("accept ", r3[1], " count ", r3[2]) + local stored = red:get(cur) + ngx.say("stored: ", stored) + } + } +--- response_body +accept 1 count 1 +accept 1 count 2 +accept 0 count 2 +stored: 2