diff --git a/plugins/index.ts b/plugins/index.ts index 641738b50..296e36a08 100644 --- a/plugins/index.ts +++ b/plugins/index.ts @@ -66,6 +66,7 @@ import { handler as javelinguardrails } from './javelin/guardrails'; import { handler as f5GuardrailsScan } from './f5-guardrails/scan'; import { handler as azureShieldPrompt } from './azure/shieldPrompt'; import { handler as azureProtectedMaterial } from './azure/protectedMaterial'; +import { handler as lakeraguard } from './lakera/main-function'; export const plugins = { default: { @@ -176,4 +177,7 @@ export const plugins = { 'f5-guardrails': { scan: f5GuardrailsScan, }, + lakera: { + guard: lakeraguard, + }, }; diff --git a/plugins/lakera/main-function.ts b/plugins/lakera/main-function.ts new file mode 100644 index 000000000..ac1ac1778 --- /dev/null +++ b/plugins/lakera/main-function.ts @@ -0,0 +1,232 @@ +import { + HookEventType, + PluginContext, + PluginHandler, + PluginParameters, +} from '../types'; +import { HttpError, post } from '../utils'; +import { + applyMasksToMessages, + isOnlyPiiViolation, + type PayloadItem, +} from './redaction'; + +function normalizeMessages(messages: any[]): any[] { + if (!messages?.length) return []; + return messages.map((message: any) => { + if (typeof message.content === 'string') return message; + if (Array.isArray(message.content)) { + const text = message.content.reduce( + (acc: string, item: any) => + acc + (item?.type === 'text' ? `${item.text}\n` : ''), + '' + ); + return { ...message, content: text }; + } + return message; + }); +} + +function extractMessages( + context: PluginContext, + eventType: HookEventType +): any[] { + const reqJson = context.request?.json || {}; + let messages = reqJson.messages; + if (messages && Array.isArray(messages)) { + const base = JSON.parse(JSON.stringify(messages)); + const normalized = normalizeMessages(base); + if (eventType === 'afterRequestHook') { + const rjson = context.response?.json || {}; + const choices = rjson.choices || []; + const ch0 = choices[0]; + if (ch0?.message && ch0.message.content != null) { + normalized.push({ + role: ch0.message.role || 'assistant', + content: ch0.message.content, + }); + } + } + return normalized; + } + const text = context.request?.text; + if (typeof text === 'string' && text.trim()) { + const msgs: any[] = [{ role: 'user', content: text }]; + if (eventType === 'afterRequestHook') { + const respText = context.response?.text; + if (typeof respText === 'string' && respText.trim()) { + msgs.push({ role: 'assistant', content: respText }); + } + } + return msgs; + } + return []; +} + +function portkeyMetadataToLakera( + meta: Record | undefined +): Record | undefined { + if (!meta || typeof meta !== 'object') return undefined; + const out: Record = {}; + const u = meta._user ?? meta.user_id; + if (u != null) out.user_id = String(u); + if (meta.session_id != null) out.session_id = String(meta.session_id); + if (meta.ip_address != null) out.ip_address = String(meta.ip_address); + return Object.keys(out).length ? out : undefined; +} + +export const handler: PluginHandler = async ( + context: PluginContext, + parameters: PluginParameters, + eventType: HookEventType +) => { + let error: any = null; + let verdict = false; + let data: any = null; + let transformed = false; + const transformedData: Record = { + request: { json: null, text: null }, + response: { json: null, text: null }, + }; + + try { + const apiKey = parameters.credentials?.apiKey as string | undefined; + if (!apiKey) { + throw new Error( + 'Missing Lakera apiKey: set credentials.apiKey in the guardrail config' + ); + } + const projectID = parameters.projectID; + + const messages = extractMessages(context, eventType); + if (!messages.length) { + return { + error: null, + verdict: true, + data: { explanation: 'no messages to screen' }, + }; + } + + const apiBase = String( + (parameters.credentials?.apiBase as string | undefined) ?? + 'https://api.lakera.ai' + ).replace(/\/$/, ''); + const url = `${apiBase}/v2/guard`; + const body: Record = { + messages, + payload: true, + breakdown: true, + }; + if (projectID) { + body.project_id = projectID; + } + const lm = portkeyMetadataToLakera( + context.metadata as Record + ); + if (lm) body.metadata = lm; + + const headers = { + Authorization: `Bearer ${apiKey}`, + 'Content-Type': 'application/json', + }; + + const lakeraResp: any = await post( + url, + body, + { headers }, + parameters.timeout || 30000 + ); + + const flagged = Boolean(lakeraResp.flagged); + const breakdown = lakeraResp.breakdown || []; + const payload = (lakeraResp.payload || []) as PayloadItem[]; + + const safeLog = { ...lakeraResp }; + // Strip raw spans (contain PII text positions) and internal Lakera IDs + // (detector_id, policy_id, project_id) from caller-visible data. + delete safeLog.payload; + delete safeLog.breakdown; + data = { + lakera: { + ...safeLog, + detectedTypes: breakdown + .filter((b: any) => b.detected) + .map((b: any) => b.detector_type), + }, + }; + + if (!flagged) { + verdict = true; + return { error, verdict, data }; + } + + const endInclusive = Boolean(parameters.endInclusive); + + if (isOnlyPiiViolation(breakdown) && payload.length > 0) { + const { messages: maskedMsgs, warnings } = applyMasksToMessages( + messages, + payload, + endInclusive + ); + + if (warnings.some((w) => w.includes('multimodal'))) { + verdict = false; + return { + error, + verdict, + data: { + ...data, + warnings, + explanation: + 'multimodal content cannot be masked in this plugin build', + }, + }; + } + + if (eventType === 'beforeRequestHook') { + const reqJson = context.request?.json + ? JSON.parse(JSON.stringify(context.request.json)) + : {}; + reqJson.messages = maskedMsgs; + transformedData.request.json = reqJson; + transformed = true; + } else { + const respJson = context.response?.json + ? JSON.parse(JSON.stringify(context.response.json)) + : {}; + const choices = respJson.choices || []; + if (choices[0]?.message && maskedMsgs.length > 0) { + const last = maskedMsgs[maskedMsgs.length - 1]; + if (last && last.role === 'assistant') { + choices[0].message = choices[0].message || {}; + choices[0].message.content = last.content; + respJson.choices = choices; + } + } + transformedData.response.json = respJson; + transformed = true; + } + + verdict = true; + return { + error, + verdict, + data: { ...data, warnings }, + transformedData, + transformed, + }; + } + + verdict = false; + return { error, verdict, data }; + } catch (e: any) { + // Strip stack trace to avoid leaking internal file paths to callers. + delete e?.stack; + error = e; + verdict = false; + if (e instanceof HttpError) { + data = { httpStatus: e.response?.status, body: e.response?.body }; + } + return { error, verdict, data }; + } +}; diff --git a/plugins/lakera/manifest.json b/plugins/lakera/manifest.json new file mode 100644 index 000000000..7e1c314f5 --- /dev/null +++ b/plugins/lakera/manifest.json @@ -0,0 +1,66 @@ +{ + "id": "lakera", + "description": "Lakera Guard — screen prompts and responses via POST /v2/guard. Supports blocking and PII redaction (payload) when only pii/* detectors fire.", + "credentials": { + "type": "object", + "properties": { + "apiKey": { + "type": "string", + "label": "Lakera API key", + "description": "Create at platform.lakera.ai (Guard API key)", + "encrypted": true + }, + "apiBase": { + "type": "string", + "label": "API base URL (optional)", + "description": "Default https://api.lakera.ai — use a regional host if required" + } + }, + "required": ["apiKey"] + }, + "functions": [ + { + "name": "Guard — screen content", + "id": "guard", + "supportedHooks": ["beforeRequestHook", "afterRequestHook"], + "type": "guardrail", + "description": [ + { + "type": "subHeading", + "text": "Calls Lakera Guard /v2/guard with payload+breakdown. Blocks on policy hits; redacts PII spans when breakdown shows only pii/* detectors." + } + ], + "parameters": { + "type": "object", + "properties": { + "projectID": { + "type": "string", + "label": "Lakera project ID", + "description": [ + { + "type": "subHeading", + "text": "Project whose policy defines detectors (recommended)" + } + ] + }, + "endInclusive": { + "type": "boolean", + "label": "Payload end offset is inclusive", + "description": [ + { + "type": "subHeading", + "text": "Leave false unless your Lakera tier emits inclusive end indices" + } + ] + }, + "timeout": { + "type": "number", + "label": "HTTP timeout (ms)", + "description": [{ "type": "subHeading", "text": "Default 30000" }] + } + }, + "required": [] + } + } + ] +} diff --git a/plugins/lakera/redaction.ts b/plugins/lakera/redaction.ts new file mode 100644 index 000000000..d137f16f4 --- /dev/null +++ b/plugins/lakera/redaction.ts @@ -0,0 +1,200 @@ +/** + * Multi-span PII masking from Lakera `payload` (message_id, start, end, detector_type). + * Half-open [start, end) unless endInclusive. + */ + +export interface PayloadItem { + message_id?: number; + start?: number; + end?: number; + detector_type?: string; + [key: string]: unknown; +} + +export function dedupePayloadItems(items: PayloadItem[]): PayloadItem[] { + const seen = new Set(); + const out: PayloadItem[] = []; + for (const it of items) { + const key = `${it.message_id}\0${it.start}\0${it.end}\0${it.detector_type}`; + if (seen.has(key)) continue; + seen.add(key); + out.push(it); + } + return out; +} + +export function mergeOverlappingIntervals( + spans: [number, number][] +): [number, number][] { + if (spans.length === 0) return []; + const sorted = [...spans].sort((a, b) => a[0] - b[0] || a[1] - b[1]); + const merged: [number, number][] = []; + let [curS, curE] = sorted[0]; + for (let i = 1; i < sorted.length; i++) { + const [s, e] = sorted[i]; + if (s <= curE) { + curE = Math.max(curE, e); + } else { + merged.push([curS, curE]); + curS = s; + curE = e; + } + } + merged.push([curS, curE]); + return merged; +} + +export function normalizeSpan( + start: number, + end: number, + length: number, + endInclusive: boolean +): [number, number] | null { + let e = endInclusive ? end + 1 : end; + if (start < 0 || e < 0 || start >= length || e > length || start >= e) + return null; + return [start, e]; +} + +export function maskLabel(detectorType: string): string { + const raw = (detectorType || '').trim(); + const base = raw.includes('/') ? raw.split('/').pop() || 'PII' : raw || 'PII'; + const safe = + base + .replace(/[^a-zA-Z0-9]+/g, '_') + .replace(/^_|_$/g, '') + .toUpperCase() || 'PII'; + return `[MASKED_${safe}]`; +} + +export type SpanLabel = [number, number, string]; + +export function mergeSpansWithLabels(raw: SpanLabel[]): SpanLabel[] { + if (raw.length === 0) return []; + const sorted = [...raw].sort((a, b) => a[0] - b[0] || a[1] - b[1]); + const merged: SpanLabel[] = []; + let [curS, curE, curL] = sorted[0]; + for (let i = 1; i < sorted.length; i++) { + const [s, e, lab] = sorted[i]; + if (s <= curE) { + curE = Math.max(curE, e); + } else { + merged.push([curS, curE, curL]); + curS = s; + curE = e; + curL = lab; + } + } + merged.push([curS, curE, curL]); + return merged; +} + +function collectNormalizedSpans( + items: PayloadItem[], + messageId: number, + textLen: number, + endInclusive: boolean +): { spans: SpanLabel[]; warnings: string[] } { + const warnings: string[] = []; + const spans: SpanLabel[] = []; + const forMsg = dedupePayloadItems( + items.filter((p) => p.message_id === messageId) + ); + for (const p of forMsg) { + if (p.start === undefined || p.end === undefined) { + warnings.push(`skip span missing start/end: ${JSON.stringify(p)}`); + continue; + } + const start = Number(p.start); + const end = Number(p.end); + const norm = normalizeSpan(start, end, textLen, endInclusive); + if (!norm) { + warnings.push( + `skip out-of-range span start=${start} end=${end} len=${textLen}` + ); + continue; + } + const [a, b] = norm; + spans.push([a, b, String(p.detector_type || '')]); + } + return { spans, warnings }; +} + +export function applyPayloadMasksToString( + text: string, + payloadItems: PayloadItem[], + messageId: number, + endInclusive: boolean +): { text: string; warnings: string[] } { + const { spans, warnings } = collectNormalizedSpans( + payloadItems, + messageId, + text.length, + endInclusive + ); + const merged = mergeSpansWithLabels(spans); + let out = text; + for (const [start, end, dt] of merged.sort((a, b) => b[0] - a[0])) { + out = out.slice(0, start) + maskLabel(dt) + out.slice(end); + } + return { text: out, warnings }; +} + +export function isOnlyPiiViolation( + breakdown: + | Array<{ detected?: boolean; detector_type?: string }> + | null + | undefined +): boolean { + if (!breakdown || breakdown.length === 0) return false; + let any = false; + for (const item of breakdown) { + if (!item.detected) continue; + any = true; + const dt = (item.detector_type || '').trim(); + if (!dt.startsWith('pii/')) return false; + } + return any; +} + +export function applyMasksToMessages( + messages: Array>, + payload: PayloadItem[] | null | undefined, + endInclusive: boolean +): { messages: Array>; warnings: string[] } { + const warnings: string[] = []; + if (!payload || payload.length === 0) { + return { messages: messages.map((m) => ({ ...m })), warnings }; + } + const out = messages.map((m) => ({ ...m, content: m.content })); + const indices = new Set( + payload + .map((p) => p.message_id) + .filter((id) => id !== undefined && id !== null) + ); + for (const idx of indices) { + const i = Number(idx); + if (i < 0 || i >= out.length) { + warnings.push( + `payload references message_id=${i} but only ${out.length} messages` + ); + continue; + } + const msg = out[i]; + const content = msg.content; + if (Array.isArray(content)) { + warnings.push( + `message ${i}: multimodal content array — redaction skipped for PoC` + ); + continue; + } + if (typeof content !== 'string') { + warnings.push(`message ${i}: non-string content — skipped`); + continue; + } + const r = applyPayloadMasksToString(content, payload, i, endInclusive); + warnings.push(...r.warnings); + msg.content = r.text; + } + return { messages: out, warnings }; +} diff --git a/plugins/lakera/test-file.test.ts b/plugins/lakera/test-file.test.ts new file mode 100644 index 000000000..170900ad8 --- /dev/null +++ b/plugins/lakera/test-file.test.ts @@ -0,0 +1,118 @@ +import { + applyMasksToMessages, + applyPayloadMasksToString, + dedupePayloadItems, + isOnlyPiiViolation, + mergeOverlappingIntervals, + normalizeSpan, +} from './redaction'; + +describe('lakera redaction helpers', () => { + it('mergeOverlappingIntervals merges overlap and adjacent', () => { + expect( + mergeOverlappingIntervals([ + [0, 3], + [2, 5], + ]) + ).toEqual([[0, 5]]); + expect( + mergeOverlappingIntervals([ + [0, 2], + [2, 4], + ]) + ).toEqual([[0, 4]]); + expect( + mergeOverlappingIntervals([ + [0, 1], + [5, 6], + ]) + ).toEqual([ + [0, 1], + [5, 6], + ]); + }); + + it('dedupePayloadItems', () => { + const items = [ + { message_id: 0, start: 1, end: 2, detector_type: 'pii/a' }, + { message_id: 0, start: 1, end: 2, detector_type: 'pii/a' }, + ]; + expect(dedupePayloadItems(items)).toHaveLength(1); + }); + + it('normalizeSpan half-open', () => { + expect(normalizeSpan(0, 3, 10, false)).toEqual([0, 3]); + expect(normalizeSpan(0, 11, 10, false)).toBeNull(); + }); + + it('two non-overlapping spans', () => { + const text = 'hello SECRET1 world SECRET2 end'; + const payload = [ + { message_id: 0, start: 6, end: 13, detector_type: 'pii/foo' }, + { message_id: 0, start: 20, end: 27, detector_type: 'pii/bar' }, + ]; + const { text: out } = applyPayloadMasksToString(text, payload, 0, false); + expect(out).not.toContain('SECRET1'); + expect(out).not.toContain('SECRET2'); + expect(out).toContain('[MASKED_'); + }); + + it('overlapping spans merged once', () => { + const text = '0123456789'; + const payload = [ + { message_id: 0, start: 2, end: 5, detector_type: 'pii/a' }, + { message_id: 0, start: 4, end: 7, detector_type: 'pii/b' }, + ]; + const { text: out } = applyPayloadMasksToString(text, payload, 0, false); + expect(out.split('[MASKED_').length - 1).toBe(1); + }); + + it('unicode emoji index', () => { + const text = 'hi 👋 there'; + const i = text.indexOf('👋'); + const payload = [ + { message_id: 0, start: i, end: i + 1, detector_type: 'pii/x' }, + ]; + const { text: out } = applyPayloadMasksToString(text, payload, 0, false); + expect(out).not.toContain('👋'); + expect(out).toContain('[MASKED_'); + }); + + it('invalid span skipped', () => { + const text = 'short'; + const payload = [ + { message_id: 0, start: 0, end: 99, detector_type: 'pii/x' }, + ]; + const { text: out, warnings } = applyPayloadMasksToString( + text, + payload, + 0, + false + ); + expect(out).toBe(text); + expect(warnings.length).toBeGreaterThan(0); + }); + + it('message_id isolation', () => { + const msgs = [ + { role: 'user', content: 'aaa' }, + { role: 'user', content: 'bbb' }, + ]; + const payload = [ + { message_id: 1, start: 0, end: 1, detector_type: 'pii/x' }, + ]; + const { messages: out } = applyMasksToMessages(msgs, payload, false); + expect(out[0].content).toBe('aaa'); + expect(out[1].content).not.toBe('bbb'); + }); + + it('isOnlyPiiViolation', () => { + expect(isOnlyPiiViolation([])).toBe(false); + expect( + isOnlyPiiViolation([{ detected: true, detector_type: 'prompt_attack' }]) + ).toBe(false); + expect( + isOnlyPiiViolation([{ detected: true, detector_type: 'pii/email' }]) + ).toBe(true); + }); +});