Disclaimer: I use the Keras interface from TensorFlow 2.0.0-rc1, so I’m not sure if the content applies to the original Keras library or other versions of TensorFlow. And note that TensorFlow 2.0.0-rc1 uses eager-execution mode by default.
Let’s say we have a metric object
m=tf.keras.metrics.Mean(). There are two ways to add a number, for example,
10, to the metric
m: the first way is to call
m.update_state(10), and the other way is to simply call
m(10). When we want to know the current mean value, we then use
I was wondering what’s the difference between using
m.update_state(10) as I couldn’t find information from the documentation of
tf.keras.metrics.Mean (see here). The special member function of a Python class that allows us to use something like
__call__(). But in that documentation, there’s no mention of
__call__() method. And why did I care? Because in the official tutorial here, the tutorial uses the second way to add a value to metrics without explanation.
If we use an interactive environment, like an IPython console or a Jupyter Notebook, we can find some information regarding the two methods through
Accumulates statistics for computing the reduction metric. For example, if `values` is [1, 3, 5, 7] and reduction=SUM_OVER_BATCH_SIZE, then the value of `result()` is 4. If the `sample_weight` is specified as [1, 1, 0, 0] then value of `result()` would be 2.
Accumulates statistics and then computes metric result value.
update_state simply updates the sequence of numbers but does not do the calculation (that is, averaging the numbers in our example). And
__call__ updates the sequence and also does the calculation. From this point of view, I guess calling
__call__ (or the
m(10) in our example) is not efficient. After all, we don’t really need to know the current mean value everytime we add a number to the sequence.
We can actually see this from the source code. As indicated by
help(tf.keras.metrics.Mean), the method
__call__ is inderited from
tf.keras.metrics.Metric. Its source code can be viewed here:
def __call__(self, *args, **kwargs): """Accumulates statistics and then computes metric result value. Args: *args: **kwargs: A mini-batch of inputs to the Metric, passed on to `update_state()`. Returns: The metric value tensor. """ def replica_local_fn(*args, **kwargs): """Updates the state of the metric in a replica-local context.""" update_op = self.update_state(*args, **kwargs) # pylint: disable=not-callable with ops.control_dependencies([update_op]): result_t = self.result() # pylint: disable=not-callable # We are adding the metric object as metadata on the result tensor. # This is required when we want to use a metric with `add_metric` API on # a Model/Layer in graph mode. This metric instance will later be used # to reset variable state after each epoch of training. # Example: # model = Model() # mean = Mean() # model.add_metric(mean(values), name='mean') result_t._metric_obj = self # pylint: disable=protected-access return result_t from tensorflow.python.keras.distribute import distributed_training_utils # pylint:disable=g-import-not-at-top return distributed_training_utils.call_replica_local_fn( replica_local_fn, *args, **kwargs)
Removing the docstrings, we can see that
__call__ is actually calling
update_state(...) and then