|
@@ -22,7 +22,11 @@ class torchMLP (nn.Module):
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(200, 50),
|
|
|
nn.ReLU(),
|
|
|
- nn.Linear(50, n_labels),
|
|
|
+ nn.Linear(50, 20),
|
|
|
+ nn.ReLU(),
|
|
|
+ nn.Linear(20, 10),
|
|
|
+ nn.ReLU(),
|
|
|
+ nn.Linear(10, n_labels),
|
|
|
nn.Sigmoid()
|
|
|
])
|
|
|
|
|
@@ -38,7 +42,7 @@ class torchMLPClassifier_sklearn (BaseEstimator):
|
|
|
Pytorch neural network with a sklearn-like API
|
|
|
"""
|
|
|
|
|
|
- def __init__ (self, model, n_epochs=50, early_stop=True, early_stop_metric="accuracy", early_stop_validations_size=0.1, batch_size=1024, learning_rate=1e-3, class_weight=None, device_train="cpu", device_predict="cpu", verbose=False):
|
|
|
+ def __init__ (self, model, n_epochs=50, early_stop=True, early_stop_metric="accuracy", early_stop_validations_size=0.1, early_stop_tol=2, batch_size=1024, learning_rate=1e-3, class_weight=None, device_train="cpu", device_predict="cpu", verbose=False):
|
|
|
"""
|
|
|
Parameters:
|
|
|
-----------
|
|
@@ -47,6 +51,7 @@ class torchMLPClassifier_sklearn (BaseEstimator):
|
|
|
early_stop: boolean, if true an evaluation dataset is created and used to stop the training
|
|
|
early_stop_metric: str, metric score to evaluate the model, according to sklearn.metrics.SCORERS.keys()
|
|
|
early_stop_validations_size: int or float, if float percentage of the train dataset used for validation, otherwise number of sample to use
|
|
|
+ early_stop_tol: int, number of epoch with metric decrease before stopping training
|
|
|
batch_size: int, size of the training batch
|
|
|
learning_rate: float, Adam optimizer learning rate
|
|
|
class_weight: dict or str, same as the sklearn API
|
|
@@ -63,6 +68,7 @@ class torchMLPClassifier_sklearn (BaseEstimator):
|
|
|
self.early_stop_metric_name = early_stop_metric
|
|
|
self.early_stop_metric = SCORERS[early_stop_metric]
|
|
|
self.early_stop_validations_size = early_stop_validations_size
|
|
|
+ self.early_stop_tol = early_stop_tol
|
|
|
else:
|
|
|
self.early_stop = False
|
|
|
self.early_stop_metric = None
|
|
@@ -124,6 +130,7 @@ class torchMLPClassifier_sklearn (BaseEstimator):
|
|
|
|
|
|
# Running train
|
|
|
last_score = 0
|
|
|
+ early_stop_count = 0
|
|
|
for i in range(self.n_epochs):
|
|
|
|
|
|
self.network = self.network.to(self.device_train)
|
|
@@ -158,8 +165,11 @@ class torchMLPClassifier_sklearn (BaseEstimator):
|
|
|
score = self.early_stop_metric(self, X_val, y_val)
|
|
|
|
|
|
if score < last_score:
|
|
|
- return self
|
|
|
+ early_stop_count += 1
|
|
|
+ if early_stop_count >= self.early_stop_tol:
|
|
|
+ return self
|
|
|
else:
|
|
|
+ early_stop_count = 0
|
|
|
last_score = score
|
|
|
|
|
|
if self.verbose:
|