import numpy as np
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt
''' 
% Solution of differential equations for a wobbling disk 
% Radius is the radius of the disk, kappa is the radius of gyration
% Units are SI -- Need to change g value in rolling_disk_f_mat.m function 
% file to use U.S. Customary units.
%
% {Z} is current state-space vector: 
%     number of generalized coordinates is 5, so size of {Z} is 10
% Sequence of generalized coordinates is:
%     Z(1) = X, Z(2) = Y, Z(3) = psi, Z(4) = theta, x(5) = phi
% Sequence of generalized velocities is:
%     Z(6) = d(X), Z(7) = d(Y), Z(8) = d(psi), Z(9) = d(theta), Z(10) = d(phi)

The function simulate_and_plot_case asks the user which case to run and then calls
Rolling_disk_ic_case to run the simulation.  There are also several functions that
define the EOM.
'''

from math import pi, sin, cos, sqrt, tan
from mpl_toolkits.mplot3d import Axes3D
import time

# Define the functions that comprise the EOM in the solution to the problem.
def a_a_dot_disk(Z, Radius):
    psi, theta = Z[2], Z[3]
    psi_dot, theta_dot = Z[7], Z[8]

    a = np.array([
        [1, 0, -Radius * sin(psi) * cos(theta), -Radius * cos(psi) * sin(theta), -Radius * sin(psi)],
        [0, 1, Radius * cos(psi) * cos(theta), -Radius * sin(psi) * sin(theta), Radius * cos(psi)]
    ])

    a_dot = np.array([
        [0, 0,
         Radius * (-psi_dot * cos(psi) * cos(theta) + theta_dot * sin(psi) * sin(theta)),
         Radius * (psi_dot * sin(psi) * sin(theta) - theta_dot * cos(psi) * cos(theta)),
         -Radius * psi_dot * cos(psi)],
        [0, 0,
         Radius * (-psi_dot * sin(psi) * cos(theta) - theta_dot * cos(psi) * sin(theta)),
         Radius * (-psi_dot * cos(psi) * sin(theta) - theta_dot * sin(psi) * cos(theta)),
         -Radius * psi_dot * sin(psi)]
    ])
    return a, a_dot

def F_vec_disk(Z, Radius, kappa):
    q, q_dot = Z[:5], Z[5:]
    F = np.zeros((5, 1))
    F[2] = kappa**2 * (0.5 * q_dot[2]*q_dot[3]*sin(2*q[3]) + q_dot[3]*q_dot[4]*sin(q[3]))
    F[3] = -0.5*(0.5*kappa**2*q_dot[2]**2 - Radius**2*q_dot[3]**2)*sin(2*q[3]) \
           - kappa**2*q_dot[2]*q_dot[4]*sin(q[3]) - 9.807 * Radius * cos(q[3])
    F[4] = kappa**2 * q_dot[2]*q_dot[3]*sin(q[3])
    return F

def M_mat_disk(q, Radius, kappa):
    theta = q[3]
    M = np.zeros((5, 5))
    M[0, 0] = 1
    M[1, 1] = 1
    M[2, 2] = 0.5 * kappa**2 * (1 + cos(theta)**2)
    M[2, 4] = kappa**2 * cos(theta)
    M[3, 3] = 0.5 * kappa**2 + (Radius * cos(theta))**2
    M[4, 2] = M[2, 4]
    M[4, 4] = kappa**2
    return M

def G_vec_disk_aug(t, Z, Radius, kappa):
    q, q_dot = Z[:5], Z[5:]
    a, a_dot = a_a_dot_disk(Z, Radius)
    M = M_mat_disk(q, Radius, kappa)
    F = F_vec_disk(Z, Radius, kappa)
    LHS = np.block([[M, -a.T], [-a, np.zeros((2, 2))]])
    RHS = np.vstack((F, a_dot @ q_dot.reshape(-1, 1)))
    sol = np.linalg.solve(LHS, RHS)
    q_ddot = sol[:5].flatten()
    return np.concatenate((q_dot, q_ddot))

def Rolling_disk_ic_case(case, Radius, kappa):
    g = 9.807
    rho = 5 * Radius
    if case == 1:
        # ---------------  First initial condition case  ------------------------
        #     Initial conditions match the steady precession solution.
        # Set initial q
        theta = pi / 3  # Set nominal value of nutation angle
        psi_dot = sqrt(2 * g * Radius ** 2 * (1 / tan(theta)) /
                       (2 * rho * (Radius ** 2 + kappa ** 2) + kappa ** 2 * Radius * cos(theta)))
        phi_dot = -(rho / Radius + cos(theta)) * psi_dot
        theta_dot = 0
        q = np.array([0, rho, 0, theta, 0])
        q_dot = np.zeros(5)
        a, _ = a_a_dot_disk(np.concatenate((q, q_dot)), Radius)
        XY_dot = -np.linalg.solve(a[:, :2], a[:, 2:] @ np.array([psi_dot, theta_dot, phi_dot]))
        q_dot = np.concatenate((XY_dot, [psi_dot, theta_dot, phi_dot]))
        t_max = 10.0
        print("\nphi_dot = {:.4f}".format(phi_dot))
        print("\ntheta_dot = {:.4f}".format(theta_dot))
        return q, q_dot, t_max
    elif case == 2:
        # ---------------  Second initial condition case  ------------------------
        # Disturb the steady precession solution by setting all variables other
        # than theta and theta_dot to steady precession values for some other theta.
        theta = pi / 3  # Set nominal value of nutation angle
        psi = 0
        phi = 0
        X = -rho * sin(psi)  #0
        Y = rho * cos(psi)  #0
        psi_dot = sqrt(2 * g * Radius ** 2 * (1 / tan(theta)) /
                       (2 * rho * (Radius ** 2 + kappa ** 2) + kappa ** 2 * Radius * cos(theta)))
        theta_dot = -0.5 * psi_dot  # Nonzero nutational velocity
        phi_dot = -(rho / Radius + cos(theta)) * psi_dot
        theta = pi / 6  # % Change theta to non-steady value
        # Set maximum time for integration:
        t_max = 10
        print("\nphi_dot = {:.4f}".format(phi_dot))
    elif case == 3:
        # %---------------  Third initial condition case  ------------------------
        # Examine stability of planar rolling.
        # Set initial psi_dot=0, phi_dot very low, theta slightly smaller
        # than pi/s, and theta_dot equals zero.
        # Set initial q
        X = 0
        Y = 0
        psi = 0
        phi = 0
        theta = (pi / 2) - 0.01
        # Set angular velocities
        v_cr = sqrt(9.807 * Radius / 3)
        v = 0.2 * v_cr
        psi_dot = 0
        phi_dot = -0.2 * sqrt(g / (3 * Radius))
        theta_dot = 0
        t_max = 10
        print("\nv = {:.4f}".format(v))
    else:
        raise ValueError("Case must be 1, 2, or 3")

    q = np.array([X, Y, psi, theta, phi])
    a, _ = a_a_dot_disk(np.concatenate((q, np.zeros(5))), Radius)
    a_XY = a[:, :2]
    a_ang = a[:, 2:]
    XY_dot = -np.linalg.solve(a_XY, a_ang @ np.array([psi_dot, theta_dot, phi_dot]))
    q_dot = np.concatenate((XY_dot, [psi_dot, theta_dot, phi_dot]))
    return q, q_dot, t_max

def compute_energy(Z, Radius, kappa):
    q, q_dot = Z[:5], Z[5:]
    theta = q[3]
    X_dot, Y_dot, psi_dot, theta_dot, phi_dot = q_dot
    T = 0.5 * (X_dot**2 + Y_dot**2 +
               (0.5 * kappa**2 + (Radius * cos(theta))**2) * theta_dot**2 +
               0.5 * kappa**2 * psi_dot**2 * sin(theta)**2 +
               kappa**2 * (psi_dot * cos(theta) + phi_dot)**2)
    V = 9.807 * Radius * sin(theta)
    return T + V

def plot_rolling_disk(t_vals, Z_vals):
    # Close figures if this was run already
    plt.close(1)
    plt.close(2)
    plt.close(3)
    mark_idx = np.linspace(0, len(t_vals) - 1, 20, dtype=int)
    plt.figure(1)
    plt.subplot(3, 1, 1)
    plt.plot(t_vals, Z_vals[:, 0], 'r-', label='X (meter)')
    plt.plot(t_vals, Z_vals[:, 1], 'b-', label='Y (meter)')
    plt.legend()
    ax2 = plt.subplot(3, 1, 2)

    ax2.plot(t_vals, np.rad2deg(Z_vals[:, 3]), 'r-', label='θ (deg)')
    ax2.set_ylabel('θ (deg)', color='r')
    ax2.tick_params(axis='y', labelcolor='r')
    ax2_right = ax2.twinx()
    ax2_right.plot(t_vals, np.rad2deg(Z_vals[:, 2]), 'b-o', label='ψ (deg)', markevery=mark_idx, markersize=2)
    ax2_right.set_ylabel('ψ (deg)', color='b')
    ax2_right.tick_params(axis='y', labelcolor='b')
    psi_deg = np.rad2deg(Z_vals[:, 2])
    ax2_right.set_ylim([-1000, max(psi_deg)])
    ax2.legend(loc='upper left')
    ax2_right.legend(loc='upper right')

    plt.subplot(3, 1, 3)
    plt.plot(t_vals, Z_vals[:, 8], 'r-', label='dψ/dt (rad/s)')
    plt.plot(t_vals, Z_vals[:, 9], 'b-o', label='dϕ/dt (rad/s)', markevery=mark_idx)
    plt.legend()
    plt.xlabel('Time (seconds)')

    plt.figure(2)
    plt.plot(Z_vals[:, 0], Z_vals[:, 1], 'k-')
    plt.xlabel('X (meter)')
    plt.ylabel('Y (meter)')
    plt.gca().set_aspect('equal', adjustable='box')

    fig = plt.figure(3)
    ax = fig.add_subplot(111, projection='3d')
    ax.plot(Z_vals[:, 0], Z_vals[:, 1], t_vals, 'b')
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Time (s)')

    plt.show()

def simulate_and_plot_case(case=1):
    Radius = 0.25
    kappa = Radius / sqrt(2)

    # Call the function that asks the user which case and then simulates it.
    q, q_dot, t_max = Rolling_disk_ic_case(case, Radius, kappa)
    Z_0 = np.concatenate((q, q_dot))

    m_max = 2000
    t_eval = np.linspace(0, t_max, m_max)

    t0_cpu = time.process_time()
    sol = solve_ivp(lambda t, Z: G_vec_disk_aug(t, Z, Radius, kappa),
                    [0, t_max], Z_0, method='Radau', t_eval=t_eval,
                    rtol=1e-8, atol=1e-8)
    cpu_time = time.process_time() - t0_cpu
    Z_vals = sol.y.T

    E_0 = compute_energy(Z_vals[0], Radius, kappa)
    E_all = np.array([compute_energy(Z, Radius, kappa) for Z in Z_vals])
    epsilon_E = np.max(np.abs(E_all - E_0) / E_0)
    constraint_err = [np.linalg.norm(a_a_dot_disk(Z, Radius)[0] @ Z[5:]) / 5 for Z in Z_vals]
    epsilon_v = np.max(constraint_err)

    print(f"\nCPU time = {cpu_time:.4f} sec")
    print(f"Errors: max(E - E0)/E0 = {epsilon_E:.2e}\n"
          f"        max(norm([a]*dq/dt)) = {epsilon_v:.2e}")

    plot_rolling_disk(sol.t, Z_vals)

try:
    user_input = int(input("Enter the initial condition case\n"
                           "1 ==> Steady precession\n"
                           "2 ==> Disturbance of steady precession\n"
                           "3 ==> Stability of planar rolling\n"
                           "Enter 1, 2 or 3 __ "))
    simulate_and_plot_case(case=user_input)
except Exception as e:
    print("Invalid input. Error:", e)