Skip to content
Snippets Groups Projects
mnist_perceiver.ipynb 11.3 KiB
Newer Older
heilep's avatar
heilep committed
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/heilep/.local/lib/python3.9/site-packages/tensorflow_addons/utils/ensure_tf_install.py:53: UserWarning: Tensorflow Addons supports using Python ops for all Tensorflow versions above or equal to 2.9.0 and strictly below 2.12.0 (nightly versions are not supported). \n",
      " The versions of TensorFlow you are currently using is 2.7.1 and is not supported. \n",
      "Some things might work, some things might not.\n",
      "If you were to encounter a bug, do not file an issue.\n",
      "If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. \n",
      "You can find the compatibility matrix in TensorFlow Addon's readme:\n",
      "https://github.com/tensorflow/addons\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow_addons as tfa\n",
    "\n",
    "from nn.models.perceiver.model import Perceiver, LogitsPerceiver\n",
    "from nn.data_utils.datasets import mnist\n",
    "\n",
    "tf.get_logger().setLevel('ERROR')\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, val_data = mnist(128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create Perceiver object with hypter params\n",
    "perceiver = Perceiver(\n",
    "    freq_base=2, input_channels=1, input_axis=2,\n",
    "    num_freq_bands=32,\n",
    "    max_freq=112.,\n",
    "    blocks=1,\n",
    "    latent_attentions=1,\n",
    "    num_latents=32,\n",
    "    latent_dim=32,\n",
    "    cross_dim_head=32,\n",
    "    latent_dim_head=32,\n",
    "    cross_heads=8,\n",
    "    latent_heads=1,\n",
    "    share_weights=True,\n",
    "    share_weights_layer_0=False,\n",
    "    attn_dropout=0.0,\n",
    "    dropout=0.0,\n",
    "    outputs=10\n",
    ")\n",
    "# For logits \n",
    "model = LogitsPerceiver(perceiver, units=None, dropout=0.4, flatten=False, output_activation=None)\n",
    "model.compile(\n",
    "    optimizer=tfa.optimizers.LAMB(learning_rate=1e-3, weight_decay=1e-5), # tf.keras.optimizers.Adam(0.001),\n",
    "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]\n",
    ")\n",
    "cbs = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"logits_perceiver\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " perceiver_1 (Perceiver)     multiple                  169430    \n",
      "                                                                 \n",
      " dense (Dense)               multiple                  330       \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 169,760\n",
      "Trainable params: 169,760\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"perceiver_1\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " BlockReapeaterPerceiverBloc  multiple                 168406    \n",
      " k (BlockRepeater)                                               \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 169,430\n",
      "Trainable params: 169,430\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"BlockReapeaterPerceiverBlock\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " PerceiverBlock_0 (Perceiver  multiple                 168406    \n",
      " Block)                                                          \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 168,406\n",
      "Trainable params: 168,406\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"PerceiverBlock_0\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " CrossAttention (Attention)  multiple                  141056    \n",
      "                                                                 \n",
      " BlockReapeaterLatentAttenti  multiple                 25632     \n",
      " on (BlockRepeater)                                              \n",
      "                                                                 \n",
      " layer_normalization_12 (Lay  multiple                 262       \n",
      " erNormalization)                                                \n",
      "                                                                 \n",
      " layer_normalization_13 (Lay  multiple                 64        \n",
      " erNormalization)                                                \n",
      "                                                                 \n",
      " layer_normalization_14 (Lay  multiple                 512       \n",
      " erNormalization)                                                \n",
      "                                                                 \n",
      " layer_normalization_15 (Lay  multiple                 64        \n",
      " erNormalization)                                                \n",
      "                                                                 \n",
      " LatentFeedForward_0 (Dense)  multiple                 528       \n",
      "                                                                 \n",
      " LatentFeedForward_1 (Dense)  multiple                 288       \n",
      "                                                                 \n",
      " geglu_3 (geglu)             multiple                  0         \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 168,406\n",
      "Trainable params: 168,406\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Model: \"BlockReapeaterLatentAttention\"\n",
      "_________________________________________________________________\n",
      " Layer (type)                Output Shape              Param #   \n",
      "=================================================================\n",
      " LatentAttention_0 (Attentio  multiple                 25632     \n",
      " n)                                                              \n",
      "                                                                 \n",
      "=================================================================\n",
      "Total params: 25,632\n",
      "Trainable params: 25,632\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.test_step([x for x in train.take(1)][0])\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "469/469 [==============================] - 11s 15ms/step - sparse_categorical_accuracy: 0.4946 - loss: 1.4429 - val_sparse_categorical_accuracy: 0.7998 - val_val_loss: 0.4963\n",
      "Epoch 2/10\n",
      "469/469 [==============================] - 6s 13ms/step - sparse_categorical_accuracy: 0.8569 - loss: 0.4699 - val_sparse_categorical_accuracy: 0.8739 - val_val_loss: 0.2862\n",
      "Epoch 3/10\n",
      "469/469 [==============================] - 6s 13ms/step - sparse_categorical_accuracy: 0.8961 - loss: 0.3313 - val_sparse_categorical_accuracy: 0.9002 - val_val_loss: 0.1654\n",
      "Epoch 4/10\n",
      "469/469 [==============================] - 6s 13ms/step - sparse_categorical_accuracy: 0.9183 - loss: 0.2659 - val_sparse_categorical_accuracy: 0.9298 - val_val_loss: 0.1699\n",
      "Epoch 5/10\n",
      "469/469 [==============================] - 6s 13ms/step - sparse_categorical_accuracy: 0.9284 - loss: 0.2294 - val_sparse_categorical_accuracy: 0.9338 - val_val_loss: 0.1797\n",
      "Epoch 6/10\n",
      "469/469 [==============================] - 6s 13ms/step - sparse_categorical_accuracy: 0.9378 - loss: 0.1996 - val_sparse_categorical_accuracy: 0.9460 - val_val_loss: 0.1128\n",
      "Epoch 7/10\n",
      "469/469 [==============================] - 6s 13ms/step - sparse_categorical_accuracy: 0.9433 - loss: 0.1824 - val_sparse_categorical_accuracy: 0.9461 - val_val_loss: 0.0863\n",
      "Epoch 8/10\n",
      "469/469 [==============================] - 6s 13ms/step - sparse_categorical_accuracy: 0.9470 - loss: 0.1688 - val_sparse_categorical_accuracy: 0.9496 - val_val_loss: 0.1161\n",
      "Epoch 9/10\n",
      "469/469 [==============================] - 6s 13ms/step - sparse_categorical_accuracy: 0.9514 - loss: 0.1544 - val_sparse_categorical_accuracy: 0.9544 - val_val_loss: 0.0592\n",
      "Epoch 10/10\n",
      "469/469 [==============================] - 6s 13ms/step - sparse_categorical_accuracy: 0.9548 - loss: 0.1461 - val_sparse_categorical_accuracy: 0.9552 - val_val_loss: 0.0256\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7fc250eb8a90>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(train, epochs=10, validation_data=val_data, callbacks=cbs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}