Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/providers/oracle/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ const OracleAPIConfig: ProviderAPIConfig = {
case 'stream-chatComplete':
endpoint = '/actions/chat';
break;
case 'embed':
endpoint = '/actions/embedText';
break;
default:
return '';
}
Expand Down
100 changes: 100 additions & 0 deletions src/providers/oracle/embed.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import { OracleEmbedConfig, OracleEmbedResponseTransform } from './embed';

const baseProviderOptions = {
oracleCompartmentId: 'ocid1.compartment.oc1..test',
} as any;

const transformInput = (params: any) => {
const cfg = OracleEmbedConfig.input as any;
return cfg.transform(params, baseProviderOptions);
};

const transformInputType = (params: any) => {
const cfg = OracleEmbedConfig.input_type as any;
return cfg.transform(params, baseProviderOptions);
};

const transformServingMode = (params: any) => {
const cfg = OracleEmbedConfig.model as any;
return cfg.transform(params, baseProviderOptions);
};

describe('Oracle Embeddings — request transform', () => {
it('wraps a string input into the OCI inputs array', () => {
expect(transformInput({ input: 'hello' })).toEqual(['hello']);
});

it('passes through an array of strings', () => {
expect(transformInput({ input: ['a', 'b'] })).toEqual(['a', 'b']);
});

it('extracts text from object items in input array', () => {
const out = transformInput({
input: [{ text: 'first' }, { text: 'second' }],
});
expect(out).toEqual(['first', 'second']);
});

it('returns undefined for empty array', () => {
expect(transformInput({ input: [] })).toBeUndefined();
});

it('maps lowercase input_type to OCI uppercase', () => {
expect(transformInputType({ input_type: 'search_document' })).toBe(
'SEARCH_DOCUMENT'
);
expect(transformInputType({ input_type: 'CLUSTERING' })).toBe('CLUSTERING');
});

it('builds an ON_DEMAND servingMode payload', () => {
expect(
transformServingMode({ model: 'cohere.embed-multilingual-v3.0' })
).toEqual({
servingType: 'ON_DEMAND',
modelId: 'cohere.embed-multilingual-v3.0',
});
});
});

describe('Oracle Embeddings — response transform', () => {
it('maps OCI embeddings to OpenAI shape', () => {
const oci = {
embeddings: [
[0.1, 0.2],
[0.3, 0.4],
],
modelId: 'cohere.embed-multilingual-v3.0',
modelVersion: '1.0',
inputTextTokenCount: 7,
};

const result: any = OracleEmbedResponseTransform(oci, 200, new Headers());

expect(result.object).toBe('list');
expect(result.data).toEqual([
{ object: 'embedding', embedding: [0.1, 0.2], index: 0 },
{ object: 'embedding', embedding: [0.3, 0.4], index: 1 },
]);
expect(result.model).toBe('cohere.embed-multilingual-v3.0');
expect(result.usage).toEqual({ prompt_tokens: 7, total_tokens: 7 });
});

it('returns an error response for non-200 with code', () => {
const result: any = OracleEmbedResponseTransform(
{ code: '400', message: 'bad input' } as any,
400,
new Headers()
);
expect(result.error).toBeDefined();
expect(result.error.message).toContain('bad input');
});

it('returns invalid-provider-response when embeddings missing', () => {
const result: any = OracleEmbedResponseTransform(
{ modelId: 'foo' } as any,
200,
new Headers()
);
expect(result.error).toBeDefined();
});
});
160 changes: 160 additions & 0 deletions src/providers/oracle/embed.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import { ORACLE } from '../../globals';
import {
EmbedParams,
EmbedResponse,
EmbedResponseData,
} from '../../types/embedRequestBody';
import { Options } from '../../types/requestBody';
import { ErrorResponse, ProviderConfig } from '../types';
import {
generateErrorResponse,
generateInvalidProviderResponseError,
} from '../utils';

type OracleInputType =
| 'SEARCH_DOCUMENT'
| 'SEARCH_QUERY'
| 'CLASSIFICATION'
| 'CLUSTERING'
| 'IMAGE';

type OracleTruncate = 'NONE' | 'START' | 'END';

export const OracleEmbedConfig: ProviderConfig = {
input: {
param: 'inputs',
required: true,
transform: (params: EmbedParams): string[] | undefined => {
if (typeof params.input === 'string') {
return [params.input];
}
if (Array.isArray(params.input)) {
const texts: string[] = [];
params.input.forEach((item) => {
if (typeof item === 'string') {
texts.push(item);
} else if (typeof item === 'object' && 'text' in item) {
texts.push((item as { text: string }).text);
}
});
return texts.length > 0 ? texts : undefined;
}
return undefined;
},
},
input_type: {
param: 'inputType',
required: false,
transform: (params: EmbedParams): OracleInputType | undefined => {
const typeMap: Record<string, OracleInputType> = {
search_document: 'SEARCH_DOCUMENT',
search_query: 'SEARCH_QUERY',
classification: 'CLASSIFICATION',
clustering: 'CLUSTERING',
image: 'IMAGE',
};
const inputType = (params as any).input_type;
if (inputType && typeMap[inputType.toLowerCase()]) {
return typeMap[inputType.toLowerCase()];
}
return undefined;
},
},
truncate: {
param: 'truncate',
required: false,
transform: (params: EmbedParams): OracleTruncate | undefined => {
const truncateMap: Record<string, OracleTruncate> = {
none: 'NONE',
start: 'START',
end: 'END',
};
const truncate = (params as any).truncate;
if (truncate && truncateMap[truncate.toLowerCase()]) {
return truncateMap[truncate.toLowerCase()];
}
return undefined;
},
},
is_echo: {
param: 'isEcho',
required: false,
transform: (params: EmbedParams): boolean | undefined => {
return (params as any).is_echo;
},
},
model: {
param: 'servingMode',
required: true,
transform: (params: EmbedParams): object => {
return {
servingType: 'ON_DEMAND',
modelId: params.model,
};
},
},
compartmentId: {
param: 'compartmentId',
required: true,
default: (_params: EmbedParams, provider: Options): string => {
return provider.oracleCompartmentId || '';
},
},
};

export interface OracleEmbedResponse {
embeddings: number[][];
modelId: string;
modelVersion: string;
inputTextTokenCount?: number;
}

export interface OracleEmbedErrorResponse {
code: string;
message: string;
}

export const OracleEmbedResponseTransform = (
response: OracleEmbedResponse | OracleEmbedErrorResponse,
responseStatus: number,
_responseHeaders: Headers
): EmbedResponse | ErrorResponse => {
if (responseStatus !== 200 && 'code' in response) {
return generateErrorResponse(
{
message: `oracle error: ${response.message || 'Unknown error'}`,
type: response.code?.toString() || null,
param: null,
code: response.code?.toString() || null,
},
ORACLE
);
}

const successResponse = response as OracleEmbedResponse;
if (
!successResponse.embeddings ||
!Array.isArray(successResponse.embeddings)
) {
return generateInvalidProviderResponseError(response, ORACLE);
}

const data: EmbedResponseData[] = successResponse.embeddings.map(
(embedding, index) => ({
object: 'embedding' as const,
embedding,
index,
})
);

return {
object: 'list',
data,
model: successResponse.modelId,
usage: {
prompt_tokens: successResponse.inputTextTokenCount || 0,
total_tokens: successResponse.inputTextTokenCount || 0,
},
provider: ORACLE,
};
};
3 changes: 3 additions & 0 deletions src/providers/oracle/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ import {
OracleChatCompleteResponseTransform,
OracleChatCompleteStreamChunkTransform,
} from './chatComplete';
import { OracleEmbedConfig, OracleEmbedResponseTransform } from './embed';

const OracleConfig: ProviderConfigs = {
chatComplete: OracleChatCompleteConfig,
embed: OracleEmbedConfig,
api: OracleAPIConfig,
responseTransforms: {
chatComplete: OracleChatCompleteResponseTransform,
'stream-chatComplete': OracleChatCompleteStreamChunkTransform,
embed: OracleEmbedResponseTransform,
},
};

Expand Down