Cannot get tracing to work on complex custom deep learning layer

14 visualizaciones (últimos 30 días)
I'm trying to get a Matlab version of https://github.com/jfcrenshaw/pzflow to work because I need something like it buried deep in a Matlab workflow. My working code is at https://github.com/jeremylea/DLextras/tree/main/mzflow. No matter what I try, I cannot get it to train. I keep getting this error: 'dlgradient' inputs must be traced dlarray objects or cell arrays, structures, or tables containing traced dlarray objects. To enable tracing, use 'dlfeval'. This is despite dfeval being in the call stack... The main custom layer is quite complex and only uses the learnables tangentially as inputs to the knot locations in a spline, and then these impact the loss through the jacobian, not through the main output. The code flow is test_flow->Flow.train->dlfeval->loss_fun->dlgradient. It follows a standard training path compared to the examples and the custom training setups I have done and that work.
Hours of debugging tell me that the tape is recording all of the operations on the complete set of layers, that the weight matrices are on the tape, and that the weight matrices are in a recording state when they are called in predict. However, somehow the value that is found by the bijector.Learnables call for dlgradient is returning a table of values that are not in a recording state (which it should be within the dlfeval call?). I can't figure out how and when the matrices get switched out of a recording state or if I have two copies. Can anyone help?
I've tried replacing the embedded dlnetwork with a custom set of weights and biases - that doesn't help. I've also tried using the state to capture the log determinant values (which is much cleaner and would be my prefered design). My next choice is to figure out how to make this a direct call to Python and abandon the deep learning toolkit... Unfortunately, the pzflow library uses JAX, which is hard on Windows, so that probably also means moving the entire flow to Linux.

Respuesta aceptada

Richard
Richard el 19 de Mayo de 2023
You are passing the bijector network into dlfeval as data that is copied inside the closure of an anonymous function. dlfeval cannot see these pieces of data because they are private to the closure, and thus cannot convert them to inputs in the trace. You need to change the call to dlfeval so that at least the network is an input to dlfeval:
[~,gradients] = dlfeval(@loss_fun,this.bijector,this.latent,Xbat,Cbat);
It also looks like your implementations of log_prob, and possibly also sample, in the distributions.Uniform class may cause an issue: from inspection it looks to me as if they do not derive their output in a traceable chain from the input, which will ultimtely result in the loss value not being traced.
In log_prob, the problem line is:
log_prob = repmat(-inf,size(mask));
log_prob is not a tracing variable and thus it will not record the mask application on the next line. I think a better implementation is:
function inputs = log_prob(this, inputs)
mask = all((inputs >= 0) & (inputs <= 1),finddim(inputs,"C"));
inputs(:) = -inf;
inputs(mask) = -this.input_dim*log(1.0);
end
I was also a bit suspicious about some of the code in the custom layers' forwards - it isn't clear that they are all correctly tracing everything they do. For example the use of extractdata in the NeuralSplineCoupling class is a flag that often indicates that something will be lost from the trace. You may need to write custom backward implementations for some of these cases.
  1 comentario
Jeremy Lea
Jeremy Lea el 22 de Mayo de 2023
Thanks for getting back to me so fast on this. The closure was the problem. I knew it would be something simple and obvious. I've pushed a working version of this up to github. Now I just need to make it work in my actual problem :-). I also need to make a readme and then I'll release to matlab central.
These normalizing flows are quite confusing - they're a bit like "inverted self GANs". The network is actually backwards, and producing a uniform distribution, so you never actually use the outputs. To use it you generate uniform data and feed in through the inverse functions backwards. The learnables are all in the jacobians of the transform, which produces the likelihoods of the inputs, which is used to train, like a disscriminator in a GAN. The extractdata calls are mostly around just the outliers (caused by numerical issues), but they result in the same output with or without them, so probably best to cut off the tracing.
The examples are pretty cool, especially the last conditional one.

Iniciar sesión para comentar.

Más respuestas (0)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by