如何使用ThreadLocal避免線程安全問題?

語言: CN / TW / HK

這篇文章是關於ThreadLocal的第二篇文章。

在上一篇文章,Yasin給大家介紹了什麼是ThreadLocal,以及ThreadLocal的基本原理。

那在實際工作中,ThreadLocal一般用來做什麼呢?今天我們以一個簡單的應用場景為例,給大家介紹如何用ThreadLocal來幫助我們解決多線程的安全問題。

這是一個簡單的統計計數的問題。比如説我們想要統計一段時間內某個接口的調用量。每次訪問接口,統計量都 +1。首先我們來一個最簡單的線程不安全的基礎實現:

@RestController
@RequestMapping("orders")
public class OrderController {

    private Integer count = 0;

    @GetMapping("/visit")
    public Integer visit() throws InterruptedException {
        count++;
        Thread.sleep(100);
        return 0;
    }

    @GetMapping("/stat")
    public Integer stat() {
        return count;
    }
}
複製代碼

這裏我們假設調用這個接口會有100毫秒的消耗(模擬同步IO操作)。稍微懂一點多線程知識的同學都知道,這個時候是線程不安全的。假如同時多個線程來訪問這個接口,就會造成數據不一致問題。我們試着用ab來測試一下。

# 總共調用10000次,100併發
$ ab -n 10000 -c 100 localhost:8080/orders/visit

$ curl localhost:8080/orders/stat
9953(base)
複製代碼

我們預期調用stat應該返回10000才對,但實際返回了9953。為什麼會造成這樣的結果呢?是因為 count++ 這個操作不是線程安全的。這裏涉及到一個內存模型的知識,對於這個操作,首先我們是從內存裏面讀取原來的值,放在了線程本地內存裏。然後進行 +1 操作,再寫回到內存裏。

Java內存模型
Java內存模型

這個時候如果多個線程操作的話,有可能線程A這邊還沒來得及寫,線程B那邊讀取的是原來的值。這樣子的話就會造成數據不一致的問題。結果就會比預期的小。

那如何解決這個線程安全的問題呢?解決辦法有很多種。我們先嚐試用一個最簡單的辦法,就是加鎖。上篇文章我們聊到解決多線程問題有幾種思路,其中一個是排隊。使用鎖就是排隊的理念,它可以絕對的保證線程安全。我們先來看一下使用鎖之後的效果。

@GetMapping("/visit")
public Integer visit() throws InterruptedException {
    Thread.sleep(100);
    this.add();
    return 0;
}

private synchronized void add() {
    count++;
}
複製代碼

同樣壓測一下。可以看到結果是正確的,符合我們期望的10000。

$ ab -n 10000 -c 100 localhost:8080/orders/visit

$ curl localhost:8080/orders/stat
10000(base)
複製代碼

那有沒有什麼其它辦法可以做到線程安全呢。

前面我們説到,對於這個case來説,使用count++會造成線程不安全,那是因為多個線程都在爭用同一個資源count。我們可以使用“避免”的思想,使得一個線程只用自己的資源,不去用別人的資源就好啦,這樣子就不會存在線程安全問題了。

我們使用ThreadLocal,修改一下代碼:

@RestController
@RequestMapping("orders")
public class OrderController {

    private static final ThreadLocal<Integer> TL = ThreadLocal.withInitial(() -> 0);

    @GetMapping("/visit")
    public Integer visit() throws InterruptedException {
        Thread.sleep(100);
        TL.set(TL.get() + 1);
        return 0;
    }

    @GetMapping("/stat")
    public Integer stat() {
        return TL.get();
    }
}
複製代碼

同樣用ab測一下。

$ ab -n 10000 -c 100 localhost:8080/orders/visit

$ curl localhost:8080/orders/stat
99(base)
複製代碼

當我們訪問統計量接口,發現只能得到當前線程的統計量。那我們怎麼才能得到所有線程加起來的統計量總和呢?

這個功能ThreadLocal並沒有實現,需要我們自己編寫代碼輔助。其實思路很簡單,我們只需要把每個線程對應的value的引用,放到一個統一的容器裏面,然後我們需要用的時候從這個容器取出來遍歷一遍就好了。

首先,我們嘗試使用一個HashSet來保存這個值。這裏需要注意的是我們在初始化這個值的時候需要加鎖。因為HashSet並不是線程安全的。

@RestController
@RequestMapping("orders")
public class OrderController {

    private static final Set<Integer> SET = new HashSet<>();
    private static final ThreadLocal<Integer> TL = ThreadLocal.withInitial(() -> {
        Integer value = 0;
        addSet(value);
        return value;
    });
    
    private static synchronized void addSet(Val<Integer> val) {
        SET.add(val);
    }

    @GetMapping("/visit")
    public Integer visit() throws InterruptedException {
        Thread.sleep(100);
        TL.set(TL.get() + 1);
        return 0;
    }

    @GetMapping("/stat")
    public Integer stat() {
        return SET.stream().reduce(Integer::sum).orElse(-1);
    }
}
複製代碼

但是我們測試一下發現,好像並不生效,stat結果總是0。為什麼呢?

因為Integer有些特殊,它是一個原生類型int的封裝類,它內部有一個緩存,當它的值比較小(-128~127)的時候,使用的是同一個對象。而+1操作也不會改變原來引用對應的值。所以它不能作為一個正常的引用對象來使用。

那如何解決這個問題?很簡單,我們在外面給他包一層對象就好了。

public class Val<T> {
    T v;

    public T getV() {
        return v;
    }

    public void setV(T v) {
        this.v = v;
    }
}

@RestController
@RequestMapping("orders")
public class OrderController {

    private static final Set<Val<Integer>> SET = new HashSet<>();
    private static final ThreadLocal<Val<Integer>> TL = ThreadLocal.withInitial(() -> {
        Val<Integer> val = new Val<>();
        val.setV(0);
        addSet(val);
        return val;
    });

    private static synchronized void addSet(Val<Integer> val) {
        SET.add(val);
    }

    @GetMapping("/visit")
    public Integer visit() throws InterruptedException {
        Thread.sleep(100);
        Val<Integer> val = TL.get();
        val.setV(val.getV() + 1);
        return 0;
    }

    @GetMapping("/stat")
    public Integer stat() {
        return SET.stream().map(Val::getV).reduce(Integer::sum).orElse(-1);
    }
}
複製代碼

然後我們再測試一下,發現可以得到我們預期的結果。

$ ab -n 10000 -c 100 localhost:8080/orders/visit

$ curl localhost:8080/orders/stat
10000(base)
複製代碼

有些同學可能會疑惑。那這個比起直接使用synchronized或者原子類,孰優孰劣呢?

其實兩者用的思想不一樣,上鎖和原子類使用的是排隊的思想,而ThreadLocal使用的是避免的思想。它通過自己的一個設計哲學避免了線程的爭用,所以效率也會比較高。要知道,排隊是很危險的,一旦你的臨界區比較耗時,很有可能造成大量線程阻塞,導致系統不可用。

臨界區:多個線程爭用資源的區域,同時只能有一個線程運行那部分代碼。

我們這個case由於線程爭用的資源很簡單,臨界區就是一個Integer類型的變量,所以看不太出來使用ThreadLocal的優勢。但如果臨界區的消耗較大,ThreadLocal的優勢就體現出來了。大家可以嘗試在前面的synchronized方法中sleep 100ms試一下效果。

雖然ThreadLocal不一定能避免所有的線程安全問題,比如這個case,我們在初始化addSet的時候,仍然要同步上鎖。但是他可以把線程安全的問題縮小範圍,提升性能。

那麼你get到使用ThreadLocal的精髓了嗎?還有哪些場景可以使用ThreadLocal呢?下篇文章我們會解析主流框架的源碼,看看大神們是如何使用ThreadLocal的。

關於作者

我是Yasin,一個有顏有料又有趣的程序員。

微信公眾號:編了個程

個人網站:http://yasinshaw.com

關注我的公眾號,和我一起成長~

公眾號
公眾號