爱折腾的WaitGroup

语言: CN / TW / HK

WaitGroup 是Go并发编程中经常使用的做任务编排的一个一个并发原语。看起来它只有几个简单的方法,使用起来比较简单。实际上,WaitGroup的内部实现也陆陆续续改变了好几次,主要是针对它的字段的原子操作不断的做优化。

WaitGroup原始的实现

最早的 WaitGroup 的实现如下:

type WaitGroup struct {
	m       Mutex
	counter int32
	waiters int32
	sema    *uint32
}

func (wg *WaitGroup) Add(delta int) {
	v := atomic.AddInt32(&wg.counter, int32(delta))
	if v <0 {
		panic("sync: negative WaitGroup count")
	}
	if v >0 || atomic.LoadInt32(&wg.waiters) ==0 {
		return
	}
	wg.m.Lock()
	for i := int32(0); i < wg.waiters; i++ {
		runtime_Semrelease(wg.sema)
	}
	wg.waiters =0
	wg.sema = nil
	wg.m.Unlock()
}

它的实现字段的意义比较明确,但是实现还略显粗糙,比如sema采用指针实现。

之后将字段 counterwaiters 合并。为了要保证64bit的原子操作8位对齐, 需要找到state1的对齐点。 sema去掉了指针实现。

type WaitGroup struct {
	// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
	// 64-bit atomic operations require 64-bit alignment, but 32-bit
	// compilers do not ensure it. So we allocate 12 bytes and then use
	// the aligned 8 bytes in them as state.
	state1 [12]byte
	sema   uint32
}

func (wg *WaitGroup) state() *uint64 {
	if uintptr(unsafe.Pointer(&wg.state1))%8 ==0 {
		return (*uint64)(unsafe.Pointer(&wg.state1))
	} else {
		return (*uint64)(unsafe.Pointer(&wg.state1[4]))
	}
}

后来, WaitGroup 实现如下,并稳定下来:

type WaitGroup struct {
	noCopy noCopy

	// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
	// 64-bit atomic operations require 64-bit alignment, but 32-bit
	// compilers do not ensure it. So we allocate 12 bytes and then use
	// the aligned 8 bytes in them as state, and the other 4 as storage
	// for the sema.
	state1 [3]uint32
}

// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if uintptr(unsafe.Pointer(&wg.state1))%8 ==0 {
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}

state1 和 sema字段合并成一个字段 state1 , 这个数组是uint32,四字节。所以要么是第一个元素就是8byte对齐的,要么就是第二个元素是8byte对齐的。找到对齐的8byte,剩余的4byte就作为sema。

这个实现没有问题,就是有些饶人。因为你不得不检查state1的对齐,才能确定哪个是counter和waiters,哪个是sema。

问个问题: WaitGroup的waiter数最多是多大?

Go 1.18的改变

在Go 1.18中, WaitGroup又做了改变,针对64bit架构的环境,编译器保证伟uint64类型的字段按照8byte对齐。

type WaitGroup struct {
	noCopy noCopy

	// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
	// 64-bit atomic operations require 64-bit alignment, but 32-bit
	// compilers only guarantee that 64-bit fields are 32-bit aligned.
	// For this reason on 32 bit architectures we need to check in state()
	// if state1 is aligned or not, and dynamically "swap" the field order if
	// needed.
	state1 uint64
	state2 uint32
}

当然为了兼容32bit的架构,还是需要判断一下对齐:

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if unsafe.Alignof(wg.state1) ==8 || uintptr(unsafe.Pointer(&wg.state1))%8 ==0 {
		// state1 is 64-bit aligned: nothing to do.
		return &wg.state1, &wg.state2
	} else {
		// state1 is 32-bit aligned but not 64-bit aligned: this means that
		// (&state1)+4 is 64-bit aligned.
		state := (*[3]uint32)(unsafe.Pointer(&wg.state1))
		return (*uint64)(unsafe.Pointer(&state[1])), &state[0]
	}
}

总体上来说,在linux/amd64环境中,此修改会带来 9%~30%的性能提升。

Go 1.20中的改变

优化还未万。在Go 1.19中, Russ Cox实现了atomic.Uint64,它在64bit架构和32bit架构下都是8byte对齐的,为啥呢?因为它有一个"尚方宝剑": align64

// An Uint64 is an atomic uint64. The zero value is zero.
type Uint64 struct {
	_ noCopy
	_ align64
	v uint64
}

64bit架构下没有问题,32bit架构下看到这个字段,Go编译器就会自动把它按照8byte对齐,这是一个约定。你在你的package下定义struct加上 align64 是没有用的。

不过如果你也想让你的struct 8byte对齐的话,你可以使用下面的技术:

import "sync/atomic"

type T struct {
    _ [0]atomic.Int64 // 占用0字节,但是隐含字段是8byte对齐的
    x uint64 // x是8byte对齐的
}

这样依赖, WaitGroup的实现又可以简化成了:

type WaitGroup struct {
	noCopy noCopy

	state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
	sema  uint32
}

也不必实现单独的 state() 方法了。直接使用state字段即可(去除了race代码):

func (wg *WaitGroup) Add(delta int) {
	state := wg.state.Add(uint64(delta) <<32)
	v := int32(state >>32)
	w := uint32(state)

	if v <0 {
		panic("sync: negative WaitGroup counter")
	}
	if w !=0 && delta >0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	if v >0 || w ==0 {
		return
	}
	// This goroutine has set counter to 0 when waiters > 0.
	// Now there can't be concurrent mutations of state:
	// - Adds must not happen concurrently with Wait,
	// - Wait does not increment waiters if it sees counter == 0.
	// Still do a cheap sanity check to detect WaitGroup misuse.
	if wg.state.Load() != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// Reset waiters count to 0.
	wg.state.Store(0)
	for ; w !=0; w-- {
		runtime_Semrelease(&wg.sema, false,0)
	}
}