2021-01-17 04:08:04 +08:00

154 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
from math import e
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint, solve_bvp, solve_ivp
def runge_kutta(y, x, dx, f):
""" y is the initial value for y
x is the initial value for x
dx is the time step in x
f is derivative of function y(t)
"""
k1 = dx * f(y, x)
k2 = dx * f(y + 0.5 * k1, x + 0.5 * dx)
k3 = dx * f(y + 0.5 * k2, x + 0.5 * dx)
k4 = dx * f(y + k3, x + dx)
return y + (k1 + 2 * k2 + 2 * k3 + k4) / 6.
'''
为了兼容solve_ivp的参数形式微分方程函数定义的参数顺序为(t,y)因此使用odeint函数时需要使参数tfirst=True
二阶甚至高阶微分方程组都可以变量替换成一阶方程组的形式,再调用相关函数进行求解,因此编写函数的时候,不同于一阶微分方程,二阶或者高阶微分方程返回的是低阶到高阶组成的方程组,
'''
y0 = [1 / (e + 1 / e), (e - 1 / e) / ((e + 1 / e) ** 2)] # 初值条件
# 初值[2,0]表示y(0)=2,y'(0)=0
def fvdp1(t, y):
'''
要把y看出一个向量y = [dy0,dy1,dy2,...]分别表示y的n阶导那么
y[0]就是需要求解的函数y[1]表示一阶导y[2]表示二阶导,以此类推
对于二阶微分方程肯定是由0阶和1阶函数组合而成的所以下面把y看成向量的话y0表示最初始的函数也就是我们要求解的函数y1表示一阶导对于高阶微分方程也可以以此类推
'''
dy1 = y[1] # y[1]=dy/dt一阶导
# dy2 = -3 * y[1] - 2 * y[0] + np.exp(-1 * t)
dy2 = 2 * y[1] ** 2 / y[0] - y[0]
# y[0]是最初始,也就是需要求解的函数
# 注意返回的顺序是[一阶导, 二阶导],这就形成了一阶微分方程组
return [dy1, dy2]
def solve0():
'''
内置求解器1
'''
t2 = np.linspace(-1, 1, 1000)
return odeint(fvdp1, y0, t2, tfirst=True)[:, 0]
def solve01(seq):
f0 = [y0[0]]
f1 = [y0[1]]
f2 = [fvdp1(-1, [f0[0], f1[0]])[1]]
for i in range(1, len(seq)):
h = seq[i] - seq[i - 1]
k21 = f2[i - 1]
k22 = fvdp1(seq[i - 1] + h / 2, [f0[i - 1] + h * k21 / 2, f1[i - 1] + h * k21 / 2])[1]
k23 = fvdp1(seq[i - 1] + h / 2, [f0[i - 1] + h * k22 / 2, f1[i - 1] + h * k22 / 2])[1]
k24 = fvdp1(seq[i - 1] + h / 2, [f0[i - 1] + h * k23, f1[i - 1] + h * k23])[1]
f1.append(f1[i - 1] + h * (k21 + k22 + k23 + k24) / 6)
f0.append(f0[i - 1] + h * f1[i - 1])
f2.append(fvdp1(seq[i], [f0[i], f1[i]])[1])
return f0
def solve1(seq):
'''
向前差分
'''
f0 = [y0[0]]
f1 = [y0[1]]
f2 = [fvdp1(-1, [f0[0], f1[0]])[1]]
for i in range(1, len(seq)):
h = seq[i] - seq[i - 1]
f0.append(f0[i - 1] + h * f1[i - 1])
f1.append(f1[i - 1] + h * f2[i - 1])
f2.append(fvdp1(seq[i], [f0[i], f1[i]])[1])
return f0
def solve2(seq):
'''
向后差分
'''
f0 = [y0[0]]
f1 = [y0[1]]
f2 = [fvdp1(-1, [f0[0], f1[0]])[1]]
for i in range(1, len(seq)):
h = seq[i] - seq[i - 1]
f2.append(fvdp1(seq[i], [f0[i - 1], f1[i - 1]])[1])
f1.append(f1[i - 1] + h * f2[i])
f0.append(f0[i - 1] + h * f1[i])
return f0
def runge_kutta(y, x, dx, f):
""" y is the initial value for y
x is the initial value for x
dx is the time step in x
f is derivative of function y(t)
"""
k1 = dx * f(y, x)
k2 = dx * f(y + 0.5 * k1, x + 0.5 * dx)
k3 = dx * f(y + 0.5 * k2, x + 0.5 * dx)
k4 = dx * f(y + k3, x + dx)
return y + (k1 + 2 * k2 + 2 * k3 + k4) / 6.
def solve3(seq):
'''
rk4
'''
return solve_ivp(fvdp1, t_span=(-1, 1.0), y0=y0, t_eval=seq).y.T[:, 0]
def show():
t0 = np.linspace(-1, 1, 1000)
r0 = solve0()
t1 = np.linspace(-1, 1, 6)
r1 = solve1(t1)
t2 = np.linspace(-1, 1, 6)
r2 = solve2(t2)
t3 = np.linspace(-1, 1, 6)
r3 = solve3(t3)
plt.plot(t0, r0, label='0')
plt.plot(t1, r1, label='1')
plt.plot(t2, r2, label='2')
plt.plot(t3, r3, label='3')
plt.legend()
plt.show()
def showN():
t0 = np.linspace(-1, 1, 1000)
r0 = solve0()
plt.plot(t0, r0, label='0: N = 1000')
solves = [solve1, solve2, solve3]
for j in range(3):
for i in range(1, 5):
n = 2 ** i
t = np.linspace(-1, 1, n + 1)
plt.plot(t, solves[j](t), label='%s:N=%s' % (j, n))
plt.legend()
plt.show()
if __name__ == '__main__':
showN()