|
@@ -0,0 +1,378 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 22,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import sqlite3\n",
|
|
|
+ "import pandas as pd\n",
|
|
|
+ "import numpy as np"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 23,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Sqlite connection\n",
|
|
|
+ "conn = sqlite3.connect(\"./data/mimic-iv.sqlite\")\n",
|
|
|
+ "\n",
|
|
|
+ "# Classification ATC des médicaments\n",
|
|
|
+ "drugs_rules = pd.read_csv(\"./config/atc_items.csv\")\n",
|
|
|
+ "drugs_rules_list = drugs_rules[\"gsn\"].drop_duplicates().astype(\"str\").tolist()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 24,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Récupération des codes\n",
|
|
|
+ "\n",
|
|
|
+ "drugs = pd.read_sql(f\"\"\"\n",
|
|
|
+ " SELECT stay_id, gsn, etccode, 1 n\n",
|
|
|
+ " FROM medrecon\n",
|
|
|
+ " WHERE gsn IN ({','.join(drugs_rules_list)})\n",
|
|
|
+ "\"\"\", conn)\n",
|
|
|
+ "\n",
|
|
|
+ "# Liste des codes pour chaque séjour\n",
|
|
|
+ "stays_code = pd.merge(\n",
|
|
|
+ " drugs,\n",
|
|
|
+ " drugs_rules,\n",
|
|
|
+ " left_on=\"gsn\",\n",
|
|
|
+ " right_on=\"gsn\"\n",
|
|
|
+ ") \\\n",
|
|
|
+ " .reset_index()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 25,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "stays_code[\"ATC_4\"] = stays_code[\"atc\"]\n",
|
|
|
+ "stays_code[\"ATC_2\"] = stays_code[\"atc\"].str.slice(0,3)\n",
|
|
|
+ "stays_code[\"ETC\"] = stays_code[\"etccode\"]"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "# Création de l'encodeur et des embeddings"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 41,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "variable = \"ETC\""
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 42,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "from sklearn.preprocessing import OrdinalEncoder"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 43,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "stays_code_dropped = stays_code.dropna(subset=[variable]).drop_duplicates([\"stay_id\", variable]).reset_index(drop=True)\n",
|
|
|
+ "stays_code_dropped = stays_code_dropped[[\"stay_id\", \"gsn\", variable]] \\\n",
|
|
|
+ " .rename(columns={variable:\"code\"})\n",
|
|
|
+ "stays_code_dropped[\"code\"] = stays_code_dropped[\"code\"].astype(\"int\").astype(\"str\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 51,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Creation de l'encodeur\n",
|
|
|
+ "encoder = OrdinalEncoder().fit(stays_code_dropped[[\"code\"]])"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 52,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Entrainement des embeddings\n",
|
|
|
+ "stays_code_dropped[\"code_id\"] = encoder.transform(stays_code_dropped[[\"code\"]]).astype(\"int32\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 53,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "pair_matrix = pd.merge(\n",
|
|
|
+ " stays_code_dropped[[\"stay_id\",\"gsn\", \"code_id\"]],\n",
|
|
|
+ " stays_code_dropped[[\"stay_id\",\"gsn\", \"code_id\"]],\n",
|
|
|
+ " left_on=\"stay_id\",\n",
|
|
|
+ " right_on=\"stay_id\"\n",
|
|
|
+ ").query(\"gsn_x != gsn_y\")[[\"code_id_x\", \"code_id_y\"]]\n",
|
|
|
+ "\n",
|
|
|
+ "pair_matrix_probability = pair_matrix.assign(n = 1).groupby([\"code_id_x\", \"code_id_y\"]).sum() \\\n",
|
|
|
+ " .reset_index() \\\n",
|
|
|
+ " .join(\n",
|
|
|
+ " pair_matrix.assign(n_total=1).groupby(\"code_id_x\")[\"n_total\"].sum(),\n",
|
|
|
+ " on=\"code_id_x\"\n",
|
|
|
+ " ) \\\n",
|
|
|
+ " .assign(prob=lambda x: x[\"n\"]/x[\"n_total\"])[[\"code_id_x\", \"code_id_y\", \"prob\"]] \\\n",
|
|
|
+ " .values"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 54,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import torch\n",
|
|
|
+ "from torch import nn, optim\n",
|
|
|
+ "from torch.utils.data import DataLoader"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 55,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "class embeddingTrainer (nn.Module):\n",
|
|
|
+ " def __init__ (self, embedding_size=100):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ "\n",
|
|
|
+ " # Le dernier index correspond au pad token\n",
|
|
|
+ " self.embeddings = nn.Embedding(num_embeddings=encoder.categories_[0].shape[0]+1, embedding_dim=embedding_size)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ " self.network = nn.Sequential(*[\n",
|
|
|
+ " nn.Linear(embedding_size, 50),\n",
|
|
|
+ " nn.ReLU(),\n",
|
|
|
+ " nn.Linear(50, 200),\n",
|
|
|
+ " nn.ReLU()\n",
|
|
|
+ " ])\n",
|
|
|
+ "\n",
|
|
|
+ " self.proba = nn.Sequential(*[\n",
|
|
|
+ " nn.Linear(400, 200),\n",
|
|
|
+ " nn.ReLU(),\n",
|
|
|
+ " nn.Linear(200,50),\n",
|
|
|
+ " nn.ReLU(),\n",
|
|
|
+ " nn.Linear(50, 10),\n",
|
|
|
+ " nn.ReLU(),\n",
|
|
|
+ " nn.Linear(10,1),\n",
|
|
|
+ " nn.Sigmoid()\n",
|
|
|
+ " ])\n",
|
|
|
+ "\n",
|
|
|
+ " self.loss = nn.BCELoss()\n",
|
|
|
+ " self.optimizer = optim.Adam(self.parameters(), lr=5e-5)\n",
|
|
|
+ "\n",
|
|
|
+ " def forward(self, x):\n",
|
|
|
+ "\n",
|
|
|
+ " word_1 = x[:,0]\n",
|
|
|
+ " word_2 = x[:,1]\n",
|
|
|
+ "\n",
|
|
|
+ " embedding_1 = self.network(self.embeddings(word_1))\n",
|
|
|
+ " embedding_2 = self.network(self.embeddings(word_2))\n",
|
|
|
+ "\n",
|
|
|
+ " merged_data = torch.concat([embedding_1, embedding_2], axis=1)\n",
|
|
|
+ "\n",
|
|
|
+ " y_hat = self.proba(merged_data)\n",
|
|
|
+ "\n",
|
|
|
+ " return y_hat\n",
|
|
|
+ " \n",
|
|
|
+ " def fit(self, x, y):\n",
|
|
|
+ "\n",
|
|
|
+ " self.train()\n",
|
|
|
+ "\n",
|
|
|
+ " self.optimizer.zero_grad()\n",
|
|
|
+ "\n",
|
|
|
+ " y_hat = self.forward(x)\n",
|
|
|
+ " loss = self.loss(y_hat, y)\n",
|
|
|
+ "\n",
|
|
|
+ " loss.backward()\n",
|
|
|
+ "\n",
|
|
|
+ " self.optimizer.step()\n",
|
|
|
+ "\n",
|
|
|
+ " loss_detach = loss.detach().cpu()\n",
|
|
|
+ "\n",
|
|
|
+ " return loss_detach\n",
|
|
|
+ " \n",
|
|
|
+ " def predict(self, x):\n",
|
|
|
+ "\n",
|
|
|
+ " self.eval()\n",
|
|
|
+ " with torch.no_grad():\n",
|
|
|
+ "\n",
|
|
|
+ " y_hat = self.forward(x)\n",
|
|
|
+ "\n",
|
|
|
+ " return y_hat"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 56,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "loader = DataLoader(pair_matrix_probability, shuffle=True, batch_size=1000)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 57,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "embedding_trainer = embeddingTrainer(embedding_size=100)\n",
|
|
|
+ "embedding_trainer = embedding_trainer.to(\"cuda:0\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 58,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "Epoch 0 - Batch 0 - Loss : 0.7970905303955078\n",
|
|
|
+ "Epoch 0 - Loss : 0.45808276534080505\n",
|
|
|
+ "Epoch 1 - Batch 0 - Loss : 0.0421447679400444\n",
|
|
|
+ "Epoch 2 - Batch 0 - Loss : 0.025797121226787567\n",
|
|
|
+ "Epoch 3 - Batch 0 - Loss : 0.027662230655550957\n",
|
|
|
+ "Epoch 4 - Batch 0 - Loss : 0.02129991166293621\n",
|
|
|
+ "Epoch 5 - Batch 0 - Loss : 0.02649623528122902\n",
|
|
|
+ "Epoch 6 - Batch 0 - Loss : 0.025592397898435593\n",
|
|
|
+ "Epoch 7 - Batch 0 - Loss : 0.02580280229449272\n",
|
|
|
+ "Epoch 8 - Batch 0 - Loss : 0.0239135529845953\n",
|
|
|
+ "Epoch 9 - Batch 0 - Loss : 0.025206178426742554\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "n_epoch = 10\n",
|
|
|
+ "\n",
|
|
|
+ "n_print_epoch = 10\n",
|
|
|
+ "n_print_batch = 1000\n",
|
|
|
+ "\n",
|
|
|
+ "for i in range(n_epoch):\n",
|
|
|
+ " losses = []\n",
|
|
|
+ "\n",
|
|
|
+ " j = 0\n",
|
|
|
+ " for x in loader:\n",
|
|
|
+ " x_batch = x[:,[0,1]].int()\n",
|
|
|
+ " x_batch = x_batch.to(\"cuda:0\")\n",
|
|
|
+ " y_batch = x[:,2].float().unsqueeze(dim=1)\n",
|
|
|
+ " y_batch = y_batch.to(\"cuda:0\")\n",
|
|
|
+ "\n",
|
|
|
+ " loss = embedding_trainer.fit(x_batch, y_batch)\n",
|
|
|
+ " losses.append(loss)\n",
|
|
|
+ "\n",
|
|
|
+ " if j%n_print_batch == 0:\n",
|
|
|
+ " loss_mean = np.array(losses).mean()\n",
|
|
|
+ " print(f\"Epoch {i} - Batch {j} - Loss : {loss_mean}\")\n",
|
|
|
+ "\n",
|
|
|
+ " j += 1\n",
|
|
|
+ "\n",
|
|
|
+ " if i%n_print_epoch == 0:\n",
|
|
|
+ " loss_mean = np.array(losses).mean()\n",
|
|
|
+ " print(f\"Epoch {i} - Loss : {loss_mean}\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "# Export"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 59,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import pickle"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Encoder"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 60,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "with open(f\"./models/{variable}_encoder.model\",\"wb\") as f:\n",
|
|
|
+ " pickle.dump(encoder, f)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## Modele d'embedding"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 61,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "with open(f\"./models/{variable}_embedding.model\",\"wb\") as f:\n",
|
|
|
+ " torch.save(embedding_trainer.embeddings, f)"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "interpreter": {
|
|
|
+ "hash": "c304935560631f5a20c1bdabb506947800ccd82d813704000c078f0735b9b818"
|
|
|
+ },
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3.9.9 ('base')",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.9.9"
|
|
|
+ },
|
|
|
+ "orig_nbformat": 4
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 2
|
|
|
+}
|