Press "Enter" to skip to content

Optimize TensorFlow & Keras models with L-BFGS from TensorFlow Probability

Summary: This post showcases a workaround to optimize a tf.keras.Model model with a TensorFlow-based L-BFGS optimizer from TensorFlow Probability. The complete code can be found at my GitHub Gist here.

Update (06/08/2020): I’ve updated the code on GitHub Gist to show how to save loss values into a list when using the @tf.function decorator. But I didn’t update the blog post here, so the code line numbers may not match the code on GitHub.

While SGD, Adam, etc. optimizers nowadays dominate the training of deep neural networks, some, including me, may want to use second-order methods, such as L-BFGS. The problem is TensorFlow 2.0 does not have L-BFGS.

PyTorch provides L-BFGS, so I guess that using Keras with PyTorch backend may be a possible workaround. But I don’t use original Keras. I use TensorFlow 2.0 and build Keras models with the tf.keras module. That means I’m not able to switch the backend. And it’s difficult for me to migrate my code to original Keras because my code has many customized things that require TensorFlow 2.0.

A workaround is to use the L-BFGS solver from SciPy library to train a tf.keras.Model or its subclasses. We can find some example code of this workaround from Google search. The problem of this workaround, however, is that vanilla SciPy is not GPU-capable. And many people like me use TensorFlow because we need GPU computing. In addition, I personally don’t think SciPy is capable of doing serious and large-scale calculations. It is for prototyping, not for something supposed to run on HPC clusters. Though it’s just my personal opinion.

Fortunately, TensorFlow-based L-BFGS solver exists in a library called TensorFlow Probability. The API documentation of this solver is here. We can use it through something like import tensorflow_probability as tfp and then result = tfp.optimizer.lbfgs_minimize(...). The returned object, result, contains several data. And the final optimized parameters will be in result.position. If using a GPU version of TensorFlow, then this L-BFGS solver should also run on GPUs.

Apparently, the solver is not implemented as a subclass of tf.keras.optimizers.Optimizer. So we are not able to use this solver directly with model.compile(...) and model.fit(...). The solver is a function. We need some workaround or wrappers to use this solver. Let’s first see the arguments of this function:

tfp.optimizer.lbfgs_minimize(
value_and_gradients_function,
initial_position,
num_correction_pairs=10,
tolerance=1e-08,
x_tolerance=0,
f_relative_tolerance=0,
initial_inverse_hessian_estimate=None,
max_iterations=50,
parallel_iterations=1,
stopping_condition=None,
name=None
)
Arguments of tfp.optimizer.lbfgs_minimize

The value_and_gradients_function is a function or a callable object that returns the loss and the gradients with respect to parameters. And this callable object should take in the parameters that we want to optimize, which are a model’s trainable parameters (i.e., the kernels & biases of trainable layers). See the first notable thing here? The callable object value_and_gradients_function takes in model parameters, not training data.

The second thing is that the model parameters fed into value_and_gradients_function should be a 1D tf.Tensor. But TensorFlow and Keras store trainable model parameters as a list of multidimensional tf.Variable. We can easily see this with print(model.trainable_variables) (assuming model is an instance of tf.keras.Model or its subclasses). This means we need to find a way to transform a list of multidimensional tf.Variable to a single 1D tensor. This can be done with tf.dynamic_stitch.

We also need a way to convert a 1D tf.Tensor back to a list of multidimensional tf.Tensor, tf.Variable or numpy.ndarray so that we can update the parameters of the model. (Basically, under most situations, we can treat tf.Variable like tf.Tensor, and vice versa.) There are two ways to update a model’s parameters

  1. We can use model.set_weights(params) to update the values of the parameters. In this case, params is a list of multidimensional numpy.ndarray. And we need to convert the aforementioned 1D tf.Tensor to the params. In TensorFlow 2.0, this is not difficult because of the default eager execution behavior. We first partition the 1D tf.Tensor with tf.dynamic_partition to a list of tensors, convert the list of tensors to a list of numpy.ndarray with tensor.numpy(), and then reshape each array to the corresponding shape. Finally, we can call model.set_weights(params).
  2. The other way is more efficient in computing and can work with the graph mode and without eager execution. Each element in model.trainable_variables is a tf.Variable, which provides a member method, assign, to update the values in it. So we first partition the 1D tf.Tensor to a list of tensors with tf.dynamic_partition. Next, we use a for loop to reshape each tensor and then assign it to the corresponding tf.Variable in model.trainable_variables.

The third thing is that when returning the gradients, the gradients should also be a 1D tf.Tensor to match the value_and_gradients_function. This is again can be done with tf.dynamic_stitch.

In addition, the argument initial_position of tfp.optimizer.lbfgs_minimize should be the initial parameter values of the model. And of course, we should use tf.dynamic_stitch to covert the initial model parameters to a 1D tf.Tensor before passing them to initial_position.

In a nutshell, when we create a function for the value_and_gradients_function argument, this function should store the following information inside it:

  1. the tf.keras.Model model we want to use,
  2. the loss function to use,
  3. the training data that we want to evaluate the loss,
  4. the information required by tf.dynamic_stitch to convert a list of multidimensional tf.Tensor to a 1D tf.Tensor, and
  5. the information required by tf.dynamic_partition to convert a 1D tf.Tensor to a list of multidimensional tf.Tensor or numpy.ndarray.

We can define a function factory to create such a function for us. Here’s an example:

def function_factory(model, loss, train_x, train_y):
"""A factory to create a function required by tfp.optimizer.lbfgs_minimize.

Args:
model [in]: an instance of `tf.keras.Model` or its subclasses.
loss [in]: a function with signature loss_value = loss(pred_y, true_y).
train_x [in]: the input part of training data.
train_y [in]: the output part of training data.

Returns:
A function that has a signature of:
loss_value, gradients = f(model_parameters).
"""

# obtain the shapes of all trainable parameters in the model
shapes = tf.shape_n(model.trainable_variables)
n_tensors = len(shapes)

# we'll use tf.dynamic_stitch and tf.dynamic_partition later, so we need to
# prepare required information first
count = 0
idx = [] # stitch indices
part = [] # partition indices

for i, shape in enumerate(shapes):
n = numpy.product(shape)
idx.append(tf.reshape(tf.range(count, count+n, dtype=tf.int32), shape))
part.extend([i]*n)
count += n

part = tf.constant(part)

@tf.function
def assign_new_model_parameters(params_1d):
"""A function updating the model's parameters with a 1D tf.Tensor.

Args:
params_1d [in]: a 1D tf.Tensor representing the model's trainable parameters.
"""

params = tf.dynamic_partition(params_1d, part, n_tensors)
for i, (shape, param) in enumerate(zip(shapes, params)):
model.trainable_variables[i].assign(tf.reshape(param, shape))

# now create a function that will be returned by this factory
@tf.function
def f(params_1d):
"""A function that can be used by tfp.optimizer.lbfgs_minimize.

This function is created by function_factory.

Args:
params_1d [in]: a 1D tf.Tensor.

Returns:
A scalar loss and the gradients w.r.t. the `params_1d`.
"""

# use GradientTape so that we can calculate the gradient of loss w.r.t. parameters
with tf.GradientTape() as tape:
# update the parameters in the model
assign_new_model_parameters(params_1d)
# calculate the loss
loss_value = loss(model(train_x, training=True), train_y)

# calculate gradients and convert to 1D tf.Tensor
grads = tape.gradient(loss_value, model.trainable_variables)
grads = tf.dynamic_stitch(idx, grads)

# print out iteration & loss
f.iter.assign_add(1)
tf.print("Iter:", f.iter, "loss:", loss_value)

# store loss value so we can retrieve later
tf.py_function(f.history.append, inp=[loss_value], Tout=[])

return loss_value, grads

# store these information as members so we can use them outside the scope
f.iter = tf.Variable(0)
f.idx = idx
f.part = part
f.shapes = shapes

And here’s an example how to use this function factory together with tfp.optimizer.lbfgs_minimize to train a tf.keras.Model model.

    inps = numpy.stack((x1.flatten(), x2.flatten()), 1)
outs = numpy.reshape(inps[:, 0]**2+inps[:, 1]**2, (x_1d.size**2, 1))

# prepare prediction model, loss function, and the function passed to L-BFGS solver
pred_model = tf.keras.Sequential(
[tf.keras.Input(shape=[2,]),
tf.keras.layers.Dense(64, "tanh"),
tf.keras.layers.Dense(64, "tanh"),
tf.keras.layers.Dense(1, None)])

loss_fun = tf.keras.losses.MeanSquaredError()
func = function_factory(pred_model, loss_fun, inps, outs)

# convert initial model parameters to a 1D tf.Tensor
init_params = tf.dynamic_stitch(func.idx, pred_model.trainable_variables)

# train the model with L-BFGS solver
results = tfp.optimizer.lbfgs_minimize(
value_and_gradients_function=func, initial_position=init_params, max_iterations=500)

# after training, the final optimized parameters are still in results.position
# so we have to manually put them back to the model

The complete example code can be found at my GitHub Gist here.

Finally, the example code is just to show a sense of how to use the L-BFGS solver from TensorFlow Probability. Using a function factory is not the only option. value_and_gradients_function can be a callable object. So we can also wrap the model with a Python class and implement the __call__ method. It’s all up to us.

29 Comments

  1. Anonymous Anonymous

    Thanks for sharing this code – was incredibly useful!!

  2. Anonymous Anonymous

    Nice buddy, but what about early stopping?

    • The early stopping mechanism in TensorFlow may not work with lbfgs, because the lbfgs function is not a TensorFlow optimizer object. I usually just allow the lbfgs to run up to the max iterations. But I can think of two workarounds:

      1. Check the loss in, for example, the “func” in the blog post. Whenever the true loss reaches a satisfying value, we force “func” to return zero for the loss. And the lbfgs will stop because it thinks the loss is zero already.

      2. Use the “tolerance”, “x_tolerance”, “f_relative_tolerance” arguments of the lbfgs function to stop the optimization early.

      Just some thoughts. I didn’t test these ideas.

  3. Alex Alex

    appreciated for your great work!!!

  4. Anonymous Anonymous

    Thanks a lot for sharing this code. Do you know if it is possible to create a list of all the loss values?

    • During each epoch/iteration, L-BFGS optimizer calls the function that returns loss and gradients several times. So it depends on what kind of loss history you want to put into the list — the loss each time the loss-gradient function is called, or the loss at the end of each epoch/iteration.

      To save the loss each time the function “func” is called, we can add “f.history = []” after line 105 and add “f.history.append(loss_value)” before line 98. Finally, after optimization, retrieve the loss history with “func.history”

      But to save only the loss at the end of each epoch, I haven’t had a good solution to achieve it.

      • Aron Aron

        Hi, Thanks for the code! I’m trying to save the history, saving at every function call rather than every evaluation is not so much a problem, but doing exactly as written above, adding “f.history = []” and “f.history.append(loss_value)” isn’t working. It returns this object “[]” which I think doesn’t have any data in it, but I’m new to tensorflow so maybe it does and I just dont see how to get it out? (This is in tensorflow 2.2.0)

        • Yeah, you’re right. That doesn’t work because the function `f` is defined with “@tf.function” decorator, which makes it just a graph construction and disables the eager execution. If you’re not familiar with TensorFlow 1.x, then graph construction means the function does not actually calculate the numbers; instead, it just remembers what it needs and how to calculate a value.

          Anyway, in a nutshell, for toy problems, you can just remove the “@tf.function” at line 70, and then “f.histtory.append” should work.

          For serious runs, removing “@tf.function” will make training slower. We may need to utilize “tf.py_function”. In other words, replace “f.history.append(loss_value)” with “tf.py_function(f.history.append, inp=[loss_value], Tout=[])”

          I’ve updated the code on GitHub: https://gist.github.com/piyueh/712ec7d4540489aad2dcfb80f9a54993

          • Aron Aron

            Thanks for the update, that works!
            However, I think the implementation is incorrect. As you wrote above, for each parameter update the loss function is called several times. This is done in the line search part of BFGS and is meant to only inspect the value of the loss function along the direction of the next step to determine the best length of this step, not to update the current parameters.

            But as I understand this code updates the parameters each time the loss function is called, so it actually takes several steps of all of the lengths tried in a given direction, probably overshooting every time and slowing or preventing convergence.

          • Aron Aron

            Thanks for the update, that works!
            However, I think the implementation is incorrect. As you wrote above, for each parameter update the loss function is called several times. This is done in the line search part of BFGS and is meant to only inspect the value of the loss function along the direction of the next step to determine the best length of this step, not to update the current parameters.

            But as I understand this code updates the parameters each time the loss function is called, so it actually takes several steps of all of the lengths tried in a given direction, probably overshooting every time and slowing or preventing convergence.

            If we had access to a counter of the line search we could modify the code to update only when the line search is done, but I don’t see this in the documentation.

            What do you think?
            (sorry may have posted twice accidentally)

        • I’m not sure about this. The loss-gradient function takes in a tensor as parameters, which is provided by the L-BFGS solver from TensorFlow Probability. To me, this implies the L-BFGS solver keeps its own copy of parameters inside the solver. It doesn’t use the parameters in our model instance. That is also why we have to provide an argument “initial_position” to L-BFGS. L-BFGS starts its copy from “initial_position.”

          We update the parameters in our model instance just to calculate the loss and gradients. But whatever parameters we put inside the model instance should not affect the L-BFGS and line-searching steps. For example, if the neural network is very basic, says a one-hidden-layer network, then in our loss-gradient function, we don’t even have to update the model parameters nor to use the model instance. We can just use the tensor provided by L-BFGS to calculate the loss and gradients. This says the parameters inside the model instance have nothing to do with the actual parameters used inside L-BFGS.

          So if the line search is not correct, this is probably something in TensorFlow Probability’s L-BFGS code. But I’m just guessing, not so sure.

  5. M Buckner M Buckner

    Is it possible for the LBFG-S optimizer to train two different models/networks at the same time because they share a custom loss function?

    • By sharing a custom loss function, you mean the loss is calculated by using both networks? Then yes. But you have to find a way to stitch the parameters from the two models into one single 1D tensor. And also you need to figure out how to split a 1D tensor and to distribute the resulting sub-tensors to the parameters in the two models.

  6. OP OP

    Hi, Thanks for the code! I am wondering if there is a simple way to convert the code to use tf 1.15?

    • Hi, I never used TF 1.15, so I’m not sure how different it is from TF 2.0. But I remember TF 1.15 has some kind of interfaces to SciPy functions. SciPy’s BFGS optimizer has a similar function signature to TensorFlow Probability’s L-BFGS, so maybe try modifying that SciPy interface?

  7. Anonymous Anonymous

    Thanks for the code! I want to ask, can this code be transformed to the L-BFGS-B? And how?

    • Hi, this code uses TensorFlow Probability’s L-BFGS, and TensorFlow Probability does not seem to have L-BFGS-B. So I guess the answer is no.

      If you don’t mind using SciPy’s L-BFGS-B ( I think it’s non-parallel CPU code), it’s very similar to the posted solution. See https://stackoverflow.com/q/59029854/12488965

  8. Anonymous Anonymous

    Hi, This is really awesome and works very well! Is it possible to process the training data in mini-batches with this method?

    • Hi, I think so, though I never tried it, and you may have to make more workarounds in the code. Also, I doubt L-BFGS’ efficiency when using mini-batches.

      1. From the mathematical aspect, the regular L-BFGS method does not work well with mini-batch training. If you google the papers of L-BFGS for mini-batch training, this is probably still an ongoing research topic. And people are still developing modified L-BFGS for mini-batch approach. But the L-BFGS in TensorFlow Probability is just the regular version L-BFGS and does not have those modifications.

      2. From the compuational aspect, L-BFGS usually requires a lot of memory. So if a problem is big and requires mini-batch approach, then L-BFGS may be too memory-demanding for this problem.

      • Anonymous Anonymous

        Ah I see thanks for your informative reply! Very helpful and once again this code is really great to use thanks for sharing it!

  9. Nathan Wycoff Nathan Wycoff

    Dude thank you so much for this. It’s so frustrating how TFP’s optim stuff is so different in terms of usage from the vanilla TF….

  10. Thanks very much for sharing this under the MIT license! I recently ran into a use case where a deterministic solver was preferable to a stochastic solver, and wanted to thank you & let you know I’ve adapted parts of this code in the following package: https://github.com/mbhynes/kormos.
    There’s a few modifications in the function/gradient evaluation loop to keep it robust even with models with large memory requirements.
    Cheers!
    Mike

  11. John L John L

    Thank you very much for your code! I have a question though, about “max_iterations”. It seems that if I put for example max_it = 1000 , it never stops at exactly 1000 iterations. I have noticed it does almost 3 times the max_it, so for that example it would stop after almost 3000 iterations. Also, there are cases where it stops before even reaching that value. Do you know where this could be attributed to? Thanks in advance for any answer.

    • The number of iterations printed on the screen is not the number of L-BFGS iterations. It is the number of loss-function evaluations. Each L-BFGS iteration needs to evaluate the loss function and the gradients several times, so you will see the number printed on the screen higher than the `max_iterations`. On the other hand, if the L-BFGS stops earlier, it means other stopping criteria kick in (for example, when the two consecutive iterations do not give significant changes in the losses).

Leave a Reply to Alex Cancel reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.