83 lines
2.1 KiB
Python
83 lines
2.1 KiB
Python
import sys
|
|
import matplotlib.pyplot as plt
|
|
|
|
def f(x: float, y: float) -> float:
|
|
# f(x, y) = x^2 - 2y
|
|
return x**2 - 2 * y
|
|
|
|
def euler(h: float, x_min: float, x_max: float, y_0: float):
|
|
# y_(n + 1) = y_n + h * f(x_n, y_n)
|
|
x_i: list[float] = [x_min]
|
|
y_i: list[float] = [y_0]
|
|
|
|
f_i = f(x_i[-1], y_i[-1])
|
|
h_f_i = h * f_i
|
|
|
|
for i in range(int((x_max - x_min) / h)):
|
|
y_i.append(y_i[-1] + h_f_i)
|
|
x_i.append(x_i[-1] + h)
|
|
|
|
f_i = f(x_i[-1], y_i[-1])
|
|
h_f_i = h * f_i
|
|
|
|
return x_i, y_i
|
|
|
|
def trapezoid(h: float, x_min: float, x_max: float, y_0: float):
|
|
# y_(n + 1) = y_n + h * (f(x_n, y_n) + f(x_(n + 1), y~_(n + 1)))/2
|
|
# y~_(n + 1) = y_n + h * f(x_n, y_n)
|
|
|
|
x_i: list[float] = [x_min]
|
|
y_i: list[float] = [y_0]
|
|
|
|
y_tilde = y_i[-1] + h * f(x_i[-1], y_i[-1])
|
|
|
|
for i in range(int((x_max - x_min) / h)):
|
|
y = y_i[-1] + h * (f(x_i[-1], y_i[-1]) + f(x_i[-1] + h, y_tilde)) / 2
|
|
y_i.append(y)
|
|
x_i.append(x_i[-1] + h)
|
|
|
|
y_tilde = y_i[-1] + h * f(x_i[-1], y_i[-1])
|
|
|
|
return x_i, y_i
|
|
|
|
def rk4(h: float, x_min: float, x_max: float, y_0: float):
|
|
# k_1 = f(x_n, y_n)
|
|
# k_2 = f(x_n + h/2, y_n + h/2 k_1)
|
|
# k_3 = f(x_n + h/2, y_n + h/2 k_2)
|
|
# k_4 = f(x_n + h, y_n + hk_3)
|
|
# y_(n + 1) = y_n + h/6 (k_1 + 2k_2 + 2k_3 + k_4)
|
|
|
|
x_i: list[float] = [x_min]
|
|
y_i: list[float] = [y_0]
|
|
|
|
for i in range(int((x_max - x_min) / h)):
|
|
k_1 = f(x_i[-1], y_i[-1])
|
|
k_2 = f(x_i[-1] + h/2, y_i[-1] + h/2 * k_1)
|
|
k_3 = f(x_i[-1] + h/2, y_i[-1] + h/2 * k_2)
|
|
k_4 = f(x_i[-1] + h, y_i[-1] + h * k_3)
|
|
y = y_i[-1] + h/6 * (k_1 + 2 * k_2 + 2 * k_3 + k_4)
|
|
|
|
y_i.append(y)
|
|
x_i.append(x_i[-1] + h)
|
|
|
|
return x_i, y_i
|
|
|
|
def draw_graph(x: list, y: list):
|
|
plt.plot(x, y)
|
|
plt.show()
|
|
|
|
def main() -> None:
|
|
# y' = t^2 - 2y, y(0) = 1, t in [0, 1]
|
|
|
|
methods = {"euler": euler,
|
|
"trapezoid": trapezoid,
|
|
"rk4": rk4}
|
|
|
|
if sys.argv[1] in methods.keys():
|
|
x, y = methods[sys.argv[1]](0.01, 0, 1, 1)
|
|
|
|
draw_graph(x, y)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|