Press "Enter" to skip to content

The implicit dtype conversion in the call and __call__ methods of tf.keras.Model

Note: The content in this post was done with the GPU version of TensorFlow 2.0.

I need float64 for my models. However, not until recently I found even though the dtype of all layers in my models are tf.float64 and even though my x is a tf.float64 tensor, when I do model(x), the model still treats x as a tf.float32 tensor. Here’s a simple example to check this:

class TestModel(tf.keras.Model):
    def __init__(self):
        super(TestModel, self).__init__()
    def call(self, inputs, training=False):
        print("The type of input: {}".format(inputs.dtype))
        return inputs

And then in any python interpreter or script

x = tf.constant(1.0, dtype=tf.float64)
print("The type of x: {}".format(x.dtype))
model = TestModel()
y = model(x)
print("The type of output: {}".format(y.dtype))

In the output, we should see something like

The type of x: <dtype: 'float64'>
The type of input: <dtype: 'float32'>
The type of output: <dtype: 'float32'>

To solve this issue, we have to also provide the dtype argument to our model. Something like:

class TestModel(tf.keras.Model):
    def __init__(self, dtype=tf.float64):
        super(TestModel, self).__init__(dtype=dtype)
    def call(self, inputs, training=False):
        print("The type of input: {}".format(inputs.dtype))
        return inputs

Now we can see the x, inputs, and y are all `tf.float64`.

The type of x: <dtype: 'float64'>
The type of input: <dtype: 'float64'>
The type of output: <dtype: 'float64'>

This is not a complex problem. But I wasn’t aware of it because the official documentation does not mention this, at least not when I am writing this blog. And I didn’t get a warning or an error message during runtime, either.

The underlying reason is that in a parent class of tf.keras.Model, which is tf.keras.layers.Layer, the __call__ method calls a casting method _maybe_cast_input. And _maybe_case_input will cast the inputs to a model’s dtype. So we must specify the model’s dtype to the one we desire. Otherwise, no matter how we set up our data or layers, the model will always use tf.float32 by default.

Another workaround is to force TensorFlow to use tf.float64 as the default floating number type. We can do this through tf.keras.backend.set_floatx('float64').

Be First to Comment

Leave a Reply

Your email address will not be published.

This site uses Akismet to reduce spam. Learn how your comment data is processed.