#include "repetition_testing/reptester.h"
#include "profiler/timer.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

void handle_alloc(reptester *tester, alloc_type type) {
  switch (type) {
  case ALLOC_TYPE_WITH_MALLOC:
    if (!(tester->params.buffer)) {
      tester->params.buffer = (char *)malloc(tester->params.read_size + 1);
      memset(tester->params.buffer, 0, tester->params.read_size + 1);
    }

    break;
  default:
    break;
  }
}

void handle_free(reptester *tester, alloc_type type) {
  switch (type) {
  case ALLOC_TYPE_WITH_MALLOC:
    if (tester->params.buffer) {
      free(tester->params.buffer);
      tester->params.buffer = NULL;
    }

    break;
  default:
    break;
  }
}

void run_func_test(reptester *tester, reptest_func func, const char *func_name,
                   alloc_type type) {
  tester->test_start_time = read_cpu_timer();
  tester->test_time_secs = 0.0;
  tester->current_run = 1;
  tester->tstats = {
      UINT64_MAX, // min_time
      0,          // max_time
      0,          // avg_time
      0,          // total_time
  };
  tester->results = {};

  char *buffer = NULL;

  if (type == ALLOC_TYPE_WITH_MALLOC) {
    buffer = tester->params.buffer;
    tester->params.buffer = (char *)malloc(tester->params.read_size + 1);
    memset(tester->params.buffer, 0, tester->params.read_size + 1);
  }

  while (tester->test_time_secs <= tester->wait_time_secs) {
    func(tester, type);

    if (tester->results.bytes_read <
        tester->params.read_size * tester->params.read_count) {
      printf("Failed to read the entire file (Total size: %lu, Bytes read: "
             "%lu)\n",
             tester->params.read_size, tester->results.bytes_read);

      return;
    }

    tester->tstats.total_time += tester->results.read_time;

    if (tester->results.read_time > tester->tstats.max_time) {
      tester->tstats.max_time = tester->results.read_time;
    } else if (tester->results.read_time < tester->tstats.min_time) {
      tester->test_start_time = read_cpu_timer();
      tester->tstats.min_time = tester->results.read_time;
    }

    tester->test_time_secs = time_in_seconds(
        read_cpu_timer() - tester->test_start_time, tester->cpu_freq);

    ++(tester->current_run);
  }

  if (type == ALLOC_TYPE_WITH_MALLOC) {
    free(tester->params.buffer);
    tester->params.buffer = buffer;
  }

  print_results(tester, func_name);
}

void print_results(reptester *tester, const char *name) {
  f64 gb = 1024.0 * 1024.0 * 1024.0;

  f64 size_in_gb =
      (f64)(tester->params.read_size * tester->params.read_count) / gb;

  u64 run_count = tester->current_run - 1;

  tester->tstats.avg_time = tester->tstats.total_time / run_count;

  printf("\n%s: %lu runs\n", name, run_count);
  printf("MIN: %lu (%fGB/s)\n", tester->tstats.min_time,
         size_in_gb /
             time_in_seconds(tester->tstats.min_time, tester->cpu_freq));
  printf("MAX: %lu (%fGB/s)\n", tester->tstats.max_time,
         size_in_gb /
             time_in_seconds(tester->tstats.max_time, tester->cpu_freq));
  printf("AVG: %lu (%fGB/s)\n", tester->tstats.avg_time,
         size_in_gb /
             time_in_seconds(tester->tstats.avg_time, tester->cpu_freq));
}