|
@@ -63,6 +63,8 @@ class torchMLPClassifier_sklearn (BaseEstimator):
|
|
|
self.early_stop_validations_size = early_stop_validations_size
|
|
|
else:
|
|
|
self.early_stop = False
|
|
|
+ self.early_stop_metric = None
|
|
|
+ self.early_stop_validations_size = None
|
|
|
|
|
|
self.class_weight = class_weight
|
|
|
self.learning_rate = learning_rate
|
|
@@ -168,7 +170,7 @@ class torchMLPClassifier_sklearn (BaseEstimator):
|
|
|
"""
|
|
|
|
|
|
y_hat_proba = self.predict_raw_proba(X)
|
|
|
- y_hat = (y_hat_proba >= 0.5)*1
|
|
|
+ y_hat = ((y_hat_proba >= 0.5)*1).flatten()
|
|
|
|
|
|
return y_hat
|
|
|
|