Ali 2 лет назад
Родитель
Сommit
0963226b71
2 измененных файлов с 20 добавлено и 5 удалено
  1. 7 2
      bop_scripts/models.py
  2. 13 3
      bop_scripts/nn_models.py

+ 7 - 2
bop_scripts/models.py

@@ -158,16 +158,17 @@ def get_features_selection (X, y, classifier, categorical_variables, continuous_
 
     return scores
 
-def fit_all_classifiers(classifier, X_train, y_train, hide_warnings=True):
+def fit_all_classifiers(classifier_fn, X_train, y_train, hide_warnings=True, verbose=False):
     """
         This function fill all the models for each label.
 
         Parameters:
         ----------
-        model: Classifier with a fit method
+        classifier_fn: Function to raise a new classifier with fit method
         X: Pandas Dataframe of features
         y: Pandas Dataframe of labels
         hide_warnings: boolean, if true the warnings will be hidden
+        verbose: boolean, if true the trained model are printed
 
         Output:
         -------
@@ -180,6 +181,10 @@ def fit_all_classifiers(classifier, X_train, y_train, hide_warnings=True):
     labels = y_train.columns.tolist()
     classifiers = {}
     for label in labels:
+        if verbose:
+            print(f"Training model {label}")
+        
+        classifier = classifier_fn()
         classifiers[label] = classifier.fit(X_train, y_train[label])
 
     return classifiers

+ 13 - 3
bop_scripts/nn_models.py

@@ -22,7 +22,11 @@ class torchMLP (nn.Module):
             nn.ReLU(),
             nn.Linear(200, 50),
             nn.ReLU(),
-            nn.Linear(50, n_labels),
+            nn.Linear(50, 20),
+            nn.ReLU(),
+            nn.Linear(20, 10),
+            nn.ReLU(),
+            nn.Linear(10, n_labels),
             nn.Sigmoid()
         ])
 
@@ -38,7 +42,7 @@ class torchMLPClassifier_sklearn (BaseEstimator):
         Pytorch neural network with a sklearn-like API
     """
 
-    def __init__ (self, model, n_epochs=50, early_stop=True, early_stop_metric="accuracy", early_stop_validations_size=0.1, batch_size=1024, learning_rate=1e-3, class_weight=None, device_train="cpu", device_predict="cpu", verbose=False):
+    def __init__ (self, model, n_epochs=50, early_stop=True, early_stop_metric="accuracy", early_stop_validations_size=0.1, early_stop_tol=2, batch_size=1024, learning_rate=1e-3, class_weight=None, device_train="cpu", device_predict="cpu", verbose=False):
         """
             Parameters:
             -----------
@@ -47,6 +51,7 @@ class torchMLPClassifier_sklearn (BaseEstimator):
             early_stop: boolean, if true an evaluation dataset is created and used to stop the training
             early_stop_metric: str, metric score to evaluate the model, according to sklearn.metrics.SCORERS.keys()
             early_stop_validations_size: int or float, if float percentage of the train dataset used for validation, otherwise number of sample to use 
+            early_stop_tol: int, number of epoch with metric decrease before stopping training
             batch_size: int, size of the training batch
             learning_rate: float, Adam optimizer learning rate
             class_weight: dict or str, same as the sklearn API
@@ -63,6 +68,7 @@ class torchMLPClassifier_sklearn (BaseEstimator):
             self.early_stop_metric_name = early_stop_metric
             self.early_stop_metric = SCORERS[early_stop_metric]
             self.early_stop_validations_size = early_stop_validations_size
+            self.early_stop_tol = early_stop_tol
         else:
             self.early_stop = False
             self.early_stop_metric = None
@@ -124,6 +130,7 @@ class torchMLPClassifier_sklearn (BaseEstimator):
 
         # Running train
         last_score = 0
+        early_stop_count = 0
         for i in range(self.n_epochs):
 
             self.network = self.network.to(self.device_train)
@@ -158,8 +165,11 @@ class torchMLPClassifier_sklearn (BaseEstimator):
                 score = self.early_stop_metric(self, X_val, y_val)
 
                 if score < last_score:
-                    return self
+                    early_stop_count += 1
+                    if early_stop_count >= self.early_stop_tol:
+                        return self
                 else:
+                    early_stop_count = 0
                     last_score = score
 
             if self.verbose: