nn_models.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. from random import sample
  2. from sklearn.base import BaseEstimator
  3. from sklearn.metrics import SCORERS
  4. from sklearn.model_selection import train_test_split
  5. from sklearn.utils import check_X_y, check_array
  6. import torch
  7. from torch import nn, optim
  8. from torch.utils.data import DataLoader
  9. from scipy.sparse import issparse
  10. import numpy as np
  11. class torchMLP (nn.Module):
  12. """
  13. Neural network model for
  14. """
  15. def __init__(self, n_features, n_labels):
  16. super().__init__()
  17. self.network = nn.Sequential(*[
  18. nn.Linear(n_features, 200),
  19. nn.ReLU(),
  20. nn.Linear(200, 50),
  21. nn.ReLU(),
  22. nn.Linear(50, 20),
  23. nn.ReLU(),
  24. nn.Linear(20, 10),
  25. nn.ReLU(),
  26. nn.Linear(10, n_labels),
  27. nn.Sigmoid()
  28. ])
  29. def forward(self, x):
  30. y_hat = self.network(x)
  31. return y_hat
  32. class torchMLPClassifier_sklearn (BaseEstimator):
  33. """
  34. Pytorch neural network with a sklearn-like API
  35. """
  36. def __init__ (self, model, n_epochs=50, early_stop=True, early_stop_metric="accuracy", early_stop_validations_size=0.1, early_stop_tol=2, batch_size=1024, learning_rate=1e-3, class_weight=None, device_train="cpu", device_predict="cpu", verbose=False):
  37. """
  38. Parameters:
  39. -----------
  40. model: non instanciated pytorch neural network model with a n_features and n_labels parameter
  41. n_epochs: int, number of epochs
  42. early_stop: boolean, if true an evaluation dataset is created and used to stop the training
  43. early_stop_metric: str, metric score to evaluate the model, according to sklearn.metrics.SCORERS.keys()
  44. early_stop_validations_size: int or float, if float percentage of the train dataset used for validation, otherwise number of sample to use
  45. early_stop_tol: int, number of epoch with metric decrease before stopping training
  46. batch_size: int, size of the training batch
  47. learning_rate: float, Adam optimizer learning rate
  48. class_weight: dict or str, same as the sklearn API
  49. device_train: str, device on which to train
  50. device_predict: str, device on which to predict
  51. verbose: boolean, if true the loss and score are printed
  52. """
  53. self.model = model
  54. self.n_epochs = n_epochs
  55. if early_stop and (early_stop_metric is not None) and (early_stop_metric in SCORERS.keys()) and (isinstance(early_stop_validations_size, int) or isinstance(early_stop_validations_size, float)):
  56. self.early_stop = early_stop
  57. self.early_stop_metric_name = early_stop_metric
  58. self.early_stop_metric = SCORERS[early_stop_metric]
  59. self.early_stop_validations_size = early_stop_validations_size
  60. self.early_stop_tol = early_stop_tol
  61. else:
  62. self.early_stop = False
  63. self.early_stop_metric = None
  64. self.early_stop_validations_size = None
  65. self.class_weight = class_weight
  66. self.learning_rate = learning_rate
  67. self.device_train = device_train
  68. self.device_predict = device_predict
  69. self.batch_size = batch_size
  70. self.verbose = verbose
  71. def fit(self, X, y):
  72. """
  73. Training the model
  74. Parameters:
  75. -----------
  76. X_test: pandas dataframe of the features
  77. y_test: pandas dataframe of the labels
  78. """
  79. X, y = check_X_y(X, y, accept_sparse=True, multi_output=True)
  80. if y.ndim == 1:
  81. y = np.expand_dims(y, 1)
  82. # Validation split if early stopping
  83. if self.early_stop:
  84. X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=self.early_stop_validations_size)
  85. if issparse(X_val): # To deal with the sparse matrix situations
  86. X_val = X_val.toarray()
  87. else:
  88. X_train, y_train = X, y
  89. n_samples = y_train.shape[0]
  90. n_labels_values = len(np.unique(y_train))
  91. n_labels = y_train.shape[1]
  92. n_features = X.shape[1]
  93. # Raising the model
  94. self.network = self.model(n_features=n_features, n_labels=n_labels)
  95. self.optimizer = optim.Adam(self.network.parameters(), lr=self.learning_rate)
  96. # Creating dataloader for X_train, y_train
  97. data_loader = DataLoader(range(X_train.shape[0]), shuffle=True, batch_size=self.batch_size)
  98. # Initializing loss function
  99. ## Getting weights
  100. if self.class_weight is not None:
  101. if self.class_weight == "balanced":
  102. weights = n_samples/(n_labels_values*np.bincount(y_train[:,0]))
  103. weights_dict = dict(zip(range(len(weights)), weights))
  104. else:
  105. weights_dict = self.class_weight
  106. else:
  107. weights_dict = None
  108. criterion = nn.BCELoss()
  109. # Running train
  110. last_score = 0
  111. early_stop_count = 0
  112. for i in range(self.n_epochs):
  113. self.network = self.network.to(self.device_train)
  114. # Starting an epoch
  115. for indices in data_loader:
  116. self.optimizer.zero_grad()
  117. X_train_sample, y_train_sample = X_train[indices, :], y_train[indices, :]
  118. if issparse(X_train_sample): # To deal with the sparse matrix situations
  119. X_train_sample = X_train_sample.toarray()
  120. X_train_sample_tensor, y_train_sample_tensor = [torch.tensor(x, dtype=torch.float32).to(self.device_train) for x in [X_train_sample, y_train_sample]]
  121. # Weighting the loss
  122. if self.class_weight is not None:
  123. sample_weights = y_train_sample.copy()
  124. for x, y in weights_dict.items():
  125. sample_weights[sample_weights == x] = y
  126. criterion.weigths = sample_weights
  127. # Get prediction
  128. X_train_sample_tensor, y_train_sample_tensor = X_train_sample_tensor.to(self.device_train), y_train_sample_tensor.to(self.device_train)
  129. y_train_sample_hat = self.network(X_train_sample_tensor)
  130. loss = criterion(y_train_sample_hat, y_train_sample_tensor)
  131. loss.backward()
  132. self.optimizer.step()
  133. # End of the Epoch : evaluating the score
  134. if self.early_stop:
  135. score = self.early_stop_metric(self, X_val, y_val)
  136. if score < last_score:
  137. early_stop_count += 1
  138. if early_stop_count >= self.early_stop_tol:
  139. return self
  140. else:
  141. early_stop_count = 0
  142. last_score = score
  143. if self.verbose:
  144. if self.early_stop:
  145. print(f"Epoch {i} : Loss {loss.item():.3f} - {self.early_stop_metric_name} {score:.3f}")
  146. else:
  147. print(f"Epoch {i} : Loss {loss.item():.3f}")
  148. return self
  149. def predict(self, X):
  150. """
  151. Getting the prediction
  152. Parameters:
  153. -----------
  154. X_test: pandas dataframe of the features
  155. """
  156. y_hat_proba = self.predict_raw_proba(X)
  157. y_hat = ((y_hat_proba >= 0.5)*1).flatten()
  158. return y_hat
  159. def predict_raw_proba(self, X):
  160. """
  161. Getting the prediction score in tensor format
  162. Parameters:
  163. -----------
  164. X_test: pandas dataframe of the features
  165. """
  166. X = check_array(X, accept_sparse=True)
  167. if issparse(X): # To deal with the sparse matrix situations
  168. X = X.toarray()
  169. with torch.no_grad():
  170. model_predict = self.network.to(self.device_predict)
  171. model_predict.eval()
  172. # Create a tensor from X
  173. X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device_predict)
  174. y_hat_proba_torch = model_predict(X_tensor)
  175. y_hat_proba_torch = y_hat_proba_torch.detach().cpu().numpy()
  176. return y_hat_proba_torch
  177. def predict_proba(self, X):
  178. """
  179. Getting the prediction score in sklearn format
  180. Parameters:
  181. -----------
  182. X_test: pandas dataframe of the features
  183. """
  184. y_hat_proba_torch = self.predict_raw_proba(X)
  185. y_hat_proba_torch = np.concatenate([
  186. 1-y_hat_proba_torch,
  187. y_hat_proba_torch
  188. ], axis=1)
  189. y_hat_proba = y_hat_proba_torch
  190. return y_hat_proba