diff --git a/cloudflare_workers/translation/index.ts b/cloudflare_workers/translation/index.ts index f2672557ea..6efd865bda 100644 --- a/cloudflare_workers/translation/index.ts +++ b/cloudflare_workers/translation/index.ts @@ -13,6 +13,9 @@ const TRANSLATION_STORE_CLEANUP_INTERVAL_SECONDS = 60 const TRANSLATION_REQUEUE_AFTER_SECONDS = 60 const TRANSLATION_BATCH_LEASE_SECONDS = 15 * 60 const TRANSLATION_STORE_TABLE = 'translation_messages_cache' +const TRANSLATION_RATE_LIMIT_PATH = '/translation/messages-rate-limit' +const TRANSLATION_RATE_LIMIT_TTL_SECONDS = 60 +const DEFAULT_TRANSLATION_REQUEST_LIMIT = 30 const CLAIMED_TRANSLATION_BATCH_INDEX_OFFSET = 1_000_000_000 const PLACEHOLDER_PATTERN = /\{[\w.]+\}|%\w+%?|\$\d+/g @@ -64,6 +67,7 @@ interface TranslationWorkerBindings { AI?: AiBinding DB_TRANSLATIONS?: D1Database ENV_NAME?: string + TRANSLATION_MESSAGES_RATE_LIMIT?: string TRANSLATION_MESSAGES_QUEUE?: Queue> TRANSLATION_MODEL?: string } @@ -98,6 +102,11 @@ interface TranslationQueuePayload { targetLanguage?: string } +interface RateLimitEntry { + count: number + resetAt: number +} + type MessageEntry = [string, string] type TranslationStoreEntryInput = Omit @@ -140,7 +149,8 @@ function jsonResponse(data: unknown, status = 200, headers: HeadersInit = {}) { } function errorResponse(status: number, code: string, message: string) { - return jsonResponse({ error: code, message }, status) + const headers = status === 429 ? { 'Retry-After': String(TRANSLATION_RATE_LIMIT_TTL_SECONDS) } : {} + return jsonResponse({ error: code, message }, status, headers) } function serializeError(error: unknown) { @@ -193,6 +203,53 @@ async function sha256Hex(value: string) { return encoded } +function clientIP(request: Request) { + return request.headers.get('cf-connecting-ip') + ?? request.headers.get('x-real-ip') + ?? request.headers.get('x-forwarded-for')?.split(',')[0]?.trim() + ?? '' +} + +function translationRequestLimit(env: TranslationWorkerBindings) { + const configured = Number.parseInt(env.TRANSLATION_MESSAGES_RATE_LIMIT ?? '', 10) + return Number.isFinite(configured) && configured > 0 ? configured : DEFAULT_TRANSLATION_REQUEST_LIMIT +} + +function rateLimitResetAt(now = Date.now()) { + return now + TRANSLATION_RATE_LIMIT_TTL_SECONDS * 1000 +} + +async function translationRateLimitCacheRequest(ip: string) { + const ipHash = await sha256Hex(ip) + const url = new URL(TRANSLATION_RATE_LIMIT_PATH, 'https://translation-cache.capgo.local') + url.searchParams.set('ip', ipHash) + return new Request(url) +} + +async function checkTranslationRequestRateLimit(request: Request, env: TranslationWorkerBindings) { + const ip = clientIP(request) + if (!ip) + return + + const cache = globalThis.caches?.default + if (!cache) + return + + const cacheRequest = await translationRateLimitCacheRequest(ip) + const cached = await cache.match(cacheRequest) + const existing = cached ? await cached.json().catch(() => null) as RateLimitEntry | null : null + const count = (existing?.count ?? 0) + 1 + const resetAt = existing?.resetAt && existing.resetAt > Date.now() ? existing.resetAt : rateLimitResetAt() + const entry: RateLimitEntry = { count, resetAt } + + await cache.put(cacheRequest, jsonResponse(entry, 200, { + 'Cache-Control': `max-age=${TRANSLATION_RATE_LIMIT_TTL_SECONDS}`, + })) + + if (count > translationRequestLimit(env)) + fail(429, 'rate_limited', 'Too many translation requests') +} + function recordOf(value: unknown): Record | null { if (value === null || typeof value !== 'object' || Array.isArray(value)) return null @@ -969,6 +1026,8 @@ async function handleTranslationMessages(request: Request, env: TranslationWorke return readyResponse } + await checkTranslationRequestRateLimit(request, env) + try { const queuedResponse = await queueCurrentTranslationResponse(env, requestId, readyRequest, checksum, targetLanguage, model) if (queuedResponse) diff --git a/tests/translation-queue.unit.test.ts b/tests/translation-queue.unit.test.ts index 562a22d63f..1cac6aba53 100644 --- a/tests/translation-queue.unit.test.ts +++ b/tests/translation-queue.unit.test.ts @@ -14,6 +14,24 @@ function stubWorkerCache() { return cache } +function stubStatefulWorkerCache() { + const store = new Map() + const cache = { + match: vi.fn(async (request: Request) => { + const response = store.get(request.url) + return response?.clone() ?? null + }), + put: vi.fn(async (request: Request, response: Response) => { + store.set(request.url, response.clone()) + }), + } + Object.defineProperty(globalThis, 'caches', { + configurable: true, + value: { default: cache }, + }) + return cache +} + function createTranslationStoreMock(latestReadyEntry: Record | null) { return { prepare: vi.fn((sql: string) => ({ @@ -195,4 +213,39 @@ describe('translation queue helpers', () => { expect(payload.status).toBe('pending') expect(queue.send).toHaveBeenCalledTimes(1) }) + + it('rate limits repeated public translation queue requests by client IP', async () => { + stubStatefulWorkerCache() + const db = createTranslationStoreMock(null) + const queue = { + send: vi.fn(), + } + const env = { + DB_TRANSLATIONS: db, + TRANSLATION_MESSAGES_QUEUE: queue, + TRANSLATION_MESSAGES_RATE_LIMIT: '1', + } as any + + const firstResponse = await translationWorker.fetch(new Request('https://api.capgo.app/translation/messages', { + body: JSON.stringify({ targetLanguage: 'fr' }), + headers: { + 'Content-Type': 'application/json', + 'cf-connecting-ip': '203.0.113.10', + }, + method: 'POST', + }), env) + const secondResponse = await translationWorker.fetch(new Request('https://api.capgo.app/translation/messages', { + body: JSON.stringify({ targetLanguage: 'es' }), + headers: { + 'Content-Type': 'application/json', + 'cf-connecting-ip': '203.0.113.10', + }, + method: 'POST', + }), env) + + expect(firstResponse.status).toBe(202) + expect(secondResponse.status).toBe(429) + expect(secondResponse.headers.get('retry-after')).toBe('60') + expect(queue.send).toHaveBeenCalledTimes(1) + }) })