Training/Cross validation/Test sets

9 visualizaciones (últimos 30 días)
Reuben Addison
Reuben Addison el 1 de Mzo. de 2019
Respondida: Shubham el 4 de Sept. de 2024
I am new to machine learning and I am a little lost on these concepts, I trained a model with my training data set(60%) and get the optimized parameters(theta values), do i have to train the cross validation sets? if the answer is yes what then do I do with the new theta values for my cross validation data? If not why do I still need a cross validation sets?
Since I can test my model accuracy with just the test sets. I will appreciate it if someone takes me through.
Also how to assess model accuracy with my test sets. I appreciate any help in advance. Even a tutorial video/codes

Respuestas (1)

Shubham
Shubham el 4 de Sept. de 2024
Hi Reuben,
Here's a basic guide on how you might handle training, validation, and testing of a model in MATLAB, using a simple example with a classification model. MATLAB provides several built-in functions for machine learning, including the fitcsvm function for training a support vector machine (SVM) classifier.
Steps in MATLAB
  1. Load and Split Data: First, split your dataset into training, validation, and test sets.
  2. Train the Model: Use the training set to train your model.
  3. Validate the Model: Use the validation set to tune hyperparameters and check for overfitting.
  4. Test the Model: Use the test set to evaluate the final model's performance.
Example Code
Here's a simple example using MATLAB:
% Load your dataset
load fisheriris % Example dataset
X = meas; % Features
y = species; % Labels
% Split the data into training, validation, and test sets
cv = cvpartition(length(y), 'HoldOut', 0.4);
XTrain = X(training(cv), :);
yTrain = y(training(cv), :);
XTemp = X(test(cv), :);
yTemp = y(test(cv), :);
% Further split the temp data into validation and test
cv2 = cvpartition(length(yTemp), 'HoldOut', 0.5);
XVal = XTemp(training(cv2), :);
yVal = yTemp(training(cv2), :);
XTest = XTemp(test(cv2), :);
yTest = yTemp(test(cv2), :);
% Train the model using the training set
model = fitcsvm(XTrain, yTrain);
% Validate the model using the validation set
valPredictions = predict(model, XVal);
valAccuracy = sum(valPredictions == yVal) / length(yVal);
fprintf('Validation Accuracy: %.2f%%\n', valAccuracy * 100);
% Test the model using the test set
testPredictions = predict(model, XTest);
testAccuracy = sum(testPredictions == yTest) / length(yTest);
fprintf('Test Accuracy: %.2f%%\n', testAccuracy * 100);
% Display a detailed classification report
confMat = confusionmat(yTest, testPredictions);
disp('Confusion Matrix:');
disp(confMat);
Explanation
  • Data Splitting: We first split the data into training and temporary sets (60% training, 40% temporary). The temporary set is further split into validation and test sets (20% each).
  • Model Training: We train an SVM model using the training data.
  • Validation: We evaluate the model's performance on the validation set to ensure it generalizes well and to fine-tune parameters if necessary.
  • Testing: Finally, we assess the model's performance on the test set to get an unbiased estimate of its accuracy.

Categorías

Más información sobre Support Vector Machine Regression en Help Center y File Exchange.

Etiquetas

Productos


Versión

R2018a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by