visualisation.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. from matplotlib import pyplot as plt
  2. from sklearn.feature_extraction.text import CountVectorizer
  3. from sklearn.metrics import roc_curve, roc_auc_score, precision_score, recall_score, accuracy_score, ConfusionMatrixDisplay, f1_score, confusion_matrix
  4. from wordcloud import WordCloud
  5. import seaborn as sns
  6. import itertools
  7. import pandas as pd
  8. import numpy as np
  9. import random
  10. from .preprocessing import get_Xy_df
  11. def plot_all_scatter (X, variables, ncols=3, figsize=(20,10)):
  12. """
  13. This function produce a scatter view of all the variables from a dataset
  14. Parameters
  15. ----------
  16. X: Pandas Dataframe
  17. variables: [str], list of variables name
  18. n_cols: int, number of columns in the plot
  19. figsize: (int, int), tuple of the figure size
  20. """
  21. # Getting nrows
  22. nrows = (len(variables) // ncols) + 1*((len(variables) % ncols) != 0)
  23. figs, axs = plt.subplots(nrows, ncols, figsize=figsize)
  24. axs = axs.flatten()
  25. for i in range(len(variables)):
  26. variable = variables[i]
  27. sns.scatterplot(
  28. data=X[variable].value_counts(),
  29. ax = axs[i]
  30. )
  31. axs[i].ticklabel_format(style='scientific', axis='x', scilimits=(0, 4))
  32. axs[i].set_xlabel("Valeur")
  33. axs[i].set_ylabel("Nombre d'occurences")
  34. axs[i].set_title(variable)
  35. plt.tight_layout()
  36. def plot_missing_outcome(X, y, features, labels, figsize=(20,10)):
  37. """
  38. This function produce a line plot of all the missings values according to the outcomes values
  39. Parameters
  40. ----------
  41. X: Pandas Dataframe of features
  42. y: Pandas Dataframe of labels
  43. features: [str], list of variables name
  44. labels: [str], list of output name
  45. figsize: (int, int), tuple of the figure size
  46. """
  47. Xy = get_Xy_df(X, y)
  48. data = Xy[labels].join(
  49. pd.DataFrame(Xy[features].isna().astype("int").sum(axis=1))
  50. ).rename(columns={0:"n_NA"}) \
  51. .groupby("n_NA") \
  52. .agg(lambda x: x.sum()/x.count())
  53. fig,ax = plt.subplots(1, 1, figsize=figsize)
  54. sns.lineplot(
  55. data=pd.melt(data.reset_index(), id_vars="n_NA",value_vars=data.columns),
  56. hue="variable",
  57. x="n_NA",
  58. y="value",
  59. ax=ax
  60. )
  61. ax.set_xlabel("Nombre de valeurs manquantes")
  62. ax.set_ylabel("Pourcentage d'examen prescrit")
  63. ax.set_title("% de prescription de bilans en fonction du nombre de variables manquantes")
  64. def plot_missing_bar(X, features, figsize=(15,10)):
  65. """
  66. This function produce a bar plot of all the missings values
  67. Parameters
  68. ----------
  69. X: Pandas Dataframe of features
  70. features: [str], list of variables name
  71. figsize: (int, int), tuple of the figure size
  72. """
  73. fig, ax = plt.subplots(1,1, figsize=figsize)
  74. data = (X[features].isna()*1).mean().reset_index()
  75. sns.barplot(
  76. data=data,
  77. x="index",
  78. y=0,
  79. ax=ax
  80. )
  81. ax.set_title("% de valeurs manquantes par variable")
  82. ax.set_xlabel("Variable")
  83. ax.set_ylabel("% de valeurs manquantes")
  84. def plot_correlation(X, features, figsize=(10,6)):
  85. """
  86. This function produce a heatmap plot of all variables correlation values
  87. Parameters
  88. ----------
  89. X: Pandas Dataframe of features
  90. features: [str], list of variables name
  91. figsize: (int, int), tuple of the figure size
  92. """
  93. fig, ax = plt.subplots(figsize = figsize)
  94. correlation_matrix = X[features].corr()
  95. sns.heatmap(
  96. correlation_matrix,
  97. cmap='YlGn',
  98. ax=ax
  99. )
  100. ax.set_title('Corrélations entre les features');
  101. def plot_labels_frequencies_and_correlation(y, labels, figsize=(30,10)):
  102. """
  103. This function produce a bar of label proportion and heatmap plot of all labels correlation values
  104. Parameters
  105. ----------
  106. y: Pandas Dataframe of labels
  107. labels: [str], list of labels name
  108. figsize: (int, int), tuple of the figure size
  109. """
  110. fig, axs = plt.subplots(1, 2, figsize=figsize)
  111. axs = axs.flatten()
  112. # Plotting labels proportion
  113. labels_data = ((y[labels].sum()/y.shape[0])*100).reset_index().round(2)
  114. sns.barplot(
  115. data=labels_data,
  116. x="index",
  117. y=0,
  118. ax=axs[0]
  119. )
  120. axs[0].tick_params(labelrotation=45)
  121. axs[0].set_ylim(0,100)
  122. axs[0].set_title("Proportion d'examens biologiques réalisés")
  123. axs[0].set_xlabel("Examens biologiques")
  124. axs[0].set_ylabel("% d'examens réalisés")
  125. # Plotting correlation
  126. correlation_data = y[labels].corr()
  127. sns.heatmap(correlation_data, ax=axs[1], cmap='YlGn')
  128. axs[1].set_title('Correlations entre les labels');
  129. def plot_box_variable_label_distribution(X, y, features, labels):
  130. """
  131. This function produce a box plot of the features distribution according to the variable status
  132. Parameters
  133. ----------
  134. X: Pandas Dataframe of features
  135. y: Pandas Dataframe of labels
  136. features: [str], list of variables name
  137. labels: [str], list of output name
  138. """
  139. # Generating colormap
  140. colors = sns.color_palette("muted", 2*len(features))
  141. # Getting Xy dataframe
  142. Xy = get_Xy_df(X, y)
  143. fig = plt.figure(constrained_layout=True, figsize=(5*len(labels),2*len(features)))
  144. figs = fig.subfigures(len(labels), 1)
  145. axs = [x.subplots(1, len(features)) for x in figs]
  146. for i in range(len(labels)):
  147. figs[i].suptitle(f"Distribution des variables selon le statut {labels[i]} (réalisé (1) ou non (0))")
  148. for j in range(len(features)):
  149. feature_name, variable_name = features[j], labels[i]
  150. axs[i][j].set_title(feature_name)
  151. axs[i][j].set_xlabel(variable_name)
  152. axs[i][j].set_ylabel(feature_name)
  153. sns.boxplot(
  154. data=Xy,
  155. ax=axs[i][j],
  156. x=variable_name,
  157. y=feature_name,
  158. showfliers=False,
  159. palette=colors[j*2:(j+1)*2]
  160. )
  161. fig.suptitle("Distribution des features en fonction du label")
  162. plt.show()
  163. def plot_odd_word_wc(X, y, text_column, labels, min_occurrence=3, ncols=5):
  164. """
  165. This function produce a word cloud of words odd-ratio (odd-ratio of seing the word given the label)
  166. Parameters
  167. ----------
  168. X: Pandas Dataframe of features
  169. y: Pandas Dataframe of labels
  170. text_column: str, name of the column containing the text
  171. labels: [str], list of output name
  172. min_occurrence: int, minimum number of ocurrence of the word
  173. ncols: int, number of columns in the output plot
  174. """
  175. # Computing nrows an getting the structure
  176. nrows = len(labels)//ncols + 1*((len(labels)%ncols) != 0)
  177. fig = plt.figure(constrained_layout=True, figsize=(4*ncols, 5*nrows))
  178. figs = fig.subfigures(nrows, ncols)
  179. figs = figs.flatten()
  180. axs = [x.subplots(2, 1) for x in figs]
  181. def rand_color_label0(*args, **kwargs):
  182. return "rgb(0, 100, {})".format(random.randint(200, 455))
  183. def rand_color_label1(*args, **kwargs):
  184. return "rgb({}, 0, 100)".format(random.randint(200, 455))
  185. color_fn = [rand_color_label0, rand_color_label1]
  186. # Getting Xy
  187. Xy = get_Xy_df(X, y)
  188. # Text preprocessing
  189. Xy = Xy.dropna(subset=[text_column])
  190. Xy["text_preprocessed"] = Xy[text_column] \
  191. .replace(",", " ").str.lower()
  192. # Generating the plots
  193. for i in range(len(labels)):
  194. label = labels[i]
  195. figs[i].suptitle(label)
  196. # Filtering text data
  197. text_data = Xy[[label, "chiefcomplaint"]].dropna().groupby(label).agg(lambda x: " ".join(x))["chiefcomplaint"]
  198. # Training countvectorizer model then counting the odd
  199. cv = CountVectorizer().fit(text_data)
  200. text_data_array = (cv.transform(text_data).toarray()+1) # Smoothing count
  201. text_data_array[:,np.where(text_data_array <= (min_occurrence+1))[1]] = 1 # Set the odds to neutral odd
  202. text_data_array = text_data_array/text_data_array.sum(axis=1).reshape(2, -1)
  203. for j, text in text_data.items():
  204. values = (text_data_array[j,:]/(text_data_array[1-j,:])).tolist()
  205. axs[i][j].imshow(
  206. WordCloud(background_color = "white", relative_scaling=0.2, max_words = 25, color_func=color_fn[j]).generate_from_frequencies(
  207. frequencies=dict(zip(
  208. cv.get_feature_names(),
  209. values
  210. ))
  211. )
  212. )
  213. axs[i][j].set_xlabel(f"{j}")
  214. fig.suptitle("WordCloud selon le label")
  215. plt.show()
  216. def vizualize_features_selection (scores, score_name, f_precision=2, n_score_max=5, ncols=3):
  217. """
  218. This function produce an heatmap of metrics score according to each variables combination
  219. Parameters
  220. ----------
  221. scores: Dictionnary containing a list of combination and associated score for each label produced by the .models.get_features_selection function
  222. score_name: str, Name of the score
  223. f_precision: int, floating point precision is the number of decimal to keep
  224. n_score_max: int, maximum number of scores to display
  225. ncols: int, number of columns in the output plot
  226. """
  227. # Creating a dataframe containing the scores
  228. scores_df = []
  229. for key, value in scores.items():
  230. scores_df_temp = pd.DataFrame(
  231. [dict(zip(x[0], [x[1] for i in range(len(x[0]))])) for x in value]
  232. ).assign(score=lambda x: x.max(axis=1))
  233. scores_df_temp.iloc[:,:-1] = (scores_df_temp.iloc[:,:-1].fillna("")*0).astype("str").replace("0.0", "x")
  234. scores_df_temp["name"] = key
  235. scores_df.append(scores_df_temp.sort_values("score", ascending=False))
  236. scores_df = pd.concat(scores_df).reset_index(drop=True)
  237. scores_df["n_features"] = (scores_df == "x").sum(axis=1)
  238. scores_df[score_name] = scores_df["score"].round(f_precision)
  239. scores_df = scores_df.sort_values(["name", "roc_auc", score_name], ascending=[True, False, True]).drop_duplicates(["name", score_name])
  240. # Plotting the dataframe
  241. scores_list = scores_df["name"].drop_duplicates().values.tolist()
  242. ncols = 3
  243. nrows = len(scores_list)//ncols + (len(scores_list)%ncols != 0)*1
  244. fig, axs = plt.subplots(nrows, ncols, figsize=(5*ncols,4*nrows))
  245. axs = axs.flatten()
  246. for i in range(len(scores_list)):
  247. score = scores_list[i]
  248. sns.heatmap(
  249. (scores_df.query(f"name == '{score}'").set_index("roc_auc").head(n_score_max).iloc[:, :-3] == 'x')*1,
  250. ax=axs[i]
  251. )
  252. axs[i].set_title(score)
  253. fig.suptitle(f"{score_name} according to features included in the model")
  254. plt.tight_layout()
  255. def display_model_performances(classifier, X_test, y_test, algorithm_name="", threshold=0.5, ncols=1):
  256. """
  257. This function produce a vizualization of the model performances
  258. Parameters
  259. ----------
  260. classifier: python object which should contains a predict and a predict_proba method, if many labels a dict in the format {label:classifier,...} is expected
  261. X_test: pandas dataframe of the features
  262. y_test: pandas dataframe of the labels
  263. algorithm_name: str, name of the algorithm
  264. threshold: float, threshold for classification
  265. ncols: int, number of columns
  266. """
  267. # Checking type of y_test
  268. if isinstance(y_test, pd.Series):
  269. y_test = pd.DataFrame(y_test)
  270. # Checking if one or many labels
  271. if len(y_test.shape) > 1 and y_test.shape[1] > 1:
  272. if isinstance(classifier, dict) == False or len(classifier.keys()) != y_test.shape[1]:
  273. raise ValueError("You should provide as many classifier than labels")
  274. else:
  275. if isinstance(classifier, dict) == False:
  276. classifier = {y_test.columns[0]:classifier}
  277. labels = y_test.columns.tolist()
  278. # Construction of the pyplot object
  279. nrows = (len(labels)//ncols) + ((len(labels)%ncols)!=0)*1
  280. fig = plt.figure(constrained_layout=True, figsize=(15*ncols,7*nrows))
  281. figs = fig.subfigures(nrows, ncols)
  282. figs = figs.flatten()
  283. if (ncols+nrows) == 1:
  284. figs = [figs]
  285. axs = [x.subplots(1, 2) for x in figs]
  286. # For each label :
  287. for i in range(len(labels)):
  288. label = labels[i]
  289. label_classifier = classifier[label]
  290. figs[i].suptitle(label)
  291. y_test_true = y_test[label].values
  292. y_test_hat_proba = label_classifier.predict_proba(X_test)[:,1]
  293. y_test_hat = (y_test_hat_proba >= threshold)*1
  294. # Computation of metrics
  295. f1_score_, accuracy_score_, recall_score_, precision_score_ = [x(y_test_true, y_test_hat) for x in [f1_score, accuracy_score, recall_score, precision_score]]
  296. auc_score_ = roc_auc_score(y_test_true, y_test_hat_proba)
  297. confusion_matrix_ = confusion_matrix(y_test_true, y_test_hat)
  298. # Plotting
  299. ## Confusion matrix
  300. ConfusionMatrixDisplay(
  301. confusion_matrix_,
  302. display_labels=[0, 1]
  303. ).plot(
  304. ax=axs[i][0]
  305. )
  306. ## ROC curve
  307. fpr, tpr, thresholds = roc_curve(
  308. y_test_true,
  309. y_test_hat_proba
  310. )
  311. axs[i][1].plot(
  312. fpr,
  313. tpr,
  314. label=f"AUC: {auc_score_:.2f}\nF1-Score: {f1_score_:.2f}\nRecall: {recall_score_:.2f}\nPrecision: {precision_score_:.2f}\nAccuracy: {accuracy_score_:.2f}"
  315. )
  316. axs[i][1].legend(loc=4, fontsize="x-large")
  317. axs[i][1].set_ylabel('Taux de vrai positifs')
  318. axs[i][1].set_xlabel('Taux de faux positifs')
  319. fig.suptitle(f"Performance de l'algorithme {algorithm_name} avec un threshold de {threshold}")