The adamupdate function in MATLAB R2024b incorrectly uses uint32 with sqrt and exhibits state corruption, causing errors even in minimal test cases."

3 visualizaciones (últimos 30 días)
% Test adamupdate function
clc;
clear;
% Define test parameters
learnable = dlarray(randn(5, 1)); % Example learnable parameter
gradient = dlarray(randn(5, 1)); % Example gradient
state = []; % Initial state (empty)
optimizer = trainingOptions('adam', 'InitialLearnRate', 0.01); % Example optimizer
timeStep = uint32(1); % Initial time step
try
% Perform a single adamupdate
updatedLearnable = adamupdate(learnable, gradient, state, optimizer, timeStep);
% Display results
disp('adamupdate test successful!');
disp('Updated Learnable:');
disp(updatedLearnable);
catch ME
% Display error message
disp('adamupdate test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace to help Mathworks track the problem.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate test failed!
Error: Undefined function 'sqrt' for input arguments of type 'uint32'.
Stack Trace:
9x1 struct array with fields: file name line
% Perform a second adam update to test state persistence.
timeStep = uint32(2);
try
% Perform a single adamupdate
updatedLearnable = adamupdate(learnable, gradient, state, optimizer, double(timeStep));
% Display results
disp('adamupdate second test successful!');
disp('Updated Learnable:');
disp(updatedLearnable2);
catch ME
% Display error message
disp('adamupdate second test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace to help Mathworks track the problem.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate second test failed!
Error: dlarray is supported only for full arrays of data type double, single, or logical, or for full gpuArrays of these data types.
Stack Trace:
12x1 struct array with fields: file name line
  1 comentario
Chika
Chika el 18 de Mzo. de 2025
error message"
:
adamupdate test failed!
Error: Undefined function 'sqrt' for input arguments of type 'uint32'.
Stack Trace:
7×1 struct array with fields:
file
name
line
adamupdate second test failed!
Error: dlarray is supported only for full arrays of data type double, single, or logical, or for full gpuArrays of these data types.
Stack Trace:
10×1 struct array with fields:
file
name
line

Iniciar sesión para comentar.

Respuesta aceptada

Joss Knight
Joss Knight el 22 de Mzo. de 2025
Well, I admit the error messages aren't very helpful but the basic problem is that passing a trainingOptions object in as an argument to adamupdate is not supported. See the documentation for the correct syntax.
  1 comentario
Chika
Chika el 22 de Mzo. de 2025
I am extremely grateful to Joss Knight for pointing out the error and his advis for me to look at the documentation for adamupdate function.

Iniciar sesión para comentar.

Más respuestas (1)

Chika
Chika el 22 de Mzo. de 2025
% corrected code following the documentation as advised by Joss Knight
% Test adamupdate function (Built-in)
clc;
clear;
% Define test parameters
learnable = dlarray(randn(5, 1)); % Example learnable parameter
gradient = dlarray(randn(5, 1)); % Example gradient
averageGrad = zeros(size(learnable)); % Initialize average gradient
averageSqGrad = zeros(size(learnable)); % Initialize average squared gradient
iteration = 1; % Initial iteration
try
% Perform a single adamupdate
[updatedLearnable, averageGrad, averageSqGrad] = adamupdate(learnable, gradient, averageGrad, averageSqGrad, iteration);
% Display results
disp('adamupdate test successful!');
disp('Updated Learnable:');
disp(updatedLearnable);
disp('Average Gradient:');
disp(averageGrad);
disp('Average Squared Gradient:');
disp(averageSqGrad);
catch ME
% Display error message
disp('adamupdate test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate test successful!
Updated Learnable:
5x1 dlarray -0.7648 -1.0165 -0.0125 -0.5996 0.4997
Average Gradient:
5x1 dlarray 0.1693 -0.0385 0.0958 -0.0383 0.0295
Average Squared Gradient:
5x1 dlarray 0.0029 0.0001 0.0009 0.0001 0.0001
% Perform a second adam update to test state persistence.
iteration = 2;
try
% Perform a second adam update, passing in the updated state
[updatedLearnable2, averageGrad2, averageSqGrad2] = adamupdate(learnable, gradient, averageGrad, averageSqGrad, iteration);
% Display results
disp('adamupdate second test successful!');
disp('Updated Learnable:');
disp(updatedLearnable2);
disp('Average Gradient:');
disp(averageGrad2);
disp('Average Squared Gradient:');
disp(averageSqGrad2);
catch ME
% Display error message
disp('adamupdate second test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate second test successful!
Updated Learnable:
5x1 dlarray -0.7648 -1.0165 -0.0125 -0.5996 0.4997
Average Gradient:
5x1 dlarray 0.3217 -0.0732 0.1819 -0.0727 0.0560
Average Squared Gradient:
5x1 dlarray 0.0057 0.0003 0.0018 0.0003 0.0002

Categorías

Más información sobre Image Data Workflows en Help Center y File Exchange.

Productos


Versión

R2024b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by