#include "types.h"
#include "param.h"
#include "memlayout.h"
#include "spinlock.h"
#include "riscv.h"
#include "proc.h"
#include "defs.h"

//
// Race detector using gcc's thread sanitizer. It delays all stores
// and loads and monitors if any other CPU is using the same address.
// If so, we have a race and print out the backtrace of the thread
// that raced and the thread that set the watchpoint.
//

//
// To run with kcsan:
// make clean
// make KCSAN=1 qemu
//

// The number of watch points.
#define NWATCH (NCPU)

// The number of cycles to delay stores, whatever that means on qemu.
//#define DELAY_CYCLES 20000
#define DELAY_CYCLES 200000

#define MAXTRACE 20

int
trace(uint64 *trace, int maxtrace)
{
  uint64 i = 0;
  
  push_off();
  
  uint64 fp = r_fp();
  uint64 ra, low = PGROUNDDOWN(fp) + 16, high = PGROUNDUP(fp);

  while(!(fp & 7) && fp >= low && fp < high){
    ra = *(uint64*)(fp - 8);
    fp = *(uint64*)(fp - 16);
    trace[i++] = ra;
    if(i >= maxtrace)
      break;
  }

  pop_off();
  
  return i;
}

struct watch {
  uint64 addr;
  int write;
  int race;
  uint64 trace[MAXTRACE];
  int tracesz;
};
  
struct {
  struct spinlock lock;
  struct watch points[NWATCH];
  int on;
} tsan;

static struct watch*
wp_lookup(uint64 addr)
{
  for(struct watch *w = &tsan.points[0]; w < &tsan.points[NWATCH]; w++) {
    if(w->addr == addr) {
      return w;
    }
  }
  return 0;
}

static int
wp_install(uint64 addr, int write)
{
  for(struct watch *w = &tsan.points[0]; w < &tsan.points[NWATCH]; w++) {
    if(w->addr == 0) {
      w->addr = addr;
      w->write = write;
      w->tracesz = trace(w->trace, MAXTRACE);
      return 1;
    }
  }
  panic("wp_install");
  return 0;
}

static void
wp_remove(uint64 addr)
{
  for(struct watch *w = &tsan.points[0]; w < &tsan.points[NWATCH]; w++) {
    if(w->addr == addr) {
      w->addr = 0;
      w->tracesz = 0;
      return;
    }
  }
  panic("remove");
}

static void
printtrace(uint64 *t, int n)
{
  int i;
  
  for(i = 0; i < n; i++) {
    printf("%p\n", t[i]);
  }
}

static void
race(char *s, struct watch *w) {
  uint64 t[MAXTRACE];
  int n;
  
  n = trace(t, MAXTRACE);
  printf("== race detected ==\n");
  printf("backtrace for racing %s\n", s);
  printtrace(t, n);
  printf("backtrace for watchpoint:\n");
  printtrace(w->trace, w->tracesz);
  printf("==========\n");
}

// cycle counter
static inline uint64
r_cycle()
{
  uint64 x;
  asm volatile("rdcycle %0" : "=r" (x) );
  return x;
}

static void delay(void) __attribute__((noinline));
static void delay() {
  uint64 stop = r_cycle() + DELAY_CYCLES;
  uint64 c = r_cycle();
  while(c < stop) {
    c = r_cycle();
  }
}

static void
kcsan_read(uint64 addr, int sz)
{
  struct watch *w;
  
  acquire(&tsan.lock);
  if((w = wp_lookup(addr)) != 0) {
    if(w->write) {
      race("load", w);
    }
    release(&tsan.lock);
    return;
  }
  release(&tsan.lock);
}

static void
kcsan_write(uint64 addr, int sz)
{
  struct watch *w;
  
  acquire(&tsan.lock);
  if((w = wp_lookup(addr)) != 0) {
    race("store", w);
    release(&tsan.lock);
  }

  // no watchpoint; try to install one
  if(wp_install(addr, 1)) {

    release(&tsan.lock);

    // XXX maybe read value at addr before and after delay to catch
    // races of unknown origins (e.g., device).

    delay(); 

    acquire(&tsan.lock);

    wp_remove(addr);
  }
  release(&tsan.lock);
}

// tsan.on will only have effect with "make KCSAN=1"
void
kcsaninit(void)
{
  initlock(&tsan.lock, "tsan");
  tsan.on = 1;
  __sync_synchronize();
}

//
// Calls inserted by compiler into kernel binary, except for this file.
//

void
__tsan_init(void)
{
}

void
__tsan_read1(uint64 addr)
{
  if(!tsan.on)
    return;
  // kcsan_read(addr, 1);
}

void
__tsan_read2(uint64 addr)
{
  if(!tsan.on)
    return;
  kcsan_read(addr, 2);
}

void
__tsan_read4(uint64 addr)
{
  if(!tsan.on)
    return;
  kcsan_read(addr, 4);
}

void
__tsan_read8(uint64 addr)
{
  if(!tsan.on)
    return;
  kcsan_read(addr, 8);
}

void
__tsan_read_range(uint64 addr, uint64 size)
{
  if(!tsan.on)
    return;
  kcsan_read(addr, size);
}

void
__tsan_write1(uint64 addr)
{
  if(!tsan.on)
    return;
  // kcsan_write(addr, 1);
}

void
__tsan_write2(uint64 addr)
{
  if(!tsan.on)
    return;
  kcsan_write(addr, 2);
}

void
__tsan_write4(uint64 addr)
{
  if(!tsan.on)
    return;
  kcsan_write(addr, 4);
}

void
__tsan_write8(uint64 addr)
{
  if(!tsan.on)
    return;
  kcsan_write(addr, 8);
}

void
__tsan_write_range(uint64 addr, uint64 size)
{
  if(!tsan.on)
    return;
  kcsan_write(addr, size);
}

void
__tsan_atomic_thread_fence(int order)
{
  __sync_synchronize();
}

uint32
__tsan_atomic32_load(uint *ptr, uint *val, int order)
{
  uint t;
  __atomic_load(ptr, &t, __ATOMIC_SEQ_CST);
  return t;
}

void
__tsan_atomic32_store(uint *ptr, uint val, int order)
{
  __atomic_store(ptr, &val, __ATOMIC_SEQ_CST);
}

// We don't use this
void
__tsan_func_entry(uint64 pc)
{
}

// We don't use this
void
__tsan_func_exit(void)
{
}