diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
index cd687395e4e7..27c6ece91da6 100644
--- a/arch/x86/kvm/x86.c
+++ b/arch/x86/kvm/x86.c
@@ -3267,54 +3267,6 @@ static u32 get_tss_base_addr(struct kvm_vcpu *vcpu,
 	return vcpu->arch.mmu.gva_to_gpa(vcpu, base_addr);
 }
 
-static int load_tss_segment32(struct kvm_vcpu *vcpu,
-			      struct desc_struct *seg_desc,
-			      struct tss_segment_32 *tss)
-{
-	u32 base_addr;
-
-	base_addr = get_tss_base_addr(vcpu, seg_desc);
-
-	return kvm_read_guest(vcpu->kvm, base_addr, tss,
-			      sizeof(struct tss_segment_32));
-}
-
-static int save_tss_segment32(struct kvm_vcpu *vcpu,
-			      struct desc_struct *seg_desc,
-			      struct tss_segment_32 *tss)
-{
-	u32 base_addr;
-
-	base_addr = get_tss_base_addr(vcpu, seg_desc);
-
-	return kvm_write_guest(vcpu->kvm, base_addr, tss,
-			       sizeof(struct tss_segment_32));
-}
-
-static int load_tss_segment16(struct kvm_vcpu *vcpu,
-			      struct desc_struct *seg_desc,
-			      struct tss_segment_16 *tss)
-{
-	u32 base_addr;
-
-	base_addr = get_tss_base_addr(vcpu, seg_desc);
-
-	return kvm_read_guest(vcpu->kvm, base_addr, tss,
-			      sizeof(struct tss_segment_16));
-}
-
-static int save_tss_segment16(struct kvm_vcpu *vcpu,
-			      struct desc_struct *seg_desc,
-			      struct tss_segment_16 *tss)
-{
-	u32 base_addr;
-
-	base_addr = get_tss_base_addr(vcpu, seg_desc);
-
-	return kvm_write_guest(vcpu->kvm, base_addr, tss,
-			       sizeof(struct tss_segment_16));
-}
-
 static u16 get_segment_selector(struct kvm_vcpu *vcpu, int seg)
 {
 	struct kvm_segment kvm_seg;
@@ -3472,20 +3424,26 @@ static int load_state_from_tss16(struct kvm_vcpu *vcpu,
 }
 
 static int kvm_task_switch_16(struct kvm_vcpu *vcpu, u16 tss_selector,
-		       struct desc_struct *cseg_desc,
+		       u32 old_tss_base,
 		       struct desc_struct *nseg_desc)
 {
 	struct tss_segment_16 tss_segment_16;
 	int ret = 0;
 
-	if (load_tss_segment16(vcpu, cseg_desc, &tss_segment_16))
+	if (kvm_read_guest(vcpu->kvm, old_tss_base, &tss_segment_16,
+			   sizeof tss_segment_16))
 		goto out;
 
 	save_state_to_tss16(vcpu, &tss_segment_16);
-	save_tss_segment16(vcpu, cseg_desc, &tss_segment_16);
 
-	if (load_tss_segment16(vcpu, nseg_desc, &tss_segment_16))
+	if (kvm_write_guest(vcpu->kvm, old_tss_base, &tss_segment_16,
+			    sizeof tss_segment_16))
 		goto out;
+
+	if (kvm_read_guest(vcpu->kvm, get_tss_base_addr(vcpu, nseg_desc),
+			   &tss_segment_16, sizeof tss_segment_16))
+		goto out;
+
 	if (load_state_from_tss16(vcpu, &tss_segment_16))
 		goto out;
 
@@ -3495,20 +3453,26 @@ out:
 }
 
 static int kvm_task_switch_32(struct kvm_vcpu *vcpu, u16 tss_selector,
-		       struct desc_struct *cseg_desc,
+		       u32 old_tss_base,
 		       struct desc_struct *nseg_desc)
 {
 	struct tss_segment_32 tss_segment_32;
 	int ret = 0;
 
-	if (load_tss_segment32(vcpu, cseg_desc, &tss_segment_32))
+	if (kvm_read_guest(vcpu->kvm, old_tss_base, &tss_segment_32,
+			   sizeof tss_segment_32))
 		goto out;
 
 	save_state_to_tss32(vcpu, &tss_segment_32);
-	save_tss_segment32(vcpu, cseg_desc, &tss_segment_32);
 
-	if (load_tss_segment32(vcpu, nseg_desc, &tss_segment_32))
+	if (kvm_write_guest(vcpu->kvm, old_tss_base, &tss_segment_32,
+			    sizeof tss_segment_32))
 		goto out;
+
+	if (kvm_read_guest(vcpu->kvm, get_tss_base_addr(vcpu, nseg_desc),
+			   &tss_segment_32, sizeof tss_segment_32))
+		goto out;
+
 	if (load_state_from_tss32(vcpu, &tss_segment_32))
 		goto out;
 
@@ -3523,16 +3487,20 @@ int kvm_task_switch(struct kvm_vcpu *vcpu, u16 tss_selector, int reason)
 	struct desc_struct cseg_desc;
 	struct desc_struct nseg_desc;
 	int ret = 0;
+	u32 old_tss_base = get_segment_base(vcpu, VCPU_SREG_TR);
+	u16 old_tss_sel = get_segment_selector(vcpu, VCPU_SREG_TR);
 
-	kvm_get_segment(vcpu, &tr_seg, VCPU_SREG_TR);
+	old_tss_base = vcpu->arch.mmu.gva_to_gpa(vcpu, old_tss_base);
 
+	/* FIXME: Handle errors. Failure to read either TSS or their
+	 * descriptors should generate a pagefault.
+	 */
 	if (load_guest_segment_descriptor(vcpu, tss_selector, &nseg_desc))
 		goto out;
 
-	if (load_guest_segment_descriptor(vcpu, tr_seg.selector, &cseg_desc))
+	if (load_guest_segment_descriptor(vcpu, old_tss_sel, &cseg_desc))
 		goto out;
 
-
 	if (reason != TASK_SWITCH_IRET) {
 		int cpl;
 
@@ -3550,8 +3518,7 @@ int kvm_task_switch(struct kvm_vcpu *vcpu, u16 tss_selector, int reason)
 
 	if (reason == TASK_SWITCH_IRET || reason == TASK_SWITCH_JMP) {
 		cseg_desc.type &= ~(1 << 1); //clear the B flag
-		save_guest_segment_descriptor(vcpu, tr_seg.selector,
-					      &cseg_desc);
+		save_guest_segment_descriptor(vcpu, old_tss_sel, &cseg_desc);
 	}
 
 	if (reason == TASK_SWITCH_IRET) {
@@ -3563,10 +3530,10 @@ int kvm_task_switch(struct kvm_vcpu *vcpu, u16 tss_selector, int reason)
 	kvm_x86_ops->cache_regs(vcpu);
 
 	if (nseg_desc.type & 8)
-		ret = kvm_task_switch_32(vcpu, tss_selector, &cseg_desc,
+		ret = kvm_task_switch_32(vcpu, tss_selector, old_tss_base,
 					 &nseg_desc);
 	else
-		ret = kvm_task_switch_16(vcpu, tss_selector, &cseg_desc,
+		ret = kvm_task_switch_16(vcpu, tss_selector, old_tss_base,
 					 &nseg_desc);
 
 	if (reason == TASK_SWITCH_CALL || reason == TASK_SWITCH_GATE) {