#include "c_cpp_aliases/aliases.h"
#include "vector/vec.h"
#include "window/window.h"
#include <SDL2/SDL_events.h>
#include <math.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>

#define EPSILON 0.001f

#define ARR_LEN(ARR) sizeof(ARR) / sizeof(ARR[0])

typedef struct {
  f32 radius;
  vec3f_t centre;
  colour_t colour;
  f32 specular;
} sphere_t;

typedef enum {
  LIGHT_TYPE_POINT,
  LIGHT_TYPE_DIRECTIONAL,
  LIGHT_TYPE_AMBIENT,

  COUNT_LIGHT_TYPE,
} light_type_t;

typedef struct {
  light_type_t type;
  f32 intensity;
  union {
    vec3f_t position;
    vec3f_t direction;
  };
} light_t;

typedef struct {
  sphere_t *spheres;
  light_t *lights;
  u32 spheres_count;
  u32 lights_count;
} scene_t;

typedef struct {
  f32 t1;
  f32 t2;
} solutions_t;

typedef struct {
  f32 closest_t;
  sphere_t *closest_sphere;
} intersection_t;

colour_t trace_ray(vec3f_t origin, vec3f_t direction, f32 t_min, f32 t_max,
                   const scene_t *scene, colour_t default_colour);
intersection_t find_closest_intersection(vec3f_t origin, vec3f_t direction,
                                         f32 t_min, f32 t_max,
                                         const scene_t *scene);
f32 calculate_lighting_for_intersection(vec3f_t origin, vec3f_t direction,
                                        intersection_t intersection,
                                        const scene_t *scene);
solutions_t ray_intersects_sphere(vec3f_t origin, vec3f_t direction,
                                  sphere_t sphere);
f32 compute_lighting(vec3f_t position, vec3f_t surface_normal,
                     vec3f_t view_vector, f32 specular_exponent,
                     const scene_t *scene);
f32 light_diffuse(f32 light_intensity, vec3f_t light_direction,
                  vec3f_t surface_normal);
f32 light_specular(f32 light_intensity, vec3f_t light_direction,
                   vec3f_t surface_normal, vec3f_t view_vector,
                   f32 specular_exponent);
f32 cos_angle_between_vectors(vec3f_t v1, vec3f_t v2);
f32 clamp(f32 value, f32 min, f32 max);

i32 main(i32 argc, char *argv[]) {
  colour_t bg =
      (colour_t){.rgba.r = 27, .rgba.g = 38, .rgba.b = 79, .rgba.a = 255};
  vec3f_t camera = {.x = 0.0f, .y = 0.0f, .z = 0.0f};
  vec3f_t viewport = {.x = 1.0f, .y = 1.0f, .z = 1.0f};

  window_t window = {0};

  if (!init_window(&window, 800, 800, "CG From Scratch")) {
    return EXIT_FAILURE;
  }

  bool running = true;
  SDL_Event event = {0};

  sphere_t spheres[] = {
      (sphere_t){
          .radius = 1.0f,
          .centre = (vec3f_t){.x = 0.0f, .y = -1.0f, .z = 3.0f},
          .colour =
              (colour_t){
                  .rgba.r = 245, .rgba.g = 238, .rgba.b = 158, .rgba.a = 255},
          .specular = 500.0f,
      },
      (sphere_t){
          .radius = 1.0f,
          .centre = (vec3f_t){.x = -2.0f, .y = 0.0f, .z = 4.0f},
          .colour =
              (colour_t){
                  .rgba.r = 59, .rgba.g = 142, .rgba.b = 165, .rgba.a = 255},
          .specular = 10.0f,
      },
      (sphere_t){
          .radius = 1.0f,
          .centre = (vec3f_t){.x = 2.0f, .y = 0.0f, .z = 4.0f},
          .colour =
              (colour_t){
                  .rgba.r = 171, .rgba.g = 52, .rgba.b = 40, .rgba.a = 255},
          .specular = 500.0f,
      },
      (sphere_t){
          .radius = 5000.0f,
          .centre = (vec3f_t){.x = 0.0f, .y = -5001.0f, .z = 0.0f},
          .colour =
              (colour_t){
                  .rgba.r = 255, .rgba.g = 255, .rgba.b = 0, .rgba.a = 255},
          .specular = 1000.0f,
      },
  };

  light_t lights[] = {
      (light_t){
          .type = LIGHT_TYPE_AMBIENT,
          .intensity = 0.2f,
      },
      (light_t){
          .type = LIGHT_TYPE_POINT,
          .intensity = 0.6f,
          .position = (vec3f_t){.x = 2.0f, .y = 1.0f, .z = 0.0f},
      },
      (light_t){
          .type = LIGHT_TYPE_DIRECTIONAL,
          .intensity = 0.2f,
          .direction = (vec3f_t){.x = 1.0f, .y = 4.0f, .z = 4.0f},
      },
  };

  scene_t scene = {
      .spheres = spheres,
      .lights = lights,
      .spheres_count = ARR_LEN(spheres),
      .lights_count = ARR_LEN(lights),
  };

  i32 w_min = ((i32)window.half_width) * -1;
  i32 w_max = (i32)window.half_width;
  i32 h_min = ((i32)window.half_height) * -1;
  i32 h_max = (i32)window.half_height;

  while (running) {
    while (SDL_PollEvent(&event)) {
      switch (event.type) {
      case SDL_QUIT:
        running = false;
        break;
      }
    }

    clear_window(&window, bg);

    for (i32 y = h_min; y < h_max; ++y) {
      for (i32 x = w_min; x < w_max; ++x) {
        vec3f_t direction = window_to_viewport(&window, x, y, viewport);
        colour_t colour = trace_ray(camera, direction, 1, INFINITY, &scene, bg);
        set_pixel(&window, x, y, colour);
      }
    }

    swap_buffers(&window);
  }

  close_window(&window);

  return EXIT_SUCCESS;
}

colour_t trace_ray(vec3f_t origin, vec3f_t direction, f32 t_min, f32 t_max,
                   const scene_t *scene, colour_t default_colour) {

  intersection_t intersection =
      find_closest_intersection(origin, direction, t_min, t_max, scene);

  if (!intersection.closest_sphere) {
    return default_colour;
  }

  f32 light = calculate_lighting_for_intersection(origin, direction,
                                                  intersection, scene);

  f32 r = (f32)(intersection.closest_sphere->colour.rgba.r) * light;
  r = clamp(r, 0.0f, (f32)UINT8_MAX);
  f32 g = (f32)(intersection.closest_sphere->colour.rgba.g) * light;
  g = clamp(g, 0.0f, (f32)UINT8_MAX);
  f32 b = (f32)(intersection.closest_sphere->colour.rgba.b) * light;
  b = clamp(b, 0.0f, (f32)UINT8_MAX);

  return (colour_t){
      .rgba.r = (u8)r,
      .rgba.g = (u8)g,
      .rgba.b = (u8)b,
      .rgba.a = intersection.closest_sphere->colour.rgba.a,
  };
}

intersection_t find_closest_intersection(vec3f_t origin, vec3f_t direction,
                                         f32 t_min, f32 t_max,
                                         const scene_t *scene) {
  f32 closest_t = INFINITY;
  sphere_t *closest_sphere = NULL;

  for (u32 i = 0; i < scene->spheres_count; ++i) {
    solutions_t solutions =
        ray_intersects_sphere(origin, direction, scene->spheres[i]);

    if (solutions.t1 >= t_min && solutions.t1 <= t_max &&
        solutions.t1 < closest_t) {
      closest_t = solutions.t1;
      closest_sphere = &(scene->spheres[i]);
    }

    if (solutions.t2 >= t_min && solutions.t2 <= t_max &&
        solutions.t2 < closest_t) {
      closest_t = solutions.t2;
      closest_sphere = &(scene->spheres[i]);
    }
  }

  return (intersection_t){closest_t, closest_sphere};
}

f32 calculate_lighting_for_intersection(vec3f_t origin, vec3f_t direction,
                                        intersection_t intersection,
                                        const scene_t *scene) {
  vec3f_t _direction = vec_mul_num(vec3f_t, direction, intersection.closest_t);
  vec3f_t position = vec_add(vec3f_t, origin, _direction);

  vec3f_t surface_normal =
      vec_sub(vec3f_t, position, intersection.closest_sphere->centre);

  f32 normal_magnitude = vec_magnitude(vec3f_t, surface_normal);
  surface_normal = vec_div_num(vec3f_t, surface_normal, normal_magnitude);

  vec3f_t view_vector = vec_mul_num(vec3f_t, direction, -1.0f);

  return compute_lighting(position, surface_normal, view_vector,
                          intersection.closest_sphere->specular, scene);
}

solutions_t ray_intersects_sphere(vec3f_t origin, vec3f_t direction,
                                  sphere_t sphere) {
  f32 r = sphere.radius;
  vec3f_t CO = vec_sub(vec3f_t, origin, sphere.centre);

  f32 a = vec_dot(vec3f_t, direction, direction);
  f32 b = 2.0f * vec_dot(vec3f_t, CO, direction);
  f32 c = vec_dot(vec3f_t, CO, CO) - r * r;

  f32 discriminant = b * b - 4 * a * c;
  if (discriminant < 0) {
    return (solutions_t){INFINITY, INFINITY};
  }

  f32 t1 = (-b + sqrtf(discriminant)) / (2 * a);
  f32 t2 = (-b - sqrtf(discriminant)) / (2 * a);

  return (solutions_t){t1, t2};
}

f32 compute_lighting(vec3f_t position, vec3f_t surface_normal,
                     vec3f_t view_vector, f32 specular_exponent,
                     const scene_t *scene) {
  f32 I = 0.0f;
  light_t light = {0};

  for (u32 i = 0; i < scene->lights_count; ++i) {
    light = scene->lights[i];

    if (light.type == LIGHT_TYPE_AMBIENT) {
      I += light.intensity;
    } else {
      vec3f_t light_direction = {0};
      f32 t_max = EPSILON;

      switch (light.type) {
      case LIGHT_TYPE_POINT:
        light_direction = vec_sub(vec3f_t, light.position, position);
        t_max = 1;
        break;
      case LIGHT_TYPE_DIRECTIONAL:
        light_direction = light.direction;
        t_max = INFINITY;
        break;
      default:
        break;
      }

      intersection_t shadow = find_closest_intersection(
          position, light_direction, EPSILON, t_max, scene);
      if (shadow.closest_sphere != NULL) {
        continue;
      }

      I += light_diffuse(light.intensity, light_direction, surface_normal);

      if (specular_exponent != -1.0f) {
        I += light_specular(light.intensity, light_direction, surface_normal,
                            view_vector, specular_exponent);
      }
    }
  }

  return I;
}

f32 light_diffuse(f32 light_intensity, vec3f_t light_direction,
                  vec3f_t surface_normal) {
  return light_intensity *
         cos_angle_between_vectors(light_direction, surface_normal);
}

f32 light_specular(f32 light_intensity, vec3f_t light_direction,
                   vec3f_t surface_normal, vec3f_t view_vector,
                   f32 specular_exponent) {
  vec3f_t _2N = vec_mul_num(vec3f_t, surface_normal, 2.0f);
  f32 dot_product = vec_dot(vec3f_t, light_direction, surface_normal);

  vec3f_t _2N_mul_dot = vec_mul_num(vec3f_t, _2N, dot_product);

  vec3f_t R = vec_sub(vec3f_t, _2N_mul_dot, light_direction);

  return light_intensity *
         powf(cos_angle_between_vectors(R, view_vector), specular_exponent);
}

f32 cos_angle_between_vectors(vec3f_t v1, vec3f_t v2) {
  f32 dot_product = vec_dot(vec3f_t, v1, v2);

  if (dot_product < 0.0f) {
    return 0.0f;
  }

  f32 divisor = vec_magnitude(vec3f_t, v1) * vec_magnitude(vec3f_t, v2);

  if (divisor == 0.0f) {
    return 0.0f;
  }

  return dot_product / divisor;
}

f32 clamp(f32 value, f32 min, f32 max) {
  if (value < min) {
    return min;
  }

  if (value > max) {
    return max;
  }

  return value;
}