import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import matplotlib.animation as animation

# --------------------------------------
# Helper Functions
# --------------------------------------
def F_asym_rot(t, Z, I, H_G_mag):
    psi, theta, phi = Z
    dpsi_dt = H_G_mag * (np.cos(phi)**2 / I[0, 0] + np.sin(phi)**2 / I[1, 1])
    dtheta_dt = H_G_mag * 0.5 * (1/I[1, 1] - 1/I[0, 0]) * np.sin(theta) * np.sin(2 * phi)
    dphi_dt = H_G_mag * (1/I[2, 2] - np.cos(phi)**2 / I[0, 0] - np.sin(phi)**2 / I[1, 1]) * np.cos(theta)
    return [dpsi_dt, dtheta_dt, dphi_dt]

def S_transf(psi, theta, phi):
    R_psi = np.array([[np.cos(psi), np.sin(psi), 0],
                      [-np.sin(psi), np.cos(psi), 0],
                      [0, 0, 1]])
    R_theta = np.array([[np.cos(theta), 0, -np.sin(theta)],
                        [0, 1, 0],
                        [np.sin(theta), 0, np.cos(theta)]])
    R_phi = np.array([[np.cos(phi), np.sin(phi), 0],
                      [-np.sin(phi), np.cos(phi), 0],
                      [0, 0, 1]])
    return R_phi @ R_theta @ R_psi

def create_box_faces(corners):
    return [[corners[j] for j in [0, 1, 3, 2]],
            [corners[j] for j in [4, 5, 7, 6]],
            [corners[j] for j in [0, 1, 5, 4]],
            [corners[j] for j in [2, 3, 7, 6]],
            [corners[j] for j in [1, 3, 7, 5]],
            [corners[j] for j in [0, 2, 6, 4]]]

# --------------------------------------
# System Parameters
# --------------------------------------
rho = 900
a, b, c = 0.1, 1.0, 0.2
L = np.linalg.norm([a, b, c])
m = rho * a * b * c
I = np.diag([
    m * (b**2 + c**2) / 12,
    m * (a**2 + c**2) / 12,
    m * (a**2 + b**2) / 12
])

R_0 = np.eye(3)
k_0_vec = R_0[2, :]
omega_0 = np.array([1.5, 2, 2.55])
H_G_vec = I @ omega_0
H_G_mag = np.linalg.norm(H_G_vec)
K_vec = H_G_vec / H_G_mag
u = np.cross(K_vec, k_0_vec)
j_prime_vec = u / np.linalg.norm(u)
theta_0 = np.arccos(K_vec[2])
phi_0 = np.arccos(j_prime_vec[1])
psi_0 = 0
angles_0 = [psi_0, theta_0, phi_0]

# --------------------------------------
# Time Integration
# --------------------------------------
t_max = 6
N_vals = 500
t_eval = np.linspace(0, t_max, N_vals + 1)
sol = solve_ivp(F_asym_rot, [0, t_max], angles_0, args=(I, H_G_mag), t_eval=t_eval)
psi, theta, phi = sol.y

# --------------------------------------
# Angular Velocities
# --------------------------------------
psi_dot = H_G_mag * (np.cos(phi)**2 / I[0, 0] + np.sin(phi)**2 / I[1, 1])
theta_dot = H_G_mag * 0.5 * (1/I[1, 1] - 1/I[0, 0]) * np.sin(theta) * np.sin(2 * phi)
phi_dot = H_G_mag * (1/I[2, 2] - np.cos(phi)**2 / I[0, 0] - np.sin(phi)**2 / I[1, 1]) * np.cos(theta)
omega_x = -psi_dot * np.sin(theta) * np.cos(phi) + theta_dot * np.sin(phi)
omega_y = psi_dot * np.sin(theta) * np.sin(phi) + theta_dot * np.cos(phi)
omega_z = psi_dot * np.cos(theta) + phi_dot
omega_vecs = np.vstack((omega_x, omega_y, omega_z)).T

# --------------------------------------
# Error Checks
# --------------------------------------
H_G_check = omega_vecs @ I
H_G_mags = np.linalg.norm(H_G_check, axis=1)
err_H_g = np.max(np.abs(H_G_mags - H_G_mags[0]))

# Recompute KE using dot product ωᵀ * I * ω
KE = np.sum(0.5 * omega_vecs * (I @ omega_vecs.T).T, axis=1)
err_KE = np.max(np.abs(KE - KE[0]))

print(f"err_H_g =\n\n   {err_H_g:.4e}\n")
print(f"err_KE =\n\n   {err_KE:.4e}\n")

# --------------------------------------
# Euler Angle Plots
# --------------------------------------
plt.figure(figsize=(10, 6))
plt.subplot(2, 1, 1)
plt.plot(t_eval, np.degrees(psi), '--k', label='ψ')
plt.plot(t_eval, np.degrees(theta), '-r', label='θ')
plt.plot(t_eval, np.degrees(phi), ':b', label='ϕ')
plt.ylabel("Angle (deg)")
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(t_eval, psi_dot, '--k', label='dψ/dt')
plt.plot(t_eval, theta_dot, '-r', label='dθ/dt')
plt.plot(t_eval, phi_dot, ':b', label='dϕ/dt')
plt.xlabel("Time (s)")
plt.ylabel("Angular velocity (rad/s)")
plt.legend()
plt.tight_layout()
plt.show()

# --------------------------------------
# 3D Animation
# --------------------------------------
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')
ax.set_box_aspect([1, 1, 1])

corners = np.array([[-a, -b, -c], [a, -b, -c], [-a, b, -c], [a, b, -c],
                    [-a, -b, c], [a, -b, c], [-a, b, c], [a, b, c]]).T
RR = R_0.T @ S_transf(psi_0, theta_0, phi_0)

def update(n):
    ax.cla()
    ax.set_xlim([-L, L])
    ax.set_ylim([-L, L])
    ax.set_zlim([-L, L])
    ax.grid(True)
    ax.set_title(f'Time: {t_eval[n]:.2f} s')

    R_to_ref = RR @ S_transf(psi[n], theta[n], phi[n]).T
    new_corners = R_to_ref @ corners
    verts = create_box_faces(new_corners.T)
    box = Poly3DCollection(verts, facecolors='cyan', edgecolors='k', alpha=0.7)
    ax.add_collection3d(box)

    # Angular velocity vector (purple)
    omega = R_to_ref @ omega_vecs[n]
    ax.quiver(0, 0, 0, *omega, color='purple', linewidth=2, length=0.4*L)
    ax.text(*(omega * 0.45), r'$e_\omega$', color='purple')

    # Angular momentum vector (teal)
    H_vec = R_to_ref @ H_G_check[n]
    ax.quiver(0, 0, 0, *H_vec, color='teal', linewidth=2, length=0.4*L)
    ax.text(*(H_vec * 0.45), r'$H_G$', color='teal')

    # Global XYZ
    ax.quiver(0, 0, 0, L, 0, 0, color='black', linestyle='dotted')
    ax.quiver(0, 0, 0, 0, L, 0, color='black', linestyle='dotted')
    ax.quiver(0, 0, 0, 0, 0, L, color='black', linestyle='dotted')
    ax.text(L, 0, 0, 'X', color='black')
    ax.text(0, L, 0, 'Y', color='black')
    ax.text(0, 0, L, 'Z', color='black')

ani = animation.FuncAnimation(fig, update, frames=N_vals + 1, interval=10)
plt.show()
