ggzeng
12/30/2019 - 2:01 AM

go pipline demo

https://gist.github.com/claudiofahey/3afcf4f4fb3d8d3b35cadb100d4fb9b7 https://blog.golang.org/pipelines

基本思想:

  1. pipeline中各stage的组合通过自定义函数实现
  2. 每个stage中都通过go routine并行处理,处理结果与处理过程中的错误都通过channel返回
  3. 每个stage都会传入ctx和输入channel
  4. 下一个stage使用上一个stage的输出channel作为输入
  5. pipline完成的时候聚合收集所有的error
//
// Companion code to https://medium.com/statuscode/pipeline-patterns-in-go-a37bb3a7e61d
//
// To run:
//   go get github.com/pkg/errors
//   go run -race pipeline_demo.go
//
package main

import (
	"context"
	"fmt"
	"math/rand"
	"strconv"
	"sync"
	"time"

	"github.com/pkg/errors"
)

// MergeErrors merges multiple channels of errors.
// Based on https://blog.golang.org/pipelines.
func MergeErrors(cs ...<-chan error) <-chan error {
	var wg sync.WaitGroup
	// We must ensure that the output channel has the capacity to hold as many errors
	// as there are error channels. This will ensure that it never blocks, even
	// if WaitForPipeline returns early.
	out := make(chan error, len(cs))

	// Start an output goroutine for each input channel in cs.  output
	// copies values from c to out until c is closed, then calls wg.Done.
	output := func(c <-chan error) {
		for n := range c {
			out <- n
		}
		wg.Done()
	}
	wg.Add(len(cs))
	for _, c := range cs {
		go output(c)
	}

	// Start a goroutine to close out once all the output goroutines are
	// done.  This must start after the wg.Add call.
	go func() {
		wg.Wait()
		close(out)
	}()
	return out
}

// WaitForPipeline waits for results from all error channels.
// It returns early on the first error.
func WaitForPipeline(errs ...<-chan error) error {
	errc := MergeErrors(errs...)
	for err := range errc {
		if err != nil {
			return err
		}
	}
	return nil
}

// minimalPipelineStage shows the elements that every pipeline stage should have.
// All stages should accept a context for cancellation.
// All stages should return a channel of errors to report any error produced after this function returns.
// All stages should return an error to report any error produced before this function returns.
// Any required input parameters should follow ctx and any required outputs should precede
// the errors channel.
// Inputs can be ordinary objects (e.g. a list of strings), channels of objects, or gRPC input streams.
// Outputs can be ordinary objects, channels of objects, or gRPC output streams.
func minimalPipelineStage(ctx context.Context) (<-chan error, error) {
	errc := make(chan error, 1)
	go func() {
		defer close(errc)
		// Do something useful here.
	}()
	return errc, nil
}

func lineListSource(ctx context.Context, lines ...string) (<-chan string, <-chan error, error) {
	if len(lines) == 0 {
		// Handle an error that occurs before the goroutine begins.
		return nil, nil, errors.Errorf("no lines provided")
	}
	out := make(chan string)
	errc := make(chan error, 1)
	go func() {
		defer close(out)
		defer close(errc)
		for lineIndex, line := range lines {
			if line == "" {
				// Handle an error that occurs during the goroutine.
				errc <- errors.Errorf("line %v is empty", lineIndex+1)
				return
			}
			// Send the data to the output channel but return early
			// if the context has been cancelled.
			select {
			case out <- line:
			case <-ctx.Done():
				return
			}
		}
	}()
	return out, errc, nil
}

func lineParser(ctx context.Context, base int, in <-chan string) (<-chan int64, <-chan error, error) {
	if base < 2 {
		// Handle an error that occurs before the goroutine begins.
		return nil, nil, errors.Errorf("invalid base %v", base)
	}
	out := make(chan int64)
	errc := make(chan error, 1)
	go func() {
		defer close(out)
		defer close(errc)
		for line := range in {
			n, err := strconv.ParseInt(line, base, 64)
			if err != nil {
				// Handle an error that occurs during the goroutine.
				errc <- err
				return
			}
			// Send the data to the output channel but return early
			// if the context has been cancelled.
			select {
			case out <- n:
			case <-ctx.Done():
				return
			}
		}
	}()
	return out, errc, nil
}

func splitter(ctx context.Context, in <-chan int64) (<-chan int64, <-chan int64, <-chan error, error) {
	out1 := make(chan int64)
	out2 := make(chan int64)
	errc := make(chan error, 1)
	go func() {
		defer close(out1)
		defer close(out2)
		defer close(errc)
		for n := range in {
			// Send the data to the output channel 1 but return early
			// if the context has been cancelled.
			select {
			case out1 <- n:
			case <-ctx.Done():
				return
			}
			// Send the data to the output channel 2 but return early
			// if the context has been cancelled.
			select {
			case out2 <- n:
			case <-ctx.Done():
				return
			}
		}
	}()
	return out1, out2, errc, nil
}

func squarer(ctx context.Context, in <-chan int64) (<-chan int64, <-chan error, error) {
	out := make(chan int64)
	errc := make(chan error, 1)
	go func() {
		defer close(out)
		defer close(errc)
		for n := range in {
			// Send the data to the output channel but return early
			// if the context has been cancelled.
			select {
			case out <- n * n:
			case <-ctx.Done():
				return
			}
		}
	}()
	return out, errc, nil
}

func sink(ctx context.Context, in <-chan int64) (<-chan error, error) {
	errc := make(chan error, 1)
	go func() {
		defer close(errc)
		for n := range in {
			if n >= 100 {
				// Handle an error that occurs during the goroutine.
				errc <- errors.Errorf("number %v is too large", n)
				return
			}
			fmt.Printf("sink: %v\n", n)
		}
	}()
	return errc, nil
}

func runSimplePipeline(base int, lines []string) error {
	fmt.Printf("runSimplePipeline: base=%v, lines=%v\n", base, lines)

	ctx, cancelFunc := context.WithCancel(context.Background())
	defer cancelFunc()

	var errcList []<-chan error

	// Source pipeline stage.
	linec, errc, err := lineListSource(ctx, lines...)
	if err != nil {
		return err
	}
	errcList = append(errcList, errc)

	// Transformer pipeline stage.
	numberc, errc, err := lineParser(ctx, base, linec)
	if err != nil {
		return err
	}
	errcList = append(errcList, errc)

	// Sink pipeline stage.
	errc, err = sink(ctx, numberc)
	if err != nil {
		return err
	}
	errcList = append(errcList, errc)

	fmt.Println("Pipeline started. Waiting for pipeline to complete.")

	return WaitForPipeline(errcList...)
}

func runComplexPipeline(base int, lines []string) error {
	fmt.Printf("runComplexPipeline: base=%v, lines=%v\n", base, lines)

	ctx, cancelFunc := context.WithCancel(context.Background())
	defer cancelFunc()

	var errcList []<-chan error

	// Source pipeline stage.
	linec, errc, err := lineListSource(ctx, lines...)
	if err != nil {
		return err
	}
	errcList = append(errcList, errc)

	// Transformer pipeline stage 1.
	numberc, errc, err := lineParser(ctx, base, linec)
	if err != nil {
		return err
	}
	errcList = append(errcList, errc)

	// Transformer pipeline stage 2.
	numberc1, numberc2, errc, err := splitter(ctx, numberc)
	if err != nil {
		return err
	}
	errcList = append(errcList, errc)

	// Transformer pipeline stage 3.
	numberc3, errc, err := squarer(ctx, numberc1)
	if err != nil {
		return err
	}
	errcList = append(errcList, errc)

	// Sink pipeline stage 1.
	errc, err = sink(ctx, numberc3)
	if err != nil {
		return err
	}
	errcList = append(errcList, errc)

	// Sink pipeline stage 2.
	errc, err = sink(ctx, numberc2)
	if err != nil {
		return err
	}
	errcList = append(errcList, errc)

	fmt.Println("Pipeline started. Waiting for pipeline to complete.")

	return WaitForPipeline(errcList...)
}

func randomNumberSource(ctx context.Context, seed int64) (<-chan string, <-chan error, error) {
	out := make(chan string)
	errc := make(chan error, 1)
	random := rand.New(rand.NewSource(seed))
	go func() {
		defer close(out)
		defer close(errc)
		for {
			n := random.Intn(100)
			line := fmt.Sprintf("%v", n)
			// Send the data to the output channel but return if the context has been cancelled.
			select {
			case out <- line:
			case <-ctx.Done():
				return
			}
			time.Sleep(1 * time.Second)
		}
	}()
	return out, errc, nil
}

func runPipelineWithTimeout() error {
	fmt.Printf("runPipelineWithTimeout\n")

	ctx, cancelFunc := context.WithCancel(context.Background())
	defer cancelFunc()

	var errcList []<-chan error

	// Source pipeline stage.
	linec, errc, err := randomNumberSource(ctx, 3)
	if err != nil {
		return err
	}
	errcList = append(errcList, errc)

	// Transformer pipeline stage.
	numberc, errc, err := lineParser(ctx, 10, linec)
	if err != nil {
		return err
	}
	errcList = append(errcList, errc)

	// Sink pipeline stage.
	errc, err = sink(ctx, numberc)
	if err != nil {
		return err
	}
	errcList = append(errcList, errc)

	fmt.Println("Pipeline started. Waiting for pipeline to complete.")

	// Start a goroutine that will cancel this pipeline in 10 seconds.
	go func() {
		time.Sleep(10 * time.Second)
		fmt.Println("Cancelling context.")
		cancelFunc()
	}()

	return WaitForPipeline(errcList...)
}

func main() {
	if err := runSimplePipeline(10, []string{"3", "2", "1"}); err != nil {
		fmt.Println(err)
	}
	if err := runSimplePipeline(1, []string{"3", "2", "1"}); err != nil {
		fmt.Println(err)
	}
	if err := runSimplePipeline(2, []string{"1010", "1100", "1000"}); err != nil {
		fmt.Println(err)
	}
	if err := runSimplePipeline(2, []string{"1010", "1100", "2000", "1111"}); err != nil {
		fmt.Println(err)
	}
	if err := runSimplePipeline(10, []string{"1", "10", "100", "1000"}); err != nil {
		fmt.Println(err)
	}
	if err := runComplexPipeline(10, []string{"5", "4", "3"}); err != nil {
		fmt.Println(err)
	}
	if err := runPipelineWithTimeout(); err != nil {
		fmt.Println(err)
	}
}