diff --git a/experiment_6/weight_loading_code_ex_6.ipynb b/experiment_6/weight_loading_code_ex_6.ipynb
deleted file mode 100644
index f960f04419c7bec60471c8f88486ab01ed63e508..0000000000000000000000000000000000000000
--- a/experiment_6/weight_loading_code_ex_6.ipynb
+++ /dev/null
@@ -1,1311 +0,0 @@
-{
-  "nbformat": 4,
-  "nbformat_minor": 0,
-  "metadata": {
-    "colab": {
-      "provenance": [],
-      "toc_visible": true,
-      "gpuType": "T4"
-    },
-    "kernelspec": {
-      "name": "python3",
-      "display_name": "Python 3"
-    },
-    "language_info": {
-      "name": "python"
-    },
-    "accelerator": "GPU"
-  },
-  "cells": [
-    {
-      "cell_type": "markdown",
-      "source": [
-        "# Environment Setting"
-      ],
-      "metadata": {
-        "id": "cQtpOYGIzAa5"
-      }
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "import torch\n",
-        "\n",
-        "if torch.__version__ != '2.5.1+cu124':\n",
-        "    !pip install torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124 -U --quiet\n",
-        "    print(\"PyTorch version updated to 2.5.1.\")\n",
-        "else:\n",
-        "    print(\"PyTorch is already at the correct version (2.5.1).\")"
-      ],
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "fgItXb9XzDXk",
-        "outputId": "48748967-b2f2-44b9-8919-0e5a27023f46"
-      },
-      "execution_count": 1,
-      "outputs": [
-        {
-          "output_type": "stream",
-          "name": "stdout",
-          "text": [
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m908.3/908.3 MB\u001b[0m \u001b[31m840.4 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.3/7.3 MB\u001b[0m \u001b[31m122.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m94.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m93.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m59.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m125.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m1.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m106.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.5/209.5 MB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[?25hPyTorch version updated to 2.5.1.\n"
-          ]
-        }
-      ]
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "!pip install d2l==1.0.3 --quiet\n",
-        "!pip install scipy --quiet\n",
-        "!pip install torchmetrics --quiet"
-      ],
-      "metadata": {
-        "id": "hyr4OflizGsC",
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "outputId": "a3f4f6b6-62f7-41e5-d409-25ce0cf143f4"
-      },
-      "execution_count": 2,
-      "outputs": [
-        {
-          "output_type": "stream",
-          "name": "stdout",
-          "text": [
-            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.9/58.9 kB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m111.7/111.7 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.6/11.6 MB\u001b[0m \u001b[31m48.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m17.1/17.1 MB\u001b[0m \u001b[31m46.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.2/12.2 MB\u001b[0m \u001b[31m48.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.6/62.6 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m34.1/34.1 MB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.3/98.3 kB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m125.0/125.0 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m95.0/95.0 kB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m52.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
-            "google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.0.3 which is incompatible.\n",
-            "google-colab 1.0.0 requires requests==2.32.3, but you have requests 2.31.0 which is incompatible.\n",
-            "imbalanced-learn 0.13.0 requires numpy<3,>=1.24.3, but you have numpy 1.23.5 which is incompatible.\n",
-            "jaxlib 0.5.1 requires numpy>=1.25, but you have numpy 1.23.5 which is incompatible.\n",
-            "jaxlib 0.5.1 requires scipy>=1.11.1, but you have scipy 1.10.1 which is incompatible.\n",
-            "cvxpy 1.6.5 requires scipy>=1.11.0, but you have scipy 1.10.1 which is incompatible.\n",
-            "tensorflow 2.18.0 requires numpy<2.1.0,>=1.26.0, but you have numpy 1.23.5 which is incompatible.\n",
-            "xarray 2025.1.2 requires numpy>=1.24, but you have numpy 1.23.5 which is incompatible.\n",
-            "xarray 2025.1.2 requires pandas>=2.1, but you have pandas 2.0.3 which is incompatible.\n",
-            "thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.23.5 which is incompatible.\n",
-            "bigframes 2.1.0 requires numpy>=1.24.0, but you have numpy 1.23.5 which is incompatible.\n",
-            "jax 0.5.2 requires numpy>=1.25, but you have numpy 1.23.5 which is incompatible.\n",
-            "jax 0.5.2 requires scipy>=1.11.1, but you have scipy 1.10.1 which is incompatible.\n",
-            "plotnine 0.14.5 requires matplotlib>=3.8.0, but you have matplotlib 3.7.2 which is incompatible.\n",
-            "plotnine 0.14.5 requires pandas>=2.2.0, but you have pandas 2.0.3 which is incompatible.\n",
-            "mizani 0.13.3 requires pandas>=2.2.0, but you have pandas 2.0.3 which is incompatible.\n",
-            "blosc2 3.3.1 requires numpy>=1.26, but you have numpy 1.23.5 which is incompatible.\n",
-            "albumentations 2.0.5 requires numpy>=1.24.4, but you have numpy 1.23.5 which is incompatible.\n",
-            "chex 0.1.89 requires numpy>=1.24.1, but you have numpy 1.23.5 which is incompatible.\n",
-            "pymc 5.21.2 requires numpy>=1.25.0, but you have numpy 1.23.5 which is incompatible.\n",
-            "scikit-image 0.25.2 requires numpy>=1.24, but you have numpy 1.23.5 which is incompatible.\n",
-            "scikit-image 0.25.2 requires scipy>=1.11.4, but you have scipy 1.10.1 which is incompatible.\n",
-            "treescope 0.1.9 requires numpy>=1.25.2, but you have numpy 1.23.5 which is incompatible.\n",
-            "albucore 0.0.23 requires numpy>=1.24.4, but you have numpy 1.23.5 which is incompatible.\u001b[0m\u001b[31m\n",
-            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m961.5/961.5 kB\u001b[0m \u001b[31m51.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[?25h"
-          ]
-        }
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "source": [
-        "**Now restart the runtime for the pytorch version change to take effect.**"
-      ],
-      "metadata": {
-        "id": "PdFfGJ92EeY0"
-      }
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "# Check torch version\n",
-        "import torch\n",
-        "torch.__version__"
-      ],
-      "metadata": {
-        "id": "yqDp1fm8DzI4",
-        "colab": {
-          "base_uri": "https://localhost:8080/",
-          "height": 35
-        },
-        "outputId": "d1257111-725b-422b-824f-7d34855db4fb"
-      },
-      "execution_count": 1,
-      "outputs": [
-        {
-          "output_type": "execute_result",
-          "data": {
-            "text/plain": [
-              "'2.5.1+cu124'"
-            ],
-            "application/vnd.google.colaboratory.intrinsic+json": {
-              "type": "string"
-            }
-          },
-          "metadata": {},
-          "execution_count": 1
-        }
-      ]
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "# Set up device\n",
-        "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
-        "print(f\"Using device: {device}\")"
-      ],
-      "metadata": {
-        "id": "odsutW3jIl8Z",
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "outputId": "475f1871-9be1-4903-af33-debede7d744b"
-      },
-      "execution_count": 2,
-      "outputs": [
-        {
-          "output_type": "stream",
-          "name": "stdout",
-          "text": [
-            "Using device: cuda:0\n"
-          ]
-        }
-      ]
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "import os\n",
-        "import glob\n",
-        "import json\n",
-        "import random\n",
-        "from collections import Counter\n",
-        "\n",
-        "import torch.nn as nn\n",
-        "import torch.nn.functional as F\n",
-        "import torchvision.transforms.functional as TF\n",
-        "from torch.utils.data import Dataset, DataLoader, ConcatDataset\n",
-        "from d2l import torch as d2l\n",
-        "import numpy as np\n",
-        "from scipy.optimize import linear_sum_assignment\n",
-        "\n",
-        "import torchvision\n",
-        "import torchvision.models as models\n",
-        "from torchvision import datasets, transforms\n",
-        "from torchvision.ops.boxes import box_area\n",
-        "from torchvision.transforms import ToPILImage\n",
-        "from torchmetrics.detection import MeanAveragePrecision\n",
-        "from PIL import Image, ImageDraw\n",
-        "\n",
-        "import matplotlib.pyplot as plt\n",
-        "from mpl_toolkits.axes_grid1 import ImageGrid\n",
-        "from tqdm import tqdm\n",
-        "import csv"
-      ],
-      "metadata": {
-        "id": "SaOd5Y_tzH2j"
-      },
-      "execution_count": 3,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "source": [
-        "**Here we also need to load functions in the appendix at the end.**\n",
-        "\n",
-        "\n"
-      ],
-      "metadata": {
-        "id": "bLSMisPp7DF8"
-      }
-    },
-    {
-      "cell_type": "markdown",
-      "source": [
-        "# Datasets Loading"
-      ],
-      "metadata": {
-        "id": "fu6S7M8KyM6B"
-      }
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "!git clone https://git.wur.nl/wei044/deeplearning-mbe-8.git"
-      ],
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "eXZRwLl84T0a",
-        "outputId": "a70dcd64-f3d0-4413-92e2-409a757a5344"
-      },
-      "execution_count": 5,
-      "outputs": [
-        {
-          "output_type": "stream",
-          "name": "stdout",
-          "text": [
-            "Cloning into 'deeplearning-mbe-8'...\n",
-            "remote: Enumerating objects: 1392, done.\u001b[K\n",
-            "remote: Counting objects: 100% (450/450), done.\u001b[K\n",
-            "remote: Compressing objects: 100% (447/447), done.\u001b[K\n",
-            "remote: Total 1392 (delta 49), reused 329 (delta 1), pack-reused 942 (from 1)\u001b[K\n",
-            "Receiving objects: 100% (1392/1392), 1.04 GiB | 16.51 MiB/s, done.\n",
-            "Resolving deltas: 100% (176/176), done.\n",
-            "Updating files: 100% (423/423), done.\n"
-          ]
-        }
-      ]
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "# Get the current working directory\n",
-        "cwd = os.getcwd()\n",
-        "print(f\"Current working directory: {cwd}\")\n",
-        "print(\"Files in current directory:\")\n",
-        "print(os.listdir(\".\"))"
-      ],
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "mdzQqYtqa0uq",
-        "outputId": "c6de2968-3426-4d45-a6cd-6beb5d3ecd02"
-      },
-      "execution_count": 6,
-      "outputs": [
-        {
-          "output_type": "stream",
-          "name": "stdout",
-          "text": [
-            "Current working directory: /content\n",
-            "Files in current directory:\n",
-            "['.config', 'deeplearning-mbe-8', 'sample_data']\n"
-          ]
-        }
-      ]
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "data_dir = 'deeplearning-mbe-8'\n",
-        "data_dir"
-      ],
-      "metadata": {
-        "id": "JfBNHsIpyUWN",
-        "colab": {
-          "base_uri": "https://localhost:8080/",
-          "height": 35
-        },
-        "outputId": "c3bde184-3806-432c-c1f3-7bd0c72272f8"
-      },
-      "execution_count": 7,
-      "outputs": [
-        {
-          "output_type": "execute_result",
-          "data": {
-            "text/plain": [
-              "'deeplearning-mbe-8'"
-            ],
-            "application/vnd.google.colaboratory.intrinsic+json": {
-              "type": "string"
-            }
-          },
-          "metadata": {},
-          "execution_count": 7
-        }
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "source": [
-        "# Model Inference"
-      ],
-      "metadata": {
-        "id": "7RhM893zFTpi"
-      }
-    },
-    {
-      "cell_type": "markdown",
-      "source": [
-        "#### Load pre-trained model"
-      ],
-      "metadata": {
-        "id": "37x-wyRYxyuC"
-      }
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "class PredictionHead(nn.Module):\n",
-        "    def __init__(self, in_channels, num_classes):\n",
-        "        super(PredictionHead, self).__init__()\n",
-        "\n",
-        "        self.conv_layers = nn.Sequential(\n",
-        "            nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, stride=2, padding=1),\n",
-        "            nn.BatchNorm2d(in_channels // 2),\n",
-        "            nn.LeakyReLU(negative_slope=0.1),\n",
-        "            nn.Conv2d(in_channels // 2, in_channels // 4, kernel_size=3, stride=1, padding=1),\n",
-        "            nn.BatchNorm2d(in_channels // 4),\n",
-        "            nn.LeakyReLU(negative_slope=0.1),\n",
-        "            nn.Conv2d(in_channels // 4, in_channels // 4, kernel_size=1),\n",
-        "            nn.BatchNorm2d(in_channels // 4),\n",
-        "            nn.LeakyReLU(negative_slope=0.1),\n",
-        "        )\n",
-        "\n",
-        "        self.box_predictor = nn.Sequential(\n",
-        "            nn.Conv2d(in_channels=in_channels // 4, out_channels=in_channels // 4, kernel_size=(3, 3), stride=(1, 1), padding=1),\n",
-        "            nn.BatchNorm2d(in_channels // 4),\n",
-        "            nn.LeakyReLU(negative_slope=0.1),\n",
-        "            nn.Conv2d(in_channels=in_channels // 4, out_channels=4, kernel_size=(3, 3), stride=(1, 1), padding=1),\n",
-        "            nn.Sigmoid(),\n",
-        "        )\n",
-        "\n",
-        "        # Classifier\n",
-        "        self.classifier = nn.Sequential(\n",
-        "            nn.Conv2d(in_channels=in_channels // 4, out_channels=in_channels // 4, kernel_size=(3, 3), stride=(1, 1), padding=1),\n",
-        "            nn.BatchNorm2d(in_channels // 4),\n",
-        "            nn.LeakyReLU(negative_slope=0.1),\n",
-        "            nn.Conv2d(in_channels=in_channels // 4, out_channels=num_classes, kernel_size=(3, 3), stride=(1, 1), padding=1),\n",
-        "        )\n",
-        "\n",
-        "        # Objectness score predictor\n",
-        "        self.objectness_predictor = nn.Sequential(\n",
-        "            nn.Conv2d(in_channels=in_channels // 4, out_channels=in_channels // 4, kernel_size=(3, 3), stride=(1, 1), padding=1),\n",
-        "            nn.BatchNorm2d(in_channels // 4),\n",
-        "            nn.LeakyReLU(negative_slope=0.1),\n",
-        "            nn.Conv2d(in_channels=in_channels // 4, out_channels=1, kernel_size=(3, 3), stride=(1, 1), padding=1),\n",
-        "            nn.Sigmoid(),\n",
-        "        )\n",
-        "\n",
-        "    def forward(self, x):\n",
-        "        x = self.conv_layers(x)\n",
-        "        cls_logits = self.classifier(x)\n",
-        "        bbox_pred = self.box_predictor(x)\n",
-        "        objectness_pred = self.objectness_predictor(x)\n",
-        "\n",
-        "        # Add relative grid position to the cx and cy predictions of each box\n",
-        "        grid_size = x.shape[-1]\n",
-        "        grid_y, grid_x = torch.meshgrid(\n",
-        "            torch.arange(grid_size, device=x.device), torch.arange(grid_size, device=x.device), indexing=\"ij\"\n",
-        "        )\n",
-        "        cx = bbox_pred[:, 0, :, :]\n",
-        "        cy = bbox_pred[:, 1, :, :]\n",
-        "        w = bbox_pred[:, 2, :, :]\n",
-        "        h = bbox_pred[:, 3, :, :]\n",
-        "        cx = (cx + (grid_x)) / grid_size\n",
-        "        cy = (cy + (grid_y)) / grid_size\n",
-        "        bbox_pred = torch.stack([cx, cy, w, h], dim=1)\n",
-        "\n",
-        "        return cls_logits, bbox_pred, objectness_pred"
-      ],
-      "metadata": {
-        "id": "rhjLe42Byg2d"
-      },
-      "execution_count": 8,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "class FruitDetectorMultiScale(nn.Module):\n",
-        "    def __init__(self, n_classes, pretrained=True):\n",
-        "        super(FruitDetectorMultiScale, self).__init__()\n",
-        "\n",
-        "        self.n_classes = n_classes\n",
-        "\n",
-        "        # Backbone\n",
-        "        backbone = models.resnet34(pretrained=pretrained)\n",
-        "        self.backbone = nn.Sequential(*list(backbone.children())[:-3])\n",
-        "\n",
-        "        # Layer 1\n",
-        "        self.block_1 = nn.Sequential(*list(backbone.children())[-3])\n",
-        "\n",
-        "        self.prediction_head = PredictionHead(512, self.n_classes)\n",
-        "        self.prediction_head_2 = PredictionHead(256, self.n_classes)\n",
-        "\n",
-        "    def forward(self, x):\n",
-        "        features = self.backbone(x)\n",
-        "        y_ = self.block_1(features)\n",
-        "\n",
-        "        cls_logits, bbox_pred, objectness_pred = self.prediction_head(y_)\n",
-        "        cls_preds = cls_logits.flatten(2, 3).permute(0, 2, 1)\n",
-        "        box_preds = bbox_pred.flatten(2, 3).permute(0, 2, 1)\n",
-        "        obj_preds = objectness_pred.flatten(2, 3).permute(0, 2, 1)\n",
-        "\n",
-        "        cls_logits2, bbox_pred2, objectness_pred2 = self.prediction_head_2(features)\n",
-        "        cls_preds2 = cls_logits2.flatten(2, 3).permute(0, 2, 1)\n",
-        "        box_preds2 = bbox_pred2.flatten(2, 3).permute(0, 2, 1)\n",
-        "        obj_preds2 = objectness_pred2.flatten(2, 3).permute(0, 2, 1)\n",
-        "\n",
-        "        # Concatenate predictions from all heads\n",
-        "        pred_logits = torch.cat([cls_preds, cls_preds2], dim=1)\n",
-        "        pred_boxes = torch.cat([box_preds, box_preds2], dim=1)\n",
-        "        pred_objectness = torch.cat([obj_preds, obj_preds2], dim=1)\n",
-        "\n",
-        "        return {\"pred_logits\": pred_logits, \"pred_boxes\": pred_boxes, \"pred_objectness\": pred_objectness}"
-      ],
-      "metadata": {
-        "id": "Gj3Hxg9yk8Z3"
-      },
-      "execution_count": 9,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "model = FruitDetectorMultiScale(n_classes=4)\n",
-        "model = model.to(device)"
-      ],
-      "metadata": {
-        "id": "2qsGhgi1DDM5",
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "outputId": "9fb34326-ebf9-4104-fbc6-448bd0d8d69c"
-      },
-      "execution_count": 10,
-      "outputs": [
-        {
-          "output_type": "stream",
-          "name": "stderr",
-          "text": [
-            "/usr/local/lib/python3.11/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
-            "  warnings.warn(\n",
-            "/usr/local/lib/python3.11/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet34_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet34_Weights.DEFAULT` to get the most up-to-date weights.\n",
-            "  warnings.warn(msg)\n",
-            "Downloading: \"https://download.pytorch.org/models/resnet34-b627a593.pth\" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth\n",
-            "100%|██████████| 83.3M/83.3M [00:00<00:00, 141MB/s]\n"
-          ]
-        }
-      ]
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "best_weights_path = os.path.join(data_dir, 'experiment_6', 'best_weights_ex_6.params')\n",
-        "best_weights_path"
-      ],
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/",
-          "height": 35
-        },
-        "id": "YLz3mTXz6aCB",
-        "outputId": "5c2ce5a0-9099-4ee5-82a5-7c9a57db3eb6"
-      },
-      "execution_count": 11,
-      "outputs": [
-        {
-          "output_type": "execute_result",
-          "data": {
-            "text/plain": [
-              "'deeplearning-mbe-8/experiment_6/best_weights_ex_6.params'"
-            ],
-            "application/vnd.google.colaboratory.intrinsic+json": {
-              "type": "string"
-            }
-          },
-          "metadata": {},
-          "execution_count": 11
-        }
-      ]
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "# Load the weights\n",
-        "model.load_state_dict(torch.load(best_weights_path))\n",
-        "# Set the model to evaluation mode\n",
-        "model.eval()\n",
-        "print(\"Pre-trained model loaded successfully!\")"
-      ],
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "mGTtzpt-7nUh",
-        "outputId": "374e893e-6247-4e75-f8ef-a9e6d55c1b99"
-      },
-      "execution_count": 12,
-      "outputs": [
-        {
-          "output_type": "stream",
-          "name": "stdout",
-          "text": [
-            "Pre-trained model loaded successfully!\n"
-          ]
-        },
-        {
-          "output_type": "stream",
-          "name": "stderr",
-          "text": [
-            "<ipython-input-12-24cc7d2a743b>:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
-            "  model.load_state_dict(torch.load(best_weights_path))\n"
-          ]
-        }
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "source": [
-        "#### Load test set"
-      ],
-      "metadata": {
-        "id": "wh3aRZlA0dcK"
-      }
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "class FruitDetectionTestDataset(torch.utils.data.Dataset):\n",
-        "    def __init__(self, img_folder, transforms=None):\n",
-        "        \"\"\"\n",
-        "        Constructor for the test dataset\n",
-        "        :param img_folder: Folder containing the test images\n",
-        "        :param transforms: List of transformations to be applied to the data\n",
-        "        \"\"\"\n",
-        "        self.img_folder = img_folder\n",
-        "        self.transforms = transforms\n",
-        "\n",
-        "        # Create a list of image paths\n",
-        "        self.img_files = [os.path.join(img_folder, f) for f in sorted(os.listdir(img_folder))\n",
-        "                         if f.lower().endswith(('.png'))]\n",
-        "\n",
-        "    def __getitem__(self, idx):\n",
-        "        # Get image path\n",
-        "        img_path = self.img_files[idx]\n",
-        "\n",
-        "        # Load image\n",
-        "        img = Image.open(img_path).convert('RGB')  # PIL Image format\n",
-        "        #img = transforms.ToTensor()(img)  # Convert PIL to Tensor\n",
-        "\n",
-        "        # Apply transforms if provided\n",
-        "        if self.transforms:\n",
-        "            img = self.transforms(img)\n",
-        "\n",
-        "        # Return the image and its path for reference\n",
-        "        return img, img_path\n",
-        "\n",
-        "    def __len__(self):\n",
-        "        return len(self.img_files)"
-      ],
-      "metadata": {
-        "id": "nVdoDQxiIDgw"
-      },
-      "execution_count": 13,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "def predict_test_submission(model, test_dataset, output_txt_path, img_size, score_threshold=0.3, nms_threshold=0.3):\n",
-        "    \"\"\"\n",
-        "    Predict bounding boxes for all images in the test dataset and save to a text file:\n",
-        "      file_name, cx, cy, w, h, label, score\n",
-        "    \"\"\"\n",
-        "    model.eval()\n",
-        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
-        "    model.to(device)\n",
-        "\n",
-        "    # Define standard image preprocessing\n",
-        "    transform = transforms.Compose([\n",
-        "        transforms.Resize((img_size, img_size)),\n",
-        "        transforms.ToTensor(),\n",
-        "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
-        "    ])\n",
-        "\n",
-        "    with open(output_txt_path, \"w\") as out_file:\n",
-        "        for i in range(len(test_dataset)):\n",
-        "            img_path = test_dataset.img_files[i]\n",
-        "            pil_img = Image.open(img_path).convert(\"RGB\")\n",
-        "            tensor_img = transform(pil_img).unsqueeze(0).to(device)\n",
-        "            boxes, scores, top_class = predict(model, tensor_img, score_threshold, nms_threshold)\n",
-        "\n",
-        "            # Write results\n",
-        "            file_name = os.path.basename(img_path)\n",
-        "            for box, lbl, sc in zip(boxes, top_class, scores):\n",
-        "                box = box / img_size\n",
-        "                cx, cy, w, h = box.tolist()\n",
-        "                out_file.write(f\"{file_name}, {cx}, {cy}, {w}, {h}, {lbl.item()+1}, {sc.item()}\\n\") # submission expects 1-based indexing for labels\n",
-        "\n",
-        "    print(f\"Predictions saved to {output_txt_path}\")"
-      ],
-      "metadata": {
-        "id": "O89jyP7rFh_i"
-      },
-      "execution_count": 14,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "test_dir = os.path.join(data_dir, 'test')\n",
-        "image_size = 256\n",
-        "\n",
-        "test_dataset = FruitDetectionTestDataset(\n",
-        "    img_folder=test_dir,\n",
-        "    transforms=None\n",
-        ")"
-      ],
-      "metadata": {
-        "id": "fJdxEB2vJ5lg"
-      },
-      "execution_count": 18,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "source": [
-        "#### Make inference on test set"
-      ],
-      "metadata": {
-        "id": "1Rpt14RpgJvC"
-      }
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "def predict_test_submission(model, test_dataset, output_txt_path, img_size, score_threshold=0.3, nms_threshold=0.3):\n",
-        "    \"\"\"\n",
-        "    Predict bounding boxes for all images in the test dataset and save to a text file:\n",
-        "      file_name, cx, cy, w, h, label, score\n",
-        "    \"\"\"\n",
-        "    model.eval()\n",
-        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
-        "    model.to(device)\n",
-        "\n",
-        "    # Define standard image preprocessing\n",
-        "    transform = transforms.Compose([\n",
-        "        transforms.Resize((img_size, img_size)),\n",
-        "        transforms.ToTensor(),\n",
-        "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
-        "    ])\n",
-        "\n",
-        "    with open(output_txt_path, \"w\") as out_file:\n",
-        "        for i in range(len(test_dataset)):\n",
-        "            img_path = test_dataset.img_files[i]\n",
-        "            pil_img = Image.open(img_path).convert(\"RGB\")\n",
-        "            tensor_img = transform(pil_img).unsqueeze(0).to(device)\n",
-        "            boxes, scores, top_class = predict(model, tensor_img, score_threshold, nms_threshold)\n",
-        "\n",
-        "            # Write results\n",
-        "            file_name = os.path.basename(img_path)\n",
-        "            for box, lbl, sc in zip(boxes, top_class, scores):\n",
-        "                box = box / img_size\n",
-        "                cx, cy, w, h = box.tolist()\n",
-        "                out_file.write(f\"{file_name}, {cx}, {cy}, {w}, {h}, {lbl.item()+1}, {sc.item()}\\n\") # submission expects 1-based indexing for labels\n",
-        "\n",
-        "    print(f\"Predictions saved to {output_txt_path}\")"
-      ],
-      "metadata": {
-        "id": "cugazvv0HfTW"
-      },
-      "execution_count": 30,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "# Load the Drive helper and mount\n",
-        "from google.colab import drive\n",
-        "drive.mount('/content/drive', force_remount=True)"
-      ],
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "outputId": "cb0af785-67d9-47e0-face-5c9a344ca9dc",
-        "id": "ZQ2hXFQ6HSgp"
-      },
-      "execution_count": 22,
-      "outputs": [
-        {
-          "output_type": "stream",
-          "name": "stdout",
-          "text": [
-            "Mounted at /content/drive\n"
-          ]
-        }
-      ]
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "output_dir = \"/content/drive/My Drive\"\n",
-        "os.makedirs(output_dir, exist_ok=True)\n",
-        "\n",
-        "output_txt_path = os.path.join(output_dir, \"output_results_ex_6.txt\")\n",
-        "output_txt_path"
-      ],
-      "metadata": {
-        "id": "Ax-_Fpk2GUEm",
-        "colab": {
-          "base_uri": "https://localhost:8080/",
-          "height": 35
-        },
-        "outputId": "6e7d4208-43e0-4314-f433-fbb14547d080"
-      },
-      "execution_count": 34,
-      "outputs": [
-        {
-          "output_type": "execute_result",
-          "data": {
-            "text/plain": [
-              "'/content/drive/My Drive/output_results_ex_6.txt'"
-            ],
-            "application/vnd.google.colaboratory.intrinsic+json": {
-              "type": "string"
-            }
-          },
-          "metadata": {},
-          "execution_count": 34
-        }
-      ]
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "predict_test_submission(model, test_dataset, output_txt_path, image_size, score_threshold=0.5, nms_threshold=0.5)"
-      ],
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "outputId": "b74055e5-65f0-4910-dca0-9c25ed819c26",
-        "id": "mD-J0Do1Lxrf"
-      },
-      "execution_count": 35,
-      "outputs": [
-        {
-          "output_type": "stream",
-          "name": "stdout",
-          "text": [
-            "Predictions saved to /content/drive/My Drive/output_results_ex_6.txt\n"
-          ]
-        }
-      ]
-    },
-    {
-      "cell_type": "markdown",
-      "source": [
-        "# Appendix: object_detection_utils"
-      ],
-      "metadata": {
-        "id": "-Nby7mGQBI0c"
-      }
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "def box_iou(boxes1, boxes2):\n",
-        "    \"\"\"\n",
-        "    Compute the Intersection over Union (IoU) between two sets of bounding boxes. The format of the bounding boxes\n",
-        "    should be in (x1, y1, x2, y2) format.\n",
-        "\n",
-        "    Args:\n",
-        "        boxes1 (torch.Tensor): A tensor of shape (N, 4) in (x1, y1, x2, y2) format.\n",
-        "        boxes2 (torch.Tensor): A tensor of shape (M, 4) in (x1, y1, x2, y2) format.\n",
-        "\n",
-        "    Returns:\n",
-        "        tuple[torch.Tensor, torch.Tensor]: A tuple of (iou, union) where:\n",
-        "            iou (torch.Tensor): A tensor of shape (N, M) containing the pairwise IoU values\n",
-        "            between the boxes in boxes1 and boxes2.\n",
-        "            union (torch.Tensor): A tensor of shape (N, M) containing the pairwise union\n",
-        "            areas between the boxes in boxes1 and boxes2.\n",
-        "    \"\"\"\n",
-        "    # Calculate boxes area\n",
-        "    area1 = box_area(boxes1)  # [N,]\n",
-        "    area2 = box_area(boxes2)  # [M,]\n",
-        "\n",
-        "    # Compute the coordinates of the intersection of each pair of bounding boxes\n",
-        "    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]\n",
-        "    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]\n",
-        "    # Need clamp(min=0) in case they do not intersect, then we want intersection to be 0\n",
-        "    wh = (rb - lt).clamp(min=0)  # [N,M,2]\n",
-        "    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]\n",
-        "\n",
-        "    # Since the size of the variables is different, pytorch broadcast them\n",
-        "    # area1[:, None] converts size from [N,] to [N,1] to help broadcasting\n",
-        "    union = area1[:, None] + area2 - inter  # [N,M]\n",
-        "\n",
-        "    iou = inter / union\n",
-        "    return iou, union\n",
-        "\n",
-        "\n",
-        "def generalized_box_iou(boxes1, boxes2):\n",
-        "    \"\"\"\n",
-        "    Computes the generalized box intersection over union (IoU) between two sets of bounding boxes.\n",
-        "    The IoU is defined as the area of overlap between the two bounding boxes divided by the area of union.\n",
-        "\n",
-        "    Args:\n",
-        "        boxes1: A tensor containing the coordinates of the bounding boxes for the first set.\n",
-        "            Shape: [batch_size, num_boxes, 4]\n",
-        "        boxes2: A tensor containing the coordinates of the bounding boxes for the second set.\n",
-        "            Shape: [batch_size, num_boxes, 4]\n",
-        "\n",
-        "    Returns:\n",
-        "        A tensor containing the generalized IoU between `boxes1` and `boxes2`.\n",
-        "            Shape: [batch_size, num_boxes1, num_boxes2]\n",
-        "    \"\"\"\n",
-        "    # Check for degenerate boxes that give Inf/NaN results\n",
-        "    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()\n",
-        "    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()\n",
-        "\n",
-        "    # Calculate the IoU and union of each pair of bounding boxes\n",
-        "    # TODO: put your code here (~1 line)\n",
-        "    iou, union = box_iou(boxes1, boxes2)\n",
-        "\n",
-        "    # Compute the coordinates of the intersection of each pair of bounding boxes\n",
-        "    lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])\n",
-        "    rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])\n",
-        "    wh = (rb - lt).clamp(min=0)  # [N,M,2]\n",
-        "\n",
-        "    # Compute the area of the bounding box that encloses both input boxes\n",
-        "    C = wh[:, :, 0] * wh[:, :, 1]\n",
-        "\n",
-        "    # TODO: put your code here (~1 line)\n",
-        "    return iou - (C - union) / C\n",
-        "\n",
-        "def get_src_permutation_idx(indices):\n",
-        "    # permute predictions following indices\n",
-        "    batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])\n",
-        "    src_idx = torch.cat([src for (src, _) in indices])\n",
-        "    return batch_idx, src_idx\n",
-        "\n",
-        "\n",
-        "class APCalculator:\n",
-        "    \"\"\"A class for calculating average precision (AP).\n",
-        "\n",
-        "    This class is built to be used in a training loop, allowing the ground truth (GTs)\n",
-        "    to be initialized once in the __init__ constructor, and then reused many times by calling\n",
-        "    the `calculate_map()` method.\n",
-        "\n",
-        "    Attributes:\n",
-        "        iou_threshold (float): The intersection over union (IoU) threshold used for the AP calculation. Defaults to 0.5.\n",
-        "        data_iter (torch.utils.data.DataLoader): A PyTorch dataloader that provides images and targets (e.g., bounding boxes)\n",
-        "            for the dataset.\n",
-        "        n_classes (int): The number of object classes.\n",
-        "        metric (MeanAveragePrecision): An instance of the MeanAveragePrecision class from torchmetrics to compute mAP.\n",
-        "\n",
-        "    Args:\n",
-        "        data_iter (torch.utils.data.DataLoader): A PyTorch dataloader that provides images and targets (e.g., bounding boxes)\n",
-        "            for the dataset.\n",
-        "        n_classes (int): The number of object classes.\n",
-        "        iou_threshold (float, optional): The intersection over union (IoU) threshold used for the AP calculation.\n",
-        "            Defaults to 0.5.\n",
-        "    \"\"\"\n",
-        "\n",
-        "    def __init__(self, data_iter):\n",
-        "        \"\"\"Initializes the APCalculator object with the specified data iterator, number of classes,\n",
-        "        and IoU threshold.\"\"\"\n",
-        "        self.data_iter = data_iter\n",
-        "        self.metric = MeanAveragePrecision(iou_type=\"bbox\", box_format=\"cxcywh\", class_metrics=True)\n",
-        "\n",
-        "    def calculate_map(self, net, nms_threshold=0.1): # Modified\n",
-        "        \"\"\"Calculates the mean average precision (mAP) for the given object detection network.\n",
-        "\n",
-        "        Args:\n",
-        "            net (torch.nn.Module): The object detection network.\n",
-        "            nms_threshold (float, optional): The non-maximum suppression (NMS) threshold. Defaults to 0.1.\n",
-        "\n",
-        "        Returns:\n",
-        "            dict: A dictionary containing the mAP and other related metrics.\n",
-        "        \"\"\"\n",
-        "        net.eval()\n",
-        "        for i, (images, targets) in enumerate(self.data_iter):\n",
-        "            preds = []\n",
-        "            GTs = []\n",
-        "            new_targets = []\n",
-        "            \"\"\"\n",
-        "            for idx in range(targets[\"labels\"].shape[0]):\n",
-        "                labels = targets[\"labels\"][idx]\n",
-        "                boxes = targets[\"boxes\"][idx]\n",
-        "                new_targets.append(\n",
-        "                    {\n",
-        "                        \"labels\": labels[labels != -1].cpu().detach(),\n",
-        "                        \"boxes\": boxes[labels != -1].cpu().detach(),\n",
-        "                    }\n",
-        "                )\n",
-        "            \"\"\"\n",
-        "            # Process each item in the batch\n",
-        "            for idx in range(len(targets)):\n",
-        "                labels = targets[idx][\"labels\"]\n",
-        "                boxes = targets[idx][\"boxes\"]\n",
-        "                new_targets.append(\n",
-        "                    {\n",
-        "                        \"labels\": labels[labels != -1].cpu().detach(),\n",
-        "                        \"boxes\": boxes[labels != -1].cpu().detach(),\n",
-        "                    }\n",
-        "                )\n",
-        "\n",
-        "            for j in range(images.shape[0]):\n",
-        "                GTs.append(\n",
-        "                    {\n",
-        "                        \"boxes\": new_targets[j][\"boxes\"],\n",
-        "                        \"labels\": new_targets[j][\"labels\"],\n",
-        "                    }\n",
-        "                )\n",
-        "            device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
-        "            images = images.to(device)\n",
-        "            outputs = net(images)\n",
-        "            outputs[\"pred_logits\"] = outputs[\"pred_logits\"].cpu().detach()  # [Bs, N, C]\n",
-        "            outputs[\"pred_boxes\"] = outputs[\"pred_boxes\"].cpu().detach()  # [Bs, N, 4]\n",
-        "            outputs[\"pred_objectness\"] = outputs[\"pred_objectness\"].cpu().detach()  # [Bs, N, 1]\n",
-        "            for j in range(images.shape[0]):\n",
-        "                prob = F.softmax(outputs[\"pred_logits\"][j], dim=1)\n",
-        "                top_p, top_class = prob.topk(1, dim=1)\n",
-        "                boxes = outputs[\"pred_boxes\"][j]\n",
-        "                scores = top_p.squeeze()\n",
-        "                top_class = top_class.squeeze()\n",
-        "                scores = outputs[\"pred_objectness\"][j].squeeze()\n",
-        "                sel_boxes_idx = torchvision.ops.nms(\n",
-        "                    boxes=box_cxcywh_to_xyxy(boxes), scores=scores, iou_threshold=nms_threshold\n",
-        "                )\n",
-        "                preds.append(\n",
-        "                    {\n",
-        "                        \"boxes\": boxes[sel_boxes_idx],\n",
-        "                        \"scores\": scores[sel_boxes_idx],\n",
-        "                        \"labels\": top_class[sel_boxes_idx],\n",
-        "                    }\n",
-        "                )\n",
-        "            self.metric.update(preds, GTs)\n",
-        "        result = self.metric.compute()\n",
-        "        self.metric.reset()\n",
-        "        return result\n",
-        "\n",
-        "\n",
-        "def plot_bbox(img, boxes, labels):\n",
-        "    \"\"\"\n",
-        "    Plot bounding boxes on the given image with labels.\n",
-        "\n",
-        "    Bounding boxes are defined as tuples (x, y, w, h), where:\n",
-        "        - (x, y) are the center coordinates of the bounding box\n",
-        "        - w is the width of the bounding box\n",
-        "        - h is the height of the bounding box\n",
-        "\n",
-        "    The bounding boxes are drawn with a unique color for each label.\n",
-        "\n",
-        "    Args:\n",
-        "        img (PIL.Image): The image to plot bounding boxes on.\n",
-        "        boxes (List[Tuple[int, int, int, int]]): A list of bounding boxes.\n",
-        "        labels (List[int]): A list of labels corresponding to the bounding boxes.\n",
-        "\n",
-        "    Returns:\n",
-        "        A PIL.Image object representing the original image with the bounding boxes plotted on it.\n",
-        "    \"\"\"\n",
-        "    draw = ImageDraw.Draw(img)\n",
-        "    colors = [\"red\", \"blue\", \"green\", \"yellow\", \"purple\", \"orange\", \"pink\", \"cyan\", \"magenta\", \"lime\"]\n",
-        "    for box, label in zip(boxes, labels):\n",
-        "        x, y, w, h = box\n",
-        "        color = colors[label % len(colors)]\n",
-        "        draw.rectangle(\n",
-        "            (x - w / 2, y - h / 2, x + w / 2, y + h / 2),\n",
-        "            outline=color,\n",
-        "            width=3,\n",
-        "        )\n",
-        "        draw.text((x - w / 2, y - h / 2), str(label), fill=color)\n",
-        "    return img\n",
-        "\n",
-        "\n",
-        "# Well use this function later\n",
-        "def plot_grid(imgs, nrows, ncols):\n",
-        "    \"\"\"\n",
-        "    This function plots a grid of images using the given list of images.\n",
-        "    The grid has the specified number of rows and columns.\n",
-        "    The size of the figure is set to 10x10 inches.\n",
-        "    Returns None.\n",
-        "\n",
-        "    Parameters:\n",
-        "        - imgs (List[PIL.Image]): a list of PIL.Image objects to plot\n",
-        "        - nrows (int): the number of rows in the grid\n",
-        "        - ncols (int): the number of columns in the grid\n",
-        "\n",
-        "    Returns:\n",
-        "        None.\n",
-        "    \"\"\"\n",
-        "    assert len(imgs) == nrows * ncols, \"nrows*ncols must be equal to the number of images\"\n",
-        "    fig = plt.figure(figsize=(10.0, 10.0))\n",
-        "    grid = ImageGrid(\n",
-        "        fig,\n",
-        "        111,  # similar to subplot(111)\n",
-        "        nrows_ncols=(nrows, ncols),\n",
-        "        axes_pad=0.1,  # pad between axes in inch.\n",
-        "    )\n",
-        "    for ax, im in zip(grid, imgs):\n",
-        "        # Iterating over the grid returns the Axes.\n",
-        "        ax.imshow(im)\n",
-        "    plt.show()\n",
-        "\n",
-        "\n",
-        "def box_cxcywh_to_xyxy(x):\n",
-        "    \"\"\"\n",
-        "    Convert bounding boxes from (center x, center y, width, height) format to\n",
-        "    (x1, y1, x2, y2) format.\n",
-        "\n",
-        "    Args:\n",
-        "        x (torch.Tensor): A tensor of shape (N, 4) in (center x, center y,\n",
-        "            width, height) format.\n",
-        "\n",
-        "    Returns:\n",
-        "        torch.Tensor: A tensor of shape (N, 4) in (x1, y1, x2, y2) format.\n",
-        "    \"\"\"\n",
-        "    x_c, y_c, w, h = x.unbind(-1)\n",
-        "    x1 = x_c - 0.5 * w\n",
-        "    y1 = y_c - 0.5 * h\n",
-        "    x2 = x_c + 0.5 * w\n",
-        "    y2 = y_c + 0.5 * h\n",
-        "    b = torch.stack([x1, y1, x2, y2], dim=-1)\n",
-        "    return b\n",
-        "\n",
-        "\n",
-        "def box_xyxy_to_cxcywh(x):\n",
-        "    \"\"\"\n",
-        "    Convert bounding boxes from (x1, y1, x2, y2) format to (center_x, center_y, width, height)\n",
-        "    format.\n",
-        "\n",
-        "    Args:\n",
-        "        x (torch.Tensor): A tensor of shape (N, 4) in (x1, y1, x2, y2) format.\n",
-        "\n",
-        "    Returns:\n",
-        "        torch.Tensor: A tensor of shape (N, 4) in (center_x, center_y, width, height) format.\n",
-        "    \"\"\"\n",
-        "    x0, y0, x1, y1 = x.unbind(-1)\n",
-        "    center_x = (x0 + x1) / 2\n",
-        "    center_y = (y0 + y1) / 2\n",
-        "    width = x1 - x0\n",
-        "    height = y1 - y0\n",
-        "    b = torch.stack([center_x, center_y, width, height], dim=-1)\n",
-        "    return b\n",
-        "\n",
-        "\n",
-        "def box_xywh_to_xyxy(x): # Modified to handle list inputs\n",
-        "    \"\"\"\n",
-        "    Convert bounding box from (x, y, w, h) format to (x1, y1, x2, y2) format.\n",
-        "\n",
-        "    Args:\n",
-        "        x (torch.Tensor): A tensor of shape (N, 4) in (x, y, w, h) format.\n",
-        "\n",
-        "    Returns:\n",
-        "        torch.Tensor: A tensor of shape (N, 4) in (x1, y1, x2, y2) format.\n",
-        "    \"\"\"\n",
-        "    if not isinstance(x, torch.Tensor):\n",
-        "        x = torch.tensor(x, dtype=torch.float32)\n",
-        "    x_min, y_min, w, h = x.unbind(-1)\n",
-        "    x_max = x_min + w\n",
-        "    y_max = y_min + h\n",
-        "    b = torch.stack([x_min, y_min, x_max, y_max], dim=-1)\n",
-        "    return b\n",
-        "\n",
-        "\n",
-        "def predict(model, img, n_classes, nms_threshold=0.1, conf_threshold=0.25): # Modified\n",
-        "    model.eval()\n",
-        "    img_size = img.shape[-1]\n",
-        "    outputs = model(img.to(\"cuda\"))\n",
-        "    outputs[\"pred_logits\"] = outputs[\"pred_logits\"].cpu()  # [Bs, N, C]\n",
-        "    outputs[\"pred_boxes\"] = outputs[\"pred_boxes\"].cpu()  # [Bs, N, 4]\n",
-        "    outputs[\"pred_objectness\"] = outputs[\"pred_objectness\"].cpu()  # [Bs, N, 1]\n",
-        "    prob = F.softmax(outputs[\"pred_logits\"][0], dim=1)\n",
-        "    top_p, top_class = prob.topk(1, dim=1)\n",
-        "    boxes = outputs[\"pred_boxes\"][0]\n",
-        "    scores = top_p.squeeze()\n",
-        "    top_class = top_class.squeeze()\n",
-        "    keep = outputs[\"pred_objectness\"][0].squeeze() >= conf_threshold\n",
-        "    boxes = boxes[keep]\n",
-        "    scores = scores[keep]\n",
-        "    top_class = top_class[keep]\n",
-        "    if len(outputs[\"pred_logits\"]) == 0:\n",
-        "        return {\n",
-        "            \"boxes\": torch.tensor([]),\n",
-        "            \"scores\": torch.tensor([]),\n",
-        "            \"labels\": torch.tensor([]),\n",
-        "        }\n",
-        "    sel_boxes_idx = torchvision.ops.nms(boxes=box_cxcywh_to_xyxy(boxes), scores=scores, iou_threshold=nms_threshold)\n",
-        "    return boxes[sel_boxes_idx] * img_size, scores[sel_boxes_idx], top_class[sel_boxes_idx]\n",
-        "\n",
-        "\n",
-        "class ResizeWithBBox(object):\n",
-        "    \"\"\"\n",
-        "    Resizes an image and its corresponding bounding boxes.\n",
-        "    \"\"\"\n",
-        "\n",
-        "    def __init__(self, size):\n",
-        "        \"\"\"\n",
-        "        Initializes the transform.\n",
-        "\n",
-        "        Args:\n",
-        "            size: tuple, containing the new size of the image.\n",
-        "        \"\"\"\n",
-        "        self.size = size\n",
-        "\n",
-        "    def __call__(self, image, boxes):\n",
-        "        \"\"\"\n",
-        "        Applies the transform to an image and its corresponding bounding boxes.\n",
-        "\n",
-        "        Args:\n",
-        "            image: PIL.Image object, containing the original image.\n",
-        "            boxes: a list of bounding box coordinates in the format [cx, cy, width, height].\n",
-        "\n",
-        "        Returns:\n",
-        "        new_image: PIL.Image object, containing the resized image.\n",
-        "        new_boxes: the bounding box coordinates scaled to the new image size. Range [0-1]\n",
-        "        \"\"\"\n",
-        "\n",
-        "        width_scale = self.size[0] / image.size[0]\n",
-        "        height_scale = self.size[1] / image.size[1]\n",
-        "        new_image = image.resize(self.size)\n",
-        "\n",
-        "        new_boxes = []\n",
-        "        for box in boxes:\n",
-        "            x1, y1, w, h = box\n",
-        "            new_x1 = x1 * width_scale\n",
-        "            new_y1 = y1 * height_scale\n",
-        "            new_w = w * width_scale\n",
-        "            new_h = h * height_scale\n",
-        "            new_boxes.append([new_x1 / self.size[0], new_y1 / self.size[1], new_w / self.size[0], new_h / self.size[1]])\n",
-        "\n",
-        "        return new_image, new_boxes\n",
-        "\n",
-        "\n",
-        "class FileBasedAPCalculator:\n",
-        "    \"\"\"A class for calculating average precision (AP) from text files containing detections.\n",
-        "\n",
-        "    This class reads ground truth and prediction bounding boxes from text files and calculates\n",
-        "    the mean average precision (mAP).\n",
-        "\n",
-        "    Attributes:\n",
-        "        gt_file (str): Path to the ground truth file.\n",
-        "        pred_file (str): Path to the prediction file.\n",
-        "        metric (MeanAveragePrecision): An instance of MeanAveragePrecision class from torchmetrics.\n",
-        "\n",
-        "    Args:\n",
-        "        gt_file (str): Path to the ground truth file with format \"file_name, cx, cy, w, h, class_id\" per line.\n",
-        "        pred_file (str): Path to the prediction file with format \"file_name, cx, cy, w, h, class_id, score\" per line.\n",
-        "        box_format (str, optional): Format of the bounding boxes. Defaults to \"cxcywh\".\n",
-        "    \"\"\"\n",
-        "\n",
-        "    def __init__(self, gt_file, pred_file, box_format=\"cxcywh\"):\n",
-        "        \"\"\"Initializes the FileBasedAPCalculator object with the ground truth and prediction files.\"\"\"\n",
-        "        self.gt_file = gt_file\n",
-        "        self.pred_file = pred_file\n",
-        "        self.metric = MeanAveragePrecision(iou_type=\"bbox\", box_format=box_format, class_metrics=True)\n",
-        "\n",
-        "    def _parse_file(self, file_path, is_pred=False):\n",
-        "        \"\"\"Parses a text file containing bounding box information.\n",
-        "\n",
-        "        Args:\n",
-        "            file_path (str): Path to the file to parse.\n",
-        "            is_pred (bool, optional): Whether the file contains predictions (with confidence scores).\n",
-        "                                     Defaults to False.\n",
-        "\n",
-        "        Returns:\n",
-        "            dict: A dictionary mapping file names to lists of bounding boxes and labels.\n",
-        "        \"\"\"\n",
-        "        result = {}\n",
-        "        with open(file_path, \"r\") as f:\n",
-        "            for line in f:\n",
-        "                parts = line.strip().split(\",\")\n",
-        "\n",
-        "                # Skip empty lines or malformed entries\n",
-        "                if len(parts) < 6:\n",
-        "                    continue\n",
-        "\n",
-        "                file_name = parts[0].strip()\n",
-        "\n",
-        "                # Parse coordinates and convert to float\n",
-        "                cx, cy, w, h = map(float, parts[1:5])\n",
-        "                class_id = int(parts[5])\n",
-        "\n",
-        "                # Initialize entry for this file if it doesn't exist\n",
-        "                if file_name not in result:\n",
-        "                    if is_pred:\n",
-        "                        result[file_name] = {\"boxes\": [], \"labels\": [], \"scores\": []}\n",
-        "                    else:\n",
-        "                        result[file_name] = {\"boxes\": [], \"labels\": []}\n",
-        "\n",
-        "                # Add bounding box and label\n",
-        "                result[file_name][\"boxes\"].append([cx, cy, w, h])\n",
-        "                result[file_name][\"labels\"].append(class_id)\n",
-        "\n",
-        "                # Add score if it's a prediction file\n",
-        "                if is_pred and len(parts) > 6:\n",
-        "                    score = float(parts[6])\n",
-        "                    result[file_name][\"scores\"].append(score)\n",
-        "\n",
-        "        return result\n",
-        "\n",
-        "    def calculate_map(self):\n",
-        "        \"\"\"Calculates the mean average precision (mAP) using the ground truth and prediction files.\n",
-        "\n",
-        "        Returns:\n",
-        "            dict: A dictionary containing the mAP and other related metrics.\n",
-        "        \"\"\"\n",
-        "        # Parse the ground truth and prediction files\n",
-        "        gt_data = self._parse_file(self.gt_file)\n",
-        "        pred_data = self._parse_file(self.pred_file, is_pred=True)\n",
-        "\n",
-        "        # Convert the parsed data to the format expected by the metric\n",
-        "        for file_name in gt_data:\n",
-        "            GTs = []\n",
-        "            preds = []\n",
-        "\n",
-        "            # Convert ground truth to tensors\n",
-        "            if file_name in gt_data:\n",
-        "                gt_boxes = torch.tensor(gt_data[file_name][\"boxes\"], dtype=torch.float32)\n",
-        "                gt_labels = torch.tensor(gt_data[file_name][\"labels\"], dtype=torch.int64)\n",
-        "                GTs.append({\"boxes\": gt_boxes, \"labels\": gt_labels})\n",
-        "\n",
-        "            # Convert predictions to tensors if available\n",
-        "            if file_name in pred_data:\n",
-        "                pred_boxes = torch.tensor(pred_data[file_name][\"boxes\"], dtype=torch.float32)\n",
-        "                pred_labels = torch.tensor(pred_data[file_name][\"labels\"], dtype=torch.int64)\n",
-        "                pred_scores = torch.tensor(pred_data[file_name][\"scores\"], dtype=torch.float32)\n",
-        "                preds.append({\"boxes\": pred_boxes, \"scores\": pred_scores, \"labels\": pred_labels})\n",
-        "\n",
-        "            # Skip if we don't have both predictions and ground truth\n",
-        "            if not GTs or not preds:\n",
-        "                continue\n",
-        "\n",
-        "            # Update the metric\n",
-        "            self.metric.update(preds, GTs)\n",
-        "\n",
-        "        # Compute the results\n",
-        "        result = self.metric.compute()\n",
-        "        self.metric.reset()\n",
-        "        return result"
-      ],
-      "metadata": {
-        "id": "AMsY3AD87vVC"
-      },
-      "execution_count": 4,
-      "outputs": []
-    }
-  ]
-}
\ No newline at end of file