Deep Fine-Tuning
Learn how to fine-tune pre-trained models
Fine-tuning is a deep learning technique that refers to taking a pre-trained model and further training it on a new dataset or task. The term "fine-tuning" implies making small adjustments or refinements to the already learned representations in the pre-trained model rather than training from scratch.
Fine-tuning leverages the power of pre-trained models to improve their performance on a new, related task. It involves taking a pre-trained model, which was previously trained on a vast dataset for a general-purpose task, and tailoring it to a more specific task.
Why Choose Deep Fine-Tuning?
Clarifai offers a variety of pre-built models that are designed to help you create AI solutions quickly and efficiently. Clarifai models are the recommended starting point for many users because they offer incredibly fast training times, especially when you customize them using the transfer learning model type.
But there are some cases where accuracy and the ability to carefully target solutions takes priority over speed and ease of use. Additionally, you may need a model to learn new features not recognized by existing Clarifai models.
For such cases, it is possible to "deep fine-tune" your custom models and integrate them directly within your workflows.
You might consider deep fine-tuning if:
- You have a custom tailored dataset. This will help you to tailor the model to a specific application or domain, such as customizing it with proprietary data from a private company.
- You have accurate labels. This provides a strong foundation for training your models, resulting in improved performance, reduced errors, and better alignment with the desired task or domain.
- You have the expertise and time to fine-tune models. So, you can modify the model's behavior to eliminate unwanted traits and instill desired ones.
- You want to reduce hallucinations, especially when presenting the model with questions or prompts it hasn't encountered during its initial training.
Types of Deep Fine-Tuned Models
To create a deep fine-tuned model using the Clarifai API, you need to specify the type of model using the model_type_id
parameter.
You can use the ListModelTypes
method to learn more about the available model types and their hyperparameters.
Here some types of deep fine-tuned models you can create:
- Visual classifier (
visual-classifier
) — Create this model to classify images and video frames into a set of concepts. - Visual detector (
visual-detector
) — Create this model to detect bounding box regions in images or video frames and then classify the detected images. You can also send the image regions to an image cropper model to create a new cropped image. - Visual embedder (
visual-embedder
) — Create this model to transform images and video frames into "high level" vector representation understood by our AI models. These embeddings enable visual search and can be used as base models to train other models. - Visual segmenter (
visual-segmenter
) — Create this model to segment a per-pixel mask in images where things are and then classify objects, descriptive words, or topics within the masks. - Visual anomaly heatmap (
visual-anomaly-heatmap
) — Create this model to perform visual anomaly detection with image-level score and anomaly heatmap. - Text classifier (
text-classifier
) — Create this model to classify text into a set of concepts. - Text generator (
text-to-text
) — Create this model to generate or convert text based on the provided text input. For example, you can create it for prompt completion, translation, or summarization tasks.
Number of Inputs
In general, deep fine-tuned models need more data than those trained using the transfer learning technique.
For most use cases, you’ll need at least 1000 training inputs, but it could be much more than this depending on your specific scenario.
Training Time Estimator
Before initiating the training of a deep fine-tuned model, you can estimate the anticipated duration of the training process. This offers transparency in expected training costs.
We currently charge $4 per hour.
The exact training time estimate depends on the following:
- Model type;
- Model configuration details;
- Dataset statistics;
- Hardware.
Clarifai’s Training Time Estimator is carefully designed to balance trade-offs between simplicity, generalization, and accuracy.
Notably, some model configurations and dataset statistics affect training time much more than others. For example, the number of items in the dataset directly affects the number of training steps in most configs, while the learning rate has no impact.
In addition, some parameters affect the time linearly (e.g. number of items), while others may be quadratic (e.g. image size), and others approximately linear, quadratic, or subquadratic—depending on the model (e.g. number of tokens in each input).
The current version of the Training Time Estimator provides estimates only for each template’s default parameter configuration, and we plan to include other parameter configurations in the upcoming releases.
The exact calculation based on the current AWS A10 GPU is:
training time = int(round(A * num_inputs * num_epochs + B))
Where A and B are parameter coefficients estimated specific to the template of each model type.
How to Estimate Training Time
Before using the Python SDK, Node.js SDK, or any of our gRPC clients, ensure they are properly installed on your machine. Refer to their respective installation guides for instructions on how to install and initialize them.
When training a deep fine-tuned model using the UI, the estimated duration for the training process will be displayed, rounded down to the nearest hour with 15-minute increments.
Below is an example of how you can use the API to estimate the expected training time programmatically.
Instead of providing an estimated input count, an alternative approach is to specify a dataset version ID in the train_info.params
of the request. Here is an example: params.update({"template":"MMDetection_FasterRCNN", "dataset_version_id":"dataset-version-1681974758238s"})
.
- Python (gRPC)
- JavaScript (REST)
- Java (gRPC)
- cURL
###################################################################################################
# In this section, we set the user authentication, app ID, model ID, and estimated input count.
# Change these strings to run your own example.
##################################################################################################
USER_ID = "YOUR_USER_ID_HERE"
# Your PAT (Personal Access Token) can be found in the Account's Security section
PAT = "YOUR_PAT_HERE"
APP_ID = "YOUR_APP_ID_HERE"
# Change these to get your training time estimate
MODEL_ID = "YOUR_CUSTOM_MODEL_ID_HERE"
ESTIMATED_INPUT_COUNT = 100
##########################################################################
# YOU DO NOT NEED TO CHANGE ANYTHING BELOW THIS LINE TO RUN THIS EXAMPLE
##########################################################################
from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
from clarifai_grpc.grpc.api.status import status_code_pb2
from google.protobuf.struct_pb2 import Struct
channel = ClarifaiChannel.get_grpc_channel()
stub = service_pb2_grpc.V2Stub(channel)
params = Struct()
params.update({
"template": "MMDetection_FasterRCNN"
})
metadata = (("authorization", "Key " + PAT),)
userDataObject = resources_pb2.UserAppIDSet(user_id=USER_ID, app_id=APP_ID)
training_time_estimate_response = stub.PostModelVersionsTrainingTimeEstimate(
service_pb2.PostModelVersionsTrainingTimeEstimateRequest(
user_app_id=userDataObject,
model_id=MODEL_ID,
model_versions=[
resources_pb2.ModelVersion(
train_info=resources_pb2.TrainInfo(params=params)
),
],
estimated_input_count=ESTIMATED_INPUT_COUNT
),
metadata=metadata,
)
if training_time_estimate_response.status.code != status_code_pb2.SUCCESS:
print(training_time_estimate_response.status)
raise Exception("Post model outputs failed, status: " + training_time_estimate_response.status.description)
print(training_time_estimate_response)
<!--index.html file-->
<script>
///////////////////////////////////////////////////////////////////////////////////////////////////
// In this section, we set the user authentication, app ID, model ID, and estimated input count.
// Change these strings to run your own example.
//////////////////////////////////////////////////////////////////////////////////////////////////
const USER_ID = "YOUR_USER_ID_HERE";
// Your PAT (Personal Access Token) can be found in the Account's Security section
const PAT = "YOUR_PAT_HERE";
const APP_ID = "YOUR_APP_ID_HERE";
// Change these to get your training time estimate
const MODEL_ID = "YOUR_CUSTOM_MODEL_ID_HERE";
const ESTIMATED_INPUT_COUNT = 100;
///////////////////////////////////////////////////////////////////////////////////
// YOU DO NOT NEED TO CHANGE ANYTHING BELOW THIS LINE TO RUN THIS EXAMPLE
///////////////////////////////////////////////////////////////////////////////////
const raw = JSON.stringify({
"user_app_id": {
"user_id": USER_ID,
"app_id": APP_ID
},
"model_versions": [{
"train_info": {
"params": {
"template": "MMDetection_FasterRCNN"
}
},
}],
"estimated_input_count": ESTIMATED_INPUT_COUNT
});
const requestOptions = {
method: "POST",
headers: {
"Content-Type": "application/json",
"Authorization": "Key " + PAT
},
body: raw
};
fetch(`https://api.clarifai.com/v2/users/${USER_ID}/apps/${APP_ID}/models/${MODEL_ID}/versions/time_estimate/`, requestOptions)
.then(response => response.text())
.then(result => console.log(result))
.catch(error => console.log("error", error));
</script>
package com.clarifai.example;
import com.clarifai.grpc.api.*;
import com.clarifai.channel.ClarifaiChannel;
import com.clarifai.credentials.ClarifaiCallCredentials;
import com.clarifai.grpc.api.status.StatusCode;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
public class ClarifaiExample {
//////////////////////////////////////////////////////////////////////////////////////////////////////
// In this section, we set the user authentication, app ID, model ID, and estimated input count.
// Change these strings to run your own example.
/////////////////////////////////////////////////////////////////////////////////////////////////////
static final String USER_ID = "YOUR_USER_ID_HERE";
// Your PAT (Personal Access Token) can be found in the portal under Authentication
static final String PAT = "YOUR_PAT_HERE";
static final String APP_ID = "YOUR_APP_ID_HERE";
// Change these to get your training time estimate
static final String MODEL_ID = "YOUR_CUSTOM_MODEL_ID_HERE";
static final int ESTIMATED_INPUT_COUNT = 100;
///////////////////////////////////////////////////////////////////////////////////
// YOU DO NOT NEED TO CHANGE ANYTHING BELOW THIS LINE TO RUN THIS EXAMPLE
///////////////////////////////////////////////////////////////////////////////////
public static void main(String[] args) {
V2Grpc.V2BlockingStub stub = V2Grpc.newBlockingStub(ClarifaiChannel.INSTANCE.getGrpcChannel())
.withCallCredentials(new ClarifaiCallCredentials(PAT));
Struct.Builder params = Struct.newBuilder()
.putFields("template", Value.newBuilder().setStringValue("MMDetection_FasterRCNN").build());
MultiTrainingTimeEstimateResponse trainingTimeEstimateResponse = stub.postModelVersionsTrainingTimeEstimate(
PostModelVersionsTrainingTimeEstimateRequest.newBuilder()
.setUserAppId(UserAppIDSet.newBuilder().setUserId(USER_ID).setAppId(APP_ID))
.setModelId(MODEL_ID)
.addModelVersions(ModelVersion.newBuilder()
.setTrainInfo(TrainInfo.newBuilder()
.setParams(params)
)
)
.setEstimatedInputCount(ESTIMATED_INPUT_COUNT)
.build()
);
if (trainingTimeEstimateResponse.getStatus().getCode() != StatusCode.SUCCESS) {
throw new RuntimeException("Post model outputs failed, status: " + trainingTimeEstimateResponse.getStatus());
}
System.out.print(trainingTimeEstimateResponse);
}
}
curl -X POST "https://api.clarifai.com/v2/users/YOUR_USER_ID_HERE/apps/YOUR_APP_ID_HERE/models/YOUR_MODEL_ID_HERE/versions/time_estimate/" \
-H "Authorization: Key YOUR_PAT_HERE" \
-H "Content-Type: application/json" \
-d '{
"model_versions": [{
"train_info": {
"params": {
"template": "MMDetection_FasterRCNN"
}
}
}],
"estimated_input_count": 100
}'
Raw Output Example
status {
code: SUCCESS
description: "Ok"
req_id: "f45dfcf36746a567f690744f0b3805a7"
}
training_time_estimates {
seconds: 308
}
Incrementally Train a Model
You can update existing deep fine-tuned models with new data without retraining from scratch. After training a model version, a checkpoint file is automatically saved. And you can initiate incremental training from that previously trained version checkpoint.
Below is an example of how you would perform incremental training from a specific version of a visual detector model.
- Python (gRPC)
- JavaScript (REST)
- Java (gRPC)
- cURL
###################################################################################################
# In this section, we set the user authentication, app ID, and details for incremental training.
# Change these strings to run your own example.
###################################################################################################
USER_ID = "YOUR_USER_ID_HERE"
# Your PAT (Personal Access Token) can be found in the Account's Security section
PAT = "YOUR_PAT_HERE"
APP_ID = "YOUR_APP_ID_HERE"
# Change these to incrementally train your own model
MODEL_ID = "detection-test"
MODEL_VERSION_ID = "5af1bd0fb79d47289ab82d5bb2325c81"
CONCEPT_ID = "face"
##########################################################################
# YOU DO NOT NEED TO CHANGE ANYTHING BELOW THIS LINE TO RUN THIS EXAMPLE
##########################################################################
from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
from clarifai_grpc.grpc.api.status import status_code_pb2
from google.protobuf.struct_pb2 import Struct
channel = ClarifaiChannel.get_grpc_channel()
stub = service_pb2_grpc.V2Stub(channel)
params = Struct()
params.update({
"template": "MMDetection_SSD",
"num_epochs": 1
})
metadata = (("authorization", "Key " + PAT),)
userDataObject = resources_pb2.UserAppIDSet(user_id=USER_ID, app_id=APP_ID)
post_model_versions = stub.PostModelVersions(
service_pb2.PostModelVersionsRequest(
user_app_id=userDataObject,
model_id=MODEL_ID,
model_versions=[
resources_pb2.ModelVersion(
train_info=resources_pb2.TrainInfo(
params=params,
resume_from_model=resources_pb2.Model(
id=MODEL_ID,
model_version=resources_pb2.ModelVersion(id=MODEL_VERSION_ID),
),
),
output_info=resources_pb2.OutputInfo(
data=resources_pb2.Data(
concepts=[resources_pb2.Concept(id=CONCEPT_ID)]
),
),
)
],
),
metadata=metadata,
)
if post_model_versions.status.code != status_code_pb2.SUCCESS:
print(post_model_versions.status)
raise Exception(
"Post models versions failed, status: " + post_model_versions.status.description
)
print(post_model_versions)
<!--index.html file-->
<script>
////////////////////////////////////////////////////////////////////////////////////////////////////////
// In this section, we set the user authentication, app ID, and details for incremental training.
// Change these strings to run your own example.
////////////////////////////////////////////////////////////////////////////////////////////////////////
const USER_ID = "YOUR_USER_ID_HERE";
// Your PAT (Personal Access Token) can be found in the Account's Security section
const PAT = "YOUR_PAT_HERE";
const APP_ID = "YOUR_APP_ID_HERE";
// Change these to incrementally train your own model
const MODEL_ID = "detection-test";
const MODEL_VERSION_ID = "5af1bd0fb79d47289ab82d5bb2325c81";
const CONCEPT_ID = "face";
///////////////////////////////////////////////////////////////////////////////////
// YOU DO NOT NEED TO CHANGE ANYTHING BELOW THIS LINE TO RUN THIS EXAMPLE
///////////////////////////////////////////////////////////////////////////////////
const raw = JSON.stringify({
"user_app_id": {
"user_id": USER_ID,
"app_id": APP_ID
},
"model_versions": [{
"train_info": {
"params": {
"template": "MMDetection_SSD",
"num_epochs": 1
},
"resume_from_model": {
"id": MODEL_ID,
"model_version": {
"id": MODEL_VERSION_ID
}
}
},
"output_info": {
"data": {
"concepts": [
{
"id": CONCEPT_ID
}
]
}
}
}]
});
const requestOptions = {
method: "POST",
headers: {
"Content-Type": "application/json",
"Authorization": "Key " + PAT
},
body: raw
};
fetch(`https://api.clarifai.com/v2/models/${MODEL_ID}/versions`, requestOptions)
.then(response => response.text())
.then(result => console.log(result))
.catch(error => console.log("error", error));
</script>
package com.clarifai.example;
import com.clarifai.grpc.api.*;
import com.clarifai.channel.ClarifaiChannel;
import com.clarifai.credentials.ClarifaiCallCredentials;
import com.clarifai.grpc.api.status.StatusCode;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
public class ClarifaiExample {
//////////////////////////////////////////////////////////////////////////////////////////////////////
// In this section, we set the user authentication, app ID, and details for incremental training.
// Change these strings to run your own example.
//////////////////////////////////////////////////////////////////////////////////////////////////////
static final String USER_ID = "YOUR_USER_ID_HERE";
// Your PAT (Personal Access Token) can be found in the portal under Authentication
static final String PAT = "YOUR_PAT_HERE";
static final String APP_ID = "YOUR_APP_ID_HERE";
// Change these to incrementally train your own model
static final String MODEL_ID = "detection-test";
static final String MODEL_VERSION_ID = "5af1bd0fb79d47289ab82d5bb2325c81";
static final String CONCEPT_ID = "face";
///////////////////////////////////////////////////////////////////////////////////
// YOU DO NOT NEED TO CHANGE ANYTHING BELOW THIS LINE TO RUN THIS EXAMPLE
///////////////////////////////////////////////////////////////////////////////////
public static void main(String[] args) {
V2Grpc.V2BlockingStub stub = V2Grpc.newBlockingStub(ClarifaiChannel.INSTANCE.getGrpcChannel())
.withCallCredentials(new ClarifaiCallCredentials(PAT));
Struct.Builder params = Struct.newBuilder()
.putFields("template", Value.newBuilder().setStringValue("MMDetection_SSD").build())
.putFields("num_epochs", Value.newBuilder().setNumberValue(1).build());
SingleModelResponse postModelVersionsResponse = stub.postModelVersions(
PostModelVersionsRequest.newBuilder()
.setUserAppId(UserAppIDSet.newBuilder().setUserId(USER_ID).setAppId(APP_ID))
.setModelId(MODEL_ID)
.addModelVersions(ModelVersion.newBuilder()
.setTrainInfo(TrainInfo.newBuilder()
.setParams(params)
.setResumeFromModel(Model.newBuilder()
.setId(MODEL_ID)
.setModelVersion(ModelVersion.newBuilder()
.setId(MODEL_VERSION_ID)
)
)
)
.setOutputInfo(OutputInfo.newBuilder()
.setData(Data.newBuilder()
.addConcepts(Concept.newBuilder()
.setId(CONCEPT_ID)
)
)
)
)
.build()
);
if (postModelVersionsResponse.getStatus().getCode() != StatusCode.SUCCESS) {
throw new RuntimeException("Post model outputs failed, status: " + postModelVersionsResponse.getStatus());
}
}
}
curl -X POST "https://api.clarifai.com/v2/users/YOUR_USER_ID_HERE/apps/YOUR_APP_ID_HERE/models/detection-test/versions" \
-H "Authorization: Key YOUR_PAT_HERE" \
-H "Content-Type: application/json" \
-d '{
"model_versions": [{
"train_info": {
"params": {
"template": "MMDetection_SSD",
"num_epochs": 1
},
"resume_from_model": {
"id": "detection-test",
"model_version": {
"id": "5af1bd0fb79d47289ab82d5bb2325c81"
}
}
},
"output_info": {
"data": {
"concepts": [
{
"id": "face"
}
]
}
}
}]
}'
📄️ Visual Classifier
Learn how to create and train a visual classifier model
📄️ Visual Detector
Learn about our visual detector model type
📄️ Visual Segmenter
Learn about our visual segmenter model type
📄️ Visual Anomaly
Learn about our visual anomaly model type
📄️ Visual Embedder
Learn about our visual embedder model type
📄️ Clusterer
Learn about our clusterer model type
📄️ Text Classifier
Learn about our text classifier model type
📄️ Text Generation
Learn about our text-to-text model type and understand its fine-tuning process