import numpy as np

def skew(vector):
    """Create a skew-symmetric matrix from a 3-element vector."""
    return np.array([
        [0, -vector[2], vector[1]],
        [vector[2], 0, -vector[0]],
        [-vector[1], vector[0], 0]
    ])

def disp_out(var_string, var_value):
    """Display a variable name and its value in a formatted way."""
    if isinstance(var_value, np.ndarray):
        var_value = np.array2string(var_value, separator='\t', precision=4, suppress_small=True)
    elif isinstance(var_value, float):
        var_value = f"{var_value:.4f}"
    print(f"{var_string}{var_value}")

# Constants and initial conditions
m_s = 5000 # Mass of satellite
I = np.diag([32000, 40000, 3600]) # Inertia matrix of satellite
m_m = 2 # Mass of meteorite
eps_co = 0.3 # Cofficient of Restitution

v_C_0 = np.array([0, 0, 8000]) # Initial velocity of satellite
om_0 = np.array([0, 0, 3]) # Initial angular velocity of satellite
v_m_0 = np.array([-9000, 12000, 0]) # Initial velocity of meteorite
r_A_C = np.array([0, -1.2, -9]) # Location of impact
delta_t = 0.40 # Time of impact

r_AC_x = skew(r_A_C)
S = np.array([[0, 1, 0]])
Z0 = np.zeros((3, 3))
U = np.eye(3)

# Building matrix A
top = np.hstack([Z0, Z0, m_m * U, S.T])
middle = np.hstack([m_s * U, Z0, m_m * U, np.zeros((3, 1))])
third = np.hstack([Z0, I, Z0, -r_AC_x @ S.T])
bottom = S @ np.hstack([U, -r_AC_x, -U, np.zeros((3, 1))])
A = np.vstack([top, middle, third, bottom])

# Building vector B
B = np.concatenate([
    m_m * v_m_0,
    m_s * v_C_0 + m_m * v_m_0,
    I @ om_0,
    -eps_co * S @ (v_C_0 - r_AC_x @ om_0 - v_m_0)
])

# Solve the linear system
Y = np.linalg.solve(A, B)

# Extract outputs
v_C_f = Y[0:3]
om_f = Y[3:6]
v_m_f = Y[6:9]
F = Y[9] / delta_t

T_s_0 = 0.5 * m_s * np.linalg.norm(v_C_0)**2 + 0.5 * om_0 @ I @ om_0
T_s_f = 0.5 * m_s * np.linalg.norm(v_C_f)**2 + 0.5 * om_f @ I @ om_f
T_m_0 = 0.5 * m_m * np.linalg.norm(v_m_0)**2
T_m_f = 0.5 * m_m * np.linalg.norm(v_m_f)**2

# Print results
print()
disp_out('v_C_f = ', v_C_f)
disp_out('om_f = ', om_f)
disp_out('v_m_f = ', v_m_f)
disp_out('F = ', F)
disp_out('T_s_0 = ', T_s_0)
disp_out('T_s_f - T_s_0 = ', T_s_f - T_s_0)
disp_out('T_m_0 = ', T_m_0)
disp_out('T_m_f - T_m_0 = ', T_m_f - T_m_0)
