classifier.py 403 B

1234567891011121314151617181920
  1. import numpy as np
  2. from sklearn.base import BaseEstimator
  3. class Classifier(BaseEstimator):
  4. def fit(self, X, y):
  5. self.pred = y.mean(axis=0)
  6. return self
  7. def predict_proba(self, X):
  8. y_pred = np.repeat(self.pred.reshape(1, -1), X.shape[0], axis=0)
  9. return y_pred
  10. def predict(self, X):
  11. y_pred = self.predict_proba(X)
  12. return (y_pred >= 0.5)*1