import numpy as np


from scipy.integrate import solve_ivp
from scipy.spatial.transform import Rotation as R

import matplotlib.pyplot as plt


# -------------------------------------------------------------------------- #
#
#
#                             FUNCTIONS
#
#
# -------------------------------------------------------------------------- #



def Mean_to_true_anomaly(M, e):
    """
    Calculate True anomaly from mean anomaly

    Args:
        M: mean anomaly (rad).
        e: Eccentricity.


    Returns:
        f: True anomaly (rad).
    """

    M0 = M
    E = M0 # initial guess for the Eccentric anomaly
    for _ in range(10): # solve keplers equation
        E_new = M0 + e * np.sin(E)
        if np.abs(E_new - E) < 1e-8:
            break
        E = E_new
    f = 2 * np.arctan2(np.sqrt(1 - e) * np.tan(E / 2), np.sqrt(1 + e)) # true anomaly

    return f

# -------------------------------------------------------------------------- #
def cartesian_to_orbitgeometry(r_p, v_p):
    """
    Calculate orbit geometry from cartesian Cartesian coordinates (r_p, v_p)

    Args:
        r_p (numpy.ndarray): Position vector (rx_p, ry_p, rz_p) in the perifocal frame (km).
        v_p (numpy.ndarray): Velocity vector (vx_p, vy_p, vz_p) in the perifocal frame (km/s).


    Returns:
        tuple: Classical orbital elements (a, e).
            a: Semi-major axis (km)
            e: Eccentricity
    """

    # mu (float): Gravitational parameter.
    mu=398600.4415

    # Ensure r_p and v_p are numpy arrays
    r_p = np.array(r_p)
    v_p = np.array(v_p)

    # Calculate magnitude of position and velocity
    r = np.linalg.norm(r_p)
    v = np.linalg.norm(v_p)

    # Calculate specific energy
    energy = (v**2) / 2 - mu / r

    # Calculate semi-major axis
    if energy < 0:
        a = -mu / (2 * energy)
    else:
        a = np.inf  # For parabolic or hyperbolic orbits

    # Calculate eccentricity vector
    h = np.cross(r_p, v_p) # angular momentum
    e_vec = (1 / mu) * (np.cross(v_p, h) - mu * (r_p / r))
    e = np.linalg.norm(e_vec)

    return a, e

# -------------------------------------------------------------------------- #
def classical_to_cartesian(a, M, e):
    """
    Converts classical orbital elements to Cartesian coordinates (r, v) in the perifocal frame.

    Args:
        a (float): Semi-major axis (km)
        M (float): Mean anomaly (rad)
        e (float): Eccentricity

    Returns:
        tuple: Position (rx_p, ry_p, rz_p) and velocity (vx_p, vy_p, vz_p) in the perifocal frame.
    """

    # mu (float): Gravitational parameter.
    mu=398600.4415

    # Calculate true anomaly
    f = Mean_to_true_anomaly(M, e)

    # Calculate semi-latus rectum p and radius r
    p = a * (1 - e**2)
    r = p / (1 + e * np.cos(f))

    # Calculate position in perifocal frame
    rx_p = r * np.cos(f)
    ry_p = r * np.sin(f)
    rz_p = 0.0

    # Calculate velocity in perifocal frame
    h = np.sqrt(a * mu * (1 - e**2))
    vx_p = -mu / h * np.sin(f)
    vy_p = mu / h * (e + np.cos(f))
    vz_p = 0.0

    return (rx_p, ry_p, rz_p), (vx_p, vy_p, vz_p)

# -------------------------------------------------------------------------- #

def rtn_to_perifocal(r_RTN, v_RTN, f):
    """
    Transforms a state vector from the RTN frame to the perifocal frame.

    Args:
        r_RTN (np.ndarray): Position vector in RTN frame [rx, ry, rz] (km).
        v_RTN (np.ndarray): Velocity vector in RTN frame [vx, vy, vz] (km/s).
        f (float): True anomaly of the chief spacecraft (rad).

    Returns:
        tuple: Position and velocity vectors in the perifocal frame (km and km/s).
    """
    # Rotation matrix about the z-axis (N-axis) by -f.
    Rz_minus_f = np.array([[np.cos(f), -np.sin(f), 0],
                           [np.sin(f), np.cos(f), 0],
                           [0, 0, 1]])

    r_perifocal = np.dot(Rz_minus_f, r_RTN)
    v_perifocal = np.dot(Rz_minus_f, v_RTN)
    return r_perifocal, v_perifocal

# -------------------------------------------------------------------------- #

def perifocal_to_rtn(r_perifocal, v_perifocal, f):
    """
    Transforms a state vector from the perifocal frame to the RTN frame.

    Args:
        r_perifocal (np.ndarray): Position vector in perifocal frame [rx, ry, rz] (km).
        v_perifocal (np.ndarray): Velocity vector in perifocal frame [vx, vy, vz] (km/s).
        f (float): True anomaly of the chief spacecraft (rad).

    Returns:
        tuple: Position and velocity vectors in the RTN frame (km and km/s).
    """
    # Rotation matrix about the z-axis (N-axis) by +f.
    Rz_plus_f = np.array([[np.cos(f), np.sin(f), 0],
                          [-np.sin(f), np.cos(f), 0],
                          [0, 0, 1]])

    r_RTN = np.dot(Rz_plus_f, r_perifocal)
    v_RTN = np.dot(Rz_plus_f, v_perifocal)
    return r_RTN, v_RTN

# -------------------------------------------------------------------------- #


def perifocal_to_eci(r_perifocal, v_perifocal, omega, inc, raan):
    """
    Transforms position and velocity vectors from the perifocal frame to the ECI frame.

    Args:
        r_perifocal (numpy.ndarray): Position vector in the perifocal frame (km).
        v_perifocal (numpy.ndarray): Velocity vector in the perifocal frame (km/s).
        omega (float): Argument of perigee (rad).
        inc (float): Inclination (rad).
        raan (float): Right Ascension of the Ascending Node (rad).

    Returns:
        tuple: Two numpy arrays representing the position (r_eci) and velocity (v_eci)
               vectors in the ECI frame.
    """
    # Rotation matrices
    R_z_raan = np.array([
        [np.cos(raan), -np.sin(raan), 0],
        [np.sin(raan), np.cos(raan), 0],
        [0, 0, 1]
    ])

    R_x_inc = np.array([
        [1, 0, 0],
        [0, np.cos(inc), -np.sin(inc)],
        [0, np.sin(inc), np.cos(inc)]
    ])

    R_z_omega = np.array([
        [np.cos(omega), -np.sin(omega), 0],
        [np.sin(omega), np.cos(omega), 0],
        [0, 0, 1]
    ])

    # Transformation matrix from perifocal to ECI
    Q_perifocal_to_eci = np.dot(R_z_raan, np.dot(R_x_inc, R_z_omega))

    # Calculate ECI position and velocity
    r_eci = np.dot(Q_perifocal_to_eci, r_perifocal)
    v_eci = np.dot(Q_perifocal_to_eci, v_perifocal)

    return r_eci, v_eci


# -------------------------------------------------------------------------- #


def eci_to_perifocal(r_eci, v_eci, omega, inc, raan):
    """
    Transforms position and velocity vectors from the ECI frame to the perifocal frame.

    Args:
        r_eci (numpy.ndarray): Position vector in the ECI frame (km).
        v_eci (numpy.ndarray): Velocity vector in the ECI frame (km/s).
        omega (float): Argument of perigee (rad).
        inc (float): Inclination (rad).
        raan (float): Right Ascension of the Ascending Node (rad).

    Returns:
        tuple: Two numpy arrays representing the position (r_perifocal) and velocity (v_perifocal)
               vectors in the perifocal frame.
    """
    # Rotation matrices (transpose of the ECI to perifocal rotation matrices)
    R_z_omega_T = np.array([
        [np.cos(omega), np.sin(omega), 0],
        [-np.sin(omega), np.cos(omega), 0],
        [0, 0, 1]
    ])

    R_x_inc_T = np.array([
        [1, 0, 0],
        [0, np.cos(inc), np.sin(inc)],
        [0, -np.sin(inc), np.cos(inc)]
    ])

    R_z_raan_T = np.array([
        [np.cos(raan), np.sin(raan), 0],
        [-np.sin(raan), np.cos(raan), 0],
        [0, 0, 1]
    ])

    # Transformation matrix from ECI to perifocal (transpose of Q_perifocal_to_eci)
    Q_eci_to_perifocal = np.dot(R_z_omega_T, np.dot(R_x_inc_T, R_z_raan_T))

    # Calculate perifocal position and velocity
    r_perifocal = np.dot(Q_eci_to_perifocal, r_eci)
    v_perifocal = np.dot(Q_eci_to_perifocal, v_eci)

    return r_perifocal, v_perifocal



# -------------------------------------------------------------------------- #

def propagate_cw(delta_r0, n, dt, num_steps):
    """
    Propagates the relative state vector using the Clohessy-Wiltshire equations.

    Args:
        delta_r0: Initial relative state vector [x0, y0, z0, vx0, vy0, vz0] in RTN of Chief.
        n (float): Mean motion of the chief spacecraft (rad/s).
        dt (float): Time step for propagation (s).
        num_steps (int): Number of time steps to propagate.

    Returns:
        tuple: Arrays of propagated relative state vectors and corresponding times.
    """

    propagated_states = [delta_r0]
    times = [0.0]
    deltaM = [0.0]
    current_state = np.array(delta_r0)
    t = 0.0

    for _ in range(num_steps):
        t += dt
        nt = n * t

        phi = np.array([
            [4 - 3*np.cos(nt), 0, 0, np.sin(nt)/n, 2*(1 - np.cos(nt))/n, 0],
            [6*(np.sin(nt) - nt), 1, 0, 2*(np.cos(nt) - 1)/n, (4*np.sin(nt) - 3*nt)/n, 0],
            [0, 0, np.cos(nt), 0, 0, np.sin(nt)/n],
            [3*n*np.sin(nt), 0, 0, np.cos(nt), 2*np.sin(nt), 0],
            [6*n*(np.cos(nt) - 1), 0, 0, -2*np.sin(nt), 4*np.cos(nt) - 3, 0],
            [0, 0, -n*np.sin(nt), 0, 0, np.cos(nt)]
        ])

        current_state = np.dot(phi, delta_r0)
        propagated_states.append(current_state)
        times.append(t)

    return np.array(propagated_states), np.array(times)




# -------------------------------------------------------------------------- #
def propagate_orbit_truth(r0, v0, t_span, method='DOP853', J2=1.08263e-3, dense_output=True, t_eval=None):
    """
    Propagates the orbit using scipy.integrate.solve_ivp with J2 perturbation, ensuring consistency
    with the ECI frame.

    Args:
        r0 (np.ndarray): Initial position vector [x0, y0, z0] in ECI (km).
        v0 (np.ndarray): Initial velocity vector [vx0, vy0, vz0] in ECI (km/s).
        t_span (tuple): Time span for the propagation (t0, tf) (s).
        method (str, optional): Numerical integration method to use. Defaults to 'DOP853'.
        J2 (float, optional): J2 perturbation coefficient. Defaults to 1.08263e-3.
        dense_output (bool, optional): Whether to use dense output. Defaults to True.
        t_eval (array_like, optional): Times at which to store the computed solution.

    Returns:
        scipy.integrate._ivp.OdeResult: The result of the integration.
    """
    # mu (float): Gravitational parameter (km^3/s^2)
    mu = 398600.447
    # R_Earth (float): Earth Radius (km)
    R_earth = 6378.137

    y0 = np.concatenate((r0, v0))

    def equations_of_motion_with_j2(t, y):
        """
        Equations of motion with J2 perturbation in ECI frame.

        Args:
            t (float): Time (s).
            y (np.ndarray): State vector [x, y, z, vx, vy, vz] (km, km/s).

        Returns:
            np.ndarray: Derivative of the state vector [vx, vy, vz, ax, ay, az] (km/s, km/s^2).
        """
        r = y[:3]
        v = y[3:]
        r_mag = np.linalg.norm(r)

        # Calculate the gravitational acceleration
        a_grav = -mu * r / r_mag**3

        # Calculate J2 perturbation acceleration
        z = r[2]
        a_j2 = (3 * J2 * mu * R_earth**2 / (2 * r_mag**5)) * \
               np.array([
                   (5 * z**2 / r_mag**2 - 1) * r[0],
                   (5 * z**2 / r_mag**2 - 1) * r[1],
                   (5 * z**2 / r_mag**2 - 3) * z
               ])

        a = a_grav + a_j2  # Total acceleration

        return np.concatenate((v, a))

    solution = solve_ivp(equations_of_motion_with_j2, t_span, y0,
                         method=method,
                         dense_output=dense_output,
                         rtol=1e-13,
                         atol=1e-16,  # Increased accuracy
                         t_eval=t_eval
                         )
    return solution

    # -------------------------------------------------------------------------- #
def propagate_Keplerian_Motion(alpha0, dt, num_steps):
    """
    Propagates the relative state vector using Keplerian motion.

    Args:
        alpha0: Initial Keplerian elements [a, M, e, w, i, ra] (semi-major axis, Mean anomaly, eccentricity, argument of periapsis, inclination, right ascension of the ascending node).
        dt (float): Time step for propagation (s).
        num_steps (int): Number of time steps to propagate.

    Returns:
        tuple: Arrays of propagated Keplerian elements and corresponding times.
    """
    # mu (float): Gravitational parameter (km^3/s^2)
    mu = 398600.447

    propagated_SingularStates = [np.array(alpha0)] # Initialize with a copy
    propagated_position = []
    propagated_velocity = []

    times = [0.0]

    current_state = np.array(alpha0)
    t = 0.0

    a = alpha0[0]
    e = alpha0[2]
    M0 = alpha0[1]
    w = alpha0[3]
    inc = alpha0[4]
    ra = alpha0[5]


    n = np.sqrt(mu / a**3)


    r_per0, v_per0 = classical_to_cartesian(a, M0, e)
    r_eci0, v_eci0 = perifocal_to_eci(r_per0, v_per0, w, inc, ra)
    propagated_position.append(r_eci0)
    propagated_velocity.append(v_eci0)


    for _ in range(num_steps):
        t += dt
        M_new = n * t + M0

        current_state[1] = M_new

        propagated_SingularStates.append(np.copy(current_state))
        times.append(t)

        r_per, v_per = classical_to_cartesian(a, M_new, e)
        r_eci, v_eci = perifocal_to_eci(r_per, v_per, w, inc, ra)
        propagated_position.append(r_eci)
        propagated_velocity.append(v_eci)

    return np.array(propagated_SingularStates), np.array(propagated_position), np.array(propagated_velocity), np.array(times)





















# -------------------------------------------------------------------------- #
#
#
#                           LECTURE EXERCISE
#
#
# -------------------------------------------------------------------------- #

def propagate_ROE(dalpha0, n, alpha_c, dt, num_steps):
    """
    Propagates the relative state vector using the ROE model.

    Args:
        dalpha0: Initial relative state vector [da, dM, de, dw, di, dra].
        n (float): Mean motion of the chief spacecraft (rad/s).
        alpha_c: Chief state vector [a, M, e, w, i, ra] propagation (dim. num_steps x 6) - already propagated / known.
        dt (float): Time step for propagation (s).
        num_steps (int): Number of time steps to propagate.

    Returns:
        tuple: Arrays of propagated relative state vectors and corresponding times.
    """
    # setup
    propagated_ROEstates = np.tile(dalpha0, (num_steps+1, 1))
    times = [0.0]
    deltaM = [0.0]
    current_ROEstate = np.array(dalpha0)

    propagated_da = [dalpha0[0]] # Watch out! Not all of them are needed!
    propagated_dM = [dalpha0[1]]
    propagated_de = [dalpha0[2]]
    propagated_dw = [dalpha0[3]]
    propagated_di = [dalpha0[4]]
    propagated_dra = [dalpha0[5]]

    t = 0.0

    # Relative SMA and mean anomaly at the istant t = t0 [WATCH OUT! NOT all of them are needed!]
    da0 = dalpha0[0]
    dM0 = dalpha0[1]
    de0 = dalpha0[2]
    dw0 = dalpha0[3]
    dinc0 = dalpha0[4]
    dra0 = dalpha0[5]

    ###########################################################################
    # --- EXERCISE --- #
    """

    Dear reader,

    your objective is to complete the ROE propagation based on what you have learnt during the lecture.
    Please follow the following tasks to guarantee a successful completion:
    (1) identify which ROEs [da, dM, de, dw, di, dra] vary when ONLY the Keplerian motion is considered. Let's call it/them generically "dx"
    (2) For each identified dx, write the evolution of the absolute motion dot(x)=f(alpha) as a function of time.
    (3) Then, since dx = x_d - x_c, dot(dx) = dot(x_d) - dot(x_c), where d and c identify Deputy and Chief.
    (4) Expand at the N-order "x_d = x_c + dot(x)*dx + dot2(x)*dx^2/2 + ..." as a function of the ROEs. The N-value is your choice: we advice to start with N=1.
    (5) Finally you may approximate dx_new = dx0 + dot(dx)*t, where dx0 is an input of this function

    NB: Do not forget that "da = (a_d - a_c) / a_c" is normalized by "/a_c" and it is not just "a_d - a_c"!

    Solution: You may find it at the end of this code. However, before looking at it, try yourself!

    Enjoy,
    The Coder


    """

    # ROE propagation (assuming ONLY Keplerian motion)
    for _ in range(num_steps):
        t += dt

        da_new = da0
        # First order
        dM_new = dM0 - 1.5*n*da0*t
        """
        # Second order
        dM_new = dM0 - 1.5*n*da0*t + 15/8*n*da0**2*t
        # Third order
        dM_new = dM0 - 1.5*n*da0*t + 15/8*n*da0**2*t - 105/48*n*da0**3*t
        # And so on...
        """
        de_new = de0
        dw_new = dw0
        dinc_new = dinc0
        dra_new = dra0

        # Save
        propagated_da.append(da_new)
        propagated_dM.append(dM_new)
        propagated_de.append(de_new)
        propagated_dw.append(dw_new)
        propagated_di.append(dinc_new)
        propagated_dra.append(dra_new)

        times.append(t)


    # ----- END ----- #
    ###########################################################################


    # --- Do not modify --- #
    # Deputy state calculation
    propagated_ROEstates[:,0] = propagated_da
    propagated_ROEstates[:,1] = propagated_dM
    propagated_ROEstates[:,2] = propagated_de
    propagated_ROEstates[:,3] = propagated_dw
    propagated_ROEstates[:,4] = propagated_di
    propagated_ROEstates[:,5] = propagated_dra
    alpha_d = alpha_c + propagated_ROEstates
    alpha_d[:,0] = alpha_c[:,0] * (1 + propagated_ROEstates[:,0])

    # From absolute motion of Chief and Deputy to Relative one in ECI
    relative_states_eci = []
    for ii in range(num_steps+1):
        alpha_c_vec = alpha_c[ii, :]
        a_c = alpha_c_vec[0]
        M_c = alpha_c_vec[1]
        e_c = alpha_c_vec[2]
        w_c = alpha_c_vec[3]
        inc_c = alpha_c_vec[4]
        ra_c = alpha_c_vec[5]
        r_c_per, v_c_per = classical_to_cartesian(a_c, M_c, e_c)
        r_c_eci, v_c_eci = perifocal_to_eci(r_c_per, v_c_per, w_c, inc_c, ra_c)

        alpha_d_vec = alpha_d[ii, :]
        a_d = alpha_d_vec[0]
        M_d = alpha_d_vec[1]
        e_d = alpha_d_vec[2]
        w_d = alpha_d_vec[3]
        inc_d = alpha_d_vec[4]
        ra_d = alpha_d_vec[5]
        r_d_per, v_d_per = classical_to_cartesian(a_d, M_d, e_d)
        r_d_eci, v_d_eci = perifocal_to_eci(r_d_per, v_d_per, w_d, inc_d, ra_d)

        r_c_eci = np.array(r_c_eci)
        v_c_eci = np.array(v_c_eci)
        r_d_eci = np.array(r_d_eci)
        v_d_eci = np.array(v_d_eci)

        delta_r_eci = r_d_eci - r_c_eci
        delta_v_eci = v_d_eci - v_c_eci
        current_relative_state_eci = np.concatenate((delta_r_eci, delta_v_eci))
        relative_states_eci.append(current_relative_state_eci)


    return propagated_ROEstates, np.array(relative_states_eci), np.array(times)





























# -------------------------------------------------------------------------- #
#
#
#                        RELATIVE ORBITAL PROPAGATION
#
#
# -------------------------------------------------------------------------- #




# --- Code Setup --- #
mu = 398600.447 # Gravitational parameter (km^3/s^2)
step = 100      # number of points per orbit
N_orbit = 1     # number of propagated orbits


# --- Define Initial Conditions --- #
# Chief Classical singular Equinoctial Elements
a_c_singular = 7000 # SMA (km)
M_c_singular = 0.0  # Mean anomaly (rad)
e_c_singular = 0.0 # Eccentricity
w_c_singular = 0.0  # Argument of perigee (rad)
inc_c_singular = 1.0  # Inclination (rad)
ra_c_singular = 0.0  # RAAN (rad)
f_c_singular = Mean_to_true_anomaly(M_c_singular, e_c_singular)
alpha0_c = np.array([a_c_singular, M_c_singular, e_c_singular, w_c_singular, inc_c_singular, ra_c_singular])
print(" \n Initial Chief Equinoctial State:", alpha0_c)

# Relative Orbital Element (Singular form)
da_singular = 1 / a_c_singular
dM_singular = 0.0
de_singular = 0.0
dw_singular = 0.0
dinc_singular = 0.01
dra_singular = 0.0
dalpha0 = np.array([da_singular, dM_singular, de_singular, dw_singular, dinc_singular, dra_singular])
print(" \n Initial ROE State:", dalpha0)


"""
You can play around changing a little bit the orbital elements and relative one.
However, do not forget that those are models valid only for CLOSE orbits, for large initial distances they will not match the truth!

"""






# Deputy Classical singular Equinoctial Elements
a_d_singular = a_c_singular * (1 + da_singular)
M_d_singular = M_c_singular + dM_singular
e_d_singular = e_c_singular + de_singular
w_d_singular = w_c_singular + dw_singular
inc_d_singular = inc_c_singular + dinc_singular
ra_d_singular = ra_c_singular + dra_singular
alpha0_d = np.array([a_d_singular, M_d_singular, e_d_singular, w_d_singular, inc_d_singular, ra_d_singular])
print(" \n Initial Deputy Equinoctial State:", alpha0_d)


# Convert initial equinoctial elements to Cartesian coordinates (perifocal frame)
r_c_p0, v_c_p0 = classical_to_cartesian(a_c_singular, M_c_singular, e_c_singular)
r_c_eci0, v_c_eci0 = perifocal_to_eci(r_c_p0, v_c_p0, w_c_singular, inc_c_singular, ra_c_singular) # in ECI

r_d_pd0, v_d_pd0 = classical_to_cartesian(a_d_singular, M_d_singular, e_d_singular)                # in the Deputy perifocal frame
r_d_eci0, v_d_eci0 = perifocal_to_eci(r_d_pd0, v_d_pd0, w_d_singular, inc_d_singular, ra_d_singular) # in ECI


# Calculate initial relative Cartesian state in the chief's perifocal frame
delta_r0_p = np.array(r_d_eci0) - np.array(r_c_eci0)
delta_v0_p = np.array(v_d_eci0) - np.array(v_c_eci0)

delta_r0_p, delta_v0_p = eci_to_perifocal(delta_r0_p, delta_v0_p, w_c_singular, inc_c_singular, ra_c_singular)

# Calculate initial relative Cartesian state in the chief's RTN frame
delta_r0_rtn, delta_v0_rtn = perifocal_to_rtn(delta_r0_p, delta_v0_p, f_c_singular) # in the Chief RTN frame


# -------------------------------------------------------------------------- #


# --- Propagation using CW Equations --- #
# Propagation setup
num_steps = N_orbit*step
delta_r0_cw = np.concatenate((delta_r0_rtn, delta_v0_rtn))

# Define mean motion of the chief
period_chief = 2 * np.pi * np.sqrt(a_c_singular**3 / mu)
n_chief = 2 * np.pi / period_chief

# Define the time step
dt = period_chief / step

# Propagate the relative state using CW
relative_states_cw, time_points_cw = propagate_cw(delta_r0_cw, n_chief, dt, num_steps)

# Mapping to ECI
relative_states_cw_eci = []
deltar_cw_rtn_matrix = relative_states_cw[:, 0:3]
deltav_cw_rtn_matrix = relative_states_cw[:, 3:6]

for ii in range(num_steps+1):
    MM = M_c_singular + n_chief*time_points_cw[ii]
    ff = Mean_to_true_anomaly(MM, e_c_singular)
    deltar_cw_rtn_vec = deltar_cw_rtn_matrix[ii, :]
    deltav_cw_rtn_vec = deltav_cw_rtn_matrix[ii, :]
    deltar_cw_per, deltav_cw_per = rtn_to_perifocal(deltar_cw_rtn_vec, deltav_cw_rtn_vec, ff)
    deltar_cw_eci, deltav_cw_eci = perifocal_to_eci(deltar_cw_per, deltav_cw_per, w_c_singular, inc_c_singular, ra_c_singular)
    relative_states_cw_eci.append(np.concatenate((deltar_cw_eci, deltav_cw_eci)))

relative_states_cw_eci = np.array(relative_states_cw_eci)


# -------------------------------------------------------------------------- #


# --- Propagation using Truth Model --- #

# Time span for propagation
t_span = (0, period_chief * N_orbit)  # Propagate for two orbits
num_steps_truth = N_orbit*step+1

# Use the same t_eval for both propagations to ensure the same time points
t_eval = np.linspace(t_span[0], t_span[1], num_steps_truth)

# Propagate orbits for chief and deputy
solution_chief = propagate_orbit_truth(r_c_eci0, v_c_eci0, t_span, t_eval=t_eval)
solution_deputy = propagate_orbit_truth(r_d_eci0, v_d_eci0, t_span, t_eval=t_eval)


# Calculate relative state
deltar_truth_eci = solution_deputy.y[:3, :] - solution_chief.y[:3, :]
deltav_truth_eci = solution_deputy.y[3:, :] - solution_chief.y[3:, :]

t_truth = solution_chief.t


# -------------------------------------------------------------------------- #
# --- Propagation using Kepler Model --- #
# Chief
alpha_evo_c, r_kep_c, v_kep_c, time_points_kep = propagate_Keplerian_Motion(alpha0_c, dt, num_steps)

# Deputy
alpha_evo_d, r_kep_d, v_kep_d, time_points_kep = propagate_Keplerian_Motion(alpha0_d, dt, num_steps)

# Relative Motion
deltar_kep_eci = r_kep_d - r_kep_c
deltav_kep_eci = v_kep_d - v_kep_c


# -------------------------------------------------------------------------- #

# --- Propagation using ROE Model --- #
propagated_ROEstates, relative_states_roe_eci, time_points_roe = propagate_ROE(dalpha0, n_chief, alpha_evo_c, dt, num_steps)
alpha_d = alpha_evo_c + propagated_ROEstates











# -------------------------------------------------------------------------- #
#
#
#                               PLOTTING
#
#
# -------------------------------------------------------------------------- #




# --- Plotting Propagation --- #
plt.figure(figsize=(12, 8))

# CW Results
plt.subplot(4, 1, 1)
plt.plot(time_points_cw, relative_states_cw_eci[:, 0], label='CW - x')
plt.plot(time_points_cw, relative_states_cw_eci[:, 1], label='CW - y')
plt.plot(time_points_cw, relative_states_cw_eci[:, 2], label='CW - z')
plt.title('Relative Position - CW')
plt.xlabel('Time (s)')
plt.ylabel('Position (km)')
plt.legend()
plt.grid(True)

# ROE Model Results
plt.subplot(4, 1, 2)
plt.plot(time_points_roe, relative_states_roe_eci[:, 0], label='ROE - x')
plt.plot(time_points_roe, relative_states_roe_eci[:, 1], label='ROE - y')
plt.plot(time_points_roe, relative_states_roe_eci[:, 2], label='ROE - z')
plt.title('Relative Position - ROE Model')
plt.xlabel('Time (s)')
plt.ylabel('Position (km)')
plt.legend()
plt.grid(True)

# Keplerian Motion Results
plt.subplot(4, 1, 3)
plt.plot(time_points_kep, deltar_kep_eci[:, 0], label='Keplerian - x')
plt.plot(time_points_kep, deltar_kep_eci[:, 1], label='Keplerian - y')
plt.plot(time_points_kep, deltar_kep_eci[:, 2], label='Keplerian - z')
plt.title('Relative Position - Truth (Keplerian Motion)')
plt.xlabel('Time (s)')
plt.ylabel('Position (km)')
plt.legend()
plt.grid(True)

# Truth Model (w J2) Results
plt.subplot(4, 1, 4)
plt.plot(t_truth, deltar_truth_eci[0, :], label='Truth - x')
plt.plot(t_truth, deltar_truth_eci[1, :], label='Truth - y')
plt.plot(t_truth, deltar_truth_eci[2, :], label='Truth - z')
plt.title('Relative Position - Truth (considering also J2)')
plt.xlabel('Time (s)')
plt.ylabel('Position (km)')
plt.legend()
plt.grid(True)






plt.tight_layout()
plt.show()

# -------------------------------------------------------------------------- #


# --- Plotting Error --- #
plt.figure(figsize=(12, 8))

with np.errstate(divide='ignore', invalid='ignore'):
    # CW Relative Position Error in ECI
    plt.subplot(2, 1, 1)
    err_x_cw = (relative_states_cw_eci[:, 0] - deltar_kep_eci[:, 0]) / deltar_kep_eci[:, 0] * 100
    err_y_cw = (relative_states_cw_eci[:, 1] - deltar_kep_eci[:, 1]) / deltar_kep_eci[:, 1] * 100
    err_z_cw = (relative_states_cw_eci[:, 2] - deltar_kep_eci[:, 2]) / deltar_kep_eci[:, 2] * 100
    plt.plot(time_points_cw, np.where(np.abs(deltar_kep_eci[:, 0]) > 1e-9, err_x_cw, np.nan), label='CW - err_x')
    plt.plot(time_points_cw, np.where(np.abs(deltar_kep_eci[:, 1]) > 1e-9, err_y_cw, np.nan), label='CW - err_y')
    plt.plot(time_points_cw, np.where(np.abs(deltar_kep_eci[:, 2]) > 1e-9, err_z_cw, np.nan), label='CW - err_z')
    plt.title('Relative Position Error in ECI - CW')
    plt.xlabel('Time (s)')
    plt.ylabel('Relative Error (%)')
    plt.legend()
    plt.grid(True)

    # ROE Model Relative Position Error in ECI
    plt.subplot(2, 1, 2)
    err_x_roe = (relative_states_roe_eci[:, 0] - deltar_kep_eci[:, 0]) / deltar_kep_eci[:, 0] * 100
    err_y_roe = (relative_states_roe_eci[:, 1] - deltar_kep_eci[:, 1]) / deltar_kep_eci[:, 1] * 100
    err_z_roe = (relative_states_roe_eci[:, 2] - deltar_kep_eci[:, 2]) / deltar_kep_eci[:, 2] * 100
    plt.plot(time_points_roe, np.where(np.abs(deltar_kep_eci[:, 0]) > 1e-9, err_x_roe, np.nan), label='ROE - err_x')
    plt.plot(time_points_roe, np.where(np.abs(deltar_kep_eci[:, 1]) > 1e-9, err_y_roe, np.nan), label='ROE - err_y')
    plt.plot(time_points_roe, np.where(np.abs(deltar_kep_eci[:, 2]) > 1e-9, err_z_roe, np.nan), label='ROE - err_z')
    plt.title('Relative Position Error in ECI - ROE Model')
    plt.xlabel('Time (s)')
    plt.ylabel('Relative Error (%)')
    plt.legend()
    plt.grid(True)

plt.tight_layout()
plt.show()


# -------------------------------------------------------------------------- #



# --- Plotting Normalized Cumulative Relative Error --- #
plt.figure(figsize=(12, 8))

with np.errstate(divide='ignore', invalid='ignore'):
    # Calculate the relative error, multiply by dt, and cumulative sum for CW
    relerr_cw_x = np.abs((relative_states_cw_eci[:, 0] - deltar_kep_eci[:, 0]) / (deltar_kep_eci[:, 0] + 1e-12))
    relerr_cw_y = np.abs((relative_states_cw_eci[:, 1] - deltar_kep_eci[:, 1]) / (deltar_kep_eci[:, 1] + 1e-12))
    relerr_cw_z = np.abs((relative_states_cw_eci[:, 2] - deltar_kep_eci[:, 2]) / (deltar_kep_eci[:, 2] + 1e-12))
    cumulative_relerr_cw_x_dt = np.cumsum(relerr_cw_x * dt)
    cumulative_relerr_cw_y_dt = np.cumsum(relerr_cw_y * dt)
    cumulative_relerr_cw_z_dt = np.cumsum(relerr_cw_z * dt)
    normalized_cumulative_relerr_cw_x = cumulative_relerr_cw_x_dt / period_chief * 100
    normalized_cumulative_relerr_cw_y = cumulative_relerr_cw_y_dt / period_chief * 100
    normalized_cumulative_relerr_cw_z = cumulative_relerr_cw_z_dt / period_chief * 100

    # Calculate the relative error, multiply by dt, and cumulative sum for ROE
    relerr_roe_x = np.abs((relative_states_roe_eci[:, 0] - deltar_kep_eci[:, 0]) / (deltar_kep_eci[:, 0] + 1e-12))
    relerr_roe_y = np.abs((relative_states_roe_eci[:, 1] - deltar_kep_eci[:, 1]) / (deltar_kep_eci[:, 1] + 1e-12))
    relerr_roe_z = np.abs((relative_states_roe_eci[:, 2] - deltar_kep_eci[:, 2]) / (deltar_kep_eci[:, 2] + 1e-12))
    cumulative_relerr_roe_x_dt = np.cumsum(relerr_roe_x * dt)
    cumulative_relerr_roe_y_dt = np.cumsum(relerr_roe_y * dt)
    cumulative_relerr_roe_z_dt = np.cumsum(relerr_roe_z * dt)
    normalized_cumulative_relerr_roe_x = cumulative_relerr_roe_x_dt / period_chief * 100
    normalized_cumulative_relerr_roe_y = cumulative_relerr_roe_y_dt / period_chief * 100
    normalized_cumulative_relerr_roe_z = cumulative_relerr_roe_z_dt / period_chief * 100

    # Plot the normalized cumulative relative error for the CW model
    plt.subplot(2, 1, 1)
    plt.plot(time_points_cw, normalized_cumulative_relerr_cw_x, label='CW - Sum |Rel Err| x * dt / T')
    plt.plot(time_points_cw, normalized_cumulative_relerr_cw_y, label='CW - Sum |Rel Err| y * dt / T')
    plt.plot(time_points_cw, normalized_cumulative_relerr_cw_z, label='CW - Sum |Rel Err| z * dt / T')
    plt.title('Normalized Cumulative Relative Position Error in ECI - CW')
    plt.xlabel('Time (s)')
    plt.ylabel('Normalized Cumulative Relative Error (%)')
    plt.legend()
    plt.grid(True)

    # Plot the normalized cumulative relative error for the ROE Model
    plt.subplot(2, 1, 2)
    plt.plot(time_points_roe, normalized_cumulative_relerr_roe_x, label='ROE - Sum |Rel Err| x * dt / T')
    plt.plot(time_points_roe, normalized_cumulative_relerr_roe_y, label='ROE - Sum |Rel Err| y * dt / T')
    plt.plot(time_points_roe, normalized_cumulative_relerr_roe_z, label='ROE - Sum |Rel Err| z * dt / T')
    plt.title('Normalized Cumulative Relative Position Error in ECI - ROE Model')
    plt.xlabel('Time (s)')
    plt.ylabel('Normalized Cumulative Relative Error (%)')
    plt.legend()
    plt.grid(True)

plt.tight_layout()
plt.show()



# -------------------------------------------------------------------------- #
#
#
#                      SOLUTION OF THE EXERCISE
#
#
# -------------------------------------------------------------------------- #


#
# THE SOLUTION TO THE EXERCISE IS:
#
#     # ROE propagation (assuming ONLY Keplerian motion)
#     for _ in range(num_steps):
#         t += dt
#
#         da_new = da0
#         # First order
#         dM_new = dM0 - 1.5*n*da0*t
#         # Second order
#         dM_new = dM0 - 1.5*n*da0*t + 15/8*n*da0**2*t
#         # Third order
#         dM_new = dM0 - 1.5*n*da0*t + 15/8*n*da0**2*t - 105/48*n*da0**3*t
#         # And so on...
#         de_new = de0
#         dw_new = dw0
#         dinc_new = dinc0
#         dra_new = dra0
#
#         # Save
#         propagated_da.append(da_new)
#         propagated_dM.append(dM_new)
#         propagated_de.append(de_new)
#         propagated_dw.append(dw_new)
#         propagated_di.append(dinc_new)
#         propagated_dra.append(dra_new)
#
#         times.append(t)
#
# Because only dM varies when ONLY Keplerian Motion is considered.
# Let's derive together dot(dM):
#
# (1) dM = M_d - M_c
# (2) dot(dM) = dot(M_d) - dot(M_c) = n_d - n_c
# (3) dot(M_d) = n_d = sqrt(mu/a_d^3)
# (4) dot(M_d) can be approximated with: n_d ~ n_c * ( 1 + dot(n)_c * (a_d - a_c) )
#     where dot(n) = -3/2 sqrt(mu/a^5) since the only variable is a, and "_c" indicates that you need to calculate it using Chief alpha
# (5) finally knowing that da = (a_d - a_c)/a_c we obtain dot(dM) ~ -3/2 n_c * da
#
# Then we can approximate at the first order dM(t) = dM0 + dot(dM) * t = dM0 - 1.5 * n * da0 * t
#
# To add more orther one shall just compute the N derivatives of dot(M) that completes the expansion.
#