aboutsummaryrefslogtreecommitdiff
path: root/plugin/rewrite/rewrite.go
blob: edb11813e920f8920ba0312f75be8a24df405fd0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package rewrite

import (
	"context"
	"fmt"
	"strings"

	"github.com/coredns/coredns/plugin"
	"github.com/coredns/coredns/request"

	"github.com/miekg/dns"
)

// Result is the result of a rewrite
type Result int

const (
	// RewriteIgnored is returned when rewrite is not done on request.
	RewriteIgnored Result = iota
	// RewriteDone is returned when rewrite is done on request.
	RewriteDone
)

// These are defined processing mode.
const (
	// Processing should stop after completing this rule
	Stop = "stop"
	// Processing should continue to next rule
	Continue = "continue"
)

// Rewrite is a plugin to rewrite requests internally before being handled.
type Rewrite struct {
	Next  plugin.Handler
	Rules []Rule
	RevertPolicy
}

// ServeDNS implements the plugin.Handler interface.
func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
	if rw.RevertPolicy == nil {
		rw.RevertPolicy = NewRevertPolicy(false, false)
	}
	wr := NewResponseReverter(w, r, rw.RevertPolicy)
	state := request.Request{W: w, Req: r}

	for _, rule := range rw.Rules {
		respRules, result := rule.Rewrite(ctx, state)
		if result == RewriteDone {
			if _, ok := dns.IsDomainName(state.Req.Question[0].Name); !ok {
				err := fmt.Errorf("invalid name after rewrite: %s", state.Req.Question[0].Name)
				state.Req.Question[0] = wr.originalQuestion
				return dns.RcodeServerFailure, err
			}
			wr.ResponseRules = append(wr.ResponseRules, respRules...)
			if rule.Mode() == Stop {
				if !rw.RevertPolicy.DoRevert() {
					return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, w, r)
				}
				rcode, err := plugin.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r)
				if plugin.ClientWrite(rcode) {
					return rcode, err
				}
				// The next plugins didn't write a response, so write one now with the ResponseReverter.
				// If server.ServeDNS does this then it will create an answer mismatch.
				res := new(dns.Msg).SetRcode(r, rcode)
				state.SizeAndDo(res)
				wr.WriteMsg(res)
				// return success, so server does not write a second error response to client
				return dns.RcodeSuccess, err
			}
		}
	}
	if !rw.RevertPolicy.DoRevert() || len(wr.ResponseRules) == 0 {
		return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, w, r)
	}
	return plugin.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r)
}

// Name implements the Handler interface.
func (rw Rewrite) Name() string { return "rewrite" }

// Rule describes a rewrite rule.
type Rule interface {
	// Rewrite rewrites the current request.
	Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result)
	// Mode returns the processing mode stop or continue.
	Mode() string
}

func newRule(args ...string) (Rule, error) {
	if len(args) == 0 {
		return nil, fmt.Errorf("no rule type specified for rewrite")
	}

	arg0 := strings.ToLower(args[0])
	var ruleType string
	var expectNumArgs, startArg int
	mode := Stop
	switch arg0 {
	case Continue:
		mode = Continue
		if len(args) < 2 {
			return nil, fmt.Errorf("continue rule must begin with a rule type")
		}
		ruleType = strings.ToLower(args[1])
		expectNumArgs = len(args) - 1
		startArg = 2
	case Stop:
		if len(args) < 2 {
			return nil, fmt.Errorf("stop rule must begin with a rule type")
		}
		ruleType = strings.ToLower(args[1])
		expectNumArgs = len(args) - 1
		startArg = 2
	default:
		// for backward compatibility
		ruleType = arg0
		expectNumArgs = len(args)
		startArg = 1
	}

	switch ruleType {
	case "answer":
		return nil, fmt.Errorf("response rewrites must begin with a name rule")
	case "name":
		return newNameRule(mode, args[startArg:]...)
	case "class":
		if expectNumArgs != 3 {
			return nil, fmt.Errorf("%s rules must have exactly two arguments", ruleType)
		}
		return newClassRule(mode, args[startArg:]...)
	case "type":
		if expectNumArgs != 3 {
			return nil, fmt.Errorf("%s rules must have exactly two arguments", ruleType)
		}
		return newTypeRule(mode, args[startArg:]...)
	case "edns0":
		return newEdns0Rule(mode, args[startArg:]...)
	case "ttl":
		return newTTLRule(mode, args[startArg:]...)
	case "cname":
		return newCNAMERule(mode, args[startArg:]...)
	case "rcode":
		return newRCodeRule(mode, args[startArg:]...)
	default:
		return nil, fmt.Errorf("invalid rule type %q", args[0])
	}
}