import pandas as pd import seaborn as sns import matplotlib.pyplot as plt from sklearn.ensemble import RandomForestClassifier from sklearn.svm import SVC from sklearn.linear_model import SGDClassifier, LinearRegression from sklearn.metrics import confusion_matrix, classification_report, r2_score, accuracy_score from sklearn.neural_network import BernoulliRBM, MLPClassifier, MLPRegressor from sklearn.preprocessing import StandardScaler, LabelEncoder from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score class WinePredict: def __init__(self): self.wine = pd.read_csv('./wine.csv', sep=';') X = self.wine.drop('quality', axis=1) y = self.wine['quality'] self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(X, y, test_size=0.1, random_state=42) # 特征归一化 sc = StandardScaler() self.X_train = sc.fit_transform(self.X_train) self.X_test = sc.fit_transform(self.X_test) # 线性回归 def lr(self): lr = LinearRegression() lr.fit(self.X_train, self.y_train) return lr.predict(self.X_test) # 随机森林 def rfc(self): rfc = RandomForestClassifier(n_estimators=200) rfc.fit(self.X_train, self.y_train) return rfc.predict(self.X_test) # 随机梯度下降 # 0.2 极其不稳定 def sgd(self): sgd = SGDClassifier(penalty=None) sgd.fit(self.X_train, self.y_train) return sgd.predict(self.X_test) # 支持向量机 # 0.23 -> 0.25 def svc(self): svc = SVC(C=1.4, gamma=0.8, kernel='rbf') svc.fit(self.X_train, self.y_train) return svc.predict(self.X_test) def mlp(self): mlp = MLPClassifier([10, 6], learning_rate_init=0.001, activation='relu', solver='adam', alpha=0.0001, max_iter=30000) # 神经网络 mlp.fit(self.X_train, self.y_train) return mlp.predict(self.X_test) # 参数调优 def grid_search(self, model, param): grid_svc = GridSearchCV(model, param_grid=param, scoring='accuracy', cv=10) grid_svc.fit(self.X_train, self.y_train) return grid_svc.best_params_ def gs_svc(self): param = { 'C': [0.1, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4], 'kernel': ['linear', 'rbf'], 'gamma': [0.1, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4] } print(self.grid_search(SVC, param)) # {'C': 1.4, 'gamma': 0.8, 'kernel': 'rbf'} def report(self, fc): r = fc() if r.dtype == 'float64' or r.dtype == 'float32': r = r.round() # print(classification_report(self.y_test, r)) print(fc.__name__) print(" R2: %f" % r2_score(self.y_test, r)) print(" accuracy: %f" % accuracy_score(self.y_test, r)) def showXY(self): # fig = plt.figure(figsize=(10, 6)) for i in range(len(self.wine.columns[:-1])): sns.barplot(x='quality', y=self.wine.columns[i], data=self.wine, ax=plt.subplot(4, 4, i + 1)) plt.xlabel(self.wine.columns[i]) plt.ylabel('') plt.tight_layout() plt.show() if __name__ == '__main__': wp = WinePredict() wp.report(wp.lr) # wp.showXY()