# -*- coding: utf-8 -*-
"""
% Animates the free rotation of a body about each of its principal axes,
% showing which results in stable rotation and which do not.
% Rotation is always about the Z-axis (with a small perturbation), so the
% initial orientation of the box is varied from case to case.
"""

import sys
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import time

#%% Helper Functions:
def F_asym_rot(t, Z, I, H_G_mag):
    psi, theta, phi = Z

    F1 = H_G_mag * (np.cos(phi)**2 / I[0, 0] + np.sin(phi)**2 / I[1, 1])
    F2 = H_G_mag * 0.5 * (1 / I[1, 1] - 1 / I[0, 0]) * np.sin(theta) * np.sin(2 * phi)
    F3 = H_G_mag * (1 / I[2, 2]
                    - np.cos(phi)**2 / I[0, 0]
                    - np.sin(phi)**2 / I[1, 1]) * np.cos(theta)

    return [F1, F2, F3]

def S_transf(psi, theta, phi):
    R_psi = np.array([
        [np.cos(psi),  np.sin(psi), 0],
        [-np.sin(psi), np.cos(psi), 0],
        [0, 0, 1]
    ])

    R_theta = np.array([
        [np.cos(theta), 0, -np.sin(theta)],
        [0, 1, 0],
        [np.sin(theta), 0, np.cos(theta)]
    ])

    R_phi = np.array([
        [np.cos(phi),  np.sin(phi), 0],
        [-np.sin(phi), np.cos(phi), 0],
        [0, 0, 1]
    ])

    return R_phi @ R_theta @ R_psi

def Rect_face(ax, box_coords, corner_indices, face_color):
    square = box_coords[:, corner_indices]
    verts = [list(zip(square[0, :], square[1, :], square[2, :]))]
    ax.add_collection3d(Poly3DCollection(verts, color=face_color, alpha=0.7))
    
    # ax.plot_trisurf(
    #     XYZ[0, :],
    #     XYZ[1, :],
    #     XYZ[2, :],
    #     color=face_color,
    #     alpha=0.7,
    #     shade=False
    # )


#%% Main Script

I_cases = np.array([
    [1, 4, 5],
    [1, 5, 4],
    [5, 4, 1]
])

case_names = [
    "Rotation about axis with smallest moment of inertia.",
    "Rotation about axis with intermediate moment of inertia.",
    "Rotation about axis with largest moment of inertia."
]

sides = np.sqrt(I_cases)
L = 2

R_0 = np.eye(3)
k_0_vec = R_0[2, :]

omega_0 = np.array([0.001, 0.001, 5.0])

psi_all = []
theta_all = []
phi_all = []
psi_dot_all = []
theta_dot_all = []
phi_dot_all = []
omega_x_all = []
omega_y_all = []
omega_z_all = []

for nn in range(3):
    I = np.diag(I_cases[nn])
    print(f"Case #{nn+1} I = {I_cases[nn]}")
    print(case_names[nn])

    a, b, c = sides[nn]

    H_G_vec = I @ omega_0
    H_G_mag = np.linalg.norm(H_G_vec)

    K_vec = H_G_vec / H_G_mag
    u = np.cross(K_vec, k_0_vec)
    J_vec = u / np.linalg.norm(u)

    theta_0 = np.arccos(K_vec[2])
    phi_0 = np.arccos(J_vec[1])
    psi_0 = 0.0

    T_avg = 2 * np.pi / np.linalg.norm(omega_0)
    delta_t = T_avg / 100
    T_max = 10 * T_avg
    N_vals = int(np.ceil(T_max / delta_t))

    t_vals = np.linspace(0, T_max, N_vals + 1)
    angles_0 = [psi_0, theta_0, phi_0]

    sol = solve_ivp(
        F_asym_rot,
        [0, T_max],
        angles_0,
        t_eval=t_vals,
        args=(I, H_G_mag),
        rtol=1e-5,
        atol=1e-5
    )

    psi, theta, phi = sol.y

    psi_dot = H_G_mag * (np.cos(phi)**2 / I[0, 0] + np.sin(phi)**2 / I[1, 1])
    theta_dot = H_G_mag * 0.5 * (1 / I[1, 1] - 1 / I[0, 0]) * np.sin(theta) * np.sin(2 * phi)
    phi_dot = H_G_mag * (1 / I[2, 2]
                         - np.cos(phi)**2 / I[0, 0]
                         - np.sin(phi)**2 / I[1, 1]) * np.cos(theta)

    omega_x = -psi_dot * np.sin(theta) * np.cos(phi) + theta_dot * np.sin(phi)
    omega_y =  psi_dot * np.sin(theta) * np.sin(phi) + theta_dot * np.cos(phi)
    omega_z =  psi_dot * np.cos(theta) + phi_dot

    psi_all.append(psi)
    theta_all.append(theta)
    phi_all.append(phi)
    psi_dot_all.append(psi_dot)
    theta_dot_all.append(theta_dot)
    phi_dot_all.append(phi_dot)
    omega_x_all.append(omega_x)
    omega_y_all.append(omega_y)
    omega_z_all.append(omega_z)

    plt.figure((nn+1) * 10)
    deg = 180 / np.pi

    plt.subplot(2, 1, 1)
    plt.plot(t_vals, psi * deg, '--k', label=r'$\psi$')
    plt.plot(t_vals, theta * deg, '-r', label=r'$\theta$')
    plt.plot(t_vals[::3], phi[::3] * deg, ':b', label=r'$\phi$')
    plt.ylabel('Angle (deg)')
    plt.legend()
    plt.title(case_names[nn])

    plt.subplot(2, 1, 2)
    plt.plot(t_vals, psi_dot, '--k', label=r'$\dot{\psi}$')
    plt.plot(t_vals, theta_dot, '-r', label=r'$\dot{\theta}$')
    plt.plot(t_vals[::3], phi_dot[::3], ':b', label=r'$\dot{\phi}$')
    plt.xlabel('Time (s)')
    plt.ylabel('Angular velocity (rad/s)')
    plt.legend()

    plt.tight_layout()

# When done, plot figure comparing all:
deg = 180 / np.pi

plt.figure(50, figsize=(7, 8))
# =========================
# Subplot 1: theta vs time
ax1 = plt.subplot(3, 1, 1)
ax1.plot(t_vals, theta_all[0] * deg, 'b--', label='Case 1')
ax1.plot(t_vals, theta_all[1] * deg, 'r-',  label='Case 2')
ax1.plot(t_vals, theta_all[2] * deg, 'g:',  label='Case 3')
ax1.set_xlim([0, np.max(t_vals)])
ax1.set_ylim([-5, 185])
ax1.set_yticks(np.arange(0, 181, 45))
ax1.set_ylabel(r'$\theta$ (deg)')
ax1.legend()
ax1.grid(True)
# =========================
# Subplot 2: psi_dot & phi_dot
ax2 = plt.subplot(3, 1, 2)
# Case 1
ax2.plot(t_vals, psi_dot_all[0], 'b-')
ax2.plot(t_vals, phi_dot_all[0], 'b-')
# Case 2
ax2.plot(t_vals, psi_dot_all[1], 'r-')
ax2.plot(t_vals, phi_dot_all[1], 'r-')
# Case 3
ax2.plot(t_vals, psi_dot_all[2], 'g--')
ax2.plot(t_vals, phi_dot_all[2], 'g--')
ax2.set_xlim([0, np.max(t_vals)])
ax2.set_ylabel(r'$\dot{\psi}$, $\dot{\phi}$ (rad/s)')
ax2.grid(True)
# =========================
# Subplot 3: omega_z
ax3 = plt.subplot(3, 1, 3)
ax3.plot(t_vals, omega_z_all[0], 'b-',  label='Case 1')
ax3.plot(t_vals, omega_z_all[1], 'r-',  label='Case 2')
ax3.plot(t_vals, omega_z_all[2], 'g--', label='Case 3')
ax3.set_xlim([0, np.max(t_vals)])
ax3.set_xlabel('Time (s)')
ax3.set_ylabel(r'$\omega_z$ (rad/s)')
ax3.grid(True)
plt.tight_layout()
plt.show()

# === MATLAB "return" ===
sys.exit()

#%% Run this section to see an animation of a particular case.

# Select the case of interest
nn = 1   # 0, 1, 2  → corresponds to MATLAB 1,2,3

# Close other plots, they interfere with this one:
plt.close('all')
# Note, if you manually close fig 100 then it will pop to the top again when initiated.

I = np.diag(I_cases[nn])
print(f"Case #{nn+1} I = {I_cases[nn]}")
print(case_names[nn])

a, b, c = sides[nn]

H_G_vec = I @ omega_0
H_G_mag = np.linalg.norm(H_G_vec)

# Direction cosines
K_vec = H_G_vec / H_G_mag
u = np.cross(K_vec, k_0_vec)
J_vec = u / np.linalg.norm(u)

theta_0 = np.arccos(K_vec[2])
phi_0 = np.arccos(J_vec[1])
psi_0 = 0.0

# Time parameters
T_average = 2 * np.pi / np.linalg.norm(omega_0)
delta_t = T_average / 100
T_max = 20 * T_average
N_vals = int(np.ceil(T_max / delta_t))

t_vals = np.linspace(0, T_max, N_vals + 1)
angles_0 = [psi_0, theta_0, phi_0]

sol = solve_ivp(
    F_asym_rot,
    [0, T_max],
    angles_0,
    t_eval=t_vals,
    args=(I, H_G_mag),
    rtol=1e-5,
    atol=1e-5
)

psi, theta, phi = sol.y

psi_dot = H_G_mag * (np.cos(phi)**2 / I[0, 0] + np.sin(phi)**2 / I[1, 1])
theta_dot = H_G_mag * 0.5 * (1/I[1, 1] - 1/I[0, 0]) * np.sin(theta) * np.sin(2 * phi)
phi_dot = H_G_mag * (1/I[2, 2]
                     - np.cos(phi)**2 / I[0, 0]
                     - np.sin(phi)**2 / I[1, 1]) * np.cos(theta)

omega_x = -psi_dot * np.sin(theta) * np.cos(phi) + theta_dot * np.sin(phi)
omega_y =  psi_dot * np.sin(theta) * np.sin(phi) + theta_dot * np.cos(phi)
omega_z =  psi_dot * np.cos(theta) + phi_dot

H_G_vec_all = np.vstack((omega_x, omega_y, omega_z)).T @ I

# =================================
# Begin animation
# =================================

RR = R_0.T @ S_transf(psi_0, theta_0, phi_0)

# Box corners (columns)
xyz_points = np.array([
    [-a, -b, -c],
    [ a, -b, -c],
    [-a,  b, -c],
    [ a,  b, -c],
    [-a, -b,  c],
    [ a, -b,  c],
    [-a,  b,  c],
    [ a,  b,  c]
]).T

# Axes from box
x_axis = np.array([[a, 0, 0], [a + 0.6, 0, 0]]).T
y_axis = np.array([[0, b, 0], [0, b + 0.6, 0]]).T
z_axis = np.array([[0, 0, c], [0, 0, c + 0.6]]).T
axes_xyz = np.hstack((x_axis, y_axis, z_axis))

fig = plt.figure(100)
ax = fig.add_subplot(111, projection='3d')

n_skip = 5

for n in range(0, N_vals + 1, n_skip):
    ax.cla()
    ax.view_init(elev=20, azim=135)

    R_to_ref = RR @ S_transf(psi[n], theta[n], phi[n]).T

    # Draw angular momentum vector
    H_G_XYZ = R_to_ref @ H_G_vec_all[n]
    e_H_G = H_G_XYZ / np.linalg.norm(H_G_XYZ)

    ax.quiver(0, 0, 0, e_H_G[0]*1.45*L, e_H_G[1]*1.45*L, e_H_G[2]*1.45*L,
        linewidth=2, color=(0, 0.5, 0.5))
    ax.text(*(K_vec * 1.35 * L), r'$H_G$')

    # --- Draw global XYZ axes ---
    ax.plot([-1.4*L, 1.4*L], [0, 0], [0, 0], ':k', linewidth=1.0)
    ax.text(1.45*L, 0, 0, 'X', fontweight='bold')
    ax.plot([0, 0], [-1.4*L, 1.4*L], [0, 0], ':k', linewidth=1.0)
    ax.text(0, 1.45*L, 0, 'Y', fontweight='bold')
    ax.plot([0, 0], [0, 0], [-1.6*L, 1.6*L], ':k', linewidth=1.0)
    ax.text(0, 0, 1.65*L, 'Z', fontweight='bold')

    # Transform box
    XYZ_pts = R_to_ref @ xyz_points
    axes_XYZ = R_to_ref @ axes_xyz

    # --- Draw body-fixed xyz axes ---
    ax.plot(axes_XYZ[0, 0:2], axes_XYZ[1, 0:2], axes_XYZ[2, 0:2],'-k', linewidth=1)
    ax.text(axes_XYZ[0, 1]*1.1, axes_XYZ[1, 1]*1.1, axes_XYZ[2, 1]*1.1, 'x')
    ax.plot(axes_XYZ[0, 2:4], axes_XYZ[1, 2:4], axes_XYZ[2, 2:4], '-k', linewidth=1)
    ax.text(axes_XYZ[0, 3]*1.1, axes_XYZ[1, 3]*1.1, axes_XYZ[2, 3]*1.1, 'y')
    ax.plot(axes_XYZ[0, 4:6], axes_XYZ[1, 4:6], axes_XYZ[2, 4:6], '-k', linewidth=1)
    ax.text(axes_XYZ[0, 5]*1.1, axes_XYZ[1, 5]*1.1, axes_XYZ[2, 5]*1.1, 'z')

    # Draw box faces
    Rect_face(ax, XYZ_pts, [0, 1, 3, 2], (0.4, 0, 0))
    Rect_face(ax, XYZ_pts, [0, 1, 5, 4], (0, 0, 0.4))
    Rect_face(ax, XYZ_pts, [0, 2, 6, 4], (0, 0.4, 0))
    Rect_face(ax, XYZ_pts, [4, 5, 7, 6], (1, 0, 0))
    Rect_face(ax, XYZ_pts, [2, 3, 7, 6], (0, 0, 1))
    Rect_face(ax, XYZ_pts, [1, 3, 7, 5], (0, 1, 0))

    # Angular velocity vector
    omega_n = R_to_ref @ np.array([omega_x[n], omega_y[n], omega_z[n]])
    e_omega = 1.1 * L * omega_n / np.linalg.norm(omega_n)

    ax.quiver(0, 0, 0, *e_omega, linewidth=1.5, color=(0.9, 0.4, 1))
    ax.text(*(1.15 * e_omega), r'$e_\omega$')

    ax_len = 1.25 * np.linalg.norm([a, b, c])
    ax.set_xlim([-ax_len, ax_len])
    ax.set_ylim([-ax_len, ax_len])
    ax.set_zlim([-ax_len, ax_len])
    ax.set_box_aspect([1, 1, 1])
    ax.grid(True)

    ax.set_title(f"H_G = {H_G_vec},   t = {t_vals[n]:.2f}s")

    plt.pause(delta_t if n > 0 else 2)
