diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index ed277276fa98..9920bae6ee43 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -401,15 +401,14 @@ void vhost_dev_cleanup(struct vhost_dev *dev)
 	kfree(rcu_dereference_protected(dev->memory,
 					lockdep_is_held(&dev->mutex)));
 	RCU_INIT_POINTER(dev->memory, NULL);
-	if (dev->mm)
-		mmput(dev->mm);
-	dev->mm = NULL;
-
 	WARN_ON(!list_empty(&dev->work_list));
 	if (dev->worker) {
 		kthread_stop(dev->worker);
 		dev->worker = NULL;
 	}
+	if (dev->mm)
+		mmput(dev->mm);
+	dev->mm = NULL;
 }
 
 static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)