vmi_ops.allocate_page(pfn, VMI_PAGE_L1, 0, 0, 0);
 }
 
-static void vmi_allocate_pd(u32 pfn)
+static void vmi_allocate_pd(struct mm_struct *mm, u32 pfn)
 {
        /*
         * This call comes in very early, before mem_map is setup.
 
        if (!(pgd_val(*pgd) & _PAGE_PRESENT)) {
                pmd_table = (pmd_t *) alloc_bootmem_low_pages(PAGE_SIZE);
 
-               paravirt_alloc_pd(__pa(pmd_table) >> PAGE_SHIFT);
+               paravirt_alloc_pd(&init_mm, __pa(pmd_table) >> PAGE_SHIFT);
                set_pgd(pgd, __pgd(__pa(pmd_table) | _PAGE_PRESENT));
                pud = pud_offset(pgd, 0);
                if (pmd_table != pmd_offset(pud, 0))
        memset(&base[USER_PTRS_PER_PGD], 0,
               KERNEL_PGD_PTRS * sizeof(pgd_t));
 #else
-       paravirt_alloc_pd(__pa(swapper_pg_dir) >> PAGE_SHIFT);
+       paravirt_alloc_pd(&init_mm, __pa(base) >> PAGE_SHIFT);
 #endif
 }
 
 
        if (PTRS_PER_PMD == 1 || !pgd)
                return pgd;
 
+       mm->pgd = pgd;          /* so that alloc_pd can use it */
+
        for (i = 0; i < UNSHARED_PTRS_PER_PGD; ++i) {
                pmd_t *pmd = pmd_cache_alloc(i);
 
                if (!pmd)
                        goto out_oom;
 
-               paravirt_alloc_pd(__pa(pmd) >> PAGE_SHIFT);
+               paravirt_alloc_pd(mm, __pa(pmd) >> PAGE_SHIFT);
                set_pgd(&pgd[i], __pgd(1 + __pa(pmd)));
        }
        return pgd;
 
 
        /* Hooks for allocating/releasing pagetable pages */
        void (*alloc_pt)(struct mm_struct *mm, u32 pfn);
-       void (*alloc_pd)(u32 pfn);
+       void (*alloc_pd)(struct mm_struct *mm, u32 pfn);
        void (*alloc_pd_clone)(u32 pfn, u32 clonepfn, u32 start, u32 count);
        void (*release_pt)(u32 pfn);
        void (*release_pd)(u32 pfn);
        PVOP_VCALL1(pv_mmu_ops.release_pt, pfn);
 }
 
-static inline void paravirt_alloc_pd(unsigned pfn)
+static inline void paravirt_alloc_pd(struct mm_struct *mm, unsigned pfn)
 {
-       PVOP_VCALL1(pv_mmu_ops.alloc_pd, pfn);
+       PVOP_VCALL2(pv_mmu_ops.alloc_pd, mm, pfn);
 }
 
 static inline void paravirt_alloc_pd_clone(unsigned pfn, unsigned clonepfn,
 
 #include <asm/paravirt.h>
 #else
 #define paravirt_alloc_pt(mm, pfn) do { } while (0)
-#define paravirt_alloc_pd(pfn) do { } while (0)
-#define paravirt_alloc_pd(pfn) do { } while (0)
+#define paravirt_alloc_pd(mm, pfn) do { } while (0)
 #define paravirt_alloc_pd_clone(pfn, clonepfn, start, count) do { } while (0)
 #define paravirt_release_pt(pfn) do { } while (0)
 #define paravirt_release_pd(pfn) do { } while (0)