diff --git a/src/providers/oracle/api.ts b/src/providers/oracle/api.ts index 37ca193b2..30447eed6 100644 --- a/src/providers/oracle/api.ts +++ b/src/providers/oracle/api.ts @@ -37,6 +37,9 @@ const OracleAPIConfig: ProviderAPIConfig = { case 'stream-chatComplete': endpoint = '/actions/chat'; break; + case 'embed': + endpoint = '/actions/embedText'; + break; default: return ''; } diff --git a/src/providers/oracle/embed.test.ts b/src/providers/oracle/embed.test.ts new file mode 100644 index 000000000..26cebe2e1 --- /dev/null +++ b/src/providers/oracle/embed.test.ts @@ -0,0 +1,100 @@ +import { OracleEmbedConfig, OracleEmbedResponseTransform } from './embed'; + +const baseProviderOptions = { + oracleCompartmentId: 'ocid1.compartment.oc1..test', +} as any; + +const transformInput = (params: any) => { + const cfg = OracleEmbedConfig.input as any; + return cfg.transform(params, baseProviderOptions); +}; + +const transformInputType = (params: any) => { + const cfg = OracleEmbedConfig.input_type as any; + return cfg.transform(params, baseProviderOptions); +}; + +const transformServingMode = (params: any) => { + const cfg = OracleEmbedConfig.model as any; + return cfg.transform(params, baseProviderOptions); +}; + +describe('Oracle Embeddings — request transform', () => { + it('wraps a string input into the OCI inputs array', () => { + expect(transformInput({ input: 'hello' })).toEqual(['hello']); + }); + + it('passes through an array of strings', () => { + expect(transformInput({ input: ['a', 'b'] })).toEqual(['a', 'b']); + }); + + it('extracts text from object items in input array', () => { + const out = transformInput({ + input: [{ text: 'first' }, { text: 'second' }], + }); + expect(out).toEqual(['first', 'second']); + }); + + it('returns undefined for empty array', () => { + expect(transformInput({ input: [] })).toBeUndefined(); + }); + + it('maps lowercase input_type to OCI uppercase', () => { + expect(transformInputType({ input_type: 'search_document' })).toBe( + 'SEARCH_DOCUMENT' + ); + expect(transformInputType({ input_type: 'CLUSTERING' })).toBe('CLUSTERING'); + }); + + it('builds an ON_DEMAND servingMode payload', () => { + expect( + transformServingMode({ model: 'cohere.embed-multilingual-v3.0' }) + ).toEqual({ + servingType: 'ON_DEMAND', + modelId: 'cohere.embed-multilingual-v3.0', + }); + }); +}); + +describe('Oracle Embeddings — response transform', () => { + it('maps OCI embeddings to OpenAI shape', () => { + const oci = { + embeddings: [ + [0.1, 0.2], + [0.3, 0.4], + ], + modelId: 'cohere.embed-multilingual-v3.0', + modelVersion: '1.0', + inputTextTokenCount: 7, + }; + + const result: any = OracleEmbedResponseTransform(oci, 200, new Headers()); + + expect(result.object).toBe('list'); + expect(result.data).toEqual([ + { object: 'embedding', embedding: [0.1, 0.2], index: 0 }, + { object: 'embedding', embedding: [0.3, 0.4], index: 1 }, + ]); + expect(result.model).toBe('cohere.embed-multilingual-v3.0'); + expect(result.usage).toEqual({ prompt_tokens: 7, total_tokens: 7 }); + }); + + it('returns an error response for non-200 with code', () => { + const result: any = OracleEmbedResponseTransform( + { code: '400', message: 'bad input' } as any, + 400, + new Headers() + ); + expect(result.error).toBeDefined(); + expect(result.error.message).toContain('bad input'); + }); + + it('returns invalid-provider-response when embeddings missing', () => { + const result: any = OracleEmbedResponseTransform( + { modelId: 'foo' } as any, + 200, + new Headers() + ); + expect(result.error).toBeDefined(); + }); +}); diff --git a/src/providers/oracle/embed.ts b/src/providers/oracle/embed.ts new file mode 100644 index 000000000..fae2dac98 --- /dev/null +++ b/src/providers/oracle/embed.ts @@ -0,0 +1,160 @@ +import { ORACLE } from '../../globals'; +import { + EmbedParams, + EmbedResponse, + EmbedResponseData, +} from '../../types/embedRequestBody'; +import { Options } from '../../types/requestBody'; +import { ErrorResponse, ProviderConfig } from '../types'; +import { + generateErrorResponse, + generateInvalidProviderResponseError, +} from '../utils'; + +type OracleInputType = + | 'SEARCH_DOCUMENT' + | 'SEARCH_QUERY' + | 'CLASSIFICATION' + | 'CLUSTERING' + | 'IMAGE'; + +type OracleTruncate = 'NONE' | 'START' | 'END'; + +export const OracleEmbedConfig: ProviderConfig = { + input: { + param: 'inputs', + required: true, + transform: (params: EmbedParams): string[] | undefined => { + if (typeof params.input === 'string') { + return [params.input]; + } + if (Array.isArray(params.input)) { + const texts: string[] = []; + params.input.forEach((item) => { + if (typeof item === 'string') { + texts.push(item); + } else if (typeof item === 'object' && 'text' in item) { + texts.push((item as { text: string }).text); + } + }); + return texts.length > 0 ? texts : undefined; + } + return undefined; + }, + }, + input_type: { + param: 'inputType', + required: false, + transform: (params: EmbedParams): OracleInputType | undefined => { + const typeMap: Record = { + search_document: 'SEARCH_DOCUMENT', + search_query: 'SEARCH_QUERY', + classification: 'CLASSIFICATION', + clustering: 'CLUSTERING', + image: 'IMAGE', + }; + const inputType = (params as any).input_type; + if (inputType && typeMap[inputType.toLowerCase()]) { + return typeMap[inputType.toLowerCase()]; + } + return undefined; + }, + }, + truncate: { + param: 'truncate', + required: false, + transform: (params: EmbedParams): OracleTruncate | undefined => { + const truncateMap: Record = { + none: 'NONE', + start: 'START', + end: 'END', + }; + const truncate = (params as any).truncate; + if (truncate && truncateMap[truncate.toLowerCase()]) { + return truncateMap[truncate.toLowerCase()]; + } + return undefined; + }, + }, + is_echo: { + param: 'isEcho', + required: false, + transform: (params: EmbedParams): boolean | undefined => { + return (params as any).is_echo; + }, + }, + model: { + param: 'servingMode', + required: true, + transform: (params: EmbedParams): object => { + return { + servingType: 'ON_DEMAND', + modelId: params.model, + }; + }, + }, + compartmentId: { + param: 'compartmentId', + required: true, + default: (_params: EmbedParams, provider: Options): string => { + return provider.oracleCompartmentId || ''; + }, + }, +}; + +export interface OracleEmbedResponse { + embeddings: number[][]; + modelId: string; + modelVersion: string; + inputTextTokenCount?: number; +} + +export interface OracleEmbedErrorResponse { + code: string; + message: string; +} + +export const OracleEmbedResponseTransform = ( + response: OracleEmbedResponse | OracleEmbedErrorResponse, + responseStatus: number, + _responseHeaders: Headers +): EmbedResponse | ErrorResponse => { + if (responseStatus !== 200 && 'code' in response) { + return generateErrorResponse( + { + message: `oracle error: ${response.message || 'Unknown error'}`, + type: response.code?.toString() || null, + param: null, + code: response.code?.toString() || null, + }, + ORACLE + ); + } + + const successResponse = response as OracleEmbedResponse; + if ( + !successResponse.embeddings || + !Array.isArray(successResponse.embeddings) + ) { + return generateInvalidProviderResponseError(response, ORACLE); + } + + const data: EmbedResponseData[] = successResponse.embeddings.map( + (embedding, index) => ({ + object: 'embedding' as const, + embedding, + index, + }) + ); + + return { + object: 'list', + data, + model: successResponse.modelId, + usage: { + prompt_tokens: successResponse.inputTextTokenCount || 0, + total_tokens: successResponse.inputTextTokenCount || 0, + }, + provider: ORACLE, + }; +}; diff --git a/src/providers/oracle/index.ts b/src/providers/oracle/index.ts index 5d6ecce53..dcc9b358e 100644 --- a/src/providers/oracle/index.ts +++ b/src/providers/oracle/index.ts @@ -5,13 +5,16 @@ import { OracleChatCompleteResponseTransform, OracleChatCompleteStreamChunkTransform, } from './chatComplete'; +import { OracleEmbedConfig, OracleEmbedResponseTransform } from './embed'; const OracleConfig: ProviderConfigs = { chatComplete: OracleChatCompleteConfig, + embed: OracleEmbedConfig, api: OracleAPIConfig, responseTransforms: { chatComplete: OracleChatCompleteResponseTransform, 'stream-chatComplete': OracleChatCompleteStreamChunkTransform, + embed: OracleEmbedResponseTransform, }, };