diff --git a/modelz/client.py b/modelz/client.py index e28e2b0..5b1c371 100644 --- a/modelz/client.py +++ b/modelz/client.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Generator +from typing import Any, Dict, Generator from http import HTTPStatus from urllib.parse import urljoin @@ -131,3 +131,45 @@ def build(self, repo: str): ModelzResponse(resp) console.print(f"created the build job for repo [bold cyan]{repo}[/bold cyan]") + + @classmethod + def create_completion( + cls, + deployment: str, + model: str, + prompt: str, + params: Dict[str, Any] | None = None, + serde: str = "json", + ): + """Create a completion using the model. + + Args: + deployment: deployment ID + prompt: The prompt to use for the completion. + params: additional request params, will be serialized by `serde` + serde: serialize/deserialize method, choose from ("json", "msg", "raw") + """ + # Create an instance of the class. + client = cls(deployment=deployment) + + try: + from llmspec import LLMSpec + + # Instantiate LLMSpec and transform the prompt + llmspec = LLMSpec(prompt) + transformed_prompt = llmspec.to_model(model) + except ImportError as err: + raise ImportError( + "llmspec is required for LLM models" + "\nPlease install it with the command `pip install llmspec" + ) from err + + # Prepare request params + request_params = {"prompt": transformed_prompt} + if params: + request_params.update(params) + + # Get the inference result + response = client.inference(request_params, deployment, serde) + + return response