Ali Bellamine 2 năm trước cách đây
mục cha
commit
8dd43b0f9a
3 tập tin đã thay đổi với 8 bổ sung6 xóa
  1. 4 4
      Project_Report.ipynb
  2. 3 1
      bop_scripts/nn_models.py
  3. 1 1
      bop_scripts/visualisation.py

Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 4 - 4
Project_Report.ipynb


+ 3 - 1
bop_scripts/nn_models.py

@@ -63,6 +63,8 @@ class torchMLPClassifier_sklearn (BaseEstimator):
             self.early_stop_validations_size = early_stop_validations_size
         else:
             self.early_stop = False
+            self.early_stop_metric = None
+            self.early_stop_validations_size = None
 
         self.class_weight = class_weight
         self.learning_rate = learning_rate
@@ -168,7 +170,7 @@ class torchMLPClassifier_sklearn (BaseEstimator):
         """
 
         y_hat_proba = self.predict_raw_proba(X)
-        y_hat = (y_hat_proba >= 0.5)*1
+        y_hat = ((y_hat_proba >= 0.5)*1).flatten()
 
         return y_hat
 

+ 1 - 1
bop_scripts/visualisation.py

@@ -345,7 +345,7 @@ def display_model_performances(classifier, X_test, y_test, algorithm_name="", th
     fig = plt.figure(constrained_layout=True, figsize=(15*ncols,7*nrows))
     figs = fig.subfigures(nrows, ncols)
     figs = figs.flatten()
-    if len(labels) == 1:
+    if (ncols+nrows) == 1:
         figs = [figs]
     axs = [x.subplots(1, 2) for x in figs]
 

Một số tệp đã không được hiển thị bởi vì quá nhiều tập tin thay đổi trong này khác