Go Concurrency

Table of Contents

1 Go并发简介

本文介绍Go的并发编程,主要参考“Concurrency in Go”,推荐直接阅读原书。

1.1 Race Conditions(竞争条件)

下面是一段存在竞争条件的代码:

 1: // 警告:下面代码存在竞争条件!
 2: package main
 3: 
 4: import "fmt"
 5: 
 6: func main() {
 7: 	var data int
 8: 
 9: 	go func() {
10: 		data++
11: 	}()
12: 
13: 	if data == 0 {
14: 		fmt.Printf("the value is %v.\n", data)
15: 	}
16: }

上面代码存在竞争条件,有3个可能输出(尽管你每次可能得到相同输出):
1、什么都没有输出,这个情况下,第10行执行在第13行前。
2、输出了"the value is 0",这个情况下,第13行和第14行执行在第10行前。
3、输出了"the value is 1",这个情况下,第13行执行在第10行前,同时第10行执行在第14行前。

1.1.1 Race Detection(-race)

在Go 1.1中,给大部分go命令引入了 -race 参数,可以检测竞争条件。

$ go test -race mypkg    # test the package
$ go run -race mysrc.go  # compile and run the program
$ go build -race mycmd   # build the command
$ go install -race mypkg # install the package

下面是测试上一节代码中的竞争条件的例子:

$ go run -race bad.go                   # 增加 -race 参数可检测竞争条件
the value is 0.
==================
WARNING: DATA RACE
Write at 0x00c4200a4008 by goroutine 6:
  main.main.func1()
      /Users/cig01/test/bad.go:9 +0x4e

Previous read at 0x00c4200a4008 by main goroutine:
  main.main()
      /Users/cig01/test/bad.go:12 +0x88

Goroutine 6 (running) created at:
  main.main()
      /Users/cig01/test/bad.go:8 +0x7a
==================
Found 1 data race(s)
exit status 66

注1: -race 不是银弹, 只有当Race Conditions已经发生的情况下 -race 才能检测出来。
注2:默认检测报告输出在stderr中,可以通过在环境变量 GORACE 中指定log_path把检测报告保存到文件中,如:

$ GORACE="log_path=/tmp/go_race_report" go run -race bad.go
the value is 0.
exit status 66

这样,检测报告会保存到文件“/tmp/go_race_report.pid”中。

参考:Data Race Detector

1.2 Go内存模型

和Java类似,Go语言也定义了自己的内存模型。内存模型规定了哪些情况下不同goroutine读写同一变量时是安全的。详情可参考:The Go Memory Model

1.3 sync VS. channel

如何在sync库和channel之间选择呢?图 1 可以帮你决策。关于这个图的详细说明可参考:Concurrency in Go, Chapter 2

go_concurrency_sync_vs_channel.png

Figure 1: Decision tree for sync and channel

2 sync包

Go中的 sync 库提供了一些基本的同步原语。除了Once类型和WaitGroup类型外,包中的大多数其它类型(如Mutex,Cond等)都是为底层函数库程序准备的,高层次的同步最好还是通过channel来完成。

2.1 sync.Once(全局唯一性操作)

对于从全局的角度只需要运行一次的代码,比如全局初始化操作,Go语言提供了一个Once类型来保证全局的唯一性操作,具体代码实例如下:

var a string
var once sync.Once

func setup() {
    a = "hello, world"
}

func doprint() {
    once.Do(setup)           // 可保证setup方法只会被调用一次,就算多次调用doprint
    print(a)
}

func twoprint() {
    go doprint()
    go doprint()
}

2.2 sync.WaitGroup

可以认为WaitGroup内部维护着一个并发安全的计数器,调用 Add(n) 时会把计数器增加 n ,调用 Done() 时会把计数器减少1。调用 Wait() 会阻塞执行直到计数器变为0。

下面是使用sync.WaitGroup的例子:

package main

import (
	"net/http"
	"fmt"
	"sync"
)

func main() {
	var wg sync.WaitGroup
	var urls = []string{
        "http://www.golang.org/",
        "http://www.google.com/",
        "http://www.somestupidname.com/",
        "http://nosuch/",
	}
	for _, url := range urls {
        // Increment the WaitGroup counter.
        wg.Add(1)
        // Launch a goroutine to fetch the URL.
        go func(url string) {
			// Decrement the counter when the goroutine completes.
			defer wg.Done()
			// Fetch the URL.
			response, err := http.Get(url)
			if err != nil {
				fmt.Printf("Error=%s\n", err)
			} else {
				fmt.Printf("StatusCode=%d, url=%s\n", response.StatusCode, url)
			}
        }(url)
	}
	// Wait for all HTTP fetches to complete.
	wg.Wait()
}

运行上面程序,可能得到下面输出:

StatusCode=502, url=http://nosuch/
StatusCode=200, url=http://www.google.com/
StatusCode=200, url=http://www.somestupidname.com/
StatusCode=200, url=http://www.golang.org/

下面是使用sync.WaitGroup的另一个例子:

package main

import (
	"fmt"
	"sync"
)

func main() {
	hello := func(wg *sync.WaitGroup, id int) {
		defer wg.Done()
		fmt.Printf("Hello from %v!\n", id)
	}

	const numGreeters = 5
	var wg sync.WaitGroup
	wg.Add(numGreeters)
	for i := 0; i < numGreeters; i++ {
		go hello(&wg, i+1)
	}

	wg.Wait()
}

运行上面程序,可能得到下面输出:

Hello from 5!
Hello from 3!
Hello from 1!
Hello from 4!
Hello from 2!

2.3 sync.Mutex, sync.RWMutex(锁)

下面是使用sync.RWMutex实现一个并发安全计数器的例子:

package main

import (
    "fmt"
    "sync"
    "time"
)

// SafeCounter is safe to use concurrently.
type SafeCounter struct {
    value int
    mux sync.RWMutex
}

// Inc increments the counter
func (c *SafeCounter) Inc() {
    c.mux.Lock()              // 加锁(“写锁”)
    c.value++
    c.mux.Unlock()
}

// Value returns the current value of the counter for the given key.
func (c *SafeCounter) GetValue() int {
    c.mux.RLock()              // 加“读锁”
    defer c.mux.RUnlock()
    return c.value
}

func main() {
    counter := SafeCounter{value: 0}

    for i := 0; i < 1000; i++ {        // 启动1000个goroutine,每个都对counter增加1
        go counter.Inc()
    }

    time.Sleep(1 * time.Second)        // 等待1000个goroutine结束,这里仅简单地等待1秒
    fmt.Println(counter.GetValue())    // 输出counter的值
}

2.4 sync.Cond(条件变量)

条件变量用来更好地管理“获得了锁,但由于某些条件未满足而无事可做的goroutine”。

下面先介绍不使用条件变量的例子(生产者往队列items中添加元素,而消费者删除items中的元素):

 1: package main
 2: 
 3: import (
 4: 	"fmt"
 5: 	"math/rand"
 6: 	"sync"
 7: 	"time"
 8: )
 9: 
10: func main() {
11: 
12: 	mutex := new(sync.Mutex)
13: 	var items []int
14: 
15: 	go func() {
16: 		// 消费者代码(删除队列items中的元素)
17: 		for {
18: 			mutex.Lock()
19: 			for len(items) > 0 { // 消耗队列items中所有元素
20: 				lastIndex := len(items) - 1
21: 				lastItem := items[lastIndex]
22: 				items = items[:lastIndex] // 删除最后一个元素
23: 				fmt.Printf("consume %d\n", lastItem)
24: 			}
25: 			mutex.Unlock()
26: 
27: 			// 必需Sleep,否则当队列items为空时CPU占用会很高
28: 			time.Sleep(10 * time.Millisecond)  // 很难决定Sleep多久合适!
29: 		}
30: 	}()
31: 
32: 	// 生产者代码(往队列items中增加元素)
33: 	for {
34: 		newItem := rand.Intn(10000)
35: 
36: 		mutex.Lock()
37: 		items = append(items, newItem)
38: 		fmt.Printf("produce %d\n", newItem)
39: 		mutex.Unlock()
40: 
41: 		// Other works
42: 		time.Sleep(time.Duration(rand.Intn(10)) * time.Second)
43: 	}
44: }

在消费者代码相关片断中,如果队列items一直为空,则第19行的for循环条件不会满足,则消费者就在不停地Lock/Unlock/Sleep,白白浪费CPU资源, 第28行Sleep多久是很难决定的(Sleep太短了,很可能浪费更多的CPU资源;Sleep太长了,队列中的元素不会被及时地消费掉)。

条件变量可以解决上面的问题:可实现当队列中有元素了,由生产者通知消费者!具体实现参考下表右半部分。

 不用条件变量                                                     使用条件变量(更好)                                                                   
 package main                                                   
                                                                
 import (                                                       
     "fmt"                                                      
     "math/rand"                                                
     "sync"                                                     
     "time"                                                     
 )                                                              
                                                                
 func main() {                                                  
                                                                
     mutex := new(sync.Mutex)                                   
     var items []int                                            
                                                                
     go func() {                                                
         // 消费者代码(删除队列items中的元素)                 
         for {                                                  
             mutex.Lock()                                       
                                                                
                                                                
                                                                
                                                                
                                                                
             for len(items) > 0 { // 消耗队列items中所有元素    
                 lastIndex := len(items) - 1                    
                 lastItem := items[lastIndex]                   
                 items = items[:lastIndex] // 删除最后一个元素  
                 fmt.Printf("consume %d\n", lastItem)           
             }                                                  
             mutex.Unlock()                                     
                                                                
             // 必需Sleep,否则当队列items为空时CPU占用会很高   
             time.Sleep(10 * time.Millisecond)                  
         }                                                      
     }()                                                        
                                                                
     // 生产者代码(往队列items中增加元素)                     
     for {                                                      
         newItem := rand.Intn(10000)                            
                                                                
         mutex.Lock()                                           
         items = append(items, newItem)                         
         fmt.Printf("produce %d\n", newItem)                    
         mutex.Unlock()                                         
                                                                
                                                                
                                                                
         // Do other works                                      
         time.Sleep(time.Duration(rand.Intn(10)) * time.Second) 
     }                                                          
 }                                                              
 package main                                                                           
                                                                                        
 import (                                                                               
     "fmt"                                                                              
     "math/rand"                                                                        
     "sync"                                                                             
     "time"                                                                             
 )                                                                                      
                                                                                        
 func main() {                                                                          
                                                                                        
     cond := sync.NewCond(new(sync.Mutex))                                              
     var items []int                                                                    
                                                                                        
     go func() {                                                                        
         // 消费者代码(删除队列items中的元素)                                         
         for {                                                                          
             cond.L.Lock()                                                              
             for len(items) == 0 {                                                      
                 cond.Wait() // 它有3个作用:(1) 解锁,即相当于cond.L.Unlock()          
                             // (2) 阻塞当前goroutine直到其它goroutine调用cond.Signal() 
                             // (3) 再次加锁,即相当于cond.L.Lock()                     
             }                                                                          
             for len(items) > 0 { // 消耗队列items中所有元素                            
                 lastIndex := len(items) - 1                                            
                 lastItem := items[lastIndex]                                           
                 items = items[:lastIndex] // 删除最后一个元素                          
                 fmt.Printf("consume %d\n", lastItem)                                   
             }                                                                          
             cond.L.Unlock()                                                            
                                                                                        
                                                                                        
                                                                                        
         }                                                                              
     }()                                                                                
                                                                                        
     // 生产者代码(往队列items中增加元素)                                             
     for {                                                                              
         newItem := rand.Intn(10000)                                                    
                                                                                        
         cond.L.Lock()                                                                  
         items = append(items, newItem)                                                 
         fmt.Printf("produce %d\n", newItem)                                            
         cond.L.Unlock()                                                                
                                                                                        
         cond.Signal()     // 唤醒cond.Wait(),这里使用cond.Broadcast()也行             
                                                                                        
         // Do other works                                                              
         time.Sleep(time.Duration(rand.Intn(10)) * time.Second)                         
     }                                                                                  
 }                                                                                      

2.5 sync.Pool(临时对象池)

Go 1.3在sync包中增加了Pool,它主要用来保存和复用临时对象,以减少内存分配,降低GC压力。

调用 Get() 方法当Pool中有对象时会返回Pool中的任意一个对象,且在返回给调用者前从Pool中删除相应的对象。如果Pool为空,则调用New返回一个新创建的对象;如果没有设置New,则返回nil。调用 Put(x) 方法可以把对象x放入到Pool中。

注: 我们不能自由控制Pool中元素的数量,且放进Pool中的对象在每次GC发生时都会被清理掉。 如果用它来实现数据库连接池有点心有余而力不足,比如:在高并发时一旦Pool中的连接被GC清理掉,后面操作数据库需要重新建立连接,代价太大。

下面是sync.Pool的一个例子:

 1: package main
 2: 
 3: import (
 4: 	"fmt"
 5: 	"sync"
 6: 	"runtime"
 7: )
 8: 
 9: func main() {
10: 	myPool := &sync.Pool{
11: 		New: func() interface{} {
12: 			fmt.Println("Creating new instance.")
13: 			return struct{}{}
14: 		},
15: 	}
16: 
17: 	fmt.Println("line 17")
18: 	myPool.Get()              // myPool中没对象,调用Get()会触发New()
19: 	fmt.Println("line 19")
20: 	myPool.Get()              // myPool中没对象,调用Get()会触发New()
21: 	fmt.Println("line 21")
22: 	myPool.Get()              // myPool中没对象,调用Get()会触发New()
23: 	fmt.Println("line 23")
24: 	myPool.Put(myPool.New())  // 加入一个对象到myPool中
25: 	fmt.Println("line 25")
26: 	myPool.Put(myPool.New())  // 加入一个对象到myPool中
27: 	fmt.Println("line 27")
28: 	myPool.Get()              // myPool中有两个对象,调用Get()会取出其中一个对象,Pool中还剩下一个对象
29: 	fmt.Println("line 29")
30: 	runtime.GC()              // 执行GC时会清除myPool中所有对象
31: 	myPool.Get()              // 此时,myPool中没对象,调用Get()会触发New()
32: 	fmt.Println("line 32")
33: }

运行上面程序,将得到下面输出:

line 17
Creating new instance.
line 19
Creating new instance.
line 21
Creating new instance.
line 23
Creating new instance.
line 25
Creating new instance.
line 27
line 29
Creating new instance.
line 32

下面是官方文件中介绍的sync.Pool的应用举例:

An example of good use of a Pool is in the fmt package, which maintains a dynamically-sized store of temporary output buffers. The store scales under load (when many goroutines are actively printing) and shrinks when quiescent.

2.6 sync.Map

Go 1.9在sync包中增加了并发安全的 Map 。使用方法 Store/Load/Delete 可分别实现增加、获取、删除Map中的元素。

下面是sync.Map的简单例子:

package main

import (
    "fmt"
    "sync"
)

func main() {
	var sm sync.Map

	sm.Store("key1", "value1")
	sm.Store("key2", 2)

	// Fetch an item that doesn't exist yet.
	result, ok := sm.Load("key1")
	if ok {
		fmt.Printf("key1=%v\n", result)
	} else {
		fmt.Printf("value not found for key: key1\n")
	}

	result, ok = sm.Load("key2")
	if ok {
		fmt.Printf("key2=%v\n", result)
	} else {
		fmt.Printf("value not found for key: key2\n")
	}

	sm.Delete("key2")

	result, ok = sm.Load("key2")
	if ok {
		fmt.Printf("key2=%v\n", result)
	} else {
		fmt.Printf("value not found for key: key2\n")
	}
}

运行上面程序,将得到下面输出:

key1=value1
key2=2
value not found for key: key2

注:目前Go中没有泛型,无法优雅地限制sync.Map中key和value的类型。

3 Concurrency Patterns in Go

3.1 Confinement

3.1.1 Ad hoc confinement

所谓Ad hoc confinement,就是通过一些约定(比如社区的约定,开发小组的约定)来实现的并发安全。但这些约定是很脆弱的,随着开发者的变化,很可能一不小心就打破了约定。 我们应该避免Ad hoc confinement。

比如,下面代码对data的访问是安全的,因为只有loopData访问了它(可以认为这是一个约定)。但loopData以外的其它代码也是可以访问data的,而一旦loopData外的代码修改了data,则程序的行为可以变得不确定。

data := make([]int, 4)

loopData := func(ch chan<- int) {
	defer close(ch)
	for i := range data {
		ch <- data[i]
	}
}

handleData := make(chan int)
go loopData(handleData)
for num := range handleData {
	fmt.Println(num)
}

3.1.2 Lexical confinement

Lexical confinement是指通过“词法作用域来限制数据的访问,从而实现并发安全”。它比Ad hoc confinement要安全,推荐使用。如:

package main

import "fmt"

func main() {
	chanOwner := func() <-chan int {     // chanOwner返回一个read-only channel
		results := make(chan int, 5)     // results是chanOwner的局部变量
		go func() {                      // 仅chanOwner函数里定义的closure才能往results中写内容
			defer close(results)         // 正确情况下,chanOwner外代码无法往results中写内容
			for i := 0; i <= 5; i++ {
				results <- i
			}
		}()
		return results
	}

	consumer := func(results <-chan int) { // 对results仅有读的权限
		for result := range results {
			fmt.Printf("Received: %d\n", result)
		}
		fmt.Println("Done receiving!")
	}

	results := chanOwner()
	consumer(results)
}

3.2 Error Handling

在并发程序中,恰当的错误处理不是一件容易的事。“Who should be responsible for handling the error?”这是一个较难回复的问题。

我们考虑下面程序:

 1: package main
 2: 
 3: import (
 4: 	"fmt"
 5: 	"net/http"
 6: )
 7: 
 8: func main() {
 9: 	checkStatus := func(urls ...string) <-chan *http.Response {
10: 		responses := make(chan *http.Response)
11: 		go func() {
12: 			defer close(responses)
13: 			for _, url := range urls {
14: 				resp, err := http.Get(url)
15: 				if err != nil {             // 在这里处理错误不合适,后面介绍更好的方式
16: 					fmt.Println(err)
17: 					continue
18: 				}
19: 				responses <- resp
20: 			}
21: 		}()
22: 		return responses
23: 	}
24: 
25: 	urls := []string{"https://www.bing.com", "https://badhost"}
26: 	for response := range checkStatus(urls...) {
27: 		fmt.Printf("Response: %v\n", response.Status)
28: 	}
29: }

上面程序中错误处理在第15行到第18行。但放在这个位置处理其实不太合适, 我们应该把错误留给对整个程序了解更全面的goroutine来处理。这样,方便我们根据错误做出进一步的决策。

下面是一种更好的处理方式:

package main

import (
	"fmt"
	"net/http"
)

type Result struct {
	Error    error
	Response *http.Response
}

func main() {
	checkStatus := func(urls ...string) <-chan Result {
		results := make(chan Result)
		go func() {
			defer close(results)
			for _, url := range urls {
				resp, err := http.Get(url)
				results <- Result{Error: err, Response: resp}    // 把错误从当前goroutine中propagate出去
			}
		}()
		return results
	}

	errCount := 0
	urls := []string{"https://www.bing.com", "https://badhost"}
	for result := range checkStatus(urls...) {
		if result.Error != nil {                                 // 在这里处理错误更合适。这个位置掌握了更多的信息,比如可以方便地统计错误总数等
			fmt.Printf("error: %v\n", result.Error)
			errCount++
			continue
		}
		fmt.Printf("Response: %v\n", result.Response.Status)
	}
}

3.3 Explicit Cancellation (The done channel)


goroutine很轻量级,占用内存很少。不过,处理不当也可能造成goroutine不退出,从而导致内存泄漏。

If a goroutine is responsible for creating a goroutine, it is also responsible for ensuring it can stop the goroutine.

3.3.1 读channel可能导致goroutine不退出

看下面程序:

// 下面程序是一个反例(goroutine不退出)
package main

import (
	"fmt"
)

func main() {
	doWork := func(strings <-chan string) <-chan interface{} {
		completed := make(chan interface{})
		go func() {
			defer fmt.Println("doWork exited.")
			defer close(completed)
			for s := range strings {
				// Do something interesting
				fmt.Println(s)
			}
		}()
		return completed
	}

	doWork(nil)

	// 下面两个代码(已经注释掉了)会导致deadlock!
	// completed := doWork(nil)
	// <-completed  // 让main goroutine等待doWork中创建的goroutine

	// Perhaps more work is done here
	fmt.Println("Done.")
}

运行上面程序,会输出(注:并不会输出“doWork exited.”):

Done.

上面程序,在main goroutine中传递了 nildoWork ,这样 doWork 中的通道 strings 永远读取不到字符串,从而 doWork 中的那个goroutine永远不会自动结束(它的生命期会维持到整个程序结束。就这个例子而言,由于main很快就结束了,问题并不大。不过,真实的服务器程序很可能会运行很长时间,这时问题就严重了)。

可以用一个 done 通道来解决上面goroutine不会结束的问题,代码如下:

package main

import (
	"fmt"
	"time"
)

func main() {
	doWork := func(done <-chan interface{}, strings <-chan string) <-chan interface{} {
		completed := make(chan interface{})
		go func() {
			defer fmt.Println("doWork exited.")
			defer close(completed)
			for {
				select {
				case s := <-strings:
					// Do something interesting
					fmt.Println(s)
				case <-done: // 注:往done通道写任意内容,或者关闭done,都会进入这个分支
					return
				}
			}
		}()
		return completed
	}

	done := make(chan interface{})
	completed := doWork(done, nil)

	// Cancel the operation after 1 second.
	time.Sleep(1 * time.Second)
	fmt.Println("Canceling doWork goroutine...")
	// 关闭done会使doWork中启动的goroutine退出
	// 往done通道里写任意内容(如done<- "anything")也会使doWork中启动的goroutine退出
	close(done)

	<-completed
	time.Sleep(1 * time.Second)

	// Perhaps more work is done here
	fmt.Println("Done.")
}

运行上面程序,会输出:

Canceling doWork goroutine...
doWork exited.
Done.

3.3.2 写channel可能导致goroutine不退出

看下面程序:

// 下面程序是一个反例(goroutine不退出)
package main

import (
	"fmt"
	"time"
	"math/rand"
)

func main() {
	newRandStream := func() <-chan int {
		randStream := make(chan int)
		go func() {
			defer fmt.Println("newRandStream closure exited.")
			defer close(randStream)
			for {
				randStream <- rand.Int()
			}
		}()
		return randStream
	}

	randStream := newRandStream()
	fmt.Println("3 random ints:")
	for i := 1; i <= 3; i++ {
		fmt.Printf("%d: %d\n", i, <-randStream)
	}

	// Simulate other ongoing work
	time.Sleep(1 * time.Second)
}

运行上面程序,会输出(注:并不会输出“newRandStream closure exited.”):

3 random ints:
1: 5577006791947779410
2: 8674665223082153551
3: 6129484611666145821

在上面程序中,newRandStream中启动的goroutine不会退出(它会一直存在,直到整个程序结束)。

和上一节类似,可以用一个 done 通道来解决newRandStream中启动的goroutine不会结束的问题,代码如下:

package main

import (
	"fmt"
	"time"
	"math/rand"
)

func main() {
	newRandStream := func(done <-chan interface{}) <-chan int {
		randStream := make(chan int)
		go func() {
			defer fmt.Println("newRandStream closure exited.")
			defer close(randStream)
			for {
				select {
				case randStream <- rand.Int():
				case <-done:
					return
				}
			}
		}()
		return randStream
	}

	done := make(chan interface{})
	randStream := newRandStream(done)
	fmt.Println("3 random ints:")
	for i := 1; i <= 3; i++ {
		fmt.Printf("%d: %d\n", i, <-randStream)
	}
	close(done)

	// Simulate other ongoing work
	time.Sleep(1 * time.Second)
}

运行上面程序,会输出:

3 random ints:
1: 5577006791947779410
2: 8674665223082153551
3: 6129484611666145821
newRandStream closure exited.

3.4 Pipelines

A pipeline is nothing more than a series of things that take data in, perform an operation on it, and pass the data back out. We call each of these operations a stage of the pipeline.

3.4.1 pipeline实例

下面是pipeline的一个实例,功能是对每个数据乘以2,再加上1,再乘以2,然后输出。

package main

import "fmt"

func main() {
	generator := func(done <-chan interface{}, integers ...int) <-chan int {
		intStream := make(chan int)
		go func() {
			defer close(intStream)
			for _, i := range integers {
				select {
				case <-done:
					return
				case intStream <- i:
				}
			}
		}()
		return intStream
	}

	// 下面是pipeline的一个stage
	multiply := func(
		done <-chan interface{},
		intStream <-chan int,
		multiplier int,
	) <-chan int {
		multipliedStream := make(chan int)
		go func() {
			defer close(multipliedStream)
			for i := range intStream {
				select {
				case <-done:
					return
				case multipliedStream <- i * multiplier:
				}
			}
		}()
		return multipliedStream
	}

	// 下面是pipeline的另一个stage
	add := func(
		done <-chan interface{},
		intStream <-chan int,
		additive int,
	) <-chan int {
		addedStream := make(chan int)
		go func() {
			defer close(addedStream)
			for i := range intStream {
				select {
				case <-done:
					return
				case addedStream <- i + additive:
				}
			}
		}()
		return addedStream
	}

	done := make(chan interface{})
	defer close(done)
	intStream := generator(done, 1, 2, 3, 4)

	// pipeline组合了多个stages
	pipeline := multiply(done, add(done, multiply(done, intStream, 2), 1), 2)

	for v := range pipeline {
		fmt.Println(v)
	}
}

运行上面程序,会输出:

6
10
14
18

使用pipeline时,每个值进入不同channel的时机如表 1 所示。

Table 1: 使用pipeline时,每个值进入不同channel的时机
Iteration Generator Multiply Add Multiply Value
0 1        
0   1      
0 2   2    
0   2   3  
0 3   4   6
1   3   5  
1 4   6   10
2 (closed) 4   7  
2   (closed) 8   14
3     (closed) 9  
3       (closed) 18

完成前面程序功能,不用Pipeline也行。比如,下面代码也能得到相同的输出,但它比使用pipeline的版本要慢。

// 没有使用pipeline,比使用pipeline的版本要慢
// 每次迭代结束后,才进行下一次迭代;使用pipeline时,只要有数据,就会进入下一个stage,所以更快
package main

import "fmt"

func main() {
	multiply := func(values []int, multiplier int) []int {
		multipliedValues := make([]int, len(values))
		for i, v := range values {
			multipliedValues[i] = v * multiplier
		}
		return multipliedValues
	}

	add := func(values []int, additive int) []int {
		addedValues := make([]int, len(values))
		for i, v := range values {
			addedValues[i] = v + additive
		}
		return addedValues
	}

	ints := []int{1, 2, 3, 4}
	for _, v := range multiply(add(multiply(ints, 2), 1), 2) {
		fmt.Println(v)
	}
}

3.5 Fan-Out, Fan-In

在Pipeline中,如果某个stage的处理速度太慢,则会响应到整个Pipeline的处理速度。这时,我们可以使用“Fan-Out, Fan-In”来加快某个stage的处理速度。 “多个函数从同一个channel读取数据同时进行处理叫Fan-Out(减少channel中数据的处理时间);一个函数从多个channel读取并把处理结果发送(合并)到一个channel中,称之为Fan-In”。

下面是Fan-Out, Fan-In的例子(代码摘自:https://blog.golang.org/pipelines ):

package main

import (
	"fmt"
	"sync"
)

func gen(nums ...int) <-chan int {
	out := make(chan int)
	go func() {
		for _, n := range nums {
			out <- n
		}
		close(out)
	}()
	return out
}

func sq(in <-chan int) <-chan int {
	out := make(chan int)
	go func() {
		for n := range in {
			out <- n * n
		}
		close(out)
	}()
	return out
}

func merge(cs ...<-chan int) <-chan int {
	var wg sync.WaitGroup
	out := make(chan int)

	// Start an output goroutine for each input channel in cs.  output
	// copies values from c to out until c is closed, then calls wg.Done.
	output := func(c <-chan int) {
		for n := range c {
			out <- n
		}
		wg.Done()
	}
	wg.Add(len(cs))
	for _, c := range cs {
		go output(c)
	}

	// Start a goroutine to close out once all the output goroutines are
	// done.  This must start after the wg.Add call.
	go func() {
		wg.Wait()
		close(out)
	}()
	return out
}

func main() {
	in := gen(2, 3)

	// Distribute the sq work across two goroutines that both read from in.
	c1 := sq(in)
	c2 := sq(in)

	// Consume the merged output from c1 and c2.
	for n := range merge(c1, c2) {
		fmt.Println(n) // 4 then 9, or 9 then 4
	}
}

上面程序中的下面片断:

c1 := sq(in)
c2 := sq(in)

就是“Fan-Out”;而 merge(c1, c2) 就是“Fan-In”。

3.6 The context Package

在节 3.3 中介绍了通过 done 通道来显式地取消goroutine。但有时我们需要传递其它一些额外的上下文信息(如为什么goroutine被取消等等)。如果有一个通用的机制可以取消goroutine,以及传递额外信息该多好。为此,Google为我们提供一个解决方案: context 包(Go 1.7中把context包纳入了标准库中)。

准确地说,context包的功能是“store and retrieve request-scoped data”。 使用context包的例子有很多,如官方http包使用context传递请求的上下文数据,gRPC使用context来终止某个请求产生的goroutine树。

使用context实现上下文功能需要在你的方法的第一个参数传入一个context.Context类型的变量。context.Context类型的定义如下:

// A Context carries a deadline, cancelation signal, and request-scoped values
// across API boundaries. Its methods are safe for simultaneous use by multiple
// goroutines.
type Context interface {
    // Done returns a channel that is closed when this Context is canceled
    // or times out.
    Done() <-chan struct{}

    // Err indicates why this context was canceled, after the Done channel
    // is closed.
    Err() error

    // Deadline returns the time when this Context will be canceled, if any.
    Deadline() (deadline time.Time, ok bool)

    // Value returns the value associated with key or nil if none.
    Value(key interface{}) interface{}
}

我们不用自已实现context.Context接口,context包已经提供了两个函数可以返回Context实例: context.Background()context.TODO() 。这两个函数返回的实例都是空Context。

3.6.1 context.WithValue实例:传递request-scoped data

context.WithValue的原型为: func WithValue(parent Context, key, val interface{}) Context ,通过它可以在Context对象中设置一些属性,然后使用context.Context接口中的 Value(key interface{}) 方法可以读取到属性的值。下面例子演示了它的使用:

package main

import (
	"context"
	"fmt"
)

func main() {

	f := func(ctx context.Context) {
		if v := ctx.Value("USERID"); v != nil {
			fmt.Println("found userid:", v)
		} else {
			fmt.Println("not found userid")
		}
		if v := ctx.Value("AUTHTOKEN"); v != nil {
			fmt.Println("found authtoken:", v)
		} else {
			fmt.Println("not found authtoken")
		}
	}

	fmt.Println("first test...")
	ctx1 := context.WithValue(context.Background(), "USERID", "user1") // key为内置类型,不好!
	ctx1 = context.WithValue(ctx1, "AUTHTOKEN", "token123")            // key为内置类型,不好!
	f(ctx1)

	fmt.Println("second test...")
	ctx2 := context.WithValue(context.Background(), "USERID", "user2") // key为内置类型,不好!
	f(ctx2)
}

运行上面程序,会输出:

first test...
found userid: user1
found authtoken: token123
second test...
found userid: user2
not found authtoken

说明: WithValue 的第二个参数(即key)是 interface{} ,意味着可以是任意类型,不过我们最好不使用内置类型。因为当key是内置类型时,容易出现Collisions,例如(伪代码):

Fun1(ctx) {    // 位于包package1中
    ctx = context.WithValue(ctx, "key1", "value1")
    Fun2(ctx)
}

Fun2(ctx) {    // 位于包package2中
    ctx = context.WithValue(ctx, "key1", "value2")  // 这和Fun1中的"key1"发生了Collisions
    Fun3(ctx)  // Fun3中已经无法通过"key1"得到"value1"了
}

上面伪代码中,由于package2中的函数Fun2“不小心”把"key1"设置到了ctx中,这会导致Fun3无法访问Fun1中对"key1"的设置。下面将介绍如何避免这种情况发生。

3.6.1.1 为key自定义类型(Avoid Collisions Between Packages)

通过自定义key的类型可以避免上节提到的Collisions现象。

我们先回顾一下下面的知识点。假设有代码:

type foo int
type bar int

m := make(map[interface{}]int)
m[foo(1)] = 20           // 把1转换为foo类型(底层类型是int),作为key保存到m中
m[bar(1)] = 30           // 把1转换为bar类型(底层类型也是int),作为key保存到m中
fmt.Printf("%v", m)      // 注:会输出 map[1:20 1:30] ,而不是 map[1:30]

会输出“map[1:20 1:30]”,而不是“map[1:30]”,为什么呢?这是因为尽管foo和bar的底层类型相同(都为int),但Go也认为它们是不同的。

下面程序演示了避免Collisions的技巧:

package main

import (
	"context"
	"fmt"
)

// 最好为context.WithValue的第2个参数定义一个类型(如后面的类型ctxKey),且让这个类型在其它
// 包中不可见!这样可以避免Collisions。不过,由于这个类型在其它包中不可见,我们不得不export
// 一些辅助函数,来设置或获取Context中key对应的值,如后面的函数UserID/AuthToken/SetUserID/
// SetAuthToken (首字母大写的函数对其它包是可见的)
// 这样,其它包中的代码无法直接设置或者访问Context了,从而避免不小心的Collisions。
type ctxKey string

const ctxUserId ctxKey = "USERID"
const ctxAuthToken ctxKey = "AUTHTOKEN"

func UserID(c context.Context) (string, bool) {
	userId, ok := c.Value(ctxUserId).(string)
	return userId, ok
}

func AuthToken(c context.Context) (string, bool) {
	authToken, ok := c.Value(ctxAuthToken).(string)
	return authToken, ok
}

func main() {

	// 为简单起见,这个例子中,f定义在同一个包中。真实环境中,f往往定义在其它包中。
	f := func(ctx context.Context) {
		if v, ok := UserID(ctx); ok == true {
			fmt.Println("found userid:", v)
		} else {
			fmt.Println("not found userid")
		}
		if v, ok := AuthToken(ctx); ok == true {
			fmt.Println("found authtoken:", v)
		} else {
			fmt.Println("not found authtoken")
		}
	}

	fmt.Println("first test...")
	ctx1 := context.WithValue(context.Background(), ctxUserId, "user1")
	ctx1 = context.WithValue(ctx1, ctxAuthToken, "token123")
	f(ctx1)

	fmt.Println("second test...")
	ctx2 := context.WithValue(context.Background(), ctxUserId, "user2")
	f(ctx2)
}

运行上面程序(和上一节代码类似,只是增加了一些封装),会输出:

first test...
found userid: user1
found authtoken: token123
second test...
found userid: user2
not found authtoken

3.6.2 context.WithCancel实例:避免goroutine leak

下面是使用context.WithCancel的例子(代码摘自https://golang.org/pkg/context/#example_WithCancel)。

package main

import (
	"context"
	"fmt"
)

func main() {
	// gen generates integers in a separate goroutine and
	// sends them to the returned channel.
	// The callers of gen need to cancel the context once
	// they are done consuming generated integers not to leak
	// the internal goroutine started by gen.
	gen := func(ctx context.Context) <-chan int {
		dst := make(chan int)
		n := 1
		go func() {
			for {
				select {
				case <-ctx.Done():
					return // returning not to leak the goroutine
				case dst <- n:
					n++
				}
			}
		}()
		return dst
	}

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel() // cancel when we are finished consuming integers

	for n := range gen(ctx) {
		fmt.Println(n)
		if n == 5 {
			break
		}
	}
}

3.6.3 context.WithTimeout实例:设置goroutine的timeout时间

下面例子演示了如何设置goroutine的timeout时间:

package main

import (
	"context"
	"fmt"
	"time"
)

func main() {

	ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
	// 上面这行代码等价于下面两行:
	// d := time.Now().Add(50 * time.Millisecond)
	// ctx, cancel := context.WithDeadline(context.Background(), d)

	defer cancel()

	select {
	case <-time.After(1 * time.Second):
		fmt.Println("overslept")
	case <-ctx.Done():
		fmt.Println(ctx.Err()) // prints "context deadline exceeded"
	}

}

上面代码会输出:

context deadline exceeded

注:context.WithTimeout和context.WithDeadline的关系如下:

func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
	return WithDeadline(parent, time.Now().Add(timeout))
}

3.6.4 context包的使用约定

使用context包的程序需要遵循如下的约定来满足接口的一致性以及便于静态分析:

  • Context变量需要作为第一个参数使用,一般命名为ctx;
  • 即使方法允许,也不要传入一个nil的Context ,如果你不确定你要用什么Context的时候可以传入context.TODO;
  • Value方法只应该用于传递“request-scoped data”,不要用它来传递一些可选的函数参数;
  • 同一个Context可以用来传递到不同的goroutine中,Context在多个goroutine中是安全的。

4 Go Scheduler

4.1 Fork-join

Goroutine的并发采用了 Fork-join 框架,图 2 是Fork-join框架的示意图(图片摘自:A Primer on Scheduling Fork-Join Parallelism with Work Stealing)。

go_concurrency_fork_join.png

Figure 2: Fork-join示意图(spawn/sync来自 Cilk 语言)

在图 2 的记号中,spawn/sync分别代表fork/join。 在Go语言中,用关键字 go 启动goroutine就是fork,而通过channel或者sync包把多个goroutines进行同步则属于join。 下面代码演示了Go的Fork-join:

func worker(done chan bool) {
    fmt.Print("working...")
    time.Sleep(time.Second)
    done <- true
}

func main() {

    done := make(chan bool, 1)
    go worker(done)          // 这是Fork点

	// other work

    <-done                   // 这是Join点(这是通过channel同步,通过其它方式同步也都属于Join)

	// other work
}

下面是计算斐波纳契数的cilk实现和golang实现的对比:

+-------------------------------+-------------------------------------------+
| cilk code                     | golang code                               |
+-------------------------------+-------------------------------------------+
| cilk int fib(int n) {         | fib := func(n int) <-chan int {           |
|     if (n < 2) {              |     result := make(chan int)              |
|         return 1;             |     go func() {                           |
|     } else {                  |         defer close(result)               |
|         int x, y;             |         if n <= 2 {                       |
|                               |             result <- 1                   |
|         x = spawn fib(n - 1); |             return                        |
|         y = spawn fib(n - 2); |         }                                 |
|                               |         result <- <-fib(n-1) + <-fib(n-2) |
|         sync;                 |     }()                                   |
|                               |     return result                         |
|         return x + y;         | }                                         |
|     }                         |                                           |
| }                             | // fmt.Printf("fib(4) = %d", <-fib(4))    |
+-------------------------------+-------------------------------------------+

4.2 Work Stealing算法

Work Stealing算法是实现Fork-join框架的主流方式。Work Stealing算法的基本思想为:
1、每个线程关联着一个任务队列;
2、在Fork点,线程把新任务(如图 2 中的函数f)放入自己的任务队列中(这是Child Stealing策略,后面将介绍);
3、如果线程在Join点还需要等待其它任务执行完成才能继续,则线程就从自己的任务队列中拿出任务来执行;
4、如果线程自己的任务队列为空,则从其它线程的任务队列中“窃取”任务来执行。

4.2.1 Child Stealing vs. Continuation Stealing

有两种“任务窃取”策略:Child Stealing和Continuation Stealing(也称为Parent Stealing)。以图 2(为方便查看将其复制为图 3)为例进行说明, 如果在Fork点,把函数f放入自己的任务队列中(它可能被其它线程“窃取”),而当前线程接着执行函数g,则称为Child Stealing;如果在Fork点,把Continuation(即从spawn下一句开始的所有代码)放入自己的任务队列中(它可能被其它线程“窃取”),而当前线程接着执行函数f,则称为Continuation Stealing。

go_concurrency_fork_join.png

Figure 3: Fork-join两种“任务窃取”策略

Intel TBBMicrosoft PPL 采用Child Stealing策略,Intel Cilk 采用Continuation Stealing策略,而OpenMP默认采用Child Stealing策略,但可以设置为Continuation Stealing策略。

参考:A Primer on Scheduling Fork-Join Parallelism with Work Stealing

4.2.2 Go采用Continuation Stealing策略

Go语言采用的是Continuation Stealing策略。这是因为(摘自“Concurrency in Go”):

Consider this: when creating a goroutine, it is very likely that your program will want the function in that goroutine to execute. It is also reasonably likely that the continuation from that goroutine will at some point want to join with that goroutine. And it’s not uncommon for the continuation to attempt a join before the goroutine has finished completing. Given these axioms, when scheduling a goroutine, it makes sense to immediately begin working on it.

Continuation Stealing策略使得Fork-join看起来像函数调用:如果任务不被“窃取”,图 3 所示代码在同一个线程中的执行顺序为:e() -> f() -> g() -> …,这是不是有点像函数调用?

4.3 “部分”抢占式调度

现代操作系统一般采用“基于定时器中断的抢占式调度”来实现任务调用,当前任务执行完分配给自己的时间片后,会被其它任务抢占。

Go的调度器采用的是一种“部分”抢占式调度,它并不是基于定时器中断来实现抢占的。Go在库函数、系统调用等等很多地方增加了一些hook点,在这些位置可以实现抢占。 真实程序中的Goroutine,一般都有包含hook点的代码。不过,你也可以构造没有hook点的Goroutine,请看下面的例子:

package main

import "fmt"
import "time"
import "runtime"

func main() {
	var x int
	threads := runtime.GOMAXPROCS(0)  // runtime最多使用的执行上下文,假定为8(其它值也无所谓)
	fmt.Println("threads =", threads)
	for i := 0; i < threads; i++ {    // 下面创建8个goroutine(每一个都是死循环)
		go func() {
			for {
				x++
				// runtime.Gosched()  // 取消这一行或者下一行的注释,再执行试试
				// time.Sleep(0)
			}
		}()
	}
	time.Sleep(time.Second)           // goroutine main在sleep 1秒后,打印x,并退出
	fmt.Println("x =", x)
}

上面程序永远不会结束(go版本10中测试是这样,不排除以后go改进后会结束)!

为什么会这样呢?下面假定threads为8(其它值也无所谓)进行分析。
程序创建了8个goroutines再加上隐含的goroutine main,一共有9个goroutines。但是8个goroutine代码很简单,仅一个死循环,并没有调用库函数等可能增加hook点的代码,这样只要它们开始执行就不会主要让出cpu了!而这个例子中,runtime最多使用8个执行上下文(OS线程),前面介绍过Go采用“Continuation Stealing”策略,也就是说用go启动的goroutine会先执行,这样goroutine main已经没有执行上下文可用了,所以goroutine main没有机会再执行了,从而程序不会结束。

如何使上面程序可以正常结束呢?
办法一:仅显式地创建 threads - 1 个goroutine,goroutine main就有机会执行,程序可以正常结束;
办法二:在goroutine的for循环里调用可以增加hook点的函数,比如 time.Sleep(0) ,goroutine main就有机会执行,程序可以正常结束;
办法三:在goroutine的for循环里调用 runtime.Gosched() ,主动让出cpu,goroutine main就有机会执行,程序可以正常结束。

需要说明的是:上面程序存在竞争条件,就算可以正常结束,fmt.Println输出的"x"值可能每次不一样。

参考:A pitfall of golang scheduler


Author: cig01

Created: <2018-03-07 Wed 00:00>

Last updated: <2018-03-25 Sun 19:55>

Creator: Emacs 25.3.1 (Org mode 9.1.4)