# Matlab 'AbsTol',1e-12,'RelTol',1e-12 ~= Python rtol=1e-9, atol=1e-9

import numpy as np
from scipy.integrate import solve_ivp
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import warnings

# Suppress integration warnings (e.g., overflow)
warnings.filterwarnings("ignore", category=RuntimeWarning)

# ------------------------
# System Parameters
# ------------------------
p = {
    'mass': 0.16,
    'I_1': 5e-5,
    'I_2': 2.6e-5,
    'L': 0.0,
    'phi_dot': 1e4 * 2 * np.pi / 60,
    'Gam_mag': 2,
    'freq': 20
}
p['I_2_eff'] = p['I_2'] + p['mass'] * p['L']**2
p['mgL'] = p['mass'] * 9.807 * p['L']

# ------------------------
# Modular Functions
# ------------------------
def M_mat_gyro(q, p):
    theta = q[1]
    M11 = p['I_2_eff'] * np.sin(theta)**2 + p['I_1'] * np.cos(theta)**2
    M22 = p['I_2_eff']
    return np.array([[M11, 0], [0, M22]])

def F_vec_gyro(q, q_dot, t, p):
    theta = q[1]
    psi_dot = q_dot[0]
    theta_dot = q_dot[1]
    Gamma = p['Gam_mag'] * np.sin(p['freq'] * t)
    F1 = Gamma - (p['I_2_eff'] - p['I_1']) * np.sin(2 * theta) * psi_dot * theta_dot \
         - p['I_1'] * np.sin(theta) * p['phi_dot'] * theta_dot
    F2 = 0.5 * (p['I_2_eff'] - p['I_1']) * np.sin(2 * theta) * psi_dot**2 \
         + p['I_1'] * np.sin(theta) * p['phi_dot'] * psi_dot + p['mgL'] * np.sin(theta)
    return np.array([F1, F2])

def G_vec_gyro(t, Z, p):
    q = Z[:2]
    q_dot = Z[2:]
    M = M_mat_gyro(q, p)
    F = F_vec_gyro(q, q_dot, t, p)
    q_2dot = np.linalg.solve(M, F)
    return np.concatenate((q_dot, q_2dot))

# ------------------------
# Integration Setup
# ------------------------
Z0 = [0, np.pi / 3, 0, -0.01]
t_eval = np.linspace(0, 2, 1500)
sol = solve_ivp(G_vec_gyro, [0, 2], Z0, args=(p,), method='RK45',
                t_eval=t_eval, rtol=1e-9, atol=1e-9)
t_ode = sol.t
psi, theta, psi_dot, theta_dot = sol.y

# Interpolator for ψ̇
psi_dot_func = interp1d(t_ode, psi_dot, kind='cubic')

# ------------------------
# Derived Quantities
# ------------------------
omega_x = psi_dot * np.sin(theta)
omega_z = psi_dot * np.cos(theta) - p['phi_dot']
H_O = np.vstack((p['I_2'] * omega_x, p['I_1'] * omega_z)).T
H_O_xyz_prime = np.vstack((
    np.cos(theta) * H_O[:, 0] - np.sin(theta) * H_O[:, 1],
    np.sin(theta) * H_O[:, 0] + np.cos(theta) * H_O[:, 1]
)).T
Kappa_I_2 = ((p['I_1']/p['I_2'] - 1) * np.cos(2 * theta) * psi**2
             - (p['I_1']/p['I_2']) * np.cos(theta) * p['phi_dot'] * psi_dot)

#%% Plotting
# ------------------------
# Figure 1: θ, ω, H_O (xyz)
# ------------------------
plt.figure(1, figsize=(8, 8))
plt.subplot(3, 1, 1)
plt.plot(t_ode, np.degrees(theta), 'r-', linewidth=0.5)
plt.ylabel(r'$\theta$ (degrees)')
plt.ylim([0, 360])
plt.grid(True)

plt.subplot(3, 1, 2)
plt.plot(t_ode, psi_dot, 'k:', label=r'$d\psi/dt$')
plt.plot(t_ode, omega_z, 'r-', label=r'$\omega_z$')
plt.plot(t_ode, omega_x, 'b--', label=r'$\omega_x$')
plt.ylabel('Rotation rate (rad/s)')
plt.ylim([-4000, 4000])
plt.legend()
plt.grid(True)

plt.subplot(3, 1, 3)
plt.plot(t_ode, H_O[:, 0], 'b-', label=r'$(H_O)_x$')
plt.plot(t_ode, H_O[:, 1], 'r--', label=r'$(H_O)_z$')
plt.ylabel(r'$H_O$ (kg·m²/s)')
plt.xlabel('Time (s)')
plt.ylim([-0.2, 0.05])
plt.legend()
plt.grid(True)

plt.suptitle(r"Figure 1: $\theta$, $\omega$ components, and $H_O$ (xyz)")
plt.tight_layout(rect=[0, 0, 1, 0.96])

# ------------------------
# Figure 2: θ, ψ̇, H_O in x'y'z'
# ------------------------
plt.figure(2, figsize=(8, 8))
plt.subplot(3, 1, 1)
plt.plot(t_ode, np.degrees(theta), 'r-', linewidth=0.5)
plt.ylabel(r'$\theta$ (degrees)')
plt.ylim([0, 360])
plt.grid(True)

plt.subplot(3, 1, 2)
plt.plot(t_ode, psi_dot, 'r--', linewidth=0.5)
plt.ylabel(r'$d\psi/dt$ (rad/s)')
plt.ylim([-3000, 3000])
plt.grid(True)

plt.subplot(3, 1, 3)
plt.plot(t_ode, H_O_xyz_prime[:, 0], 'b--', label="(H_O)'_x")
plt.plot(t_ode, H_O_xyz_prime[:, 1], 'r-', label="(H_O)'_z")
plt.ylabel(r"$H'_O$ (kg·m²/s)")
plt.xlabel("Time (s)")
plt.ylim([-0.1, 0.2])
plt.legend()
plt.grid(True)

plt.suptitle(r'Figure 2: $\theta$, $\dot{\psi}$, and $H_O$ in $x^{\prime}y^{\prime}z^{\prime}$')
plt.tight_layout(rect=[0, 0, 1, 0.96])

# ------------------------
# Figure 3: θ and stiffness ratio
# ------------------------
plt.figure(3, figsize=(8, 6))
plt.subplot(2, 1, 1)
plt.plot(t_ode, np.degrees(theta), 'r-', linewidth=0.5)
plt.ylabel(r'$\theta$ (degrees)')
plt.ylim([0, 360])
plt.grid(True)

plt.subplot(2, 1, 2)
plt.plot(t_ode, Kappa_I_2, 'r', linewidth=0.5)
plt.ylabel(r'$\kappa/I_2$ (rad/s)')
plt.ylim([-4e6, 8e6])
plt.yticks([-4e6, 0, 4e6, 8e6])
plt.xlabel("Time (s)")
plt.grid(True)

plt.suptitle(r"Figure 3: $\theta$ and stiffness ratio $\kappa/I_2$")
plt.tight_layout(rect=[0, 0, 1, 0.96])

# ------------------------
# Figure 4: Snapshot Subplots
# ------------------------
plt.figure(4, figsize=(10, 6))

def draw_quiver(ax, u, v, color='k', scale=1):
    ax.quiver(
        0, 0,
        u, v,
        angles='xy',
        scale_units='xy',
        scale=scale,
        width=0.015,
        color=color
    )

times = np.array([0, 0.32, 0.64, 0.96, 1.26, 1.58])  # 6 times
for nplot in range(6):
    ax = plt.subplot(2, 3, nplot + 1)
    ax.set_aspect('equal')

    # Find closest time index
    I_time = np.argmin(np.abs(t_ode - times[nplot]))
    theta_t = theta[I_time]

    print(
        f"I={I_time}, "
        f"theta={theta_t * 180 / np.pi:.2f}, "
        f"t={t_ode[I_time]:.3f}"
    )

    H_O_t = H_O_xyz_prime[I_time, :2]
    e_H = H_O_t / np.linalg.norm(H_O_t)

    # Unit vectors (2D projection)
    e_X = np.array([1.0, 0.0])
    e_Z = np.array([0.0, 1.0])

    e_x = np.array([np.cos(theta_t), np.sin(theta_t)])
    e_z = np.array([-np.sin(theta_t), np.cos(theta_t)])

    # ---- Reverse X direction (looking down Y axis) ----
    draw_quiver(ax,-e_X[0],  e_X[1], color='k')
    ax.text(-e_X[0]*1.05, e_X[1]*1.05, 'X')

    draw_quiver(ax, -e_Z[0],  e_Z[1], color='k')
    ax.text(-e_Z[0]*1.05, e_Z[1]*1.05, 'Z')

    draw_quiver(ax, -e_x[0],  e_x[1], color='b')
    ax.text(-e_x[0]*1.05, e_x[1]*1.05, 'x')

    draw_quiver(ax, -e_z[0],  e_z[1], color='b')
    ax.text(-e_z[0]*1.05, e_z[1]*1.05, 'z')

    # Angular momentum vector
    draw_quiver(ax, -H_O_t[0]/np.linalg.norm(H_O_t), H_O_t[1]/np.linalg.norm(H_O_t), color='red', scale=1)

    ax.set_title(
        rf"$t={t_ode[I_time]:.2f},\ \dot{{\psi}}={psi_dot[I_time]:.4f}$"
    )

    ax.set_xlim([-1.2, 1.2])
    ax.set_ylim([-1.2, 1.2])
    ax.grid(True)

plt.tight_layout()
plt.show()

# ------------------------
# Print exact snapshot info
# ------------------------
snapshot_times = [0, 0.32, 0.64, 0.96, 1.26, 1.58]
theta_deg_func = interp1d(t_ode, np.degrees(theta), kind='cubic')

dt = t_ode[1] - t_ode[0]
for t_snap in snapshot_times:
    theta_val = theta_deg_func(t_snap)
    I_val = round(t_snap / dt)
    print(f"I={I_val}, theta={theta_val:.4f}, t ={t_snap}")
