#include "mem_arena.h"
#include "aliases.h"
#include "mem_utils.h"
#include <stdbool.h>
#include <stdlib.h>
#include <string.h>

#ifndef DEFAULT_ALIGNMENT
// Why 2 * sizeof(void *) instead of sizeof(void *)
// https://handmade.network/forums/t/6860-alignment_arena_allocator
#define DEFAULT_ALIGNMENT (2 * sizeof(void *))
#endif /* ifndef DEFAULT_ALIGNMENT */

typedef struct base_arena BaseArena;
struct base_arena {
  u8 *buf;
  u8 *offset;
  u64 capacity;
  BaseArena *prev;
  BaseArena *next;
};

struct growing_arena {
  BaseArena *active_arena;
  u64 count;
  u64 initial_capacity;
};

internal bool base_arena_init(BaseArena *arena, u64 capacity);
internal void *base_arena_alloc_aligned(BaseArena *arena, u64 size,
                                        u64 alignment);
internal void base_arena_clear(BaseArena *arena);
internal void base_arena_free(BaseArena *arena);

// PUBLIC API

bool mem_arena_init(Arena **arena, u64 base_capacity) {
  if (!arena || *arena) {
    return false;
  }

  *arena = (Arena *)malloc(sizeof(Arena));
  Arena *arena_ptr = *arena;
  if (!arena_ptr) {
    return false;
  }

  memset(arena_ptr, 0, sizeof(Arena));

  arena_ptr->active_arena = (BaseArena *)malloc(sizeof(BaseArena));
  if (!(arena_ptr->active_arena)) {
    mem_arena_free(arena);
    return false;
  }

  memset(arena_ptr->active_arena, 0, sizeof(BaseArena));

  if (!base_arena_init(arena_ptr->active_arena, base_capacity)) {
    mem_arena_free(arena);
    return false;
  }

  arena_ptr->count = 1;
  arena_ptr->initial_capacity = base_capacity;

  return true;
}

void *mem_arena_alloc(Arena *arena, u64 size) {
  return mem_arena_alloc_aligned(arena, size, DEFAULT_ALIGNMENT);
}

void *mem_arena_alloc_aligned(Arena *arena, u64 size, u64 alignment) {
  if (!arena || !(arena->active_arena)) {
    return NULL;
  }

  void *output = base_arena_alloc_aligned(arena->active_arena, size, alignment);
  if (!output) {
    if (arena->active_arena->next) {
      arena->active_arena = arena->active_arena->next;
    } else {
      arena->active_arena->next = (BaseArena *)malloc(sizeof(BaseArena));
      if (!(arena->active_arena->next)) {
        return NULL;
      }

      memset(arena->active_arena->next, 0, sizeof(BaseArena));

      if (!base_arena_init(arena->active_arena->next,
                           arena->initial_capacity)) {
        free(arena->active_arena->next);
        return NULL;
      }

      arena->active_arena->next->prev = arena->active_arena;
      arena->active_arena = arena->active_arena->next;

      ++(arena->count);
    }

    output = base_arena_alloc_aligned(arena->active_arena, size, alignment);
    if (!output) {
      return NULL;
    }
  }

  memset(output, 0, size);

  return output;
}

void mem_arena_clear(Arena *arena) {
  if (!arena) {
    return;
  }

  BaseArena *new_active = NULL;
  while (arena->active_arena) {
    base_arena_clear(arena->active_arena);

    arena->active_arena = arena->active_arena->prev;

    if (arena->active_arena) {
      new_active = arena->active_arena;
    }
  }

  arena->active_arena = new_active;
}

void mem_arena_free(Arena **arena) {
  if (!arena) {
    return;
  }

  Arena *arena_ptr = *arena;

  BaseArena *current;
  BaseArena *next;
  BaseArena *prev;

  current = arena_ptr->active_arena->next;
  while (current) {
    next = current->next;

    base_arena_free(current);
    free(current);

    current = next;
  }

  current = arena_ptr->active_arena->prev;
  while (current) {
    prev = current->prev;

    base_arena_free(current);
    free(current);

    current = prev;
  }

  base_arena_free(arena_ptr->active_arena);

  free(arena_ptr->active_arena);
  arena_ptr->active_arena = NULL;

  arena_ptr->count = 0;
  arena_ptr->initial_capacity = 0;

  free(*arena);
  *arena = NULL;
}

// INTERNAL FUNCTIONS

internal bool base_arena_init(BaseArena *arena, u64 capacity) {
  if (!arena || arena->buf) {
    return false;
  }

  u64 alloc_size = sizeof(u8) * capacity;

  arena->buf = (u8 *)malloc(alloc_size);
  if (!(arena->buf)) {
    return false;
  }

  memset(arena->buf, 0, alloc_size);
  arena->capacity = capacity;
  arena->offset = arena->buf;
  arena->prev = arena->next = NULL;

  return true;
}

internal void *base_arena_alloc_aligned(BaseArena *arena, u64 size,
                                        u64 alignment) {
  if (!arena) {
    return NULL;
  }

  u8 *output = mem_util_align_forward((void *)(arena->offset), alignment);
  if (output + size >= arena->buf + arena->capacity) {
    return NULL;
  }

  arena->offset += size;

  return (void *)output;
}

internal void base_arena_clear(BaseArena *arena) {
  if (!arena) {
    return;
  }

  memset(arena->buf, 0, arena->offset - arena->buf);
  arena->offset = arena->buf;
}

internal void base_arena_free(BaseArena *arena) {
  if (!arena) {
    return;
  }

  if (arena->buf) {
    free(arena->buf);
  }

  arena->buf = arena->offset = NULL;
  arena->capacity = 0;
  arena->prev = arena->next = NULL;
}