TransWikia.com

tf.keras.utils.to_categorical raises TypeError in graph mode

Stack Overflow Asked by Borun Chowdhury on November 5, 2020

I have a sparse categorical tensor that I want to convert into a one-hot encoded representation. I can get it to work in eager mode but not in graph mode. I do not understand what is going on.

Specifically, lets say I have a sparse classification

y=tf.constant(np.random.choice([0,1,2],2).reshape(-1,1))

that I want to convert into one hot encoded representation. I defined a function

def tmp(y):
    return tf.keras.utils.to_categorical(y)

and it works as expected. However, if I wrap the function

@tf.function
def tmp(y):
    return tf.keras.utils.to_categorical(y)

then I get the exception

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-23-bb4411a7c6f7> in <module>
----> 1 tmp(y)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    578         xla_context.Exit()
    579     else:
--> 580       result = self._call(*args, **kwds)
    581 
    582     if tracing_count == self._get_tracing_count():

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    625       # This is the first call of __call__, so we have to initialize.
    626       initializers = []
--> 627       self._initialize(args, kwds, add_initializers_to=initializers)
    628     finally:
    629       # At this point we know that the initialization is complete (or less

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    504     self._concrete_stateful_fn = (
    505         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 506             *args, **kwds))
    507 
    508     def invalid_creator_scope(*unused_args, **unused_kwds):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   2444       args, kwargs = None, None
   2445     with self._lock:
-> 2446       graph_function, _, _ = self._maybe_define_function(args, kwargs)
   2447     return graph_function
   2448 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2775 
   2776       self._function_cache.missed.add(call_context_key)
-> 2777       graph_function = self._create_graph_function(args, kwargs)
   2778       self._function_cache.primary[cache_key] = graph_function
   2779       return graph_function, args, kwargs

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2665             arg_names=arg_names,
   2666             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2667             capture_by_value=self._capture_by_value),
   2668         self._function_attributes,
   2669         # Tell the ConcreteFunction to clean up its graph once it goes out of

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    979         _, original_func = tf_decorator.unwrap(python_func)
    980 
--> 981       func_outputs = python_func(*func_args, **func_kwargs)
    982 
    983       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    439         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    440         # the function a weak reference to itself to avoid a reference cycle.
--> 441         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    442     weak_wrapped_fn = weakref.ref(wrapped_fn)
    443 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

TypeError: in user code:

    <ipython-input-11-4e328cd877a4>:3 tmp  *
        return tf.keras.utils.to_categorical(y)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/np_utils.py:49 to_categorical  **
        y = np.array(y, dtype='int')

    TypeError: __array__() takes 1 positional argument but 2 were given

On the face of it, it looks like a numpy error but I verified that my numpy installation takes two arguments so what is going wrong?

More generally what is the way to convert a spares categorical tensor to one hot encoded one in graph mode?

One Answer

You can use tf.one_hot instead of tf.keras.utils.to_categorical as a workaround. It working as expected both in eager and graph mode.

Please refer complete code as shown below

import tensorflow as tf
print(tf.__version__)
import numpy as np
y=tf.constant(np.random.choice([0,1,2],[2,2]))
print(y)

Output:

2.3.0
tf.Tensor(
[[1 0]
 [2 1]], shape=(2, 2), dtype=int64)

Eager Mode with tf.keras.utils.to_categorical

def tmp(y):
    return tf.keras.utils.to_categorical(y)
tmp(y)

Output:

array([[[0., 1., 0.],
        [1., 0., 0.]],

       [[0., 0., 1.],
        [0., 1., 0.]]], dtype=float32)

Eager Mode with tf.one_hot

def tmp(y):
    return tf.one_hot(y,3)
tmp(y)

Output:

<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[0., 1., 0.],
        [1., 0., 0.]],

       [[0., 0., 1.],
        [0., 1., 0.]]], dtype=float32)>

Graph Mode with tf.one_hot

@tf.function
def tmp(y):
    return tf.one_hot(y,3)
tmp(y)

Output:

<tf.Tensor: shape=(2, 2, 3), dtype=float32, numpy=
array([[[0., 1., 0.],
        [1., 0., 0.]],

       [[0., 0., 1.],
        [0., 1., 0.]]], dtype=float32)>

Answered by TFer2 on November 5, 2020

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP