steevehook
1/10/2019 - 12:48 PM

Pipeline concurrency pattern

package main

import (
	"fmt"
	"sync"
)

func main() {
	done := make(chan struct{})
	defer close(done)
	numbers := gen(2, 4, 6)
	c1 := sq(done, numbers)
	c2 := sq(done, numbers)
	c3 := sq(done, numbers)
	out := merge(done, c1, c2, c3)

	fmt.Println(<-out)
}

func gen(numbers ...int) <-chan int {
	out := make(chan int, len(numbers))
	for _, n := range numbers {
		out <- n
	}
	close(out)

	return out
}

func sq(done <-chan struct{}, in <-chan int) <-chan int {
	out := make(chan int)
	go func() {
		defer close(out)
		for n := range in {
			select {
			case out <- n * n:
			case <-done:
			}
		}
	}()

	return out
}

func merge(done <-chan struct{}, in ...<-chan int) <-chan int {
	var wg sync.WaitGroup
	out := make(chan int)
	output := func(c <-chan int) {
		defer wg.Done()
		for n := range c {
			select {
			case out <- n:
				fmt.Println("Adding ", n)
			case <-done:
				fmt.Println("Exiting from ", n)
				return
			}
		}
	}

	wg.Add(len(in))
	for _, c := range in {
		go output(c)
	}

	go func() {
		wg.Wait()
		close(out)
	}()

	return out
}