Can anyone tell me why my function is taking so long to run?

3 visualizaciones (últimos 30 días)
Alex
Alex el 24 de Nov. de 2011
I wrote a K-means clustering algorithm and it works fine but now I have to run it over 100 iterations and it is taking way too long. Can anyone tell me why this is? Here is my code
%This code computes the accuracy of the k-means algorithm
clear all;
dataPts=(genData())';
numClust=3;
[m,n]=size(dataPts);
%Run the k-means algorithm 100 times
numIter=35;
for iter=1:numIter
%choose initial centers randomly. First generate random indeces from which
%to choose the centers.
clustPos=randi([1,n],[1,numClust]);
%Can now identify initial clusters
for i=1:numClust
clustCenter(1:m,i)=dataPts(1:m,clustPos(i));
end
oldClustCenter=clustCenter;
while 1
r=zeros([n,numClust]);
%Now assign each data point to a cluster
for j=1:n
colNorm=(sqrt(sum((clustCenter-repmat(dataPts(:,j),1,numClust)).^2,1))).^2;
[distToCenter(j),closestClust(j)]=min(colNorm);
r(j,closestClust(j))=1;
end
%Now recompute the cluster centers using the indicator variables
%clustCenter(:,1)=
for j=1:numClust
sum1=0;
sum2=0;
for i=1:n
numerator=r(i,j)*dataPts(:,i)+sum1;
sum1=numerator;
denominator=r(i,j)+sum2;
sum2=denominator;
end
clustCenter(:,j)=numerator/denominator;
end
if clustCenter==oldClustCenter
break
end
oldClustCenter=clustCenter;
end
%must now determine which data points are in the correct cluster are which
%are in error.
clust11Count=0;
clust12Count=0;
clust13Count=0;
for k=1:50
if r(k,:)==[1,0,0]
clust11Count=clust11Count+1;
elseif r(k,:)==[0,1,0]
clust12Count=clust12Count+1;
else
clust13Count=clust13Count+1;
end
end
[clust1Error,bestClust1]=max([clust11Count,clust12Count,clust13Count]);
error1=50-clust1Error;
clust21Count=0;
clust22Count=0;
clust23Count=0;
for k=51:100
if r(k,:)==[1,0,0]
clust21Count=clust21Count+1;
elseif r(k,:)==[0,1,0]
clust22Count=clust22Count+1;
else
clust23Count=clust23Count+1;
end
end
[clust2Error,bestClust2]=max([clust21Count,clust22Count,clust23Count]);
error2=50-clust2Error;
clust31Count=0;
clust32Count=0;
clust33Count=0;
for k=101:150
if r(k,:)==[1,0,0]
clust31Count=clust31Count+1;
elseif r(k,:)==[0,1,0]
clust32Count=clust32Count+1;
else
clust33Count=clust33Count+1;
end
end
[clust3Error,bestClust3]=max([clust31Count,clust32Count,clust33Count]);
%It sometimes happens that one k-means generated cluster covers two true clusters.
%In this case every points in k-means cluster 3 will be incorrect.
if bestClust2==bestClust3
error3=50;
else
error3=50-clust3Error;
end
avgError(iter)=((error1+error2+error3)/150);
end
avgavgError=(sum(avgError)/numIter)*100;
disp(['The average error is ' num2str(avgavgError) ' percent']);
and here is the function that generates the data
function [ data ] = genData
% generate data
data = randn(150,2) ;
data(1:50,:) = 3 + sqrt(.2) * data(1:50,:) ;
data(51:100,:) = 6 + sqrt(.5) * data(51:100,:) ;
data(101:150,:) = 7 + sqrt(.6) * data(101:150,:) ;
end

Respuestas (3)

Walter Roberson
Walter Roberson el 24 de Nov. de 2011
Several of your arrays are not pre-allocated. That can slow things down a lot.

Hin Kwan Wong
Hin Kwan Wong el 24 de Nov. de 2011
Please read in the MATLAB help about the use of profiler to isolate lines of code which takes most of time to run. There are also many good tips and techniques on improving coding performance in the help file. There are many for loops which has potential for vectorization and the tips covers that with examples.

Jan
Jan el 24 de Nov. de 2011
1. The already mention pre-allocation is most important.
2. Unnecessary calculations:
colNorm=(sqrt(sum((clustCenter - repmat(dataPts(:,j), 1, numClust)) .^ 2, 1))).^2;
At first SQRT(X).^2 == X. At second it is cheaper to square the vector dataPts(:,j) before REPMATting it.
colNorm = sum(clustCenter - repmat(dataPts(:,j) .^ 2, 1, numClust), 1);
3. Slow loop
for i=1:numClust
clustCenter(1:m,i)=dataPts(1:m,clustPos(i));
end
Faster without loop:
clusterCenter = dataPts(:, clustPos);
4. The loops over k to calculate clustXYCount can be accelerated also:
q = r(1:50, :);
clust11Count = sum(q(:, 1) & q(:, 2)==0 & q(:, 3)==0);
clust12Count = sum(q(:, 1) == 0 & q(:, 2) & q(:, 3)==0);
clust13Count = 50 - clust11Count - clust12Count;
5. clear all wastes time without any benefit. It deletes all formerly loaded functions from the memory. The reloading is very expensive. See: Answers: good programming practize.

Etiquetas

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by