{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Urgences - Image CC0 - pexels.com](img/pexels-pixabay-263402.jpg \"Urgences\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Challenge - [ED Lab Prediction]\n",
    "_Nom à trouver_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Objectif"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Ce notebook effectue le pre-processing des données.  \n",
    "Il exploite les données stockées dans la base sqlite, téléchargées à partir de `download_data.py` et les exporte dans un fichier csv à destination de l'entrainement de la data-visualisation l'algorithme."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Chargement des données"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sqlite3\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sqlite connection\n",
    "conn = sqlite3.connect(\"./data/mimic-iv.sqlite\")\n",
    "\n",
    "# Classification des items de biologie\n",
    "items = pd.read_csv(\"./config/lab_items.csv\").dropna()\n",
    "items_list = items[\"item_id\"].astype(\"str\").tolist()\n",
    "\n",
    "# Classification ATC des médicaments\n",
    "drugs_rules = pd.read_csv(\"./config/atc_items.csv\")\n",
    "drugs_rules_list = drugs_rules[\"gsn\"].drop_duplicates().astype(\"str\").tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Création d'un index pour accélérer les requêtes\n",
    "conn.execute(\"CREATE INDEX IF NOT EXISTS biological_index ON labevents (stay_id, itemid)\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stays dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "stays = pd.read_sql(f\"\"\"\n",
    "    SELECT \n",
    "        s.stay_id,\n",
    "        s.intime intime,\n",
    "        p.gender gender,\n",
    "        p.anchor_age age,\n",
    "        t.temperature,\n",
    "        t.heartrate,\n",
    "        t.resprate,\n",
    "        t.o2sat,\n",
    "        t.sbp,\n",
    "        t.dbp,\n",
    "        t.pain,\n",
    "        t.chiefcomplaint\n",
    "    FROM edstays s\n",
    "    LEFT JOIN patients p\n",
    "        ON p.subject_id = s.subject_id\n",
    "    LEFT Join triage t\n",
    "        ON t.stay_id = s.stay_id\n",
    "\"\"\", conn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Passage dans les 30 et 7 jours\n",
    "\n",
    "derniers_passages = pd.read_sql(f\"\"\"\n",
    "    SELECT DISTINCT\n",
    "        s1.stay_id,\n",
    "        CAST(MAX((julianday(s1.intime)-julianday(s2.intime))) <= 7 AS INT) last_7,\n",
    "        CAST(MAX((julianday(s1.intime)-julianday(s2.intime))) <= 30 AS INT) last_30\n",
    "    FROM edstays s1\n",
    "    INNER JOIN edstays s2\n",
    "        ON s1.subject_id = s2.subject_id\n",
    "            AND s1.stay_id != s2.stay_id\n",
    "            AND s1.intime >= s2.intime\n",
    "    WHERE (julianday(s1.intime)-julianday(s2.intime)) <= 30\n",
    "    GROUP BY s1.stay_id \n",
    "\"\"\", conn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Derniers diagnostic\n",
    "from icdcodex import icd2vec, hierarchy\n",
    "import numpy as np\n",
    "\n",
    "dernier_diag = pd.read_sql(f\"\"\"\n",
    "    SELECT \n",
    "        s1.stay_id,\n",
    "        d.icd_code,\n",
    "        d.icd_version,\n",
    "        COUNT(1) n\n",
    "    FROM edstays s1\n",
    "    INNER JOIN diagnosis d\n",
    "        ON d.subject_id = s1.subject_id\n",
    "    INNER JOIN edstays s2\n",
    "        ON d.stay_id = s2.stay_id\n",
    "    WHERE \n",
    "        s1.intime >= s2.intime\n",
    "        AND s1.stay_id != s2.stay_id\n",
    "    GROUP BY \n",
    "        s1.stay_id,\n",
    "        d.icd_code,\n",
    "        d.icd_version\n",
    "\"\"\", conn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedder_icd9 = icd2vec.Icd2Vec(num_embedding_dimensions=10, workers=-1)\n",
    "embedder_icd9.fit(*hierarchy.icd9())\n",
    "\n",
    "icd_9 = dernier_diag.query(\"icd_version == 9\")[\"icd_code\"]\n",
    "\n",
    "# Hotfix\n",
    "icd_9 = icd_9.replace(\"E119\",\"E0119\")\n",
    "icd_9 = icd_9[icd_9.isin(hierarchy.icd9()[1])].drop_duplicates()\n",
    "\n",
    "\n",
    "icd_9_embedding = embedder_icd9.to_vec(icd_9)\n",
    "\n",
    "embedder_icd10 = icd2vec.Icd2Vec(num_embedding_dimensions=10, workers=-1)\n",
    "embedder_icd10.fit(*hierarchy.icd10cm(version=\"2020\"))\n",
    "\n",
    "icd_10 = dernier_diag.query(\"icd_version == 10\")[\"icd_code\"]\n",
    "icd_10 = icd_10.apply(lambda x: x[0:3]+\".\"+x[3:] if len(x) > 3 else x)\n",
    "icd_10 = icd_10[icd_10.isin(hierarchy.icd10cm(version=\"2020\")[1])].drop_duplicates()\n",
    "icd_10_embedding = embedder_icd10.to_vec(icd_10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 320,
   "metadata": {},
   "outputs": [],
   "source": [
    "icd_series = pd.concat([\n",
    "    \"ICD9_\"+icd_9,\n",
    "    \"ICD10_\"+icd_10\n",
    "]).reset_index(drop=True)\n",
    "\n",
    "dernier_diag_list = dernier_diag.assign(\n",
    "    icd_str = lambda x: \"ICD\"+x[\"icd_version\"].astype(\"str\")+\"_\"+x[\"icd_code\"],\n",
    ")[\"icd_str\"]\n",
    "\n",
    "icd_embeddings_matrix = np.concatenate([\n",
    "    icd_9_embedding,\n",
    "    icd_10_embedding\n",
    "], axis=0)\n",
    "\n",
    "icd_to_idx = icd_series.reset_index().set_index(\"icd_code\").join(\n",
    "    dernier_diag_list.drop_duplicates().reset_index().set_index(\"icd_str\").drop(columns=\"index\"),\n",
    "    how=\"right\"\n",
    ").fillna(pd.NA).astype(pd.Int64Dtype())[\"index\"].to_dict()\n",
    "\n",
    "dernier_diag[\"icd_idx\"] = dernier_diag_list.apply(lambda x: icd_to_idx[x])\n",
    "dernier_diag_idx = dernier_diag.dropna().groupby(\"stay_id\")[\"icd_idx\"].agg(lambda x: x.tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 324,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn import Embedding\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "import torch\n",
    "\n",
    "icd_embeddings_matrix_with_pad = np.concatenate([\n",
    "    icd_embeddings_matrix,\n",
    "    np.zeros((1, icd_embeddings_matrix.shape[1]))\n",
    "])\n",
    "\n",
    "torch_embedding = Embedding(\n",
    "    icd_embeddings_matrix_with_pad.shape[0],\n",
    "    embedding_dim=10,\n",
    "    _weight=torch.tensor(icd_embeddings_matrix_with_pad)\n",
    ")\n",
    "torch_embedding.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 325,
   "metadata": {},
   "outputs": [],
   "source": [
    "dernier_diag_idx_tensor = pad_sequence([torch.tensor(x) for x in dernier_diag_idx.tolist()],\n",
    "             batch_first=True,\n",
    "             padding_value=icd_embeddings_matrix_with_pad.shape[0]-1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 361,
   "metadata": {},
   "outputs": [],
   "source": [
    "dernier_diag_idx_tensor_mask = (dernier_diag_idx_tensor != icd_embeddings_matrix_with_pad.shape[0]-1).unsqueeze(2)*1\n",
    "dernier_diag_idx_tensor_embeddings = torch_embedding(dernier_diag_idx_tensor).sum(axis=1)/(dernier_diag_idx_tensor_mask.sum(axis=1)+1e-8)\n",
    "dernier_diag_idx_tensor_embeddings = dernier_diag_idx_tensor_embeddings.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 371,
   "metadata": {},
   "outputs": [],
   "source": [
    "stays = stays.join(\n",
    "    pd.DataFrame(dernier_diag_idx_tensor_embeddings, index=dernier_diag_idx.index, columns=[\"diag_\"+str(x) for x in range(10)]),\n",
    "    on = \"stay_id\",\n",
    "    how=\"left\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "stays = stays \\\n",
    "    .join(derniers_passages.set_index(\"stay_id\"), on=\"stay_id\")\n",
    "stays[\"last_7\"] = stays[\"last_7\"].fillna(0)\n",
    "stays[\"last_30\"] = stays[\"last_30\"].fillna(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "stays[\"intime\"] = pd.to_datetime(stays[\"intime\"])\n",
    "stays[\"gender\"] = stays[\"gender\"].astype(\"string\") # Pas de valeurs manquantes en gender\n",
    "stays[\"chiefcomplaint\"] = stays[\"chiefcomplaint\"].fillna(\"\").astype(\"string\") # ¨Chiefcomplaint manquant = chiefcomplaint vide"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "drugs = pd.read_sql(f\"\"\"\n",
    "    SELECT stay_id, gsn, etccode, 1 n\n",
    "    FROM medrecon\n",
    "    WHERE gsn IN ({','.join(drugs_rules_list)})\n",
    "\"\"\", conn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Liste des codes ATC pour chaque séjour\n",
    "atc_stays = pd.merge(\n",
    "    drugs,\n",
    "    drugs_rules,\n",
    "    left_on=\"gsn\",\n",
    "    right_on=\"gsn\"\n",
    ").groupby([\"stay_id\",\"atc\"])[\"n\"].sum() \\\n",
    " .reset_index()\n",
    "\n",
    "atc_stays[\"atc_2\"] = atc_stays[\"atc\"].str.slice(0, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Considérons 2 niveaux de granularité\n",
    "## Le code ATC complet (Anatomique, Thérapeutique et Pharmacologique), ATC IV\n",
    "\n",
    "atc_stays_pivoted_4 = pd.pivot_table(\n",
    "    atc_stays[[\"stay_id\",\"atc\", \"n\"]],\n",
    "    columns=[\"atc\"],\n",
    "    index=[\"stay_id\"],\n",
    "    values=\"n\"\n",
    ").fillna(0).reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Le code ATC 2 (Anatomique et Thérapeutique)\n",
    "\n",
    "atc_stays_pivoted_2 = pd.pivot_table(\n",
    "    atc_stays[[\"stay_id\",\"atc_2\", \"n\"]] \\\n",
    "        .groupby([\"stay_id\",\"atc_2\"])[\"n\"].sum() \\\n",
    "        .reset_index() \\\n",
    "        .rename(columns={\"atc_2\":\"atc\"}),\n",
    "    columns=[\"atc\"],\n",
    "    index=[\"stay_id\"],\n",
    "    values=\"n\"\n",
    ").fillna(0).reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Les codes ETC\n",
    "\n",
    "etc_pivoted = pd.pivot_table(\n",
    "    drugs[[\"stay_id\",\"etccode\", \"n\"]].dropna() \\\n",
    "        .assign(etccode = lambda x: x[\"etccode\"].astype(\"int\").astype(\"str\")) \\\n",
    "        .groupby([\"stay_id\",\"etccode\"])[\"n\"].sum() \\\n",
    "        .reset_index() \\\n",
    "        .rename(columns={\"etccode\":\"atc\"}),\n",
    "    columns=[\"atc\"],\n",
    "    index=[\"stay_id\"],\n",
    "    values=\"n\"\n",
    ").fillna(0).reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "stays_atc_4 = pd.merge(\n",
    "    stays,\n",
    "    atc_stays_pivoted_4,\n",
    "    left_on=\"stay_id\",\n",
    "    right_on=\"stay_id\",\n",
    "    how=\"left\"\n",
    ")\n",
    "\n",
    "stays_atc_2 = pd.merge(\n",
    "    stays,\n",
    "    atc_stays_pivoted_2,\n",
    "    left_on=\"stay_id\",\n",
    "    right_on=\"stay_id\",\n",
    "    how=\"left\"\n",
    ")\n",
    "\n",
    "stays_etc = pd.merge(\n",
    "    stays,\n",
    "    etc_pivoted,\n",
    "    left_on=\"stay_id\",\n",
    "    right_on=\"stay_id\",\n",
    "    how=\"left\"\n",
    ")\n",
    "\n",
    "stays_atc_4[atc_stays_pivoted_4.columns[1:]] = stays_atc_4[atc_stays_pivoted_4.columns[1:]].fillna(0)\n",
    "stays_atc_2[atc_stays_pivoted_2.columns[1:]] = stays_atc_2[atc_stays_pivoted_2.columns[1:]].fillna(0)\n",
    "stays_etc[etc_pivoted.columns[1:]] = stays_etc[etc_pivoted.columns[1:]].fillna(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ecriture du featues dataset\n",
    "# On écrit en parquet pour optimiser le stockage et les temps d'io\n",
    "\n",
    "stays_atc_2.sort_values(\"stay_id\").reset_index(drop=True).to_parquet(\"./data/features_atc2.parquet\", engine=\"pyarrow\", index=False)\n",
    "stays_atc_4.sort_values(\"stay_id\").reset_index(drop=True).to_parquet(\"./data/features_atc4.parquet\", engine=\"pyarrow\", index=False)\n",
    "stays_etc.sort_values(\"stay_id\").reset_index(drop=True).to_parquet(\"./data/features_etc.parquet\", engine=\"pyarrow\", index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lab dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'pd' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Input \u001b[0;32mIn [1]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0m labs \u001b[38;5;241m=\u001b[39m \u001b[43mpd\u001b[49m\u001b[38;5;241m.\u001b[39mread_sql(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\"\"\u001b[39m\n\u001b[1;32m      2\u001b[0m \u001b[38;5;124m    SELECT \u001b[39m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;124m        le.stay_id,\u001b[39m\n\u001b[1;32m      4\u001b[0m \u001b[38;5;124m        le.itemid item_id\u001b[39m\n\u001b[1;32m      5\u001b[0m \u001b[38;5;124m    FROM labevents le\u001b[39m\n\u001b[1;32m      6\u001b[0m \u001b[38;5;124m    WHERE le.itemid IN (\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m,\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(items_list)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m)\u001b[39m\n\u001b[1;32m      7\u001b[0m \u001b[38;5;124m    GROUP BY\u001b[39m\n\u001b[1;32m      8\u001b[0m \u001b[38;5;124m        le.stay_id,\u001b[39m\n\u001b[1;32m      9\u001b[0m \u001b[38;5;124m        le.itemid\u001b[39m\n\u001b[1;32m     10\u001b[0m \u001b[38;5;124m\"\"\"\u001b[39m, conn)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'pd' is not defined"
     ]
    }
   ],
   "source": [
    "labs = pd.read_sql(f\"\"\"\n",
    "    SELECT \n",
    "        le.stay_id,\n",
    "        le.itemid item_id\n",
    "    FROM labevents le\n",
    "    WHERE le.itemid IN ('{\"','\".join(items_list)}')\n",
    "    GROUP BY\n",
    "        le.stay_id,\n",
    "        le.itemid\n",
    "\"\"\", conn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labs_deduplicate = pd.merge(\n",
    "    items[[\"item_id\",\"3\"]].rename(columns={\"3\":\"label\"}),\n",
    "    labs,\n",
    "    left_on=\"item_id\",\n",
    "    right_on=\"item_id\"\n",
    ") \\\n",
    " .drop_duplicates([\"stay_id\", \"label\"])[[\"stay_id\",\"label\"]] \\\n",
    " .reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labs_deduplicate_pivot = pd.pivot_table(\n",
    "    labs_deduplicate.assign(value=1),\n",
    "    index=[\"stay_id\"],\n",
    "    columns=[\"label\"],\n",
    "    values=\"value\"\n",
    ").fillna(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labs_deduplicate_pivot_final = labs_deduplicate_pivot.join(\n",
    "    stays[[\"stay_id\"]].set_index(\"stay_id\"),\n",
    "    how=\"right\"\n",
    ").fillna(0).astype(\"int8\").reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labs_deduplicate_pivot_final.sort_values(\"stay_id\").reset_index(drop=True).to_parquet(\"./data/labels.parquet\", index=False)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "c304935560631f5a20c1bdabb506947800ccd82d813704000c078f0735b9b818"
  },
  "kernelspec": {
   "display_name": "R",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}