How to create a hierarchical neural network (merge multiple neural networks together)?

4 visualizaciones (últimos 30 días)
Hi all,
I am using MATLAB R2020b (including Deep Learning Toolbox) and I am new to this. I have just implemented a function that creates a neural network with a specific structure:
function x = createNetwork(d,d_stern,M_stern)
l5 = M_stern *2 + 1;
net = network;
net.numInputs = 1;
net.inputs{1}.size = d;
net.numLayers = l5;
net.biasConnect = ones([l5,1]);
inCon = zeros([l5,1]);
inCon(1:M_stern) = 1;
net.inputConnect = inCon;
for i = 1:M_stern
net.layerConnect(l5,l5-i)=1;
net.layerConnect(M_stern+i,i)=1;
end
outCon = zeros([l5,1]);
outCon(l5) = 1;
net.outputConnect = transpose(outCon);
for i = 1:M_stern
net.layers{i}.size=4*d_stern;
end
for i = M_stern+1:l5
net.layers{i}.size = 1;
end
net.numWeightElements;
Weights = M_stern*(4*d_stern*(d+2)+2)+1;
for i = 1:l5-1
net.layers{i}.transferFcn = 'tansig';
end
for i = 1:l5
net.layers{i}.initFcn = 'initnw';
end
x = net
view(net)
end
It works as expected. See here for an example of the network with d = 5, d_stern = 1, M_stern = 2:
Now my aim is to create a much "bigger" network, that consists of multiple of the just defined networks. If f.e.
f1 = createNetwork(5,3,2);
f2 = createNetwork(5,3,2);
f3 = createNetwork(5,3,2);
g1 = createNetwork(3,3,2);
g2 = createNetwork(3,3,2);
g3 = createNetwork(3,3,2);
h = createNetwork(3,3,2);
What I now want is a network that calculates
h(
g1( f1(x),f2(x),f3(x) ),
g2( f1(x),f2(x),f3(x) ),
g3(f1(x), f2(x),f3(x))
) with
The bigger network is then trained with data. The goal would be a function
createBigNetwork(d,d_stern,M_stern,p), where p is the number of "neural network layers".
In the upper example, it would be p = 2. Furthermore, createNetwork(d,d_stern,M_stern) = createBigNetwork(d,d_stern,M_stern,0)
Is there any way to do this in Matlab? Any help is greatly appreciated!

Respuestas (1)

Ayush Aniket
Ayush Aniket el 11 de Jun. de 2025
You can create a hierarchical neural network in MATLAB by using addLayers and connectLayers functions. Since MATLAB's network object is somewhat limited for deep learning, you might want to use dlnetwork for more flexibility. Refer the code snippet below:
function bigNet = createBigNetwork(d, d_stern, M_stern, p)
% Initialize an empty network
bigNet = dlnetwork;
% Create base networks
f1 = createNetwork(d, d_stern, M_stern);
f2 = createNetwork(d, d_stern, M_stern);
f3 = createNetwork(d, d_stern, M_stern);
g1 = createNetwork(d_stern, d_stern, M_stern);
g2 = createNetwork(d_stern, d_stern, M_stern);
g3 = createNetwork(d_stern, d_stern, M_stern);
h = createNetwork(d_stern, d_stern, M_stern);
% Connect networks hierarchically
inputLayer = featureInputLayer(d);
f1Layer = dlnetwork(f1);
f2Layer = dlnetwork(f2);
f3Layer = dlnetwork(f3);
g1Layer = dlnetwork(g1);
g2Layer = dlnetwork(g2);
g3Layer = dlnetwork(g3);
hLayer = dlnetwork(h);
% Define connections
bigNet = addLayers(bigNet, inputLayer);
bigNet = addLayers(bigNet, f1Layer);
bigNet = addLayers(bigNet, f2Layer);
bigNet = addLayers(bigNet, f3Layer);
bigNet = addLayers(bigNet, g1Layer);
bigNet = addLayers(bigNet, g2Layer);
bigNet = addLayers(bigNet, g3Layer);
bigNet = addLayers(bigNet, hLayer);
% Connect layers
bigNet = connectLayers(bigNet, 'inputLayer', 'f1Layer');
bigNet = connectLayers(bigNet, 'inputLayer', 'f2Layer');
bigNet = connectLayers(bigNet, 'inputLayer', 'f3Layer');
bigNet = connectLayers(bigNet, 'f1Layer', 'g1Layer');
bigNet = connectLayers(bigNet, 'f2Layer', 'g2Layer');
bigNet = connectLayers(bigNet, 'f3Layer', 'g3Layer');
bigNet = connectLayers(bigNet, 'g1Layer', 'hLayer');
bigNet = connectLayers(bigNet, 'g2Layer', 'hLayer');
bigNet = connectLayers(bigNet, 'g3Layer', 'hLayer');
% View the network
analyzeNetwork(bigNet);
end
You can also check out the following documentation on nested layers for more details on creating and training such networks: https://www.mathworks.com/help/deeplearning/ug/create-network-with-nested-layers.html

Categorías

Más información sobre Deep Learning Toolbox en Help Center y File Exchange.

Productos


Versión

R2020b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by