2021-01-09 11:18:13 +08:00

114 lines
3.6 KiB
Python

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.linear_model import SGDClassifier, LinearRegression
from sklearn.metrics import confusion_matrix, classification_report, r2_score, accuracy_score
from sklearn.neural_network import BernoulliRBM, MLPClassifier, MLPRegressor
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
class WinePredict:
def __init__(self):
self.wine = pd.read_csv('./wine.csv', sep=';')
X = self.wine.drop('quality', axis=1)
y = self.wine['quality']
self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(X, y, test_size=0.1, random_state=42)
# 特征归一化
sc = StandardScaler()
self.X_train = sc.fit_transform(self.X_train)
self.X_test = sc.fit_transform(self.X_test)
def lr(self):
"""
线性回归
"""
lr = LinearRegression()
lr.fit(self.X_train, self.y_train)
return lr.predict(self.X_test)
def rfc(self):
"""
随机森林
"""
rfc = RandomForestClassifier(n_estimators=200, random_state=20)
rfc.fit(self.X_train, self.y_train)
return rfc.predict(self.X_test)
def sgd(self):
"""
随机梯度下降
"""
sgd = SGDClassifier(penalty=None)
sgd.fit(self.X_train, self.y_train)
return sgd.predict(self.X_test)
def svc(self):
"""
支持向量机
"""
svc = SVC(C=1.4, gamma=0.8, kernel='rbf')
svc.fit(self.X_train, self.y_train)
return svc.predict(self.X_test)
def mlp(self):
"""
多层感知器
"""
mlp = MLPClassifier([10, 6], learning_rate_init=0.001, activation='relu', solver='adam', alpha=0.0001,
max_iter=30000)
# 神经网络
mlp.fit(self.X_train, self.y_train)
return mlp.predict(self.X_test)
# 参数调优
def grid_search(self, model, param):
print('search %s', param)
grid = GridSearchCV(model, param_grid=param, scoring='accuracy', cv=10)
grid.fit(self.X_train, self.y_train)
print("%s: %f" % (grid.best_params_, grid.best_score_))
def gs_svc(self):
param = {
'C': [0.1, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4],
'kernel': ['linear', 'rbf'],
'gamma': [0.1, 0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4]
}
self.grid_search(SVC(), param)
# {'C': 1.4, 'gamma': 0.8, 'kernel': 'rbf'}
def gs_rfc(self):
self.grid_search(RandomForestClassifier(), {
'n_estimators': [200],
'random_state': [_ for _ in range(0, 200, 10)]
})
def report(self, fc):
r = fc()
if r.dtype == 'float64' or r.dtype == 'float32':
r = r.round()
# print(classification_report(self.y_test, r))
print(fc.__doc__.strip())
print(" R2: %f" % r2_score(self.y_test, r))
print(" accuracy: %f" % accuracy_score(self.y_test, r))
def showXY(self):
# fig = plt.figure(figsize=(10, 6))
for i in range(len(self.wine.columns[:-1])):
sns.barplot(x='quality', y=self.wine.columns[i], data=self.wine, ax=plt.subplot(4, 4, i + 1))
plt.xlabel(self.wine.columns[i])
plt.ylabel('')
plt.tight_layout()
plt.show()
if __name__ == '__main__':
wp = WinePredict()
wp.gs_rfc()
for i in [wp.rfc, wp.lr, wp.svc, wp.sgd, wp.mlp][:1]:
wp.report(i)
# wp.showXY()