Disclaimer: I use the Keras interface from TensorFlow 2.0.0-rc1, so I’m not sure if the content applies to the original Keras library or other versions of TensorFlow. And note that TensorFlow 2.0.0-rc1 uses eager-execution mode by default.
Let’s say we have a metric object m=tf.keras.metrics.Mean()
. There are two ways to add a number, for example, 10
, to the metric m
: the first way is to call m.update_state(10)
, and the other way is to simply call m(10)
. When we want to know the current mean value, we then use m.result()
.
I was wondering what’s the difference between using m(10)
and m.update_state(10)
as I couldn’t find information from the documentation of tf.keras.metrics.Mean
(see here). The special member function of a Python class that allows us to use something like m(10)
is __call__()
. But in that documentation, there’s no mention of __call__()
method. And why did I care? Because in the official tutorial here, the tutorial uses the second way to add a value to metrics without explanation.
If we use an interactive environment, like an IPython console or a Jupyter Notebook, we can find some information regarding the two methods through help(tf.keras.metrics.Mean)
:
tf.keras.metrics.Mean.update_state
:
Accumulates statistics for computing the reduction metric. For example, if `values` is [1, 3, 5, 7] and reduction=SUM_OVER_BATCH_SIZE, then the value of `result()` is 4. If the `sample_weight` is specified as [1, 1, 0, 0] then value of `result()` would be 2.
tf.keras.metrics.Mean.__call__
:
Accumulates statistics and then computes metric result value.
So apparently, update_state
simply updates the sequence of numbers but does not do the calculation (that is, averaging the numbers in our example). And __call__
updates the sequence and also does the calculation. From this point of view, I guess calling __call__
(or the m(10)
in our example) is not efficient. After all, we don’t really need to know the current mean value everytime we add a number to the sequence.
We can actually see this from the source code. As indicated by help(tf.keras.metrics.Mean)
, the method __call__
is inderited from tf.keras.metrics.Metric
. Its source code can be viewed here:
def __call__(self, *args, **kwargs):
"""Accumulates statistics and then computes metric result value.
Args:
*args:
**kwargs: A mini-batch of inputs to the Metric,
passed on to `update_state()`.
Returns:
The metric value tensor.
"""
def replica_local_fn(*args, **kwargs):
"""Updates the state of the metric in a replica-local context."""
update_op = self.update_state(*args, **kwargs) # pylint: disable=not-callable
with ops.control_dependencies([update_op]):
result_t = self.result() # pylint: disable=not-callable
# We are adding the metric object as metadata on the result tensor.
# This is required when we want to use a metric with `add_metric` API on
# a Model/Layer in graph mode. This metric instance will later be used
# to reset variable state after each epoch of training.
# Example:
# model = Model()
# mean = Mean()
# model.add_metric(mean(values), name='mean')
result_t._metric_obj = self # pylint: disable=protected-access
return result_t
from tensorflow.python.keras.distribute import distributed_training_utils # pylint:disable=g-import-not-at-top
return distributed_training_utils.call_replica_local_fn(
replica_local_fn, *args, **kwargs)
Removing the docstrings, we can see that __call__
is actually calling update_state(...)
and then result()
.
Be First to Comment