#include <bits/types/FILE.h>
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>

#define ARRAY_LEN(ARR) (sizeof(ARR) / sizeof(*ARR))

enum class INST_BITS {
  ADD_REG_MEM_REG = 0x00,
  SUB_REG_MEM_REG = 0x28,
  CMP_REG_MEM_REG = 0x38,
  MOV_REG_MEM_REG = 0x88,
  ADD_IMM_TO_ACC = 0x04,
  SUB_IMM_FROM_ACC = 0x2c,
  CMP_IMM_WITH_ACC = 0x3c,
  MOV_MEM_TO_ACC = 0xa0,
  MOV_ACC_TO_MEM = 0xa2,
  MOV_IMM_TO_REG = 0xb0,
  MOV_IMM_TO_REG_MEM = 0xc6,
  ARITHMETIC_IMM_TO_REG_MEM = 0x80,
  JE = 0x74,
  JL = 0x7c,
  JLE = 0x7e,
  JB = 0x72,
  JBE = 0x76,
  JP = 0x7a,
  JO = 0x70,
  JS = 0x78,
  JNE_JNZ = 0x75,
  JNL = 0x7d,
  JG = 0x7f,
  JNB = 0x73,
  JA = 0x77,
  JNP = 0x7b,
  JNO = 0x71,
  JNS = 0x79,
  LOOP = 0xe2,
  LOOPZ = 0xe1,
  LOOPNZ = 0xe0,
  JCXZ = 0xe3,
};

uint8_t reg_mem_reg_insts[] = {
    (uint8_t)INST_BITS::ADD_REG_MEM_REG,
    (uint8_t)INST_BITS::SUB_REG_MEM_REG,
    (uint8_t)INST_BITS::CMP_REG_MEM_REG,
    (uint8_t)INST_BITS::MOV_REG_MEM_REG,
};

uint8_t to_accumulator_insts[] = {
    (uint8_t)INST_BITS::MOV_MEM_TO_ACC,
    (uint8_t)INST_BITS::ADD_IMM_TO_ACC,
    (uint8_t)INST_BITS::SUB_IMM_FROM_ACC,
    (uint8_t)INST_BITS::CMP_IMM_WITH_ACC,
};

uint8_t jump_insts[] = {
    (uint8_t)INST_BITS::JE,      (uint8_t)INST_BITS::JL,
    (uint8_t)INST_BITS::JLE,     (uint8_t)INST_BITS::JB,
    (uint8_t)INST_BITS::JBE,     (uint8_t)INST_BITS::JP,
    (uint8_t)INST_BITS::JO,      (uint8_t)INST_BITS::JS,
    (uint8_t)INST_BITS::JNE_JNZ, (uint8_t)INST_BITS::JNL,
    (uint8_t)INST_BITS::JG,      (uint8_t)INST_BITS::JNB,
    (uint8_t)INST_BITS::JA,      (uint8_t)INST_BITS::JNP,
    (uint8_t)INST_BITS::JNO,     (uint8_t)INST_BITS::JNS,
    (uint8_t)INST_BITS::LOOP,    (uint8_t)INST_BITS::LOOPZ,
    (uint8_t)INST_BITS::LOOPNZ,  (uint8_t)INST_BITS::JCXZ,
};

enum class ARITHMETIC {
  ADD = 0x00,
  SUB = 0x28,
  CMP = 0x38,
};

enum class INST_MASKS {
  REG_MEM_REG = 0xfc,
  IMM_TO_REG = 0xf0,
  MOV_IMM_TO_REG_MEM = 0xfe,
  ARITHMETIC_IMM_TO_REG_MEM = 0xfc,
  ACCUMULATOR = 0xfe,
  ARITHMETIC = 0x38,
  JUMPS = 0xff,
};

enum class MODE {
  MEM = 0x00,
  MEM8 = 0x40,
  MEM16 = 0x80,
  REG = 0xc0,
};

bool mask_instruction(uint8_t instruction, uint8_t inst_bits, uint8_t mask);
bool instruction_in_array(uint8_t inst, uint8_t *instructions, size_t arr_size,
                          uint8_t mask);
uint8_t get_instruction_from_array(uint8_t inst, uint8_t *instructions,
                                   size_t arr_size, uint8_t mask);
void decode_register(uint8_t instruction, bool word, char *dest);
void decode_rm(uint8_t instruction, char *dest);
void stringify_rm_and_disp(FILE *fp, uint8_t operands, char *rm,
                           uint32_t buff_size);
void handle_accumulator_mov_instructions(FILE *fp, uint8_t inst, bool reg_dest,
                                         char *dest);
void handle_accumulator_arithmetic_instructions(FILE *fp, uint8_t inst,
                                                char *dest);

int main(int argc, char *argv[]) {
  if (argc < 2) {
    printf("Please provide a file to disassemble\n");
    return 1;
  }

  const char *filename = argv[1];

  FILE *fp = fopen(filename, "rb");

  if (fp) {
    uint8_t inst = 0;
    const char *op = "";

    char out_filename[4096] = {0};
    sprintf(out_filename, "%s_out.asm", filename);

    FILE *out = fopen(out_filename, "w");

    if (out) {
      fprintf(out, "; Disassembled by DASM\n\nbits 16\n\n");

      while (fread(&inst, sizeof(inst), 1, fp)) {
        if (instruction_in_array(inst, reg_mem_reg_insts,
                                 ARRAY_LEN(reg_mem_reg_insts),
                                 (uint8_t)INST_MASKS::REG_MEM_REG)) {
          switch (get_instruction_from_array(
              inst, reg_mem_reg_insts, ARRAY_LEN(reg_mem_reg_insts),
              (uint8_t)INST_MASKS::REG_MEM_REG)) {
          case (uint8_t)INST_BITS::MOV_REG_MEM_REG:
            op = "mov";
            break;
          case (uint8_t)INST_BITS::ADD_REG_MEM_REG:
            op = "add";
            break;
          case (uint8_t)INST_BITS::SUB_REG_MEM_REG:
            op = "sub";
            break;
          case (uint8_t)INST_BITS::CMP_REG_MEM_REG:
            op = "cmp";
            break;
          }

          uint8_t operands = 0;
          fread(&operands, sizeof(operands), 1, fp);

          bool reg_dest = mask_instruction(inst, 0x02, 0x02);

          bool word = mask_instruction(inst, 0x01, 0x01);

          char reg[3] = {0};
          decode_register(operands >> 3, word, reg);

          if (mask_instruction(operands, (uint8_t)MODE::REG,
                               (uint8_t)MODE::REG)) {
            char rm[3] = {0};
            decode_register(operands, word, rm);

            fprintf(out, "%s %s, %s\n", op, reg_dest ? reg : rm,
                    reg_dest ? rm : reg);
          } else {
            char rm[20] = {0};
            stringify_rm_and_disp(fp, operands, rm, 20);

            fprintf(out, reg_dest ? "%s %s, [%s]\n" : "%s [%s], %s\n", op,
                    reg_dest ? reg : rm, reg_dest ? rm : reg);
          }
        } else if (mask_instruction(inst, (uint8_t)INST_BITS::MOV_IMM_TO_REG,
                                    (uint8_t)INST_MASKS::IMM_TO_REG)) {
          op = "mov";

          // Bit pattern is:
          //   7   6   5   4   3   2   1   0
          //  -------------------------------
          // | 1 | 0 | 1 | 1 | w |    reg    |
          //  -------------------------------
          //
          // So, we need to mask the fourth bit to check the w flag
          bool word = mask_instruction(inst, 0x08, 0x08);

          uint8_t next_bytes = word ? 2 : 1;

          char reg[3] = {0};

          decode_register(inst, word, reg);

          int16_t data = 0;

          fread(&data, sizeof(next_bytes), next_bytes, fp);

          fprintf(out, "%s %s, %d\n", op, reg, word ? data : (int8_t)data);
        } else if (mask_instruction(inst,
                                    (uint8_t)INST_BITS::MOV_IMM_TO_REG_MEM,
                                    (uint8_t)INST_MASKS::MOV_IMM_TO_REG_MEM)) {
          op = "mov";

          uint8_t operands = 0;
          fread(&operands, sizeof(operands), 1, fp);

          // Instruction bit pattern is:
          //   7   6   5   4   3   2   1   0
          //  -------------------------------
          // | 1 | 1 | 0 | 0 | 0 | 1 | 1 | w |
          //  -------------------------------
          //
          // Operands bit pattern is:
          //   7   6   5   4   3   2   1   0
          //  -------------------------------
          // |  mod  |    000    |    r/m    |
          //  -------------------------------

          bool word = mask_instruction(inst, 0x01, 0x01);

          char rm[20] = {0};
          stringify_rm_and_disp(fp, operands, rm, 20);

          uint8_t next_bytes = word ? 2 : 1;

          int16_t data = 0;
          fread(&data, sizeof(next_bytes), next_bytes, fp);

          fprintf(out, "%s [%s], %s %d\n", op, rm, word ? "word" : "byte",
                  word ? data : (int8_t)data);
        } else if (mask_instruction(
                       inst, (uint8_t)INST_BITS::ARITHMETIC_IMM_TO_REG_MEM,
                       (uint8_t)INST_MASKS::ARITHMETIC_IMM_TO_REG_MEM)) {
          uint8_t operands = 0;
          fread(&operands, sizeof(operands), 1, fp);

          if (mask_instruction(operands, (uint8_t)ARITHMETIC::ADD,
                               (uint8_t)INST_MASKS::ARITHMETIC)) {
            op = "add";
          } else if (mask_instruction(operands, (uint8_t)ARITHMETIC::SUB,
                                      (uint8_t)INST_MASKS::ARITHMETIC)) {
            op = "sub";
          } else if (mask_instruction(operands, (uint8_t)ARITHMETIC::CMP,
                                      (uint8_t)INST_MASKS::ARITHMETIC)) {
            op = "cmp";
          }

          bool word = mask_instruction(inst, 0x01, 0x01);
          bool sign = mask_instruction(inst, 0x02, 0x02);

          if (mask_instruction(operands, (uint8_t)MODE::REG,
                               (uint8_t)MODE::REG)) {
            char rm[3] = {0};
            decode_register(operands, word, rm);

            uint8_t next_bytes = 0;

            if ((!word && !sign) || (word && sign)) {
              next_bytes = 1;
            } else if (word && !sign) {
              next_bytes = 2;
            }

            int16_t data = 0;
            fread(&data, sizeof(next_bytes), next_bytes, fp);

            fprintf(out, "%s %s, %d\n", op, rm,
                    next_bytes == 1 ? (int8_t)data : data);
          } else {
            char rm[20] = {0};
            stringify_rm_and_disp(fp, operands, rm, 20);

            uint8_t next_bytes = 0;

            if ((!word && !sign) || (word && sign)) {
              next_bytes = 1;
            } else if (word && !sign) {
              next_bytes = 2;
            }

            int16_t data = 0;
            fread(&data, sizeof(next_bytes), next_bytes, fp);

            fprintf(out, "%s %s [%s], %d\n", op, word ? "word" : "byte", rm,
                    next_bytes == 1 ? (int8_t)data : data);
          }
        } else if (instruction_in_array(inst, to_accumulator_insts,
                                        ARRAY_LEN(to_accumulator_insts),
                                        (uint8_t)INST_MASKS::ACCUMULATOR)) {
          char inst_out[256] = {0};

          switch (get_instruction_from_array(
              inst, to_accumulator_insts, ARRAY_LEN(to_accumulator_insts),
              (uint8_t)INST_MASKS::ACCUMULATOR)) {
          case (uint8_t)INST_BITS::MOV_MEM_TO_ACC:
            handle_accumulator_mov_instructions(fp, inst, true, inst_out);
            break;
          case (uint8_t)INST_BITS::ADD_IMM_TO_ACC:
            handle_accumulator_arithmetic_instructions(fp, inst, inst_out);
            break;
          case (uint8_t)INST_BITS::SUB_IMM_FROM_ACC:
            handle_accumulator_arithmetic_instructions(fp, inst, inst_out);
            break;
          case (uint8_t)INST_BITS::CMP_IMM_WITH_ACC:
            handle_accumulator_arithmetic_instructions(fp, inst, inst_out);
            break;
          }

          fprintf(out, "%s\n", inst_out);
        } else if (mask_instruction(inst, (uint8_t)INST_BITS::MOV_ACC_TO_MEM,
                                    (uint8_t)INST_MASKS::ACCUMULATOR)) {
          char inst_out[256] = {0};
          handle_accumulator_mov_instructions(fp, inst, false, inst_out);

          fprintf(out, "%s\n", inst_out);
        } else if (instruction_in_array(inst, jump_insts, ARRAY_LEN(jump_insts),
                                        (uint8_t)INST_MASKS::JUMPS)) {
          switch (get_instruction_from_array(inst, jump_insts,
                                             ARRAY_LEN(jump_insts),
                                             (uint8_t)INST_MASKS::JUMPS)) {
          case (uint8_t)INST_BITS::JE:
            op = "je";
            break;
          case (uint8_t)INST_BITS::JL:
            op = "jl";
            break;
          case (uint8_t)INST_BITS::JLE:
            op = "jle";
            break;
          case (uint8_t)INST_BITS::JB:
            op = "jb";
            break;
          case (uint8_t)INST_BITS::JBE:
            op = "jbe";
            break;
          case (uint8_t)INST_BITS::JP:
            op = "jp";
            break;
          case (uint8_t)INST_BITS::JO:
            op = "jo";
            break;
          case (uint8_t)INST_BITS::JS:
            op = "js";
            break;
          case (uint8_t)INST_BITS::JNE_JNZ:
            op = "jnz";
            break;
          case (uint8_t)INST_BITS::JNL:
            op = "jnl";
            break;
          case (uint8_t)INST_BITS::JG:
            op = "jg";
            break;
          case (uint8_t)INST_BITS::JNB:
            op = "jnb";
            break;
          case (uint8_t)INST_BITS::JA:
            op = "ja";
            break;
          case (uint8_t)INST_BITS::JNP:
            op = "jnp";
            break;
          case (uint8_t)INST_BITS::JNO:
            op = "jno";
            break;
          case (uint8_t)INST_BITS::JNS:
            op = "jns";
            break;
          case (uint8_t)INST_BITS::LOOP:
            op = "loop";
            break;
          case (uint8_t)INST_BITS::LOOPZ:
            op = "loopz";
            break;
          case (uint8_t)INST_BITS::LOOPNZ:
            op = "loopnz";
            break;
          case (uint8_t)INST_BITS::JCXZ:
            op = "jcxz";
            break;
          }

          int8_t inc = 0;
          fread(&inc, sizeof(int8_t), 1, fp);

          fprintf(out, "%s %d\n", op, inc);
        } else {
          printf("Invalid instruction\n");
        }
      }

      fclose(out);
    } else {
      printf("Failed to open output file\n");
    }

    fclose(fp);
  } else {
    printf("Failed to open the selected file\n");
  }

  return 0;
}

bool mask_instruction(uint8_t instruction, uint8_t inst_bits, uint8_t mask) {
  return (instruction & mask) == inst_bits;
}

bool instruction_in_array(uint8_t inst, uint8_t *instructions, size_t arr_size,
                          uint8_t mask) {
  for (size_t i = 0; i < arr_size; ++i) {
    if (mask_instruction(inst, instructions[i], mask)) {
      return true;
    }
  }

  return false;
}

uint8_t get_instruction_from_array(uint8_t inst, uint8_t *instructions,
                                   size_t arr_size, uint8_t mask) {
  for (size_t i = 0; i < arr_size; ++i) {
    if (mask_instruction(inst, instructions[i], mask)) {
      return instructions[i];
    }
  }

  return mask;
}

void decode_register(uint8_t instruction, bool word, char *dest) {
  static uint8_t reg_mask = 0x07;

  // clang-format off
	static const char *table[16] = {
		"al", "ax",
		"cl", "cx",
		"dl", "dx",
		"bl", "bx",
		"ah", "sp",
		"ch", "bp",
		"dh", "si",
		"bh", "di"
	};
  // clang-format on

  static const uint8_t ROW_WIDTH = 2;

  uint8_t offset = instruction & reg_mask;

  // Multiply offset by 2 since each row has 2 columns
  strcpy(dest, table[offset * ROW_WIDTH + (uint8_t)word]);
}

void decode_rm(uint8_t instruction, char *dest) {
  static uint8_t rm_mask = 0x07;

  // clang-format off
  static const char *table[8] = {
		"bx + si",
		"bx + di",
		"bp + si",
		"bp + di",
		"si",
		"di",
		"bp",
		"bx"
  };
  // clang-format on

  uint8_t index = instruction & rm_mask;

  strcpy(dest, table[index]);
}

void stringify_rm_and_disp(FILE *fp, uint8_t operands, char *rm,
                           uint32_t buff_size) {
  decode_rm(operands, rm);

  bool direct_address = false;

  uint8_t next_bytes = operands >> 6;

  if (next_bytes == 0 && mask_instruction(operands, 0x06, 0x07)) {
    // Handle case when MOD == 00 and R/M == 110
    next_bytes = 2;

    direct_address = true;
  }

  int16_t disp = 0;
  fread(&disp, sizeof(next_bytes), next_bytes, fp);

  if (disp != 0) {
    if (direct_address) {
      memset(rm, 0, buff_size);

      sprintf(rm, "%d", disp);
    } else {
      bool positive = next_bytes > 1 ? disp > 0 : (int8_t)disp > 0;

      char disp_out[buff_size];
      memset(disp_out, 0, buff_size);

      sprintf(disp_out, " %c %d", positive ? '+' : '-',
              next_bytes > 1 ? abs(disp) : abs((int8_t)disp));

      strcat(rm, disp_out);
    }
  }
}

void handle_accumulator_mov_instructions(FILE *fp, uint8_t inst, bool reg_dest,
                                         char *dest) {
  bool word = mask_instruction(inst, 0x01, 0x01);

  uint8_t next_bytes = word ? 2 : 1;

  uint16_t addr = 0;
  fread(&addr, sizeof(next_bytes), next_bytes, fp);

  char addr_out[64] = {0};
  sprintf(addr_out, "[%d]", word ? addr : (uint8_t)addr);

  sprintf(dest, "mov %s, %s", reg_dest ? "ax" : addr_out,
          reg_dest ? addr_out : "ax");
}

void handle_accumulator_arithmetic_instructions(FILE *fp, uint8_t inst,
                                                char *dest) {
  const char *op = "";

  switch (get_instruction_from_array(inst, to_accumulator_insts,
                                     ARRAY_LEN(to_accumulator_insts),
                                     (uint8_t)INST_MASKS::ACCUMULATOR)) {
  case (uint8_t)INST_BITS::ADD_IMM_TO_ACC:
    op = "add";
    break;
  case (uint8_t)INST_BITS::SUB_IMM_FROM_ACC:
    op = "sub";
    break;
  case (uint8_t)INST_BITS::CMP_IMM_WITH_ACC:
    op = "cmp";
    break;
  }

  bool word = mask_instruction(inst, 0x01, 0x01);

  uint8_t next_bytes = word ? 2 : 1;

  uint16_t data = 0;
  fread(&data, sizeof(next_bytes), next_bytes, fp);

  sprintf(dest, "%s %s, %d", op, word ? "ax" : "al",
          word ? data : (int8_t)data);
}