Summary: This post showcases a workaround to optimize a tf.keras.Model
model with a TensorFlow-based L-BFGS optimizer from TensorFlow Probability. The complete code can be found at my GitHub Gist here.
Update (06/08/2020): I’ve updated the code on GitHub Gist to show how to save loss values into a list when using the @tf.function
decorator. But I didn’t update the blog post here, so the code line numbers may not match the code on GitHub.
While SGD, Adam, etc. optimizers nowadays dominate the training of deep neural networks, some, including me, may want to use second-order methods, such as L-BFGS. The problem is TensorFlow 2.0 does not have L-BFGS.
PyTorch provides L-BFGS, so I guess that using Keras with PyTorch backend may be a possible workaround. But I don’t use original Keras. I use TensorFlow 2.0 and build Keras models with the tf.keras
module. That means I’m not able to switch the backend. And it’s difficult for me to migrate my code to original Keras because my code has many customized things that require TensorFlow 2.0.
A workaround is to use the L-BFGS solver from SciPy library to train a tf.keras.Model
or its subclasses. We can find some example code of this workaround from Google search. The problem of this workaround, however, is that vanilla SciPy is not GPU-capable. And many people like me use TensorFlow because we need GPU computing. In addition, I personally don’t think SciPy is capable of doing serious and large-scale calculations. It is for prototyping, not for something supposed to run on HPC clusters. Though it’s just my personal opinion.
Fortunately, TensorFlow-based L-BFGS solver exists in a library called TensorFlow Probability. The API documentation of this solver is here. We can use it through something like import tensorflow_probability as tfp
and then result = tfp.optimizer.lbfgs_minimize(...)
. The returned object, result
, contains several data. And the final optimized parameters will be in result.position
. If using a GPU version of TensorFlow, then this L-BFGS solver should also run on GPUs.
Apparently, the solver is not implemented as a subclass of tf.keras.optimizers.Optimizer
. So we are not able to use this solver directly with model.compile(...)
and model.fit(...)
. The solver is a function. We need some workaround or wrappers to use this solver. Let’s first see the arguments of this function:
The value_and_gradients_function
is a function or a callable object that returns the loss and the gradients with respect to parameters. And this callable object should take in the parameters that we want to optimize, which are a model’s trainable parameters (i.e., the kernels & biases of trainable layers). See the first notable thing here? The callable object value_and_gradients_function
takes in model parameters, not training data.
The second thing is that the model parameters fed into value_and_gradients_function
should be a 1D tf.Tensor
. But TensorFlow and Keras store trainable model parameters as a list of multidimensional tf.Variable
. We can easily see this with print(model.trainable_variables)
(assuming model
is an instance of tf.keras.Model
or its subclasses). This means we need to find a way to transform a list of multidimensional tf.Variable
to a single 1D tensor. This can be done with tf.dynamic_stitch
.
We also need a way to convert a 1D tf.Tensor
back to a list of multidimensional tf.Tensor
, tf.Variable
or numpy.ndarray
so that we can update the parameters of the model. (Basically, under most situations, we can treat tf.Variable
like tf.Tensor
, and vice versa.) There are two ways to update a model’s parameters
- We can use
model.set_weights(params)
to update the values of the parameters. In this case,params
is a list of multidimensionalnumpy.ndarray
. And we need to convert the aforementioned 1Dtf.Tensor
to theparams
. In TensorFlow 2.0, this is not difficult because of the default eager execution behavior. We first partition the 1Dtf.Tensor
withtf.dynamic_partition
to a list of tensors, convert the list of tensors to a list ofnumpy.ndarray
withtensor.numpy()
, and then reshape each array to the corresponding shape. Finally, we can callmodel.set_weights(params)
. - The other way is more efficient in computing and can work with the graph mode and without eager execution. Each element in
model.trainable_variables
is atf.Variable
, which provides a member method,assign
, to update the values in it. So we first partition the 1Dtf.Tensor
to a list of tensors withtf.dynamic_partition.
Next, we use afor
loop to reshape each tensor and then assign it to the correspondingtf.Variable
inmodel.trainable_variables
.
The third thing is that when returning the gradients, the gradients should also be a 1D tf.Tensor
to match the value_and_gradients_function
. This is again can be done with tf.dynamic_stitch
.
In addition, the argument initial_position
of tfp.optimizer.lbfgs_minimize
should be the initial parameter values of the model. And of course, we should use tf.dynamic_stitch
to covert the initial model parameters to a 1D tf.Tensor
before passing them to initial_position
.
In a nutshell, when we create a function for the value_and_gradients_function
argument, this function should store the following information inside it:
- the
tf.keras.Model
model we want to use, - the loss function to use,
- the training data that we want to evaluate the loss,
- the information required by
tf.dynamic_stitch
to convert a list of multidimensionaltf.Tensor
to a 1Dtf.Tensor
, and - the information required by
tf.dynamic_partition
to convert a 1Dtf.Tensor
to a list of multidimensionaltf.Tensor
ornumpy.ndarray
.
We can define a function factory to create such a function for us. Here’s an example:
And here’s an example how to use this function factory together with tfp.optimizer.lbfgs_minimize
to train a tf.keras.Model
model.
The complete example code can be found at my GitHub Gist here.
Finally, the example code is just to show a sense of how to use the L-BFGS solver from TensorFlow Probability. Using a function factory is not the only option. value_and_gradients_function
can be a callable object. So we can also wrap the model with a Python class and implement the __call__
method. It’s all up to us.
30 Comments