''' 
% Simulation of a gear on a rack being squeezed and ejected.
 
% Generalized coordinates are angle of rack C = theta, 
%      rotation of gear A = phi, and position of rack B = x_B, and the EOM
%      are defined in G_vec_squeeze
% 
% System parameters and initial conditions
%      for the generalized coordinates must be entered below,
%      as well as the maximum time for a solution and the number
%      of time steps.
'''

import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt


def G_vec_squeeze(Z, t, p):
    """
    Function to evaluate dZ/dt corresponding to input values of Z and t,
    and a set of system parameters.
    """
    N = 3
    q = Z[0:N]
    theta = q[0]
    d_theta = Z[3]

    q_dot = Z[N:2 * N]

    Q = np.array([-p['Force'] * p['L'], 0, 0]).reshape(-1, 1)

    M_inertia = np.array([
        [p['I_C'], 0, 0],
        [0, (p['I_A'] + p['m_A'] * p['R'] ** 2), -p['m_A'] * p['R']],
        [0, -p['m_A'] * p['R'], (p['m_B'] + p['m_A'])]
    ])

    FF = Q + np.array([-p['m_C'] * 9.807 * p['L'] * np.cos(theta), 0, 0]).reshape(-1, 1)

    a = np.array([
        [np.cos(theta) * (1 + np.cos(theta)), np.sin(theta) ** 2, 0],
        [p['R'] * (1 + np.cos(theta)), -p['R'] * np.sin(theta) ** 2, np.sin(theta) ** 2]
    ])

    da_dt = np.array([
        d_theta * np.array([-np.sin(theta) * (1 + 2 * np.cos(theta)), 2 * np.sin(theta) * np.cos(theta), 0]),
        d_theta * np.array(
            [-p['R'] * np.sin(theta), -2 * p['R'] * np.sin(theta) * np.cos(theta), 2 * np.sin(theta) * np.cos(theta)])
    ])

    RHS = np.vstack((FF, np.dot(da_dt, q_dot.reshape(-1, 1))))

    LHS_top = np.hstack((M_inertia, -a.T))
    LHS_bottom = np.hstack((-a, np.zeros((2, 2))))
    LHS = np.vstack((LHS_top, LHS_bottom))

    X = np.linalg.solve(LHS, RHS)

    q_2dot = X[0:3, 0]
    Z_dot = np.vstack((q_dot.reshape(-1, 1), q_2dot.reshape(-1, 1)))

    return Z_dot.flatten()


def run_simulation():
    """
    Main function to run the simulation based on the provided MATLAB code.
    """
    # Define system parameters
    p = {
        'R': 0.1,
        'L': 0.4,
        'm_A': 2,
        'kappa_A': 0.08,
        'm_B': 2.5,
        'm_C': 3,
        'kappa_C': 0.26,
        'Force': 0  # Will be set in the loop
    }
    p['I_A'] = p['m_A'] ** (p['kappa_A'] ** 2)
    p['I_C'] = p['m_C'] * p['kappa_C'] ** 2

    v_B_crit = 30
    Force_vals = np.array([31089])  # Using the single force value for this conversion
    # Force_vals = np.arange(1000, 50001, 10000) # To loop over the values
    
    v_B_last = np.zeros(len(Force_vals))

    for n_force, force_val in enumerate(Force_vals):
        p['Force'] = force_val

        # Set initial conditions
        t_0 = 0
        N = 3
        q_0 = np.array([np.pi / 2, 0, 0])
        q_dot_0 = np.array([0, 0, 0])

        N_steps = 10000  # Number of time steps
        t_min = t_0
        t_max = 0.013  # Target maximum time
        delta_t = (t_max - t_min) / N_steps
        t_vals = np.arange(0, N_steps) * delta_t

        q_vals = np.zeros((N, N_steps))
        q_dot_vals = np.zeros((N, N_steps))
        q_2dot_vals = np.zeros((N, N_steps))

        q_vals[:, 0] = q_0
        q_dot_vals[:, 0] = q_dot_0

        Z_prev = np.concatenate((q_0, q_dot_0))

        Z_dot_initial = G_vec_squeeze(Z_prev, t_vals[0], p)
        q_2dot_vals[:, 0] = Z_dot_initial[N:2 * N]

        # Begin time stepping
        for m in range(1, N_steps):
            t_span = [t_vals[m - 1], t_vals[m]]

            sol = odeint(G_vec_squeeze, Z_prev, t_span, args=(p,), 
                         rtol=1.0e-6, atol=1.0e-6)

            Z_prev = sol[-1, :]

            Z_dot_current = G_vec_squeeze(Z_prev, t_span[-1], p)
            q_vals[:, m] = Z_prev[0:N]
            q_dot_vals[:, m] = Z_prev[N:2 * N]
            q_2dot_vals[:, m] = Z_dot_current[N:2 * N]

            theta_out = q_vals[0, m]
            # Ensure theta_out is not too close to 0 or pi to avoid cot(0) or cot(pi) issues
            if np.isclose(np.sin(theta_out / 2), 0, atol=np.finfo(float).eps):
                x_A_out = np.inf * np.sign(theta_out)
            else:
                x_A_out = p['R'] * (1 / np.tan(theta_out / 2))

            x_B_out = q_vals[2, m]

            # Re-introduce the break condition, but now that the math might be correct,
            # we should see if it still stops prematurely or aligns with MATLAB.
            if x_B_out > x_A_out:
                m_max = m + 1
                t_vals = t_vals[0:m_max]
                q_vals = q_vals[:, 0:m_max]
                q_dot_vals = q_dot_vals[:, 0:m_max]
                q_2dot_vals = q_2dot_vals[:, 0:m_max]
                v_B_last[n_force] = q_dot_vals[2, m_max - 1]
                break

    # Plotting based on the MATLAB code's 'else' block (single force value)
    if len(Force_vals) == 1:
        # Adjust data arrays based on actual simulation length if it broke early
        if 'm_max' in locals():  # Check if m_max was set by the break
            current_t_vals = t_vals
            current_q_vals = q_vals
            current_q_dot_vals = q_dot_vals
            current_q_2dot_vals = q_2dot_vals
        else:  # If it ran for the full t_max
            current_t_vals = t_vals
            current_q_vals = q_vals
            current_q_dot_vals = q_dot_vals
            current_q_2dot_vals = q_2dot_vals

        theta = current_q_vals[0, :]
        phi = current_q_vals[1, :]
        x_B = current_q_vals[2, :]
        x_A = np.zeros_like(theta)
        for i, t_val in enumerate(theta):
            if np.isclose(np.sin(t_val / 2), 0, atol=np.finfo(float).eps):
                x_A[i] = np.inf * np.sign(t_val)
            else:
                x_A[i] = p['R'] * (1 / np.tan(t_val / 2))

        x_B_dot = current_q_dot_vals[2, :]
        theta_dot = current_q_dot_vals[0, :]
        phi_dot = current_q_dot_vals[1, :]
        theta_2dot = current_q_2dot_vals[0, :]
        phi_2dot = current_q_2dot_vals[1, :]
        x_B_2dot = current_q_2dot_vals[2, :]

        # Figure 100
        plt.figure(100)

        # Subplot 1
        ax1 = plt.subplot(3, 1, 1)
        ax1.plot(current_t_vals * 1000, x_B, 'r-', label='x_B', markersize=3,
                 markevery=slice(9, len(current_t_vals), 5), linewidth=1)
        ax1.plot(current_t_vals * 1000, x_A, 'b-', label='x_A', markersize=3,
                 markevery=slice(9, len(current_t_vals), 5), linewidth=1)
        ax1.set_ylabel('x (meters)')
        ax1.set_ylim([0, 0.3])
        ax1.set_xlim([0, 0.013 * 1000])  # X-axis limit for ms

        ax1_twin = ax1.twinx()
        ax1_twin.plot(current_t_vals * 1000, theta * 180 / np.pi, 'k--', label='$\\theta$', markersize=2,
                      markevery=slice(9, len(current_t_vals), 10), linewidth=1)
        ax1_twin.plot(current_t_vals * 1000, phi * 180 / np.pi, 'g--', label='$\\phi$', markersize=2,
                      markevery=slice(9, len(current_t_vals), 10), linewidth=1)
        ax1_twin.set_ylabel('Angle (deg)')
        ax1_twin.set_ylim([0, 100])
        ax1.legend(loc='upper left')
        ax1_twin.legend(loc='upper right')
        ax1.set_title('Displacements and Angles')

        # Subplot 2
        ax2 = plt.subplot(3, 1, 2)
        ax2.plot(current_t_vals * 1000, x_B_dot, '-r', label='dx_B/dt', linewidth=1)
        ax2.set_ylabel('dx_B/dt (m/s)')
        ax2.set_ylim([0, np.max(x_B_dot)])
        ax2.set_yticks(np.arange(0, 30.1, 10))
        ax2.set_xlim([0, 0.013 * 1000])  # X-axis limit for ms

        ax2_twin = ax2.twinx()
        ax2_twin.plot(current_t_vals * 1000, theta_dot, '--b', label='d$\\theta$/dt', linewidth=1)
        ax2_twin.plot(current_t_vals * 1000, phi_dot, '--g', label='d$\\phi$/dt', linewidth=1)
        ax2_twin.set_ylabel('ang vel (rad/s)')
        ax2_twin.set_ylim([-200, 200])
        ax2.legend(loc='upper left')
        ax2_twin.legend(loc='upper right')
        ax2.set_title('Velocities')

        # Subplot 3
        ax3 = plt.subplot(3, 1, 3)
        ax3.plot(current_t_vals * 1000, x_B_2dot, '-r', label='d$^2$x_b/dt$^2$', linewidth=1)
        ax3.set_ylabel('d$^2$x_b/dt$^2$ (m/s$^2$)')
        ax3.set_ylim([0, np.max(x_B_2dot)])
        ax3.set_yticks(np.arange(0, np.max(x_B_2dot), 2000))
        ax3.set_xlim([0, 0.013 * 1000])  # X-axis limit for ms

        ax3_twin = ax3.twinx()
        ax3_twin.plot(current_t_vals * 1000, theta_2dot, '--b', label='d$^2\\theta$/dt$^2$', linewidth=1)
        ax3.set_xlabel('time (ms)')
        ax3_twin.set_yticks(np.arange(-60000, 20001, 20000))
        ax3_twin.set_ylabel('d$^2\\theta$/dt$^2$ (rad/s$^2$)')
        ax3.legend(loc='upper left')
        ax3_twin.legend(loc='upper right')
        ax3.set_title('Accelerations')
        plt.tight_layout()

        # Check no slip at point H (Figure 1000 equivalent)
        v_H_g_x = x_B_dot - p['R'] * phi_dot * (1 + np.cos(theta))
        v_H_g_y = - p['R'] * phi_dot * np.sin(theta)

        s = np.zeros_like(theta)
        for i, t_val in enumerate(theta):
            if np.isclose(np.sin(t_val / 2), 0, atol=np.finfo(float).eps):
                s[i] = np.inf * np.sign(t_val)
            else:
                s[i] = p['R'] * (1 / np.tan(t_val / 2))

        v_H_r_x = -s * theta_dot * np.sin(theta)
        v_H_r_y = s * theta_dot * np.cos(theta)

        plt.figure(1000)

        # Subplot 1 (top)
        ax_1000_1 = plt.subplot(2, 1, 1)
        ax_1000_1.plot(current_t_vals, v_H_g_x, '-r', label='v_H_g_x (Gear)')
        ax_1000_1.plot(current_t_vals, v_H_r_x, '--b', label='v_H_r_x (Rack)')
        ax_1000_1.legend()
        ax_1000_1.set_xlim([0, 0.013])  # X-axis limit for seconds
        ax_1000_1.set_ylim([0, 16])  # Adjusted for MATLAB plot magnitudes
        ax_1000_1.set_xlabel('Time (s)')

        # Subplot 2 (bottom)
        ax_1000_2 = plt.subplot(2, 1, 2)
        ax_1000_2.plot(current_t_vals, v_H_g_y, '-r', label='v_H_g_y (Gear)')
        ax_1000_2.plot(current_t_vals, v_H_r_y, '--b', label='v_H_r_y (Rack)')
        ax_1000_2.legend()
        ax_1000_2.set_xlim([0, 0.013])  # X-axis limit for seconds
        ax_1000_2.set_ylim([-10, 0.5])  # Adjusted for MATLAB plot magnitudes
        ax_1000_2.set_xlabel('Time (s)')

        plt.tight_layout()
        plt.show()

    else:
        # This block plots the results of looping over force values to find the
        # one that produces an exit velocity of 30 m/s
        plt.figure(10)
        plt.subplot(2, 1, 1)
        plt.plot(Force_vals / 1000, v_B_last, linewidth=1)
        plt.xlabel('Force F (kN)')
        plt.ylabel('(v_B)_f (m/s)')

        min_diff_idx = np.argmin(np.abs(v_B_last - v_B_crit))
        N_crit = min_diff_idx  # Python index

        if N_crit == 0:
            idx1 = 0
            idx2 = 1
        elif N_crit == len(Force_vals) - 1:
            idx1 = len(Force_vals) - 2
            idx2 = len(Force_vals) - 1
        else:
            idx1 = N_crit - 1
            idx2 = N_crit + 1

        slope = (v_B_last[idx2] - v_B_last[idx1]) / \
                (Force_vals[idx2] - Force_vals[idx1])
        F_crit = Force_vals[idx1] + (v_B_crit - v_B_last[idx1]) / slope
        print(f"Critical force = {F_crit}")
        plt.show()


if __name__ == '__main__':
    run_simulation()