#include "Vehicle.h"
#include "outVector.h"

#include <iostream>

void burnStartTimeCalc(struct Vehicle &);
void thrustSelection(struct Vehicle &, int t);
void pidController(struct Vehicle &, struct Vehicle &);
void TVC(struct Vehicle &);
void vehicleDynamics(struct Vehicle &, struct Vehicle &, int t);
void state2vec(struct Vehicle &, struct Vehicle &, struct outVector &, int t);
void write2CSV(struct outVector &, struct Vehicle &);
double derivative(double x2, double x1, double dt);
double integral(double x2, double x1, double dt);

// Any parameters that are constants should be declared here instead of buried
// in code
double const dt = 0.01;
double const g = -9.81;

bool sim(struct Vehicle &State, struct Vehicle &PrevState) {

  outVector stateVector;

  // Determine when to burn
  burnStartTimeCalc(State);

  int t = 0;

  // Start Sim
  do {
    vehicleDynamics(State, PrevState, t);
    thrustSelection(State, t);
    pidController(State, PrevState);
    TVC(State);
    state2vec(State, PrevState, stateVector, t);

    t += State.stepSize;
  } while ((State.z > 0.0) || (State.thrust > 0.1));

  write2CSV(stateVector, State);

  bool pass = 1;

  double landing_angle =
      pow(State.yaw * State.yaw + State.pitch * State.pitch, .5);

  double landing_velocity =
      pow(State.vx * State.vx + State.vy * State.vy + State.vz * State.vz, .5);

  if (landing_angle < 5.0) {
    std::cout << "   Landing Angle < 5°    | PASS | ";
  } else {
    std::cout << "   Landing Angle < 5°    | FAIL | ";
    pass = pass * 0;
  }
  std::cout << "Final Angles: [" << State.yaw << ", " << State.pitch << "]"
            << std::endl;

  if (landing_velocity < 5.0) {
    std::cout << "Landing Velocity < 5 m/s | PASS | ";
  } else {
    std::cout << "Landing Velocity < 5 m/s | FAIL | ";
    pass = pass * 0;
  }
  std::cout << "Final Velocity: [" << State.vx << ", " << State.vy << ", "
            << State.vz << "]" << std::endl;

  return pass;
}

void burnStartTimeCalc(Vehicle &State) {
  double velocity = State.vz;
  double h = 0;

  double mass, thrust;

  // Piecewise functions for F15 thrust curve
  for (double i = 0.148; i < 3.450; i = i + dt) {
    mass = State.massInitial - i * State.mdot;

    if ((i > 0.147) && (i < 0.420))
      thrust = 65.165 * i - 2.3921;

    else if ((i > 0.419) && (i < 3.383))
      thrust = 0.8932 * pow(i, 6) - 11.609 * pow(i, 5) + 60.739 * pow(i, 4) -
               162.99 * pow(i, 3) + 235.6 * pow(i, 2) - 174.43 * i + 67.17;

    else if ((i > 3.382) && (i < 3.46))
      thrust = -195.78 * i + 675.11;

    velocity = (((thrust / mass) + g) * dt) + velocity;
    h = velocity * dt + h;
  }

  State.z = h + (pow(velocity, 2) / (2 * -g)); // starting height
  State.burnVelocity = velocity;               // terminal velocity

  double burnStartTime = State.burnVelocity / -g;
  State.simTime = (State.burntime + burnStartTime) * 1000;
}

void vehicleDynamics(Vehicle &State, Vehicle &PrevState, int t) {
  // Moment of Inertia
  State.I11 = State.mass * ((1 / 12) * pow(State.vehicleHeight, 2) +
                            pow(State.vehicleRadius, 2) / 4);
  State.I22 = State.mass * ((1 / 12) * pow(State.vehicleHeight, 2) +
                            pow(State.vehicleRadius, 2) / 4);
  State.I33 = State.mass * 0.5 * pow(State.vehicleRadius, 2);

  // Idot
  if (t < 0.1) {
    State.I11dot = 0;
    State.I22dot = 0;
    State.I33dot = 0;

    State.x = 0;
    State.y = 0;

    State.ax = 0;
    State.ay = 0;
    State.az = State.Fz / State.massInitial;

  } else {
    State.I11dot = derivative(State.I11, PrevState.I11, State.stepSize);
    State.I22dot = derivative(State.I22, PrevState.I22, State.stepSize);
    State.I33dot = derivative(State.I33, PrevState.I33, State.stepSize);

    // pdot, qdot, rdot
    State.yawddot = (State.momentX - State.I11dot * PrevState.yawdot +
                     State.I22 * PrevState.pitchdot * PrevState.rolldot -
                     State.I33 * PrevState.pitchdot * PrevState.rolldot) /
                    State.I11;
    State.pitchddot = (State.momentY - State.I22dot * PrevState.pitchdot -
                       State.I11 * PrevState.rolldot * PrevState.yawdot +
                       State.I33 * PrevState.rolldot * PrevState.yawdot) /
                      State.I22;
    State.rollddot = (State.momentZ - State.I33dot * PrevState.rolldot +
                      State.I11 * PrevState.pitchdot * PrevState.yawdot -
                      State.I22 * PrevState.pitchdot * PrevState.yawdot) /
                     State.I33;

    // p, q, r
    State.yawdot = integral(State.yawddot, PrevState.yawdot, State.stepSize);
    State.pitchdot =
        integral(State.pitchddot, PrevState.pitchdot, State.stepSize);
    State.rolldot = integral(State.rollddot, PrevState.rolldot, State.stepSize);

    // Euler Angles
    State.phidot =
        State.yawdot + (sin(State.pitch) * (State.rolldot * cos(State.yaw) +
                                            State.pitchdot * sin(State.yaw))) /
                           cos(State.pitch);
    State.thetadot =
        State.pitchdot * cos(State.yaw) - State.rolldot * sin(State.yaw);
    State.psidot =
        (State.rolldot * cos(State.yaw) + State.pitchdot * sin(State.yaw)) /
        cos(State.pitch);

    State.yaw = integral(State.phidot, PrevState.yaw, State.stepSize);
    State.pitch = integral(State.thetadot, PrevState.pitch, State.stepSize);
    State.roll = integral(State.psidot, PrevState.roll, State.stepSize);

    // ax ay az
    State.ax = (State.Fx / State.mass);
    State.ay = (State.Fy / State.mass);
    State.az = (State.Fz / State.mass);

    // vx vy vz in Earth frame
    State.vx = integral(State.ax, PrevState.vx, State.stepSize);
    State.vy = integral(State.ay, PrevState.vy, State.stepSize);
    State.vz = integral(State.az, PrevState.vz, State.stepSize);

    // Xe
    State.x = integral(State.vx, PrevState.x, State.stepSize);
    State.y = integral(State.vy, PrevState.y, State.stepSize);
    State.z = integral(State.vz, PrevState.z, State.stepSize);
  }
}

void thrustSelection(Vehicle &State, int t) {

  if (State.burnElapsed != 2000) {
    // determine where in the thrust curve we're at based on elapsed burn time
    // as well as current mass
    State.burnElapsed = (t - State.burnStart) / 1000;
    State.mass = State.massInitial - (State.mdot * State.burnElapsed);
  }

  else if (abs(State.burnVelocity + State.vz) < 0.001) {
    // Start burn
    State.burnStart = t;
    State.burnElapsed = 0;
  }

  else
    State.burnElapsed = 2000; // arbitrary number to ensure we don't burn

  if ((State.burnElapsed > 0.147) && (State.burnElapsed < 0.420)) {
    State.thrustFiring = true;
    State.thrust = 65.165 * State.burnElapsed - 2.3921;

  } else if ((State.burnElapsed > 0.419) && (State.burnElapsed < 3.383))
    State.thrust = 0.8932 * pow(State.burnElapsed, 6) -
                   11.609 * pow(State.burnElapsed, 5) +
                   60.739 * pow(State.burnElapsed, 4) -
                   162.99 * pow(State.burnElapsed, 3) +
                   235.6 * pow(State.burnElapsed, 2) -
                   174.43 * State.burnElapsed + 67.17;

  else if ((State.burnElapsed > 3.382) && (State.burnElapsed < 3.46))
    State.thrust = -195.78 * State.burnElapsed - 675.11;

  if (State.burnElapsed > 3.45) {
    State.thrustFiring = false;
    State.thrust = 0;
  }
}

void pidController(Vehicle &State, struct Vehicle &PrevState) {
  // Make sure we start reacting when we start burning
  if (State.thrust > 0.01) {

    State.yError = State.yaw;
    State.pError = State.pitch;

    // Integral of Error
    State.i_yError = integral(State.yError, State.i_yError, State.stepSize);
    State.i_pError = integral(State.pError, State.i_pError, State.stepSize);

    // Derivative of Error
    State.d_yError = derivative(State.yError, PrevState.yError, State.stepSize);
    State.d_pError = derivative(State.pError, PrevState.pError, State.stepSize);

    // TVC block properly

    State.PIDx = (State.Kp * State.yError + State.Ki * State.i_yError +
                  State.Kd * State.d_yError) /
                 State.momentArm;
    State.PIDy = (State.Kp * State.pError + State.Ki * State.i_pError +
                  State.Kd * State.d_pError) /
                 State.momentArm;

  } else {
    State.PIDx = 0;
    State.PIDy = 0;
  }

  // PID Force limiter X
  if (State.PIDx > State.thrust)
    State.PIDx = State.thrust;
  else if (State.PIDx < -1 * State.thrust)
    State.PIDx = -1 * State.thrust;

  // PID Force limiter Y
  if (State.PIDy > State.thrust)
    State.PIDy = State.thrust;
  else if (State.PIDy < -1 * State.thrust)
    State.PIDy = -1 * State.thrust;
}

void TVC(Vehicle &State) {
  if (State.thrust < 0.1) {
    // Define forces and moments for t = 0
    State.Fx = 0;
    State.Fy = 0;
    State.Fz = g * State.massInitial;

    State.momentX = 0;
    State.momentY = 0;
    State.momentZ = 0;

  } else {
    // Convert servo position to degrees for comparison to max allowable
    State.xServoDegs = (180 / M_PI) * asin(State.PIDx / State.thrust);

    // Servo position limiter
    if (State.xServoDegs > State.maxServo)
      State.xServoDegs = State.maxServo;
    else if (State.xServoDegs < -1 * State.maxServo)
      State.xServoDegs = -1 * State.maxServo;

    // Convert servo position to degrees for comparison to max allowable
    State.yServoDegs = (180 / M_PI) * asin(State.PIDy / State.thrust);

    // Servo position limiter
    if (State.yServoDegs > State.maxServo)
      State.yServoDegs = State.maxServo;
    else if (State.yServoDegs < -1 * State.maxServo)
      State.yServoDegs = -1 * State.maxServo;

    // Vector math to aqcuire thrust vector components
    State.Fx = State.thrust * sin(State.xServoDegs * (M_PI / 180));
    State.Fy = State.thrust * sin(State.yServoDegs * (M_PI / 180));
    State.Fz =
        sqrt(pow(State.thrust, 2) - pow(State.Fx, 2) - pow(State.Fy, 2)) +
        (State.mass * g);

    // Calculate moment created by Fx and Fy
    State.momentX = State.Fx * State.momentArm;
    State.momentY = State.Fy * State.momentArm;
    State.momentZ = 0;
  }
}

void state2vec(Vehicle &State, Vehicle &PrevState, outVector &stateVector,
               int t) {
  stateVector.x[t] = State.x;
  stateVector.y[t] = State.y;
  stateVector.z[t] = State.z;

  stateVector.vx[t] = State.vx;
  stateVector.vy[t] = State.vy;
  stateVector.vz[t] = State.vz;

  stateVector.ax[t] = State.ax;
  stateVector.ay[t] = State.ay;
  stateVector.az[t] = State.az;

  stateVector.yaw[t] = State.yaw;
  stateVector.pitch[t] = State.pitch;
  stateVector.roll[t] = State.roll;

  stateVector.yawdot[t] = State.yawdot;
  stateVector.pitchdot[t] = State.pitchdot;
  stateVector.rolldot[t] = State.rolldot;

  stateVector.servo1[t] = State.xServoDegs;
  stateVector.servo2[t] = State.yServoDegs;

  stateVector.thrustFiring[t] = State.thrustFiring;

  stateVector.PIDx[t] = State.PIDx;
  stateVector.PIDy[t] = State.PIDy;

  // Set "prev" values for next timestep
  PrevState = State;
}

void write2CSV(outVector &stateVector, Vehicle &State) {

  // Deleting any previous output file
  if (remove("simOut.csv") != 0)
    perror("No file deletion necessary");
  else
    puts("Previous output file successfully deleted");

  // Define and open output file "simOut.csv"
  std::fstream outfile;
  outfile.open("simOut.csv", std::ios::app);

  // Output file header. These are the variables that we output - useful for
  // debugging
  outfile << "t, x, y, z, vx, vy, vz, ax, ay, az, yaw, pitch, roll, yawdot, "
             "pitchdot, rolldot, Servo1, Servo2, thrustFiring, PIDx, PIDy, "
             "thrust, deriv"
          << std::endl;

  // writing to output file
  for (int t = 0; t < State.simTime; t += State.stepSize) {
    outfile << t << ", ";

    outfile << stateVector.x[t] << ", ";
    outfile << stateVector.y[t] << ", ";
    outfile << stateVector.z[t] << ", ";

    outfile << stateVector.vx[t] << ", ";
    outfile << stateVector.vy[t] << ", ";
    outfile << stateVector.vz[t] << ", ";

    outfile << stateVector.ax[t] << ", ";
    outfile << stateVector.ay[t] << ", ";
    outfile << stateVector.az[t] << ", ";

    outfile << stateVector.yaw[t] * 180 / M_PI << ", ";
    outfile << stateVector.pitch[t] * 180 / M_PI << ", ";
    outfile << stateVector.roll[t] * 180 / M_PI << ", ";

    outfile << stateVector.yawdot[t] * 180 / M_PI << ", ";
    outfile << stateVector.pitchdot[t] * 180 / M_PI << ", ";
    outfile << stateVector.rolldot[t] * 180 / M_PI << ", ";

    outfile << stateVector.servo1[t] << ", ";
    outfile << stateVector.servo2[t] << ", ";

    outfile << stateVector.thrustFiring[t] << ", ";

    outfile << stateVector.PIDx[t] << ", ";
    outfile << stateVector.PIDy[t] << std::endl;
  }

  outfile.close();
  std::cout << "simOut.csv created successfully.\n" << std::endl;
}

double derivative(double current, double previous, double step) {
  double dxdt = (current - previous) / (step / 1000);
  return dxdt;
}

double integral(double currentChange, double prevValue, double dt) {
  return (currentChange * dt / 1000) + prevValue;
}