#include "include/aliases.h"
#include "include/flag_access.h"
#include "include/reg_access.h"
#include "include/sim86_instruction.h"
#include "include/sim86_lib.h"
#include <bits/types/FILE.h>
#include <stdio.h>
#include <string.h>

#define MEM_SIZE (1 << 16)
#define BITS_PER_BYTE 8

struct basic_string {
  char str[4096];
};

struct membuf {
  u8 buffer[MEM_SIZE];
  u64 mem_start;
};

struct mem_access_result {
  u16 value;
  u32 error;
};

u16 get_operand_value(instruction_operand operand, bool wide);
basic_string get_operand_string(instruction_operand operand, bool wide);
void print_instruction(instruction inst);
void mov_to_register(const register_access &reg,
                     const instruction_operand &source, bool wide);
void mov_to_memory(const effective_address_expression &addrexp,
                   const instruction_operand &source, bool wide);
mem_access_result get_mem_value(const effective_address_expression &addrexp,
                                bool wide);
mem_access_result set_mem_value(const effective_address_expression &addrexp,
                                u16 value, bool wide);
u16 get_mem_index(const effective_address_expression &addrexp);

static membuf memory;

int main(int argc, char *argv[]) {
  if (argc < 2) {
    printf("Usage: sim86 BINARY_FILE\n");
    return 1;
  }

  memset((void *)memory.buffer, 0, MEM_SIZE);
  memory.mem_start = 0;

  const char *filename = argv[1];

  printf("Filename: %s\n", filename);

  FILE *fp = fopen(filename, "rb");
  if (!fp) {
    printf("Failed to open file %s\n", filename);
  }

  fseek(fp, 0, SEEK_END);

  u32 size = ftell(fp);

  fseek(fp, 0, SEEK_SET);

  fread((void *)memory.buffer, sizeof(u8), size, fp);
  memory.mem_start = size + 1;

  fclose(fp);

  instruction_table table;
  Sim86_Get8086InstructionTable(&table);

  u32 offset = 0;

  bool accessed_registers[REGISTER_COUNT] = {false};

  printf("\nDisassembly:\n");

  while (offset < size) {
    instruction decoded;
    Sim86_Decode8086Instruction(size - offset, memory.buffer + offset,
                                &decoded);

    if (decoded.Op) {
      offset += decoded.Size;
      bool wide = (decoded.Flags & Inst_Wide) == Inst_Wide;

      print_instruction(decoded);

      instruction_operand dest = decoded.Operands[0];
      instruction_operand source = decoded.Operands[1];

      switch (decoded.Op) {
      case Op_mov: {
        if (dest.Type == Operand_Register) {
          mov_to_register(dest.Register, source, wide);

          accessed_registers[dest.Register.Index] = true;
        } else if (dest.Type == Operand_Memory) {
          mov_to_memory(dest.Address, source, wide);
        }

        break;
      }
      case Op_add: {
        if (dest.Type == Operand_Register) {
          u16 value = get_register(dest.Register);

          value += get_operand_value(source, wide);
          set_flags(value);

          set_register(dest.Register, value);
        }

        break;
      }
      case Op_sub:
      case Op_cmp: {
        if (dest.Type == Operand_Register) {
          u16 value = get_register(dest.Register);

          value -= get_operand_value(source, wide);
          set_flags(value);

          if (decoded.Op == Op_sub) {
            set_register(dest.Register, value);
          }
        }

        break;
      }
      case Op_jne: {
        if (!get_flag(FLAG_ZERO)) {
          i16 inst_offset = get_operand_value(dest, wide);

          offset += inst_offset;
        }
      }
      default:
        break;
      }
    }
  }

  printf("\nFinal registers:\n");

  for (u32 i = 0; i < REGISTER_COUNT; ++i) {
    if (accessed_registers[i]) {
      register_access reg = {i, 0, 2};
      u16 value = get_register(reg);

      printf("\t%s: 0x%04x (%d)\n", get_register_name(reg), value, value);
    }
  }

  // Print the instruction pointer register
  printf("\tip: 0x%04x (%d)\n", offset, offset);

  printf("\nFinal flags:\n");
  print_flags();

#if 0 // Only needed (and working) for listing 0054
#define SIZE 64
#define BYTES SIZE * 4 * SIZE

  u8 image[BYTES];
  mempcpy(image, &(memory.buffer[memory.mem_start + (SIZE * 4)]), BYTES);

  FILE *out = fopen("image.data", "wb");

  fwrite(image, sizeof(u8), BYTES, out);

  fclose(out);
#endif

  return 0;
}

u16 get_operand_value(instruction_operand operand, bool wide) {
  u16 output = 0;

  switch (operand.Type) {
  case Operand_Register:
    output = get_register(operand.Register);

    break;
  case Operand_Memory: {
    mem_access_result result = get_mem_value(operand.Address, wide);

    if (result.error) {
      break;
    }

    output = result.value;

    break;
  }
  case Operand_Immediate:
    output = operand.Immediate.Value;

    break;
  default:
    break;
  }

  return output;
}

basic_string get_operand_string(instruction_operand operand, bool wide) {
  basic_string output = {""};

  switch (operand.Type) {
  case Operand_Register:
    sprintf(output.str, "%s", get_register_name(operand.Register));

    break;
  case Operand_Memory: {
    char mem_string[1024] = {0};

    register_access reg1 = operand.Address.Terms[0].Register;
    if (reg1.Index != 0) {
      sprintf(mem_string, "%s + ", get_register_name(reg1));
    }

    register_access reg2 = operand.Address.Terms[1].Register;
    if (reg2.Index != 0) {
      strcat(mem_string, get_register_name(reg2));
    } else {
      u32 length = strlen(mem_string);

      sprintf(&(mem_string[length]), "%d", operand.Address.Displacement);
    }

    sprintf(output.str, "%s [%s]", wide ? "word" : "byte", mem_string);

    break;
  }
  case Operand_Immediate:
    sprintf(output.str, "%d", operand.Immediate.Value);

    break;
  default:
    break;
  }

  return output;
}

void print_instruction(instruction inst) {
  bool wide = (inst.Flags & Inst_Wide) == Inst_Wide;

  printf("\t%s %s, %s\n", Sim86_MnemonicFromOperationType(inst.Op),
         get_operand_string(inst.Operands[0], wide).str,
         get_operand_string(inst.Operands[1], wide).str);
}

void mov_to_register(const register_access &reg,
                     const instruction_operand &source, bool wide) {
  switch (source.Type) {
  case Operand_Immediate:
    set_register(reg, source.Immediate.Value);
    break;
  case Operand_Register:
    set_register(reg, get_register(source.Register));
    break;
  case Operand_Memory: {
    mem_access_result result = get_mem_value(source.Address, wide);

    if (!result.error) {
      set_register(reg, result.value);
    }

    break;
  }
  default:
    break;
  }
}

void mov_to_memory(const effective_address_expression &addrexp,
                   const instruction_operand &source, bool wide) {
  switch (source.Type) {
  case Operand_Immediate:
    set_mem_value(addrexp, source.Immediate.Value, wide);
    break;
  case Operand_Register:
    set_mem_value(addrexp, get_register(source.Register), wide);
    break;
  case Operand_Memory: {
    mem_access_result result = get_mem_value(source.Address, wide);

    if (!result.error) {
      set_mem_value(addrexp, result.value, wide);
    }

    break;
  }
  default:
    break;
  }
}

mem_access_result get_mem_value(const effective_address_expression &addrexp,
                                bool wide) {
  u16 index = get_mem_index(addrexp);

  mem_access_result result = {0, 0};

  if (memory.mem_start + index >= MEM_SIZE) {
    result.error = 1;
  } else {
    result.value |= memory.buffer[memory.mem_start + index];

    if (wide) {
      result.value |= (memory.buffer[memory.mem_start + index + 1]
                       << (wide ? BITS_PER_BYTE : 0));
    }
  }

  return result;
}

mem_access_result set_mem_value(const effective_address_expression &addrexp,
                                u16 value, bool wide) {
  u16 index = get_mem_index(addrexp);

  mem_access_result result = {0, 0};

  if (memory.mem_start + index >= MEM_SIZE) {
    result.error = 1;
  } else {
    memory.buffer[memory.mem_start + index] = (u8)value;

    if (wide) {
      memory.buffer[memory.mem_start + index + 1] =
          (u8)(value >> (wide ? BITS_PER_BYTE : 0));
    }

    result.value = value;
  }

  return result;
}

u16 get_mem_index(const effective_address_expression &addrexp) {
  u16 index = addrexp.Displacement;

  const u16 term_count = 2;
  for (u16 i = 0; i < term_count; ++i) {
    if (addrexp.Terms[i].Register.Index != 0) {
      index += get_register(addrexp.Terms[i].Register);
    }
  }

  return index;
}