Note: The content in this post was done with the GPU version of TensorFlow 2.0.
I need float64
for my models. However, not until recently I found even though the dtype
of all layers in my models are tf.float64
and even though my x
is a tf.float64
tensor, when I do model(x)
, the model still treats x
as a tf.float32
tensor. Here’s a simple example to check this:
class TestModel(tf.keras.Model): def __init__(self): super(TestModel, self).__init__() def call(self, inputs, training=False): print("The type of input: {}".format(inputs.dtype)) return inputs
And then in any python interpreter or script
x = tf.constant(1.0, dtype=tf.float64) print("The type of x: {}".format(x.dtype)) model = TestModel() y = model(x) print("The type of output: {}".format(y.dtype))
In the output, we should see something like
The type of x: <dtype: 'float64'> The type of input: <dtype: 'float32'> The type of output: <dtype: 'float32'>
To solve this issue, we have to also provide the dtype
argument to our model. Something like:
class TestModel(tf.keras.Model): def __init__(self, dtype=tf.float64): super(TestModel, self).__init__(dtype=dtype) def call(self, inputs, training=False): print("The type of input: {}".format(inputs.dtype)) return inputs
Now we can see the x
, inputs
, and y
are all `tf.float64`.
The type of x: <dtype: 'float64'> The type of input: <dtype: 'float64'> The type of output: <dtype: 'float64'>
This is not a complex problem. But I wasn’t aware of it because the official documentation does not mention this, at least not when I am writing this blog. And I didn’t get a warning or an error message during runtime, either.
The underlying reason is that in a parent class of tf.keras.Model
, which is tf.keras.layers.Layer
, the __call__
method calls a casting method _maybe_cast_input
. And _maybe_case_input
will cast the inputs
to a model’s dtype
. So we must specify the model’s dtype
to the one we desire. Otherwise, no matter how we set up our data or layers, the model will always use tf.float32
by default.
Another workaround is to force TensorFlow to use tf.float64
as the default floating number type. We can do this through tf.keras.backend.set_floatx('float64')
.
Be First to Comment