Recently I was building a application that would tick on an interval and on each tick, produce potentially thousands of goroutines. I wanted to ensure when the application was terminated, it would exit gracefully and quickly, even if particular goroutines were processing slowly.
Initially, I was using sync.WaitGroup
to control flow, primarily around how I log
output, but I quickly realized that if I created many goroutines and even a small
collection of them did not return immediately, my application would hang when being
terminated. This led me to reviewing context.WithCancel
and understanding how I can
adjust my application to be well suited for quick and graceful termination!
We can demonstrate this by building up from an application which, initially, does not use either:
package main
import (
"fmt"
"log"
"math/rand"
"os"
"os/signal"
"syscall"
"time"
)
func doSomething(ch chan int) {
fmt.Printf("Received job %d\n", <-ch)
}
func init() {
rand.Seed(time.Now().Unix())
}
func main() {
var (
closing = make(chan struct{})
ticker = time.NewTicker(1 * time.Second)
logger = log.New(os.Stderr, "", log.LstdFlags)
batchSize = 6
jobs = make(chan int, batchSize)
)
go func() {
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGTERM, os.Interrupt)
<-signals
close(closing)
}()
loop:
for {
select {
case <-closing:
break loop
case <-ticker.C:
for n := 0; n < batchSize; n++ {
jobs <- n
go doSomething(jobs)
}
logger.Printf("Completed doing %d things.", batchSize)
}
}
}
When the program is run, we observe the sequence of "Received job ..." messages alongside the completion message ("Completed doing .."). It might look something like this:
Received job 0
Received job 1
Received job 2
2021/02/08 21:30:59 Completed doing 6 things.
Received job 3
Received job 4
Received job 5
2021/02/08 21:31:00 Completed doing 6 things.
The results don't print consistently! This makes sense as we know that goroutines are not
blocking so unless we do somethign about it, code after them will execute immediately. We can add a
WaitGroup
to improve flow. First, define it in the var
block:
var (
..
wg sync.WaitGroup
)
Adjust the loop:
for n := 0; n < batchSize; n++ {
wg.Add(1)
jobs <- n
go doSomething(&wg, jobs)
}
wg.Wait()
logger.Printf("Completed doing %d things.", batchSize)
And finally, the goroutine:
func doSomething(wg *sync.WaitGroup, ch chan int) {
defer wg.Done()
fmt.Printf("Received job %d\n", <-ch)
}
WaitGroups wait for a collection of goroutines to finish. If we read the code out loud, we can see that:
- On each iteration of the loop, we add
1
to the WaitGroup. We add1
because our goroutine will callwg.Done()
once, which decrements the WaitGroup counter by one. It balances out as each goroutine returns. - Before the
logger
call, we addwg.Wait()
. This tells our Go program to block until the WaitGroup counter is zero. The counter will be zero when all goroutines have calledwg.Done()
Simple, right? If we run the program again we can see the results print more consistently:
2021/02/08 21:46:47 Completed doing 6 things.
Received job 0
Received job 1
Received job 2
Received job 4
Received job 5
Received job 3
2021/02/08 21:46:48 Completed doing 6 things.
Received job 0
Received job 2
Received job 3
Received job 4
Received job 5
Received job 1
By the way, it's expected the jobs won't be ordered! We haven't done anything to ensure that.
Before we continue, run the application as it is thus far and try to terminate it,
usually this is done by hitting Control-d
. The program should exit without issue.
To demonstrate further need for control, let's add a piece of code that's more akin to a real-world
scenario. We'll make a new function which calls out to an API and expects a response.
We'll then use context.WithCancel
to cancel the request while it's in flight.
First, create the new function without any context. It's going to be heavier, so follow the in-line commentary as necessary:
func doAPICall(wg *sync.WaitGroup) error {
defer wg.Done()
req, err := http.NewRequest("GET", "https://httpstat.us/200", nil)
if err != nil {
return err
}
// The httpstat.us API accepts a sleep parameter which sleeps the request for the
// passed time in ms
q := req.URL.Query()
sleepMin := 1000
sleepMax := 4000
q.Set("sleep", fmt.Sprintf("%d", rand.Intn(sleepMax-sleepMin)+sleepMin))
req.URL.RawQuery = q.Encode()
// Make the request to the API in an anonymous function, using a channel to
// communicate the results
c := make(chan error, 1)
go func() {
// For the purposes of this example, we're not doing anything with the response.
_, err := http.DefaultClient.Do(req)
c <- err
}()
// Block until the channel is populated
return <-c
}
Modify the ticker interval; remove the previous call to doSomething
, optionally drop the
jobs
channel (we won't use it further), and add a call to doAPICall
:
for n := 0; n < batchSize; n++ {
wg.Add(1)
go doAPICall(&wg)
}
Run the application and try to exit again.
- The WaitGroup continues to wait until all go routines are finished.
- The
doAPICall
function blocks until a response is returned from thehttpstat.us
API, and that can range anywhere from1000
to4000
ms. - Depending on when you try to terminate the application, it can be difficult to do so (you may not notice this on first pass, run it a few times and try to terminate at different times)
Now to demonstrate how context.WithCancel
provides further control over program cancellation. When context.WithCancel
is initialised, it provides a context and a CancelFunc
. This cancel func can cancel the context, which sounds confusing at first pass; Reading Go Concurrency Patterns: Context from The Go Blog helped, and I recommend checking that out after this post!
Ok, back to it. There's little that needs to be done to the application to have it support this cancellation flow. First, create a new context with cancellation function:
var (
ctx, cancel = context.WithCancel(context.Background())
...
)
Then, in the anonymous function where we watch for program termination, call the CancelFunc
after
the signals
channel is notified. This means that the context will be considered cancelled:
go func() {
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGTERM, os.Interrupt)
<-signals
logger.Println("Initiating shutdown of producer.")
cancel()
close(closing)
}()
Then, adjust the doAPICall
function to accept a context, and modify the return statement
to use a blocking select
, waiting on either the ctx.Done
channel or the request
response. Parts of the function snipped for brevity:
func doAPICall(ctx context.Context, ....) {
// Cancel the request if ctx.Done is closed or await the response
select {
case <-ctx.Done():
return ctx.Err()
case err := <-c:
return err
}
}
Finally, ensure the call to doAPICall
has been adjusted to pass the context. Now, run
the application and terminate it at different start times.
What happens now? The application terminates immediately. The blocking select
call
watches for the closure of ctx.Done
or the response on c
, whichever comes first. When
the application is terminated, ctx.Done
takes precedence and the function returns early,
not concerning itself with the response of the request. The WaitGroup continues to do its
specific job and the flow of the application during termination is much improved!
One of the underlying philosophies in Go is Don't communicate by sharing memory; share memory by communicating. Here, we used channels to pass references between goroutines, and that allowed us to improve the flow of our application. There's many ways we could take this further, for example, we don't collect the responses or errors from the API across goroutines. Thankfully, Go makes that trivial to arrange, so consider this a starting point and if you'd like to take this further, try playing with such ideas!
For reference, here's the code in its entirety:
package main
import (
"context"
"fmt"
"log"
"math/rand"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
)
func doAPICall(ctx context.Context, wg *sync.WaitGroup) error {
defer wg.Done()
req, err := http.NewRequest("GET", "https://httpstat.us/200", nil)
if err != nil {
return err
}
// The httpstat.us API accepts a sleep parameter which sleeps the request for the
// passed time in ms
q := req.URL.Query()
sleepMin := 1000
sleepMax := 4000
q.Set("sleep", fmt.Sprintf("%d", rand.Intn(sleepMax-sleepMin)+sleepMin))
req.URL.RawQuery = q.Encode()
c := make(chan error, 1)
go func() {
// For the purposes of this example, we're not doing anything with the response.
_, err := http.DefaultClient.Do(req)
c <- err
}()
// Block until either channel is populated or closed
select {
case <-ctx.Done():
return ctx.Err()
case err := <-c:
return err
}
}
func init() {
rand.Seed(time.Now().Unix())
}
func main() {
var (
closing = make(chan struct{})
ticker = time.NewTicker(1 * time.Second)
logger = log.New(os.Stderr, "", log.LstdFlags)
batchSize = 6
wg sync.WaitGroup
ctx, cancel = context.WithCancel(context.Background())
)
go func() {
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGTERM, os.Interrupt)
<-signals
cancel()
close(closing)
}()
loop:
for {
select {
case <-closing:
break loop
case <-ticker.C:
for n := 0; n < batchSize; n++ {
wg.Add(1)
go doAPICall(ctx, &wg)
}
wg.Wait()
logger.Printf("Completed doing %d things.", batchSize)
}
}
}
As a final note, a portion of this code was inspired by the Go Concurrency Patterns:
Context blog post, which I, again, recommend. It
introduces further controls like context.WithTimeout
and well, the Go blog is a treasure that everyone should read!