Browse Source

Adding preprocessing script

Ali 3 years ago
parent
commit
74679678a5
4 changed files with 778 additions and 25 deletions
  1. 378 0
      0_1_Drug_Embedding.ipynb
  2. 160 14
      0_Preprocessing.ipynb
  3. 52 11
      1_Prototype_MLP.ipynb
  4. 188 0
      scripts/preprocessing.py

+ 378 - 0
0_1_Drug_Embedding.ipynb

@@ -0,0 +1,378 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import sqlite3\n",
+    "import pandas as pd\n",
+    "import numpy as np"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Sqlite connection\n",
+    "conn = sqlite3.connect(\"./data/mimic-iv.sqlite\")\n",
+    "\n",
+    "# Classification ATC des médicaments\n",
+    "drugs_rules = pd.read_csv(\"./config/atc_items.csv\")\n",
+    "drugs_rules_list = drugs_rules[\"gsn\"].drop_duplicates().astype(\"str\").tolist()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Récupération des codes\n",
+    "\n",
+    "drugs = pd.read_sql(f\"\"\"\n",
+    "    SELECT stay_id, gsn, etccode, 1 n\n",
+    "    FROM medrecon\n",
+    "    WHERE gsn IN ({','.join(drugs_rules_list)})\n",
+    "\"\"\", conn)\n",
+    "\n",
+    "# Liste des codes pour chaque séjour\n",
+    "stays_code = pd.merge(\n",
+    "    drugs,\n",
+    "    drugs_rules,\n",
+    "    left_on=\"gsn\",\n",
+    "    right_on=\"gsn\"\n",
+    ") \\\n",
+    " .reset_index()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "stays_code[\"ATC_4\"] = stays_code[\"atc\"]\n",
+    "stays_code[\"ATC_2\"] = stays_code[\"atc\"].str.slice(0,3)\n",
+    "stays_code[\"ETC\"] = stays_code[\"etccode\"]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Création de l'encodeur et des embeddings"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 41,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "variable = \"ETC\""
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 42,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from sklearn.preprocessing import OrdinalEncoder"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 43,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "stays_code_dropped = stays_code.dropna(subset=[variable]).drop_duplicates([\"stay_id\", variable]).reset_index(drop=True)\n",
+    "stays_code_dropped = stays_code_dropped[[\"stay_id\", \"gsn\", variable]] \\\n",
+    "    .rename(columns={variable:\"code\"})\n",
+    "stays_code_dropped[\"code\"] = stays_code_dropped[\"code\"].astype(\"int\").astype(\"str\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 51,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Creation de l'encodeur\n",
+    "encoder = OrdinalEncoder().fit(stays_code_dropped[[\"code\"]])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 52,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Entrainement des embeddings\n",
+    "stays_code_dropped[\"code_id\"] = encoder.transform(stays_code_dropped[[\"code\"]]).astype(\"int32\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 53,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pair_matrix = pd.merge(\n",
+    "    stays_code_dropped[[\"stay_id\",\"gsn\", \"code_id\"]],\n",
+    "    stays_code_dropped[[\"stay_id\",\"gsn\", \"code_id\"]],\n",
+    "    left_on=\"stay_id\",\n",
+    "    right_on=\"stay_id\"\n",
+    ").query(\"gsn_x != gsn_y\")[[\"code_id_x\", \"code_id_y\"]]\n",
+    "\n",
+    "pair_matrix_probability = pair_matrix.assign(n = 1).groupby([\"code_id_x\", \"code_id_y\"]).sum() \\\n",
+    "           .reset_index() \\\n",
+    "           .join(\n",
+    "               pair_matrix.assign(n_total=1).groupby(\"code_id_x\")[\"n_total\"].sum(),\n",
+    "               on=\"code_id_x\"\n",
+    "           ) \\\n",
+    "           .assign(prob=lambda x: x[\"n\"]/x[\"n_total\"])[[\"code_id_x\", \"code_id_y\", \"prob\"]] \\\n",
+    "           .values"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 54,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "from torch import nn, optim\n",
+    "from torch.utils.data import DataLoader"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 55,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class embeddingTrainer (nn.Module):\n",
+    "    def __init__ (self, embedding_size=100):\n",
+    "        super().__init__()\n",
+    "\n",
+    "        # Le dernier index correspond au pad token\n",
+    "        self.embeddings = nn.Embedding(num_embeddings=encoder.categories_[0].shape[0]+1, embedding_dim=embedding_size)\n",
+    "\n",
+    "\n",
+    "        self.network = nn.Sequential(*[\n",
+    "            nn.Linear(embedding_size, 50),\n",
+    "            nn.ReLU(),\n",
+    "            nn.Linear(50, 200),\n",
+    "            nn.ReLU()\n",
+    "        ])\n",
+    "\n",
+    "        self.proba = nn.Sequential(*[\n",
+    "            nn.Linear(400, 200),\n",
+    "            nn.ReLU(),\n",
+    "            nn.Linear(200,50),\n",
+    "            nn.ReLU(),\n",
+    "            nn.Linear(50, 10),\n",
+    "            nn.ReLU(),\n",
+    "            nn.Linear(10,1),\n",
+    "            nn.Sigmoid()\n",
+    "        ])\n",
+    "\n",
+    "        self.loss = nn.BCELoss()\n",
+    "        self.optimizer = optim.Adam(self.parameters(), lr=5e-5)\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "\n",
+    "        word_1 = x[:,0]\n",
+    "        word_2 = x[:,1]\n",
+    "\n",
+    "        embedding_1 = self.network(self.embeddings(word_1))\n",
+    "        embedding_2 = self.network(self.embeddings(word_2))\n",
+    "\n",
+    "        merged_data = torch.concat([embedding_1, embedding_2], axis=1)\n",
+    "\n",
+    "        y_hat = self.proba(merged_data)\n",
+    "\n",
+    "        return y_hat\n",
+    "    \n",
+    "    def fit(self, x, y):\n",
+    "\n",
+    "        self.train()\n",
+    "\n",
+    "        self.optimizer.zero_grad()\n",
+    "\n",
+    "        y_hat = self.forward(x)\n",
+    "        loss = self.loss(y_hat, y)\n",
+    "\n",
+    "        loss.backward()\n",
+    "\n",
+    "        self.optimizer.step()\n",
+    "\n",
+    "        loss_detach = loss.detach().cpu()\n",
+    "\n",
+    "        return loss_detach\n",
+    "    \n",
+    "    def predict(self, x):\n",
+    "\n",
+    "        self.eval()\n",
+    "        with torch.no_grad():\n",
+    "\n",
+    "            y_hat = self.forward(x)\n",
+    "\n",
+    "        return y_hat"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 56,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "loader = DataLoader(pair_matrix_probability, shuffle=True, batch_size=1000)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 57,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "embedding_trainer = embeddingTrainer(embedding_size=100)\n",
+    "embedding_trainer = embedding_trainer.to(\"cuda:0\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 58,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch 0 - Batch 0 - Loss : 0.7970905303955078\n",
+      "Epoch 0 - Loss : 0.45808276534080505\n",
+      "Epoch 1 - Batch 0 - Loss : 0.0421447679400444\n",
+      "Epoch 2 - Batch 0 - Loss : 0.025797121226787567\n",
+      "Epoch 3 - Batch 0 - Loss : 0.027662230655550957\n",
+      "Epoch 4 - Batch 0 - Loss : 0.02129991166293621\n",
+      "Epoch 5 - Batch 0 - Loss : 0.02649623528122902\n",
+      "Epoch 6 - Batch 0 - Loss : 0.025592397898435593\n",
+      "Epoch 7 - Batch 0 - Loss : 0.02580280229449272\n",
+      "Epoch 8 - Batch 0 - Loss : 0.0239135529845953\n",
+      "Epoch 9 - Batch 0 - Loss : 0.025206178426742554\n"
+     ]
+    }
+   ],
+   "source": [
+    "n_epoch = 10\n",
+    "\n",
+    "n_print_epoch = 10\n",
+    "n_print_batch = 1000\n",
+    "\n",
+    "for i in range(n_epoch):\n",
+    "    losses = []\n",
+    "\n",
+    "    j = 0\n",
+    "    for x in loader:\n",
+    "        x_batch = x[:,[0,1]].int()\n",
+    "        x_batch = x_batch.to(\"cuda:0\")\n",
+    "        y_batch = x[:,2].float().unsqueeze(dim=1)\n",
+    "        y_batch = y_batch.to(\"cuda:0\")\n",
+    "\n",
+    "        loss = embedding_trainer.fit(x_batch, y_batch)\n",
+    "        losses.append(loss)\n",
+    "\n",
+    "        if j%n_print_batch == 0:\n",
+    "            loss_mean = np.array(losses).mean()\n",
+    "            print(f\"Epoch {i} - Batch {j} - Loss : {loss_mean}\")\n",
+    "\n",
+    "        j += 1\n",
+    "\n",
+    "    if i%n_print_epoch == 0:\n",
+    "        loss_mean = np.array(losses).mean()\n",
+    "        print(f\"Epoch {i} - Loss : {loss_mean}\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Export"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 59,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import pickle"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Encoder"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 60,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "with open(f\"./models/{variable}_encoder.model\",\"wb\") as f:\n",
+    "    pickle.dump(encoder, f)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Modele d'embedding"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 61,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "with open(f\"./models/{variable}_embedding.model\",\"wb\") as f:\n",
+    "    torch.save(embedding_trainer.embeddings, f)"
+   ]
+  }
+ ],
+ "metadata": {
+  "interpreter": {
+   "hash": "c304935560631f5a20c1bdabb506947800ccd82d813704000c078f0735b9b818"
+  },
+  "kernelspec": {
+   "display_name": "Python 3.9.9 ('base')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.9"
+  },
+  "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

+ 160 - 14
0_Preprocessing.ipynb

@@ -39,7 +39,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -49,7 +49,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -67,7 +67,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -84,7 +84,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 25,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -112,7 +112,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 22,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -135,7 +135,153 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 29,
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Derniers diagnostic\n",
+    "from icdcodex import icd2vec, hierarchy\n",
+    "import numpy as np\n",
+    "\n",
+    "dernier_diag = pd.read_sql(f\"\"\"\n",
+    "    SELECT \n",
+    "        s1.stay_id,\n",
+    "        d.icd_code,\n",
+    "        d.icd_version,\n",
+    "        COUNT(1) n\n",
+    "    FROM edstays s1\n",
+    "    INNER JOIN diagnosis d\n",
+    "        ON d.subject_id = s1.subject_id\n",
+    "    INNER JOIN edstays s2\n",
+    "        ON d.stay_id = s2.stay_id\n",
+    "    WHERE \n",
+    "        s1.intime >= s2.intime\n",
+    "        AND s1.stay_id != s2.stay_id\n",
+    "    GROUP BY \n",
+    "        s1.stay_id,\n",
+    "        d.icd_code,\n",
+    "        d.icd_version\n",
+    "\"\"\", conn)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "embedder_icd9 = icd2vec.Icd2Vec(num_embedding_dimensions=10, workers=-1)\n",
+    "embedder_icd9.fit(*hierarchy.icd9())\n",
+    "\n",
+    "icd_9 = dernier_diag.query(\"icd_version == 9\")[\"icd_code\"]\n",
+    "\n",
+    "# Hotfix\n",
+    "icd_9 = icd_9.replace(\"E119\",\"E0119\")\n",
+    "icd_9 = icd_9[icd_9.isin(hierarchy.icd9()[1])].drop_duplicates()\n",
+    "\n",
+    "\n",
+    "icd_9_embedding = embedder_icd9.to_vec(icd_9)\n",
+    "\n",
+    "embedder_icd10 = icd2vec.Icd2Vec(num_embedding_dimensions=10, workers=-1)\n",
+    "embedder_icd10.fit(*hierarchy.icd10cm(version=\"2020\"))\n",
+    "\n",
+    "icd_10 = dernier_diag.query(\"icd_version == 10\")[\"icd_code\"]\n",
+    "icd_10 = icd_10.apply(lambda x: x[0:3]+\".\"+x[3:] if len(x) > 3 else x)\n",
+    "icd_10 = icd_10[icd_10.isin(hierarchy.icd10cm(version=\"2020\")[1])].drop_duplicates()\n",
+    "icd_10_embedding = embedder_icd10.to_vec(icd_10)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 320,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "icd_series = pd.concat([\n",
+    "    \"ICD9_\"+icd_9,\n",
+    "    \"ICD10_\"+icd_10\n",
+    "]).reset_index(drop=True)\n",
+    "\n",
+    "dernier_diag_list = dernier_diag.assign(\n",
+    "    icd_str = lambda x: \"ICD\"+x[\"icd_version\"].astype(\"str\")+\"_\"+x[\"icd_code\"],\n",
+    ")[\"icd_str\"]\n",
+    "\n",
+    "icd_embeddings_matrix = np.concatenate([\n",
+    "    icd_9_embedding,\n",
+    "    icd_10_embedding\n",
+    "], axis=0)\n",
+    "\n",
+    "icd_to_idx = icd_series.reset_index().set_index(\"icd_code\").join(\n",
+    "    dernier_diag_list.drop_duplicates().reset_index().set_index(\"icd_str\").drop(columns=\"index\"),\n",
+    "    how=\"right\"\n",
+    ").fillna(pd.NA).astype(pd.Int64Dtype())[\"index\"].to_dict()\n",
+    "\n",
+    "dernier_diag[\"icd_idx\"] = dernier_diag_list.apply(lambda x: icd_to_idx[x])\n",
+    "dernier_diag_idx = dernier_diag.dropna().groupby(\"stay_id\")[\"icd_idx\"].agg(lambda x: x.tolist())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 324,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from torch.nn import Embedding\n",
+    "from torch.nn.utils.rnn import pad_sequence\n",
+    "import torch\n",
+    "\n",
+    "icd_embeddings_matrix_with_pad = np.concatenate([\n",
+    "    icd_embeddings_matrix,\n",
+    "    np.zeros((1, icd_embeddings_matrix.shape[1]))\n",
+    "])\n",
+    "\n",
+    "torch_embedding = Embedding(\n",
+    "    icd_embeddings_matrix_with_pad.shape[0],\n",
+    "    embedding_dim=10,\n",
+    "    _weight=torch.tensor(icd_embeddings_matrix_with_pad)\n",
+    ")\n",
+    "torch_embedding.requires_grad = False"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 325,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dernier_diag_idx_tensor = pad_sequence([torch.tensor(x) for x in dernier_diag_idx.tolist()],\n",
+    "             batch_first=True,\n",
+    "             padding_value=icd_embeddings_matrix_with_pad.shape[0]-1\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 361,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dernier_diag_idx_tensor_mask = (dernier_diag_idx_tensor != icd_embeddings_matrix_with_pad.shape[0]-1).unsqueeze(2)*1\n",
+    "dernier_diag_idx_tensor_embeddings = torch_embedding(dernier_diag_idx_tensor).sum(axis=1)/(dernier_diag_idx_tensor_mask.sum(axis=1)+1e-8)\n",
+    "dernier_diag_idx_tensor_embeddings = dernier_diag_idx_tensor_embeddings.detach().numpy()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 371,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "stays = stays.join(\n",
+    "    pd.DataFrame(dernier_diag_idx_tensor_embeddings, index=dernier_diag_idx.index, columns=[\"diag_\"+str(x) for x in range(10)]),\n",
+    "    on = \"stay_id\",\n",
+    "    how=\"left\"\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -147,7 +293,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 30,
+   "execution_count": 9,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -158,7 +304,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 31,
+   "execution_count": 10,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -171,7 +317,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 32,
+   "execution_count": 11,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -189,7 +335,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 12,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -206,7 +352,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 13,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -225,7 +371,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 14,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -245,7 +391,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 15,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -280,7 +426,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 16,
    "metadata": {},
    "outputs": [],
    "source": [

File diff suppressed because it is too large
+ 52 - 11
1_Prototype_MLP.ipynb


+ 188 - 0
scripts/preprocessing.py

@@ -0,0 +1,188 @@
+"""
+    This script the preprocessing functions
+"""
+
+import sqlite3
+import pandas as pd
+
+def generate_features_dataset(database, get_drugs=True, get_diseases=True):
+
+    """
+        Generate features dataset according to the data
+
+        Parameters
+        ----------
+        database: str, path of the database sqlite file
+        get_drugs: boolean, if true the drug history is returned,
+        get_diseases: boolean, if true the disease history is returned
+    """
+
+    to_merge = []
+    
+    # Sqlite connection
+    conn = sqlite3.connect("./data/mimic-iv.sqlite")
+
+    ## Getting the features
+    features = pd.read_sql(f"""
+        SELECT 
+            s.stay_id,
+            s.intime intime,
+            p.gender gender,
+            p.anchor_age age,
+            t.temperature,
+            t.heartrate,
+            t.resprate,
+            t.o2sat,
+            t.sbp,
+            t.dbp,
+            t.pain,
+            t.chiefcomplaint
+        FROM edstays s
+        LEFT JOIN patients p
+            ON p.subject_id = s.subject_id
+        LEFT Join triage t
+            ON t.stay_id = s.stay_id
+    """, conn)
+
+    ## Additional features
+    ### Last visit
+    last_visit = pd.read_sql(f"""
+        SELECT DISTINCT
+            s1.stay_id,
+            CAST(MAX((julianday(s1.intime)-julianday(s2.intime))) <= 7 AS INT) last_7,
+            CAST(MAX((julianday(s1.intime)-julianday(s2.intime))) <= 30 AS INT) last_30
+        FROM edstays s1
+        INNER JOIN edstays s2
+            ON s1.subject_id = s2.subject_id
+                AND s1.stay_id != s2.stay_id
+                AND s1.intime >= s2.intime
+        WHERE (julianday(s1.intime)-julianday(s2.intime)) <= 30
+        GROUP BY s1.stay_id 
+    """, conn)
+    to_merge.append(last_visit)
+
+    ### Past diagnosis
+    if get_diseases:
+        past_diagnosis = pd.read_sql(f"""
+            SELECT 
+                s1.stay_id,
+                d.icd_code,
+                d.icd_version,
+                COUNT(1) n
+            FROM edstays s1
+            INNER JOIN diagnosis d
+                ON d.subject_id = s1.subject_id
+            INNER JOIN edstays s2
+                ON d.stay_id = s2.stay_id
+            WHERE 
+                s1.intime >= s2.intime
+                AND s1.stay_id != s2.stay_id
+            GROUP BY 
+                s1.stay_id,
+                d.icd_code,
+                d.icd_version
+        """, conn)
+        past_diagnosis = pd.pivot_table(
+            past_diagnosis.groupby(["stay_id","icd_version"])["icd_code"].agg(lambda x: x.tolist()) \
+                    .reset_index(),
+                index="stay_id",
+                columns="icd_version",
+                values="icd_code",
+                aggfunc=lambda x: x
+        ).reset_index().rename(columns={
+            9:"icd9",
+            10:"icd10"
+        })
+        to_merge.append(past_diagnosis)
+
+    ### Drugs
+    if get_drugs:
+        drugs = pd.read_sql(f"""
+            SELECT stay_id, gsn, 1 n
+            FROM medrecon
+        """, conn)
+        drugs = drugs.groupby("stay_id")["gsn"].agg(lambda x: x.tolist()).reset_index()
+        to_merge.append(drugs)
+
+    ### Merging all together
+    for df_to_merge in to_merge:
+        features = pd.merge(
+            features,
+            df_to_merge,
+            left_on="stay_id",
+            right_on="stay_id",
+            how="left"
+        )
+
+    features = features.sort_values("stay_id").reset_index(drop=True)
+
+    return features
+    
+
+def generate_labels_dataset(database, lab_dictionnary):
+
+    """
+        Generate features dataset according to the data
+
+        Parameters
+        ----------
+        database: str, path of the database sqlite file
+        lab_dictionnary: dictionnary containing the id (keys) and label (value) of the biological exams to predict
+    """
+
+    to_merge = []
+    
+    # Sqlite connection
+    conn = sqlite3.connect("./data/mimic-iv.sqlite")
+
+    # Getting biological values
+    lab_dictionnary_pd = pd.DataFrame.from_dict(lab_dictionnary, orient="index").reset_index()
+    lab_dictionnary_list = [str(x) for x in lab_dictionnary.keys()]
+
+    ## Let's create an index to speed up queries
+    conn.execute("CREATE INDEX IF NOT EXISTS biological_index ON labevents (stay_id, itemid)")
+
+    # 1. Generating features
+
+    ## Getting list of stay_id
+    stays = pd.read_sql(
+        "SELECT DISTINCT stay_id FROM edstays",
+        conn
+    )
+
+    ## Getting the features
+    labs = pd.read_sql(f"""
+        SELECT 
+            le.stay_id,
+            le.itemid item_id
+        FROM labevents le
+        WHERE le.itemid IN ('{"','".join(lab_dictionnary_list)}')
+        GROUP BY
+            le.stay_id,
+            le.itemid
+    """, conn)
+
+    labs_deduplicate = pd.merge(
+        lab_dictionnary_pd.rename(columns={0:"label"}),
+        labs,
+        left_on="index",
+        right_on="item_id"
+    ) \
+    .drop_duplicates(["stay_id", "label"])[["stay_id","label"]] \
+    .reset_index(drop=True)
+
+    labs_deduplicate_pivot = pd.pivot_table(
+        labs_deduplicate.assign(value=1),
+        index="stay_id",
+        columns="label",
+        values="value"
+    ).fillna(0)
+
+    labs_deduplicate_pivot_final = labs_deduplicate_pivot.join(
+        stays[["stay_id"]].set_index("stay_id"),
+        how="right"
+    ).fillna(0).astype("int8").reset_index()
+
+    labels = labs_deduplicate_pivot_final.sort_values("stay_id").reset_index(drop=True)
+
+    return labels

Some files were not shown because too many files changed in this diff