Can anyone tell me why my function is taking so long to run?
3 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
I wrote a K-means clustering algorithm and it works fine but now I have to run it over 100 iterations and it is taking way too long. Can anyone tell me why this is? Here is my code
%This code computes the accuracy of the k-means algorithm
clear all;
dataPts=(genData())';
numClust=3;
[m,n]=size(dataPts);
%Run the k-means algorithm 100 times
numIter=35;
for iter=1:numIter
%choose initial centers randomly. First generate random indeces from which
%to choose the centers.
clustPos=randi([1,n],[1,numClust]);
%Can now identify initial clusters
for i=1:numClust
clustCenter(1:m,i)=dataPts(1:m,clustPos(i));
end
oldClustCenter=clustCenter;
while 1
r=zeros([n,numClust]);
%Now assign each data point to a cluster
for j=1:n
colNorm=(sqrt(sum((clustCenter-repmat(dataPts(:,j),1,numClust)).^2,1))).^2;
[distToCenter(j),closestClust(j)]=min(colNorm);
r(j,closestClust(j))=1;
end
%Now recompute the cluster centers using the indicator variables
%clustCenter(:,1)=
for j=1:numClust
sum1=0;
sum2=0;
for i=1:n
numerator=r(i,j)*dataPts(:,i)+sum1;
sum1=numerator;
denominator=r(i,j)+sum2;
sum2=denominator;
end
clustCenter(:,j)=numerator/denominator;
end
if clustCenter==oldClustCenter
break
end
oldClustCenter=clustCenter;
end
%must now determine which data points are in the correct cluster are which
%are in error.
clust11Count=0;
clust12Count=0;
clust13Count=0;
for k=1:50
if r(k,:)==[1,0,0]
clust11Count=clust11Count+1;
elseif r(k,:)==[0,1,0]
clust12Count=clust12Count+1;
else
clust13Count=clust13Count+1;
end
end
[clust1Error,bestClust1]=max([clust11Count,clust12Count,clust13Count]);
error1=50-clust1Error;
clust21Count=0;
clust22Count=0;
clust23Count=0;
for k=51:100
if r(k,:)==[1,0,0]
clust21Count=clust21Count+1;
elseif r(k,:)==[0,1,0]
clust22Count=clust22Count+1;
else
clust23Count=clust23Count+1;
end
end
[clust2Error,bestClust2]=max([clust21Count,clust22Count,clust23Count]);
error2=50-clust2Error;
clust31Count=0;
clust32Count=0;
clust33Count=0;
for k=101:150
if r(k,:)==[1,0,0]
clust31Count=clust31Count+1;
elseif r(k,:)==[0,1,0]
clust32Count=clust32Count+1;
else
clust33Count=clust33Count+1;
end
end
[clust3Error,bestClust3]=max([clust31Count,clust32Count,clust33Count]);
%It sometimes happens that one k-means generated cluster covers two true clusters.
%In this case every points in k-means cluster 3 will be incorrect.
if bestClust2==bestClust3
error3=50;
else
error3=50-clust3Error;
end
avgError(iter)=((error1+error2+error3)/150);
end
avgavgError=(sum(avgError)/numIter)*100;
disp(['The average error is ' num2str(avgavgError) ' percent']);
and here is the function that generates the data
function [ data ] = genData
% generate data
data = randn(150,2) ;
data(1:50,:) = 3 + sqrt(.2) * data(1:50,:) ;
data(51:100,:) = 6 + sqrt(.5) * data(51:100,:) ;
data(101:150,:) = 7 + sqrt(.6) * data(101:150,:) ;
end
0 comentarios
Respuestas (3)
Walter Roberson
el 24 de Nov. de 2011
Several of your arrays are not pre-allocated. That can slow things down a lot.
0 comentarios
Hin Kwan Wong
el 24 de Nov. de 2011
Please read in the MATLAB help about the use of profiler to isolate lines of code which takes most of time to run. There are also many good tips and techniques on improving coding performance in the help file. There are many for loops which has potential for vectorization and the tips covers that with examples.
0 comentarios
Jan
el 24 de Nov. de 2011
1. The already mention pre-allocation is most important.
2. Unnecessary calculations:
colNorm=(sqrt(sum((clustCenter - repmat(dataPts(:,j), 1, numClust)) .^ 2, 1))).^2;
At first SQRT(X).^2 == X. At second it is cheaper to square the vector dataPts(:,j) before REPMATting it.
colNorm = sum(clustCenter - repmat(dataPts(:,j) .^ 2, 1, numClust), 1);
3. Slow loop
for i=1:numClust
clustCenter(1:m,i)=dataPts(1:m,clustPos(i));
end
Faster without loop:
clusterCenter = dataPts(:, clustPos);
4. The loops over k to calculate clustXYCount can be accelerated also:
q = r(1:50, :);
clust11Count = sum(q(:, 1) & q(:, 2)==0 & q(:, 3)==0);
clust12Count = sum(q(:, 1) == 0 & q(:, 2) & q(:, 3)==0);
clust13Count = 50 - clust11Count - clust12Count;
5. clear all wastes time without any benefit. It deletes all formerly loaded functions from the memory. The reloading is very expensive. See: Answers: good programming practize.
0 comentarios
Ver también
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!