import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp

# Clear all and close all figures (automatically handled in Python)

# Initial Conditions
y0 = 0
ydot0 = 0

# Define equations of motion (equivalent to 'eom_rocket')
def eom_rocket(t, x):
    m = 0.060 # Mass, kg
    g = 9.81 # Gravity, m/s^2
    F = 6.25 if t < 1.6 else 0.0 # Thrust
    c = 2.2e-4 # Drag
    # Note, drag calculation was as follows:
    # c=(1/2)*C_D*Ap*rho_f = (1/2)*0.75*(0.025**2*pi/4)*1.2 = 2.2e-4 kg/m
    y, ydot = x
    dydt = ydot
    dydotdt = F/m - g - (c/m) * ydot * abs(ydot)
    return [dydt, dydotdt]

# Solve numerically using ode45 equivalent
t_eval = np.linspace(0, 10, 1000)
sol = solve_ivp(eom_rocket, [0, 10], [y0, ydot0], t_eval=t_eval)
tout = sol.t
xout = sol.y.T

# Analytical solution
m = 0.060
g = 9.81
c = 2.2e-4
kappa = np.sqrt((6.25 - m * g) / c)
lambda_ = np.sqrt((m * g) / c)
t_b = 1.6  # Burnout
F_b = 6.25  # Thrust

N_time = len(tout)
vv = np.zeros(N_time)
acc = np.zeros(N_time)

# Loop through time
for j in range(N_time):
    tt = tout[j]
    # Velocity prior to burnout
    if tt <= t_b:
        vv[j] = kappa * np.tanh((kappa * c / m) * tt)
        acc[j] = (c / m) * (kappa ** 2 - vv[j] ** 2)
    else:
        v_b = kappa * np.tanh((kappa * c / m) * t_b)  # Velocity at burnout
        v_formula = lambda_ * np.tan(np.arctan(v_b / lambda_) - (lambda_ * c / m) * (tt - t_b))
        if v_formula >= 0:
            # Rocket is continuing to ascend
            vv[j] = v_formula
            acc[j] = -(c / m) * (lambda_ ** 2 + vv[j] ** 2)
        else:
            # Stop evaluation when upward motion ends
            v_max = v_formula
            N_pos = j - 1
            break

# Trim arrays
tt = tout[:N_pos+1]
vv = vv[:N_pos+1]
acc = acc[:N_pos+1]

# Identify maxima
v_max = np.max(vv)
i_v_max = np.argmax(vv)
t_max_v = tt[i_v_max]

v_max_comp = np.max(xout[:, 1])
i_v_max_comp = np.argmax(xout[:, 1])
t_max_v_comp = tout[i_v_max_comp]

y_max_comp = np.max(xout[:, 0])
t_max_y_comp = np.max(tt)
max_y = xout[N_pos, 1]

# Print results (formatted like MATLAB output)
print(f"\nv_max =\n\n  {v_max:.4f}\n")
print(f"i_v_max =\n\n  {i_v_max}\n")
print(f"t_max_v =\n\n  {t_max_v:.4f}\n")
print(f"v_max_comp =\n\n  {v_max_comp:.4f}\n")
print(f"i_v_max_comp =\n\n  {i_v_max_comp}\n")
print(f"t_max_v_comp =\n\n  {t_max_v_comp:.4f}\n")
print(f"y_max_comp =\n\n  {y_max_comp:.4f}\n")
print(f"t_max_y_comp =\n\n  {t_max_y_comp:.4f}\n")

# Plot results
plt.figure(1)

plt.subplot(3, 1, 1)
plt.plot(tout, xout[:, 0], 'k.-')
plt.xlabel('Time (s)')
plt.ylabel('Altitude y(t) (m)')

plt.subplot(3, 1, 2)
plt.plot(tout, xout[:, 1], 'k.-', linewidth=0.75)
plt.plot(tt, vv, 'r--', linewidth=2)
plt.ylim([-25, 130])
plt.plot(t_max_v, v_max, 'ro', markersize=6)
plt.xlabel('Time (s)')
plt.ylabel('Velocity dy/dt (m/s)')
plt.legend(['Numerical', 'Analytical', 'Maximum y'])

plt.tight_layout()
plt.show()
