如何使用ThreadLocal避免线程安全问题?

语言: CN / TW / HK

这篇文章是关于ThreadLocal的第二篇文章。

在上一篇文章,Yasin给大家介绍了什么是ThreadLocal,以及ThreadLocal的基本原理。

那在实际工作中,ThreadLocal一般用来做什么呢?今天我们以一个简单的应用场景为例,给大家介绍如何用ThreadLocal来帮助我们解决多线程的安全问题。

这是一个简单的统计计数的问题。比如说我们想要统计一段时间内某个接口的调用量。每次访问接口,统计量都 +1。首先我们来一个最简单的线程不安全的基础实现:

@RestController
@RequestMapping("orders")
public class OrderController {

    private Integer count = 0;

    @GetMapping("/visit")
    public Integer visit() throws InterruptedException {
        count++;
        Thread.sleep(100);
        return 0;
    }

    @GetMapping("/stat")
    public Integer stat() {
        return count;
    }
}
复制代码

这里我们假设调用这个接口会有100毫秒的消耗(模拟同步IO操作)。稍微懂一点多线程知识的同学都知道,这个时候是线程不安全的。假如同时多个线程来访问这个接口,就会造成数据不一致问题。我们试着用ab来测试一下。

# 总共调用10000次,100并发
$ ab -n 10000 -c 100 localhost:8080/orders/visit

$ curl localhost:8080/orders/stat
9953(base)
复制代码

我们预期调用stat应该返回10000才对,但实际返回了9953。为什么会造成这样的结果呢?是因为 count++ 这个操作不是线程安全的。这里涉及到一个内存模型的知识,对于这个操作,首先我们是从内存里面读取原来的值,放在了线程本地内存里。然后进行 +1 操作,再写回到内存里。

Java内存模型
Java内存模型

这个时候如果多个线程操作的话,有可能线程A这边还没来得及写,线程B那边读取的是原来的值。这样子的话就会造成数据不一致的问题。结果就会比预期的小。

那如何解决这个线程安全的问题呢?解决办法有很多种。我们先尝试用一个最简单的办法,就是加锁。上篇文章我们聊到解决多线程问题有几种思路,其中一个是排队。使用锁就是排队的理念,它可以绝对的保证线程安全。我们先来看一下使用锁之后的效果。

@GetMapping("/visit")
public Integer visit() throws InterruptedException {
    Thread.sleep(100);
    this.add();
    return 0;
}

private synchronized void add() {
    count++;
}
复制代码

同样压测一下。可以看到结果是正确的,符合我们期望的10000。

$ ab -n 10000 -c 100 localhost:8080/orders/visit

$ curl localhost:8080/orders/stat
10000(base)
复制代码

那有没有什么其它办法可以做到线程安全呢。

前面我们说到,对于这个case来说,使用count++会造成线程不安全,那是因为多个线程都在争用同一个资源count。我们可以使用“避免”的思想,使得一个线程只用自己的资源,不去用别人的资源就好啦,这样子就不会存在线程安全问题了。

我们使用ThreadLocal,修改一下代码:

@RestController
@RequestMapping("orders")
public class OrderController {

    private static final ThreadLocal<Integer> TL = ThreadLocal.withInitial(() -> 0);

    @GetMapping("/visit")
    public Integer visit() throws InterruptedException {
        Thread.sleep(100);
        TL.set(TL.get() + 1);
        return 0;
    }

    @GetMapping("/stat")
    public Integer stat() {
        return TL.get();
    }
}
复制代码

同样用ab测一下。

$ ab -n 10000 -c 100 localhost:8080/orders/visit

$ curl localhost:8080/orders/stat
99(base)
复制代码

当我们访问统计量接口,发现只能得到当前线程的统计量。那我们怎么才能得到所有线程加起来的统计量总和呢?

这个功能ThreadLocal并没有实现,需要我们自己编写代码辅助。其实思路很简单,我们只需要把每个线程对应的value的引用,放到一个统一的容器里面,然后我们需要用的时候从这个容器取出来遍历一遍就好了。

首先,我们尝试使用一个HashSet来保存这个值。这里需要注意的是我们在初始化这个值的时候需要加锁。因为HashSet并不是线程安全的。

@RestController
@RequestMapping("orders")
public class OrderController {

    private static final Set<Integer> SET = new HashSet<>();
    private static final ThreadLocal<Integer> TL = ThreadLocal.withInitial(() -> {
        Integer value = 0;
        addSet(value);
        return value;
    });
    
    private static synchronized void addSet(Val<Integer> val) {
        SET.add(val);
    }

    @GetMapping("/visit")
    public Integer visit() throws InterruptedException {
        Thread.sleep(100);
        TL.set(TL.get() + 1);
        return 0;
    }

    @GetMapping("/stat")
    public Integer stat() {
        return SET.stream().reduce(Integer::sum).orElse(-1);
    }
}
复制代码

但是我们测试一下发现,好像并不生效,stat结果总是0。为什么呢?

因为Integer有些特殊,它是一个原生类型int的封装类,它内部有一个缓存,当它的值比较小(-128~127)的时候,使用的是同一个对象。而+1操作也不会改变原来引用对应的值。所以它不能作为一个正常的引用对象来使用。

那如何解决这个问题?很简单,我们在外面给他包一层对象就好了。

public class Val<T> {
    T v;

    public T getV() {
        return v;
    }

    public void setV(T v) {
        this.v = v;
    }
}

@RestController
@RequestMapping("orders")
public class OrderController {

    private static final Set<Val<Integer>> SET = new HashSet<>();
    private static final ThreadLocal<Val<Integer>> TL = ThreadLocal.withInitial(() -> {
        Val<Integer> val = new Val<>();
        val.setV(0);
        addSet(val);
        return val;
    });

    private static synchronized void addSet(Val<Integer> val) {
        SET.add(val);
    }

    @GetMapping("/visit")
    public Integer visit() throws InterruptedException {
        Thread.sleep(100);
        Val<Integer> val = TL.get();
        val.setV(val.getV() + 1);
        return 0;
    }

    @GetMapping("/stat")
    public Integer stat() {
        return SET.stream().map(Val::getV).reduce(Integer::sum).orElse(-1);
    }
}
复制代码

然后我们再测试一下,发现可以得到我们预期的结果。

$ ab -n 10000 -c 100 localhost:8080/orders/visit

$ curl localhost:8080/orders/stat
10000(base)
复制代码

有些同学可能会疑惑。那这个比起直接使用synchronized或者原子类,孰优孰劣呢?

其实两者用的思想不一样,上锁和原子类使用的是排队的思想,而ThreadLocal使用的是避免的思想。它通过自己的一个设计哲学避免了线程的争用,所以效率也会比较高。要知道,排队是很危险的,一旦你的临界区比较耗时,很有可能造成大量线程阻塞,导致系统不可用。

临界区:多个线程争用资源的区域,同时只能有一个线程运行那部分代码。

我们这个case由于线程争用的资源很简单,临界区就是一个Integer类型的变量,所以看不太出来使用ThreadLocal的优势。但如果临界区的消耗较大,ThreadLocal的优势就体现出来了。大家可以尝试在前面的synchronized方法中sleep 100ms试一下效果。

虽然ThreadLocal不一定能避免所有的线程安全问题,比如这个case,我们在初始化addSet的时候,仍然要同步上锁。但是他可以把线程安全的问题缩小范围,提升性能。

那么你get到使用ThreadLocal的精髓了吗?还有哪些场景可以使用ThreadLocal呢?下篇文章我们会解析主流框架的源码,看看大神们是如何使用ThreadLocal的。

关于作者

我是Yasin,一个有颜有料又有趣的程序员。

微信公众号:编了个程

个人网站:http://yasinshaw.com

关注我的公众号,和我一起成长~

公众号
公众号