2021-01-05 21:29:58 +08:00

96 lines
3.2 KiB
Python

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.linear_model import SGDClassifier, LinearRegression
from sklearn.metrics import confusion_matrix, classification_report, r2_score, accuracy_score
from sklearn.neural_network import BernoulliRBM, MLPClassifier, MLPRegressor
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
class WinePredict:
def __init__(self):
self.wine = pd.read_csv('./wine.csv', sep=';')
X = self.wine.drop('quality', axis=1)
y = self.wine['quality']
self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(X, y, test_size=0.1, random_state=42)
# 特征归一化
sc = StandardScaler()
self.X_train = sc.fit_transform(self.X_train)
self.X_test = sc.fit_transform(self.X_test)
# 线性回归
def lr(self):
lr = LinearRegression()
lr.fit(self.X_train, self.y_train)
return lr.predict(self.X_test)
# 随机森林
def rfc(self):
rfc = RandomForestClassifier(n_estimators=200)
rfc.fit(self.X_train, self.y_train)
return rfc.predict(self.X_test)
# 随机梯度下降
# 0.2 极其不稳定
def sgd(self):
sgd = SGDClassifier(penalty=None)
sgd.fit(self.X_train, self.y_train)
return sgd.predict(self.X_test)
# 支持向量机
# 0.23 -> 0.25
def svc(self):
svc = SVC(C=1.4, gamma=0.8, kernel='rbf')
svc.fit(self.X_train, self.y_train)
return svc.predict(self.X_test)
def mlp(self):
mlp = MLPClassifier([10, 6], learning_rate_init=0.001, activation='relu', solver='adam', alpha=0.0001,
max_iter=30000)
# 神经网络
mlp.fit(self.X_train, self.y_train)
return mlp.predict(self.X_test)
# 参数调优
def grid_search(self, model, param):
grid_svc = GridSearchCV(model, param_grid=param, scoring='accuracy', cv=10)
grid_svc.fit(self.X_train, self.y_train)
return grid_svc.best_params_
def gs_svc(self):
param = {
'C': [0.1, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4],
'kernel': ['linear', 'rbf'],
'gamma': [0.1, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4]
}
print(self.grid_search(SVC, param))
# {'C': 1.4, 'gamma': 0.8, 'kernel': 'rbf'}
def report(self, fc):
r = fc()
if r.dtype == 'float64' or r.dtype == 'float32':
r = r.round()
# print(classification_report(self.y_test, r))
print(fc.__name__)
print(" R2: %f" % r2_score(self.y_test, r))
print(" accuracy: %f" % accuracy_score(self.y_test, r))
def showXY(self):
# fig = plt.figure(figsize=(10, 6))
for i in range(len(self.wine.columns[:-1])):
sns.barplot(x='quality', y=self.wine.columns[i], data=self.wine, ax=plt.subplot(4, 4, i + 1))
plt.xlabel(self.wine.columns[i])
plt.ylabel('')
plt.tight_layout()
plt.show()
if __name__ == '__main__':
wp = WinePredict()
wp.report(wp.lr)
# wp.showXY()