Use XLA with tf.function

View on TensorFlow.org Run in Google Colab Download notebook View source on GitHub

This tutorial trains a TensorFlow model to classify the MNIST dataset, where the training function is compiled using XLA.

First, load TensorFlow and enable eager execution.

import tensorflow as tf
2024-07-19 11:23:45.216189: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-19 11:23:45.236918: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-19 11:23:45.243232: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

Then define some necessary constants and prepare the MNIST dataset.

# Size of each input image, 28 x 28 pixels
IMAGE_SIZE = 28 * 28
# Number of distinct number labels, [0..9]
NUM_CLASSES = 10
# Number of examples in each training batch (step)
TRAIN_BATCH_SIZE = 100
# Number of training steps to run
TRAIN_STEPS = 1000

# Loads MNIST dataset.
train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()

# Casting from raw data to the required datatypes.
def cast(images, labels):
  images = tf.cast(
      tf.reshape(images, [-1, IMAGE_SIZE]), tf.float32)
  labels = tf.cast(labels, tf.int64)
  return (images, labels)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1721388228.402026   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388228.405848   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388228.409509   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388228.413274   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388228.424828   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388228.428253   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388228.431743   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388228.435175   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388228.438568   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388228.441992   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388228.445396   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388228.448821   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.713543   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.715722   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.717723   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.719825   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.721969   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.724001   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.725904   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.727964   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.730060   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.732070   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.733963   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.736063   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.774171   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.776268   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.778211   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.780282   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.782155   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.784174   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.786611   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.788610   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.790486   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.793023   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.795337   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721388229.797720   14480 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355

Finally, define the model and the optimizer. The model uses a single dense layer.

layer = tf.keras.layers.Dense(NUM_CLASSES)
optimizer = tf.keras.optimizers.Adam()

Define the training function

In the training function, you get the predicted labels using the layer defined above, and then minimize the gradient of the loss using the optimizer. In order to compile the computation using XLA, place it inside tf.function with jit_compile=True.

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

Train and test the model

Once you have defined the training function, define the model.

for images, labels in train_ds:
  if optimizer.iterations > TRAIN_STEPS:
    break
  train_mnist(images, labels)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1721388230.602891   14480 service.cc:146] XLA service 0xb51fe10 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1721388230.602933   14480 service.cc:154]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1721388230.602937   14480 service.cc:154]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1721388230.602940   14480 service.cc:154]   StreamExecutor device (2): Tesla T4, Compute Capability 7.5
I0000 00:00:1721388230.602942   14480 service.cc:154]   StreamExecutor device (3): Tesla T4, Compute Capability 7.5
I0000 00:00:1721388230.941982   14480 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.

And, finally, check the accuracy:

images, labels = cast(test[0], test[1])
predicted_labels = layer(images)
correct_prediction = tf.equal(tf.argmax(predicted_labels, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Prediction accuracy after training: %s" % accuracy)
Prediction accuracy after training: tf.Tensor(0.8818, shape=(), dtype=float32)

Behind the scenes, the XLA compiler has compiled the entire TF function to HLO, which has enabled fusion optimizations. Using the introspection facilities, we can see the HLO code (other interesting possible values for "stage" are optimized_hlo for HLO after optimizations and optimized_hlo_dot for a Graphviz graph):

print(train_mnist.experimental_get_compiler_ir(images, labels)(stage='hlo'))
HloModule a_inference_train_mnist_5553__.192, input_output_alias={ {0}: (2, {}, may-alias), {1}: (3, {}, may-alias), {2}: (5, {}, may-alias), {3}: (6, {}, may-alias), {4}: (7, {}, may-alias), {5}: (8, {}, may-alias), {6}: (9, {}, may-alias) }, entry_computation_layout={(f32[10000,784]{1,0}, s64[10000]{0}, f32[784,10]{1,0}, f32[10]{0}, f32[], /*index=5*/s64[], f32[784,10]{1,0}, f32[784,10]{1,0}, f32[10]{0}, f32[10]{0})->(f32[784,10]{1,0}, f32[10]{0}, s64[], f32[784,10]{1,0}, f32[784,10]{1,0}, /*index=5*/f32[10]{0}, f32[10]{0})}

%max_float_.71 (x.72: f32[], y.73: f32[]) -> f32[] {
  %x.72 = f32[] parameter(0)
  %y.73 = f32[] parameter(1)
  ROOT %maximum.74 = f32[] maximum(f32[] %x.72, f32[] %y.73)
}

%add_float_.81 (x.82: f32[], y.83: f32[]) -> f32[] {
  %x.82 = f32[] parameter(0)
  %y.83 = f32[] parameter(1)
  ROOT %add.84 = f32[] add(f32[] %x.82, f32[] %y.83)
}

%add_float_.100 (x.101: f32[], y.102: f32[]) -> f32[] {
  %x.101 = f32[] parameter(0)
  %y.102 = f32[] parameter(1)
  ROOT %add.103 = f32[] add(f32[] %x.101, f32[] %y.102)
}

%Mean-reduction.112 (x.113: f32[], y.114: f32[]) -> f32[] {
  %x.113 = f32[] parameter(0)
  %y.114 = f32[] parameter(1)
  ROOT %add.115 = f32[] add(f32[] %x.113, f32[] %y.114)
}

%gradient_tape_dense_1_Add_Sum-reduction.129 (x.130: f32[], y.131: f32[]) -> f32[] {
  %x.130 = f32[] parameter(0)
  %y.131 = f32[] parameter(1)
  ROOT %add.132 = f32[] add(f32[] %x.130, f32[] %y.131)
}

ENTRY %a_inference_train_mnist_5553__.192 (arg0.1: f32[10000,784], arg1.2: s64[10000], arg2.3: f32[784,10], arg3.4: f32[10], arg4.5: f32[], arg5.6: s64[], arg6.7: f32[784,10], arg7.8: f32[784,10], arg8.9: f32[10], arg9.10: f32[10]) -> (f32[784,10], f32[10], s64[], f32[784,10], f32[784,10], /*index=5*/f32[10], f32[10]) {
  %arg1.2 = s64[10000]{0} parameter(1), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %reshape.12 = s64[10000]{0} reshape(s64[10000]{0} %arg1.2)
  %broadcast.51 = s64[10000,10]{1,0} broadcast(s64[10000]{0} %reshape.12), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %iota.50 = s64[10000,10]{1,0} iota(), iota_dimension=1, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %compare.52 = pred[10000,10]{1,0} compare(s64[10000,10]{1,0} %broadcast.51, s64[10000,10]{1,0} %iota.50), direction=EQ, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.48 = f32[] constant(1), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.53 = f32[10000,10]{1,0} broadcast(f32[] %constant.48), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.49 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.54 = f32[10000,10]{1,0} broadcast(f32[] %constant.49), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %select.55 = f32[10000,10]{1,0} select(pred[10000,10]{1,0} %compare.52, f32[10000,10]{1,0} %broadcast.53, f32[10000,10]{1,0} %broadcast.54), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.56 = s64[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.57 = s64[10000]{0} broadcast(s64[] %constant.56), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %compare.58 = pred[10000]{0} compare(s64[10000]{0} %broadcast.57, s64[10000]{0} %reshape.12), direction=LE, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.59 = s64[] constant(10), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.60 = s64[10000]{0} broadcast(s64[] %constant.59), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %compare.61 = pred[10000]{0} compare(s64[10000]{0} %reshape.12, s64[10000]{0} %broadcast.60), direction=LT, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %and.62 = pred[10000]{0} and(pred[10000]{0} %compare.58, pred[10000]{0} %compare.61), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.63 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.64 = f32[10000]{0} broadcast(f32[] %constant.63), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.65 = f32[] constant(nan), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.66 = f32[10000]{0} broadcast(f32[] %constant.65), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %select.67 = f32[10000]{0} select(pred[10000]{0} %and.62, f32[10000]{0} %broadcast.64, f32[10000]{0} %broadcast.66), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.68 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %select.67), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %add.69 = f32[10000,10]{1,0} add(f32[10000,10]{1,0} %select.55, f32[10000,10]{1,0} %broadcast.68), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %negate.96 = f32[10000,10]{1,0} negate(f32[10000,10]{1,0} %add.69), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.90 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.91 = f32[10000,10]{1,0} broadcast(f32[] %constant.90), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %compare.92 = pred[10000,10]{1,0} compare(f32[10000,10]{1,0} %add.69, f32[10000,10]{1,0} %broadcast.91), direction=EQ, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.93 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.94 = f32[10000,10]{1,0} broadcast(f32[] %constant.93), dimensions={}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %arg0.1 = f32[10000,784]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %reshape.11 = f32[10000,784]{1,0} reshape(f32[10000,784]{1,0} %arg0.1)
  %reshape.43 = f32[10000,784]{1,0} reshape(f32[10000,784]{1,0} %reshape.11), metadata={op_type="Reshape" op_name="Reshape" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %arg2.3 = f32[784,10]{1,0} parameter(2), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %dot.44 = f32[10000,10]{1,0} dot(f32[10000,784]{1,0} %reshape.43, f32[784,10]{1,0} %arg2.3), lhs_contracting_dims={1}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="false"}, metadata={op_type="MatMul" op_name="dense_1/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %transpose.45 = f32[10000,10]{1,0} transpose(f32[10000,10]{1,0} %dot.44), dimensions={0,1}, metadata={op_type="MatMul" op_name="dense_1/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %arg3.4 = f32[10]{0} parameter(3), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %broadcast.46 = f32[10000,10]{1,0} broadcast(f32[10]{0} %arg3.4), dimensions={1}, metadata={op_type="AddV2" op_name="dense_1/Add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %add.47 = f32[10000,10]{1,0} add(f32[10000,10]{1,0} %transpose.45, f32[10000,10]{1,0} %broadcast.46), metadata={op_type="AddV2" op_name="dense_1/Add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.70 = f32[] constant(-inf), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %reduce.75 = f32[10000]{0} reduce(f32[10000,10]{1,0} %add.47, f32[] %constant.70), dimensions={1}, to_apply=%max_float_.71, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.76 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %reduce.75), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %subtract.77 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %add.47, f32[10000,10]{1,0} %broadcast.76), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %exponential.78 = f32[10000,10]{1,0} exponential(f32[10000,10]{1,0} %subtract.77), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %convert.79 = f32[10000,10]{1,0} convert(f32[10000,10]{1,0} %exponential.78), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.80 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %reduce.85 = f32[10000]{0} reduce(f32[10000,10]{1,0} %convert.79, f32[] %constant.80), dimensions={1}, to_apply=%add_float_.81, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %convert.86 = f32[10000]{0} convert(f32[10000]{0} %reduce.85), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %log.87 = f32[10000]{0} log(f32[10000]{0} %convert.86), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.88 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %log.87), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %subtract.89 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %subtract.77, f32[10000,10]{1,0} %broadcast.88), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %select.95 = f32[10000,10]{1,0} select(pred[10000,10]{1,0} %compare.92, f32[10000,10]{1,0} %broadcast.94, f32[10000,10]{1,0} %subtract.89), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %multiply.97 = f32[10000,10]{1,0} multiply(f32[10000,10]{1,0} %negate.96, f32[10000,10]{1,0} %select.95), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %convert.98 = f32[10000,10]{1,0} convert(f32[10000,10]{1,0} %multiply.97), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.99 = f32[] constant(0), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %reduce.104 = f32[10000]{0} reduce(f32[10000,10]{1,0} %convert.98, f32[] %constant.99), dimensions={1}, to_apply=%add_float_.100, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %convert.105 = f32[10000]{0} convert(f32[10000]{0} %reduce.104), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %convert.109 = f32[10000]{0} convert(f32[10000]{0} %convert.105), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.110 = f32[] constant(0), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %convert.111 = f32[] convert(f32[] %constant.110), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %reduce.116 = f32[] reduce(f32[10000]{0} %convert.109, f32[] %convert.111), dimensions={0}, to_apply=%Mean-reduction.112, metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.117 = s32[] constant(10000), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %convert.118 = f32[] convert(s32[] %constant.117), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %divide.119 = f32[] divide(f32[] %reduce.116, f32[] %convert.118), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %convert.120 = f32[] convert(f32[] %divide.119), metadata={op_type="Mean" op_name="Mean" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %arg6.7 = f32[784,10]{1,0} parameter(6), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %constant.121 = f32[] constant(0.0001), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.122 = f32[10000,1]{1,0} broadcast(f32[] %constant.121), dimensions={}, metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %reshape.123 = f32[10000]{0} reshape(f32[10000,1]{1,0} %broadcast.122), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.124 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %reshape.123), dimensions={0}, metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.106 = f32[10000,10]{1,0} broadcast(f32[10000]{0} %convert.86), dimensions={0}, metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %divide.107 = f32[10000,10]{1,0} divide(f32[10000,10]{1,0} %exponential.78, f32[10000,10]{1,0} %broadcast.106), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %subtract.108 = f32[10000,10]{1,0} subtract(f32[10000,10]{1,0} %divide.107, f32[10000,10]{1,0} %add.69), metadata={op_type="SparseSoftmaxCrossEntropyWithLogits" op_name="SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %multiply.125 = f32[10000,10]{1,0} multiply(f32[10000,10]{1,0} %broadcast.124, f32[10000,10]{1,0} %subtract.108), metadata={op_type="Mul" op_name="gradient_tape/SparseSoftmaxCrossEntropyWithLogits/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %dot.137 = f32[784,10]{1,0} dot(f32[10000,784]{1,0} %reshape.43, f32[10000,10]{1,0} %multiply.125), lhs_contracting_dims={0}, rhs_contracting_dims={0}, frontend_attributes={grad_x="false",grad_y="true"}, metadata={op_type="MatMul" op_name="gradient_tape/dense_1/MatMul/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %transpose.138 = f32[784,10]{1,0} transpose(f32[784,10]{1,0} %dot.137), dimensions={0,1}, metadata={op_type="MatMul" op_name="gradient_tape/dense_1/MatMul/MatMul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %subtract.159 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %transpose.138, f32[784,10]{1,0} %arg6.7), metadata={op_type="Sub" op_name="adam/Sub_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.160 = f32[] constant(0.1), metadata={op_type="Mul" op_name="adam/Mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.161 = f32[784,10]{1,0} broadcast(f32[] %constant.160), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %multiply.162 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %subtract.159, f32[784,10]{1,0} %broadcast.161), metadata={op_type="Mul" op_name="adam/Mul_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %add.163 = f32[784,10]{1,0} add(f32[784,10]{1,0} %arg6.7, f32[784,10]{1,0} %multiply.162), metadata={op_type="AssignAddVariableOp" op_name="adam/AssignAddVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %arg4.5 = f32[] parameter(4), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %constant.22 = f32[] constant(1), metadata={op_type="Sub" op_name="adam/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.20 = f32[] constant(0.999), metadata={op_type="Pow" op_name="adam/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %arg5.6 = s64[] parameter(5), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %constant.13 = s64[] constant(1), metadata={op_type="AddV2" op_name="adam/add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %add.14 = s64[] add(s64[] %arg5.6, s64[] %constant.13), metadata={op_type="AddV2" op_name="adam/add" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %convert.15 = f32[] convert(s64[] %add.14), metadata={op_type="Cast" op_name="adam/Cast_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %power.21 = f32[] power(f32[] %constant.20, f32[] %convert.15), metadata={op_type="Pow" op_name="adam/Pow_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %subtract.23 = f32[] subtract(f32[] %constant.22, f32[] %power.21), metadata={op_type="Sub" op_name="adam/sub" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %sqrt.24 = f32[] sqrt(f32[] %subtract.23), metadata={op_type="Sqrt" op_name="adam/Sqrt"}
  %multiply.25 = f32[] multiply(f32[] %arg4.5, f32[] %sqrt.24), metadata={op_type="Mul" op_name="adam/mul" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.18 = f32[] constant(1), metadata={op_type="Sub" op_name="adam/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.16 = f32[] constant(0.9), metadata={op_type="Pow" op_name="adam/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %power.17 = f32[] power(f32[] %constant.16, f32[] %convert.15), metadata={op_type="Pow" op_name="adam/Pow" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %subtract.19 = f32[] subtract(f32[] %constant.18, f32[] %power.17), metadata={op_type="Sub" op_name="adam/sub_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %divide.26 = f32[] divide(f32[] %multiply.25, f32[] %subtract.19), metadata={op_type="RealDiv" op_name="adam/truediv" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.164 = f32[784,10]{1,0} broadcast(f32[] %divide.26), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %multiply.165 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %add.163, f32[784,10]{1,0} %broadcast.164), metadata={op_type="Mul" op_name="adam/Mul_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %arg7.8 = f32[784,10]{1,0} parameter(7), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %multiply.139 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %transpose.138, f32[784,10]{1,0} %transpose.138), metadata={op_type="Square" op_name="adam/Square" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %subtract.140 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %multiply.139, f32[784,10]{1,0} %arg7.8), metadata={op_type="Sub" op_name="adam/Sub_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.141 = f32[] constant(0.001), metadata={op_type="Mul" op_name="adam/Mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.142 = f32[784,10]{1,0} broadcast(f32[] %constant.141), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %multiply.143 = f32[784,10]{1,0} multiply(f32[784,10]{1,0} %subtract.140, f32[784,10]{1,0} %broadcast.142), metadata={op_type="Mul" op_name="adam/Mul_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %add.144 = f32[784,10]{1,0} add(f32[784,10]{1,0} %arg7.8, f32[784,10]{1,0} %multiply.143), metadata={op_type="AssignAddVariableOp" op_name="adam/AssignAddVariableOp_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %sqrt.145 = f32[784,10]{1,0} sqrt(f32[784,10]{1,0} %add.144), metadata={op_type="Sqrt" op_name="adam/Sqrt_1"}
  %constant.146 = f32[] constant(1e-07), metadata={op_type="AddV2" op_name="adam/Add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.147 = f32[784,10]{1,0} broadcast(f32[] %constant.146), dimensions={}, metadata={op_type="AddV2" op_name="adam/Add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %add.148 = f32[784,10]{1,0} add(f32[784,10]{1,0} %sqrt.145, f32[784,10]{1,0} %broadcast.147), metadata={op_type="AddV2" op_name="adam/Add_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %divide.166 = f32[784,10]{1,0} divide(f32[784,10]{1,0} %multiply.165, f32[784,10]{1,0} %add.148), metadata={op_type="RealDiv" op_name="adam/truediv_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %subtract.167 = f32[784,10]{1,0} subtract(f32[784,10]{1,0} %arg2.3, f32[784,10]{1,0} %divide.166), metadata={op_type="AssignSubVariableOp" op_name="adam/AssignSubVariableOp" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %reshape.177 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %subtract.167), metadata={op_name="XLA_Retvals"}
  %copy.178 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.177), metadata={op_name="XLA_Retvals"}
  %arg8.9 = f32[10]{0} parameter(8), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %convert.126 = f32[10000,10]{1,0} convert(f32[10000,10]{1,0} %multiply.125), metadata={op_type="Sum" op_name="gradient_tape/dense_1/Add/Sum" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.127 = f32[] constant(0), metadata={op_type="Sum" op_name="gradient_tape/dense_1/Add/Sum" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %convert.128 = f32[] convert(f32[] %constant.127), metadata={op_type="Sum" op_name="gradient_tape/dense_1/Add/Sum" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %reduce.133 = f32[10]{0} reduce(f32[10000,10]{1,0} %convert.126, f32[] %convert.128), dimensions={0}, to_apply=%gradient_tape_dense_1_Add_Sum-reduction.129, metadata={op_type="Sum" op_name="gradient_tape/dense_1/Add/Sum" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %convert.134 = f32[10]{0} convert(f32[10]{0} %reduce.133), metadata={op_type="Sum" op_name="gradient_tape/dense_1/Add/Sum" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %reshape.135 = f32[1,10]{1,0} reshape(f32[10]{0} %convert.134), metadata={op_type="Sum" op_name="gradient_tape/dense_1/Add/Sum" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %reshape.136 = f32[10]{0} reshape(f32[1,10]{1,0} %reshape.135), metadata={op_type="Reshape" op_name="gradient_tape/dense_1/Add/Reshape" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %subtract.168 = f32[10]{0} subtract(f32[10]{0} %reshape.136, f32[10]{0} %arg8.9), metadata={op_type="Sub" op_name="adam/Sub_6" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.169 = f32[] constant(0.1), metadata={op_type="Mul" op_name="adam/Mul_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.170 = f32[10]{0} broadcast(f32[] %constant.169), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %multiply.171 = f32[10]{0} multiply(f32[10]{0} %subtract.168, f32[10]{0} %broadcast.170), metadata={op_type="Mul" op_name="adam/Mul_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %add.172 = f32[10]{0} add(f32[10]{0} %arg8.9, f32[10]{0} %multiply.171), metadata={op_type="AssignAddVariableOp" op_name="adam/AssignAddVariableOp_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.36 = f32[] constant(1), metadata={op_type="Sub" op_name="adam/sub_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.34 = f32[] constant(0.999), metadata={op_type="Pow" op_name="adam/Pow_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.27 = s64[] constant(1), metadata={op_type="AddV2" op_name="adam/add_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %add.28 = s64[] add(s64[] %arg5.6, s64[] %constant.27), metadata={op_type="AddV2" op_name="adam/add_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %convert.29 = f32[] convert(s64[] %add.28), metadata={op_type="Cast" op_name="adam/Cast_6" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %power.35 = f32[] power(f32[] %constant.34, f32[] %convert.29), metadata={op_type="Pow" op_name="adam/Pow_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %subtract.37 = f32[] subtract(f32[] %constant.36, f32[] %power.35), metadata={op_type="Sub" op_name="adam/sub_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %sqrt.38 = f32[] sqrt(f32[] %subtract.37), metadata={op_type="Sqrt" op_name="adam/Sqrt_2"}
  %multiply.39 = f32[] multiply(f32[] %arg4.5, f32[] %sqrt.38), metadata={op_type="Mul" op_name="adam/mul_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.32 = f32[] constant(1), metadata={op_type="Sub" op_name="adam/sub_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.30 = f32[] constant(0.9), metadata={op_type="Pow" op_name="adam/Pow_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %power.31 = f32[] power(f32[] %constant.30, f32[] %convert.29), metadata={op_type="Pow" op_name="adam/Pow_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %subtract.33 = f32[] subtract(f32[] %constant.32, f32[] %power.31), metadata={op_type="Sub" op_name="adam/sub_5" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %divide.40 = f32[] divide(f32[] %multiply.39, f32[] %subtract.33), metadata={op_type="RealDiv" op_name="adam/truediv_2" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.173 = f32[10]{0} broadcast(f32[] %divide.40), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_7" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %multiply.174 = f32[10]{0} multiply(f32[10]{0} %add.172, f32[10]{0} %broadcast.173), metadata={op_type="Mul" op_name="adam/Mul_7" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %arg9.10 = f32[10]{0} parameter(9), parameter_replication={false}, metadata={op_name="XLA_Args"}
  %multiply.149 = f32[10]{0} multiply(f32[10]{0} %reshape.136, f32[10]{0} %reshape.136), metadata={op_type="Square" op_name="adam/Square_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %subtract.150 = f32[10]{0} subtract(f32[10]{0} %multiply.149, f32[10]{0} %arg9.10), metadata={op_type="Sub" op_name="adam/Sub_7" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %constant.151 = f32[] constant(0.001), metadata={op_type="Mul" op_name="adam/Mul_6" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.152 = f32[10]{0} broadcast(f32[] %constant.151), dimensions={}, metadata={op_type="Mul" op_name="adam/Mul_6" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %multiply.153 = f32[10]{0} multiply(f32[10]{0} %subtract.150, f32[10]{0} %broadcast.152), metadata={op_type="Mul" op_name="adam/Mul_6" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %add.154 = f32[10]{0} add(f32[10]{0} %arg9.10, f32[10]{0} %multiply.153), metadata={op_type="AssignAddVariableOp" op_name="adam/AssignAddVariableOp_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %sqrt.155 = f32[10]{0} sqrt(f32[10]{0} %add.154), metadata={op_type="Sqrt" op_name="adam/Sqrt_3"}
  %constant.156 = f32[] constant(1e-07), metadata={op_type="AddV2" op_name="adam/Add_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %broadcast.157 = f32[10]{0} broadcast(f32[] %constant.156), dimensions={}, metadata={op_type="AddV2" op_name="adam/Add_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %add.158 = f32[10]{0} add(f32[10]{0} %sqrt.155, f32[10]{0} %broadcast.157), metadata={op_type="AddV2" op_name="adam/Add_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %divide.175 = f32[10]{0} divide(f32[10]{0} %multiply.174, f32[10]{0} %add.158), metadata={op_type="RealDiv" op_name="adam/truediv_3" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %subtract.176 = f32[10]{0} subtract(f32[10]{0} %arg3.4, f32[10]{0} %divide.175), metadata={op_type="AssignSubVariableOp" op_name="adam/AssignSubVariableOp_1" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %reshape.179 = f32[10]{0} reshape(f32[10]{0} %subtract.176), metadata={op_name="XLA_Retvals"}
  %copy.180 = f32[10]{0} copy(f32[10]{0} %reshape.179), metadata={op_name="XLA_Retvals"}
  %constant.41 = s64[] constant(1), metadata={op_type="AddV2" op_name="adam/add_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %add.42 = s64[] add(s64[] %arg5.6, s64[] %constant.41), metadata={op_type="AddV2" op_name="adam/add_4" source_file="/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/framework/ops.py" source_line=1177}
  %reshape.181 = s64[] reshape(s64[] %add.42), metadata={op_name="XLA_Retvals"}
  %copy.182 = s64[] copy(s64[] %reshape.181), metadata={op_name="XLA_Retvals"}
  %reshape.183 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %add.163), metadata={op_name="XLA_Retvals"}
  %copy.184 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.183), metadata={op_name="XLA_Retvals"}
  %reshape.185 = f32[784,10]{1,0} reshape(f32[784,10]{1,0} %add.144), metadata={op_name="XLA_Retvals"}
  %copy.186 = f32[784,10]{1,0} copy(f32[784,10]{1,0} %reshape.185), metadata={op_name="XLA_Retvals"}
  %reshape.187 = f32[10]{0} reshape(f32[10]{0} %add.172), metadata={op_name="XLA_Retvals"}
  %copy.188 = f32[10]{0} copy(f32[10]{0} %reshape.187), metadata={op_name="XLA_Retvals"}
  %reshape.189 = f32[10]{0} reshape(f32[10]{0} %add.154), metadata={op_name="XLA_Retvals"}
  %copy.190 = f32[10]{0} copy(f32[10]{0} %reshape.189), metadata={op_name="XLA_Retvals"}
  ROOT %tuple.191 = (f32[784,10]{1,0}, f32[10]{0}, s64[], f32[784,10]{1,0}, f32[784,10]{1,0}, /*index=5*/f32[10]{0}, f32[10]{0}) tuple(f32[784,10]{1,0} %copy.178, f32[10]{0} %copy.180, s64[] %copy.182, f32[784,10]{1,0} %copy.184, f32[784,10]{1,0} %copy.186, /*index=5*/f32[10]{0} %copy.188, f32[10]{0} %copy.190), metadata={op_name="XLA_Retvals"}
}