Model Upload
Learn how to upload a model using Clarifai SDKs
The Clarifai SDKs allow you to upload custom models easily. Whether you're working with a pre-trained model from an external source or one you've built from scratch, Clarifai allows seamless integration of your models, enabling you to take advantage of the platform’s powerful capabilities.
Once uploaded, your model can be utilized alongside Clarifai's vast suite of AI tools. It will be automatically deployed and ready to be evaluated, combined with other models and agent operators in a workflow, or used to serve inference requests as it is.
Let’s demonstrate how you can successfully upload different types of models to the Clarifai platform.
-
This new feature is in Public Preview. If you'd like to test it out and provide feedback, please request access here.
-
This new upload experience is compatible with the latest
clarifai
Python package, starting from version 10.9.2. -
If you prefer the previous upload method, which is supported up to version 10.8.4, you can refer to the documentation here.
You can run the following command to clone the repository containing examples of how to upload various model types and follow along with this documentation:
git clone https://github.com/Clarifai/examples.git
. After cloning it, go to the models/model_upload
folder.
Prerequisites
Installation
To begin, install the latest version of the clarifai
Python package.
pip install --upgrade clarifai
Environment Set Up
Before proceeding, ensure that the CLARIFAI_PAT
(Personal Access Token) environment variable is set. You can generate the PAT key in your Personal Settings page by navigating to the Security section.
This token is essential for authenticating your connection to the Clarifai platform.
export CLARIFAI_PAT=YOUR_PERSONAL_ACCESS_TOKEN_HERE
Create Project Directory
Create a project directory and organize your files as indicated below to fit the requirements of uploading models to the Clarifai platform.
your_model_directory/
├── 1/
│ └── model.py
├── requirements.txt
└── config.yaml
- your_model_directory/ – The main directory containing your model files.
- 1/ – A subdirectory that holds the model file (Note that the folder is named as 1).
- model.py – Contains the code that defines your model, including loading the model and running inference.
- requirements.txt – Lists the Python libraries and dependencies required to run your model.
- config.yaml – Contains model metadata and configuration details necessary for building the Docker image, defining compute resources, and uploading the model to Clarifai.
- 1/ – A subdirectory that holds the model file (Note that the folder is named as 1).
How to Upload a Model
Let's talk about the common steps you'd follow to upload any type of model to the Clarifai platform.
Step 1: Define the config.yaml
File
The config.yaml
file is essential for specifying the model’s metadata, compute resource requirements, and model checkpoints.
Here’s a breakdown of the key sections in the file.
Model Info
This section defines your model ID, Clarifai user ID, and Clarifai app ID, which will determine where the model is uploaded on the Clarifai platform.
model:
id: "model_id"
user_id: "user_id"
app_id: "app_id"
model_type_id: "text-to-text" # Change this based on your model type (e.g., image-classifier, text-to-text)
Compute Resources
Here, you define the minimum compute resources required for running your model, including CPU, memory, and optional GPU specifications.
inference_compute_info:
cpu_limit: "2"
cpu_memory: "13Gi"
num_accelerators: 1
accelerator_type: ["NVIDIA-A10G"] # Specify the GPU type if needed
accelerator_memory: "15Gi"
cpu_limit
– Number of CPUs allocated for the model (follows Kubernetes notation, e.g., "1", "2").cpu_memory
– Minimum memory required for the CPU (uses Kubernetes notation, e.g., "1Gi", "1500Mi", "3Gi").num_accelerators
– Number of GPUs or TPUs to use for inference.accelerator_type
– Specifies the type of accelerators (e.g., GPU or TPU) supported by the model (e.g., "NVIDIA-A10G").accelerator_memory
– Minimum memory required for the GPU or TPU.
Model Checkpoints
If you're using a model from Hugging Face, you can automatically download its checkpoints by specifying the appropriate configuration in this section. For private or restricted Hugging Face repositories, include an access token.
checkpoints:
type: "huggingface"
repo_id: "meta-llama/Meta-Llama-3-8B-Instruct"
hf_token: "your_hf_token" # Required for private models
Model Concepts or Labels
This section is required if your model outputs concepts or labels and is not being directly loaded from Hugging Face.
For models that output concepts or labels, such as classification or detection models, you must define a concepts
section in the config.yaml
file:
concepts:
- id: '0'
name: bus
- id: '1'
name: person
- id: '2'
name: bicycle
- id: '3'
name: car
If you're using a model from Hugging Face and the checkpoints
section is defined, the Clarifai platform will automatically infer concepts. In this case, you don’t need to manually specify them.
Step 2: Define Dependencies in requirements.txt
The requirements.txt
file lists all the Python dependencies your model needs. This ensures that the necessary libraries are installed in the runtime environment.
Step 3: Prepare the model.py
File
The model.py
file contains the logic for your model, including how it loads and handles predictions. This file must implement a class that inherits from ModelRunner
and defines the following methods:
load_model()
– Initializes and loads the model, preparing it for inference.predict(input_data)
– Handles the core logic for making predictions. It processes the input data and returns the output response.generate(input_data)
– Provides output in a streaming manner, if applicable to the model's use case.stream(input_data)
– Manages both streaming input and output, primarily for more advanced use cases where data is processed continuously.
from clarifai.runners.models.model_runner import ModelRunner
class YourCustomModelRunner(ModelRunner):
def load_model(self):
# Initialize and load the model here
pass
def predict(self, request):
# Handle input and return the model's predictions
return output_data
def generate(self, request):
# Handle streaming output (if applicable)
pass
def stream(self, request):
# Handle both streaming input and output
pass
Step 4: Test the Model Locally
Before uploading your model to the Clarifai platform, it's important to test it locally to catch any typos or misconfigurations in the code.
This can prevent upload failures due to issues in the model.py
or incorrect model implementation. It also ensures the model runs smoothly and that all dependencies are correctly configured.
You can test the model within a Docker container or a Python virtual environment.
If Docker is installed on your system, it is highly recommended to use it for testing or running the model. Docker provides better isolation and avoids dependency conflicts.
Ensure your local environment has sufficient memory and compute resources to load and run the model for testing.
There are two types of CLI (command line interface) commands you can use to test your models in your local development environment.
1. Using the test-locally
Command
This method allows you to test your model with a single CLI command. It runs the model locally and sends a sample request to verify that the model responds successfully. The results of the request are displayed directly in the console.
Here is how to test a model in a Docker Container:
clarifai model test-locally --model_path {add_model_path_here} --mode container
Here is how to test a model in a virtual environment:
clarifai model test-locally --model_path {add_model_path_here} --mode env
2. Using the run-locally
Command
This method starts a local gRPC server at https://localhost:{port}/
for running the model. Once the server is running, you can perform inference on the model via the Clarifai client SDK.
Here is how to test a model in a Docker Container:
clarifai model run-locally --model_path {add_model_path_here} --mode container --port 8000
Here is how to test a model in a virtual environment:
clarifai model run-locally --model_path {add_model_path_here} --mode container --port 8000
Once the model is running locally, you need to configure the CLARIFAI_API_BASE
environment variable to point to the localhost and port where the gRPC server is running.
export CLARIFAI_API_BASE="localhost:{port}"
You can then make different types of inference requests using the model — unary-unary, unary-stream, or stream-stream predict calls.
Here is an example of a unary-unary prediction call:
- Python
from clarifai.client.model import Model
model = Model(model_id='model_id', user_id='user_id', app_id='app_id') # no need to provide any actual values of `model_id`, `user_id` and `app_id`
image_url = "https://samples.clarifai.com/metro-north.jpg"
# Model Predict
model_prediction = model.predict_by_url(image_url,)
These are the key CLI flags available for local testing and running your models:
-
--model_path
— Path to the model directory. -
--mode
— Specify how to run the model:env
for virtual environment orcontainer
for Docker container. Defaults toenv
. -
-p
or--port
— The port to host the gRPC server for running the model locally. Defaults to8000
. -
--keep_env
— Retain the virtual environment after testing the model locally (applicable forenv
mode). Defaults toFalse
. -
--keep_image
— Retain the Docker image built after testing the model locally (applicable forcontainer
mode). Defaults toFalse
.
Step 5: Upload the Model to Clarifai
Once your model is ready, upload it to the Clarifai platform by running the following command:
clarifai model upload --model_path {add_model_path_here}
This command builds the model’s Docker image using the defined compute resources and uploads it to Clarifai, where it can be served in production.
Examples
You can find various model upload examples here, which demonstrate different use cases and optimizations.
Image Classifier
model.py
- Python
# Model to be uploaded: https://huggingface.co/Falconsai/nsfw_image_detection
import os
from io import BytesIO
from typing import Iterator
import requests
import torch
from clarifai.runners.models.model_runner import ModelRunner
from clarifai.utils.logging import logger
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2
from PIL import Image
from transformers import AutoModelForImageClassification, ViTImageProcessor
def preprocess_image(image_url=None, image_base64=None):
if image_base64:
img = Image.open(BytesIO(image_base64))
elif image_url:
img = Image.open(BytesIO(requests.get(image_url).content))
return img
class MyRunner(ModelRunner):
"""A custom runner that loads the model and classifies images using it.
"""
def load_model(self):
"""Load the model here."""
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Running on device: {self.device}")
# if checkpoints section is in config.yaml file then checkpoints will be downloaded at this path during model upload time.
checkpoint_path = os.path.join(os.path.dirname(__file__), "checkpoints")
self.model = AutoModelForImageClassification.from_pretrained(checkpoint_path,).to(self.device)
self.processor = ViTImageProcessor.from_pretrained(checkpoint_path)
logger.info("Done loading!")
def predict(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
"""This is the method that will be called when the runner is run. It takes in an input and
returns an output.
"""
# Get the concept protos from the model.
concept_protos = request.model.model_version.output_info.data.concepts
outputs = []
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
output = resources_pb2.Output()
data = inp.data
output_concepts = []
if data.image.base64 != b"":
img = preprocess_image(image_base64=data.image.base64)
elif data.image.url != "":
img = preprocess_image(image_url=data.image.url)
with torch.no_grad():
inputs = self.processor(images=img, return_tensors="pt").to(self.device)
model_output = self.model(**inputs)
logits = model_output.logits
probs = torch.softmax(logits, dim=-1)[0]
sorted_indices = torch.argsort(probs, dim=-1, descending=True)
for idx in sorted_indices:
concept_protos[idx.item()].value = probs[idx.item()].item()
output_concepts.append(concept_protos[idx.item()])
output.data.concepts.extend(output_concepts)
output.status.code = status_code_pb2.SUCCESS
outputs.append(output)
return service_pb2.MultiOutputResponse(outputs=outputs,)
def generate(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
raise NotImplementedError("Stream method is not implemented for image classification models.")
def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
## raise NotImplementedError
raise NotImplementedError("Stream method is not implemented for image classification models.")
requirements.txt
- Text
torch==2.4.1
tokenizers==0.19.1
transformers==4.44.1
pillow==10.4.0
requests==2.32.3
config.yaml
- YAML
# This is the sample config file for the image-classifier model
model:
id: "model_id"
user_id: "user_id"
app_id: "app_id"
model_type_id: "visual-classifier"
build_info:
python_version: "3.10"
inference_compute_info:
cpu_limit: "1"
cpu_memory: "2Gi"
num_accelerators: 1
accelerator_type: ["NVIDIA-A10G"]
accelerator_memory: "3Gi"
checkpoints:
type: "huggingface"
repo_id: "Falconsai/nsfw_image_detection"
hf_token: "hf_token"
Image Detector
model.py
- Python
# Model to be uploaded: https://huggingface.co/facebook/detr-resnet-50
import os
from io import BytesIO
from typing import Iterator
import requests
import torch
from clarifai.runners.models.model_runner import ModelRunner
from clarifai.utils.logging import logger
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2
from PIL import Image
from transformers import DetrForObjectDetection, DetrImageProcessor
def preprocess_image(image_url=None, image_base64=None):
if image_base64:
img = Image.open(BytesIO(image_base64))
elif image_url:
img = Image.open(BytesIO(requests.get(image_url).content))
return img
class MyRunner(ModelRunner):
"""A custom runner that adds "Hello World" to the end of the text and replaces the domain of the
image URL as an example.
"""
def load_model(self):
"""Load the model here."""
checkpoint_path = os.path.join(os.path.dirname(__file__), "checkpoints")
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Running on device: {self.device}")
self.model = DetrForObjectDetection.from_pretrained(
checkpoint_path, revision="no_timm").to(self.device)
self.processor = DetrImageProcessor.from_pretrained(checkpoint_path, revision="no_timm")
logger.info("Done loading!")
def predict(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
"""This is the method that will be called when the runner is run. It takes in an input and
returns an output.
"""
# Get the concept protos from the model.
concept_protos = request.model.model_version.output_info.data.concepts
outputs = []
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
output = resources_pb2.Output()
data = inp.data
output_regions = []
if data.image.base64 != b"":
img = preprocess_image(image_base64=data.image.base64)
elif data.image.url != "":
img = preprocess_image(image_url=data.image.url)
with torch.no_grad():
inputs = self.processor(images=img, return_tensors="pt").to(self.device)
model_output = self.model(**inputs)
# convert outputs (bounding boxes and class logits) to COCO API
# let's only keep detections with score > 0.7 (You can set it to any other value)
target_sizes = torch.tensor([img.size[::-1]])
results = self.processor.post_process_object_detection(
model_output, target_sizes=target_sizes, threshold=0.7)[0]
width, height = img.size
for score, label_idx, box in zip(results["scores"], results["labels"], results["boxes"]):
# Normalize bounding box
x_min, y_min, x_max, y_max = box
top_row = round(y_min.item() / height, 2)
left_col = round(x_min.item() / width, 2)
bottom_row = round(y_max.item() / height, 2)
right_col = round(x_max.item() / width, 2)
output_region = resources_pb2.Region()
output_region.id = str(label_idx.item())
output_region.value = score.item()
concept_protos[label_idx.item()].value = score.item()
output_region.data.concepts.add(concept_protos[label_idx.item()])
output_region.region_info.bounding_box.top_row = top_row
output_region.region_info.bounding_box.left_col = left_col
output_region.region_info.bounding_box.bottom_row = bottom_row
output_region.region_info.bounding_box.right_col = right_col
output_regions.append(output_region)
output.data.regions.extend(output_regions)
output.status.code = status_code_pb2.SUCCESS
outputs.append(output)
return service_pb2.MultiOutputResponse(outputs=outputs,)
def generate(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
raise NotImplementedError("Stream method is not implemented for image detection models.")
def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
raise NotImplementedError("Stream method is not implemented for image detection models.")
requirements.txt
- Text
torch==2.4.1
tokenizers==0.19.1
transformers==4.44.2
pillow==10.4.0
requests==2.32.3
config.yaml
- YAML
# This is the sample config file for the image-detection model
model:
id: "model_id"
user_id: "user_id"
app_id: "app_id"
model_type_id: "visual-detector"
build_info:
python_version: "3.10"
inference_compute_info:
cpu_limit: "1"
cpu_memory: "2Gi"
num_accelerators: 1
accelerator_type: ["NVIDIA-A10G"]
accelerator_memory: "5Gi"
checkpoints:
type: "huggingface"
repo_id: "facebook/detr-resnet-50"
hf_token: "hf_token"
Large Language Models (LLMs)
model.py
- Python
# Model to be uploaded: https://huggingface.co/casperhansen/llama-3-8b-instruct-awq
import os
from threading import Thread
from typing import Iterator
import torch
from clarifai.runners.models.model_runner import ModelRunner
from clarifai.utils.logging import logger
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
from google.protobuf import json_format
from transformers import (AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer)
class MyRunner(ModelRunner):
"""A custom runner that loads the Llama model and generates text using it.
"""
def load_model(self):
"""Load the model here."""
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Running on device: {self.device}")
# if checkpoints section is in config.yaml file then checkpoints will be downloaded at this path during model upload time.
checkpoints = os.path.join(os.path.dirname(__file__), "checkpoints")
self.tokenizer = AutoTokenizer.from_pretrained(checkpoints)
self.model = AutoModelForCausalLM.from_pretrained(
checkpoints,
low_cpu_mem_usage=True,
device_map=self.device,
torch_dtype=torch.bfloat16,
)
# Create a streamer for streaming the output of the model
self.streamer = TextIteratorStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True)
logger.info("Done loading!")
def predict(self,
request: service_pb2.PostModelOutputsRequest) -> service_pb2.MultiOutputResponse:
"""This is the method that will be called when the runner is run. It takes in an input and
returns an outputs the response using llama model.
"""
# TODO: Could cache the model and this conversion if the hash is the same.
model = request.model
output_info = {}
if request.model.model_version.id != "":
output_info = json_format.MessageToDict(
model.model_version.output_info, preserving_proto_field_name=True)
outputs = []
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
data = inp.data
# Optional use of output_info
inference_params = {}
if "params" in output_info:
inference_params = output_info["params"]
temperature = inference_params.get("temperature", 0.7)
max_tokens = inference_params.get("max_tokens", 100)
max_tokens = int(max_tokens)
top_k = inference_params.get("top_k", 40)
top_k = int(top_k)
top_p = inference_params.get("top_p", 1.0)
if data.text.raw != "":
prompt = data.text.raw
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
output_tokens = self.model.generate(
**inputs,
eos_token_id=self.tokenizer.eos_token_id,
do_sample=True,
temperature=temperature,
max_new_tokens=max_tokens,
top_p=top_p,
top_k=top_k,
)
llm_outputs = self.tokenizer.batch_decode(
output_tokens[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True)
output = resources_pb2.Output()
output.data.text.raw = llm_outputs[0]
output.status.code = status_code_pb2.SUCCESS
outputs.append(output)
return service_pb2.MultiOutputResponse(outputs=outputs,)
def generate(self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
"""Example yielding a whole batch of streamed stuff back."""
# TODO: Could cache the model and this conversion if the hash is the same.
model = request.model
output_info = {}
if request.model.model_version.id != "":
output_info = json_format.MessageToDict(
model.model_version.output_info, preserving_proto_field_name=True)
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
data = inp.data
# Optional use of output_info
inference_params = {}
if "params" in output_info:
inference_params = output_info["params"]
temperature = inference_params.get("temperature", 0.7)
max_tokens = inference_params.get("max_tokens", 100)
max_tokens = int(max_tokens)
top_p = inference_params.get("top_p", 1.0)
top_k = inference_params.get("top_k", 40)
top_k = int(top_k)
kwargs = dict(temperature=temperature, top_p=top_p, max_new_tokens=max_tokens, top_k=top_k)
if data.text.raw != "":
prompt = data.text.raw
inputs = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()
generation_kwargs = dict(input_ids=inputs, streamer=self.streamer, **kwargs)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in self.streamer:
output = resources_pb2.Output()
output.data.text.raw = new_text
output.status.code = status_code_pb2.SUCCESS
result = service_pb2.MultiOutputResponse(
status=status_pb2.Status(
code=status_code_pb2.SUCCESS,
description="Success",
),
outputs=[output],
)
yield result
thread.join()
def stream(self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
"""Example yielding a whole batch of streamed stuff back."""
output_info = {}
for ri, request in enumerate(request_iterator):
if ri == 0: # only first request has model information.
model = request.model
if request.model.model_version.id != "":
output_info = json_format.MessageToDict(
model.model_version.output_info, preserving_proto_field_name=True)
# Optional use of output_info
inference_params = {}
if "params" in output_info:
inference_params = output_info["params"]
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
data = inp.data
# Optional use of output_info
inference_params = {}
if "params" in output_info:
inference_params = output_info["params"]
temperature = inference_params.get("temperature", 0.7)
max_tokens = inference_params.get("max_tokens", 100)
max_tokens = int(max_tokens)
top_p = inference_params.get("top_p", 1.0)
top_k = inference_params.get("top_k", 40)
top_k = int(top_k)
kwargs = dict(temperature=temperature, top_p=top_p, max_new_tokens=max_tokens, top_k=top_k)
if data.text.raw != "":
prompt = data.text.raw
inputs = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()
generation_kwargs = dict(input_ids=inputs, streamer=self.streamer, **kwargs)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in self.streamer:
output = resources_pb2.Output()
output.data.text.raw = new_text
output.status.code = status_code_pb2.SUCCESS
result = service_pb2.MultiOutputResponse(
status=status_pb2.Status(
code=status_code_pb2.SUCCESS,
description="Success",
),
outputs=[output],
)
yield result
thread.join()
requirements.txt
- Text
torch==2.3.1
tokenizers==0.19.1
transformers==4.44.2
accelerate==0.34.2
scipy==1.10.1
optimum==1.22.0
xformers==0.0.27
protobuf==5.27.3
einops==0.8.0
requests==2.32.2
sentence_transformers==2.2.0
sentencepiece==0.2.0
autoawq==0.2.6
config.yaml
- YAML
# This is the sample config file for the Llama model
model:
id: "llama-3-8b-instruct"
user_id: "user_id"
app_id: "app_id"
model_type_id: "text-to-text"
build_info:
python_version: "3.10"
inference_compute_info:
cpu_limit: "1"
cpu_memory: "8Gi"
num_accelerators: 1
accelerator_type: ["NVIDIA-A10G"]
accelerator_memory: "12Gi"
checkpoints:
type: "huggingface"
repo_id: "casperhansen/llama-3-8b-instruct-awq"
hf_token: "hf_token"
You can refer to the examples repository mentioned above for additional examples of uploading other large language models (LLMs).
Speech Recognition Model
model.py
- Python
# Model to be uploaded: https://platform.openai.com/docs/guides/speech-to-text/quickstart
import copy
import io
from typing import Iterator
import requests
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2, status_pb2
from openai import OpenAI
from clarifai.runners.models.model_runner import ModelRunner
def bytes_to_audio_file(audio_bytes):
"""Convert bytes data into a file-like object."""
if not audio_bytes:
raise ValueError("Audio bytes cannot be empty.")
audio_file = io.BytesIO(audio_bytes)
audio_file.name = "audio.mp3" # This name is used for the API
return audio_file
def preprocess_audio(audio_url=None, audio_bytes=None, chunk_size=1024, stream=False):
"""
Fetch and preprocess audio data from a URL or bytes.
Parameters:
url (str): URL to fetch audio from (if provided).
bytes (bytes): Audio data in bytes (if provided).
chunk_size (int): Size of chunks for streaming.
stream (bool): Whether to stream the audio in chunks.
Returns:
Generator or file-like object containing audio data.
"""
if audio_bytes:
if stream:
# Stream the audio in chunks (generator)
def audio_stream_generator():
for i in range(0, len(audio_bytes), chunk_size):
yield audio_bytes[i : i + chunk_size]
return audio_stream_generator()
else:
# Return a single chunk of audio
return audio_bytes
elif audio_url:
response = requests.get(audio_url, stream=stream)
if response.status_code != 200:
raise Exception(f"Failed to fetch audio. Status code: {response.status_code}")
if stream:
# Stream the audio in chunks (generator)
def audio_stream_generator():
for chunk in response.iter_content(chunk_size=chunk_size):
if chunk: # Filter out keep-alive new chunks
yield chunk
return audio_stream_generator()
else:
# Return a single chunk of audio
return response.content
else:
raise ValueError("Either 'url' or 'audio_bytes' must be provided")
OPENAI_API_KEY = "API_KEY"
class MyRunner(ModelRunner):
"""A custom runner that used for transcribing audio."""
def load_model(self):
"""Load the model here."""
self.client = OpenAI(api_key=OPENAI_API_KEY)
self.modelname = "whisper-1"
self.language = None
# reset the task in set_translate_task
self.task = "transcribe"
def predict(
self, request: service_pb2.PostModelOutputsRequest
) -> service_pb2.MultiOutputResponse:
"""This is the method that will be called when the runner is run. It takes in an input and
returns an output.
"""
outputs = []
# TODO: parallelize this over inputs in a single request.
for inp in request.inputs:
output = resources_pb2.Output()
data = inp.data
audio_bytes = None
if data.audio.base64:
audio_bytes = preprocess_audio(audio_bytes=data.audio.base64, stream=False)
elif data.audio.url:
audio_bytes = preprocess_audio(
audio_url=data.audio.url,
stream=False,
)
# Send audio bytes to Whisper for transcription
transcription = self.client.audio.transcriptions.create(
model=self.modelname, language=self.language, file=bytes_to_audio_file(audio_bytes)
)
# Set the output data
output.data.text.raw = transcription.text
output.status.code = status_code_pb2.SUCCESS
outputs.append(output)
return service_pb2.MultiOutputResponse(
outputs=outputs,
)
def generate(
self, request: service_pb2.PostModelOutputsRequest
) -> Iterator[service_pb2.MultiOutputResponse]:
def request_iterator(request, chunk_size=1024):
request_copy = copy.deepcopy(request)
for inp in request_copy.inputs:
data = inp.data
audio_chunks = None
if data.audio.base64:
audio_chunks = preprocess_audio(
audio_bytes=data.audio.base64, stream=True, chunk_size=chunk_size
)
elif data.audio.url:
audio_chunks = preprocess_audio(
audio_url=data.audio.url,
stream=True,
chunk_size=chunk_size,
)
for chunk in audio_chunks:
inp.data.audio.base64 = chunk
yield request_copy
chunk_size = 1024 * 1024
return self.stream(request_iterator(request, chunk_size=chunk_size))
def stream(
self, request_iterator: Iterator[service_pb2.PostModelOutputsRequest]
) -> Iterator[service_pb2.MultiOutputResponse]:
for request in request_iterator:
for inp in request.inputs:
output = resources_pb2.Output()
data = inp.data
chunk_size = 10 * 1024 * 1024
if data.image.base64 != b"":
audio_chunks = preprocess_audio(
audio_bytes=data.audio.base64, stream=True, chunk_size=chunk_size
)
elif data.audio.url != "":
audio_chunks = preprocess_audio(
audio_url=data.audio.url, stream=True, chunk_size=chunk_size
)
for chunk in audio_chunks:
transcription = self.client.audio.transcriptions.create(
model=self.modelname, language=self.language, file=bytes_to_audio_file(chunk)
)
# Set the output data
output.data.text.raw = transcription.text
output.status.code = status_code_pb2.SUCCESS
result = service_pb2.MultiOutputResponse(
status=status_pb2.Status(
code=status_code_pb2.SUCCESS,
description="Success",
),
outputs=[output],
)
yield result
requirements.txt
- Text
openai
requests
config.yaml
- YAML
# This is the sample config file for the Openai Whisper model
model:
id: "model_id"
user_id: "user_id"
app_id: "app_id"
model_type_id: "audio-to-text"
build_info:
python_version: "3.10"
inference_compute_info:
cpu_limit: "1"
cpu_memory: "500m"
num_accelerators: 0