visualisation.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from matplotlib import pyplot as plt
  2. import seaborn as sns
  3. import pandas as pd
  4. from .preprocessing import get_Xy_df
  5. def plot_all_scatter (X, variables, ncols=3, figsize=(20,10)):
  6. """
  7. This function produce a scatter view of all the variables from a dataset
  8. Parameters
  9. ----------
  10. X: Pandas Dataframe
  11. variables: [str], list of variables name
  12. n_cols: int, number of columns in the plot
  13. figsize: (int, int), tuple of the figure size
  14. """
  15. # Getting nrows
  16. nrows = (len(variables) // ncols) + 1*((len(variables) % ncols) != 0)
  17. figs, axs = plt.subplots(nrows, ncols, figsize=figsize)
  18. axs = axs.flatten()
  19. for i in range(len(variables)):
  20. variable = variables[i]
  21. sns.scatterplot(
  22. data=X[variable].value_counts(),
  23. ax = axs[i]
  24. )
  25. axs[i].ticklabel_format(style='scientific', axis='x', scilimits=(0, 4))
  26. axs[i].set_xlabel("Valeur")
  27. axs[i].set_ylabel("Nombre d'occurences")
  28. axs[i].set_title(variable)
  29. plt.tight_layout()
  30. def plot_missing_outcome(X, y, features, labels, figsize=(20,10)):
  31. """
  32. This function produce a line plot of all the missings values according to the outcomes values
  33. Parameters
  34. ----------
  35. X: Pandas Dataframe of features
  36. y: Pandas Dataframe of labels
  37. features: [str], list of variables name
  38. labels: [str], list of output name
  39. figsize: (int, int), tuple of the figure size
  40. """
  41. Xy = get_Xy_df(X, y)
  42. data = Xy[labels].join(
  43. pd.DataFrame(Xy[features].isna().astype("int").sum(axis=1))
  44. ).rename(columns={0:"n_NA"}) \
  45. .groupby("n_NA") \
  46. .agg(lambda x: x.sum()/x.count())
  47. fig,ax = plt.subplots(1, 1, figsize=figsize)
  48. sns.lineplot(
  49. data=pd.melt(data.reset_index(), id_vars="n_NA",value_vars=data.columns),
  50. hue="variable",
  51. x="n_NA",
  52. y="value",
  53. ax=ax
  54. )
  55. ax.set_xlabel("Nombre de valeurs manquantes")
  56. ax.set_ylabel("Pourcentage d'examen prescrit")
  57. ax.set_title("% de prescription de bilans en fonction du nombre de variables manquantes")
  58. def plot_missing_bar(X, features, figsize=(15,10)):
  59. fig, ax = plt.subplots(1,1, figsize=figsize)
  60. data = (X[features].isna()*1).mean().reset_index()
  61. sns.barplot(
  62. data=data,
  63. x="index",
  64. y=0,
  65. ax=ax
  66. )
  67. ax.set_title("% de valeurs manquantes par variable")
  68. ax.set_xlabel("Variable")
  69. ax.set_ylabel("% de valeurs manquantes")