aboutsummaryrefslogtreecommitdiff
path: root/plugin/pkg/tls/tls_test.go
blob: a5635c1770f358df9660b5dcd1df0ec2c5e6b2df (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package tls

import (
	"os"
	"path/filepath"
	"testing"

	"github.com/coredns/coredns/plugin/test"
)

func getPEMFiles(t *testing.T) (cert, key, ca string) {
	tempDir, err := test.WritePEMFiles(t)
	if err != nil {
		t.Fatalf("Could not write PEM files: %s", err)
	}

	cert = filepath.Join(tempDir, "cert.pem")
	key = filepath.Join(tempDir, "key.pem")
	ca = filepath.Join(tempDir, "ca.pem")

	return
}

func TestNewTLSConfig(t *testing.T) {
	cert, key, ca := getPEMFiles(t)
	_, err := NewTLSConfig(cert, key, ca)
	if err != nil {
		t.Errorf("Failed to create TLSConfig: %s", err)
	}
}

func TestNewTLSClientConfig(t *testing.T) {
	_, _, ca := getPEMFiles(t)

	_, err := NewTLSClientConfig(ca)
	if err != nil {
		t.Errorf("Failed to create TLSConfig: %s", err)
	}
}

func TestNewTLSConfigFromArgs(t *testing.T) {
	cert, key, ca := getPEMFiles(t)

	_, err := NewTLSConfigFromArgs()
	if err != nil {
		t.Errorf("Failed to create TLSConfig: %s", err)
	}

	c, err := NewTLSConfigFromArgs(ca)
	if err != nil {
		t.Errorf("Failed to create TLSConfig: %s", err)
	}
	if c.RootCAs == nil {
		t.Error("RootCAs should not be nil when one arg passed")
	}

	c, err = NewTLSConfigFromArgs(cert, key)
	if err != nil {
		t.Errorf("Failed to create TLSConfig: %s", err)
	}
	if c.RootCAs != nil {
		t.Error("RootCAs should be nil when two args passed")
	}
	if len(c.Certificates) != 1 {
		t.Error("Certificates should have a single entry when two args passed")
	}
	args := []string{cert, key, ca}
	c, err = NewTLSConfigFromArgs(args...)
	if err != nil {
		t.Errorf("Failed to create TLSConfig: %s", err)
	}
	if c.RootCAs == nil {
		t.Error("RootCAs should not be nil when three args passed")
	}
	if len(c.Certificates) != 1 {
		t.Error("Certificates should have a single entry when three args passed")
	}
}

func TestNewTLSConfigFromArgsWithRoot(t *testing.T) {
	cert, key, ca := getPEMFiles(t)
	tempDir, err := os.MkdirTemp("", "go-test-pemfiles")
	defer func() {
		if err := os.RemoveAll(tempDir); err != nil {
			t.Error("failed to clean up temporary directory", err)
		}
	}()
	if err != nil {
		t.Error("failed to create temporary directory", err)
	}
	root := tempDir
	args := []string{cert, key, ca}
	for i := range args {
		if !filepath.IsAbs(args[i]) && root != "" {
			args[i] = filepath.Join(root, args[i])
		}
	}
	c, err := NewTLSConfigFromArgs(args...)
	if err != nil {
		t.Errorf("Failed to create TLSConfig: %s", err)
	}
	if c.RootCAs == nil {
		t.Error("RootCAs should not be nil when three args passed")
	}
	if len(c.Certificates) != 1 {
		t.Error("Certificates should have a single entry when three args passed")
	}
}

func TestNewHTTPSTransport(t *testing.T) {
	_, _, ca := getPEMFiles(t)

	cc, err := NewTLSClientConfig(ca)
	if err != nil {
		t.Errorf("Failed to create TLSConfig: %s", err)
	}

	tr := NewHTTPSTransport(cc)
	if tr == nil {
		t.Errorf("Failed to create https transport with cc")
	}

	tr = NewHTTPSTransport(nil)
	if tr == nil {
		t.Errorf("Failed to create https transport without cc")
	}
}