From 09ba9112386d5d59d7f2a31c469768c582acb939 Mon Sep 17 00:00:00 2001
From: Mole Shang <135e2@135e2.dev>
Date: Mon, 19 Feb 2024 21:51:26 +0800
Subject: lab mmap: finish

---
 kernel/defs.h      |  10 +++-
 kernel/file.c      |  31 +++++++++++
 kernel/memlayout.h |   2 +
 kernel/proc.c      | 156 ++++++++++++++++++++++++++++++++++++++++++++++++++++-
 kernel/proc.h      |  16 ++++++
 kernel/syscall.c   |   6 +++
 kernel/sysfile.c   |   2 +-
 kernel/sysproc.c   |  36 ++++++++++++-
 kernel/vm.c        |   8 +--
 9 files changed, 260 insertions(+), 7 deletions(-)

(limited to 'kernel')

diff --git a/kernel/defs.h b/kernel/defs.h
index 541c97e..6257437 100644
--- a/kernel/defs.h
+++ b/kernel/defs.h
@@ -45,6 +45,9 @@ void            fileinit(void);
 int             fileread(struct file*, uint64, int n);
 int             filestat(struct file*, uint64 addr);
 int             filewrite(struct file*, uint64, int n);
+int             fileperm(struct file*);
+int             mmap_read(struct file*, uint64, uint64, int);
+int             munmap_write(struct file*, uint64, uint64, int);
 
 // fs.c
 void            fsinit(int);
@@ -124,6 +127,8 @@ int             either_copyin(void *dst, int user_src, uint64 src, uint64 len);
 void            procdump(void);
 int             get_nproc(void);
 int             pgaccess(uint64 base, int len, uint64 mask);
+uint64 mmap(uint64 addr, uint64 len, int prot, int flags, int fd, struct file* file, uint64 offset);
+int munmap(uint64 addr, uint64 len);
 
 // swtch.S
 void            swtch(struct context*, struct context*);
@@ -163,6 +168,9 @@ int             fetchstr(uint64, char*, int);
 int             fetchaddr(uint64, uint64*);
 void            syscall();
 
+// sysfile.c
+int            argfd(int n, int *pfd, struct file **pf);
+
 // sysinfo.c
 int             sys_info(uint64);
 
@@ -189,7 +197,7 @@ pagetable_t     uvmcreate(void);
 void            uvmfirst(pagetable_t, uchar *, uint);
 uint64          uvmalloc(pagetable_t, uint64, uint64, int);
 uint64          uvmdealloc(pagetable_t, uint64, uint64);
-int             uvmcopy(pagetable_t, pagetable_t, uint64);
+int             uvmcopy(pagetable_t, pagetable_t, uint64, uint64);
 void            uvmfree(pagetable_t, uint64);
 void            uvmunmap(pagetable_t, uint64, uint64, int);
 void            uvmclear(pagetable_t, uint64);
diff --git a/kernel/file.c b/kernel/file.c
index 0fba21b..add6575 100644
--- a/kernel/file.c
+++ b/kernel/file.c
@@ -6,6 +6,7 @@
 #include "riscv.h"
 #include "defs.h"
 #include "param.h"
+#include "fcntl.h"
 #include "fs.h"
 #include "spinlock.h"
 #include "sleeplock.h"
@@ -197,3 +198,33 @@ filewrite(struct file *f, uint64 addr, int n)
   return ret;
 }
 
+int fileperm(struct file* f)
+{
+  int flags = 0;
+  if(f->writable)
+    flags |= PROT_WRITE;
+  if(f->readable)
+    flags |= PROT_READ;
+  return flags;
+}
+
+// reads file from disk to physical address, needed by mmap
+int mmap_read(struct file *file, uint64 pa, uint64 off, int size) {
+  int n = 0;
+  ilock(file->ip);
+  n = readi(file->ip, 0, pa, off, size);
+  off += n;
+  iunlock(file->ip);
+  return off;
+}
+
+// writes from virtual address to disk, needed by munmap
+int munmap_write(struct file *file, uint64 va, uint64 off, int size) {
+  int r;
+  begin_op();
+  ilock(file->ip);
+  r = writei(file->ip, 1, va, off, size);
+  iunlock(file->ip);
+  end_op();
+  return r;
+}
diff --git a/kernel/memlayout.h b/kernel/memlayout.h
index 74d2fd4..517a69e 100644
--- a/kernel/memlayout.h
+++ b/kernel/memlayout.h
@@ -77,3 +77,5 @@ struct usyscall {
   int pid;  // Process ID
 };
 #endif
+
+#define MAX_VM_ADDR (USYSCALL - PGSIZE)
diff --git a/kernel/proc.c b/kernel/proc.c
index 9a9bae9..923729d 100644
--- a/kernel/proc.c
+++ b/kernel/proc.c
@@ -1,4 +1,5 @@
 #include "types.h"
+#include "fcntl.h"
 #include "param.h"
 #include "memlayout.h"
 #include "riscv.h"
@@ -160,6 +161,14 @@ found:
     return 0;
   }
 
+  // setup mmap vm area
+  p->cur_max_vm_addr = MAX_VM_ADDR;
+  for(int i = 0; i < MAX_VM_AREA; i++){
+    struct vm_area* vma = &p->vma[i];
+    memset(&p->vma[i], 0, sizeof(struct vm_area));
+    vma->start_addr = -1;
+  }
+
   // Set up new context to start executing at forkret,
   // which returns to user space.
   memset(&p->context, 0, sizeof(p->context));
@@ -328,7 +337,7 @@ fork(void)
   }
 
   // Copy user memory from parent to child.
-  if(uvmcopy(p->pagetable, np->pagetable, p->sz) < 0){
+  if(uvmcopy(p->pagetable, np->pagetable, 0, p->sz) < 0){
     freeproc(np);
     release(&np->lock);
     return -1;
@@ -350,6 +359,17 @@ fork(void)
       np->ofile[i] = filedup(p->ofile[i]);
   np->cwd = idup(p->cwd);
 
+  // copy vm areas
+  // TODO: cow & remap vm to pa
+  for(int i = 0; i < MAX_VM_AREA; i++){
+    if(p->vma[i].start_addr != -1){
+      memmove(&np->vma[i], &p->vma[i], sizeof(struct vm_area));
+      uvmcopy(p->pagetable, np->pagetable, p->vma[i].start_addr, p->vma[i].len);
+      if(p->vma[i].file)
+        filedup(p->vma[i].file);
+    }
+  }
+
   safestrcpy(np->name, p->name, sizeof(p->name));
 
   pid = np->pid;
@@ -393,6 +413,12 @@ exit(int status)
   if(p == initproc)
     panic("init exiting");
 
+  for(int i = 0; i < MAX_VM_AREA; i++){
+    if(p->vma[i].start_addr != -1){
+      uvmunmap(p->pagetable, p->vma[i].start_addr, p->vma[i].len/PGSIZE, 1);
+    }
+  }
+
   // Close all open files.
   for(int fd = 0; fd < NOFILE; fd++){
     if(p->ofile[fd]){
@@ -768,3 +794,131 @@ pgaccess(uint64 base, int len, uint64 mask_addr)
   // now copyout the mask to user memory
   return copyout(pgtbl, mask_addr, (char *)&mask, sizeof(mask));
 }
+
+// lab mmap
+// all the addrs should be page aligned.
+uint64 mmap(uint64 addr, uint64 len, int prot, int flags, int fd, struct file* file, uint64 offset)
+{
+  struct proc *p = myproc();
+  struct vm_area *vma = 0;
+
+  len = PGROUNDUP(len);
+  addr = PGROUNDDOWN(addr);
+
+  if(p->sz+len > MAXVA)
+    return -1;
+
+  if(offset < 0 || offset & PGSIZE)
+    return -1;
+
+  for(int i = 0; i < MAX_VM_AREA; i++){
+    if(p->vma[i].start_addr == -1){
+      vma = &p->vma[i];
+      break;
+    }
+  }
+  if(!vma)
+    goto mmap_bad;
+
+  vma->len = len;
+
+  if(addr >= p->sz && addr <= MAXVA)
+    vma->start_addr = addr;
+  else if (addr == 0) {
+    vma->start_addr = p->cur_max_vm_addr - len;
+  } else {
+    goto mmap_bad;
+  }
+  p->cur_max_vm_addr = vma->start_addr;
+
+  vma->fd = fd;
+  filedup(file);
+  vma->flags = flags;
+  vma->prot = prot;
+  vma->file = file;
+  vma->roff = offset;
+
+  int pte_flags = PTE_U|PTE_V;
+
+  if(vma->prot & PROT_READ)
+    pte_flags |= PTE_R;
+  if(vma->prot & PROT_WRITE)
+    pte_flags |= PTE_W;
+  if(vma->prot & PROT_EXEC)
+    pte_flags |= PTE_X;
+
+  for(uint64 a = vma->start_addr; a < vma->start_addr+len; a+= PGSIZE){
+    char *mem = kalloc();
+
+    if(mem == 0){
+      goto mmap_bad;
+    }
+    memset(mem, 0, PGSIZE);
+    vma->roff = mmap_read(vma->file, (uint64)mem, vma->roff, PGSIZE);
+
+    if(mappages(p->pagetable, a, PGSIZE, (uint64)mem, pte_flags) != 0){
+      kfree(mem);
+      uvmunmap(p->pagetable, a, PGSIZE, 0);
+      goto mmap_bad;
+    }
+  }
+
+  return vma->start_addr;
+
+mmap_bad:
+  return 0xffffffffffffffff;
+}
+
+// all the addrs should be page aligned.
+// ceil-round if not.
+int munmap(uint64 addr, uint64 len)
+{
+  struct proc *p = myproc();
+  struct vm_area *vma = 0;
+  int r = 0, dec_refcnt = 0;
+
+  if(addr+len > MAXVA)
+    return -1;
+
+  addr = PGROUNDDOWN(addr);
+  len = PGROUNDUP(len);
+
+  for(int i = 0; i < MAX_VM_AREA; i++){
+    if((p->vma[i].start_addr != -1) && addr >= p->vma[i].start_addr && addr+len <= p->vma[i].start_addr+p->vma[i].len){
+      vma = &p->vma[i];
+      break;
+    }
+  }
+
+  if(!vma){
+    return -1;
+  }
+
+  if(vma->flags & MAP_SHARED){
+    // do the writeback
+    vma->woff = addr-vma->start_addr;
+    if(vma->woff+len >= vma->roff){
+      // should decrease refcnt
+      dec_refcnt = 1;
+    }
+    r = munmap_write(vma->file, addr, vma->woff, len);
+    if(r)
+      return r;
+    if(dec_refcnt){
+      fileclose(vma->file);
+    }
+    vma->woff += len;
+  }
+
+  uvmunmap(p->pagetable, addr, len/PGSIZE, 1);
+
+  vma->len -= len;
+  if(dec_refcnt){
+    // mark it invalid
+    memset(vma, 0, sizeof(struct vm_area));
+    vma->start_addr = -1;
+  } else
+    vma->start_addr += len;
+
+  return r;
+}
diff --git a/kernel/proc.h b/kernel/proc.h
index a195b02..ebdbb7a 100644
--- a/kernel/proc.h
+++ b/kernel/proc.h
@@ -81,6 +81,20 @@ struct trapframe {
 
 enum procstate { UNUSED, USED, SLEEPING, RUNNABLE, RUNNING, ZOMBIE };
 
+// mmap vm area
+#define MAX_VM_AREA 0x40
+
+struct vm_area {
+  uint64 start_addr;
+  uint64 roff;
+  uint64 woff;
+  uint64 len;
+  int prot;
+  int flags;
+  struct file* file;
+  int fd;
+};
+
 // Per-process state
 struct proc {
   struct spinlock lock;
@@ -111,4 +125,6 @@ struct proc {
   int alarm_tickspassed;       // record how many ticks passed since last sigalarm handler call
   int alarm_caninvoke;         // prevent re-entrant calls to handler
   struct trapframe *atpfm;     // trapframe to resume after handling, must hold p->lock
+  struct vm_area vma[MAX_VM_AREA]; // vm_area
+  uint64 cur_max_vm_addr;      // current max vm addr, used by mmap
 };
diff --git a/kernel/syscall.c b/kernel/syscall.c
index c39ebd8..3c8d3d8 100644
--- a/kernel/syscall.c
+++ b/kernel/syscall.c
@@ -120,6 +120,8 @@ extern uint64 sys_connect(void);
 extern uint64 sys_pgaccess(void);
 #endif
 extern uint64 sys_symlink(void);
+extern uint64 sys_mmap(void);
+extern uint64 sys_munmap(void);
 
 // An array mapping syscall numbers from syscall.h
 // to the function that handles the system call.
@@ -156,6 +158,8 @@ static uint64 (*syscalls[])(void) = {
 [SYS_sigalarm] sys_sigalarm,
 [SYS_sigreturn] sys_sigreturn,
 [SYS_symlink] sys_symlink,
+[SYS_mmap]    sys_mmap,
+[SYS_munmap]  sys_munmap,
 };
 
 // syscall name maps for SYS_trace:
@@ -192,6 +196,8 @@ static char *syscall_names[] = {
 [SYS_sigalarm]  "sigalarm",
 [SYS_sigreturn] "sigreturn",
 [SYS_symlink]  "symlink",
+[SYS_mmap]    "mmap",
+[SYS_munmap]  "munmap",
 };
 
 
diff --git a/kernel/sysfile.c b/kernel/sysfile.c
index 9c12d44..639c02b 100644
--- a/kernel/sysfile.c
+++ b/kernel/sysfile.c
@@ -27,7 +27,7 @@ struct symlink {
 
 // Fetch the nth word-sized system call argument as a file descriptor
 // and return both the descriptor and the corresponding struct file.
-static int
+int
 argfd(int n, int *pfd, struct file **pf)
 {
   int fd;
diff --git a/kernel/sysproc.c b/kernel/sysproc.c
index 715a511..abe1cdd 100644
--- a/kernel/sysproc.c
+++ b/kernel/sysproc.c
@@ -1,4 +1,5 @@
 #include "types.h"
+#include "fcntl.h"
 #include "riscv.h"
 #include "param.h"
 #include "defs.h"
@@ -140,7 +141,8 @@ sys_sigalarm(void)
   return 0;
 }
 
-uint64 sys_sigreturn(void)
+uint64
+sys_sigreturn(void)
 {
   struct proc *p = myproc();
   // retore saved trapframe to resume
@@ -150,3 +152,35 @@ uint64 sys_sigreturn(void)
   // make sure return the original a0 in trapframe to pass test3
   return p->trapframe->a0;
 }
+
+uint64
+sys_mmap(void)
+{
+  uint64 addr, len, offset;
+  int prot, flags, fd;
+  struct file* file;
+
+  argaddr(0, &addr);
+  argaddr(1, &len);
+  argint(2, &prot);
+  argint(3, &flags);
+  if(argfd(4, &fd, &file) == -1)
+    return 0xffffffffffffffff;
+  argaddr(5, &offset);
+
+  if(!(fileperm(file) & PROT_WRITE) && (prot & PROT_WRITE) && (flags == MAP_SHARED))
+    return 0xffffffffffffffff;
+
+  return mmap(addr, len, prot, flags, fd, file, offset);
+}
+
+uint64
+sys_munmap(void)
+{
+  uint64 addr, len;
+
+  argaddr(0, &addr);
+  argaddr(1, &len);
+
+  return munmap(addr, len);
+}
diff --git a/kernel/vm.c b/kernel/vm.c
index be7d042..08859ae 100644
--- a/kernel/vm.c
+++ b/kernel/vm.c
@@ -182,6 +182,7 @@ uvmunmap(pagetable_t pagetable, uint64 va, uint64 npages, int do_free)
 {
   uint64 a;
   pte_t *pte;
+  int cow_pg = 0;
 
   if((va % PGSIZE) != 0)
     panic("uvmunmap: not aligned");
@@ -192,10 +193,11 @@ uvmunmap(pagetable_t pagetable, uint64 va, uint64 npages, int do_free)
     if((*pte & PTE_V) == 0) {
       printf("va=%p pte=%p\n", a, *pte);
       panic("uvmunmap: not mapped");
+      cow_pg = 1;
     }
     if(PTE_FLAGS(*pte) == PTE_V)
       panic("uvmunmap: not a leaf");
-    if(do_free){
+    if(do_free && !cow_pg){
       uint64 pa = PTE2PA(*pte);
       kfree((void*)pa);
     }
@@ -315,14 +317,14 @@ uvmfree(pagetable_t pagetable, uint64 sz)
 // returns 0 on success, -1 on failure.
 // frees any allocated pages on failure.
 int
-uvmcopy(pagetable_t old, pagetable_t new, uint64 sz)
+uvmcopy(pagetable_t old, pagetable_t new, uint64 start, uint64 sz)
 {
   pte_t *pte;
   uint64 pa, i;
   uint flags;
   // char *mem;
 
-  for(i = 0; i < sz; i += PGSIZE){
+  for(i = start; i < start+sz; i += PGSIZE){
     if((pte = walk(old, i, 0)) == 0)
       panic("uvmcopy: pte should exist");
     if((*pte & PTE_V) == 0)
-- 
cgit v1.2.3