import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

# ------------------------------
# Constants and Parameters
# ------------------------------
m_total = 549e3
m_prop = 385e3
t_burn = 185
m_dot = m_prop / t_burn
Thrust = 6e6
x_area = 40
C_d = 0.4
rho_0 = 1.29
H = 10400
R_e = 6378e3


# ------------------------------
# Gravity and Resistance
# ------------------------------
def gravity(h):
    return 9.807 * (R_e / (R_e + h)) ** 2


def resistance(h, v, C_d):
    rho = rho_0 * np.exp(-h / H)
    return 0.5 * C_d * rho * v ** 2 * x_area


# ------------------------------
# ODE Function
# ------------------------------
def Falcon_Z_dot(t, Z, m_total, m_dot, C_d, g_flag):
    y, v = Z
    m_cv = m_total - m_dot * t
    g = g_flag if g_flag != 0 else gravity(y)
    F_drag = resistance(y, v, C_d)
    dydt = v
    dvdt = (Thrust - F_drag - m_cv * g) / m_cv
    return [dydt, dvdt]


# ------------------------------
# Simulation Function
# ------------------------------
def run_simulation(m_dot_val, C_d_val, g_flag):
    t_vals = np.linspace(0, t_burn, t_burn + 1)
    y_vals = np.zeros(len(t_vals))
    v_vals = np.zeros(len(t_vals))
    Z_prev = [0, 0]

    for i in range(1, len(t_vals)):
        sol = solve_ivp(Falcon_Z_dot, [t_vals[i - 1], t_vals[i]], Z_prev,
                        args=(m_total, m_dot_val, C_d_val, g_flag), method='RK45')
        Z_prev = sol.y[:, -1]
        y_vals[i] = Z_prev[0]
        v_vals[i] = Z_prev[1]

    return y_vals, v_vals, max(y_vals) / 1000, max(v_vals)


# ------------------------------
# Run All 4 Cases
# ------------------------------
y_baseline, v_baseline, max_y1, max_v1 = run_simulation(m_dot, C_d, 0)
y_const_mass, v_const_mass, max_y2, max_v2 = run_simulation(0, C_d, 0)
y_no_drag, v_no_drag, max_y3, max_v3 = run_simulation(m_dot, 0, 0)
y_const_grav, v_const_grav, max_y4, max_v4 = run_simulation(m_dot, C_d, 9.807)

# ------------------------------
# Print Results
# ------------------------------
print("\n>>")
print("Max altitudes (km):", f"{max_y1:.4f}", f"{max_y2:.4f}", f"{max_y3:.4f}", f"{max_y4:.4f}")
print("Max speeds (m/s):", f"{max_v1:.4f}", f"{max_v2:.5f}", f"{max_v3:.4f}", f"{max_v4:.4f}")

# ------------------------------
# Plotting
# ------------------------------
plt.figure(100)

plt.subplot(2, 2, 1)
plt.plot(np.linspace(0, t_burn, t_burn + 1), y_baseline / 1000, '-r', label='Baseline')
plt.plot(np.linspace(0, t_burn, t_burn + 1), y_const_mass / 1000, '-b', label='Constant mass')
plt.plot(np.linspace(0, t_burn, t_burn + 1), y_no_drag / 1000, '-g', label='No resistance')
plt.plot(np.linspace(0, t_burn, t_burn + 1), y_const_grav / 1000, '--k', label='Constant gravity')
plt.ylabel('Altitude (km)')
plt.xlim([0, t_burn])
plt.legend()

plt.subplot(2, 2, 2)
plt.plot(np.linspace(0, t_burn, t_burn + 1), v_baseline, '-r',
         np.linspace(0, t_burn, t_burn + 1), v_const_mass, '-b',
         np.linspace(0, t_burn, t_burn + 1), v_no_drag, '-g',
         np.linspace(0, t_burn, t_burn + 1), v_const_grav, '--k')
plt.xlabel('Time (seconds)')
plt.ylabel('Speed v (m/s)')
plt.xlim([0, t_burn])
plt.tight_layout()
plt.show()
