aboutsummaryrefslogtreecommitdiff
path: root/plugin/dnstap/io.go
blob: f95e4b5e893eeee89ed7b8a299f952865fe751cf (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package dnstap

import (
	"crypto/tls"
	"net"
	"sync/atomic"
	"time"

	tap "github.com/dnstap/golang-dnstap"
)

const (
	tcpWriteBufSize = 1024 * 1024 // there is no good explanation for why this number has this value.
	queueSize       = 10000       // idem.

	tcpTimeout   = 4 * time.Second
	flushTimeout = 1 * time.Second

	skipVerify = false // by default, every tls connection is verified to be secure
)

// tapper interface is used in testing to mock the Dnstap method.
type tapper interface {
	Dnstap(*tap.Dnstap)
}

// dio implements the Tapper interface.
type dio struct {
	endpoint     string
	proto        string
	enc          *encoder
	queue        chan *tap.Dnstap
	dropped      uint32
	quit         chan struct{}
	flushTimeout time.Duration
	tcpTimeout   time.Duration
	skipVerify   bool
}

// newIO returns a new and initialized pointer to a dio.
func newIO(proto, endpoint string) *dio {
	return &dio{
		endpoint:     endpoint,
		proto:        proto,
		queue:        make(chan *tap.Dnstap, queueSize),
		quit:         make(chan struct{}),
		flushTimeout: flushTimeout,
		tcpTimeout:   tcpTimeout,
		skipVerify:   skipVerify,
	}
}

func (d *dio) dial() error {
	var conn net.Conn
	var err error

	if d.proto == "tls" {
		config := &tls.Config{
			InsecureSkipVerify: d.skipVerify,
		}
		dialer := &net.Dialer{
			Timeout: d.tcpTimeout,
		}
		conn, err = tls.DialWithDialer(dialer, "tcp", d.endpoint, config)
		if err != nil {
			return err
		}
	} else {
		conn, err = net.DialTimeout(d.proto, d.endpoint, d.tcpTimeout)
		if err != nil {
			return err
		}
	}

	if tcpConn, ok := conn.(*net.TCPConn); ok {
		tcpConn.SetWriteBuffer(tcpWriteBufSize)
		tcpConn.SetNoDelay(false)
	}

	d.enc, err = newEncoder(conn, d.tcpTimeout)
	return err
}

// Connect connects to the dnstap endpoint.
func (d *dio) connect() error {
	err := d.dial()
	go d.serve()
	return err
}

// Dnstap enqueues the payload for log.
func (d *dio) Dnstap(payload *tap.Dnstap) {
	select {
	case d.queue <- payload:
	default:
		atomic.AddUint32(&d.dropped, 1)
	}
}

// close waits until the I/O routine is finished to return.
func (d *dio) close() { close(d.quit) }

func (d *dio) write(payload *tap.Dnstap) error {
	if d.enc == nil {
		atomic.AddUint32(&d.dropped, 1)
		return nil
	}
	if err := d.enc.writeMsg(payload); err != nil {
		atomic.AddUint32(&d.dropped, 1)
		return err
	}
	return nil
}

func (d *dio) serve() {
	timeout := time.NewTimer(d.flushTimeout)
	defer timeout.Stop()
	for {
		timeout.Reset(d.flushTimeout)
		select {
		case <-d.quit:
			if d.enc == nil {
				return
			}
			d.enc.flush()
			d.enc.close()
			return
		case payload := <-d.queue:
			if err := d.write(payload); err != nil {
				d.dial()
			}
		case <-timeout.C:
			if dropped := atomic.SwapUint32(&d.dropped, 0); dropped > 0 {
				log.Warningf("Dropped dnstap messages: %d", dropped)
			}
			if d.enc == nil {
				d.dial()
			} else {
				d.enc.flush()
			}
		}
	}
}