diff --git a/include/net/mac80211.h b/include/net/mac80211.h
index e6521261a585..982d2cd80166 100644
--- a/include/net/mac80211.h
+++ b/include/net/mac80211.h
@@ -3411,14 +3411,20 @@ void ieee80211_tx_status_irqsafe(struct ieee80211_hw *hw,
  */
 void ieee80211_report_low_ack(struct ieee80211_sta *sta, u32 num_packets);
 
+#define IEEE80211_MAX_CSA_COUNTERS_NUM 2
+
 /**
  * struct ieee80211_mutable_offsets - mutable beacon offsets
  * @tim_offset: position of TIM element
  * @tim_length: size of TIM element
+ * @csa_offs: array of IEEE80211_MAX_CSA_COUNTERS_NUM offsets to CSA counters.
+ *	This array can contain zero values which should be ignored.
  */
 struct ieee80211_mutable_offsets {
 	u16 tim_offset;
 	u16 tim_length;
+
+	u16 csa_counter_offs[IEEE80211_MAX_CSA_COUNTERS_NUM];
 };
 
 /**
@@ -3433,7 +3439,8 @@ struct ieee80211_mutable_offsets {
  *
  * This function should be used if the beacon frames are generated by the
  * device, and then the driver must use the returned beacon as the template
- * The driver is responsible to update the DTIM count.
+ * The driver or the device are responsible to update the DTIM and, when
+ * applicable, the CSA count.
  *
  * The driver is responsible for freeing the returned skb.
  *
@@ -3485,6 +3492,20 @@ static inline struct sk_buff *ieee80211_beacon_get(struct ieee80211_hw *hw,
 	return ieee80211_beacon_get_tim(hw, vif, NULL, NULL);
 }
 
+/**
+ * ieee80211_csa_update_counter - request mac80211 to decrement the csa counter
+ * @vif: &struct ieee80211_vif pointer from the add_interface callback.
+ *
+ * The csa counter should be updated after each beacon transmission.
+ * This function is called implicitly when
+ * ieee80211_beacon_get/ieee80211_beacon_get_tim are called, however if the
+ * beacon frames are generated by the device, the driver should call this
+ * function after each beacon transmission to sync mac80211's csa counters.
+ *
+ * Return: new csa counter value
+ */
+u8 ieee80211_csa_update_counter(struct ieee80211_vif *vif);
+
 /**
  * ieee80211_csa_finish - notify mac80211 about channel switch
  * @vif: &struct ieee80211_vif pointer from the add_interface callback.
diff --git a/net/mac80211/cfg.c b/net/mac80211/cfg.c
index d44dca56b8ff..bfd2534e5a4d 100644
--- a/net/mac80211/cfg.c
+++ b/net/mac80211/cfg.c
@@ -3502,10 +3502,10 @@ static int ieee80211_mgmt_tx(struct wiphy *wiphy, struct wireless_dev *wdev,
 	     sdata->vif.type == NL80211_IFTYPE_ADHOC) &&
 	    params->n_csa_offsets) {
 		int i;
+		u8 c = sdata->csa_current_counter;
 
 		for (i = 0; i < params->n_csa_offsets; i++)
-			data[params->csa_offsets[i]] =
-					sdata->csa_current_counter;
+			data[params->csa_offsets[i]] = c;
 	}
 
 	IEEE80211_SKB_CB(skb)->flags = flags;
diff --git a/net/mac80211/ieee80211_i.h b/net/mac80211/ieee80211_i.h
index 05ed592d8bbe..57e0b2682cef 100644
--- a/net/mac80211/ieee80211_i.h
+++ b/net/mac80211/ieee80211_i.h
@@ -70,8 +70,6 @@ struct ieee80211_local;
 
 #define IEEE80211_DEAUTH_FRAME_LEN	(24 /* hdr */ + 2 /* reason */)
 
-#define IEEE80211_MAX_CSA_COUNTERS_NUM 2
-
 struct ieee80211_fragment_entry {
 	unsigned long first_frag_time;
 	unsigned int seq;
diff --git a/net/mac80211/tx.c b/net/mac80211/tx.c
index 509456e5722d..5214686d9fd1 100644
--- a/net/mac80211/tx.c
+++ b/net/mac80211/tx.c
@@ -2416,13 +2416,14 @@ static int ieee80211_beacon_add_tim(struct ieee80211_sub_if_data *sdata,
 	return 0;
 }
 
-static void ieee80211_update_csa(struct ieee80211_sub_if_data *sdata,
-				 struct beacon_data *beacon)
+static void ieee80211_set_csa(struct ieee80211_sub_if_data *sdata,
+			      struct beacon_data *beacon)
 {
 	struct probe_resp *resp;
 	u8 *beacon_data;
 	size_t beacon_data_len;
 	int i;
+	u8 count = sdata->csa_current_counter;
 
 	switch (sdata->vif.type) {
 	case NL80211_IFTYPE_AP:
@@ -2450,16 +2451,7 @@ static void ieee80211_update_csa(struct ieee80211_sub_if_data *sdata,
 			if (WARN_ON(counter_offset_beacon >= beacon_data_len))
 				return;
 
-			/* Warn if the driver did not check for/react to csa
-			 * completeness.  A beacon with CSA counter set to 0
-			 * should never occur, because a counter of 1 means
-			 * switch just before the next beacon.
-			 */
-			if (WARN_ON(beacon_data[counter_offset_beacon] == 1))
-				return;
-
-			beacon_data[counter_offset_beacon] =
-				sdata->csa_current_counter - 1;
+			beacon_data[counter_offset_beacon] = count;
 		}
 
 		if (sdata->vif.type == NL80211_IFTYPE_AP &&
@@ -2474,14 +2466,24 @@ static void ieee80211_update_csa(struct ieee80211_sub_if_data *sdata,
 				rcu_read_unlock();
 				return;
 			}
-			resp->data[counter_offset_presp] =
-				sdata->csa_current_counter - 1;
+			resp->data[counter_offset_presp] = count;
 			rcu_read_unlock();
 		}
 	}
+}
+
+u8 ieee80211_csa_update_counter(struct ieee80211_vif *vif)
+{
+	struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif);
 
 	sdata->csa_current_counter--;
+
+	/* the counter should never reach 0 */
+	WARN_ON(!sdata->csa_current_counter);
+
+	return sdata->csa_current_counter;
 }
+EXPORT_SYMBOL(ieee80211_csa_update_counter);
 
 bool ieee80211_csa_is_complete(struct ieee80211_vif *vif)
 {
@@ -2552,6 +2554,7 @@ __ieee80211_beacon_get(struct ieee80211_hw *hw,
 	enum ieee80211_band band;
 	struct ieee80211_tx_rate_control txrc;
 	struct ieee80211_chanctx_conf *chanctx_conf;
+	int csa_off_base = 0;
 
 	rcu_read_lock();
 
@@ -2569,8 +2572,12 @@ __ieee80211_beacon_get(struct ieee80211_hw *hw,
 		struct beacon_data *beacon = rcu_dereference(ap->beacon);
 
 		if (beacon) {
-			if (sdata->vif.csa_active)
-				ieee80211_update_csa(sdata, beacon);
+			if (sdata->vif.csa_active) {
+				if (!is_template)
+					ieee80211_csa_update_counter(vif);
+
+				ieee80211_set_csa(sdata, beacon);
+			}
 
 			/*
 			 * headroom, head length,
@@ -2593,6 +2600,9 @@ __ieee80211_beacon_get(struct ieee80211_hw *hw,
 			if (offs) {
 				offs->tim_offset = beacon->head_len;
 				offs->tim_length = skb->len - beacon->head_len;
+
+				/* for AP the csa offsets are from tail */
+				csa_off_base = skb->len;
 			}
 
 			if (beacon->tail)
@@ -2608,9 +2618,12 @@ __ieee80211_beacon_get(struct ieee80211_hw *hw,
 		if (!presp)
 			goto out;
 
-		if (sdata->vif.csa_active)
-			ieee80211_update_csa(sdata, presp);
+		if (sdata->vif.csa_active) {
+			if (!is_template)
+				ieee80211_csa_update_counter(vif);
 
+			ieee80211_set_csa(sdata, presp);
+		}
 
 		skb = dev_alloc_skb(local->tx_headroom + presp->head_len +
 				    local->hw.extra_beacon_tailroom);
@@ -2630,8 +2643,17 @@ __ieee80211_beacon_get(struct ieee80211_hw *hw,
 		if (!bcn)
 			goto out;
 
-		if (sdata->vif.csa_active)
-			ieee80211_update_csa(sdata, bcn);
+		if (sdata->vif.csa_active) {
+			if (!is_template)
+				/* TODO: For mesh csa_counter is in TU, so
+				 * decrementing it by one isn't correct, but
+				 * for now we leave it consistent with overall
+				 * mac80211's behavior.
+				 */
+				ieee80211_csa_update_counter(vif);
+
+			ieee80211_set_csa(sdata, bcn);
+		}
 
 		if (ifmsh->sync_ops)
 			ifmsh->sync_ops->adjust_tbtt(sdata, bcn);
@@ -2658,6 +2680,20 @@ __ieee80211_beacon_get(struct ieee80211_hw *hw,
 		goto out;
 	}
 
+	/* CSA offsets */
+	if (offs) {
+		int i;
+
+		for (i = 0; i < IEEE80211_MAX_CSA_COUNTERS_NUM; i++) {
+			u16 csa_off = sdata->csa_counter_offset_beacon[i];
+
+			if (!csa_off)
+				continue;
+
+			offs->csa_counter_offs[i] = csa_off_base + csa_off;
+		}
+	}
+
 	band = chanctx_conf->def.chan->band;
 
 	info = IEEE80211_SKB_CB(skb);