aboutsummaryrefslogtreecommitdiff
path: root/plugin/rewrite/cname_target.go
blob: c7d93e8a7c5bdc1521cfdb68f290394360201ffe (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package rewrite

import (
	"context"
	"fmt"
	"regexp"
	"strconv"
	"strings"

	"github.com/coredns/coredns/plugin/pkg/log"
	"github.com/coredns/coredns/plugin/pkg/upstream"
	"github.com/coredns/coredns/request"

	"github.com/miekg/dns"
)

// UpstreamInt wraps the Upstream API for dependency injection during testing
type UpstreamInt interface {
	Lookup(ctx context.Context, state request.Request, name string, typ uint16) (*dns.Msg, error)
}

// cnameTargetRule is cname target rewrite rule.
type cnameTargetRule struct {
	rewriteType     string
	paramFromTarget string
	paramToTarget   string
	nextAction      string
	state           request.Request
	ctx             context.Context
	Upstream        UpstreamInt // Upstream for looking up external names during the resolution process.
}

func (r *cnameTargetRule) getFromAndToTarget(inputCName string) (from string, to string) {
	switch r.rewriteType {
	case ExactMatch:
		return r.paramFromTarget, r.paramToTarget
	case PrefixMatch:
		if strings.HasPrefix(inputCName, r.paramFromTarget) {
			return inputCName, r.paramToTarget + strings.TrimPrefix(inputCName, r.paramFromTarget)
		}
	case SuffixMatch:
		if strings.HasSuffix(inputCName, r.paramFromTarget) {
			return inputCName, strings.TrimSuffix(inputCName, r.paramFromTarget) + r.paramToTarget
		}
	case SubstringMatch:
		if strings.Contains(inputCName, r.paramFromTarget) {
			return inputCName, strings.Replace(inputCName, r.paramFromTarget, r.paramToTarget, -1)
		}
	case RegexMatch:
		pattern := regexp.MustCompile(r.paramFromTarget)
		regexGroups := pattern.FindStringSubmatch(inputCName)
		if len(regexGroups) == 0 {
			return "", ""
		}
		substitution := r.paramToTarget
		for groupIndex, groupValue := range regexGroups {
			groupIndexStr := "{" + strconv.Itoa(groupIndex) + "}"
			substitution = strings.Replace(substitution, groupIndexStr, groupValue, -1)
		}
		return inputCName, substitution
	}
	return "", ""
}

func (r *cnameTargetRule) RewriteResponse(res *dns.Msg, rr dns.RR) {
	// logic to rewrite the cname target of dns response
	switch rr.Header().Rrtype {
	case dns.TypeCNAME:
		// rename the target of the cname response
		if cname, ok := rr.(*dns.CNAME); ok {
			fromTarget, toTarget := r.getFromAndToTarget(cname.Target)
			if cname.Target == fromTarget {
				// create upstream request with the new target with the same qtype
				r.state.Req.Question[0].Name = toTarget
				upRes, err := r.Upstream.Lookup(r.ctx, r.state, toTarget, r.state.Req.Question[0].Qtype)

				if err != nil {
					log.Errorf("Error upstream request %v", err)
				}

				var newAnswer []dns.RR
				// iterate over first upstram response
				// add the cname record to the new answer
				for _, rr := range res.Answer {
					if cname, ok := rr.(*dns.CNAME); ok {
						// change the target name in the response
						cname.Target = toTarget
						newAnswer = append(newAnswer, rr)
					}
				}
				// iterate over upstream response recieved
				for _, rr := range upRes.Answer {
					if rr.Header().Name == toTarget {
						newAnswer = append(newAnswer, rr)
					}
				}
				res.Answer = newAnswer
			}
		}
	}
}

func newCNAMERule(nextAction string, args ...string) (Rule, error) {
	var rewriteType string
	var paramFromTarget, paramToTarget string
	if len(args) == 3 {
		rewriteType = (strings.ToLower(args[0]))
		switch rewriteType {
		case ExactMatch:
		case PrefixMatch:
		case SuffixMatch:
		case SubstringMatch:
		case RegexMatch:
		default:
			return nil, fmt.Errorf("unknown cname rewrite type: %s", rewriteType)
		}
		paramFromTarget, paramToTarget = strings.ToLower(args[1]), strings.ToLower(args[2])
	} else if len(args) == 2 {
		rewriteType = ExactMatch
		paramFromTarget, paramToTarget = strings.ToLower(args[0]), strings.ToLower(args[1])
	} else {
		return nil, fmt.Errorf("too few (%d) arguments for a cname rule", len(args))
	}
	rule := cnameTargetRule{
		rewriteType:     rewriteType,
		paramFromTarget: paramFromTarget,
		paramToTarget:   paramToTarget,
		nextAction:      nextAction,
		Upstream:        upstream.New(),
	}
	return &rule, nil
}

// Rewrite rewrites the current request.
func (r *cnameTargetRule) Rewrite(ctx context.Context, state request.Request) (ResponseRules, Result) {
	if len(r.rewriteType) > 0 && len(r.paramFromTarget) > 0 && len(r.paramToTarget) > 0 {
		r.state = state
		r.ctx = ctx
		return ResponseRules{r}, RewriteDone
	}
	return nil, RewriteIgnored
}

// Mode returns the processing mode.
func (r *cnameTargetRule) Mode() string { return r.nextAction }