diff --git a/net/ipv4/tcp_cong.c b/net/ipv4/tcp_cong.c
index 5c8caf4a1244..34ae3f13483a 100644
--- a/net/ipv4/tcp_cong.c
+++ b/net/ipv4/tcp_cong.c
@@ -77,18 +77,19 @@ void tcp_init_congestion_control(struct sock *sk)
 	struct inet_connection_sock *icsk = inet_csk(sk);
 	struct tcp_congestion_ops *ca;
 
-	if (icsk->icsk_ca_ops != &tcp_init_congestion_ops)
-		return;
+	/* if no choice made yet assign the current value set as default */
+	if (icsk->icsk_ca_ops == &tcp_init_congestion_ops) {
+		rcu_read_lock();
+		list_for_each_entry_rcu(ca, &tcp_cong_list, list) {
+			if (try_module_get(ca->owner)) {
+				icsk->icsk_ca_ops = ca;
+				break;
+			}
 
-	rcu_read_lock();
-	list_for_each_entry_rcu(ca, &tcp_cong_list, list) {
-		if (try_module_get(ca->owner)) {
-			icsk->icsk_ca_ops = ca;
-			break;
+			/* fallback to next available */
 		}
-
+		rcu_read_unlock();
 	}
-	rcu_read_unlock();
 
 	if (icsk->icsk_ca_ops->init)
 		icsk->icsk_ca_ops->init(sk);
@@ -236,6 +237,7 @@ int tcp_set_congestion_control(struct sock *sk, const char *name)
 
 	rcu_read_lock();
 	ca = tcp_ca_find(name);
+
 	/* no change asking for existing value */
 	if (ca == icsk->icsk_ca_ops)
 		goto out;
@@ -261,7 +263,8 @@ int tcp_set_congestion_control(struct sock *sk, const char *name)
 	else {
 		tcp_cleanup_congestion_control(sk);
 		icsk->icsk_ca_ops = ca;
-		if (icsk->icsk_ca_ops->init)
+
+		if (sk->sk_state != TCP_CLOSE && icsk->icsk_ca_ops->init)
 			icsk->icsk_ca_ops->init(sk);
 	}
  out: