summary refs log tree commit diff
path: root/arch/sparc/mm/gup.c
diff options
context:
space:
mode:
Diffstat (limited to 'arch/sparc/mm/gup.c')
-rw-r--r--arch/sparc/mm/gup.c59
1 files changed, 57 insertions, 2 deletions
diff --git a/arch/sparc/mm/gup.c b/arch/sparc/mm/gup.c
index 42c55df3aec3..01ee23dd724d 100644
--- a/arch/sparc/mm/gup.c
+++ b/arch/sparc/mm/gup.c
@@ -66,6 +66,56 @@ static noinline int gup_pte_range(pmd_t pmd, unsigned long addr,
 	return 1;
 }
 
+static int gup_huge_pmd(pmd_t *pmdp, pmd_t pmd, unsigned long addr,
+			unsigned long end, int write, struct page **pages,
+			int *nr)
+{
+	struct page *head, *page, *tail;
+	u32 mask;
+	int refs;
+
+	mask = PMD_HUGE_PRESENT;
+	if (write)
+		mask |= PMD_HUGE_WRITE;
+	if ((pmd_val(pmd) & mask) != mask)
+		return 0;
+
+	refs = 0;
+	head = pmd_page(pmd);
+	page = head + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
+	tail = page;
+	do {
+		VM_BUG_ON(compound_head(page) != head);
+		pages[*nr] = page;
+		(*nr)++;
+		page++;
+		refs++;
+	} while (addr += PAGE_SIZE, addr != end);
+
+	if (!page_cache_add_speculative(head, refs)) {
+		*nr -= refs;
+		return 0;
+	}
+
+	if (unlikely(pmd_val(pmd) != pmd_val(*pmdp))) {
+		*nr -= refs;
+		while (refs--)
+			put_page(head);
+		return 0;
+	}
+
+	/* Any tail page need their mapcount reference taken before we
+	 * return.
+	 */
+	while (refs--) {
+		if (PageTail(tail))
+			get_huge_page_tail(tail);
+		tail++;
+	}
+
+	return 1;
+}
+
 static int gup_pmd_range(pud_t pud, unsigned long addr, unsigned long end,
 		int write, struct page **pages, int *nr)
 {
@@ -77,9 +127,14 @@ static int gup_pmd_range(pud_t pud, unsigned long addr, unsigned long end,
 		pmd_t pmd = *pmdp;
 
 		next = pmd_addr_end(addr, end);
-		if (pmd_none(pmd))
+		if (pmd_none(pmd) || pmd_trans_splitting(pmd))
 			return 0;
-		if (!gup_pte_range(pmd, addr, next, write, pages, nr))
+		if (unlikely(pmd_large(pmd))) {
+			if (!gup_huge_pmd(pmdp, pmd, addr, next,
+					  write, pages, nr))
+				return 0;
+		} else if (!gup_pte_range(pmd, addr, next, write,
+					  pages, nr))
 			return 0;
 	} while (pmdp++, addr = next, addr != end);