#include "aliases.h"
#include "img.h"
#include "obj.h"
#include "render.h"
#include "shader.h"
#include "utils.h"
#include "vec.h"

typedef struct shader Shader;
struct shader {
  V3f light_dir;
  M4x4f mv_proj;
  VertexData vertices[TRIANGLE_VERTICES];
};

ShaderID perspective_diffuse = {0};
ShaderID perspective_albedo = {0};
ShaderID orthographic_diffuse = {0};
ShaderID orthographic_albedo = {0};

internal Shader perspective = {0};
internal Shader orthographic = {0};

internal V3f g_ambient_light = {0.1f, 0.1f, 0.1f};
internal V3f g_eye = {0.2f, 0.1f, 0.75f};
internal V3f g_target = {0};
internal V3f g_up = {0.0f, 1.0f, 0.0f};
internal V3f g_light_dir = {1.0f, 1.0f, 1.0f};

internal VertexData general_shader_vertex(void *shader, const VertexData *vert,
                                          u8 index, const Model *model);
internal FragmentResult diffuse_shader_fragment(void *shader,
                                                const V3f *barycentric,
                                                const Colour *colour,
                                                const Model *model);
internal FragmentResult albedo_shader_fragment(void *shader,
                                               const V3f *barycentric,
                                               const Colour *colour,
                                               const Model *model);
internal M4x4f get_projection_matrix(ProjectionType projection_type);

void load_shaders(void) {
  M4x4f model_view = lookat(g_eye, g_target, g_up);
  M4x4f orthographic_projection =
      get_projection_matrix(PROJECTION_TYPE_ORTHOGRAPHIC);
  M4x4f perspective_projection =
      get_projection_matrix(PROJECTION_TYPE_PERSPECTIVE);

  perspective.mv_proj = mat4x4_mul(perspective_projection, model_view);
  orthographic.mv_proj = mat4x4_mul(orthographic_projection, model_view);

  perspective.light_dir = mat3x3_mul_vec3(perspective.mv_proj, g_light_dir);
  normalise_v3(perspective.light_dir);
  orthographic.light_dir = mat3x3_mul_vec3(orthographic.mv_proj, g_light_dir);
  normalise_v3(orthographic.light_dir);

  perspective_diffuse = register_shader(&perspective, general_shader_vertex,
                                        diffuse_shader_fragment);
  perspective_albedo = register_shader(&perspective, general_shader_vertex,
                                       albedo_shader_fragment);
  orthographic_diffuse = register_shader(&orthographic, general_shader_vertex,
                                         diffuse_shader_fragment);
  orthographic_albedo = register_shader(&orthographic, general_shader_vertex,
                                        albedo_shader_fragment);
}

internal VertexData general_shader_vertex(void *shader, const VertexData *vert,
                                          u8 index, const Model *model) {
  Shader *shdr = (Shader *)shader;

  V4f vh = V3_to_V4(vert->position);
  vh = mat4x4_mul_vec4(shdr->mv_proj, vh);

  shdr->vertices[index].position = project_vec4(vh);
  shdr->vertices[index].uv = vert->uv;

  V4f hnorm = V3_to_V4(vert->normal);
  M4x4f inv_transpose = mat4x4_inv(mat4x4_transpose(shdr->mv_proj));
  hnorm = mat4x4_mul_vec4(inv_transpose, hnorm);
  shdr->vertices[index].normal = project_vec4(hnorm);
  normalise_v3(shdr->vertices[index].normal);

  return shdr->vertices[index];
}

internal FragmentResult diffuse_shader_fragment(void *shader,
                                                const V3f *barycentric,
                                                const Colour *colour,
                                                const Model *model) {
  Shader *shdr = (Shader *)shader;

  // clang-format off
  M3x3f pos_mat = {.rows = {shdr->vertices[0].position, shdr->vertices[1].position, shdr->vertices[2].position}};
  pos_mat = mat3x3_transpose(pos_mat);
  M3x3f normal_mat = {.rows = {shdr->vertices[0].normal, shdr->vertices[1].normal, shdr->vertices[2].normal}};
  normal_mat = mat3x3_transpose(normal_mat);
  M3x2f uvs = {shdr->vertices[0].uv, shdr->vertices[1].uv, shdr->vertices[2].uv};
  M2x3f uv_mat = mat3x2_transpose(uvs);
  // clang-format on

  V3f position = mat3x3_mul_vec3(pos_mat, (*barycentric));
  V3f normal = mat3x3_mul_vec3(normal_mat, (*barycentric));
  V2f uv = mat2x3_mul_vec3(uv_mat, (*barycentric));

#pragma region darboux_frame_tangent_normals
  /**
   * Based on the following section of the tinyrenderer tutorial
   * https://github.com/ssloy/tinyrenderer/wiki/Lesson-6bis:-tangent-space-normal-mapping#starting-point-phong-shading
   */

  if (model->normal) {
    u64 nm_x = uv.u * model->normal->width;
    u64 nm_y = uv.v * model->normal->height;

    Colour pixel = get_pixel(Colour, model->normal, nm_x, nm_y);
    V3f tangent = (V3f){
        .x = pixel.r / 255.f * 2.f - 1.f,
        .y = pixel.g / 255.f * 2.f - 1.f,
        .z = pixel.b / 255.f * 2.f - 1.f,
    };

    V3f p0p1 = sub_v3(shdr->vertices[1].position, shdr->vertices[0].position);
    V3f p0p2 = sub_v3(shdr->vertices[2].position, shdr->vertices[0].position);

    M3x3f A = {.rows = {p0p1, p0p2, normal}};
    M3x3f A_inv = mat3x3_inv(A);

    V2f uv0 = shdr->vertices[0].uv;
    V2f uv1 = shdr->vertices[1].uv;
    V2f uv2 = shdr->vertices[2].uv;

    V3f u_vec = {uv1.u - uv0.u, uv2.u - uv0.u, 0};
    V3f v_vec = {uv1.v - uv0.v, uv2.v - uv0.v, 0};

    V3f i = mat3x3_mul_vec3(A_inv, u_vec);
    normalise_v3(i);
    V3f j = mat3x3_mul_vec3(A_inv, v_vec);
    normalise_v3(j);

    M3x3f B = {.rows = {i, j, normal}};
    B = mat3x3_transpose(B);

    normal = mat3x3_mul_vec3(B, tangent);
    normalise_v3(normal);
  }
#pragma endregion darboux_frame_tangent_normals

  Colour output;
  if (model->texture) {
    u64 tx_x = uv.u * model->texture->width;
    u64 tx_y = uv.v * model->texture->height;
    output = get_pixel(Colour, model->texture, tx_x, tx_y);
  } else {
    output = *colour;
  }

  f32 intensity = max(0.001f, dot_v3(normal, shdr->light_dir));
  f32 r = clamp(intensity + g_ambient_light.r, 0.0f, 1.0f);
  f32 g = clamp(intensity + g_ambient_light.g, 0.0f, 1.0f);
  f32 b = clamp(intensity + g_ambient_light.b, 0.0f, 1.0f);

  output.r *= r;
  output.g *= g;
  output.b *= b;

  return (FragmentResult){.colour = output};
}

internal FragmentResult albedo_shader_fragment(void *shader,
                                               const V3f *barycentric,
                                               const Colour *colour,
                                               const Model *model) {
  Shader *shdr = (Shader *)shader;

  // clang-format off
  M3x2f uvs = {shdr->vertices[0].uv, shdr->vertices[1].uv, shdr->vertices[2].uv};
  M2x3f uv_mat = mat3x2_transpose(uvs);
  // clang-format on

  V2f uv = mat2x3_mul_vec3(uv_mat, (*barycentric));

  Colour output;
  if (model->texture) {
    u64 tx_x = uv.u * model->texture->width;
    u64 tx_y = uv.v * model->texture->height;
    output = get_pixel(Colour, model->texture, tx_x, tx_y);
  } else {
    output = *colour;
  }

  return (FragmentResult){.colour = output};
}

internal M4x4f get_projection_matrix(ProjectionType projection_type) {
  if (projection_type == PROJECTION_TYPE_PERSPECTIVE) {
    // Calculate projection matrix
    V3f cam = V3(V3f, f32, g_target.x, g_target.y, g_target.z, g_eye.x, g_eye.y,
                 g_eye.z);
    normalise_v3(cam);
    f32 coeff = -1.0f / magnitude_v3(cam) * 0.5f;
    return projection(coeff);
  }

  return mat4x4_identity;
}