Press "Enter" to skip to content

Different calling methods of the metric objects in Keras

Disclaimer: I use the Keras interface from TensorFlow 2.0.0-rc1, so I’m not sure if the content applies to the original Keras library or other versions of TensorFlow. And note that TensorFlow 2.0.0-rc1 uses eager-execution mode by default.

Let’s say we have a metric object m=tf.keras.metrics.Mean(). There are two ways to add a number, for example, 10, to the metric m: the first way is to call m.update_state(10), and the other way is to simply call m(10). When we want to know the current mean value, we then use m.result().

I was wondering what’s the difference between using m(10) and m.update_state(10) as I couldn’t find information from the documentation of tf.keras.metrics.Mean (see here). The special member function of a Python class that allows us to use something like m(10) is __call__(). But in that documentation, there’s no mention of __call__() method. And why did I care? Because in the official tutorial here, the tutorial uses the second way to add a value to metrics without explanation.

If we use an interactive environment, like an IPython console or a Jupyter Notebook, we can find some information regarding the two methods through help(tf.keras.metrics.Mean):

  • tf.keras.metrics.Mean.update_state:

    Accumulates statistics for computing the reduction metric. For example, if `values` is [1, 3, 5, 7] and reduction=SUM_OVER_BATCH_SIZE, then the value of `result()` is 4. If the `sample_weight` is specified as [1, 1, 0, 0] then value of `result()` would be 2.

  • tf.keras.metrics.Mean.__call__:

    Accumulates statistics and then computes metric result value.

So apparently, update_state simply updates the sequence of numbers but does not do the calculation (that is, averaging the numbers in our example). And __call__ updates the sequence and also does the calculation. From this point of view, I guess calling __call__ (or the m(10) in our example) is not efficient. After all, we don’t really need to know the current mean value everytime we add a number to the sequence.

We can actually see this from the source code. As indicated by help(tf.keras.metrics.Mean), the method __call__ is inderited from tf.keras.metrics.Metric. Its source code can be viewed here:

 def __call__(self, *args, **kwargs):
    """Accumulates statistics and then computes metric result value.
    Args:
      *args:
      **kwargs: A mini-batch of inputs to the Metric,
        passed on to `update_state()`.
    Returns:
      The metric value tensor.
    """

    def replica_local_fn(*args, **kwargs):
      """Updates the state of the metric in a replica-local context."""
      update_op = self.update_state(*args, **kwargs)  # pylint: disable=not-callable
      with ops.control_dependencies([update_op]):
        result_t = self.result()  # pylint: disable=not-callable

        # We are adding the metric object as metadata on the result tensor.
        # This is required when we want to use a metric with `add_metric` API on
        # a Model/Layer in graph mode. This metric instance will later be used
        # to reset variable state after each epoch of training.
        # Example:
        #   model = Model()
        #   mean = Mean()
        #   model.add_metric(mean(values), name='mean')
        result_t._metric_obj = self  # pylint: disable=protected-access
        return result_t

    from tensorflow.python.keras.distribute import distributed_training_utils  # pylint:disable=g-import-not-at-top
    return distributed_training_utils.call_replica_local_fn(
        replica_local_fn, *args, **kwargs)

Removing the docstrings, we can see that __call__ is actually calling update_state(...) and then result().

Be First to Comment

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.