aboutsummaryrefslogtreecommitdiff
path: root/middleware/rewrite/rewrite.go
blob: 44e8e43c7cff5399343ac1b0272162e352c80d74 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
package rewrite

import (
	"fmt"
	"strings"

	"github.com/coredns/coredns/middleware"

	"github.com/miekg/dns"

	"golang.org/x/net/context"
)

// Result is the result of a rewrite
type Result int

const (
	// RewriteIgnored is returned when rewrite is not done on request.
	RewriteIgnored Result = iota
	// RewriteDone is returned when rewrite is done on request.
	RewriteDone
	// RewriteStatus is returned when rewrite is not needed and status code should be set
	// for the request.
	RewriteStatus
)

// Rewrite is middleware to rewrite requests internally before being handled.
type Rewrite struct {
	Next     middleware.Handler
	Rules    []Rule
	noRevert bool
}

// ServeDNS implements the middleware.Handler interface.
func (rw Rewrite) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
	wr := NewResponseReverter(w, r)
	for _, rule := range rw.Rules {
		switch result := rule.Rewrite(w, r); result {
		case RewriteDone:
			if rw.noRevert {
				return middleware.NextOrFailure(rw.Name(), rw.Next, ctx, w, r)
			}
			return middleware.NextOrFailure(rw.Name(), rw.Next, ctx, wr, r)
		case RewriteIgnored:
			break
		case RewriteStatus:
			// only valid for complex rules.
			// if cRule, ok := rule.(*ComplexRule); ok && cRule.Status != 0 {
			// return cRule.Status, nil
			// }
		}
	}
	return middleware.NextOrFailure(rw.Name(), rw.Next, ctx, w, r)
}

// Name implements the Handler interface.
func (rw Rewrite) Name() string { return "rewrite" }

// Rule describes a rewrite rule.
type Rule interface {
	// Rewrite rewrites the current request.
	Rewrite(dns.ResponseWriter, *dns.Msg) Result
}

func newRule(args ...string) (Rule, error) {
	if len(args) == 0 {
		return nil, fmt.Errorf("no rule type specified for rewrite")
	}

	ruleType := strings.ToLower(args[0])
	if ruleType != "edns0" && len(args) != 3 {
		return nil, fmt.Errorf("%s rules must have exactly two arguments", ruleType)
	}
	switch ruleType {
	case "name":
		return newNameRule(args[1], args[2])
	case "class":
		return newClassRule(args[1], args[2])
	case "type":
		return newTypeRule(args[1], args[2])
	case "edns0":
		return newEdns0Rule(args[1:]...)
	default:
		return nil, fmt.Errorf("invalid rule type %q", args[0])
	}
}