Skip to main content

Model Upload

Learn how to upload a model using Clarifai SDKs


The Clarifai SDKs allow you to upload custom models easily. Whether you're working with a pre-trained model from an external source or one you've built from scratch, Clarifai allows seamless integration of your models, enabling you to take advantage of the platform’s powerful capabilities.

Once uploaded, your model can be utilized alongside Clarifai's vast suite of AI tools. It will be automatically deployed and ready to be evaluated, combined with other models and agent operators in a workflow, or used to serve inference requests as it is.

Let’s demonstrate how you can successfully upload different types of models to the Clarifai platform.

info
  • This new feature is in Public Preview. If you'd like to test it out and provide feedback, please request access here.

  • This new upload experience is compatible with the latest clarifai Python package, starting from version 10.9.2.

  • If you prefer the previous upload method, which is supported up to version 10.8.4, you can refer to the documentation here.

tip

You can run the following command to clone the repository containing examples of how to upload various model types and follow along with this documentation: git clone https://github.com/Clarifai/examples.git. After cloning it, go to the models/model_upload folder.

Prerequisites

Installation

To begin, install the latest version of the clarifai Python package.

pip install --upgrade clarifai

Environment Set Up

Before proceeding, ensure that the CLARIFAI_PAT (Personal Access Token) environment variable is set. You can generate the PAT key in your Personal Settings page by navigating to the Security section.

This token is essential for authenticating your connection to the Clarifai platform.

export CLARIFAI_PAT=YOUR_PERSONAL_ACCESS_TOKEN_HERE

Create Project Directory

Create a project directory and organize your files as indicated below to fit the requirements of uploading models to the Clarifai platform.

your_model_directory/
├── 1/
│ └── model.py
├── requirements.txt
└── config.yaml
  • your_model_directory/ – The main directory containing your model files.
    • 1/ – A subdirectory that holds the model file (Note that the folder is named as 1).
      • model.py – Contains the code that defines your model, including loading the model and running inference.
    • requirements.txt – Lists the Python libraries and dependencies required to run your model.
    • config.yaml – Contains model metadata and configuration details necessary for building the Docker image, defining compute resources, and uploading the model to Clarifai.

How to Upload a Model

Let's talk about the common steps you'd follow to upload any type of model to the Clarifai platform.

Step 1: Define the config.yaml File

The config.yaml file is essential for specifying the model’s metadata, compute resource requirements, and model checkpoints.

Here’s a breakdown of the key sections in the file.

Model Info

This section defines your model ID, Clarifai user ID, and Clarifai app ID, which will determine where the model is uploaded on the Clarifai platform.

model:
id: "model_id"
user_id: "user_id"
app_id: "app_id"
model_type_id: "text-to-text" # Change this based on your model type (e.g., image-classifier, text-to-text)

Compute Resources

Here, you define the minimum compute resources required for running your model, including CPU, memory, and optional GPU specifications.

inference_compute_info:
cpu_limit: "2"
cpu_memory: "13Gi"
num_accelerators: 1
accelerator_type: ["NVIDIA-A10G"] # Specify the GPU type if needed
accelerator_memory: "15Gi"
  • cpu_limit – Number of CPUs allocated for the model (follows Kubernetes notation, e.g., "1", "2").
  • cpu_memory – Minimum memory required for the CPU (uses Kubernetes notation, e.g., "1Gi", "1500Mi", "3Gi").
  • num_accelerators – Number of GPUs or TPUs to use for inference.
  • accelerator_type – Specifies the type of accelerators (e.g., GPU or TPU) supported by the model (e.g., "NVIDIA-A10G").
  • accelerator_memory – Minimum memory required for the GPU or TPU.

Model Checkpoints

If you're using a model from Hugging Face, you can automatically download its checkpoints by specifying the appropriate configuration in this section. For private or restricted Hugging Face repositories, include an access token.

checkpoints:
type: "huggingface"
repo_id: "meta-llama/Meta-Llama-3-8B-Instruct"
hf_token: "your_hf_token" # Required for private models

Model Concepts or Labels

important

This section is required if your model outputs concepts or labels and is not being directly loaded from Hugging Face.

For models that output concepts or labels, such as classification or detection models, you must define a concepts section in the config.yaml file:

concepts:
- id: '0'
name: bus
- id: '1'
name: person
- id: '2'
name: bicycle
- id: '3'
name: car
note

If you're using a model from Hugging Face and the checkpoints section is defined, the Clarifai platform will automatically infer concepts. In this case, you don’t need to manually specify them.

Step 2: Define Dependencies in requirements.txt

The requirements.txt file lists all the Python dependencies your model needs. This ensures that the necessary libraries are installed in the runtime environment.

Step 3: Prepare the model.py File

The model.py file contains the logic for your model, including how it loads and handles predictions. This file must implement a class that inherits from ModelRunner and defines the following methods:

  • load_model() – Initializes and loads the model, preparing it for inference.
  • predict(input_data) – Handles the core logic for making predictions. It processes the input data and returns the output response.
  • generate(input_data) – Provides output in a streaming manner, if applicable to the model's use case.
  • stream(input_data) – Manages both streaming input and output, primarily for more advanced use cases where data is processed continuously.
from clarifai.runners.models.model_runner import ModelRunner

class YourCustomModelRunner(ModelRunner):
def load_model(self):
# Initialize and load the model here
pass

def predict(self, request):
# Handle input and return the model's predictions
return output_data

def generate(self, request):
# Handle streaming output (if applicable)
pass

def stream(self, request):
# Handle both streaming input and output
pass

Step 4: Test the Model Locally

Before uploading your model to the Clarifai platform, it's important to test it locally to catch any typos or misconfigurations in the code.

This can prevent upload failures due to issues in the model.py or incorrect model implementation. It also ensures the model runs smoothly and that all dependencies are correctly configured.

You can test the model within a Docker container or a Python virtual environment.

Recommendation

If Docker is installed on your system, it is highly recommended to use it for testing or running the model. Docker provides better isolation and avoids dependency conflicts.

warning

Ensure your local environment has sufficient memory and compute resources to load and run the model for testing.

There are two types of CLI (command line interface) commands you can use to test your models in your local development environment.

1. Using the test-locally Command

This method allows you to test your model with a single CLI command. It runs the model locally and sends a sample request to verify that the model responds successfully. The results of the request are displayed directly in the console.

Here is how to test a model in a Docker Container:

clarifai model test-locally --model_path {add_model_path_here} --mode container

Here is how to test a model in a virtual environment:

clarifai model test-locally --model_path {add_model_path_here} --mode env

2. Using the run-locally Command

This method starts a local gRPC server at https://localhost:{port}/ for running the model. Once the server is running, you can perform inference on the model via the Clarifai client SDK.

Here is how to test a model in a Docker Container:

clarifai model run-locally --model_path {add_model_path_here} --mode container --port 8000

Here is how to test a model in a virtual environment:

clarifai model run-locally --model_path {add_model_path_here} --mode container --port 8000

Once the model is running locally, you need to configure the CLARIFAI_API_BASE environment variable to point to the localhost and port where the gRPC server is running.

export CLARIFAI_API_BASE="localhost:{port}"

You can then make different types of inference requests using the model — unary-unary, unary-stream, or stream-stream predict calls.

Here is an example of a unary-unary prediction call:

from clarifai.client.model import Model
model = Model(model_id='model_id', user_id='user_id', app_id='app_id') # no need to provide any actual values of `model_id`, `user_id` and `app_id`

image_url = "https://samples.clarifai.com/metro-north.jpg"

# Model Predict
model_prediction = model.predict_by_url(image_url,)
CLI Flags

These are the key CLI flags available for local testing and running your models:

  • --model_path — Path to the model directory.

  • --mode — Specify how to run the model: env for virtual environment or container for Docker container. Defaults to env.

  • -p or --port — The port to host the gRPC server for running the model locally. Defaults to 8000.

  • --keep_env — Retain the virtual environment after testing the model locally (applicable for env mode). Defaults to False.

  • --keep_image — Retain the Docker image built after testing the model locally (applicable for container mode). Defaults to False.

Step 5: Upload the Model to Clarifai

Once your model is ready, upload it to the Clarifai platform by running the following command:

clarifai model upload --model_path {add_model_path_here}

This command builds the model’s Docker image using the defined compute resources and uploads it to Clarifai, where it can be served in production.

Examples

info

You can find various model upload examples here, which demonstrate different use cases and optimizations.

Image Classifier

model.py

# Model to be uploaded: https://huggingface.co/Falconsai/nsfw_image_detection

import os
from io import BytesIO
from typing import Iterator

import requests
import torch
from clarifai.runners.models.model_runner import ModelRunner
from clarifai.utils.logging import logger
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2
from PIL import Image
from transformers import AutoModelForImageClassification, ViTImageProcessor


def preprocess_image(image_url=None, image_base64=None):
if image_base64:
img = Image.open(BytesIO(image_base64))
elif image_url:
img = Image.open(BytesIO(requests.get(image_url).content))
return img


class MyRunner(ModelRunner):
"""A custom runner that loads the model and classifies images using it.
"""

def load_model(self):
"""Load the model here."""

self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Running on device: {self.device}")

# if checkpoints section is in config.yaml file then checkpoints will be downloaded at this path during model upload time.
checkpoint_path = os.path.join(os.path.dirname(__file__), "checkpoints")

self.model = AutoModelForImageClassification.from_pretrained(checkpoint_path,).to(self.device)
self.processor = ViTImageProcessor.from_pretrained(checkpoint_path)
logger.info("Done loading!")

def predict(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
"""This is the method that will be called when the runner is run. It takes in an input and
returns an output.
"""

# Get the concept protos from the model.
concept_protos = request.model.model_version.output_info.data.concepts

outputs = []
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
output = resources_pb2.Output()

data = inp.data

output_concepts = []

if data.image.base64 != b"":
img = preprocess_image(image_base64=data.image.base64)
elif data.image.url != "":
img = preprocess_image(image_url=data.image.url)

with torch.no_grad():
inputs = self.processor(images=img, return_tensors="pt").to(self.device)
model_output = self.model(**inputs)
logits = model_output.logits

probs = torch.softmax(logits, dim=-1)[0]
sorted_indices = torch.argsort(probs, dim=-1, descending=True)
for idx in sorted_indices:
concept_protos[idx.item()].value = probs[idx.item()].item()
output_concepts.append(concept_protos[idx.item()])

output.data.concepts.extend(output_concepts)

output.status.code = status_code_pb2.SUCCESS
outputs.append(output)
return service_pb2.MultiOutputResponse(outputs=outputs,)

def generate(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
raise NotImplementedError("Stream method is not implemented for image classification models.")

def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
## raise NotImplementedError
raise NotImplementedError("Stream method is not implemented for image classification models.")

requirements.txt

torch==2.4.1
tokenizers==0.19.1
transformers==4.44.1
pillow==10.4.0
requests==2.32.3

config.yaml

# This is the sample config file for the image-classifier model

model:
id: "model_id"
user_id: "user_id"
app_id: "app_id"
model_type_id: "visual-classifier"

build_info:
python_version: "3.10"

inference_compute_info:
cpu_limit: "1"
cpu_memory: "2Gi"
num_accelerators: 1
accelerator_type: ["NVIDIA-A10G"]
accelerator_memory: "3Gi"

checkpoints:
type: "huggingface"
repo_id: "Falconsai/nsfw_image_detection"
hf_token: "hf_token"

Image Detector

model.py

# Model to be uploaded: https://huggingface.co/facebook/detr-resnet-50

import os
from io import BytesIO
from typing import Iterator

import requests
import torch
from clarifai.runners.models.model_runner import ModelRunner
from clarifai.utils.logging import logger
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2
from PIL import Image
from transformers import DetrForObjectDetection, DetrImageProcessor


def preprocess_image(image_url=None, image_base64=None):
if image_base64:
img = Image.open(BytesIO(image_base64))
elif image_url:
img = Image.open(BytesIO(requests.get(image_url).content))
return img


class MyRunner(ModelRunner):
"""A custom runner that adds "Hello World" to the end of the text and replaces the domain of the
image URL as an example.
"""

def load_model(self):
"""Load the model here."""
checkpoint_path = os.path.join(os.path.dirname(__file__), "checkpoints")
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Running on device: {self.device}")

self.model = DetrForObjectDetection.from_pretrained(
checkpoint_path, revision="no_timm").to(self.device)
self.processor = DetrImageProcessor.from_pretrained(checkpoint_path, revision="no_timm")
logger.info("Done loading!")

def predict(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
"""This is the method that will be called when the runner is run. It takes in an input and
returns an output.
"""

# Get the concept protos from the model.
concept_protos = request.model.model_version.output_info.data.concepts

outputs = []
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
output = resources_pb2.Output()

data = inp.data

output_regions = []

if data.image.base64 != b"":
img = preprocess_image(image_base64=data.image.base64)
elif data.image.url != "":
img = preprocess_image(image_url=data.image.url)

with torch.no_grad():
inputs = self.processor(images=img, return_tensors="pt").to(self.device)
model_output = self.model(**inputs)

# convert outputs (bounding boxes and class logits) to COCO API
# let's only keep detections with score > 0.7 (You can set it to any other value)
target_sizes = torch.tensor([img.size[::-1]])
results = self.processor.post_process_object_detection(
model_output, target_sizes=target_sizes, threshold=0.7)[0]

width, height = img.size
for score, label_idx, box in zip(results["scores"], results["labels"], results["boxes"]):
# Normalize bounding box
x_min, y_min, x_max, y_max = box

top_row = round(y_min.item() / height, 2)
left_col = round(x_min.item() / width, 2)
bottom_row = round(y_max.item() / height, 2)
right_col = round(x_max.item() / width, 2)

output_region = resources_pb2.Region()
output_region.id = str(label_idx.item())
output_region.value = score.item()

concept_protos[label_idx.item()].value = score.item()
output_region.data.concepts.add(concept_protos[label_idx.item()])

output_region.region_info.bounding_box.top_row = top_row
output_region.region_info.bounding_box.left_col = left_col
output_region.region_info.bounding_box.bottom_row = bottom_row
output_region.region_info.bounding_box.right_col = right_col

output_regions.append(output_region)

output.data.regions.extend(output_regions)

output.status.code = status_code_pb2.SUCCESS
outputs.append(output)
return service_pb2.MultiOutputResponse(outputs=outputs,)

def generate(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
raise NotImplementedError("Stream method is not implemented for image detection models.")

def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
raise NotImplementedError("Stream method is not implemented for image detection models.")

requirements.txt

torch==2.4.1
tokenizers==0.19.1
transformers==4.44.2
pillow==10.4.0
requests==2.32.3

config.yaml

# This is the sample config file for the image-detection model

model:
id: "model_id"
user_id: "user_id"
app_id: "app_id"
model_type_id: "visual-detector"

build_info:
python_version: "3.10"

inference_compute_info:
cpu_limit: "1"
cpu_memory: "2Gi"
num_accelerators: 1
accelerator_type: ["NVIDIA-A10G"]
accelerator_memory: "5Gi"


checkpoints:
type: "huggingface"
repo_id: "facebook/detr-resnet-50"
hf_token: "hf_token"

Large Language Models (LLMs)

model.py

# Model to be uploaded: https://huggingface.co/casperhansen/llama-3-8b-instruct-awq

import os
from threading import Thread
from typing import Iterator

import torch
from clarifai.runners.models.model_runner import ModelRunner
from clarifai.utils.logging import logger
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
from google.protobuf import json_format
from transformers import (AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer)


class MyRunner(ModelRunner):
"""A custom runner that loads the Llama model and generates text using it.
"""

def load_model(self):
"""Load the model here."""
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Running on device: {self.device}")

# if checkpoints section is in config.yaml file then checkpoints will be downloaded at this path during model upload time.
checkpoints = os.path.join(os.path.dirname(__file__), "checkpoints")

self.tokenizer = AutoTokenizer.from_pretrained(checkpoints)
self.model = AutoModelForCausalLM.from_pretrained(
checkpoints,
low_cpu_mem_usage=True,
device_map=self.device,
torch_dtype=torch.bfloat16,
)
# Create a streamer for streaming the output of the model
self.streamer = TextIteratorStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True)
logger.info("Done loading!")

def predict(self,
request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
"""This is the method that will be called when the runner is run. It takes in an input and
returns an outputs the response using llama model.
"""

# TODO: Could cache the model and this conversion if the hash is the same.
model = request.model
output_info = {}
if request.model.model_version.id != "":
output_info = json_format.MessageToDict(
model.model_version.output_info, preserving_proto_field_name=True)

outputs = []
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
data = inp.data

# Optional use of output_info
inference_params = {}
if "params" in output_info:
inference_params = output_info["params"]

temperature = inference_params.get("temperature", 0.7)
max_tokens = inference_params.get("max_tokens", 100)
max_tokens = int(max_tokens)

top_k = inference_params.get("top_k", 40)
top_k = int(top_k)
top_p = inference_params.get("top_p", 1.0)

if data.text.raw != "":
prompt = data.text.raw

inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
output_tokens = self.model.generate(
**inputs,
eos_token_id=self.tokenizer.eos_token_id,
do_sample=True,
temperature=temperature,
max_new_tokens=max_tokens,
top_p=top_p,
top_k=top_k,
)
llm_outputs = self.tokenizer.batch_decode(
output_tokens[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True)

output = resources_pb2.Output()
output.data.text.raw = llm_outputs[0]

output.status.code = status_code_pb2.SUCCESS
outputs.append(output)
return service_pb2.MultiOutputResponse(outputs=outputs,)

def generate(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
"""Example yielding a whole batch of streamed stuff back."""

# TODO: Could cache the model and this conversion if the hash is the same.
model = request.model
output_info = {}
if request.model.model_version.id != "":
output_info = json_format.MessageToDict(
model.model_version.output_info, preserving_proto_field_name=True)

# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
data = inp.data

# Optional use of output_info
inference_params = {}
if "params" in output_info:
inference_params = output_info["params"]

temperature = inference_params.get("temperature", 0.7)
max_tokens = inference_params.get("max_tokens", 100)
max_tokens = int(max_tokens)
top_p = inference_params.get("top_p", 1.0)

top_k = inference_params.get("top_k", 40)
top_k = int(top_k)

kwargs = dict(temperature=temperature, top_p=top_p, max_new_tokens=max_tokens, top_k=top_k)

if data.text.raw != "":
prompt = data.text.raw

inputs = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()
generation_kwargs = dict(input_ids=inputs, streamer=self.streamer, **kwargs)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()

for new_text in self.streamer:
output = resources_pb2.Output()

output.data.text.raw = new_text
output.status.code = status_code_pb2.SUCCESS
result = service_pb2.MultiOutputResponse(
status=status_pb2.Status(
code=status_code_pb2.SUCCESS,
description="Success",
),
outputs=[output],
)
yield result
thread.join()

def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
"""Example yielding a whole batch of streamed stuff back."""
output_info = {}
for ri, request in enumerate(request_iterator):
if ri == 0: # only first request has model information.
model = request.model
if request.model.model_version.id != "":
output_info = json_format.MessageToDict(
model.model_version.output_info, preserving_proto_field_name=True)
# Optional use of output_info
inference_params = {}
if "params" in output_info:
inference_params = output_info["params"]
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
data = inp.data

# Optional use of output_info
inference_params = {}
if "params" in output_info:
inference_params = output_info["params"]

temperature = inference_params.get("temperature", 0.7)
max_tokens = inference_params.get("max_tokens", 100)
max_tokens = int(max_tokens)
top_p = inference_params.get("top_p", 1.0)

top_k = inference_params.get("top_k", 40)
top_k = int(top_k)

kwargs = dict(temperature=temperature, top_p=top_p, max_new_tokens=max_tokens, top_k=top_k)

if data.text.raw != "":
prompt = data.text.raw

inputs = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()
generation_kwargs = dict(input_ids=inputs, streamer=self.streamer, **kwargs)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()

for new_text in self.streamer:
output = resources_pb2.Output()

output.data.text.raw = new_text
output.status.code = status_code_pb2.SUCCESS
result = service_pb2.MultiOutputResponse(
status=status_pb2.Status(
code=status_code_pb2.SUCCESS,
description="Success",
),
outputs=[output],
)
yield result
thread.join()

requirements.txt

torch==2.3.1
tokenizers==0.19.1
transformers==4.44.2
accelerate==0.34.2
scipy==1.10.1
optimum==1.22.0
xformers==0.0.27
protobuf==5.27.3
einops==0.8.0
requests==2.32.2
sentence_transformers==2.2.0
sentencepiece==0.2.0
autoawq==0.2.6

config.yaml

# This is the sample config file for the Llama model

model:
id: "llama-3-8b-instruct"
user_id: "user_id"
app_id: "app_id"
model_type_id: "text-to-text"

build_info:
python_version: "3.10"

inference_compute_info:
cpu_limit: "1"
cpu_memory: "8Gi"
num_accelerators: 1
accelerator_type: ["NVIDIA-A10G"]
accelerator_memory: "12Gi"

checkpoints:
type: "huggingface"
repo_id: "casperhansen/llama-3-8b-instruct-awq"
hf_token: "hf_token"
tip

You can refer to the examples repository mentioned above for additional examples of uploading other large language models (LLMs).

Speech Recognition Model

model.py

# Model to be uploaded: https://platform.openai.com/docs/guides/speech-to-text/quickstart

import copy
import io
from typing import Iterator

import requests
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
from openai import OpenAI

from clarifai.runners.models.model_runner import ModelRunner


def bytes_to_audio_file(audio_bytes):
"""Convert bytes data into a file-like object."""
if not audio_bytes:
raise ValueError("Audio bytes cannot be empty.")
audio_file = io.BytesIO(audio_bytes)
audio_file.name = "audio.mp3" # This name is used for the API
return audio_file


def preprocess_audio(audio_url=None, audio_bytes=None, chunk_size=1024, stream=False):
"""
Fetch and preprocess audio data from a URL or bytes.

Parameters:
url (str): URL to fetch audio from (if provided).
bytes (bytes): Audio data in bytes (if provided).
chunk_size (int): Size of chunks for streaming.
stream (bool): Whether to stream the audio in chunks.

Returns:
Generator or file-like object containing audio data.
"""

if audio_bytes:
if stream:
# Stream the audio in chunks (generator)
def audio_stream_generator():
for i in range(0, len(audio_bytes), chunk_size):
yield audio_bytes[i : i + chunk_size]

return audio_stream_generator()
else:
# Return a single chunk of audio
return audio_bytes
elif audio_url:
response = requests.get(audio_url, stream=stream)
if response.status_code != 200:
raise Exception(f"Failed to fetch audio. Status code: {response.status_code}")

if stream:
# Stream the audio in chunks (generator)
def audio_stream_generator():
for chunk in response.iter_content(chunk_size=chunk_size):
if chunk: # Filter out keep-alive new chunks
yield chunk

return audio_stream_generator()
else:
# Return a single chunk of audio
return response.content

else:
raise ValueError("Either 'url' or 'audio_bytes' must be provided")


OPENAI_API_KEY = "API_KEY"


class MyRunner(ModelRunner):
"""A custom runner that used for transcribing audio."""

def load_model(self):
"""Load the model here."""
self.client = OpenAI(api_key=OPENAI_API_KEY)
self.modelname = "whisper-1"
self.language = None

# reset the task in set_translate_task
self.task = "transcribe"

def predict(
self, request: service_pb2.PostModelOutputsRequest
) -> service_pb2.MultiOutputResponse:
"""This is the method that will be called when the runner is run. It takes in an input and
returns an output.
"""

outputs = []
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
output = resources_pb2.Output()

data = inp.data
audio_bytes = None
if data.audio.base64:
audio_bytes = preprocess_audio(audio_bytes=data.audio.base64, stream=False)

elif data.audio.url:
audio_bytes = preprocess_audio(
audio_url=data.audio.url,
stream=False,
)

# Send audio bytes to Whisper for transcription
transcription = self.client.audio.transcriptions.create(
model=self.modelname, language=self.language, file=bytes_to_audio_file(audio_bytes)
)

# Set the output data
output.data.text.raw = transcription.text
output.status.code = status_code_pb2.SUCCESS
outputs.append(output)
return service_pb2.MultiOutputResponse(
outputs=outputs,
)

def generate(
self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
def request_iterator(request, chunk_size=1024):
request_copy = copy.deepcopy(request)
for inp in request_copy.inputs:
data = inp.data

audio_chunks = None
if data.audio.base64:
audio_chunks = preprocess_audio(
audio_bytes=data.audio.base64, stream=True, chunk_size=chunk_size
)
elif data.audio.url:
audio_chunks = preprocess_audio(
audio_url=data.audio.url,
stream=True,
chunk_size=chunk_size,
)

for chunk in audio_chunks:
inp.data.audio.base64 = chunk
yield request_copy

chunk_size = 1024 * 1024
return self.stream(request_iterator(request, chunk_size=chunk_size))

def stream(
self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
for request in request_iterator:
for inp in request.inputs:
output = resources_pb2.Output()

data = inp.data
chunk_size = 10 * 1024 * 1024
if data.image.base64 != b"":
audio_chunks = preprocess_audio(
audio_bytes=data.audio.base64, stream=True, chunk_size=chunk_size
)
elif data.audio.url != "":
audio_chunks = preprocess_audio(
audio_url=data.audio.url, stream=True, chunk_size=chunk_size
)

for chunk in audio_chunks:
transcription = self.client.audio.transcriptions.create(
model=self.modelname, language=self.language, file=bytes_to_audio_file(chunk)
)
# Set the output data
output.data.text.raw = transcription.text

output.status.code = status_code_pb2.SUCCESS
result = service_pb2.MultiOutputResponse(
status=status_pb2.Status(
code=status_code_pb2.SUCCESS,
description="Success",
),
outputs=[output],
)
yield result

requirements.txt

openai
requests

config.yaml

# This is the sample config file for the Openai Whisper model

model:
id: "model_id"
user_id: "user_id"
app_id: "app_id"
model_type_id: "audio-to-text"

build_info:
python_version: "3.10"

inference_compute_info:
cpu_limit: "1"
cpu_memory: "500m"
num_accelerators: 0