#include "obj.h"
#include "aliases.h"
#include "img.h"
#include "mem_arena.h"
#include "pam.h"
#include "typed_list.h"
#include "utils.h"
#include <limits.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define TRIANGLE_VERTICES 3
#define CAMERA_DISTANCE 5.0f

#define V2(T, ELEM_T, X0, Y0, X1, Y1)                                          \
  ((T){(ELEM_T)X1 - (ELEM_T)X0, (ELEM_T)Y1 - (ELEM_T)Y0})
#define V3(T, ELEM_T, X0, Y0, Z0, X1, Y1, Z1)                                  \
  ((T){(ELEM_T)X1 - (ELEM_T)X0, (ELEM_T)Y1 - (ELEM_T)Y0,                       \
       (ELEM_T)Z1 - (ELEM_T)Z0})
#define dot_v2(V1, V2) ((f32)V1.x * (f32)V2.x + (f32)V1.y * (f32)V2.y)
#define dot_v3(V1, V2)                                                         \
  ((f32)V1.x * (f32)V2.x + (f32)V1.y * (f32)V2.y + (f32)V1.z * (f32)V2.z)
#define normalise_v3(V)                                                        \
  do {                                                                         \
    f32 magnitude = sqrtf(dot_v3(V, V));                                       \
    V.x /= magnitude;                                                          \
    V.y /= magnitude;                                                          \
    V.z /= magnitude;                                                          \
  } while (0);
#define cross_product(V1, V2)                                                  \
  ((V3f){                                                                      \
      .x = V1.y * V2.z - V1.z * V2.y,                                          \
      .y = V1.z * V2.x - V1.x * V2.z,                                          \
      .z = V1.x * V2.y - V1.y * V2.x,                                          \
  })
#define mat4x4_mul_vec4(MAT, V)                                                \
  ((V4f){                                                                      \
      .x = MAT.row0.x * V.x + MAT.row0.y * V.y + MAT.row0.z * V.z +            \
           MAT.row0.w * V.w,                                                   \
      .y = MAT.row1.x * V.x + MAT.row1.y * V.y + MAT.row1.z * V.z +            \
           MAT.row1.w * V.w,                                                   \
      .z = MAT.row2.x * V.x + MAT.row2.y * V.y + MAT.row2.z * V.z +            \
           MAT.row2.w * V.w,                                                   \
      .w = MAT.row3.x * V.x + MAT.row3.y * V.y + MAT.row3.z * V.z +            \
           MAT.row3.w * V.w,                                                   \
  })
#define project_vec4(V) ((V3f){.x = V.x / V.w, .y = V.y / V.w, .z = V.z / V.w})

typedef struct triangle_bbox TriangleBBox;
struct triangle_bbox {
  u64 x0;
  u64 y0;
  u64 x1;
  u64 y1;
};

internal void render_triangle(const Triangle *triangle, const Model *model,
                              Render *render, Colour colour, RenderType type,
                              ProjectionType projection);
internal void fill_triangle(Render *render, V3f vertices[TRIANGLE_VERTICES],
                            V2f coordinates[TRIANGLE_VERTICES], Colour colour,
                            f32 intensity, Image *texture);
internal void reorder_points(V2u vertices[TRIANGLE_VERTICES],
                             V2f coordinates[TRIANGLE_VERTICES]);
internal TriangleBBox get_triangle_bbox(const Image *img,
                                        V3f vertices[TRIANGLE_VERTICES]);
internal V3f get_barycentric_coords(f32 d00, f32 d01, f32 d11, f32 denom,
                                    const V2i *ab, const V2i *ac,
                                    const V2i *ap);
internal void get_image_coordinates(f32 norm_x, f32 norm_y, const Image *img,
                                    u64 *x, u64 *y);
internal u64 ndc_to_image_coordinate(f32 value, u64 max);

V3f g_light_dir = {0.0f, 0.0f, -1.0f};

// clang-format off
M4x4f g_cam_matrix = {
  .row0 = {1.0f, 0.0f,                    0.0f, 0.0f},
  .row1 = {0.0f, 1.0f,                    0.0f, 0.0f},
  .row2 = {0.0f, 0.0f,                    1.0f, 0.0f},
  .row3 = {0.0f, 0.0f, -1.0f / CAMERA_DISTANCE, 1.0f},
};
// clang-format on

Model load_obj_file(Arena *arena, const char *filename, const char *texture) {
  if (!arena) {
    return NULL_MODEL;
  }

  FILE *fp = fopen(filename, "r");
  if (!fp) {
    return NULL_MODEL;
  }

  Model model = (Model){
      .vertices = list_create(V3f, arena),
      .texture_coordinates = list_create(V2f, arena),
      .triangles = list_create(Triangle, arena),
  };
  if (!(model.vertices) || !(model.texture_coordinates) || !(model.triangles)) {
    return NULL_MODEL;
  }

  char line[8192];
  char identifier[8];
  V3f vertex;
  V2f coord;
  Triangle triangle;
  f32 vx, vy, vz;
  f32 u, v;
  u64 fp0, fp1, fp2;
  u64 tx0, tx1, tx2;
  u64 ign_0_2;
  u64 ign_1_2;
  u64 ign_2_2;
  while (fgets(line, 8191, fp) != NULL) {
    sscanf(line, "%s", identifier);
    if (strncmp(identifier, "v", 8) == 0) {
      sscanf(line + 2, "%f %f %f", &vx, &vy, &vz);
      vertex.x = vx;
      vertex.y = vy;
      vertex.z = vz;
      list_append(V3f, arena, model.vertices, vertex);
    } else if (strncmp(identifier, "vt", 8) == 0) {
      sscanf(line + 2, "%f %f", &u, &v);
      coord.u = u;
      coord.v = v;
      list_append(V2f, arena, model.texture_coordinates, coord);
    } else if (strncmp(identifier, "f", 8) == 0) {
      sscanf(line + 2, "%lu/%lu/%lu %lu/%lu/%lu %lu/%lu/%lu", &fp0, &tx0,
             &ign_0_2, &fp1, &tx1, &ign_1_2, &fp2, &tx2, &ign_2_2);
      // OBJ indices start from 1
      triangle.p0 = fp0 - 1;
      triangle.p1 = fp1 - 1;
      triangle.p2 = fp2 - 1;
      triangle.tx0 = tx0 - 1;
      triangle.tx1 = tx1 - 1;
      triangle.tx2 = tx2 - 1;
      list_append(Triangle, arena, model.triangles, triangle);
    }
  }

  if (texture) {
    model.texture = load_p6_image(arena, texture);
  }

  return model;
}

bool init_render(Arena *arena, Render *render, u64 width, u64 height) {
  render->img = (Image){.width = width, .height = height};
  if (!init_buffer(arena, &(render->img))) {
    return false;
  }

  render->depth = (Depth){.width = width, .height = height};
  if (!init_buffer(arena, &(render->depth))) {
    return false;
  }

  f32 inf = -INFINITY;
  clear_buffer(&(render->depth), &inf);

  return true;
}

void render_model(const Model *model, Render *render, Colour colour,
                  RenderType type, ColourType colour_type,
                  ProjectionType projection) {
  Triangle triangle;

  for (u64 i = 0; i < model->triangles->count; ++i) {
    triangle = list_get(model->triangles, i);
    if (colour_type == COLOUR_TYPE_RANDOM) {
      colour = (Colour){.r = rand() % UINT8_MAX,
                        .g = rand() % UINT8_MAX,
                        .b = rand() % UINT8_MAX,
                        .a = 255};
    }
    render_triangle(&triangle, model, render, colour, type, projection);
  }
}

internal void render_triangle(const Triangle *triangle, const Model *model,
                              Render *render, Colour colour, RenderType type,
                              ProjectionType projection) {
  Image *img = &(render->img);
  V3f vertices[TRIANGLE_VERTICES] = {
      list_get(model->vertices, triangle->p0),
      list_get(model->vertices, triangle->p1),
      list_get(model->vertices, triangle->p2),
  };
  V2f coordinates[TRIANGLE_VERTICES] = {
      list_get(model->texture_coordinates, triangle->tx0),
      list_get(model->texture_coordinates, triangle->tx1),
      list_get(model->texture_coordinates, triangle->tx2),
  };

  if (projection == PROJECTION_TYPE_PERSPECTIVE) {
    // Basic perspective projection
    V4f vertex;
    for (u64 i = 0; i < TRIANGLE_VERTICES; ++i) {
      vertex = (V4f){
          .x = vertices[i].x,
          .y = vertices[i].y,
          .z = vertices[i].z,
          .w = 1.0f,
      };
      vertex = mat4x4_mul_vec4(g_cam_matrix, vertex);
      vertices[i] = project_vec4(vertex);
    }
  }

  if (type == RENDER_TYPE_WIREFRAME) {
    V3f v0, v1;
    u64 x0, y0, x1, y1;
    for (u64 i = 0; i < TRIANGLE_VERTICES; ++i) {
      v0 = vertices[i];
      v1 = vertices[(i + 1) % TRIANGLE_VERTICES];

      get_image_coordinates(v0.x, v0.y, img, &x0, &y0);
      get_image_coordinates(v1.x, v1.y, img, &x1, &y1);

      draw_line(img, x0, y0, x1, y1, colour);
    }
  } else if (type == RENDER_TYPE_FILLED || type == RENDER_TYPE_SHADED) {
    f32 intensity = 1.0f;

    if (type == RENDER_TYPE_SHADED) {
      V3f ab = V3(V3f, f32, vertices[0].x, vertices[0].y, vertices[0].z,
                  vertices[1].x, vertices[1].y, vertices[1].z);
      V3f ac = V3(V3f, f32, vertices[0].x, vertices[0].y, vertices[0].z,
                  vertices[2].x, vertices[2].y, vertices[2].z);

      V3f normal = cross_product(ac, ab);
      normalise_v3(normal);

      intensity = dot_v3(normal, g_light_dir);
    }

    if (intensity > 0.0f) {
      fill_triangle(render, vertices, coordinates, colour, intensity,
                    model->texture);
    }
  }
}

internal void fill_triangle(Render *render, V3f vertices[TRIANGLE_VERTICES],
                            V2f coordinates[TRIANGLE_VERTICES], Colour colour,
                            f32 intensity, Image *texture) {
  Image *img = &(render->img);
  Depth *depth = &(render->depth);
  TriangleBBox bbox = get_triangle_bbox(img, vertices);

  V2u v0, v1, v2;
  get_image_coordinates(vertices[0].x, vertices[0].y, img, &(v0.x), &(v0.y));
  get_image_coordinates(vertices[1].x, vertices[1].y, img, &(v1.x), &(v1.y));
  get_image_coordinates(vertices[2].x, vertices[2].y, img, &(v2.x), &(v2.y));

  V2i ab = V2(V2i, i64, v0.x, v0.y, v1.x, v1.y);
  V2i ac = V2(V2i, i64, v0.x, v0.y, v2.x, v2.y);
  f32 d00 = dot_v2(ab, ab);
  f32 d01 = dot_v2(ab, ac);
  f32 d11 = dot_v2(ac, ac);
  f32 denom = d00 * d11 - d01 * d01;
  V2i ap;
  V3f coords;
  f32 z;
  f32 zbuf;
  f32 tx_u, tx_v;
  u64 tx_x, tx_y;

  if (!texture) {
    colour.r *= intensity;
    colour.g *= intensity;
    colour.b *= intensity;
  }

  for (u64 y = bbox.y0; y <= bbox.y1; ++y) {
    for (u64 x = bbox.x0; x <= bbox.x1; ++x) {
      ap = V2(V2i, i64, v0.x, v0.y, x, y);
      coords = get_barycentric_coords(d00, d01, d11, denom, &ab, &ac, &ap);
      if (coords.x < 0.0f || coords.y < 0.0f || coords.x + coords.y > 1.0f) {
        continue;
      }

      z = 0.0f;
      z += vertices[0].z * coords.x + vertices[1].z * coords.y +
           vertices[2].z * coords.z;
      zbuf = get_pixel(f32, &(render->depth), x, y);

      if (z > zbuf) {
        if (texture) {
          tx_u = coordinates[0].u * coords.x + coordinates[1].u * coords.y +
                 coordinates[2].u * coords.z;
          tx_v = coordinates[0].v * coords.x + coordinates[1].v * coords.y +
                 coordinates[2].v * coords.z;
          tx_x = tx_u * texture->width;
          tx_y = (1.0f - tx_v) * texture->height;

          colour = get_pixel(Colour, texture, tx_x, tx_y);
          colour.r *= intensity;
          colour.g *= intensity;
          colour.b *= intensity;
        }

        set_pixel(depth, x, y, &z);
        set_pixel(img, x, y, &colour);
      }
    }
  }
}

internal TriangleBBox get_triangle_bbox(const Image *img,
                                        V3f vertices[TRIANGLE_VERTICES]) {
  f32 x0 = min(vertices[0].x, min(vertices[1].x, vertices[2].x));
  f32 x1 = max(vertices[0].x, max(vertices[1].x, vertices[2].x));
  // NOTE (Abdelrahman): Because y is flipped, we use max for the minimum and
  // min for the maximum
  f32 y0 = max(vertices[0].y, max(vertices[1].y, vertices[2].y));
  f32 y1 = min(vertices[0].y, min(vertices[1].y, vertices[2].y));

  TriangleBBox bbox = {0};
  get_image_coordinates(x0, y0, img, &(bbox.x0), &(bbox.y0));
  get_image_coordinates(x1, y1, img, &(bbox.x1), &(bbox.y1));

  return bbox;
}

internal V3f get_barycentric_coords(f32 d00, f32 d01, f32 d11, f32 denom,
                                    const V2i *ab, const V2i *ac,
                                    const V2i *ap) {
  if (denom == 0.0f) {
    return (V3f){-INFINITY, -INFINITY, -INFINITY};
  }

  f32 d20 = dot_v2((*ap), (*ab));
  f32 d21 = dot_v2((*ap), (*ac));

  f32 v = (d11 * d20 - d01 * d21) / denom;
  f32 w = (d00 * d21 - d01 * d20) / denom;
  f32 u = 1.0f - v - w;

  return (V3f){v, w, u};
}

internal void get_image_coordinates(f32 norm_x, f32 norm_y, const Image *img,
                                    u64 *x, u64 *y) {
  *x = ndc_to_image_coordinate(norm_x, img->width);
  *y = ndc_to_image_coordinate(0.0f - norm_y, img->height);

  if (*x >= img->width) {
    *x = img->width - 1;
  }
  if (*y >= img->height) {
    *y = img->height - 1;
  }
}

internal u64 ndc_to_image_coordinate(f32 value, u64 max) {
  f32 result = (value + 1.0f) * max * 0.5f;
  return clamp((u64)result, 0, max);
}