Skip to main content

Deep Fine-Tuning

Learn how to fine-tune pre-trained models


Fine-tuning is a deep learning technique that refers to taking a pre-trained model and further training it on a new dataset or task. The term "fine-tuning" implies making small adjustments or refinements to the already learned representations in the pre-trained model rather than training from scratch.

Fine-tuning leverages the power of pre-trained models to improve their performance on a new, related task. It involves taking a pre-trained model, which was previously trained on a vast dataset for a general-purpose task, and tailoring it to a more specific task.

Why Choose Deep Fine-Tuning?

Clarifai offers a variety of pre-built models that are designed to help you create AI solutions quickly and efficiently. Clarifai models are the recommended starting point for many users because they offer incredibly fast training times, especially when you customize them using the transfer learning model type.

But there are some cases where accuracy and the ability to carefully target solutions takes priority over speed and ease of use. Additionally, you may need a model to learn new features not recognized by existing Clarifai models.

For such cases, it is possible to "deep fine-tune" your custom models and integrate them directly within your workflows.

You might consider deep fine-tuning if:

  • You have a custom tailored dataset. This will help you to tailor the model to a specific application or domain, such as customizing it with proprietary data from a private company.
  • You have accurate labels. This provides a strong foundation for training your models, resulting in improved performance, reduced errors, and better alignment with the desired task or domain.
  • You have the expertise and time to fine-tune models. So, you can modify the model's behavior to eliminate unwanted traits and instill desired ones.
  • You want to reduce hallucinations, especially when presenting the model with questions or prompts it hasn't encountered during its initial training.

Types of Deep Fine-Tuned Models

To create a deep fine-tuned model using the Clarifai API, you need to specify the type of model using the model_type_id parameter⁠.

tip

You can use the ListModelTypes method to learn more about the available model types and their hyperparameters.

Here some types of deep fine-tuned models you can create:

  • Visual classifier (visual-classifier) — Create this model to classify images and video frames into a set of concepts.
  • Visual detector (visual-detector) — Create this model to detect bounding box regions in images or video frames and then classify the detected images. You can also send the image regions to an image cropper model to create a new cropped image.
  • Visual embedder (visual-embedder) — Create this model to transform images and video frames into "high level" vector representation understood by our AI models. These embeddings enable visual search and can be used as base models to train other models.
  • Visual segmenter (visual-segmenter) — Create this model to segment a per-pixel mask in images where things are and then classify objects, descriptive words, or topics within the masks.
  • Visual anomaly heatmap (visual-anomaly-heatmap) — Create this model to perform visual anomaly detection with image-level score and anomaly heatmap.
  • Text classifier (text-classifier) — Create this model to classify text into a set of concepts.
  • Text generator (text-to-text) — Create this model to generate or convert text based on the provided text input. For example, you can create it for prompt completion, translation, or summarization tasks.

Number of Inputs

In general, deep fine-tuned models need more data than those trained using the transfer learning technique.

For most use cases, you’ll need at least 1000 training inputs, but it could be much more than this depending on your specific scenario.

Training Time Estimator

Before initiating the training of a deep fine-tuned model, you can estimate the anticipated duration of the training process. This offers transparency in expected training costs.

We currently charge $4 per hour.

The exact training time estimate depends on the following:

  • Model type;
  • Model configuration details;
  • Dataset statistics;
  • Hardware.

Clarifai’s Training Time Estimator is carefully designed to balance trade-offs between simplicity, generalization, and accuracy.

Notably, some model configurations and dataset statistics affect training time much more than others. For example, the number of items in the dataset directly affects the number of training steps in most configs, while the learning rate has no impact.

In addition, some parameters affect the time linearly (e.g. number of items), while others may be quadratic (e.g. image size), and others approximately linear, quadratic, or subquadratic—depending on the model (e.g. number of tokens in each input).

The current version of the Training Time Estimator provides estimates only for each template’s default parameter configuration, and we plan to include other parameter configurations in the upcoming releases.

The exact calculation based on the current AWS A10 GPU is:

training time = int(round(A * num_inputs * num_epochs + B)) 

Where A and B are parameter coefficients estimated specific to the template of each model type.

How to Estimate Training Time

info

Before using the Python SDK, Node.js SDK, or any of our gRPC clients, ensure they are properly installed on your machine. Refer to their respective installation guides for instructions on how to install and initialize them.

When training a deep fine-tuned model using the UI, the estimated duration for the training process will be displayed, rounded down to the nearest hour with 15-minute increments.

Below is an example of how you can use the API to estimate the expected training time programmatically.

tip

Instead of providing an estimated input count, an alternative approach is to specify a dataset version ID in the train_info.params of the request. Here is an example: params.update({"template":"MMDetection_FasterRCNN", "dataset_version_id":"dataset-version-1681974758238s"}).

###################################################################################################
# In this section, we set the user authentication, app ID, model ID, and estimated input count.
# Change these strings to run your own example.
##################################################################################################

USER_ID = "YOUR_USER_ID_HERE"
# Your PAT (Personal Access Token) can be found in the Account's Security section
PAT = "YOUR_PAT_HERE"
APP_ID = "YOUR_APP_ID_HERE"
# Change these to get your training time estimate
MODEL_ID = "YOUR_CUSTOM_MODEL_ID_HERE"
ESTIMATED_INPUT_COUNT = 100

##########################################################################
# YOU DO NOT NEED TO CHANGE ANYTHING BELOW THIS LINE TO RUN THIS EXAMPLE
##########################################################################

from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
from clarifai_grpc.grpc.api.status import status_code_pb2
from google.protobuf.struct_pb2 import Struct

channel = ClarifaiChannel.get_grpc_channel()
stub = service_pb2_grpc.V2Stub(channel)

params = Struct()
params.update({
"template": "MMDetection_FasterRCNN"
})

metadata = (("authorization", "Key " + PAT),)

userDataObject = resources_pb2.UserAppIDSet(user_id=USER_ID, app_id=APP_ID)

training_time_estimate_response = stub.PostModelVersionsTrainingTimeEstimate(
service_pb2.PostModelVersionsTrainingTimeEstimateRequest(
user_app_id=userDataObject,
model_id=MODEL_ID,
model_versions=[
resources_pb2.ModelVersion(
train_info=resources_pb2.TrainInfo(params=params)
),
],
estimated_input_count=ESTIMATED_INPUT_COUNT
),
metadata=metadata,
)

if training_time_estimate_response.status.code != status_code_pb2.SUCCESS:
print(training_time_estimate_response.status)
raise Exception("Post model outputs failed, status: " + training_time_estimate_response.status.description)

print(training_time_estimate_response)
Raw Output Example
status {
code: SUCCESS
description: "Ok"
req_id: "f45dfcf36746a567f690744f0b3805a7"
}
training_time_estimates {
seconds: 308
}

Incrementally Train a Model

You can update existing deep fine-tuned models with new data without retraining from scratch. After training a model version, a checkpoint file is automatically saved. And you can initiate incremental training from that previously trained version checkpoint.

Below is an example of how you would perform incremental training from a specific version of a visual detector model.

###################################################################################################
# In this section, we set the user authentication, app ID, and details for incremental training.
# Change these strings to run your own example.
###################################################################################################

USER_ID = "YOUR_USER_ID_HERE"
# Your PAT (Personal Access Token) can be found in the Account's Security section
PAT = "YOUR_PAT_HERE"
APP_ID = "YOUR_APP_ID_HERE"
# Change these to incrementally train your own model
MODEL_ID = "detection-test"
MODEL_VERSION_ID = "5af1bd0fb79d47289ab82d5bb2325c81"
CONCEPT_ID = "face"

##########################################################################
# YOU DO NOT NEED TO CHANGE ANYTHING BELOW THIS LINE TO RUN THIS EXAMPLE
##########################################################################

from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
from clarifai_grpc.grpc.api.status import status_code_pb2
from google.protobuf.struct_pb2 import Struct

channel = ClarifaiChannel.get_grpc_channel()
stub = service_pb2_grpc.V2Stub(channel)

params = Struct()
params.update({
"template": "MMDetection_SSD",
"num_epochs": 1
})

metadata = (("authorization", "Key " + PAT),)

userDataObject = resources_pb2.UserAppIDSet(user_id=USER_ID, app_id=APP_ID)

post_model_versions = stub.PostModelVersions(
service_pb2.PostModelVersionsRequest(
user_app_id=userDataObject,
model_id=MODEL_ID,
model_versions=[
resources_pb2.ModelVersion(
train_info=resources_pb2.TrainInfo(
params=params,
resume_from_model=resources_pb2.Model(
id=MODEL_ID,
model_version=resources_pb2.ModelVersion(id=MODEL_VERSION_ID),
),
),
output_info=resources_pb2.OutputInfo(
data=resources_pb2.Data(
concepts=[resources_pb2.Concept(id=CONCEPT_ID)]
),
),
)
],
),
metadata=metadata,
)

if post_model_versions.status.code != status_code_pb2.SUCCESS:
print(post_model_versions.status)
raise Exception(
"Post models versions failed, status: " + post_model_versions.status.description
)

print(post_model_versions)