2021-01-17 05:14:17 +08:00

47 lines
1.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import matplotlib.pyplot as plt
import numpy as np
import random
from sklearn.linear_model import LinearRegression, Ridge, RidgeCV
from sklearn.preprocessing import PolynomialFeatures
def f1(x, e):
return math.exp(-x) * math.sin(x) \
+ random.normalvariate(0, e)
def f2(x, y, e):
return math.exp(-x ** 2 - y ** 2) * math.sin(x * y) + random.normalvariate(0, e)
def solve1(x, y):
return LinearRegression().fit(x, y)
def solve2(x, y):
poly = PolynomialFeatures(degree=4)
X_poly = poly.fit_transform(x)
poly.fit(X_poly, y)
return poly
def solve3(x, y):
model = RidgeCV(alphas=[0.1, 1.0, 10.0]) # 通过RidgeCV可以设置多个参数值算法使用交叉验证获取最佳参数值
model.fit(x, y)
return model
if __name__ == '__main__':
x = np.linspace(0, 10, 100)
data = [f1(i, 1) for i in x]
y1 = []
model = solve3(x.reshape((-1, 1)), data)
for i in x:
y1.append(model.predict([[i]])[0])
print(model.predict([[1]]))
plt.plot(x, data, label='0')
plt.plot(x, y1, label='1')
plt.legend()
plt.show()