// GLAD
#include "glad/glad.h"

// STB Image
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"

// GLM
#define GLM_ENABLE_EXPERIMENTAL
#include "glm/ext/matrix_transform.hpp"
#include "glm/ext/vector_float3.hpp"
#include "glm/gtc/type_ptr.hpp"
#include "glm/trigonometric.hpp"
#include "glm/gtx/rotate_vector.hpp"
#include "glm/gtx/string_cast.hpp"

// Assimp
#include "assimp/Importer.hpp"
#include "assimp/postprocess.h"
#include "assimp/scene.h"
#include "assimp/mesh.h"
#include "assimp/material.h"
#include "assimp/types.h"

// SDL
#include <SDL2/SDL.h>
#include <SDL2/SDL_timer.h>
#include <SDL2/SDL_stdinc.h>
#include <SDL2/SDL_error.h>
#include <SDL2/SDL_video.h>
#include <SDL2/SDL_events.h>
#include <SDL2/SDL_mouse.h>
#include <SDL2/SDL_keycode.h>

// STDLIB
#include <cstdint>
#include <cstdio>
#include <cmath>
#include <string>
#include <vector>

#define WINDOW_WIDTH 1280
#define WINDOW_HEIGHT 720
#define WINDOW_HALF_WIDTH 640
#define WINDOW_HALF_HEIGHT 360

#define min(a, b) (a < b ? a : b)
#define max(a, b) (a > b ? a : b)
#define clamp(v, a, b) (min(max(v, a), b))

enum exit_codes : int {
  EXIT_CODE_SUCCESS,
  EXIT_CODE_SDL_INIT_FAILED,
  EXIT_CODE_WINDOW_CREATION_FAILED,
  EXIT_CODE_OPENGL_CONTEXT_FAILED,
  EXIT_CODE_GLAD_LOADER_FAILED,
  EXIT_CODE_INCOMPLETE_FRAME_BUFFER,
};

class Shader {
  public:
    Shader(const std::string &vert_file, const std::string &frag_file);
    ~Shader();
    void activate();
    void set_int(const char *name, int value);
    void set_float(const char *name, float value);
    void set_vec3(const char *name, glm::vec3 vector);
    void set_vec4(const char *name, glm::vec4 vector);
    void set_mat3(const char *name, glm::mat3 matrix);
    void set_mat4(const char *name, glm::mat4 matrix);
    GLuint program;
  private:
    void link_program(GLuint vert, GLuint frag);
    GLuint load_and_compile_shader(const std::string &filepath, GLenum shader_type);
    std::string load_shader_from_file(const std::string &filepath);
    static const char *get_shader_type_string(GLenum shader_type);
};

enum TextureType : unsigned char {
  TEXTURE_TYPE_DIFFUSE,
  TEXTURE_TYPE_SPECULAR,
};

class Texture2D {
  public:
    Texture2D(const char *filename, GLint texture_unit, TextureType type);
    ~Texture2D();
    void activate() const;
    std::string name(unsigned int index) const;
    int width;
    int height;
    int channels;
    const char *filename;
    GLint texture_unit;
    TextureType type;
  private:
    GLuint texture;
    GLint format;
};

struct Vertex {
  glm::vec3 position;
  glm::vec3 normal;
  glm::vec2 tex_coord;
};

class Mesh {
  public:
    std::vector<Vertex> vertices;
    std::vector<GLuint> indices;
    std::vector<Texture2D> textures;

    Mesh(std::vector<Vertex> vertices, std::vector<GLuint> indices, std::vector<Texture2D> textures);
    void draw(Shader &Shader);
  private:
    GLuint vao, vbo, ebo;
    void setup_mesh();
};

class Model {
  public:
    Model(const char *path) {
      load_model(path);
    }

    void draw(Shader &shader);
  private:
    std::vector<Texture2D> loaded_textures;
    std::vector<Mesh> meshes;
    std::string directory;
    GLint texture_unit = GL_TEXTURE0;

    void load_model(std::string path);
    void process_node(aiNode *node, const aiScene *scene);
    Mesh process_mesh(aiMesh *mesh, const aiScene *scene);
    std::vector<Texture2D> load_material_textures(aiMaterial *mat, aiTextureType type);
};

struct FrameBuffer {
  GLuint fbo;
  GLuint color;
  GLuint depth_stencil;
};

FrameBuffer create_frame_buffer(unsigned int width, unsigned int height);
void delete_frame_buffer(FrameBuffer &buffer);

int main() {
  if (SDL_Init(SDL_INIT_EVERYTHING) != 0) {
    return EXIT_CODE_SDL_INIT_FAILED;
  }

  SDL_GL_SetAttribute(SDL_GL_BUFFER_SIZE, 24);
  SDL_GL_SetAttribute(SDL_GL_DOUBLEBUFFER, 1);
  SDL_GL_SetAttribute(SDL_GL_CONTEXT_MAJOR_VERSION, 3);
  SDL_GL_SetAttribute(SDL_GL_CONTEXT_MINOR_VERSION, 3);
  SDL_GL_SetAttribute(SDL_GL_CONTEXT_PROFILE_MASK, SDL_GL_CONTEXT_PROFILE_CORE);

  SDL_Window *window = SDL_CreateWindow("Window", SDL_WINDOWPOS_CENTERED, SDL_WINDOWPOS_CENTERED,
                                        WINDOW_WIDTH, WINDOW_HEIGHT, SDL_WINDOW_OPENGL | SDL_WINDOW_RESIZABLE);
  if (!window) {
    return EXIT_CODE_WINDOW_CREATION_FAILED;
  }

  SDL_GLContext context = SDL_GL_CreateContext(window);
  if (!context) {
    return EXIT_CODE_OPENGL_CONTEXT_FAILED;
  }

  if (gladLoadGLLoader(SDL_GL_GetProcAddress) == 0) {
    return EXIT_CODE_GLAD_LOADER_FAILED;
  }

  SDL_SetRelativeMouseMode(SDL_TRUE);
  SDL_WarpMouseInWindow(window, WINDOW_HALF_WIDTH, WINDOW_HALF_HEIGHT);

  glViewport(0, 0, WINDOW_WIDTH, WINDOW_HEIGHT);

  FrameBuffer offscreen_buffer = create_frame_buffer(WINDOW_WIDTH, WINDOW_HEIGHT);

  std::vector<Vertex> vertices = {
    // positions                           // normals                      // texture coords
    Vertex{glm::vec3(-0.5f, -0.5f, -0.5f), glm::vec3(-1.0f,  0.0f,  0.0f), glm::vec2(0.0f, 1.0f)},
    Vertex{glm::vec3( 0.5f,  0.5f,  0.5f), glm::vec3( 1.0f,  0.0f,  0.0f), glm::vec2(1.0f, 0.0f)},
    Vertex{glm::vec3( 0.5f,  0.5f, -0.5f), glm::vec3( 1.0f,  0.0f,  0.0f), glm::vec2(1.0f, 1.0f)},
    Vertex{glm::vec3(-0.5f,  0.5f,  0.5f), glm::vec3( 0.0f,  0.0f,  1.0f), glm::vec2(0.0f, 1.0f)},
    Vertex{glm::vec3( 0.5f, -0.5f,  0.5f), glm::vec3( 1.0f,  0.0f,  0.0f), glm::vec2(0.0f, 0.0f)},
    Vertex{glm::vec3( 0.5f, -0.5f, -0.5f), glm::vec3( 0.0f, -1.0f,  0.0f), glm::vec2(1.0f, 1.0f)},
    Vertex{glm::vec3(-0.5f, -0.5f, -0.5f), glm::vec3( 0.0f,  0.0f, -1.0f), glm::vec2(0.0f, 0.0f)},
    Vertex{glm::vec3( 0.5f, -0.5f,  0.5f), glm::vec3( 0.0f,  0.0f,  1.0f), glm::vec2(1.0f, 0.0f)},
    Vertex{glm::vec3(-0.5f, -0.5f, -0.5f), glm::vec3( 0.0f, -1.0f,  0.0f), glm::vec2(0.0f, 1.0f)},
    Vertex{glm::vec3( 0.5f,  0.5f,  0.5f), glm::vec3( 0.0f,  0.0f,  1.0f), glm::vec2(1.0f, 1.0f)},
    Vertex{glm::vec3( 0.5f, -0.5f,  0.5f), glm::vec3( 0.0f, -1.0f,  0.0f), glm::vec2(1.0f, 0.0f)},
    Vertex{glm::vec3(-0.5f, -0.5f,  0.5f), glm::vec3( 0.0f, -1.0f,  0.0f), glm::vec2(0.0f, 0.0f)},
    Vertex{glm::vec3( 0.5f,  0.5f, -0.5f), glm::vec3( 0.0f,  0.0f, -1.0f), glm::vec2(1.0f, 1.0f)},
    Vertex{glm::vec3( 0.5f,  0.5f,  0.5f), glm::vec3( 0.0f,  1.0f,  0.0f), glm::vec2(1.0f, 0.0f)},
    Vertex{glm::vec3( 0.5f, -0.5f, -0.5f), glm::vec3( 0.0f,  0.0f, -1.0f), glm::vec2(1.0f, 0.0f)},
    Vertex{glm::vec3(-0.5f,  0.5f,  0.5f), glm::vec3( 0.0f,  1.0f,  0.0f), glm::vec2(0.0f, 0.0f)},
    Vertex{glm::vec3(-0.5f,  0.5f, -0.5f), glm::vec3( 0.0f,  1.0f,  0.0f), glm::vec2(0.0f, 1.0f)},
    Vertex{glm::vec3(-0.5f, -0.5f,  0.5f), glm::vec3( 0.0f,  0.0f,  1.0f), glm::vec2(0.0f, 0.0f)},
    Vertex{glm::vec3(-0.5f, -0.5f,  0.5f), glm::vec3(-1.0f,  0.0f,  0.0f), glm::vec2(0.0f, 0.0f)},
    Vertex{glm::vec3(-0.5f,  0.5f, -0.5f), glm::vec3( 0.0f,  0.0f, -1.0f), glm::vec2(0.0f, 1.0f)},
    Vertex{glm::vec3(-0.5f,  0.5f, -0.5f), glm::vec3(-1.0f,  0.0f,  0.0f), glm::vec2(1.0f, 1.0f)},
    Vertex{glm::vec3( 0.5f,  0.5f, -0.5f), glm::vec3( 0.0f,  1.0f,  0.0f), glm::vec2(1.0f, 1.0f)},
    Vertex{glm::vec3(-0.5f,  0.5f,  0.5f), glm::vec3(-1.0f,  0.0f,  0.0f), glm::vec2(1.0f, 0.0f)},
    Vertex{glm::vec3( 0.5f, -0.5f, -0.5f), glm::vec3( 1.0f,  0.0f,  0.0f), glm::vec2(0.0f, 1.0f)},
  };

  std::vector<GLuint> indices = {
     6, 14, 12,
    12, 19,  6,
    17,  7,  9,
     9,  3, 17,
    22, 20,  0,
     0, 18, 22,
     1,  2, 23,
    23,  4,  1,
     8,  5, 10,
    10, 11,  8,
    16, 21, 13,
    13, 15, 16
  };

  Model backpack = {"models/suzanne/suzanne.obj"};
  Mesh light     = {vertices, indices, {}};

  std::vector<Vertex> screen_vertices = {
    Vertex{glm::vec3(-1.0f, -1.0f, 0.0f), glm::vec3(1.0f, 1.0f, 1.0f), glm::vec2(0.0f, 0.0f)},
    Vertex{glm::vec3( 1.0f, -1.0f, 0.0f), glm::vec3(1.0f, 1.0f, 1.0f), glm::vec2(1.0f, 0.0f)},
    Vertex{glm::vec3(-1.0f,  1.0f, 0.0f), glm::vec3(1.0f, 1.0f, 1.0f), glm::vec2(0.0f, 1.0f)},
    Vertex{glm::vec3( 1.0f,  1.0f, 0.0f), glm::vec3(1.0f, 1.0f, 1.0f), glm::vec2(1.0f, 1.0f)},
  };

  std::vector<GLuint> screen_indices = {
    0, 1, 2,
    2, 1, 3,
  };

  Mesh screen = {screen_vertices, screen_indices, {}};

  Shader main_shader     {"shaders/vert.glsl",    "shaders/frag.glsl"};
  Shader light_shader    {"shaders/vert.glsl",    "shaders/light_frag.glsl"};
  Shader post_processing {"shaders/pp_vert.glsl", "shaders/pp_frag.glsl"};

  const float camera_speed   = 25.0f;
  glm::vec3 camera_position  = glm::vec3(-2.0f, 0.0f, 6.0f);
  glm::vec3 camera_forward   = glm::vec3(0.0f);
  glm::vec3 world_up         = glm::vec3(0.0f, 1.0f, 0.0f);
  glm::vec3 light_ambient    = glm::vec3(0.2f, 0.2f, 0.2f);
  glm::vec3 light_diffuse    = glm::vec3(0.75f, 0.75f, 0.75f);
  glm::vec3 light_specular   = glm::vec3(1.0f, 1.0f, 1.0f);

  float yaw   = -70.0f;
  float pitch =   0.0f;

  glm::mat4 model      = glm::mat4(1.0f);
  glm::mat4 view       = glm::mat4(1.0f);
  glm::mat4 projection = glm::perspective(glm::radians(45.0f), (float)WINDOW_WIDTH / (float)WINDOW_HEIGHT, 0.1f, 100.0f);
  glm::mat3 normal_mat = glm::mat3(1.0f);

  main_shader.set_mat4 ("projection", projection);
  main_shader.set_float("material.shininess", 32.0f);

  light_shader.set_mat4("projection", projection);

  std::vector<glm::vec3> point_light_positions = {
    glm::vec3( 0.7f,  0.2f,  2.0f),
    glm::vec3( 2.3f, -3.3f, -4.0f),
    glm::vec3(-4.0f,  2.0f, -12.0f),
    glm::vec3( 1.0f,  0.0f, -18.0f)
  };

  // Setup lights
  main_shader.set_vec3("directional_light.direction", glm::vec3(-0.2f, -1.0f, -0.3f));
  main_shader.set_vec3("directional_light.ambient", light_ambient);
  main_shader.set_vec3("directional_light.diffuse", light_diffuse * 0.25f);
  main_shader.set_vec3("directional_light.specular", light_specular);

  main_shader.set_vec3("spot_light.ambient", light_ambient);
  main_shader.set_vec3("spot_light.diffuse", light_diffuse * 0.5f);
  main_shader.set_vec3("spot_light.specular", light_specular * 0.25f);

  for (int i = 0; i < point_light_positions.size(); ++i) {
    char base[256]      = {0};
    char position[512]  = {0};
    char ambient[512]   = {0};
    char diffuse[512]   = {0};
    char specular[512]  = {0};
    char constant[512]  = {0};
    char linear[512]    = {0};
    char quadratic[512] = {0};

    snprintf(base, sizeof(base) - 1, "point_lights[%d]", i);
    snprintf(position, sizeof(position) - 1, "%s.position", base);
    snprintf(ambient, sizeof(ambient) - 1, "%s.ambient", base);
    snprintf(diffuse, sizeof(diffuse) - 1, "%s.diffuse", base);
    snprintf(specular, sizeof(specular) - 1, "%s.specular", base);
    snprintf(constant, sizeof(constant) - 1, "%s.constant", base);
    snprintf(linear, sizeof(linear) - 1, "%s.linear", base);
    snprintf(quadratic, sizeof(quadratic) - 1, "%s.quadratic", base);

    main_shader.set_vec3(position, point_light_positions[i]);
    main_shader.set_vec3(ambient, light_ambient);
    main_shader.set_vec3(diffuse, light_diffuse * 0.25f);
    main_shader.set_vec3(specular, light_specular * 0.5f);
    main_shader.set_float(constant, 1.0f);
    main_shader.set_float(linear, 0.09f);
    main_shader.set_float(quadratic, 0.032f);

    memset(base,      0, sizeof(base));
    memset(position,  0, sizeof(position));
    memset(ambient,   0, sizeof(ambient));
    memset(diffuse,   0, sizeof(diffuse));
    memset(specular,  0, sizeof(specular));
    memset(constant,  0, sizeof(constant));
    memset(linear,    0, sizeof(linear));
    memset(quadratic, 0, sizeof(quadratic));
  }

  const float sensitivity   = 0.1f;
  int last_mouse_x          = WINDOW_HALF_WIDTH;
  int last_mouse_y          = WINDOW_HALF_HEIGHT;
  uint32_t last_frame       = SDL_GetTicks();
  float delta               = 0.0f;
  bool running              = true;
  SDL_Event event           = {};

  while (running) {
    uint32_t ticks = SDL_GetTicks();
    delta          = (ticks - last_frame) * 0.001f;
    last_frame     = ticks;

    while (SDL_PollEvent(&event)) {
      switch (event.type) {
        case SDL_QUIT:
          running = false;
          break;
        case SDL_KEYDOWN:
          if (event.key.keysym.sym == SDLK_ESCAPE) {
            running = false;
          } else if (event.key.keysym.sym == SDLK_w) {
            camera_position += camera_speed * delta * camera_forward;
          } else if (event.key.keysym.sym == SDLK_s) {
            camera_position -= camera_speed * delta * camera_forward;
          } else if (event.key.keysym.sym == SDLK_d) {
            camera_position += camera_speed * delta * glm::normalize(glm::cross(camera_forward, world_up));
          } else if (event.key.keysym.sym == SDLK_a) {
            camera_position -= camera_speed * delta * glm::normalize(glm::cross(camera_forward, world_up));
          }
          break;
        case SDL_MOUSEMOTION: {
            float x_offset = event.motion.xrel;
            float y_offset = -event.motion.yrel;

            last_mouse_x = last_mouse_x + event.motion.xrel;
            last_mouse_y = last_mouse_y + event.motion.yrel;

            x_offset *= sensitivity;
            y_offset *= sensitivity;

            yaw   += x_offset;
            pitch += y_offset;

            if(pitch > 89.0f) {
              pitch =  89.0f;
            }

            if(pitch < -89.0f) {
              pitch = -89.0f;
            }
          }
          break;
        case SDL_WINDOWEVENT:
          if (event.window.event == SDL_WINDOWEVENT_RESIZED) {
            SDL_Window *wnd = SDL_GetWindowFromID(event.window.windowID);
            if (!wnd) {
              continue;
            }
            int w, h;
            SDL_GL_GetDrawableSize(wnd, &w, &h);
            glViewport(0, 0, w, h);
            SDL_WarpMouseInWindow(wnd, (int)(w * 0.5f), (int)(h * 0.5f));

            // Recreate offscreen frame buffer
            delete_frame_buffer(offscreen_buffer);
            offscreen_buffer = create_frame_buffer(w, h);
          }
          break;
      }
    }

    camera_forward.x = cos(glm::radians(yaw)) * cos(glm::radians(pitch));
    camera_forward.y = sin(glm::radians(pitch));
    camera_forward.z = sin(glm::radians(yaw)) * cos(glm::radians(pitch));
    camera_forward   = glm::normalize(camera_forward);

    view = glm::lookAt(camera_position, camera_position + camera_forward, world_up);
    main_shader.set_vec3("camera_position", camera_position);
    main_shader.set_vec3("spot_light.position", camera_position);
    main_shader.set_vec3("spot_light.direction", camera_forward);
    main_shader.set_float("spot_light.cutoff", glm::cos(glm::radians(12.5)));
    main_shader.set_float("spot_light.outer_cutoff", glm::cos(glm::radians(17.5)));
    main_shader.set_mat4("view", view);
    light_shader.set_mat4("view", view);

    // Main render pass
    glBindFramebuffer(GL_FRAMEBUFFER, offscreen_buffer.fbo);

    glClearColor(0.04f, 0.08f, 0.08f, 1.0f);
    glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
    glEnable(GL_DEPTH_TEST);
    glEnable(GL_CULL_FACE);

    model = glm::translate(glm::mat4(1.0f), glm::vec3(0.0f, 0.0f, 0.0f));
    normal_mat = glm::transpose(glm::inverse(model));
    main_shader.activate();
    main_shader.set_mat4("model", model);
    main_shader.set_mat3("normal_mat", normal_mat);
    backpack.draw(main_shader);

    // Draw light source
    for (int i = 0; i < point_light_positions.size(); ++i) {
      model = glm::translate(glm::mat4(1.0f), point_light_positions[i]);
      model = glm::scale(model, glm::vec3(0.2f));
      light_shader.activate();
      light_shader.set_mat4("model", model);
      light_shader.set_vec3("light_diffuse", light_diffuse);
      light.draw(light_shader);
    }

    // wireframe mode
    // glPolygonMode(GL_FRONT_AND_BACK, GL_LINE);

    // Post processing pass
    glBindFramebuffer(GL_FRAMEBUFFER, 0);

    glClearColor(1.0f, 1.0f, 1.0f, 1.0f);
    glClear(GL_COLOR_BUFFER_BIT);
    glDisable(GL_DEPTH_TEST);
    glDisable(GL_CULL_FACE);

    post_processing.activate();
    glActiveTexture(GL_TEXTURE0);
    glBindTexture(GL_TEXTURE_2D, offscreen_buffer.color);
    post_processing.set_int("image_texture", 0);
    screen.draw(post_processing);

    SDL_GL_SwapWindow(window);
  }

  SDL_GL_DeleteContext(context);
  SDL_DestroyWindow(window);
  SDL_Quit();

  return EXIT_CODE_SUCCESS;
}


Shader::Shader(const std::string &vert_file, const std::string &frag_file) {
  GLuint vert = load_and_compile_shader(vert_file, GL_VERTEX_SHADER);
  GLuint frag = load_and_compile_shader(frag_file, GL_FRAGMENT_SHADER);
  link_program(vert, frag);
  glDeleteShader(vert);
  glDeleteShader(frag);
}

Shader::~Shader() {
  if (program > 0) {
    glDeleteProgram(program);
  }
}

void Shader::activate() {
  if (program > 0) {
    glUseProgram(program);
  }
}

void Shader::set_int(const char *name, int value) {
  activate();
  glUniform1i(glGetUniformLocation(program, name), value);
}

void Shader::set_float(const char *name, float value) {
  activate();
  glUniform1f(glGetUniformLocation(program, name), value);
}

void Shader::set_vec3(const char *name, glm::vec3 vector) {
  activate();
  glUniform3f(glGetUniformLocation(program, name), vector.x, vector.y, vector.z);
}

void Shader::set_vec4(const char *name, glm::vec4 vector) {
  activate();
  glUniform4f(glGetUniformLocation(program, name), vector.x, vector.y, vector.z, vector.w);
}

void Shader::set_mat3(const char *name, glm::mat3 matrix) {
  activate();
  glUniformMatrix3fv(glGetUniformLocation(program, name), 1, GL_FALSE, glm::value_ptr(matrix));
}

void Shader::set_mat4(const char *name, glm::mat4 matrix) {
  activate();
  glUniformMatrix4fv(glGetUniformLocation(program, name), 1, GL_FALSE, glm::value_ptr(matrix));
}

void Shader::link_program(GLuint vert, GLuint frag) {
  program = glCreateProgram();
  glAttachShader(program, vert);
  glAttachShader(program, frag);

  glLinkProgram(program);
  GLint program_linked;
  glGetProgramiv(program, GL_LINK_STATUS, &program_linked);
  if (program_linked != GL_TRUE)
  {
      GLsizei log_length = 0;
      GLchar message[1024];
      glGetProgramInfoLog(program, 1024, &log_length, message);
      printf("Failed to link program: %s\n", message);
      program = 0;
  }
}

GLuint Shader::load_and_compile_shader(const std::string &filepath, GLenum shader_type) {
  std::string src = load_shader_from_file(filepath);
  const char *shader_src = src.c_str();

  GLuint shader = glCreateShader(shader_type);
  glShaderSource(shader, 1, &shader_src, NULL);
  glCompileShader(shader);

  GLint shader_compiled;
  glGetShaderiv(shader, GL_COMPILE_STATUS, &shader_compiled);
  if (shader_compiled != GL_TRUE)
  {
      GLsizei log_length = 0;
      GLchar message[1024];
      glGetShaderInfoLog(shader, 1024, &log_length, message);
      printf("Failed to compile %s shader: %s\n", get_shader_type_string(shader_type), message);
      return 0;
  }

  return shader;
}

std::string Shader::load_shader_from_file(const std::string &filepath) {
  FILE *fp = fopen(filepath.c_str(), "r");
  if (!fp) {
    return "";
  }

  std::string output = {};

  char buf[1024] = {0};
  while (fgets(buf, sizeof(buf), fp)) {
    output += buf;
  }

  return output;
}

const char *Shader::get_shader_type_string(GLenum shader_type) {
  const char *output;

  switch (shader_type) {
    case GL_COMPUTE_SHADER:
      output = "compute";
      break;
    case GL_VERTEX_SHADER:
      output = "vertex";
      break;
    case GL_TESS_CONTROL_SHADER:
      output = "tess_control";
      break;
    case GL_TESS_EVALUATION_SHADER:
      output = "tess_evaluation";
      break;
    case GL_GEOMETRY_SHADER:
      output = "geometry";
      break;
    case GL_FRAGMENT_SHADER:
      output = "fragment";
      break;
    default:
      output = "UNKNOWN";
      break;
  }

  return output;
}

Texture2D::Texture2D(const char *filename, GLint texture_unit, TextureType type) : filename(filename), texture_unit(texture_unit), type(type) {
  uint8_t *image = stbi_load(filename, &width, &height, &channels, 0);
  if (!image) {
    return;
  }

  // TODO (Abdelrahman): This doesn't handle all formats
  format = channels > 3 ? GL_RGBA : GL_RGB;

  glGenTextures(1, &texture);
  activate();

  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR_MIPMAP_LINEAR);
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR);
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_REPEAT);
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_REPEAT);

  glTexImage2D(GL_TEXTURE_2D, 0, format, width, height, 0, format, GL_UNSIGNED_BYTE, image);
  glGenerateMipmap(GL_TEXTURE_2D);

  glBindTexture(GL_TEXTURE_2D, 0);
  stbi_image_free(image);
}

Texture2D::~Texture2D() {}

void Texture2D::activate() const {
  glActiveTexture(texture_unit);
  glBindTexture(GL_TEXTURE_2D, texture);
}

std::string Texture2D::name(unsigned int index) const {
  std::string output = "material.";

  switch (type) {
    case TEXTURE_TYPE_DIFFUSE:
      output += "diffuse";
      break;
    case TEXTURE_TYPE_SPECULAR:
      output += "specular";
      break;
  }

  output += std::to_string(index);

  return output.c_str();
}

Mesh::Mesh(std::vector<Vertex> vertices, std::vector<GLuint> indices, std::vector<Texture2D> textures)
  : vertices(vertices), indices(indices), textures(textures) {
    setup_mesh();
}

void Mesh::setup_mesh() {
  glGenVertexArrays(1, &vao);
  glGenBuffers(1, &ebo);
  glGenBuffers(1, &vbo);

  glBindVertexArray(vao);

  glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, ebo);
  glBufferData(GL_ELEMENT_ARRAY_BUFFER, indices.size() * sizeof(GLuint), indices.data(), GL_STATIC_DRAW);

  glBindBuffer(GL_ARRAY_BUFFER, vbo);
  glBufferData(GL_ARRAY_BUFFER, vertices.size() * sizeof(Vertex), vertices.data(), GL_STATIC_DRAW);

  glVertexAttribPointer(0, 3, GL_FLOAT, GL_FALSE, sizeof(Vertex), (void *)0);
  glEnableVertexAttribArray(0);
  glVertexAttribPointer(1, 3, GL_FLOAT, GL_FALSE, sizeof(Vertex), (void *)(offsetof(Vertex, normal)));
  glEnableVertexAttribArray(1);
  glVertexAttribPointer(2, 2, GL_FLOAT, GL_FALSE, sizeof(Vertex), (void *)(offsetof(Vertex, tex_coord)));
  glEnableVertexAttribArray(2);
}

void Mesh::draw(Shader &shader) {
  unsigned int diffuse = 1;
  unsigned int specular = 1;
  unsigned int index;

  for (int i = 0; i < textures.size(); ++i) {
    const Texture2D &texture = textures[i];
    index = texture.type == TEXTURE_TYPE_DIFFUSE ? diffuse++ : specular++;
    shader.set_int(texture.name(index).c_str(), texture.texture_unit - GL_TEXTURE0);
    texture.activate();
  }

  glBindVertexArray(vao);
  glDrawElements(GL_TRIANGLES, indices.size(), GL_UNSIGNED_INT, (void *)0);
  glBindVertexArray(0);
}

void Model::draw(Shader &shader) {
  for (int i = 0; i < meshes.size(); ++i) {
    meshes[i].draw(shader);
  }
}

void Model::load_model(std::string path) {
  Assimp::Importer importer;
  const aiScene *scene = importer.ReadFile(path, aiProcess_Triangulate | aiProcess_FlipUVs);
  if (!scene || scene->mFlags & AI_SCENE_FLAGS_INCOMPLETE || !scene->mRootNode) {
    printf("Failed to load model: %s. Error: %s\n", path.c_str(), importer.GetErrorString());
    return;
  }

  directory = path.substr(0, path.find_last_of('/'));
  process_node(scene->mRootNode, scene);
}

void Model::process_node(aiNode *node, const aiScene *scene) {
  for (unsigned int i = 0; i < node->mNumMeshes; ++i) {
    aiMesh *mesh = scene->mMeshes[node->mMeshes[i]];
    meshes.push_back(process_mesh(mesh, scene));
  }

  for (unsigned int i = 0; i < node->mNumChildren; ++i) {
    process_node(node->mChildren[i], scene);
  }
}

Mesh Model::process_mesh(aiMesh *mesh, const aiScene *scene) {
  std::vector<Vertex> vertices;
  std::vector<GLuint> indices;
  std::vector<Texture2D> textures;

  for (unsigned int i = 0; i < mesh->mNumVertices; ++i) {
    Vertex vertex;

    vertex.position = glm::vec3(mesh->mVertices[i].x, mesh->mVertices[i].y, mesh->mVertices[i].z);
    vertex.normal = glm::vec3(mesh->mNormals[i].x, mesh->mNormals[i].y, mesh->mNormals[i].z);
    if (mesh->mTextureCoords[0]) {
      vertex.tex_coord = glm::vec2(mesh->mTextureCoords[0][i].x, mesh->mTextureCoords[0][i].y);
    } else {
      vertex.tex_coord = glm::vec2(0.0f, 0.0f);
    }

    vertices.push_back(vertex);
  }

  for (unsigned int i = 0; i < mesh->mNumFaces; ++i) {
    aiFace face = mesh->mFaces[i];
    for (unsigned int j = 0; j < face.mNumIndices; ++j) {
      indices.push_back(face.mIndices[j]);
    }
  }

  if (mesh->mMaterialIndex >= 0) {
    aiMaterial *material = scene->mMaterials[mesh->mMaterialIndex];

    std::vector<Texture2D> diffuse_maps = load_material_textures(material, aiTextureType_DIFFUSE);
    textures.insert(textures.end(), diffuse_maps.begin(), diffuse_maps.end());

    std::vector<Texture2D> specular_maps = load_material_textures(material, aiTextureType_SPECULAR);
    textures.insert(textures.end(), specular_maps.begin(), specular_maps.end());
  }

  return Mesh(vertices, indices, textures);
}

std::vector<Texture2D> Model::load_material_textures(aiMaterial *material, aiTextureType type) {
  std::vector<Texture2D> textures;
  for (unsigned int i = 0; i < material->GetTextureCount(type); ++i) {
    aiString path;
    material->GetTexture(type, i, &path);
    std::string absolute_path = directory + '/' + path.C_Str();

    bool skip = false;
    for (unsigned int j = 0; j < loaded_textures.size(); ++j) {
      if (std::strcmp(loaded_textures[j].filename, absolute_path.c_str()) == 0) {
        textures.push_back(loaded_textures[j]);
        skip = true;
        break;
      }
    }

    if (!skip) {
      TextureType tex_type = type == aiTextureType_DIFFUSE ? TEXTURE_TYPE_DIFFUSE : TEXTURE_TYPE_SPECULAR;
      Texture2D texture = {absolute_path.c_str(), texture_unit++, tex_type};
      textures.push_back(texture);
    }
  }

  return textures;
}

FrameBuffer create_frame_buffer(unsigned int width, unsigned int height) {
  FrameBuffer buffer = {};

  glGenFramebuffers(1, &buffer.fbo);
  glBindFramebuffer(GL_FRAMEBUFFER, buffer.fbo);

  // Create color texture
  glGenTextures(1, &buffer.color);
  glBindTexture(GL_TEXTURE_2D, buffer.color);
  glTexImage2D(GL_TEXTURE_2D, 0, GL_RGB, width, height, 0, GL_RGB, GL_UNSIGNED_BYTE, NULL);
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR);
  glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR);

  glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, buffer.color, 0);

  // Create depth and stencil buffers
  glGenRenderbuffers(1, &buffer.depth_stencil);
  glBindRenderbuffer(GL_RENDERBUFFER, buffer.depth_stencil);
  glRenderbufferStorage(GL_RENDERBUFFER, GL_DEPTH24_STENCIL8, width, height);

  glFramebufferRenderbuffer(GL_FRAMEBUFFER, GL_DEPTH_STENCIL_ATTACHMENT, GL_RENDERBUFFER, buffer.depth_stencil);

  if (glCheckFramebufferStatus(GL_FRAMEBUFFER) != GL_FRAMEBUFFER_COMPLETE) {
    printf("Incomplete frame buffer\n");
    exit(EXIT_CODE_INCOMPLETE_FRAME_BUFFER);
  }

  glBindFramebuffer(GL_FRAMEBUFFER, 0);
  glBindTexture(GL_TEXTURE_2D, 0);
  glBindRenderbuffer(GL_RENDERBUFFER, 0);

  return buffer;
}

void delete_frame_buffer(FrameBuffer &buffer) {
  glDeleteFramebuffers(1, &buffer.fbo);
  glDeleteTextures(1, &buffer.color);
  glDeleteRenderbuffers(1, &buffer.depth_stencil);
}