diff options
author | 2016-03-18 20:57:35 +0000 | |
---|---|---|
committer | 2016-03-18 20:57:35 +0000 | |
commit | 3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d (patch) | |
tree | fae74c33cfed05de603785294593275f1901c861 | |
download | coredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.tar.gz coredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.tar.zst coredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.zip |
First commit
131 files changed, 15193 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..0dd26ce5d --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +.DS_Store +Thumbs.db +_gitignore/ +Vagrantfile +.vagrant/ + +dist/builds/ +dist/release/ + +error.log +access.log + +/*.conf +Caddyfile + +og_static/
\ No newline at end of file diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 000000000..6a2da63db --- /dev/null +++ b/.travis.yml @@ -0,0 +1,16 @@ +language: go + +go: + - 1.6 + - tip + +env: +- CGO_ENABLED=0 + +install: + - go get -t ./... + - go get golang.org/x/tools/cmd/vet + +script: + - go vet ./... + - go test ./... diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..346c6dcb9 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,77 @@ +## Contributing to Caddy + +Welcome! Our community focuses on helping others and making Caddy the best it +can be. We gladly accept contributions and encourage you to get involved! + + +### Join us in chat + +Please direct your discussion to the correct room: + +- **Dev Chat:** [gitter.im/mholt/caddy](https://gitter.im/mholt/caddy) - to chat +with other Caddy developers +- **Support:** +[gitter.im/caddyserver/support](https://gitter.im/caddyserver/support) - to give +and get help +- **General:** +[gitter.im/caddyserver/general](https://gitter.im/caddyserver/general) - for +anything about Web development + + +### Bug reports + +First, please [search this repository](https://github.com/mholt/caddy/search?q=&type=Issues&utf8=%E2%9C%93) +with a variety of keywords to ensure your bug is not already reported. + +If not, [open an issue](https://github.com/mholt/caddy/issues) and answer the +questions so we can understand and reproduce the problematic behavior. + +The burden is on you to convince us that it is actually a bug in Caddy. This is +easiest to do when you write clear, concise instructions so we can reproduce +the behavior (even if it seems obvious). The more detailed and specific you are, +the faster we will be able to help you. Check out +[How to Report Bugs Effectively](http://www.chiark.greenend.org.uk/~sgtatham/bugs.html). + +Please be kind. :smile: Remember that Caddy comes at no cost to you, and you're +getting free help. If we helped you, please consider +[donating](https://caddyserver.com/donate) - it keeps us motivated! + + +### Minor improvements and new tests + +Submit [pull requests](https://github.com/mholt/caddy/pulls) at any time. Make +sure to write tests to assert your change is working properly and is thoroughly +covered. + + +### Proposals, suggestions, ideas, new features + +First, please [search](https://github.com/mholt/caddy/search?q=&type=Issues&utf8=%E2%9C%93) +with a variety of keywords to ensure your suggestion/proposal is new. + +If so, you may open either an issue or a pull request for discussion and +feedback. + +The advantage of issues is that you don't have to spend time actually +implementing your idea, but you should still describe it thoroughly. The +advantage of a pull request is that we can immediately see the impact the change +will have on the project, what the code will look like, and how to improve it. +The disadvantage of pull requests is that they are unlikely to get accepted +without significant changes, or it may be rejected entirely. Don't worry, that +won't happen without an open discussion first. + +If you are going to spend significant time implementing code for a pull request, +best to open an issue first and "claim" it and get feedback before you invest +a lot of time. + + +### Vulnerabilities + +If you've found a vulnerability that is serious, please email me: Matthew dot +Holt at Gmail. If it's not a big deal, a pull request will probably be faster. + + +## Thank you + +Thanks for your help! Caddy would not be what it is today without your +contributions. diff --git a/Caddyfile-simple b/Caddyfile-simple new file mode 100644 index 000000000..610fe9a19 --- /dev/null +++ b/Caddyfile-simple @@ -0,0 +1,7 @@ +.:1053 { + prometheus + rewrite ANY HINFO + + file db.miek.nl miek.nl + reflect +} diff --git a/Corefile b/Corefile new file mode 100644 index 000000000..97a631950 --- /dev/null +++ b/Corefile @@ -0,0 +1,9 @@ +.:1053 { + file db.miek.nl miek.nl + proxy . 8.8.8.8:53 +} + +dns.miek.nl:1053 { + file db.dns.miek.nl + reflect +} diff --git a/Corefile-name-alone b/Corefile-name-alone new file mode 100644 index 000000000..5bd787f3a --- /dev/null +++ b/Corefile-name-alone @@ -0,0 +1 @@ +miek.nl diff --git a/ISSUE_TEMPLATE b/ISSUE_TEMPLATE new file mode 100644 index 000000000..f5b4ec6e2 --- /dev/null +++ b/ISSUE_TEMPLATE @@ -0,0 +1,22 @@ +*If you are filing a bug report, please answer these questions. If your issue is not a bug report, + you do not need to use this template. Either way, please consider donating if we've helped you. + Thanks!* + +#### 1. What version of CoreDNS are you running (`coredns -version`)? + + +#### 2. What are you trying to do? + + +#### 3. What is your entire Corefile? +```text +(Put Corefile here) +``` + +#### 4. How did you run CoreDNS (give the full command and describe the execution environment)? + + +#### 5. What did you expect to see? + + +#### 6. What did you see instead (give full error messages and/or log)? diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 000000000..8dada3eda --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 000000000..1252131c3 --- /dev/null +++ b/README.md @@ -0,0 +1,174 @@ +<!-- +[](https://caddyserver.com) + +[](https://gitter.im/mholt/caddy) +[](https://godoc.org/github.com/mholt/caddy) +[](https://travis-ci.org/mholt/caddy) +[](https://ci.appveyor.com/project/mholt/caddy) +--> + +CoreDNS is a lightweight, general-purpose DNS server for Windows, Mac, Linux, BSD +and [Android](https://github.com/mholt/caddy/wiki/Running-Caddy-on-Android). +It is a capable alternative to other popular and easy to use web servers. +([@caddyserver](https://twitter.com/caddyserver) on Twitter) + +The most notable features are HTTP/2, [Let's Encrypt](https://letsencrypt.org) +support, Virtual Hosts, TLS + SNI, and easy configuration with a +[Caddyfile](https://caddyserver.com/docs/caddyfile). In development, you usually +put one Caddyfile with each site. In production, Caddy serves HTTPS by default +and manages all cryptographic assets for you. + +[Download](https://github.com/mholt/caddy/releases) · +[User Guide](https://caddyserver.com/docs) + + + +### Menu + +- [Getting Caddy](#getting-caddy) +- [Quick Start](#quick-start) +- [Running from Source](#running-from-source) +- [Contributing](#contributing) +- [About the Project](#about-the-project) + + + + +## Getting Caddy + +Caddy binaries have no dependencies and are available for nearly every platform. + +[Latest release](https://github.com/mholt/caddy/releases/latest) + + + +## Quick Start + +The website has [full documentation](https://caddyserver.com/docs) but this will +get you started in about 30 seconds: + +Place a file named "Caddyfile" with your site. Paste this into it and save: + +``` +localhost + +gzip +browse +ext .html +websocket /echo cat +log ../access.log +header /api Access-Control-Allow-Origin * +``` + +Run `caddy` from that directory, and it will automatically use that Caddyfile to +configure itself. + +That simple file enables compression, allows directory browsing (for folders +without an index file), serves clean URLs, hosts a WebSocket echo server at +/echo, logs requests to access.log, and adds the coveted +`Access-Control-Allow-Origin: *` header for all responses from some API. + +Wow! Caddy can do a lot with just a few lines. + + +#### Defining multiple sites + +You can run multiple sites from the same Caddyfile, too: + +``` +site1.com { + # ... +} + +site2.com, sub.site2.com { + # ... +} +``` + +Note that all these sites will automatically be served over HTTPS using Let's +Encrypt as the CA. Caddy will manage the certificates (including renewals) for +you. You don't even have to think about it. + +For more documentation, please view [the website](https://caddyserver.com/docs). +You may also be interested in the [developer guide] +(https://github.com/mholt/caddy/wiki) on this project's GitHub wiki. + + + + +## Running from Source + +Note: You will need **[Go 1.6](https://golang.org/dl/)** or newer. + +1. `$ go get github.com/mholt/caddy` +2. `cd` into your website's directory +3. Run `caddy` (assumes `$GOPATH/bin` is in your `$PATH`) + +If you're tinkering, you can also use `go run main.go`. + +By default, Caddy serves the current directory at +[localhost:2015](http://localhost:2015). You can place a Caddyfile to configure +Caddy for serving your site. + +Caddy accepts some flags from the command line. Run `caddy -h` to view the help + for flags. You can also pipe a Caddyfile into the caddy command. + +**Running as root:** We advise against this; use setcap instead, like so: +`setcap cap_net_bind_service=+ep ./caddy` This will allow you to listen on +ports < 1024 like 80 and 443. + + + +#### Docker Container + +Caddy is available as a Docker container from any of these sources: + +- [abiosoft/caddy](https://hub.docker.com/r/abiosoft/caddy/) +- [darron/caddy](https://hub.docker.com/r/darron/caddy/) +- [joshix/caddy](https://hub.docker.com/r/joshix/caddy/) +- [jumanjiman/caddy](https://hub.docker.com/r/jumanjiman/caddy/) +- [zenithar/nano-caddy](https://hub.docker.com/r/zenithar/nano-caddy/) + + + +#### 3rd-party dependencies + +Although Caddy's binaries are completely static, Caddy relies on some excellent +libraries. [Godoc.org](https://godoc.org/github.com/mholt/caddy) shows the +packages that each Caddy package imports. + + + + +## Contributing + +**[Join our dev chat on Gitter](https://gitter.im/mholt/caddy)** to chat with +other Caddy developers! (Dev chat only; try our +[support room](https://gitter.im/caddyserver/support) for help or +[general](https://gitter.im/caddyserver/general) for anything else.) + +This project would not be what it is without your help. Please see the +[contributing guidelines](https://github.com/mholt/caddy/blob/master/CONTRIBUTING.md) +if you haven't already. + +Thanks for making Caddy -- and the Web -- better! + +Special thanks to +[](https://www.digitalocean.com) +for hosting the Caddy project. + + + + +## About the project + +Caddy was born out of the need for a "batteries-included" web server that runs +anywhere and doesn't have to take its configuration with it. Caddy took +inspiration from [spark](https://github.com/rif/spark), +[nginx](https://github.com/nginx/nginx), lighttpd, +[Websocketd](https://github.com/joewalnes/websocketd) +and [Vagrant](https://www.vagrantup.com/), +which provides a pleasant mixture of features from each of them. + + +*Twitter: [@mholt6](https://twitter.com/mholt6)* @@ -0,0 +1,35 @@ +* Fix file middleware to use a proper zone implementation + * Zone parsing (better zone impl.) +* Zones file parsing is done twice on startup?? +* Might need global middleware state between middlewares +* Cleanup/make middlewares + * Fix complex rewrite to be useful + * Healthcheck middleware + * Slave zone middleware + * SkyDNS middleware, or call it etcd? +* Fix graceful restart +* TESTS; don't compile, need cleanups +* http.FileSystem is half used, half not used. It's a nice abstraction + for finding (zone) files, maybe we should just use it. +* prometheus: + * track the query type + * track the correct zone + +When there is already something running. + +BUG: server/server.go ListenAndServe +Activating privacy features... +.:1053 +panic: close of closed channel + +goroutine 40 [running]: +panic(0x8e5b60, 0xc8201b60b0) + /home/miek/upstream/go/src/runtime/panic.go:464 +0x3e6 +github.com/miekg/daddy/server.(*Server).ListenAndServe.func1.1() + /home/miek/g/src/github.com/miekg/daddy/server/server.go:147 +0x24 +sync.(*Once).Do(0xc82011b830, 0xc8201d3f38) + /home/miek/upstream/go/src/sync/once.go:44 +0xe4 +github.com/miekg/daddy/server.(*Server).ListenAndServe.func1(0xc82011b4c0, 0xc820090800, 0xc82011b830) + /home/miek/g/src/github.com/miekg/daddy/server/server.go:148 +0x1e3 +created by github.com/miekg/daddy/server.(*Server).ListenAndServe + /home/miek/g/src/github.com/miekg/daddy/server/server.go:150 +0xfe diff --git a/build.bash b/build.bash new file mode 100755 index 000000000..b7c97d1ec --- /dev/null +++ b/build.bash @@ -0,0 +1,55 @@ +#!/usr/bin/env bash +# +# Caddy build script. Automates proper versioning. +# +# Usage: +# +# $ ./build.bash [output_filename] +# +# Outputs compiled program in current directory. +# Default file name is 'ecaddy'. +# +set -e + +output="$1" +if [ -z "$output" ]; then + output="ecaddy" +fi + +pkg=main + +# Timestamp of build +builddate_id=$pkg.buildDate +builddate=`date -u` + +# Current tag, if HEAD is on a tag +tag_id=$pkg.gitTag +set +e +tag=`git describe --exact-match HEAD 2> /dev/null` +set -e + +# Nearest tag on branch +lasttag_id=$pkg.gitNearestTag +lasttag=`git describe --abbrev=0 --tags HEAD` + +# Commit SHA +commit_id=$pkg.gitCommit +commit=`git rev-parse --short HEAD` + +# Summary of uncommited changes +shortstat_id=$pkg.gitShortStat +shortstat=`git diff-index --shortstat HEAD` + +# List of modified files +files_id=$pkg.gitFilesModified +files=`git diff-index --name-only HEAD` + + +go build -ldflags " + -X \"$builddate_id=$builddate\" + -X \"$tag_id=$tag\" + -X \"$lasttag_id=$lasttag\" + -X \"$commit_id=$commit\" + -X \"$shortstat_id=$shortstat\" + -X \"$files_id=$files\" +" -o "$output" diff --git a/core/assets/path.go b/core/assets/path.go new file mode 100644 index 000000000..46b883b1c --- /dev/null +++ b/core/assets/path.go @@ -0,0 +1,29 @@ +package assets + +import ( + "os" + "path/filepath" + "runtime" +) + +// Path returns the path to the folder +// where the application may store data. This +// currently resolves to ~/.caddy +func Path() string { + return filepath.Join(userHomeDir(), ".caddy") +} + +// userHomeDir returns the user's home directory according to +// environment variables. +// +// Credit: http://stackoverflow.com/a/7922977/1048862 +func userHomeDir() string { + if runtime.GOOS == "windows" { + home := os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH") + if home == "" { + home = os.Getenv("USERPROFILE") + } + return home + } + return os.Getenv("HOME") +} diff --git a/core/assets/path_test.go b/core/assets/path_test.go new file mode 100644 index 000000000..374f813af --- /dev/null +++ b/core/assets/path_test.go @@ -0,0 +1,12 @@ +package assets + +import ( + "strings" + "testing" +) + +func TestPath(t *testing.T) { + if actual := Path(); !strings.HasSuffix(actual, ".caddy") { + t.Errorf("Expected path to be a .caddy folder, got: %v", actual) + } +} diff --git a/core/caddy.go b/core/caddy.go new file mode 100644 index 000000000..e76fa28f1 --- /dev/null +++ b/core/caddy.go @@ -0,0 +1,388 @@ +// Package caddy implements the Caddy web server as a service +// in your own Go programs. +// +// To use this package, follow a few simple steps: +// +// 1. Set the AppName and AppVersion variables. +// 2. Call LoadCaddyfile() to get the Caddyfile (it +// might have been piped in as part of a restart). +// You should pass in your own Caddyfile loader. +// 3. Call caddy.Start() to start Caddy, caddy.Stop() +// to stop it, or caddy.Restart() to restart it. +// +// You should use caddy.Wait() to wait for all Caddy servers +// to quit before your process exits. +package core + +import ( + "bytes" + "encoding/gob" + "errors" + "fmt" + "io/ioutil" + "log" + "os" + "path" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/coredns/core/https" + "github.com/miekg/coredns/server" +) + +// Configurable application parameters +var ( + // AppName is the name of the application. + AppName string + + // AppVersion is the version of the application. + AppVersion string + + // Quiet when set to true, will not show any informative output on initialization. + Quiet bool + + // PidFile is the path to the pidfile to create. + PidFile string + + // GracefulTimeout is the maximum duration of a graceful shutdown. + GracefulTimeout time.Duration +) + +var ( + // caddyfile is the input configuration text used for this process + caddyfile Input + + // caddyfileMu protects caddyfile during changes + caddyfileMu sync.Mutex + + // errIncompleteRestart occurs if this process is a fork + // of the parent but no Caddyfile was piped in + errIncompleteRestart = errors.New("incomplete restart") + + // servers is a list of all the currently-listening servers + servers []*server.Server + + // serversMu protects the servers slice during changes + serversMu sync.Mutex + + // wg is used to wait for all servers to shut down + wg sync.WaitGroup + + // loadedGob is used if this is a child process as part of + // a graceful restart; it is used to map listeners to their + // index in the list of inherited file descriptors. This + // variable is not safe for concurrent access. + loadedGob caddyfileGob + + // startedBefore should be set to true if caddy has been started + // at least once (does not indicate whether currently running). + startedBefore bool +) + +const ( + // DefaultHost is the default host. + DefaultHost = "" + // DefaultPort is the default port. + DefaultPort = "53" + // DefaultRoot is the default root folder. + DefaultRoot = "." +) + +// Start starts Caddy with the given Caddyfile. If cdyfile +// is nil, the LoadCaddyfile function will be called to get +// one. +// +// This function blocks until all the servers are listening. +// +// Note (POSIX): If Start is called in the child process of a +// restart more than once within the duration of the graceful +// cutoff (i.e. the child process called Start a first time, +// then called Stop, then Start again within the first 5 seconds +// or however long GracefulTimeout is) and the Caddyfiles have +// at least one listener address in common, the second Start +// may fail with "address already in use" as there's no +// guarantee that the parent process has relinquished the +// address before the grace period ends. +func Start(cdyfile Input) (err error) { + // If we return with no errors, we must do two things: tell the + // parent that we succeeded and write to the pidfile. + defer func() { + if err == nil { + signalSuccessToParent() // TODO: Is doing this more than once per process a bad idea? Start could get called more than once in other apps. + if PidFile != "" { + err := writePidFile() + if err != nil { + log.Printf("[ERROR] Could not write pidfile: %v", err) + } + } + } + }() + + // Input must never be nil; try to load something + if cdyfile == nil { + cdyfile, err = LoadCaddyfile(nil) + if err != nil { + return err + } + } + + caddyfileMu.Lock() + caddyfile = cdyfile + caddyfileMu.Unlock() + + // load the server configs (activates Let's Encrypt) + configs, err := loadConfigs(path.Base(cdyfile.Path()), bytes.NewReader(cdyfile.Body())) + if err != nil { + return err + } + + // group zones by address + groupings, err := arrangeBindings(configs) + if err != nil { + return err + } + + // Start each server with its one or more configurations + err = startServers(groupings) + if err != nil { + return err + } + startedBefore = true + + // Show initialization output + if !Quiet && !IsRestart() { + var checkedFdLimit bool + for _, group := range groupings { + for _, conf := range group.Configs { + // Print address of site + fmt.Println(conf.Address()) + + // Note if non-localhost site resolves to loopback interface + if group.BindAddr.IP.IsLoopback() && !isLocalhost(conf.Host) { + fmt.Printf("Notice: %s is only accessible on this machine (%s)\n", + conf.Host, group.BindAddr.IP.String()) + } + if !checkedFdLimit && !group.BindAddr.IP.IsLoopback() && !isLocalhost(conf.Host) { + checkFdlimit() + checkedFdLimit = true + } + } + } + } + + return nil +} + +// startServers starts all the servers in groupings, +// taking into account whether or not this process is +// a child from a graceful restart or not. It blocks +// until the servers are listening. +func startServers(groupings bindingGroup) error { + var startupWg sync.WaitGroup + errChan := make(chan error, len(groupings)) // must be buffered to allow Serve functions below to return if stopped later + + for _, group := range groupings { + s, err := server.New(group.BindAddr.String(), group.Configs, GracefulTimeout) + if err != nil { + return err + } + // TODO(miek): does not work, because this callback uses http instead of dns + // s.ReqCallback = https.RequestCallback // ensures we can solve ACME challenges while running + if s.OnDemandTLS { + s.TLSConfig.GetCertificate = https.GetOrObtainCertificate // TLS on demand -- awesome! + } else { + s.TLSConfig.GetCertificate = https.GetCertificate + } + + var ln server.ListenerFile + /* + if IsRestart() { + // Look up this server's listener in the map of inherited file descriptors; + // if we don't have one, we must make a new one (later). + if fdIndex, ok := loadedGob.ListenerFds[s.Addr]; ok { + file := os.NewFile(fdIndex, "") + + fln, err := net.FileListener(file) + if err != nil { + return err + } + + ln, ok = fln.(server.ListenerFile) + if !ok { + return errors.New("listener for " + s.Addr + " was not a ListenerFile") + } + + file.Close() + delete(loadedGob.ListenerFds, s.Addr) + } + } + */ + + wg.Add(1) + go func(s *server.Server, ln server.ListenerFile) { + defer wg.Done() + + // run startup functions that should only execute when + // the original parent process is starting. + if !IsRestart() && !startedBefore { + err := s.RunFirstStartupFuncs() + if err != nil { + errChan <- err + return + } + } + + // start the server + // TODO(miek): for now will always be nil, so we will run ListenAndServe() + if ln != nil { + //errChan <- s.Serve(ln) + } else { + errChan <- s.ListenAndServe() + } + }(s, ln) + + startupWg.Add(1) + go func(s *server.Server) { + defer startupWg.Done() + s.WaitUntilStarted() + }(s) + + serversMu.Lock() + servers = append(servers, s) + serversMu.Unlock() + } + + // Close the remaining (unused) file descriptors to free up resources + if IsRestart() { + for key, fdIndex := range loadedGob.ListenerFds { + os.NewFile(fdIndex, "").Close() + delete(loadedGob.ListenerFds, key) + } + } + + // Wait for all servers to finish starting + startupWg.Wait() + + // Return the first error, if any + select { + case err := <-errChan: + // "use of closed network connection" is normal if it was a graceful shutdown + if err != nil && !strings.Contains(err.Error(), "use of closed network connection") { + return err + } + default: + } + + return nil +} + +// Stop stops all servers. It blocks until they are all stopped. +// It does NOT execute shutdown callbacks that may have been +// configured by middleware (they must be executed separately). +func Stop() error { + https.Deactivate() + + serversMu.Lock() + for _, s := range servers { + if err := s.Stop(); err != nil { + log.Printf("[ERROR] Stopping %s: %v", s.Addr, err) + } + } + servers = []*server.Server{} // don't reuse servers + serversMu.Unlock() + + return nil +} + +// Wait blocks until all servers are stopped. +func Wait() { + wg.Wait() +} + +// LoadCaddyfile loads a Caddyfile, prioritizing a Caddyfile +// piped from stdin as part of a restart (only happens on first call +// to LoadCaddyfile). If it is not a restart, this function tries +// calling the user's loader function, and if that returns nil, then +// this function resorts to the default configuration. Thus, if there +// are no other errors, this function always returns at least the +// default Caddyfile. +func LoadCaddyfile(loader func() (Input, error)) (cdyfile Input, err error) { + // If we are a fork, finishing the restart is highest priority; + // piped input is required in this case. + if IsRestart() { + err := gob.NewDecoder(os.Stdin).Decode(&loadedGob) + if err != nil { + return nil, err + } + cdyfile = loadedGob.Caddyfile + atomic.StoreInt32(https.OnDemandIssuedCount, loadedGob.OnDemandTLSCertsIssued) + } + + // Try user's loader + if cdyfile == nil && loader != nil { + cdyfile, err = loader() + } + + // Otherwise revert to default + if cdyfile == nil { + cdyfile = DefaultInput() + } + + return +} + +// CaddyfileFromPipe loads the Caddyfile input from f if f is +// not interactive input. f is assumed to be a pipe or stream, +// such as os.Stdin. If f is not a pipe, no error is returned +// but the Input value will be nil. An error is only returned +// if there was an error reading the pipe, even if the length +// of what was read is 0. +func CaddyfileFromPipe(f *os.File) (Input, error) { + fi, err := f.Stat() + if err == nil && fi.Mode()&os.ModeCharDevice == 0 { + // Note that a non-nil error is not a problem. Windows + // will not create a stdin if there is no pipe, which + // produces an error when calling Stat(). But Unix will + // make one either way, which is why we also check that + // bitmask. + // BUG: Reading from stdin after this fails (e.g. for the let's encrypt email address) (OS X) + confBody, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + return CaddyfileInput{ + Contents: confBody, + Filepath: f.Name(), + }, nil + } + + // not having input from the pipe is not itself an error, + // just means no input to return. + return nil, nil +} + +// Caddyfile returns the current Caddyfile +func Caddyfile() Input { + caddyfileMu.Lock() + defer caddyfileMu.Unlock() + return caddyfile +} + +// Input represents a Caddyfile; its contents and file path +// (which should include the file name at the end of the path). +// If path does not apply (e.g. piped input) you may use +// any understandable value. The path is mainly used for logging, +// error messages, and debugging. +type Input interface { + // Gets the Caddyfile contents + Body() []byte + + // Gets the path to the origin file + Path() string + + // IsFile returns true if the original input was a file on the file system + // that could be loaded again later if requested. + IsFile() bool +} diff --git a/core/caddy_test.go b/core/caddy_test.go new file mode 100644 index 000000000..1dc230a94 --- /dev/null +++ b/core/caddy_test.go @@ -0,0 +1,32 @@ +package core + +import ( + "net/http" + "testing" + "time" +) + +func TestCaddyStartStop(t *testing.T) { + caddyfile := "localhost:1984" + + for i := 0; i < 2; i++ { + err := Start(CaddyfileInput{Contents: []byte(caddyfile)}) + if err != nil { + t.Fatalf("Error starting, iteration %d: %v", i, err) + } + + client := http.Client{ + Timeout: time.Duration(2 * time.Second), + } + resp, err := client.Get("http://localhost:1984") + if err != nil { + t.Fatalf("Expected GET request to succeed (iteration %d), but it failed: %v", i, err) + } + resp.Body.Close() + + err = Stop() + if err != nil { + t.Fatalf("Error stopping, iteration %d: %v", i, err) + } + } +} diff --git a/core/caddyfile/json.go b/core/caddyfile/json.go new file mode 100644 index 000000000..6f4e66771 --- /dev/null +++ b/core/caddyfile/json.go @@ -0,0 +1,173 @@ +package caddyfile + +import ( + "bytes" + "encoding/json" + "fmt" + "sort" + "strconv" + "strings" + + "github.com/miekg/coredns/core/parse" +) + +const filename = "Caddyfile" + +// ToJSON converts caddyfile to its JSON representation. +func ToJSON(caddyfile []byte) ([]byte, error) { + var j Caddyfile + + serverBlocks, err := parse.ServerBlocks(filename, bytes.NewReader(caddyfile), false) + if err != nil { + return nil, err + } + + for _, sb := range serverBlocks { + block := ServerBlock{Body: [][]interface{}{}} + + // Fill up host list + for _, host := range sb.HostList() { + block.Hosts = append(block.Hosts, host) + } + + // Extract directives deterministically by sorting them + var directives = make([]string, len(sb.Tokens)) + for dir := range sb.Tokens { + directives = append(directives, dir) + } + sort.Strings(directives) + + // Convert each directive's tokens into our JSON structure + for _, dir := range directives { + disp := parse.NewDispenserTokens(filename, sb.Tokens[dir]) + for disp.Next() { + block.Body = append(block.Body, constructLine(&disp)) + } + } + + // tack this block onto the end of the list + j = append(j, block) + } + + result, err := json.Marshal(j) + if err != nil { + return nil, err + } + + return result, nil +} + +// constructLine transforms tokens into a JSON-encodable structure; +// but only one line at a time, to be used at the top-level of +// a server block only (where the first token on each line is a +// directive) - not to be used at any other nesting level. +func constructLine(d *parse.Dispenser) []interface{} { + var args []interface{} + + args = append(args, d.Val()) + + for d.NextArg() { + if d.Val() == "{" { + args = append(args, constructBlock(d)) + continue + } + args = append(args, d.Val()) + } + + return args +} + +// constructBlock recursively processes tokens into a +// JSON-encodable structure. To be used in a directive's +// block. Goes to end of block. +func constructBlock(d *parse.Dispenser) [][]interface{} { + block := [][]interface{}{} + + for d.Next() { + if d.Val() == "}" { + break + } + block = append(block, constructLine(d)) + } + + return block +} + +// FromJSON converts JSON-encoded jsonBytes to Caddyfile text +func FromJSON(jsonBytes []byte) ([]byte, error) { + var j Caddyfile + var result string + + err := json.Unmarshal(jsonBytes, &j) + if err != nil { + return nil, err + } + + for sbPos, sb := range j { + if sbPos > 0 { + result += "\n\n" + } + for i, host := range sb.Hosts { + if i > 0 { + result += ", " + } + result += host + } + result += jsonToText(sb.Body, 1) + } + + return []byte(result), nil +} + +// jsonToText recursively transforms a scope of JSON into plain +// Caddyfile text. +func jsonToText(scope interface{}, depth int) string { + var result string + + switch val := scope.(type) { + case string: + if strings.ContainsAny(val, "\" \n\t\r") { + result += `"` + strings.Replace(val, "\"", "\\\"", -1) + `"` + } else { + result += val + } + case int: + result += strconv.Itoa(val) + case float64: + result += fmt.Sprintf("%v", val) + case bool: + result += fmt.Sprintf("%t", val) + case [][]interface{}: + result += " {\n" + for _, arg := range val { + result += strings.Repeat("\t", depth) + jsonToText(arg, depth+1) + "\n" + } + result += strings.Repeat("\t", depth-1) + "}" + case []interface{}: + for i, v := range val { + if block, ok := v.([]interface{}); ok { + result += "{\n" + for _, arg := range block { + result += strings.Repeat("\t", depth) + jsonToText(arg, depth+1) + "\n" + } + result += strings.Repeat("\t", depth-1) + "}" + continue + } + result += jsonToText(v, depth) + if i < len(val)-1 { + result += " " + } + } + } + + return result +} + +// Caddyfile encapsulates a slice of ServerBlocks. +type Caddyfile []ServerBlock + +// ServerBlock represents a server block. +type ServerBlock struct { + Hosts []string `json:"hosts"` + Body [][]interface{} `json:"body"` +} diff --git a/core/caddyfile/json_test.go b/core/caddyfile/json_test.go new file mode 100644 index 000000000..2e44ae2a2 --- /dev/null +++ b/core/caddyfile/json_test.go @@ -0,0 +1,161 @@ +package caddyfile + +import "testing" + +var tests = []struct { + caddyfile, json string +}{ + { // 0 + caddyfile: `foo { + root /bar +}`, + json: `[{"hosts":["foo"],"body":[["root","/bar"]]}]`, + }, + { // 1 + caddyfile: `host1, host2 { + dir { + def + } +}`, + json: `[{"hosts":["host1","host2"],"body":[["dir",[["def"]]]]}]`, + }, + { // 2 + caddyfile: `host1, host2 { + dir abc { + def ghi + jkl + } +}`, + json: `[{"hosts":["host1","host2"],"body":[["dir","abc",[["def","ghi"],["jkl"]]]]}]`, + }, + { // 3 + caddyfile: `host1:1234, host2:5678 { + dir abc { + } +}`, + json: `[{"hosts":["host1:1234","host2:5678"],"body":[["dir","abc",[]]]}]`, + }, + { // 4 + caddyfile: `host { + foo "bar baz" +}`, + json: `[{"hosts":["host"],"body":[["foo","bar baz"]]}]`, + }, + { // 5 + caddyfile: `host, host:80 { + foo "bar \"baz\"" +}`, + json: `[{"hosts":["host","host:80"],"body":[["foo","bar \"baz\""]]}]`, + }, + { // 6 + caddyfile: `host { + foo "bar +baz" +}`, + json: `[{"hosts":["host"],"body":[["foo","bar\nbaz"]]}]`, + }, + { // 7 + caddyfile: `host { + dir 123 4.56 true +}`, + json: `[{"hosts":["host"],"body":[["dir","123","4.56","true"]]}]`, // NOTE: I guess we assume numbers and booleans should be encoded as strings...? + }, + { // 8 + caddyfile: `http://host, https://host { +}`, + json: `[{"hosts":["http://host","https://host"],"body":[]}]`, // hosts in JSON are always host:port format (if port is specified), for consistency + }, + { // 9 + caddyfile: `host { + dir1 a b + dir2 c d +}`, + json: `[{"hosts":["host"],"body":[["dir1","a","b"],["dir2","c","d"]]}]`, + }, + { // 10 + caddyfile: `host { + dir a b + dir c d +}`, + json: `[{"hosts":["host"],"body":[["dir","a","b"],["dir","c","d"]]}]`, + }, + { // 11 + caddyfile: `host { + dir1 a b + dir2 { + c + d + } +}`, + json: `[{"hosts":["host"],"body":[["dir1","a","b"],["dir2",[["c"],["d"]]]]}]`, + }, + { // 12 + caddyfile: `host1 { + dir1 +} + +host2 { + dir2 +}`, + json: `[{"hosts":["host1"],"body":[["dir1"]]},{"hosts":["host2"],"body":[["dir2"]]}]`, + }, +} + +func TestToJSON(t *testing.T) { + for i, test := range tests { + output, err := ToJSON([]byte(test.caddyfile)) + if err != nil { + t.Errorf("Test %d: %v", i, err) + } + if string(output) != test.json { + t.Errorf("Test %d\nExpected:\n'%s'\nActual:\n'%s'", i, test.json, string(output)) + } + } +} + +func TestFromJSON(t *testing.T) { + for i, test := range tests { + output, err := FromJSON([]byte(test.json)) + if err != nil { + t.Errorf("Test %d: %v", i, err) + } + if string(output) != test.caddyfile { + t.Errorf("Test %d\nExpected:\n'%s'\nActual:\n'%s'", i, test.caddyfile, string(output)) + } + } +} + +func TestStandardizeAddress(t *testing.T) { + // host:https should be converted to https://host + output, err := ToJSON([]byte(`host:https`)) + if err != nil { + t.Fatal(err) + } + if expected, actual := `[{"hosts":["https://host"],"body":[]}]`, string(output); expected != actual { + t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual) + } + + output, err = FromJSON([]byte(`[{"hosts":["https://host"],"body":[]}]`)) + if err != nil { + t.Fatal(err) + } + if expected, actual := "https://host {\n}", string(output); expected != actual { + t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual) + } + + // host: should be converted to just host + output, err = ToJSON([]byte(`host:`)) + if err != nil { + t.Fatal(err) + } + if expected, actual := `[{"hosts":["host"],"body":[]}]`, string(output); expected != actual { + t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual) + } + output, err = FromJSON([]byte(`[{"hosts":["host:"],"body":[]}]`)) + if err != nil { + t.Fatal(err) + } + if expected, actual := "host {\n}", string(output); expected != actual { + t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual) + } +} diff --git a/core/config.go b/core/config.go new file mode 100644 index 000000000..8c376d878 --- /dev/null +++ b/core/config.go @@ -0,0 +1,346 @@ +package core + +import ( + "bytes" + "fmt" + "io" + "log" + "net" + "sync" + + "github.com/miekg/coredns/core/https" + "github.com/miekg/coredns/core/parse" + "github.com/miekg/coredns/core/setup" + "github.com/miekg/coredns/server" +) + +const ( + // DefaultConfigFile is the name of the configuration file that is loaded + // by default if no other file is specified. + DefaultConfigFile = "Corefile" +) + +func loadConfigsUpToIncludingTLS(filename string, input io.Reader) ([]server.Config, []parse.ServerBlock, int, error) { + var configs []server.Config + + // Each server block represents similar hosts/addresses, since they + // were grouped together in the Caddyfile. + serverBlocks, err := parse.ServerBlocks(filename, input, true) + if err != nil { + return nil, nil, 0, err + } + if len(serverBlocks) == 0 { + newInput := DefaultInput() + serverBlocks, err = parse.ServerBlocks(newInput.Path(), bytes.NewReader(newInput.Body()), true) + if err != nil { + return nil, nil, 0, err + } + } + + var lastDirectiveIndex int // we set up directives in two parts; this stores where we left off + + // Iterate each server block and make a config for each one, + // executing the directives that were parsed in order up to the tls + // directive; this is because we must activate Let's Encrypt. + for i, sb := range serverBlocks { + onces := makeOnces() + storages := makeStorages() + + for j, addr := range sb.Addresses { + config := server.Config{ + Host: addr.Host, + Port: addr.Port, + Root: Root, + ConfigFile: filename, + AppName: AppName, + AppVersion: AppVersion, + } + + // It is crucial that directives are executed in the proper order. + for k, dir := range directiveOrder { + // Execute directive if it is in the server block + if tokens, ok := sb.Tokens[dir.name]; ok { + // Each setup function gets a controller, from which setup functions + // get access to the config, tokens, and other state information useful + // to set up its own host only. + controller := &setup.Controller{ + Config: &config, + Dispenser: parse.NewDispenserTokens(filename, tokens), + OncePerServerBlock: func(f func() error) error { + var err error + onces[dir.name].Do(func() { + err = f() + }) + return err + }, + ServerBlockIndex: i, + ServerBlockHostIndex: j, + ServerBlockHosts: sb.HostList(), + ServerBlockStorage: storages[dir.name], + } + // execute setup function and append middleware handler, if any + midware, err := dir.setup(controller) + if err != nil { + return nil, nil, lastDirectiveIndex, err + } + if midware != nil { + config.Middleware = append(config.Middleware, midware) + } + storages[dir.name] = controller.ServerBlockStorage // persist for this server block + } + + // Stop after TLS setup, since we need to activate Let's Encrypt before continuing; + // it makes some changes to the configs that middlewares might want to know about. + if dir.name == "tls" { + lastDirectiveIndex = k + break + } + } + + configs = append(configs, config) + } + } + return configs, serverBlocks, lastDirectiveIndex, nil +} + +// loadConfigs reads input (named filename) and parses it, returning the +// server configurations in the order they appeared in the input. As part +// of this, it activates Let's Encrypt for the configs that are produced. +// Thus, the returned configs are already optimally configured for HTTPS. +func loadConfigs(filename string, input io.Reader) ([]server.Config, error) { + configs, serverBlocks, lastDirectiveIndex, err := loadConfigsUpToIncludingTLS(filename, input) + if err != nil { + return nil, err + } + + // Now we have all the configs, but they have only been set up to the + // point of tls. We need to activate Let's Encrypt before setting up + // the rest of the middlewares so they have correct information regarding + // TLS configuration, if necessary. (this only appends, so our iterations + // over server blocks below shouldn't be affected) + if !IsRestart() && !Quiet { + fmt.Println("Activating privacy features...") + } + /* TODO(miek): stopped for now + configs, err = https.Activate(configs) + if err != nil { + return nil, err + } else if !IsRestart() && !Quiet { + fmt.Println(" done.") + } + */ + + // Finish setting up the rest of the directives, now that TLS is + // optimally configured. These loops are similar to above except + // we don't iterate all the directives from the beginning and we + // don't create new configs. + configIndex := -1 + for i, sb := range serverBlocks { + onces := makeOnces() + storages := makeStorages() + + for j := range sb.Addresses { + configIndex++ + + for k := lastDirectiveIndex + 1; k < len(directiveOrder); k++ { + dir := directiveOrder[k] + + if tokens, ok := sb.Tokens[dir.name]; ok { + controller := &setup.Controller{ + Config: &configs[configIndex], + Dispenser: parse.NewDispenserTokens(filename, tokens), + OncePerServerBlock: func(f func() error) error { + var err error + onces[dir.name].Do(func() { + err = f() + }) + return err + }, + ServerBlockIndex: i, + ServerBlockHostIndex: j, + ServerBlockHosts: sb.HostList(), + ServerBlockStorage: storages[dir.name], + } + midware, err := dir.setup(controller) + if err != nil { + return nil, err + } + if midware != nil { + configs[configIndex].Middleware = append(configs[configIndex].Middleware, midware) + } + storages[dir.name] = controller.ServerBlockStorage // persist for this server block + } + } + } + } + + return configs, nil +} + +// makeOnces makes a map of directive name to sync.Once +// instance. This is intended to be called once per server +// block when setting up configs so that Setup functions +// for each directive can perform a task just once per +// server block, even if there are multiple hosts on the block. +// +// We need one Once per directive, otherwise the first +// directive to use it would exclude other directives from +// using it at all, which would be a bug. +func makeOnces() map[string]*sync.Once { + onces := make(map[string]*sync.Once) + for _, dir := range directiveOrder { + onces[dir.name] = new(sync.Once) + } + return onces +} + +// makeStorages makes a map of directive name to interface{} +// so that directives' setup functions can persist state +// between different hosts on the same server block during the +// setup phase. +func makeStorages() map[string]interface{} { + storages := make(map[string]interface{}) + for _, dir := range directiveOrder { + storages[dir.name] = nil + } + return storages +} + +// arrangeBindings groups configurations by their bind address. For example, +// a server that should listen on localhost and another on 127.0.0.1 will +// be grouped into the same address: 127.0.0.1. It will return an error +// if an address is malformed or a TLS listener is configured on the +// same address as a plaintext HTTP listener. The return value is a map of +// bind address to list of configs that would become VirtualHosts on that +// server. Use the keys of the returned map to create listeners, and use +// the associated values to set up the virtualhosts. +func arrangeBindings(allConfigs []server.Config) (bindingGroup, error) { + var groupings bindingGroup + + // Group configs by bind address + for _, conf := range allConfigs { + // use default port if none is specified + if conf.Port == "" { + conf.Port = Port + } + + bindAddr, warnErr, fatalErr := resolveAddr(conf) + if fatalErr != nil { + return groupings, fatalErr + } + if warnErr != nil { + log.Printf("[WARNING] Resolving bind address for %s: %v", conf.Address(), warnErr) + } + + // Make sure to compare the string representation of the address, + // not the pointer, since a new *TCPAddr is created each time. + var existing bool + for i := 0; i < len(groupings); i++ { + if groupings[i].BindAddr.String() == bindAddr.String() { + groupings[i].Configs = append(groupings[i].Configs, conf) + existing = true + break + } + } + if !existing { + groupings = append(groupings, bindingMapping{ + BindAddr: bindAddr, + Configs: []server.Config{conf}, + }) + } + } + + // Don't allow HTTP and HTTPS to be served on the same address + for _, group := range groupings { + isTLS := group.Configs[0].TLS.Enabled + for _, config := range group.Configs { + if config.TLS.Enabled != isTLS { + thisConfigProto, otherConfigProto := "HTTP", "HTTP" + if config.TLS.Enabled { + thisConfigProto = "HTTPS" + } + if group.Configs[0].TLS.Enabled { + otherConfigProto = "HTTPS" + } + return groupings, fmt.Errorf("configuration error: Cannot multiplex %s (%s) and %s (%s) on same address", + group.Configs[0].Address(), otherConfigProto, config.Address(), thisConfigProto) + } + } + } + + return groupings, nil +} + +// resolveAddr determines the address (host and port) that a config will +// bind to. The returned address, resolvAddr, should be used to bind the +// listener or group the config with other configs using the same address. +// The first error, if not nil, is just a warning and should be reported +// but execution may continue. The second error, if not nil, is a real +// problem and the server should not be started. +// +// This function does not handle edge cases like port "http" or "https" if +// they are not known to the system. It does, however, serve on the wildcard +// host if resolving the address of the specific hostname fails. +func resolveAddr(conf server.Config) (resolvAddr *net.TCPAddr, warnErr, fatalErr error) { + resolvAddr, warnErr = net.ResolveTCPAddr("tcp", net.JoinHostPort(conf.BindHost, conf.Port)) + if warnErr != nil { + // the hostname probably couldn't be resolved, just bind to wildcard then + resolvAddr, fatalErr = net.ResolveTCPAddr("tcp", net.JoinHostPort("", conf.Port)) + if fatalErr != nil { + return + } + } + + return +} + +// validDirective returns true if d is a valid +// directive; false otherwise. +func validDirective(d string) bool { + for _, dir := range directiveOrder { + if dir.name == d { + return true + } + } + return false +} + +// DefaultInput returns the default Caddyfile input +// to use when it is otherwise empty or missing. +// It uses the default host and port (depends on +// host, e.g. localhost is 2015, otherwise 443) and +// root. +func DefaultInput() CaddyfileInput { + port := Port + if https.HostQualifies(Host) && port == DefaultPort { + port = "443" + } + return CaddyfileInput{ + Contents: []byte(fmt.Sprintf("%s:%s\nroot %s", Host, port, Root)), + } +} + +// These defaults are configurable through the command line +var ( + // Root is the site root + Root = DefaultRoot + + // Host is the site host + Host = DefaultHost + + // Port is the site port + Port = DefaultPort +) + +// bindingMapping maps a network address to configurations +// that will bind to it. The order of the configs is important. +type bindingMapping struct { + BindAddr *net.TCPAddr + Configs []server.Config +} + +// bindingGroup maps network addresses to their configurations. +// Preserving the order of the groupings is important +// (related to graceful shutdown and restart) +// so this is a slice, not a literal map. +type bindingGroup []bindingMapping diff --git a/core/config_test.go b/core/config_test.go new file mode 100644 index 000000000..512128c0b --- /dev/null +++ b/core/config_test.go @@ -0,0 +1,159 @@ +package core + +import ( + "reflect" + "sync" + "testing" + + "github.com/miekg/coredns/server" +) + +func TestDefaultInput(t *testing.T) { + if actual, expected := string(DefaultInput().Body()), ":2015\nroot ."; actual != expected { + t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) + } + + // next few tests simulate user providing -host and/or -port flags + + Host = "not-localhost.com" + if actual, expected := string(DefaultInput().Body()), "not-localhost.com:443\nroot ."; actual != expected { + t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) + } + + Host = "[::1]" + if actual, expected := string(DefaultInput().Body()), "[::1]:2015\nroot ."; actual != expected { + t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) + } + + Host = "127.0.1.1" + if actual, expected := string(DefaultInput().Body()), "127.0.1.1:2015\nroot ."; actual != expected { + t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) + } + + Host = "not-localhost.com" + Port = "1234" + if actual, expected := string(DefaultInput().Body()), "not-localhost.com:1234\nroot ."; actual != expected { + t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) + } + + Host = DefaultHost + Port = "1234" + if actual, expected := string(DefaultInput().Body()), ":1234\nroot ."; actual != expected { + t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual) + } +} + +func TestResolveAddr(t *testing.T) { + // NOTE: If tests fail due to comparing to string "127.0.0.1", + // it's possible that system env resolves with IPv6, or ::1. + // If that happens, maybe we should use actualAddr.IP.IsLoopback() + // for the assertion, rather than a direct string comparison. + + // NOTE: Tests with {Host: "", Port: ""} and {Host: "localhost", Port: ""} + // will not behave the same cross-platform, so they have been omitted. + + for i, test := range []struct { + config server.Config + shouldWarnErr bool + shouldFatalErr bool + expectedIP string + expectedPort int + }{ + {server.Config{Host: "127.0.0.1", Port: "1234"}, false, false, "<nil>", 1234}, + {server.Config{Host: "localhost", Port: "80"}, false, false, "<nil>", 80}, + {server.Config{BindHost: "localhost", Port: "1234"}, false, false, "127.0.0.1", 1234}, + {server.Config{BindHost: "127.0.0.1", Port: "1234"}, false, false, "127.0.0.1", 1234}, + {server.Config{BindHost: "should-not-resolve", Port: "1234"}, true, false, "<nil>", 1234}, + {server.Config{BindHost: "localhost", Port: "http"}, false, false, "127.0.0.1", 80}, + {server.Config{BindHost: "localhost", Port: "https"}, false, false, "127.0.0.1", 443}, + {server.Config{BindHost: "", Port: "1234"}, false, false, "<nil>", 1234}, + {server.Config{BindHost: "localhost", Port: "abcd"}, false, true, "", 0}, + {server.Config{BindHost: "127.0.0.1", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234}, + {server.Config{BindHost: "localhost", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234}, + {server.Config{BindHost: "should-not-resolve", Host: "localhost", Port: "1234"}, true, false, "<nil>", 1234}, + } { + actualAddr, warnErr, fatalErr := resolveAddr(test.config) + + if test.shouldFatalErr && fatalErr == nil { + t.Errorf("Test %d: Expected error, but there wasn't any", i) + } + if !test.shouldFatalErr && fatalErr != nil { + t.Errorf("Test %d: Expected no error, but there was one: %v", i, fatalErr) + } + if fatalErr != nil { + continue + } + + if test.shouldWarnErr && warnErr == nil { + t.Errorf("Test %d: Expected warning, but there wasn't any", i) + } + if !test.shouldWarnErr && warnErr != nil { + t.Errorf("Test %d: Expected no warning, but there was one: %v", i, warnErr) + } + + if actual, expected := actualAddr.IP.String(), test.expectedIP; actual != expected { + t.Errorf("Test %d: IP was %s but expected %s", i, actual, expected) + } + if actual, expected := actualAddr.Port, test.expectedPort; actual != expected { + t.Errorf("Test %d: Port was %d but expected %d", i, actual, expected) + } + } +} + +func TestMakeOnces(t *testing.T) { + directives := []directive{ + {"dummy", nil}, + {"dummy2", nil}, + } + directiveOrder = directives + onces := makeOnces() + if len(onces) != len(directives) { + t.Errorf("onces had len %d , expected %d", len(onces), len(directives)) + } + expected := map[string]*sync.Once{ + "dummy": new(sync.Once), + "dummy2": new(sync.Once), + } + if !reflect.DeepEqual(onces, expected) { + t.Errorf("onces was %v, expected %v", onces, expected) + } +} + +func TestMakeStorages(t *testing.T) { + directives := []directive{ + {"dummy", nil}, + {"dummy2", nil}, + } + directiveOrder = directives + storages := makeStorages() + if len(storages) != len(directives) { + t.Errorf("storages had len %d , expected %d", len(storages), len(directives)) + } + expected := map[string]interface{}{ + "dummy": nil, + "dummy2": nil, + } + if !reflect.DeepEqual(storages, expected) { + t.Errorf("storages was %v, expected %v", storages, expected) + } +} + +func TestValidDirective(t *testing.T) { + directives := []directive{ + {"dummy", nil}, + {"dummy2", nil}, + } + directiveOrder = directives + for i, test := range []struct { + directive string + valid bool + }{ + {"dummy", true}, + {"dummy2", true}, + {"dummy3", false}, + } { + if actual, expected := validDirective(test.directive), test.valid; actual != expected { + t.Errorf("Test %d: valid was %t, expected %t", i, actual, expected) + } + } +} diff --git a/core/directives.go b/core/directives.go new file mode 100644 index 000000000..96f4d910c --- /dev/null +++ b/core/directives.go @@ -0,0 +1,89 @@ +package core + +import ( + "github.com/miekg/coredns/core/https" + "github.com/miekg/coredns/core/parse" + "github.com/miekg/coredns/core/setup" + "github.com/miekg/coredns/middleware" +) + +func init() { + // The parse package must know which directives + // are valid, but it must not import the setup + // or config package. To solve this problem, we + // fill up this map in our init function here. + // The parse package does not need to know the + // ordering of the directives. + for _, dir := range directiveOrder { + parse.ValidDirectives[dir.name] = struct{}{} + } +} + +// Directives are registered in the order they should be +// executed. Middleware (directives that inject a handler) +// are executed in the order A-B-C-*-C-B-A, assuming +// they all call the Next handler in the chain. +// +// Ordering is VERY important. Every middleware will +// feel the effects of all other middleware below +// (after) them during a request, but they must not +// care what middleware above them are doing. +// +// For example, log needs to know the status code and +// exactly how many bytes were written to the client, +// which every other middleware can affect, so it gets +// registered first. The errors middleware does not +// care if gzip or log modifies its response, so it +// gets registered below them. Gzip, on the other hand, +// DOES care what errors does to the response since it +// must compress every output to the client, even error +// pages, so it must be registered before the errors +// middleware and any others that would write to the +// response. +var directiveOrder = []directive{ + // Essential directives that initialize vital configuration settings + {"root", setup.Root}, + {"bind", setup.BindHost}, + {"tls", https.Setup}, + + // Other directives that don't create HTTP handlers + {"startup", setup.Startup}, + {"shutdown", setup.Shutdown}, + + // Directives that inject handlers (middleware) + {"prometheus", setup.Prometheus}, + {"rewrite", setup.Rewrite}, + {"file", setup.File}, + {"reflect", setup.Reflect}, + {"log", setup.Log}, + {"errors", setup.Errors}, + {"proxy", setup.Proxy}, +} + +// RegisterDirective adds the given directive to caddy's list of directives. +// Pass the name of a directive you want it to be placed after, +// otherwise it will be placed at the bottom of the stack. +func RegisterDirective(name string, setup SetupFunc, after string) { + dir := directive{name: name, setup: setup} + idx := len(directiveOrder) + for i := range directiveOrder { + if directiveOrder[i].name == after { + idx = i + 1 + break + } + } + newDirectives := append(directiveOrder[:idx], append([]directive{dir}, directiveOrder[idx:]...)...) + directiveOrder = newDirectives + parse.ValidDirectives[name] = struct{}{} +} + +// directive ties together a directive name with its setup function. +type directive struct { + name string + setup SetupFunc +} + +// SetupFunc takes a controller and may optionally return a middleware. +// If the resulting middleware is not nil, it will be chained into +// the HTTP handlers in the order specified in this package. +type SetupFunc func(c *setup.Controller) (middleware.Middleware, error) diff --git a/core/directives_test.go b/core/directives_test.go new file mode 100644 index 000000000..1bee144f5 --- /dev/null +++ b/core/directives_test.go @@ -0,0 +1,31 @@ +package core + +import ( + "reflect" + "testing" +) + +func TestRegister(t *testing.T) { + directives := []directive{ + {"dummy", nil}, + {"dummy2", nil}, + } + directiveOrder = directives + RegisterDirective("foo", nil, "dummy") + if len(directiveOrder) != 3 { + t.Fatal("Should have 3 directives now") + } + getNames := func() (s []string) { + for _, d := range directiveOrder { + s = append(s, d.name) + } + return s + } + if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2"}) { + t.Fatalf("directive order doesn't match: %s", getNames()) + } + RegisterDirective("bar", nil, "ASDASD") + if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2", "bar"}) { + t.Fatalf("directive order doesn't match: %s", getNames()) + } +} diff --git a/core/helpers.go b/core/helpers.go new file mode 100644 index 000000000..8ef6a54cd --- /dev/null +++ b/core/helpers.go @@ -0,0 +1,102 @@ +package core + +import ( + "bytes" + "fmt" + "io/ioutil" + "log" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "sync" +) + +// isLocalhost returns true if host looks explicitly like a localhost address. +func isLocalhost(host string) bool { + return host == "localhost" || host == "::1" || strings.HasPrefix(host, "127.") +} + +// checkFdlimit issues a warning if the OS max file descriptors is below a recommended minimum. +func checkFdlimit() { + const min = 4096 + + // Warn if ulimit is too low for production sites + if runtime.GOOS == "linux" || runtime.GOOS == "darwin" { + out, err := exec.Command("sh", "-c", "ulimit -n").Output() // use sh because ulimit isn't in Linux $PATH + if err == nil { + // Note that an error here need not be reported + lim, err := strconv.Atoi(string(bytes.TrimSpace(out))) + if err == nil && lim < min { + fmt.Printf("Warning: File descriptor limit %d is too low for production sites. At least %d is recommended. Set with \"ulimit -n %d\".\n", lim, min, min) + } + } + } +} + +// signalSuccessToParent tells the parent our status using pipe at index 3. +// If this process is not a restart, this function does nothing. +// Calling this function once this process has successfully initialized +// is vital so that the parent process can unblock and kill itself. +// This function is idempotent; it executes at most once per process. +func signalSuccessToParent() { + signalParentOnce.Do(func() { + if IsRestart() { + ppipe := os.NewFile(3, "") // parent is reading from pipe at index 3 + _, err := ppipe.Write([]byte("success")) // we must send some bytes to the parent + if err != nil { + log.Printf("[ERROR] Communicating successful init to parent: %v", err) + } + ppipe.Close() + } + }) +} + +// signalParentOnce is used to make sure that the parent is only +// signaled once; doing so more than once breaks whatever socket is +// at fd 4 (the reason for this is still unclear - to reproduce, +// call Stop() and Start() in succession at least once after a +// restart, then try loading first host of Caddyfile in the browser). +// Do not use this directly - call signalSuccessToParent instead. +var signalParentOnce sync.Once + +// caddyfileGob maps bind address to index of the file descriptor +// in the Files array passed to the child process. It also contains +// the caddyfile contents and other state needed by the new process. +// Used only during graceful restarts where a new process is spawned. +type caddyfileGob struct { + ListenerFds map[string]uintptr + Caddyfile Input + OnDemandTLSCertsIssued int32 +} + +// IsRestart returns whether this process is, according +// to env variables, a fork as part of a graceful restart. +func IsRestart() bool { + return os.Getenv("CADDY_RESTART") == "true" +} + +// writePidFile writes the process ID to the file at PidFile, if specified. +func writePidFile() error { + pid := []byte(strconv.Itoa(os.Getpid()) + "\n") + return ioutil.WriteFile(PidFile, pid, 0644) +} + +// CaddyfileInput represents a Caddyfile as input +// and is simply a convenient way to implement +// the Input interface. +type CaddyfileInput struct { + Filepath string + Contents []byte + RealFile bool +} + +// Body returns c.Contents. +func (c CaddyfileInput) Body() []byte { return c.Contents } + +// Path returns c.Filepath. +func (c CaddyfileInput) Path() string { return c.Filepath } + +// IsFile returns true if the original input was a real file on the file system. +func (c CaddyfileInput) IsFile() bool { return c.RealFile } diff --git a/core/https/certificates.go b/core/https/certificates.go new file mode 100644 index 000000000..0dc3db523 --- /dev/null +++ b/core/https/certificates.go @@ -0,0 +1,234 @@ +package https + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "io/ioutil" + "log" + "strings" + "sync" + "time" + + "github.com/xenolf/lego/acme" + "golang.org/x/crypto/ocsp" +) + +// certCache stores certificates in memory, +// keying certificates by name. +var certCache = make(map[string]Certificate) +var certCacheMu sync.RWMutex + +// Certificate is a tls.Certificate with associated metadata tacked on. +// Even if the metadata can be obtained by parsing the certificate, +// we can be more efficient by extracting the metadata once so it's +// just there, ready to use. +type Certificate struct { + tls.Certificate + + // Names is the list of names this certificate is written for. + // The first is the CommonName (if any), the rest are SAN. + Names []string + + // NotAfter is when the certificate expires. + NotAfter time.Time + + // Managed certificates are certificates that Caddy is managing, + // as opposed to the user specifying a certificate and key file + // or directory and managing the certificate resources themselves. + Managed bool + + // OnDemand certificates are obtained or loaded on-demand during TLS + // handshakes (as opposed to preloaded certificates, which are loaded + // at startup). If OnDemand is true, Managed must necessarily be true. + // OnDemand certificates are maintained in the background just like + // preloaded ones, however, if an OnDemand certificate fails to renew, + // it is removed from the in-memory cache. + OnDemand bool + + // OCSP contains the certificate's parsed OCSP response. + OCSP *ocsp.Response +} + +// getCertificate gets a certificate that matches name (a server name) +// from the in-memory cache. If there is no exact match for name, it +// will be checked against names of the form '*.example.com' (wildcard +// certificates) according to RFC 6125. If a match is found, matched will +// be true. If no matches are found, matched will be false and a default +// certificate will be returned with defaulted set to true. If no default +// certificate is set, defaulted will be set to false. +// +// The logic in this function is adapted from the Go standard library, +// which is by the Go Authors. +// +// This function is safe for concurrent use. +func getCertificate(name string) (cert Certificate, matched, defaulted bool) { + var ok bool + + // Not going to trim trailing dots here since RFC 3546 says, + // "The hostname is represented ... without a trailing dot." + // Just normalize to lowercase. + name = strings.ToLower(name) + + certCacheMu.RLock() + defer certCacheMu.RUnlock() + + // exact match? great, let's use it + if cert, ok = certCache[name]; ok { + matched = true + return + } + + // try replacing labels in the name with wildcards until we get a match + labels := strings.Split(name, ".") + for i := range labels { + labels[i] = "*" + candidate := strings.Join(labels, ".") + if cert, ok = certCache[candidate]; ok { + matched = true + return + } + } + + // if nothing matches, use the default certificate or bust + cert, defaulted = certCache[""] + return +} + +// cacheManagedCertificate loads the certificate for domain into the +// cache, flagging it as Managed and, if onDemand is true, as OnDemand +// (meaning that it was obtained or loaded during a TLS handshake). +// +// This function is safe for concurrent use. +func cacheManagedCertificate(domain string, onDemand bool) (Certificate, error) { + cert, err := makeCertificateFromDisk(storage.SiteCertFile(domain), storage.SiteKeyFile(domain)) + if err != nil { + return cert, err + } + cert.Managed = true + cert.OnDemand = onDemand + cacheCertificate(cert) + return cert, nil +} + +// cacheUnmanagedCertificatePEMFile loads a certificate for host using certFile +// and keyFile, which must be in PEM format. It stores the certificate in +// memory. The Managed and OnDemand flags of the certificate will be set to +// false. +// +// This function is safe for concurrent use. +func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error { + cert, err := makeCertificateFromDisk(certFile, keyFile) + if err != nil { + return err + } + cacheCertificate(cert) + return nil +} + +// cacheUnmanagedCertificatePEMBytes makes a certificate out of the PEM bytes +// of the certificate and key, then caches it in memory. +// +// This function is safe for concurrent use. +func cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error { + cert, err := makeCertificate(certBytes, keyBytes) + if err != nil { + return err + } + cacheCertificate(cert) + return nil +} + +// makeCertificateFromDisk makes a Certificate by loading the +// certificate and key files. It fills out all the fields in +// the certificate except for the Managed and OnDemand flags. +// (It is up to the caller to set those.) +func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) { + certPEMBlock, err := ioutil.ReadFile(certFile) + if err != nil { + return Certificate{}, err + } + keyPEMBlock, err := ioutil.ReadFile(keyFile) + if err != nil { + return Certificate{}, err + } + return makeCertificate(certPEMBlock, keyPEMBlock) +} + +// makeCertificate turns a certificate PEM bundle and a key PEM block into +// a Certificate, with OCSP and other relevant metadata tagged with it, +// except for the OnDemand and Managed flags. It is up to the caller to +// set those properties. +func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) { + var cert Certificate + + // Convert to a tls.Certificate + tlsCert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) + if err != nil { + return cert, err + } + if len(tlsCert.Certificate) == 0 { + return cert, errors.New("certificate is empty") + } + + // Parse leaf certificate and extract relevant metadata + leaf, err := x509.ParseCertificate(tlsCert.Certificate[0]) + if err != nil { + return cert, err + } + if leaf.Subject.CommonName != "" { + cert.Names = []string{strings.ToLower(leaf.Subject.CommonName)} + } + for _, name := range leaf.DNSNames { + if name != leaf.Subject.CommonName { + cert.Names = append(cert.Names, strings.ToLower(name)) + } + } + cert.NotAfter = leaf.NotAfter + + // Staple OCSP + ocspBytes, ocspResp, err := acme.GetOCSPForCert(certPEMBlock) + if err != nil { + // An error here is not a problem because a certificate may simply + // not contain a link to an OCSP server. But we should log it anyway. + log.Printf("[WARNING] No OCSP stapling for %v: %v", cert.Names, err) + } else if ocspResp.Status == ocsp.Good { + tlsCert.OCSPStaple = ocspBytes + cert.OCSP = ocspResp + } + + cert.Certificate = tlsCert + return cert, nil +} + +// cacheCertificate adds cert to the in-memory cache. If the cache is +// empty, cert will be used as the default certificate. If the cache is +// full, random entries are deleted until there is room to map all the +// names on the certificate. +// +// This certificate will be keyed to the names in cert.Names. Any name +// that is already a key in the cache will be replaced with this cert. +// +// This function is safe for concurrent use. +func cacheCertificate(cert Certificate) { + certCacheMu.Lock() + if _, ok := certCache[""]; !ok { + // use as default + cert.Names = append(cert.Names, "") + certCache[""] = cert + } + for len(certCache)+len(cert.Names) > 10000 { + // for simplicity, just remove random elements + for key := range certCache { + if key == "" { // ... but not the default cert + continue + } + delete(certCache, key) + break + } + } + for _, name := range cert.Names { + certCache[name] = cert + } + certCacheMu.Unlock() +} diff --git a/core/https/certificates_test.go b/core/https/certificates_test.go new file mode 100644 index 000000000..dbfb4efc1 --- /dev/null +++ b/core/https/certificates_test.go @@ -0,0 +1,59 @@ +package https + +import "testing" + +func TestUnexportedGetCertificate(t *testing.T) { + defer func() { certCache = make(map[string]Certificate) }() + + // When cache is empty + if _, matched, defaulted := getCertificate("example.com"); matched || defaulted { + t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted) + } + + // When cache has one certificate in it (also is default) + defaultCert := Certificate{Names: []string{"example.com", ""}} + certCache[""] = defaultCert + certCache["example.com"] = defaultCert + if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" { + t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) + } + if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" { + t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) + } + + // When retrieving wildcard certificate + certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}} + if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" { + t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted) + } + + // When no certificate matches, the default is returned + if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted { + t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert) + } else if cert.Names[0] != "example.com" { + t.Errorf("Expected default cert, got: %v", cert) + } +} + +func TestCacheCertificate(t *testing.T) { + defer func() { certCache = make(map[string]Certificate) }() + + cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}}) + if _, ok := certCache["example.com"]; !ok { + t.Error("Expected first cert to be cached by key 'example.com', but it wasn't") + } + if _, ok := certCache["sub.example.com"]; !ok { + t.Error("Expected first cert to be cached by key 'sub.exmaple.com', but it wasn't") + } + if cert, ok := certCache[""]; !ok || cert.Names[2] != "" { + t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't") + } + + cacheCertificate(Certificate{Names: []string{"example2.com"}}) + if _, ok := certCache["example2.com"]; !ok { + t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't") + } + if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" { + t.Error("Expected second cert to NOT be cached as default, but it was") + } +} diff --git a/core/https/client.go b/core/https/client.go new file mode 100644 index 000000000..e9e8cd82c --- /dev/null +++ b/core/https/client.go @@ -0,0 +1,215 @@ +package https + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net" + "sync" + "time" + + "github.com/miekg/coredns/server" + "github.com/xenolf/lego/acme" +) + +// acmeMu ensures that only one ACME challenge occurs at a time. +var acmeMu sync.Mutex + +// ACMEClient is an acme.Client with custom state attached. +type ACMEClient struct { + *acme.Client + AllowPrompts bool // if false, we assume AlternatePort must be used +} + +// NewACMEClient creates a new ACMEClient given an email and whether +// prompting the user is allowed. Clients should not be kept and +// re-used over long periods of time, but immediate re-use is more +// efficient than re-creating on every iteration. +var NewACMEClient = func(email string, allowPrompts bool) (*ACMEClient, error) { + // Look up or create the LE user account + leUser, err := getUser(email) + if err != nil { + return nil, err + } + + // The client facilitates our communication with the CA server. + client, err := acme.NewClient(CAUrl, &leUser, KeyType) + if err != nil { + return nil, err + } + + // If not registered, the user must register an account with the CA + // and agree to terms + if leUser.Registration == nil { + reg, err := client.Register() + if err != nil { + return nil, errors.New("registration error: " + err.Error()) + } + leUser.Registration = reg + + if allowPrompts { // can't prompt a user who isn't there + if !Agreed && reg.TosURL == "" { + Agreed = promptUserAgreement(saURL, false) // TODO - latest URL + } + if !Agreed && reg.TosURL == "" { + return nil, errors.New("user must agree to terms") + } + } + + err = client.AgreeToTOS() + if err != nil { + saveUser(leUser) // Might as well try, right? + return nil, errors.New("error agreeing to terms: " + err.Error()) + } + + // save user to the file system + err = saveUser(leUser) + if err != nil { + return nil, errors.New("could not save user: " + err.Error()) + } + } + + return &ACMEClient{ + Client: client, + AllowPrompts: allowPrompts, + }, nil +} + +// NewACMEClientGetEmail creates a new ACMEClient and gets an email +// address at the same time (a server config is required, since it +// may contain an email address in it). +func NewACMEClientGetEmail(config server.Config, allowPrompts bool) (*ACMEClient, error) { + return NewACMEClient(getEmail(config, allowPrompts), allowPrompts) +} + +// Configure configures c according to bindHost, which is the host (not +// whole address) to bind the listener to in solving the http and tls-sni +// challenges. +func (c *ACMEClient) Configure(bindHost string) { + // If we allow prompts, operator must be present. In our case, + // that is synonymous with saying the server is not already + // started. So if the user is still there, we don't use + // AlternatePort because we don't need to proxy the challenges. + // Conversely, if the operator is not there, the server has + // already started and we need to proxy the challenge. + if c.AllowPrompts { + // Operator is present; server is not already listening + c.SetHTTPAddress(net.JoinHostPort(bindHost, "")) + c.SetTLSAddress(net.JoinHostPort(bindHost, "")) + //c.ExcludeChallenges([]acme.Challenge{acme.DNS01}) + } else { + // Operator is not present; server is started, so proxy challenges + c.SetHTTPAddress(net.JoinHostPort(bindHost, AlternatePort)) + c.SetTLSAddress(net.JoinHostPort(bindHost, AlternatePort)) + //c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) + } + c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) // TODO: can we proxy TLS challenges? and we should support DNS... +} + +// Obtain obtains a single certificate for names. It stores the certificate +// on the disk if successful. +func (c *ACMEClient) Obtain(names []string) error { +Attempts: + for attempts := 0; attempts < 2; attempts++ { + acmeMu.Lock() + certificate, failures := c.ObtainCertificate(names, true, nil) + acmeMu.Unlock() + if len(failures) > 0 { + // Error - try to fix it or report it to the user and abort + var errMsg string // we'll combine all the failures into a single error message + var promptedForAgreement bool // only prompt user for agreement at most once + + for errDomain, obtainErr := range failures { + // TODO: Double-check, will obtainErr ever be nil? + if tosErr, ok := obtainErr.(acme.TOSError); ok { + // Terms of Service agreement error; we can probably deal with this + if !Agreed && !promptedForAgreement && c.AllowPrompts { + Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL + promptedForAgreement = true + } + if Agreed || !c.AllowPrompts { + err := c.AgreeToTOS() + if err != nil { + return errors.New("error agreeing to updated terms: " + err.Error()) + } + continue Attempts + } + } + + // If user did not agree or it was any other kind of error, just append to the list of errors + errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n" + } + return errors.New(errMsg) + } + + // Success - immediately save the certificate resource + err := saveCertResource(certificate) + if err != nil { + return fmt.Errorf("error saving assets for %v: %v", names, err) + } + + break + } + + return nil +} + +// Renew renews the managed certificate for name. Right now our storage +// mechanism only supports one name per certificate, so this function only +// accepts one domain as input. It can be easily modified to support SAN +// certificates if, one day, they become desperately needed enough that our +// storage mechanism is upgraded to be more complex to support SAN certs. +// +// Anyway, this function is safe for concurrent use. +func (c *ACMEClient) Renew(name string) error { + // Prepare for renewal (load PEM cert, key, and meta) + certBytes, err := ioutil.ReadFile(storage.SiteCertFile(name)) + if err != nil { + return err + } + keyBytes, err := ioutil.ReadFile(storage.SiteKeyFile(name)) + if err != nil { + return err + } + metaBytes, err := ioutil.ReadFile(storage.SiteMetaFile(name)) + if err != nil { + return err + } + var certMeta acme.CertificateResource + err = json.Unmarshal(metaBytes, &certMeta) + certMeta.Certificate = certBytes + certMeta.PrivateKey = keyBytes + + // Perform renewal and retry if necessary, but not too many times. + var newCertMeta acme.CertificateResource + var success bool + for attempts := 0; attempts < 2; attempts++ { + acmeMu.Lock() + newCertMeta, err = c.RenewCertificate(certMeta, true) + acmeMu.Unlock() + if err == nil { + success = true + break + } + + // If the legal terms changed and need to be agreed to again, + // we can handle that. + if _, ok := err.(acme.TOSError); ok { + err := c.AgreeToTOS() + if err != nil { + return err + } + continue + } + + // For any other kind of error, wait 10s and try again. + time.Sleep(10 * time.Second) + } + + if !success { + return errors.New("too many renewal attempts; last error: " + err.Error()) + } + + return saveCertResource(newCertMeta) +} diff --git a/core/https/crypto.go b/core/https/crypto.go new file mode 100644 index 000000000..bc0ff6373 --- /dev/null +++ b/core/https/crypto.go @@ -0,0 +1,57 @@ +package https + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "io/ioutil" + "os" +) + +// loadPrivateKey loads a PEM-encoded ECC/RSA private key from file. +func loadPrivateKey(file string) (crypto.PrivateKey, error) { + keyBytes, err := ioutil.ReadFile(file) + if err != nil { + return nil, err + } + keyBlock, _ := pem.Decode(keyBytes) + + switch keyBlock.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(keyBlock.Bytes) + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(keyBlock.Bytes) + } + + return nil, errors.New("unknown private key type") +} + +// savePrivateKey saves a PEM-encoded ECC/RSA private key to file. +func savePrivateKey(key crypto.PrivateKey, file string) error { + var pemType string + var keyBytes []byte + switch key := key.(type) { + case *ecdsa.PrivateKey: + var err error + pemType = "EC" + keyBytes, err = x509.MarshalECPrivateKey(key) + if err != nil { + return err + } + case *rsa.PrivateKey: + pemType = "RSA" + keyBytes = x509.MarshalPKCS1PrivateKey(key) + } + + pemKey := pem.Block{Type: pemType + " PRIVATE KEY", Bytes: keyBytes} + keyOut, err := os.Create(file) + if err != nil { + return err + } + keyOut.Chmod(0600) + defer keyOut.Close() + return pem.Encode(keyOut, &pemKey) +} diff --git a/core/https/crypto_test.go b/core/https/crypto_test.go new file mode 100644 index 000000000..c1f32b27d --- /dev/null +++ b/core/https/crypto_test.go @@ -0,0 +1,111 @@ +package https + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "os" + "runtime" + "testing" +) + +func TestSaveAndLoadRSAPrivateKey(t *testing.T) { + keyFile := "test.key" + defer os.Remove(keyFile) + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + + // test save + err = savePrivateKey(privateKey, keyFile) + if err != nil { + t.Fatal("error saving private key:", err) + } + + // it doesn't make sense to test file permission on windows + if runtime.GOOS != "windows" { + // get info of the key file + info, err := os.Stat(keyFile) + if err != nil { + t.Fatal("error stating private key:", err) + } + // verify permission of key file is correct + if info.Mode().Perm() != 0600 { + t.Error("Expected key file to have permission 0600, but it wasn't") + } + } + + // test load + loadedKey, err := loadPrivateKey(keyFile) + if err != nil { + t.Error("error loading private key:", err) + } + + // verify loaded key is correct + if !PrivateKeysSame(privateKey, loadedKey) { + t.Error("Expected key bytes to be the same, but they weren't") + } +} + +func TestSaveAndLoadECCPrivateKey(t *testing.T) { + keyFile := "test.key" + defer os.Remove(keyFile) + + privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + // test save + err = savePrivateKey(privateKey, keyFile) + if err != nil { + t.Fatal("error saving private key:", err) + } + + // it doesn't make sense to test file permission on windows + if runtime.GOOS != "windows" { + // get info of the key file + info, err := os.Stat(keyFile) + if err != nil { + t.Fatal("error stating private key:", err) + } + // verify permission of key file is correct + if info.Mode().Perm() != 0600 { + t.Error("Expected key file to have permission 0600, but it wasn't") + } + } + + // test load + loadedKey, err := loadPrivateKey(keyFile) + if err != nil { + t.Error("error loading private key:", err) + } + + // verify loaded key is correct + if !PrivateKeysSame(privateKey, loadedKey) { + t.Error("Expected key bytes to be the same, but they weren't") + } +} + +// PrivateKeysSame compares the bytes of a and b and returns true if they are the same. +func PrivateKeysSame(a, b crypto.PrivateKey) bool { + return bytes.Equal(PrivateKeyBytes(a), PrivateKeyBytes(b)) +} + +// PrivateKeyBytes returns the bytes of DER-encoded key. +func PrivateKeyBytes(key crypto.PrivateKey) []byte { + var keyBytes []byte + switch key := key.(type) { + case *rsa.PrivateKey: + keyBytes = x509.MarshalPKCS1PrivateKey(key) + case *ecdsa.PrivateKey: + keyBytes, _ = x509.MarshalECPrivateKey(key) + } + return keyBytes +} diff --git a/core/https/handler.go b/core/https/handler.go new file mode 100644 index 000000000..f3139f54e --- /dev/null +++ b/core/https/handler.go @@ -0,0 +1,42 @@ +package https + +import ( + "crypto/tls" + "log" + "net/http" + "net/http/httputil" + "net/url" + "strings" +) + +const challengeBasePath = "/.well-known/acme-challenge" + +// RequestCallback proxies challenge requests to ACME client if the +// request path starts with challengeBasePath. It returns true if it +// handled the request and no more needs to be done; it returns false +// if this call was a no-op and the request still needs handling. +func RequestCallback(w http.ResponseWriter, r *http.Request) bool { + if strings.HasPrefix(r.URL.Path, challengeBasePath) { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + upstream, err := url.Parse(scheme + "://localhost:" + AlternatePort) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + log.Printf("[ERROR] ACME proxy handler: %v", err) + return true + } + + proxy := httputil.NewSingleHostReverseProxy(upstream) + proxy.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // solver uses self-signed certs + } + proxy.ServeHTTP(w, r) + + return true + } + + return false +} diff --git a/core/https/handler_test.go b/core/https/handler_test.go new file mode 100644 index 000000000..016799ffb --- /dev/null +++ b/core/https/handler_test.go @@ -0,0 +1,63 @@ +package https + +import ( + "net" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRequestCallbackNoOp(t *testing.T) { + // try base paths that aren't handled by this handler + for _, url := range []string{ + "http://localhost/", + "http://localhost/foo.html", + "http://localhost/.git", + "http://localhost/.well-known/", + "http://localhost/.well-known/acme-challenging", + } { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("Could not craft request, got error: %v", err) + } + rw := httptest.NewRecorder() + if RequestCallback(rw, req) { + t.Errorf("Got true with this URL, but shouldn't have: %s", url) + } + } +} + +func TestRequestCallbackSuccess(t *testing.T) { + expectedPath := challengeBasePath + "/asdf" + + // Set up fake acme handler backend to make sure proxying succeeds + var proxySuccess bool + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxySuccess = true + if r.URL.Path != expectedPath { + t.Errorf("Expected path '%s' but got '%s' instead", expectedPath, r.URL.Path) + } + })) + + // Custom listener that uses the port we expect + ln, err := net.Listen("tcp", "127.0.0.1:"+AlternatePort) + if err != nil { + t.Fatalf("Unable to start test server listener: %v", err) + } + ts.Listener = ln + + // Start our engines and run the test + ts.Start() + defer ts.Close() + req, err := http.NewRequest("GET", "http://127.0.0.1:"+AlternatePort+expectedPath, nil) + if err != nil { + t.Fatalf("Could not craft request, got error: %v", err) + } + rw := httptest.NewRecorder() + + RequestCallback(rw, req) + + if !proxySuccess { + t.Fatal("Expected request to be proxied, but it wasn't") + } +} diff --git a/core/https/handshake.go b/core/https/handshake.go new file mode 100644 index 000000000..4c1fc22c3 --- /dev/null +++ b/core/https/handshake.go @@ -0,0 +1,320 @@ +package https + +import ( + "bytes" + "crypto/tls" + "encoding/pem" + "errors" + "fmt" + "log" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/miekg/coredns/server" + "github.com/xenolf/lego/acme" +) + +// GetCertificate gets a certificate to satisfy clientHello as long as +// the certificate is already cached in memory. It will not be loaded +// from disk or obtained from the CA during the handshake. +// +// This function is safe for use as a tls.Config.GetCertificate callback. +func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := getCertDuringHandshake(clientHello.ServerName, false, false) + return &cert.Certificate, err +} + +// GetOrObtainCertificate will get a certificate to satisfy clientHello, even +// if that means obtaining a new certificate from a CA during the handshake. +// It first checks the in-memory cache, then accesses disk, then accesses the +// network if it must. An obtained certificate will be stored on disk and +// cached in memory. +// +// This function is safe for use as a tls.Config.GetCertificate callback. +func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := getCertDuringHandshake(clientHello.ServerName, true, true) + return &cert.Certificate, err +} + +// getCertDuringHandshake will get a certificate for name. It first tries +// the in-memory cache. If no certificate for name is in the cache and if +// loadIfNecessary == true, it goes to disk to load it into the cache and +// serve it. If it's not on disk and if obtainIfNecessary == true, the +// certificate will be obtained from the CA, cached, and served. If +// obtainIfNecessary is true, then loadIfNecessary must also be set to true. +// An error will be returned if and only if no certificate is available. +// +// This function is safe for concurrent use. +func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { + // First check our in-memory cache to see if we've already loaded it + cert, matched, defaulted := getCertificate(name) + if matched { + return cert, nil + } + + if loadIfNecessary { + // Then check to see if we have one on disk + loadedCert, err := cacheManagedCertificate(name, true) + if err == nil { + loadedCert, err = handshakeMaintenance(name, loadedCert) + if err != nil { + log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err) + } + return loadedCert, nil + } + + if obtainIfNecessary { + // By this point, we need to ask the CA for a certificate + + name = strings.ToLower(name) + + // Make sure aren't over any applicable limits + err := checkLimitsForObtainingNewCerts(name) + if err != nil { + return Certificate{}, err + } + + // Name has to qualify for a certificate + if !HostQualifies(name) { + return cert, errors.New("hostname '" + name + "' does not qualify for certificate") + } + + // Obtain certificate from the CA + return obtainOnDemandCertificate(name) + } + } + + if defaulted { + return cert, nil + } + + return Certificate{}, errors.New("no certificate for " + name) +} + +// checkLimitsForObtainingNewCerts checks to see if name can be issued right +// now according to mitigating factors we keep track of and preferences the +// user has set. If a non-nil error is returned, do not issue a new certificate +// for name. +func checkLimitsForObtainingNewCerts(name string) error { + // User can set hard limit for number of certs for the process to issue + if onDemandMaxIssue > 0 && atomic.LoadInt32(OnDemandIssuedCount) >= onDemandMaxIssue { + return fmt.Errorf("%s: maximum certificates issued (%d)", name, onDemandMaxIssue) + } + + // Make sure name hasn't failed a challenge recently + failedIssuanceMu.RLock() + when, ok := failedIssuance[name] + failedIssuanceMu.RUnlock() + if ok { + return fmt.Errorf("%s: throttled; refusing to issue cert since last attempt on %s failed", name, when.String()) + } + + // Make sure, if we've issued a few certificates already, that we haven't + // issued any recently + lastIssueTimeMu.Lock() + since := time.Since(lastIssueTime) + lastIssueTimeMu.Unlock() + if atomic.LoadInt32(OnDemandIssuedCount) >= 10 && since < 10*time.Minute { + return fmt.Errorf("%s: throttled; last certificate was obtained %v ago", name, since) + } + + // 👍Good to go + return nil +} + +// obtainOnDemandCertificate obtains a certificate for name for the given +// name. If another goroutine has already started obtaining a cert for +// name, it will wait and use what the other goroutine obtained. +// +// This function is safe for use by multiple concurrent goroutines. +func obtainOnDemandCertificate(name string) (Certificate, error) { + // We must protect this process from happening concurrently, so synchronize. + obtainCertWaitChansMu.Lock() + wait, ok := obtainCertWaitChans[name] + if ok { + // lucky us -- another goroutine is already obtaining the certificate. + // wait for it to finish obtaining the cert and then we'll use it. + obtainCertWaitChansMu.Unlock() + <-wait + return getCertDuringHandshake(name, true, false) + } + + // looks like it's up to us to do all the work and obtain the cert + wait = make(chan struct{}) + obtainCertWaitChans[name] = wait + obtainCertWaitChansMu.Unlock() + + // Unblock waiters and delete waitgroup when we return + defer func() { + obtainCertWaitChansMu.Lock() + close(wait) + delete(obtainCertWaitChans, name) + obtainCertWaitChansMu.Unlock() + }() + + log.Printf("[INFO] Obtaining new certificate for %s", name) + + // obtain cert + client, err := NewACMEClientGetEmail(server.Config{}, false) + if err != nil { + return Certificate{}, errors.New("error creating client: " + err.Error()) + } + client.Configure("") // TODO: which BindHost? + err = client.Obtain([]string{name}) + if err != nil { + // Failed to solve challenge, so don't allow another on-demand + // issue for this name to be attempted for a little while. + failedIssuanceMu.Lock() + failedIssuance[name] = time.Now() + go func(name string) { + time.Sleep(5 * time.Minute) + failedIssuanceMu.Lock() + delete(failedIssuance, name) + failedIssuanceMu.Unlock() + }(name) + failedIssuanceMu.Unlock() + return Certificate{}, err + } + + // Success - update counters and stuff + atomic.AddInt32(OnDemandIssuedCount, 1) + lastIssueTimeMu.Lock() + lastIssueTime = time.Now() + lastIssueTimeMu.Unlock() + + // The certificate is already on disk; now just start over to load it and serve it + return getCertDuringHandshake(name, true, false) +} + +// handshakeMaintenance performs a check on cert for expiration and OCSP +// validity. +// +// This function is safe for use by multiple concurrent goroutines. +func handshakeMaintenance(name string, cert Certificate) (Certificate, error) { + // Check cert expiration + timeLeft := cert.NotAfter.Sub(time.Now().UTC()) + if timeLeft < renewDurationBefore { + log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft) + return renewDynamicCertificate(name) + } + + // Check OCSP staple validity + if cert.OCSP != nil { + refreshTime := cert.OCSP.ThisUpdate.Add(cert.OCSP.NextUpdate.Sub(cert.OCSP.ThisUpdate) / 2) + if time.Now().After(refreshTime) { + err := stapleOCSP(&cert, nil) + if err != nil { + // An error with OCSP stapling is not the end of the world, and in fact, is + // quite common considering not all certs have issuer URLs that support it. + log.Printf("[ERROR] Getting OCSP for %s: %v", name, err) + } + certCacheMu.Lock() + certCache[name] = cert + certCacheMu.Unlock() + } + } + + return cert, nil +} + +// renewDynamicCertificate renews currentCert using the clientHello. It returns the +// certificate to use and an error, if any. currentCert may be returned even if an +// error occurs, since we perform renewals before they expire and it may still be +// usable. name should already be lower-cased before calling this function. +// +// This function is safe for use by multiple concurrent goroutines. +func renewDynamicCertificate(name string) (Certificate, error) { + obtainCertWaitChansMu.Lock() + wait, ok := obtainCertWaitChans[name] + if ok { + // lucky us -- another goroutine is already renewing the certificate. + // wait for it to finish, then we'll use the new one. + obtainCertWaitChansMu.Unlock() + <-wait + return getCertDuringHandshake(name, true, false) + } + + // looks like it's up to us to do all the work and renew the cert + wait = make(chan struct{}) + obtainCertWaitChans[name] = wait + obtainCertWaitChansMu.Unlock() + + // unblock waiters and delete waitgroup when we return + defer func() { + obtainCertWaitChansMu.Lock() + close(wait) + delete(obtainCertWaitChans, name) + obtainCertWaitChansMu.Unlock() + }() + + log.Printf("[INFO] Renewing certificate for %s", name) + + client, err := NewACMEClientGetEmail(server.Config{}, false) + if err != nil { + return Certificate{}, err + } + client.Configure("") // TODO: Bind address of relevant listener, yuck + err = client.Renew(name) + if err != nil { + return Certificate{}, err + } + + return getCertDuringHandshake(name, true, false) +} + +// stapleOCSP staples OCSP information to cert for hostname name. +// If you have it handy, you should pass in the PEM-encoded certificate +// bundle; otherwise the DER-encoded cert will have to be PEM-encoded. +// If you don't have the PEM blocks handy, just pass in nil. +// +// Errors here are not necessarily fatal, it could just be that the +// certificate doesn't have an issuer URL. +func stapleOCSP(cert *Certificate, pemBundle []byte) error { + if pemBundle == nil { + // The function in the acme package that gets OCSP requires a PEM-encoded cert + bundle := new(bytes.Buffer) + for _, derBytes := range cert.Certificate.Certificate { + pem.Encode(bundle, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + } + pemBundle = bundle.Bytes() + } + + ocspBytes, ocspResp, err := acme.GetOCSPForCert(pemBundle) + if err != nil { + return err + } + + cert.Certificate.OCSPStaple = ocspBytes + cert.OCSP = ocspResp + + return nil +} + +// obtainCertWaitChans is used to coordinate obtaining certs for each hostname. +var obtainCertWaitChans = make(map[string]chan struct{}) +var obtainCertWaitChansMu sync.Mutex + +// OnDemandIssuedCount is the number of certificates that have been issued +// on-demand by this process. It is only safe to modify this count atomically. +// If it reaches onDemandMaxIssue, on-demand issuances will fail. +var OnDemandIssuedCount = new(int32) + +// onDemandMaxIssue is set based on max_certs in tls config. It specifies the +// maximum number of certificates that can be issued. +// TODO: This applies globally, but we should probably make a server-specific +// way to keep track of these limits and counts, since it's specified in the +// Caddyfile... +var onDemandMaxIssue int32 + +// failedIssuance is a set of names that we recently failed to get a +// certificate for from the ACME CA. They are removed after some time. +// When a name is in this map, do not issue a certificate for it on-demand. +var failedIssuance = make(map[string]time.Time) +var failedIssuanceMu sync.RWMutex + +// lastIssueTime records when we last obtained a certificate successfully. +// If this value is recent, do not make any on-demand certificate requests. +var lastIssueTime time.Time +var lastIssueTimeMu sync.Mutex diff --git a/core/https/handshake_test.go b/core/https/handshake_test.go new file mode 100644 index 000000000..cf70eb17d --- /dev/null +++ b/core/https/handshake_test.go @@ -0,0 +1,54 @@ +package https + +import ( + "crypto/tls" + "crypto/x509" + "testing" +) + +func TestGetCertificate(t *testing.T) { + defer func() { certCache = make(map[string]Certificate) }() + + hello := &tls.ClientHelloInfo{ServerName: "example.com"} + helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"} + helloNoSNI := &tls.ClientHelloInfo{} + helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"} + + // When cache is empty + if cert, err := GetCertificate(hello); err == nil { + t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert) + } + if cert, err := GetCertificate(helloNoSNI); err == nil { + t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert) + } + + // When cache has one certificate in it (also is default) + defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}} + certCache[""] = defaultCert + certCache["example.com"] = defaultCert + if cert, err := GetCertificate(hello); err != nil { + t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err) + } else if cert.Leaf.DNSNames[0] != "example.com" { + t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert) + } + if cert, err := GetCertificate(helloNoSNI); err != nil { + t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err) + } else if cert.Leaf.DNSNames[0] != "example.com" { + t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert) + } + + // When retrieving wildcard certificate + certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}} + if cert, err := GetCertificate(helloSub); err != nil { + t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err) + } else if cert.Leaf.DNSNames[0] != "*.example.com" { + t.Errorf("Got wrong certificate, expected wildcard: %v", cert) + } + + // When no certificate matches, the default is returned + if cert, err := GetCertificate(helloNoMatch); err != nil { + t.Errorf("Expected default certificate with no error when no matches, got err: %v", err) + } else if cert.Leaf.DNSNames[0] != "example.com" { + t.Errorf("Expected default cert with no matches, got: %v", cert) + } +} diff --git a/core/https/https.go b/core/https/https.go new file mode 100644 index 000000000..0deb88b86 --- /dev/null +++ b/core/https/https.go @@ -0,0 +1,358 @@ +// Package https facilitates the management of TLS assets and integrates +// Let's Encrypt functionality into Caddy with first-class support for +// creating and renewing certificates automatically. It is designed to +// configure sites for HTTPS by default. +package https + +import ( + "encoding/json" + "errors" + "io/ioutil" + "net" + "os" + "strings" + + "github.com/miekg/coredns/server" + "github.com/xenolf/lego/acme" +) + +// Activate sets up TLS for each server config in configs +// as needed; this consists of acquiring and maintaining +// certificates and keys for qualifying configs and enabling +// OCSP stapling for all TLS-enabled configs. +// +// This function may prompt the user to provide an email +// address if none is available through other means. It +// prefers the email address specified in the config, but +// if that is not available it will check the command line +// argument. If absent, it will use the most recent email +// address from last time. If there isn't one, the user +// will be prompted and shown SA link. +// +// Also note that calling this function activates asset +// management automatically, which keeps certificates +// renewed and OCSP stapling updated. +// +// Activate returns the updated list of configs, since +// some may have been appended, for example, to redirect +// plaintext HTTP requests to their HTTPS counterpart. +// This function only appends; it does not splice. +func Activate(configs []server.Config) ([]server.Config, error) { + // just in case previous caller forgot... + Deactivate() + + // pre-screen each config and earmark the ones that qualify for managed TLS + MarkQualified(configs) + + // place certificates and keys on disk + err := ObtainCerts(configs, true, false) + if err != nil { + return configs, err + } + + // update TLS configurations + err = EnableTLS(configs, true) + if err != nil { + return configs, err + } + + // renew all relevant certificates that need renewal. this is important + // to do right away for a couple reasons, mainly because each restart, + // the renewal ticker is reset, so if restarts happen more often than + // the ticker interval, renewals would never happen. but doing + // it right away at start guarantees that renewals aren't missed. + err = renewManagedCertificates(true) + if err != nil { + return configs, err + } + + // keep certificates renewed and OCSP stapling updated + go maintainAssets(stopChan) + + return configs, nil +} + +// Deactivate cleans up long-term, in-memory resources +// allocated by calling Activate(). Essentially, it stops +// the asset maintainer from running, meaning that certificates +// will not be renewed, OCSP staples will not be updated, etc. +func Deactivate() (err error) { + defer func() { + if rec := recover(); rec != nil { + err = errors.New("already deactivated") + } + }() + close(stopChan) + stopChan = make(chan struct{}) + return +} + +// MarkQualified scans each config and, if it qualifies for managed +// TLS, it sets the Managed field of the TLSConfig to true. +func MarkQualified(configs []server.Config) { + for i := 0; i < len(configs); i++ { + if ConfigQualifies(configs[i]) { + configs[i].TLS.Managed = true + } + } +} + +// ObtainCerts obtains certificates for all these configs as long as a +// certificate does not already exist on disk. It does not modify the +// configs at all; it only obtains and stores certificates and keys to +// the disk. If allowPrompts is true, the user may be shown a prompt. +// If proxyACME is true, the ACME challenges will be proxied to our alt port. +func ObtainCerts(configs []server.Config, allowPrompts, proxyACME bool) error { + // We group configs by email so we don't make the same clients over and + // over. This has the potential to prompt the user for an email, but we + // prevent that by assuming that if we already have a listener that can + // proxy ACME challenge requests, then the server is already running and + // the operator is no longer present. + groupedConfigs := groupConfigsByEmail(configs, allowPrompts) + + for email, group := range groupedConfigs { + // Wait as long as we can before creating the client, because it + // may not be needed, for example, if we already have what we + // need on disk. Creating a client involves the network and + // potentially prompting the user, etc., so only do if necessary. + var client *ACMEClient + + for _, cfg := range group { + if !HostQualifies(cfg.Host) || existingCertAndKey(cfg.Host) { + continue + } + + // Now we definitely do need a client + if client == nil { + var err error + client, err = NewACMEClient(email, allowPrompts) + if err != nil { + return errors.New("error creating client: " + err.Error()) + } + } + + // c.Configure assumes that allowPrompts == !proxyACME, + // but that's not always true. For example, a restart where + // the user isn't present and we're not listening on port 80. + // TODO: This could probably be refactored better. + if proxyACME { + client.SetHTTPAddress(net.JoinHostPort(cfg.BindHost, AlternatePort)) + client.SetTLSAddress(net.JoinHostPort(cfg.BindHost, AlternatePort)) + client.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) + } else { + client.SetHTTPAddress(net.JoinHostPort(cfg.BindHost, "")) + client.SetTLSAddress(net.JoinHostPort(cfg.BindHost, "")) + client.ExcludeChallenges([]acme.Challenge{acme.DNS01}) + } + + err := client.Obtain([]string{cfg.Host}) + if err != nil { + return err + } + } + } + + return nil +} + +// groupConfigsByEmail groups configs by the email address to be used by an +// ACME client. It only groups configs that have TLS enabled and that are +// marked as Managed. If userPresent is true, the operator MAY be prompted +// for an email address. +func groupConfigsByEmail(configs []server.Config, userPresent bool) map[string][]server.Config { + initMap := make(map[string][]server.Config) + for _, cfg := range configs { + if !cfg.TLS.Managed { + continue + } + leEmail := getEmail(cfg, userPresent) + initMap[leEmail] = append(initMap[leEmail], cfg) + } + return initMap +} + +// EnableTLS configures each config to use TLS according to default settings. +// It will only change configs that are marked as managed, and assumes that +// certificates and keys are already on disk. If loadCertificates is true, +// the certificates will be loaded from disk into the cache for this process +// to use. If false, TLS will still be enabled and configured with default +// settings, but no certificates will be parsed loaded into the cache, and +// the returned error value will always be nil. +func EnableTLS(configs []server.Config, loadCertificates bool) error { + for i := 0; i < len(configs); i++ { + if !configs[i].TLS.Managed { + continue + } + configs[i].TLS.Enabled = true + if loadCertificates && HostQualifies(configs[i].Host) { + _, err := cacheManagedCertificate(configs[i].Host, false) + if err != nil { + return err + } + } + setDefaultTLSParams(&configs[i]) + } + return nil +} + +// hostHasOtherPort returns true if there is another config in the list with the same +// hostname that has port otherPort, or false otherwise. All the configs are checked +// against the hostname of allConfigs[thisConfigIdx]. +func hostHasOtherPort(allConfigs []server.Config, thisConfigIdx int, otherPort string) bool { + for i, otherCfg := range allConfigs { + if i == thisConfigIdx { + continue // has to be a config OTHER than the one we're comparing against + } + if otherCfg.Host == allConfigs[thisConfigIdx].Host && otherCfg.Port == otherPort { + return true + } + } + return false +} + +// ConfigQualifies returns true if cfg qualifies for +// fully managed TLS (but not on-demand TLS, which is +// not considered here). It does NOT check to see if a +// cert and key already exist for the config. If the +// config does qualify, you should set cfg.TLS.Managed +// to true and check that instead, because the process of +// setting up the config may make it look like it +// doesn't qualify even though it originally did. +func ConfigQualifies(cfg server.Config) bool { + return (!cfg.TLS.Manual || cfg.TLS.OnDemand) && // user might provide own cert and key + + // user can force-disable automatic HTTPS for this host + cfg.Port != "80" && + cfg.TLS.LetsEncryptEmail != "off" && + + // we get can't certs for some kinds of hostnames, but + // on-demand TLS allows empty hostnames at startup + (HostQualifies(cfg.Host) || cfg.TLS.OnDemand) +} + +// HostQualifies returns true if the hostname alone +// appears eligible for automatic HTTPS. For example, +// localhost, empty hostname, and IP addresses are +// not eligible because we cannot obtain certificates +// for those names. +func HostQualifies(hostname string) bool { + return hostname != "localhost" && // localhost is ineligible + + // hostname must not be empty + strings.TrimSpace(hostname) != "" && + + // cannot be an IP address, see + // https://community.letsencrypt.org/t/certificate-for-static-ip/84/2?u=mholt + // (also trim [] from either end, since that special case can sneak through + // for IPv6 addresses using the -host flag and with empty/no Caddyfile) + net.ParseIP(strings.Trim(hostname, "[]")) == nil +} + +// existingCertAndKey returns true if the host has a certificate +// and private key in storage already, false otherwise. +func existingCertAndKey(host string) bool { + _, err := os.Stat(storage.SiteCertFile(host)) + if err != nil { + return false + } + _, err = os.Stat(storage.SiteKeyFile(host)) + if err != nil { + return false + } + return true +} + +// saveCertResource saves the certificate resource to disk. This +// includes the certificate file itself, the private key, and the +// metadata file. +func saveCertResource(cert acme.CertificateResource) error { + err := os.MkdirAll(storage.Site(cert.Domain), 0700) + if err != nil { + return err + } + + // Save cert + err = ioutil.WriteFile(storage.SiteCertFile(cert.Domain), cert.Certificate, 0600) + if err != nil { + return err + } + + // Save private key + err = ioutil.WriteFile(storage.SiteKeyFile(cert.Domain), cert.PrivateKey, 0600) + if err != nil { + return err + } + + // Save cert metadata + jsonBytes, err := json.MarshalIndent(&cert, "", "\t") + if err != nil { + return err + } + err = ioutil.WriteFile(storage.SiteMetaFile(cert.Domain), jsonBytes, 0600) + if err != nil { + return err + } + + return nil +} + +// Revoke revokes the certificate for host via ACME protocol. +func Revoke(host string) error { + if !existingCertAndKey(host) { + return errors.New("no certificate and key for " + host) + } + + email := getEmail(server.Config{Host: host}, true) + if email == "" { + return errors.New("email is required to revoke") + } + + client, err := NewACMEClient(email, true) + if err != nil { + return err + } + + certFile := storage.SiteCertFile(host) + certBytes, err := ioutil.ReadFile(certFile) + if err != nil { + return err + } + + err = client.RevokeCertificate(certBytes) + if err != nil { + return err + } + + err = os.Remove(certFile) + if err != nil { + return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error()) + } + + return nil +} + +var ( + // DefaultEmail represents the Let's Encrypt account email to use if none provided + DefaultEmail string + + // Agreed indicates whether user has agreed to the Let's Encrypt SA + Agreed bool + + // CAUrl represents the base URL to the CA's ACME endpoint + CAUrl string +) + +// AlternatePort is the port on which the acme client will open a +// listener and solve the CA's challenges. If this alternate port +// is used instead of the default port (80 or 443), then the +// default port for the challenge must be forwarded to this one. +const AlternatePort = "5033" + +// KeyType is the type to use for new keys. +// This shouldn't need to change except for in tests; +// the size can be drastically reduced for speed. +var KeyType = acme.EC384 + +// stopChan is used to signal the maintenance goroutine +// to terminate. +var stopChan chan struct{} diff --git a/core/https/https_test.go b/core/https/https_test.go new file mode 100644 index 000000000..40b67367e --- /dev/null +++ b/core/https/https_test.go @@ -0,0 +1,332 @@ +package https + +import ( + "io/ioutil" + "net/http" + "os" + "testing" + + "github.com/miekg/coredns/middleware/redirect" + "github.com/miekg/coredns/server" + "github.com/xenolf/lego/acme" +) + +func TestHostQualifies(t *testing.T) { + for i, test := range []struct { + host string + expect bool + }{ + {"localhost", false}, + {"127.0.0.1", false}, + {"127.0.1.5", false}, + {"::1", false}, + {"[::1]", false}, + {"[::]", false}, + {"::", false}, + {"", false}, + {" ", false}, + {"0.0.0.0", false}, + {"192.168.1.3", false}, + {"10.0.2.1", false}, + {"169.112.53.4", false}, + {"foobar.com", true}, + {"sub.foobar.com", true}, + } { + if HostQualifies(test.host) && !test.expect { + t.Errorf("Test %d: Expected '%s' to NOT qualify, but it did", i, test.host) + } + if !HostQualifies(test.host) && test.expect { + t.Errorf("Test %d: Expected '%s' to qualify, but it did NOT", i, test.host) + } + } +} + +func TestConfigQualifies(t *testing.T) { + for i, test := range []struct { + cfg server.Config + expect bool + }{ + {server.Config{Host: ""}, false}, + {server.Config{Host: "localhost"}, false}, + {server.Config{Host: "123.44.3.21"}, false}, + {server.Config{Host: "example.com"}, true}, + {server.Config{Host: "example.com", TLS: server.TLSConfig{Manual: true}}, false}, + {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, false}, + {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, true}, + {server.Config{Host: "example.com", Scheme: "http"}, false}, + {server.Config{Host: "example.com", Port: "80"}, false}, + {server.Config{Host: "example.com", Port: "1234"}, true}, + {server.Config{Host: "example.com", Scheme: "https"}, true}, + {server.Config{Host: "example.com", Port: "80", Scheme: "https"}, false}, + } { + if test.expect && !ConfigQualifies(test.cfg) { + t.Errorf("Test %d: Expected config to qualify, but it did NOT: %#v", i, test.cfg) + } + if !test.expect && ConfigQualifies(test.cfg) { + t.Errorf("Test %d: Expected config to NOT qualify, but it did: %#v", i, test.cfg) + } + } +} + +func TestRedirPlaintextHost(t *testing.T) { + cfg := redirPlaintextHost(server.Config{ + Host: "example.com", + BindHost: "93.184.216.34", + Port: "1234", + }) + + // Check host and port + if actual, expected := cfg.Host, "example.com"; actual != expected { + t.Errorf("Expected redir config to have host %s but got %s", expected, actual) + } + if actual, expected := cfg.BindHost, "93.184.216.34"; actual != expected { + t.Errorf("Expected redir config to have bindhost %s but got %s", expected, actual) + } + if actual, expected := cfg.Port, "80"; actual != expected { + t.Errorf("Expected redir config to have port '%s' but got '%s'", expected, actual) + } + + // Make sure redirect handler is set up properly + if cfg.Middleware == nil || len(cfg.Middleware) != 1 { + t.Fatalf("Redir config middleware not set up properly; got: %#v", cfg.Middleware) + } + + handler, ok := cfg.Middleware[0](nil).(redirect.Redirect) + if !ok { + t.Fatalf("Expected a redirect.Redirect middleware, but got: %#v", handler) + } + if len(handler.Rules) != 1 { + t.Fatalf("Expected one redirect rule, got: %#v", handler.Rules) + } + + // Check redirect rule for correctness + if actual, expected := handler.Rules[0].FromScheme, "http"; actual != expected { + t.Errorf("Expected redirect rule to be from scheme '%s' but is actually from '%s'", expected, actual) + } + if actual, expected := handler.Rules[0].FromPath, "/"; actual != expected { + t.Errorf("Expected redirect rule to be for path '%s' but is actually for '%s'", expected, actual) + } + if actual, expected := handler.Rules[0].To, "https://{host}:1234{uri}"; actual != expected { + t.Errorf("Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual) + } + if actual, expected := handler.Rules[0].Code, http.StatusMovedPermanently; actual != expected { + t.Errorf("Expected redirect rule to have code %d but was %d", expected, actual) + } + + // browsers can infer a default port from scheme, so make sure the port + // doesn't get added in explicitly for default ports like 443 for https. + cfg = redirPlaintextHost(server.Config{Host: "example.com", Port: "443"}) + handler, ok = cfg.Middleware[0](nil).(redirect.Redirect) + if actual, expected := handler.Rules[0].To, "https://{host}{uri}"; actual != expected { + t.Errorf("(Default Port) Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual) + } +} + +func TestSaveCertResource(t *testing.T) { + storage = Storage("./le_test_save") + defer func() { + err := os.RemoveAll(string(storage)) + if err != nil { + t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err) + } + }() + + domain := "example.com" + certContents := "certificate" + keyContents := "private key" + metaContents := `{ + "domain": "example.com", + "certUrl": "https://example.com/cert", + "certStableUrl": "https://example.com/cert/stable" +}` + + cert := acme.CertificateResource{ + Domain: domain, + CertURL: "https://example.com/cert", + CertStableURL: "https://example.com/cert/stable", + PrivateKey: []byte(keyContents), + Certificate: []byte(certContents), + } + + err := saveCertResource(cert) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + certFile, err := ioutil.ReadFile(storage.SiteCertFile(domain)) + if err != nil { + t.Errorf("Expected no error reading certificate file, got: %v", err) + } + if string(certFile) != certContents { + t.Errorf("Expected certificate file to contain '%s', got '%s'", certContents, string(certFile)) + } + + keyFile, err := ioutil.ReadFile(storage.SiteKeyFile(domain)) + if err != nil { + t.Errorf("Expected no error reading private key file, got: %v", err) + } + if string(keyFile) != keyContents { + t.Errorf("Expected private key file to contain '%s', got '%s'", keyContents, string(keyFile)) + } + + metaFile, err := ioutil.ReadFile(storage.SiteMetaFile(domain)) + if err != nil { + t.Errorf("Expected no error reading meta file, got: %v", err) + } + if string(metaFile) != metaContents { + t.Errorf("Expected meta file to contain '%s', got '%s'", metaContents, string(metaFile)) + } +} + +func TestExistingCertAndKey(t *testing.T) { + storage = Storage("./le_test_existing") + defer func() { + err := os.RemoveAll(string(storage)) + if err != nil { + t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err) + } + }() + + domain := "example.com" + + if existingCertAndKey(domain) { + t.Errorf("Did NOT expect %v to have existing cert or key, but it did", domain) + } + + err := saveCertResource(acme.CertificateResource{ + Domain: domain, + PrivateKey: []byte("key"), + Certificate: []byte("cert"), + }) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if !existingCertAndKey(domain) { + t.Errorf("Expected %v to have existing cert and key, but it did NOT", domain) + } +} + +func TestHostHasOtherPort(t *testing.T) { + configs := []server.Config{ + {Host: "example.com", Port: "80"}, + {Host: "sub1.example.com", Port: "80"}, + {Host: "sub1.example.com", Port: "443"}, + } + + if hostHasOtherPort(configs, 0, "80") { + t.Errorf(`Expected hostHasOtherPort(configs, 0, "80") to be false, but got true`) + } + if hostHasOtherPort(configs, 0, "443") { + t.Errorf(`Expected hostHasOtherPort(configs, 0, "443") to be false, but got true`) + } + if !hostHasOtherPort(configs, 1, "443") { + t.Errorf(`Expected hostHasOtherPort(configs, 1, "443") to be true, but got false`) + } +} + +func TestMakePlaintextRedirects(t *testing.T) { + configs := []server.Config{ + // Happy path = standard redirect from 80 to 443 + {Host: "example.com", TLS: server.TLSConfig{Managed: true}}, + + // Host on port 80 already defined; don't change it (no redirect) + {Host: "sub1.example.com", Port: "80", Scheme: "http"}, + {Host: "sub1.example.com", TLS: server.TLSConfig{Managed: true}}, + + // Redirect from port 80 to port 5000 in this case + {Host: "sub2.example.com", Port: "5000", TLS: server.TLSConfig{Managed: true}}, + + // Can redirect from 80 to either 443 or 5001, but choose 443 + {Host: "sub3.example.com", Port: "443", TLS: server.TLSConfig{Managed: true}}, + {Host: "sub3.example.com", Port: "5001", Scheme: "https", TLS: server.TLSConfig{Managed: true}}, + } + + result := MakePlaintextRedirects(configs) + expectedRedirCount := 3 + + if len(result) != len(configs)+expectedRedirCount { + t.Errorf("Expected %d redirect(s) to be added, but got %d", + expectedRedirCount, len(result)-len(configs)) + } +} + +func TestEnableTLS(t *testing.T) { + configs := []server.Config{ + {Host: "example.com", TLS: server.TLSConfig{Managed: true}}, + {}, // not managed - no changes! + } + + EnableTLS(configs, false) + + if !configs[0].TLS.Enabled { + t.Errorf("Expected config 0 to have TLS.Enabled == true, but it was false") + } + if configs[1].TLS.Enabled { + t.Errorf("Expected config 1 to have TLS.Enabled == false, but it was true") + } +} + +func TestGroupConfigsByEmail(t *testing.T) { + if groupConfigsByEmail([]server.Config{}, false) == nil { + t.Errorf("With empty input, returned map was nil, but expected non-nil map") + } + + configs := []server.Config{ + {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, + {Host: "sub1.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}}, + {Host: "sub2.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, + {Host: "sub3.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}}, + {Host: "sub4.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}}, + {Host: "sub5.example.com", TLS: server.TLSConfig{LetsEncryptEmail: ""}}, // not managed + } + DefaultEmail = "test@example.com" + + groups := groupConfigsByEmail(configs, true) + + if groups == nil { + t.Fatalf("Returned map was nil, but expected values") + } + + if len(groups) != 2 { + t.Errorf("Expected 2 groups, got %d: %#v", len(groups), groups) + } + if len(groups["foo@bar"]) != 2 { + t.Errorf("Expected 2 configs for foo@bar, got %d: %#v", len(groups["foobar"]), groups["foobar"]) + } + if len(groups[DefaultEmail]) != 3 { + t.Errorf("Expected 3 configs for %s, got %d: %#v", DefaultEmail, len(groups["foobar"]), groups["foobar"]) + } +} + +func TestMarkQualified(t *testing.T) { + // TODO: TestConfigQualifies and this test share the same config list... + configs := []server.Config{ + {Host: ""}, + {Host: "localhost"}, + {Host: "123.44.3.21"}, + {Host: "example.com"}, + {Host: "example.com", TLS: server.TLSConfig{Manual: true}}, + {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, + {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, + {Host: "example.com", Scheme: "http"}, + {Host: "example.com", Port: "80"}, + {Host: "example.com", Port: "1234"}, + {Host: "example.com", Scheme: "https"}, + {Host: "example.com", Port: "80", Scheme: "https"}, + } + expectedManagedCount := 4 + + MarkQualified(configs) + + count := 0 + for _, cfg := range configs { + if cfg.TLS.Managed { + count++ + } + } + + if count != expectedManagedCount { + t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count) + } +} diff --git a/core/https/maintain.go b/core/https/maintain.go new file mode 100644 index 000000000..46fd3d1f5 --- /dev/null +++ b/core/https/maintain.go @@ -0,0 +1,211 @@ +package https + +import ( + "log" + "time" + + "github.com/miekg/coredns/server" + + "golang.org/x/crypto/ocsp" +) + +const ( + // RenewInterval is how often to check certificates for renewal. + RenewInterval = 12 * time.Hour + + // OCSPInterval is how often to check if OCSP stapling needs updating. + OCSPInterval = 1 * time.Hour +) + +// maintainAssets is a permanently-blocking function +// that loops indefinitely and, on a regular schedule, checks +// certificates for expiration and initiates a renewal of certs +// that are expiring soon. It also updates OCSP stapling and +// performs other maintenance of assets. +// +// You must pass in the channel which you'll close when +// maintenance should stop, to allow this goroutine to clean up +// after itself and unblock. +func maintainAssets(stopChan chan struct{}) { + renewalTicker := time.NewTicker(RenewInterval) + ocspTicker := time.NewTicker(OCSPInterval) + + for { + select { + case <-renewalTicker.C: + log.Println("[INFO] Scanning for expiring certificates") + renewManagedCertificates(false) + log.Println("[INFO] Done checking certificates") + case <-ocspTicker.C: + log.Println("[INFO] Scanning for stale OCSP staples") + updateOCSPStaples() + log.Println("[INFO] Done checking OCSP staples") + case <-stopChan: + renewalTicker.Stop() + ocspTicker.Stop() + log.Println("[INFO] Stopped background maintenance routine") + return + } + } +} + +func renewManagedCertificates(allowPrompts bool) (err error) { + var renewed, deleted []Certificate + var client *ACMEClient + visitedNames := make(map[string]struct{}) + + certCacheMu.RLock() + for name, cert := range certCache { + if !cert.Managed { + continue + } + + // the list of names on this cert should never be empty... + if cert.Names == nil || len(cert.Names) == 0 { + log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v", name, cert.Names) + deleted = append(deleted, cert) + continue + } + + // skip names whose certificate we've already renewed + if _, ok := visitedNames[name]; ok { + continue + } + for _, name := range cert.Names { + visitedNames[name] = struct{}{} + } + + timeLeft := cert.NotAfter.Sub(time.Now().UTC()) + if timeLeft < renewDurationBefore { + log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft) + + if client == nil { + client, err = NewACMEClientGetEmail(server.Config{}, allowPrompts) + if err != nil { + return err + } + client.Configure("") // TODO: Bind address of relevant listener, yuck + } + + err := client.Renew(cert.Names[0]) // managed certs better have only one name + if err != nil { + if client.AllowPrompts && timeLeft < 0 { + // Certificate renewal failed, the operator is present, and the certificate + // is already expired; we should stop immediately and return the error. Note + // that we used to do this any time a renewal failed at startup. However, + // after discussion in https://github.com/miekg/coredns/issues/642 we decided to + // only stop startup if the certificate is expired. We still log the error + // otherwise. + certCacheMu.RUnlock() + return err + } + log.Printf("[ERROR] %v", err) + if cert.OnDemand { + deleted = append(deleted, cert) + } + } else { + renewed = append(renewed, cert) + } + } + } + certCacheMu.RUnlock() + + // Apply changes to the cache + for _, cert := range renewed { + _, err := cacheManagedCertificate(cert.Names[0], cert.OnDemand) + if err != nil { + if client.AllowPrompts { + return err // operator is present, so report error immediately + } + log.Printf("[ERROR] %v", err) + } + } + for _, cert := range deleted { + certCacheMu.Lock() + for _, name := range cert.Names { + delete(certCache, name) + } + certCacheMu.Unlock() + } + + return nil +} + +func updateOCSPStaples() { + // Create a temporary place to store updates + // until we release the potentially long-lived + // read lock and use a short-lived write lock. + type ocspUpdate struct { + rawBytes []byte + parsed *ocsp.Response + } + updated := make(map[string]ocspUpdate) + + // A single SAN certificate maps to multiple names, so we use this + // set to make sure we don't waste cycles checking OCSP for the same + // certificate multiple times. + visited := make(map[string]struct{}) + + certCacheMu.RLock() + for name, cert := range certCache { + // skip this certificate if we've already visited it, + // and if not, mark all the names as visited + if _, ok := visited[name]; ok { + continue + } + for _, n := range cert.Names { + visited[n] = struct{}{} + } + + // no point in updating OCSP for expired certificates + if time.Now().After(cert.NotAfter) { + continue + } + + var lastNextUpdate time.Time + if cert.OCSP != nil { + // start checking OCSP staple about halfway through validity period for good measure + lastNextUpdate = cert.OCSP.NextUpdate + refreshTime := cert.OCSP.ThisUpdate.Add(lastNextUpdate.Sub(cert.OCSP.ThisUpdate) / 2) + + // since OCSP is already stapled, we need only check if we're in that "refresh window" + if time.Now().Before(refreshTime) { + continue + } + } + + err := stapleOCSP(&cert, nil) + if err != nil { + if cert.OCSP != nil { + // if it was no staple before, that's fine, otherwise we should log the error + log.Printf("[ERROR] Checking OCSP for %s: %v", name, err) + } + continue + } + + // By this point, we've obtained the latest OCSP response. + // If there was no staple before, or if the response is updated, make + // sure we apply the update to all names on the certificate. + if lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate { + log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s", + cert.Names, lastNextUpdate, cert.OCSP.NextUpdate) + for _, n := range cert.Names { + updated[n] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP} + } + } + } + certCacheMu.RUnlock() + + // This write lock should be brief since we have all the info we need now. + certCacheMu.Lock() + for name, update := range updated { + cert := certCache[name] + cert.OCSP = update.parsed + cert.Certificate.OCSPStaple = update.rawBytes + certCache[name] = cert + } + certCacheMu.Unlock() +} + +// renewDurationBefore is how long before expiration to renew certificates. +const renewDurationBefore = (24 * time.Hour) * 30 diff --git a/core/https/setup.go b/core/https/setup.go new file mode 100644 index 000000000..ec90e0284 --- /dev/null +++ b/core/https/setup.go @@ -0,0 +1,321 @@ +package https + +import ( + "bytes" + "crypto/tls" + "encoding/pem" + "io/ioutil" + "log" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/miekg/coredns/core/setup" + "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/server" +) + +// Setup sets up the TLS configuration and installs certificates that +// are specified by the user in the config file. All the automatic HTTPS +// stuff comes later outside of this function. +func Setup(c *setup.Controller) (middleware.Middleware, error) { + if c.Port == "80" { + c.TLS.Enabled = false + log.Printf("[WARNING] TLS disabled for %s.", c.Address()) + return nil, nil + } + c.TLS.Enabled = true + + // TODO(miek): disabled for now + return nil, nil + + for c.Next() { + var certificateFile, keyFile, loadDir, maxCerts string + + args := c.RemainingArgs() + switch len(args) { + case 1: + c.TLS.LetsEncryptEmail = args[0] + + // user can force-disable managed TLS this way + if c.TLS.LetsEncryptEmail == "off" { + c.TLS.Enabled = false + return nil, nil + } + case 2: + certificateFile = args[0] + keyFile = args[1] + c.TLS.Manual = true + } + + // Optional block with extra parameters + var hadBlock bool + for c.NextBlock() { + hadBlock = true + switch c.Val() { + case "protocols": + args := c.RemainingArgs() + if len(args) != 2 { + return nil, c.ArgErr() + } + value, ok := supportedProtocols[strings.ToLower(args[0])] + if !ok { + return nil, c.Errf("Wrong protocol name or protocol not supported '%s'", c.Val()) + } + c.TLS.ProtocolMinVersion = value + value, ok = supportedProtocols[strings.ToLower(args[1])] + if !ok { + return nil, c.Errf("Wrong protocol name or protocol not supported '%s'", c.Val()) + } + c.TLS.ProtocolMaxVersion = value + case "ciphers": + for c.NextArg() { + value, ok := supportedCiphersMap[strings.ToUpper(c.Val())] + if !ok { + return nil, c.Errf("Wrong cipher name or cipher not supported '%s'", c.Val()) + } + c.TLS.Ciphers = append(c.TLS.Ciphers, value) + } + case "clients": + c.TLS.ClientCerts = c.RemainingArgs() + if len(c.TLS.ClientCerts) == 0 { + return nil, c.ArgErr() + } + case "load": + c.Args(&loadDir) + c.TLS.Manual = true + case "max_certs": + c.Args(&maxCerts) + c.TLS.OnDemand = true + default: + return nil, c.Errf("Unknown keyword '%s'", c.Val()) + } + } + + // tls requires at least one argument if a block is not opened + if len(args) == 0 && !hadBlock { + return nil, c.ArgErr() + } + + // set certificate limit if on-demand TLS is enabled + if maxCerts != "" { + maxCertsNum, err := strconv.Atoi(maxCerts) + if err != nil || maxCertsNum < 1 { + return nil, c.Err("max_certs must be a positive integer") + } + if onDemandMaxIssue == 0 || int32(maxCertsNum) < onDemandMaxIssue { // keep the minimum; TODO: We have to do this because it is global; should be per-server or per-vhost... + onDemandMaxIssue = int32(maxCertsNum) + } + } + + // don't try to load certificates unless we're supposed to + if !c.TLS.Enabled || !c.TLS.Manual { + continue + } + + // load a single certificate and key, if specified + if certificateFile != "" && keyFile != "" { + err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile) + if err != nil { + return nil, c.Errf("Unable to load certificate and key files for %s: %v", c.Host, err) + } + log.Printf("[INFO] Successfully loaded TLS assets from %s and %s", certificateFile, keyFile) + } + + // load a directory of certificates, if specified + if loadDir != "" { + err := loadCertsInDir(c, loadDir) + if err != nil { + return nil, err + } + } + } + + setDefaultTLSParams(c.Config) + + return nil, nil +} + +// loadCertsInDir loads all the certificates/keys in dir, as long as +// the file ends with .pem. This method of loading certificates is +// modeled after haproxy, which expects the certificate and key to +// be bundled into the same file: +// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt +// +// This function may write to the log as it walks the directory tree. +func loadCertsInDir(c *setup.Controller, dir string) error { + return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + log.Printf("[WARNING] Unable to traverse into %s; skipping", path) + return nil + } + if info.IsDir() { + return nil + } + if strings.HasSuffix(strings.ToLower(info.Name()), ".pem") { + certBuilder, keyBuilder := new(bytes.Buffer), new(bytes.Buffer) + var foundKey bool // use only the first key in the file + + bundle, err := ioutil.ReadFile(path) + if err != nil { + return err + } + + for { + // Decode next block so we can see what type it is + var derBlock *pem.Block + derBlock, bundle = pem.Decode(bundle) + if derBlock == nil { + break + } + + if derBlock.Type == "CERTIFICATE" { + // Re-encode certificate as PEM, appending to certificate chain + pem.Encode(certBuilder, derBlock) + } else if derBlock.Type == "EC PARAMETERS" { + // EC keys generated from openssl can be composed of two blocks: + // parameters and key (parameter block should come first) + if !foundKey { + // Encode parameters + pem.Encode(keyBuilder, derBlock) + + // Key must immediately follow + derBlock, bundle = pem.Decode(bundle) + if derBlock == nil || derBlock.Type != "EC PRIVATE KEY" { + return c.Errf("%s: expected elliptic private key to immediately follow EC parameters", path) + } + pem.Encode(keyBuilder, derBlock) + foundKey = true + } + } else if derBlock.Type == "PRIVATE KEY" || strings.HasSuffix(derBlock.Type, " PRIVATE KEY") { + // RSA key + if !foundKey { + pem.Encode(keyBuilder, derBlock) + foundKey = true + } + } else { + return c.Errf("%s: unrecognized PEM block type: %s", path, derBlock.Type) + } + } + + certPEMBytes, keyPEMBytes := certBuilder.Bytes(), keyBuilder.Bytes() + if len(certPEMBytes) == 0 { + return c.Errf("%s: failed to parse PEM data", path) + } + if len(keyPEMBytes) == 0 { + return c.Errf("%s: no private key block found", path) + } + + err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes) + if err != nil { + return c.Errf("%s: failed to load cert and key for %s: %v", path, c.Host, err) + } + log.Printf("[INFO] Successfully loaded TLS assets from %s", path) + } + return nil + }) +} + +// setDefaultTLSParams sets the default TLS cipher suites, protocol versions, +// and server preferences of a server.Config if they were not previously set +// (it does not overwrite; only fills in missing values). It will also set the +// port to 443 if not already set, TLS is enabled, TLS is manual, and the host +// does not equal localhost. +func setDefaultTLSParams(c *server.Config) { + // If no ciphers provided, use default list + if len(c.TLS.Ciphers) == 0 { + c.TLS.Ciphers = defaultCiphers + } + + // Not a cipher suite, but still important for mitigating protocol downgrade attacks + // (prepend since having it at end breaks http2 due to non-h2-approved suites before it) + c.TLS.Ciphers = append([]uint16{tls.TLS_FALLBACK_SCSV}, c.TLS.Ciphers...) + + // Set default protocol min and max versions - must balance compatibility and security + if c.TLS.ProtocolMinVersion == 0 { + c.TLS.ProtocolMinVersion = tls.VersionTLS10 + } + if c.TLS.ProtocolMaxVersion == 0 { + c.TLS.ProtocolMaxVersion = tls.VersionTLS12 + } + + // Prefer server cipher suites + c.TLS.PreferServerCipherSuites = true + + // Default TLS port is 443; only use if port is not manually specified, + // TLS is enabled, and the host is not localhost + if c.Port == "" && c.TLS.Enabled && (!c.TLS.Manual || c.TLS.OnDemand) && c.Host != "localhost" { + c.Port = "443" + } +} + +// Map of supported protocols. +// SSLv3 will be not supported in future release. +// HTTP/2 only supports TLS 1.2 and higher. +var supportedProtocols = map[string]uint16{ + "ssl3.0": tls.VersionSSL30, + "tls1.0": tls.VersionTLS10, + "tls1.1": tls.VersionTLS11, + "tls1.2": tls.VersionTLS12, +} + +// Map of supported ciphers, used only for parsing config. +// +// Note that, at time of writing, HTTP/2 blacklists 276 cipher suites, +// including all but two of the suites below (the two GCM suites). +// See https://http2.github.io/http2-spec/#BadCipherSuites +// +// TLS_FALLBACK_SCSV is not in this list because we manually ensure +// it is always added (even though it is not technically a cipher suite). +// +// This map, like any map, is NOT ORDERED. Do not range over this map. +var supportedCiphersMap = map[string]uint16{ + "ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + "ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "ECDHE-RSA-AES128-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + "ECDHE-RSA-AES256-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + "ECDHE-ECDSA-AES256-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + "ECDHE-ECDSA-AES128-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + "RSA-AES128-CBC-SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA, + "RSA-AES256-CBC-SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA, + "ECDHE-RSA-3DES-EDE-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + "RSA-3DES-EDE-CBC-SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, +} + +// List of supported cipher suites in descending order of preference. +// Ordering is very important! Getting the wrong order will break +// mainstream clients, especially with HTTP/2. +// +// Note that TLS_FALLBACK_SCSV is not in this list since it is always +// added manually. +var supportedCiphers = []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, +} + +// List of all the ciphers we want to use by default +var defaultCiphers = []uint16{ + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, +} diff --git a/core/https/setup_test.go b/core/https/setup_test.go new file mode 100644 index 000000000..339fcdb5a --- /dev/null +++ b/core/https/setup_test.go @@ -0,0 +1,232 @@ +package https + +import ( + "crypto/tls" + "io/ioutil" + "log" + "os" + "testing" + + "github.com/miekg/coredns/core/setup" +) + +func TestMain(m *testing.M) { + // Write test certificates to disk before tests, and clean up + // when we're done. + err := ioutil.WriteFile(certFile, testCert, 0644) + if err != nil { + log.Fatal(err) + } + err = ioutil.WriteFile(keyFile, testKey, 0644) + if err != nil { + os.Remove(certFile) + log.Fatal(err) + } + + result := m.Run() + + os.Remove(certFile) + os.Remove(keyFile) + os.Exit(result) +} + +func TestSetupParseBasic(t *testing.T) { + c := setup.NewTestController(`tls ` + certFile + ` ` + keyFile + ``) + + _, err := Setup(c) + if err != nil { + t.Errorf("Expected no errors, got: %v", err) + } + + // Basic checks + if !c.TLS.Manual { + t.Error("Expected TLS Manual=true, but was false") + } + if !c.TLS.Enabled { + t.Error("Expected TLS Enabled=true, but was false") + } + + // Security defaults + if c.TLS.ProtocolMinVersion != tls.VersionTLS10 { + t.Errorf("Expected 'tls1.0 (0x0301)' as ProtocolMinVersion, got %#v", c.TLS.ProtocolMinVersion) + } + if c.TLS.ProtocolMaxVersion != tls.VersionTLS12 { + t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMaxVersion, got %v", c.TLS.ProtocolMaxVersion) + } + + // Cipher checks + expectedCiphers := []uint16{ + tls.TLS_FALLBACK_SCSV, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + } + + // Ensure count is correct (plus one for TLS_FALLBACK_SCSV) + if len(c.TLS.Ciphers) != len(expectedCiphers) { + t.Errorf("Expected %v Ciphers (including TLS_FALLBACK_SCSV), got %v", + len(expectedCiphers), len(c.TLS.Ciphers)) + } + + // Ensure ordering is correct + for i, actual := range c.TLS.Ciphers { + if actual != expectedCiphers[i] { + t.Errorf("Expected cipher in position %d to be %0x, got %0x", i, expectedCiphers[i], actual) + } + } + + if !c.TLS.PreferServerCipherSuites { + t.Error("Expected PreferServerCipherSuites = true, but was false") + } +} + +func TestSetupParseIncompleteParams(t *testing.T) { + // Using tls without args is an error because it's unnecessary. + c := setup.NewTestController(`tls`) + _, err := Setup(c) + if err == nil { + t.Error("Expected an error, but didn't get one") + } +} + +func TestSetupParseWithOptionalParams(t *testing.T) { + params := `tls ` + certFile + ` ` + keyFile + ` { + protocols ssl3.0 tls1.2 + ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384 + }` + c := setup.NewTestController(params) + + _, err := Setup(c) + if err != nil { + t.Errorf("Expected no errors, got: %v", err) + } + + if c.TLS.ProtocolMinVersion != tls.VersionSSL30 { + t.Errorf("Expected 'ssl3.0 (0x0300)' as ProtocolMinVersion, got %#v", c.TLS.ProtocolMinVersion) + } + + if c.TLS.ProtocolMaxVersion != tls.VersionTLS12 { + t.Errorf("Expected 'tls1.2 (0x0302)' as ProtocolMaxVersion, got %#v", c.TLS.ProtocolMaxVersion) + } + + if len(c.TLS.Ciphers)-1 != 3 { + t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1) + } +} + +func TestSetupDefaultWithOptionalParams(t *testing.T) { + params := `tls { + ciphers RSA-3DES-EDE-CBC-SHA + }` + c := setup.NewTestController(params) + + _, err := Setup(c) + if err != nil { + t.Errorf("Expected no errors, got: %v", err) + } + if len(c.TLS.Ciphers)-1 != 1 { + t.Errorf("Expected 1 ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1) + } +} + +// TODO: If we allow this... but probably not a good idea. +// func TestSetupDisableHTTPRedirect(t *testing.T) { +// c := NewTestController(`tls { +// allow_http +// }`) +// _, err := TLS(c) +// if err != nil { +// t.Errorf("Expected no error, but got %v", err) +// } +// if !c.TLS.DisableHTTPRedir { +// t.Error("Expected HTTP redirect to be disabled, but it wasn't") +// } +// } + +func TestSetupParseWithWrongOptionalParams(t *testing.T) { + // Test protocols wrong params + params := `tls ` + certFile + ` ` + keyFile + ` { + protocols ssl tls + }` + c := setup.NewTestController(params) + _, err := Setup(c) + if err == nil { + t.Errorf("Expected errors, but no error returned") + } + + // Test ciphers wrong params + params = `tls ` + certFile + ` ` + keyFile + ` { + ciphers not-valid-cipher + }` + c = setup.NewTestController(params) + _, err = Setup(c) + if err == nil { + t.Errorf("Expected errors, but no error returned") + } +} + +func TestSetupParseWithClientAuth(t *testing.T) { + params := `tls ` + certFile + ` ` + keyFile + ` { + clients client_ca.crt client2_ca.crt + }` + c := setup.NewTestController(params) + _, err := Setup(c) + if err != nil { + t.Errorf("Expected no errors, got: %v", err) + } + + if count := len(c.TLS.ClientCerts); count != 2 { + t.Fatalf("Expected two client certs, had %d", count) + } + if actual := c.TLS.ClientCerts[0]; actual != "client_ca.crt" { + t.Errorf("Expected first client cert file to be '%s', but was '%s'", "client_ca.crt", actual) + } + if actual := c.TLS.ClientCerts[1]; actual != "client2_ca.crt" { + t.Errorf("Expected second client cert file to be '%s', but was '%s'", "client2_ca.crt", actual) + } + + // Test missing client cert file + params = `tls ` + certFile + ` ` + keyFile + ` { + clients + }` + c = setup.NewTestController(params) + _, err = Setup(c) + if err == nil { + t.Errorf("Expected an error, but no error returned") + } +} + +const ( + certFile = "test_cert.pem" + keyFile = "test_key.pem" +) + +var testCert = []byte(`-----BEGIN CERTIFICATE----- +MIIBkjCCATmgAwIBAgIJANfFCBcABL6LMAkGByqGSM49BAEwFDESMBAGA1UEAxMJ +bG9jYWxob3N0MB4XDTE2MDIxMDIyMjAyNFoXDTE4MDIwOTIyMjAyNFowFDESMBAG +A1UEAxMJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEs22MtnG7 +9K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDLSiVQvFZ6lUszTlczNxVk +pEfqrM6xAupB7qN1MHMwHQYDVR0OBBYEFHxYDvAxUwL4XrjPev6qZ/BiLDs5MEQG +A1UdIwQ9MDuAFHxYDvAxUwL4XrjPev6qZ/BiLDs5oRikFjAUMRIwEAYDVQQDEwls +b2NhbGhvc3SCCQDXxQgXAAS+izAMBgNVHRMEBTADAQH/MAkGByqGSM49BAEDSAAw +RQIgRvBqbyJM2JCJqhA1FmcoZjeMocmhxQHTt1c+1N2wFUgCIQDtvrivbBPA688N +Qh3sMeAKNKPsx5NxYdoWuu9KWcKz9A== +-----END CERTIFICATE----- +`) + +var testKey = []byte(`-----BEGIN EC PARAMETERS----- +BggqhkjOPQMBBw== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIGLtRmwzYVcrH3J0BnzYbGPdWVF10i9p6mxkA4+b2fURoAoGCCqGSM49 +AwEHoUQDQgAEs22MtnG79K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDL +SiVQvFZ6lUszTlczNxVkpEfqrM6xAupB7g== +-----END EC PRIVATE KEY----- +`) diff --git a/core/https/storage.go b/core/https/storage.go new file mode 100644 index 000000000..5d8e949da --- /dev/null +++ b/core/https/storage.go @@ -0,0 +1,94 @@ +package https + +import ( + "path/filepath" + "strings" + + "github.com/miekg/coredns/core/assets" +) + +// storage is used to get file paths in a consistent, +// cross-platform way for persisting Let's Encrypt assets +// on the file system. +var storage = Storage(filepath.Join(assets.Path(), "letsencrypt")) + +// Storage is a root directory and facilitates +// forming file paths derived from it. +type Storage string + +// Sites gets the directory that stores site certificate and keys. +func (s Storage) Sites() string { + return filepath.Join(string(s), "sites") +} + +// Site returns the path to the folder containing assets for domain. +func (s Storage) Site(domain string) string { + return filepath.Join(s.Sites(), domain) +} + +// SiteCertFile returns the path to the certificate file for domain. +func (s Storage) SiteCertFile(domain string) string { + return filepath.Join(s.Site(domain), domain+".crt") +} + +// SiteKeyFile returns the path to domain's private key file. +func (s Storage) SiteKeyFile(domain string) string { + return filepath.Join(s.Site(domain), domain+".key") +} + +// SiteMetaFile returns the path to the domain's asset metadata file. +func (s Storage) SiteMetaFile(domain string) string { + return filepath.Join(s.Site(domain), domain+".json") +} + +// Users gets the directory that stores account folders. +func (s Storage) Users() string { + return filepath.Join(string(s), "users") +} + +// User gets the account folder for the user with email. +func (s Storage) User(email string) string { + if email == "" { + email = emptyEmail + } + return filepath.Join(s.Users(), email) +} + +// UserRegFile gets the path to the registration file for +// the user with the given email address. +func (s Storage) UserRegFile(email string) string { + if email == "" { + email = emptyEmail + } + fileName := emailUsername(email) + if fileName == "" { + fileName = "registration" + } + return filepath.Join(s.User(email), fileName+".json") +} + +// UserKeyFile gets the path to the private key file for +// the user with the given email address. +func (s Storage) UserKeyFile(email string) string { + if email == "" { + email = emptyEmail + } + fileName := emailUsername(email) + if fileName == "" { + fileName = "private" + } + return filepath.Join(s.User(email), fileName+".key") +} + +// emailUsername returns the username portion of an +// email address (part before '@') or the original +// input if it can't find the "@" symbol. +func emailUsername(email string) string { + at := strings.Index(email, "@") + if at == -1 { + return email + } else if at == 0 { + return email[1:] + } + return email[:at] +} diff --git a/core/https/storage_test.go b/core/https/storage_test.go new file mode 100644 index 000000000..85c2220eb --- /dev/null +++ b/core/https/storage_test.go @@ -0,0 +1,88 @@ +package https + +import ( + "path/filepath" + "testing" +) + +func TestStorage(t *testing.T) { + storage = Storage("./le_test") + + if expected, actual := filepath.Join("le_test", "sites"), storage.Sites(); actual != expected { + t.Errorf("Expected Sites() to return '%s' but got '%s'", expected, actual) + } + if expected, actual := filepath.Join("le_test", "sites", "test.com"), storage.Site("test.com"); actual != expected { + t.Errorf("Expected Site() to return '%s' but got '%s'", expected, actual) + } + if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.crt"), storage.SiteCertFile("test.com"); actual != expected { + t.Errorf("Expected SiteCertFile() to return '%s' but got '%s'", expected, actual) + } + if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.key"), storage.SiteKeyFile("test.com"); actual != expected { + t.Errorf("Expected SiteKeyFile() to return '%s' but got '%s'", expected, actual) + } + if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.json"), storage.SiteMetaFile("test.com"); actual != expected { + t.Errorf("Expected SiteMetaFile() to return '%s' but got '%s'", expected, actual) + } + if expected, actual := filepath.Join("le_test", "users"), storage.Users(); actual != expected { + t.Errorf("Expected Users() to return '%s' but got '%s'", expected, actual) + } + if expected, actual := filepath.Join("le_test", "users", "me@example.com"), storage.User("me@example.com"); actual != expected { + t.Errorf("Expected User() to return '%s' but got '%s'", expected, actual) + } + if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.json"), storage.UserRegFile("me@example.com"); actual != expected { + t.Errorf("Expected UserRegFile() to return '%s' but got '%s'", expected, actual) + } + if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.key"), storage.UserKeyFile("me@example.com"); actual != expected { + t.Errorf("Expected UserKeyFile() to return '%s' but got '%s'", expected, actual) + } + + // Test with empty emails + if expected, actual := filepath.Join("le_test", "users", emptyEmail), storage.User(emptyEmail); actual != expected { + t.Errorf("Expected User(\"\") to return '%s' but got '%s'", expected, actual) + } + if expected, actual := filepath.Join("le_test", "users", emptyEmail, emptyEmail+".json"), storage.UserRegFile(""); actual != expected { + t.Errorf("Expected UserRegFile(\"\") to return '%s' but got '%s'", expected, actual) + } + if expected, actual := filepath.Join("le_test", "users", emptyEmail, emptyEmail+".key"), storage.UserKeyFile(""); actual != expected { + t.Errorf("Expected UserKeyFile(\"\") to return '%s' but got '%s'", expected, actual) + } +} + +func TestEmailUsername(t *testing.T) { + for i, test := range []struct { + input, expect string + }{ + { + input: "username@example.com", + expect: "username", + }, + { + input: "plus+addressing@example.com", + expect: "plus+addressing", + }, + { + input: "me+plus-addressing@example.com", + expect: "me+plus-addressing", + }, + { + input: "not-an-email", + expect: "not-an-email", + }, + { + input: "@foobar.com", + expect: "foobar.com", + }, + { + input: emptyEmail, + expect: emptyEmail, + }, + { + input: "", + expect: "", + }, + } { + if actual := emailUsername(test.input); actual != test.expect { + t.Errorf("Test %d: Expected username to be '%s' but was '%s'", i, test.expect, actual) + } + } +} diff --git a/core/https/user.go b/core/https/user.go new file mode 100644 index 000000000..13d93b1da --- /dev/null +++ b/core/https/user.go @@ -0,0 +1,200 @@ +package https + +import ( + "bufio" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "strings" + + "github.com/miekg/coredns/server" + "github.com/xenolf/lego/acme" +) + +// User represents a Let's Encrypt user account. +type User struct { + Email string + Registration *acme.RegistrationResource + key crypto.PrivateKey +} + +// GetEmail gets u's email. +func (u User) GetEmail() string { + return u.Email +} + +// GetRegistration gets u's registration resource. +func (u User) GetRegistration() *acme.RegistrationResource { + return u.Registration +} + +// GetPrivateKey gets u's private key. +func (u User) GetPrivateKey() crypto.PrivateKey { + return u.key +} + +// getUser loads the user with the given email from disk. +// If the user does not exist, it will create a new one, +// but it does NOT save new users to the disk or register +// them via ACME. It does NOT prompt the user. +func getUser(email string) (User, error) { + var user User + + // open user file + regFile, err := os.Open(storage.UserRegFile(email)) + if err != nil { + if os.IsNotExist(err) { + // create a new user + return newUser(email) + } + return user, err + } + defer regFile.Close() + + // load user information + err = json.NewDecoder(regFile).Decode(&user) + if err != nil { + return user, err + } + + // load their private key + user.key, err = loadPrivateKey(storage.UserKeyFile(email)) + if err != nil { + return user, err + } + + return user, nil +} + +// saveUser persists a user's key and account registration +// to the file system. It does NOT register the user via ACME +// or prompt the user. +func saveUser(user User) error { + // make user account folder + err := os.MkdirAll(storage.User(user.Email), 0700) + if err != nil { + return err + } + + // save private key file + err = savePrivateKey(user.key, storage.UserKeyFile(user.Email)) + if err != nil { + return err + } + + // save registration file + jsonBytes, err := json.MarshalIndent(&user, "", "\t") + if err != nil { + return err + } + + return ioutil.WriteFile(storage.UserRegFile(user.Email), jsonBytes, 0600) +} + +// newUser creates a new User for the given email address +// with a new private key. This function does NOT save the +// user to disk or register it via ACME. If you want to use +// a user account that might already exist, call getUser +// instead. It does NOT prompt the user. +func newUser(email string) (User, error) { + user := User{Email: email} + privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + return user, errors.New("error generating private key: " + err.Error()) + } + user.key = privateKey + return user, nil +} + +// getEmail does everything it can to obtain an email +// address from the user to use for TLS for cfg. If it +// cannot get an email address, it returns empty string. +// (It will warn the user of the consequences of an +// empty email.) This function MAY prompt the user for +// input. If userPresent is false, the operator will +// NOT be prompted and an empty email may be returned. +func getEmail(cfg server.Config, userPresent bool) string { + // First try the tls directive from the Caddyfile + leEmail := cfg.TLS.LetsEncryptEmail + if leEmail == "" { + // Then try memory (command line flag or typed by user previously) + leEmail = DefaultEmail + } + if leEmail == "" { + // Then try to get most recent user email ~/.caddy/users file + userDirs, err := ioutil.ReadDir(storage.Users()) + if err == nil { + var mostRecent os.FileInfo + for _, dir := range userDirs { + if !dir.IsDir() { + continue + } + if mostRecent == nil || dir.ModTime().After(mostRecent.ModTime()) { + leEmail = dir.Name() + DefaultEmail = leEmail // save for next time + } + } + } + } + if leEmail == "" && userPresent { + // Alas, we must bother the user and ask for an email address; + // if they proceed they also agree to the SA. + reader := bufio.NewReader(stdin) + fmt.Println("\nYour sites will be served over HTTPS automatically using Let's Encrypt.") + fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:") + fmt.Println(" " + saURL) // TODO: Show current SA link + fmt.Println("Please enter your email address so you can recover your account if needed.") + fmt.Println("You can leave it blank, but you'll lose the ability to recover your account.") + fmt.Print("Email address: ") + var err error + leEmail, err = reader.ReadString('\n') + if err != nil { + return "" + } + leEmail = strings.TrimSpace(leEmail) + DefaultEmail = leEmail + Agreed = true + } + return leEmail +} + +// promptUserAgreement prompts the user to agree to the agreement +// at agreementURL via stdin. If the agreement has changed, then pass +// true as the second argument. If this is the user's first time +// agreeing, pass false. It returns whether the user agreed or not. +func promptUserAgreement(agreementURL string, changed bool) bool { + if changed { + fmt.Printf("The Let's Encrypt Subscriber Agreement has changed:\n %s\n", agreementURL) + fmt.Print("Do you agree to the new terms? (y/n): ") + } else { + fmt.Printf("To continue, you must agree to the Let's Encrypt Subscriber Agreement:\n %s\n", agreementURL) + fmt.Print("Do you agree to the terms? (y/n): ") + } + + reader := bufio.NewReader(stdin) + answer, err := reader.ReadString('\n') + if err != nil { + return false + } + answer = strings.ToLower(strings.TrimSpace(answer)) + + return answer == "y" || answer == "yes" +} + +// stdin is used to read the user's input if prompted; +// this is changed by tests during tests. +var stdin = io.ReadWriter(os.Stdin) + +// The name of the folder for accounts where the email +// address was not provided; default 'username' if you will. +const emptyEmail = "default" + +// TODO: Use latest +const saURL = "https://letsencrypt.org/documents/LE-SA-v1.0.1-July-27-2015.pdf" diff --git a/core/https/user_test.go b/core/https/user_test.go new file mode 100644 index 000000000..3e1af5007 --- /dev/null +++ b/core/https/user_test.go @@ -0,0 +1,196 @@ +package https + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "io" + "os" + "strings" + "testing" + "time" + + "github.com/miekg/coredns/server" + "github.com/xenolf/lego/acme" +) + +func TestUser(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 128) + if err != nil { + t.Fatalf("Could not generate test private key: %v", err) + } + u := User{ + Email: "me@mine.com", + Registration: new(acme.RegistrationResource), + key: privateKey, + } + + if expected, actual := "me@mine.com", u.GetEmail(); actual != expected { + t.Errorf("Expected email '%s' but got '%s'", expected, actual) + } + if u.GetRegistration() == nil { + t.Error("Expected a registration resource, but got nil") + } + if expected, actual := privateKey, u.GetPrivateKey(); actual != expected { + t.Errorf("Expected the private key at address %p but got one at %p instead ", expected, actual) + } +} + +func TestNewUser(t *testing.T) { + email := "me@foobar.com" + user, err := newUser(email) + if err != nil { + t.Fatalf("Error creating user: %v", err) + } + if user.key == nil { + t.Error("Private key is nil") + } + if user.Email != email { + t.Errorf("Expected email to be %s, but was %s", email, user.Email) + } + if user.Registration != nil { + t.Error("New user already has a registration resource; it shouldn't") + } +} + +func TestSaveUser(t *testing.T) { + storage = Storage("./testdata") + defer os.RemoveAll(string(storage)) + + email := "me@foobar.com" + user, err := newUser(email) + if err != nil { + t.Fatalf("Error creating user: %v", err) + } + + err = saveUser(user) + if err != nil { + t.Fatalf("Error saving user: %v", err) + } + _, err = os.Stat(storage.UserRegFile(email)) + if err != nil { + t.Errorf("Cannot access user registration file, error: %v", err) + } + _, err = os.Stat(storage.UserKeyFile(email)) + if err != nil { + t.Errorf("Cannot access user private key file, error: %v", err) + } +} + +func TestGetUserDoesNotAlreadyExist(t *testing.T) { + storage = Storage("./testdata") + defer os.RemoveAll(string(storage)) + + user, err := getUser("user_does_not_exist@foobar.com") + if err != nil { + t.Fatalf("Error getting user: %v", err) + } + + if user.key == nil { + t.Error("Expected user to have a private key, but it was nil") + } +} + +func TestGetUserAlreadyExists(t *testing.T) { + storage = Storage("./testdata") + defer os.RemoveAll(string(storage)) + + email := "me@foobar.com" + + // Set up test + user, err := newUser(email) + if err != nil { + t.Fatalf("Error creating user: %v", err) + } + err = saveUser(user) + if err != nil { + t.Fatalf("Error saving user: %v", err) + } + + // Expect to load user from disk + user2, err := getUser(email) + if err != nil { + t.Fatalf("Error getting user: %v", err) + } + + // Assert keys are the same + if !PrivateKeysSame(user.key, user2.key) { + t.Error("Expected private key to be the same after loading, but it wasn't") + } + + // Assert emails are the same + if user.Email != user2.Email { + t.Errorf("Expected emails to be equal, but was '%s' before and '%s' after loading", user.Email, user2.Email) + } +} + +func TestGetEmail(t *testing.T) { + // let's not clutter up the output + origStdout := os.Stdout + os.Stdout = nil + defer func() { os.Stdout = origStdout }() + + storage = Storage("./testdata") + defer os.RemoveAll(string(storage)) + DefaultEmail = "test2@foo.com" + + // Test1: Use email in config + config := server.Config{ + TLS: server.TLSConfig{ + LetsEncryptEmail: "test1@foo.com", + }, + } + actual := getEmail(config, true) + if actual != "test1@foo.com" { + t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", "test1@foo.com", actual) + } + + // Test2: Use default email from flag (or user previously typing it) + actual = getEmail(server.Config{}, true) + if actual != DefaultEmail { + t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", DefaultEmail, actual) + } + + // Test3: Get input from user + DefaultEmail = "" + stdin = new(bytes.Buffer) + _, err := io.Copy(stdin, strings.NewReader("test3@foo.com\n")) + if err != nil { + t.Fatalf("Could not simulate user input, error: %v", err) + } + actual = getEmail(server.Config{}, true) + if actual != "test3@foo.com" { + t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual) + } + + // Test4: Get most recent email from before + DefaultEmail = "" + for i, eml := range []string{ + "test4-3@foo.com", + "test4-2@foo.com", + "test4-1@foo.com", + } { + u, err := newUser(eml) + if err != nil { + t.Fatalf("Error creating user %d: %v", i, err) + } + err = saveUser(u) + if err != nil { + t.Fatalf("Error saving user %d: %v", i, err) + } + + // Change modified time so they're all different, so the test becomes deterministic + f, err := os.Stat(storage.User(eml)) + if err != nil { + t.Fatalf("Could not access user folder for '%s': %v", eml, err) + } + chTime := f.ModTime().Add(-(time.Duration(i) * time.Second)) + if err := os.Chtimes(storage.User(eml), chTime, chTime); err != nil { + t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err) + } + } + actual = getEmail(server.Config{}, true) + if actual != "test4-3@foo.com" { + t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual) + } +} diff --git a/core/parse/dispenser.go b/core/parse/dispenser.go new file mode 100644 index 000000000..08aa6e76d --- /dev/null +++ b/core/parse/dispenser.go @@ -0,0 +1,251 @@ +package parse + +import ( + "errors" + "fmt" + "io" + "strings" +) + +// Dispenser is a type that dispenses tokens, similarly to a lexer, +// except that it can do so with some notion of structure and has +// some really convenient methods. +type Dispenser struct { + filename string + tokens []token + cursor int + nesting int +} + +// NewDispenser returns a Dispenser, ready to use for parsing the given input. +func NewDispenser(filename string, input io.Reader) Dispenser { + return Dispenser{ + filename: filename, + tokens: allTokens(input), + cursor: -1, + } +} + +// NewDispenserTokens returns a Dispenser filled with the given tokens. +func NewDispenserTokens(filename string, tokens []token) Dispenser { + return Dispenser{ + filename: filename, + tokens: tokens, + cursor: -1, + } +} + +// Next loads the next token. Returns true if a token +// was loaded; false otherwise. If false, all tokens +// have been consumed. +func (d *Dispenser) Next() bool { + if d.cursor < len(d.tokens)-1 { + d.cursor++ + return true + } + return false +} + +// NextArg loads the next token if it is on the same +// line. Returns true if a token was loaded; false +// otherwise. If false, all tokens on the line have +// been consumed. It handles imported tokens correctly. +func (d *Dispenser) NextArg() bool { + if d.cursor < 0 { + d.cursor++ + return true + } + if d.cursor >= len(d.tokens) { + return false + } + if d.cursor < len(d.tokens)-1 && + d.tokens[d.cursor].file == d.tokens[d.cursor+1].file && + d.tokens[d.cursor].line+d.numLineBreaks(d.cursor) == d.tokens[d.cursor+1].line { + d.cursor++ + return true + } + return false +} + +// NextLine loads the next token only if it is not on the same +// line as the current token, and returns true if a token was +// loaded; false otherwise. If false, there is not another token +// or it is on the same line. It handles imported tokens correctly. +func (d *Dispenser) NextLine() bool { + if d.cursor < 0 { + d.cursor++ + return true + } + if d.cursor >= len(d.tokens) { + return false + } + if d.cursor < len(d.tokens)-1 && + (d.tokens[d.cursor].file != d.tokens[d.cursor+1].file || + d.tokens[d.cursor].line+d.numLineBreaks(d.cursor) < d.tokens[d.cursor+1].line) { + d.cursor++ + return true + } + return false +} + +// NextBlock can be used as the condition of a for loop +// to load the next token as long as it opens a block or +// is already in a block. It returns true if a token was +// loaded, or false when the block's closing curly brace +// was loaded and thus the block ended. Nested blocks are +// not supported. +func (d *Dispenser) NextBlock() bool { + if d.nesting > 0 { + d.Next() + if d.Val() == "}" { + d.nesting-- + return false + } + return true + } + if !d.NextArg() { // block must open on same line + return false + } + if d.Val() != "{" { + d.cursor-- // roll back if not opening brace + return false + } + d.Next() + if d.Val() == "}" { + // Open and then closed right away + return false + } + d.nesting++ + return true +} + +// IncrNest adds a level of nesting to the dispenser. +func (d *Dispenser) IncrNest() { + d.nesting++ + return +} + +// Val gets the text of the current token. If there is no token +// loaded, it returns empty string. +func (d *Dispenser) Val() string { + if d.cursor < 0 || d.cursor >= len(d.tokens) { + return "" + } + return d.tokens[d.cursor].text +} + +// Line gets the line number of the current token. If there is no token +// loaded, it returns 0. +func (d *Dispenser) Line() int { + if d.cursor < 0 || d.cursor >= len(d.tokens) { + return 0 + } + return d.tokens[d.cursor].line +} + +// File gets the filename of the current token. If there is no token loaded, +// it returns the filename originally given when parsing started. +func (d *Dispenser) File() string { + if d.cursor < 0 || d.cursor >= len(d.tokens) { + return d.filename + } + if tokenFilename := d.tokens[d.cursor].file; tokenFilename != "" { + return tokenFilename + } + return d.filename +} + +// Args is a convenience function that loads the next arguments +// (tokens on the same line) into an arbitrary number of strings +// pointed to in targets. If there are fewer tokens available +// than string pointers, the remaining strings will not be changed +// and false will be returned. If there were enough tokens available +// to fill the arguments, then true will be returned. +func (d *Dispenser) Args(targets ...*string) bool { + enough := true + for i := 0; i < len(targets); i++ { + if !d.NextArg() { + enough = false + break + } + *targets[i] = d.Val() + } + return enough +} + +// RemainingArgs loads any more arguments (tokens on the same line) +// into a slice and returns them. Open curly brace tokens also indicate +// the end of arguments, and the curly brace is not included in +// the return value nor is it loaded. +func (d *Dispenser) RemainingArgs() []string { + var args []string + + for d.NextArg() { + if d.Val() == "{" { + d.cursor-- + break + } + args = append(args, d.Val()) + } + + return args +} + +// ArgErr returns an argument error, meaning that another +// argument was expected but not found. In other words, +// a line break or open curly brace was encountered instead of +// an argument. +func (d *Dispenser) ArgErr() error { + if d.Val() == "{" { + return d.Err("Unexpected token '{', expecting argument") + } + return d.Errf("Wrong argument count or unexpected line ending after '%s'", d.Val()) +} + +// SyntaxErr creates a generic syntax error which explains what was +// found and what was expected. +func (d *Dispenser) SyntaxErr(expected string) error { + msg := fmt.Sprintf("%s:%d - Syntax error: Unexpected token '%s', expecting '%s'", d.File(), d.Line(), d.Val(), expected) + return errors.New(msg) +} + +// EOFErr returns an error indicating that the dispenser reached +// the end of the input when searching for the next token. +func (d *Dispenser) EOFErr() error { + return d.Errf("Unexpected EOF") +} + +// Err generates a custom parse error with a message of msg. +func (d *Dispenser) Err(msg string) error { + msg = fmt.Sprintf("%s:%d - Parse error: %s", d.File(), d.Line(), msg) + return errors.New(msg) +} + +// Errf is like Err, but for formatted error messages +func (d *Dispenser) Errf(format string, args ...interface{}) error { + return d.Err(fmt.Sprintf(format, args...)) +} + +// numLineBreaks counts how many line breaks are in the token +// value given by the token index tknIdx. It returns 0 if the +// token does not exist or there are no line breaks. +func (d *Dispenser) numLineBreaks(tknIdx int) int { + if tknIdx < 0 || tknIdx >= len(d.tokens) { + return 0 + } + return strings.Count(d.tokens[tknIdx].text, "\n") +} + +// isNewLine determines whether the current token is on a different +// line (higher line number) than the previous token. It handles imported +// tokens correctly. If there isn't a previous token, it returns true. +func (d *Dispenser) isNewLine() bool { + if d.cursor < 1 { + return true + } + if d.cursor > len(d.tokens)-1 { + return false + } + return d.tokens[d.cursor-1].file != d.tokens[d.cursor].file || + d.tokens[d.cursor-1].line+d.numLineBreaks(d.cursor-1) < d.tokens[d.cursor].line +} diff --git a/core/parse/dispenser_test.go b/core/parse/dispenser_test.go new file mode 100644 index 000000000..20a7ddcac --- /dev/null +++ b/core/parse/dispenser_test.go @@ -0,0 +1,292 @@ +package parse + +import ( + "reflect" + "strings" + "testing" +) + +func TestDispenser_Val_Next(t *testing.T) { + input := `host:port + dir1 arg1 + dir2 arg2 arg3 + dir3` + d := NewDispenser("Testfile", strings.NewReader(input)) + + if val := d.Val(); val != "" { + t.Fatalf("Val(): Should return empty string when no token loaded; got '%s'", val) + } + + assertNext := func(shouldLoad bool, expectedCursor int, expectedVal string) { + if loaded := d.Next(); loaded != shouldLoad { + t.Errorf("Next(): Expected %v but got %v instead (val '%s')", shouldLoad, loaded, d.Val()) + } + if d.cursor != expectedCursor { + t.Errorf("Expected cursor to be %d, but was %d", expectedCursor, d.cursor) + } + if d.nesting != 0 { + t.Errorf("Nesting should be 0, was %d instead", d.nesting) + } + if val := d.Val(); val != expectedVal { + t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val) + } + } + + assertNext(true, 0, "host:port") + assertNext(true, 1, "dir1") + assertNext(true, 2, "arg1") + assertNext(true, 3, "dir2") + assertNext(true, 4, "arg2") + assertNext(true, 5, "arg3") + assertNext(true, 6, "dir3") + // Note: This next test simply asserts existing behavior. + // If desired, we may wish to empty the token value after + // reading past the EOF. Open an issue if you want this change. + assertNext(false, 6, "dir3") +} + +func TestDispenser_NextArg(t *testing.T) { + input := `dir1 arg1 + dir2 arg2 arg3 + dir3` + d := NewDispenser("Testfile", strings.NewReader(input)) + + assertNext := func(shouldLoad bool, expectedVal string, expectedCursor int) { + if d.Next() != shouldLoad { + t.Errorf("Next(): Should load token but got false instead (val: '%s')", d.Val()) + } + if d.cursor != expectedCursor { + t.Errorf("Next(): Expected cursor to be at %d, but it was %d", expectedCursor, d.cursor) + } + if val := d.Val(); val != expectedVal { + t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val) + } + } + + assertNextArg := func(expectedVal string, loadAnother bool, expectedCursor int) { + if d.NextArg() != true { + t.Error("NextArg(): Should load next argument but got false instead") + } + if d.cursor != expectedCursor { + t.Errorf("NextArg(): Expected cursor to be at %d, but it was %d", expectedCursor, d.cursor) + } + if val := d.Val(); val != expectedVal { + t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val) + } + if !loadAnother { + if d.NextArg() != false { + t.Fatalf("NextArg(): Should NOT load another argument, but got true instead (val: '%s')", d.Val()) + } + if d.cursor != expectedCursor { + t.Errorf("NextArg(): Expected cursor to remain at %d, but it was %d", expectedCursor, d.cursor) + } + } + } + + assertNext(true, "dir1", 0) + assertNextArg("arg1", false, 1) + assertNext(true, "dir2", 2) + assertNextArg("arg2", true, 3) + assertNextArg("arg3", false, 4) + assertNext(true, "dir3", 5) + assertNext(false, "dir3", 5) +} + +func TestDispenser_NextLine(t *testing.T) { + input := `host:port + dir1 arg1 + dir2 arg2 arg3` + d := NewDispenser("Testfile", strings.NewReader(input)) + + assertNextLine := func(shouldLoad bool, expectedVal string, expectedCursor int) { + if d.NextLine() != shouldLoad { + t.Errorf("NextLine(): Should load token but got false instead (val: '%s')", d.Val()) + } + if d.cursor != expectedCursor { + t.Errorf("NextLine(): Expected cursor to be %d, instead was %d", expectedCursor, d.cursor) + } + if val := d.Val(); val != expectedVal { + t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val) + } + } + + assertNextLine(true, "host:port", 0) + assertNextLine(true, "dir1", 1) + assertNextLine(false, "dir1", 1) + d.Next() // arg1 + assertNextLine(true, "dir2", 3) + assertNextLine(false, "dir2", 3) + d.Next() // arg2 + assertNextLine(false, "arg2", 4) + d.Next() // arg3 + assertNextLine(false, "arg3", 5) +} + +func TestDispenser_NextBlock(t *testing.T) { + input := `foobar1 { + sub1 arg1 + sub2 + } + foobar2 { + }` + d := NewDispenser("Testfile", strings.NewReader(input)) + + assertNextBlock := func(shouldLoad bool, expectedCursor, expectedNesting int) { + if loaded := d.NextBlock(); loaded != shouldLoad { + t.Errorf("NextBlock(): Should return %v but got %v", shouldLoad, loaded) + } + if d.cursor != expectedCursor { + t.Errorf("NextBlock(): Expected cursor to be %d, was %d", expectedCursor, d.cursor) + } + if d.nesting != expectedNesting { + t.Errorf("NextBlock(): Nesting should be %d, not %d", expectedNesting, d.nesting) + } + } + + assertNextBlock(false, -1, 0) + d.Next() // foobar1 + assertNextBlock(true, 2, 1) + assertNextBlock(true, 3, 1) + assertNextBlock(true, 4, 1) + assertNextBlock(false, 5, 0) + d.Next() // foobar2 + assertNextBlock(false, 8, 0) // empty block is as if it didn't exist +} + +func TestDispenser_Args(t *testing.T) { + var s1, s2, s3 string + input := `dir1 arg1 arg2 arg3 + dir2 arg4 arg5 + dir3 arg6 arg7 + dir4` + d := NewDispenser("Testfile", strings.NewReader(input)) + + d.Next() // dir1 + + // As many strings as arguments + if all := d.Args(&s1, &s2, &s3); !all { + t.Error("Args(): Expected true, got false") + } + if s1 != "arg1" { + t.Errorf("Args(): Expected s1 to be 'arg1', got '%s'", s1) + } + if s2 != "arg2" { + t.Errorf("Args(): Expected s2 to be 'arg2', got '%s'", s2) + } + if s3 != "arg3" { + t.Errorf("Args(): Expected s3 to be 'arg3', got '%s'", s3) + } + + d.Next() // dir2 + + // More strings than arguments + if all := d.Args(&s1, &s2, &s3); all { + t.Error("Args(): Expected false, got true") + } + if s1 != "arg4" { + t.Errorf("Args(): Expected s1 to be 'arg4', got '%s'", s1) + } + if s2 != "arg5" { + t.Errorf("Args(): Expected s2 to be 'arg5', got '%s'", s2) + } + if s3 != "arg3" { + t.Errorf("Args(): Expected s3 to be unchanged ('arg3'), instead got '%s'", s3) + } + + // (quick cursor check just for kicks and giggles) + if d.cursor != 6 { + t.Errorf("Cursor should be 6, but is %d", d.cursor) + } + + d.Next() // dir3 + + // More arguments than strings + if all := d.Args(&s1); !all { + t.Error("Args(): Expected true, got false") + } + if s1 != "arg6" { + t.Errorf("Args(): Expected s1 to be 'arg6', got '%s'", s1) + } + + d.Next() // dir4 + + // No arguments or strings + if all := d.Args(); !all { + t.Error("Args(): Expected true, got false") + } + + // No arguments but at least one string + if all := d.Args(&s1); all { + t.Error("Args(): Expected false, got true") + } +} + +func TestDispenser_RemainingArgs(t *testing.T) { + input := `dir1 arg1 arg2 arg3 + dir2 arg4 arg5 + dir3 arg6 { arg7 + dir4` + d := NewDispenser("Testfile", strings.NewReader(input)) + + d.Next() // dir1 + + args := d.RemainingArgs() + if expected := []string{"arg1", "arg2", "arg3"}; !reflect.DeepEqual(args, expected) { + t.Errorf("RemainingArgs(): Expected %v, got %v", expected, args) + } + + d.Next() // dir2 + + args = d.RemainingArgs() + if expected := []string{"arg4", "arg5"}; !reflect.DeepEqual(args, expected) { + t.Errorf("RemainingArgs(): Expected %v, got %v", expected, args) + } + + d.Next() // dir3 + + args = d.RemainingArgs() + if expected := []string{"arg6"}; !reflect.DeepEqual(args, expected) { + t.Errorf("RemainingArgs(): Expected %v, got %v", expected, args) + } + + d.Next() // { + d.Next() // arg7 + d.Next() // dir4 + + args = d.RemainingArgs() + if len(args) != 0 { + t.Errorf("RemainingArgs(): Expected %v, got %v", []string{}, args) + } +} + +func TestDispenser_ArgErr_Err(t *testing.T) { + input := `dir1 { + } + dir2 arg1 arg2` + d := NewDispenser("Testfile", strings.NewReader(input)) + + d.cursor = 1 // { + + if err := d.ArgErr(); err == nil || !strings.Contains(err.Error(), "{") { + t.Errorf("ArgErr(): Expected an error message with { in it, but got '%v'", err) + } + + d.cursor = 5 // arg2 + + if err := d.ArgErr(); err == nil || !strings.Contains(err.Error(), "arg2") { + t.Errorf("ArgErr(): Expected an error message with 'arg2' in it; got '%v'", err) + } + + err := d.Err("foobar") + if err == nil { + t.Fatalf("Err(): Expected an error, got nil") + } + + if !strings.Contains(err.Error(), "Testfile:3") { + t.Errorf("Expected error message with filename:line in it; got '%v'", err) + } + + if !strings.Contains(err.Error(), "foobar") { + t.Errorf("Expected error message with custom message in it ('foobar'); got '%v'", err) + } +} diff --git a/core/parse/import_glob0.txt b/core/parse/import_glob0.txt new file mode 100644 index 000000000..e610b5e7c --- /dev/null +++ b/core/parse/import_glob0.txt @@ -0,0 +1,6 @@ +glob0.host0 { + dir2 arg1 +} + +glob0.host1 { +} diff --git a/core/parse/import_glob1.txt b/core/parse/import_glob1.txt new file mode 100644 index 000000000..111eb044d --- /dev/null +++ b/core/parse/import_glob1.txt @@ -0,0 +1,4 @@ +glob1.host0 { + dir1 + dir2 arg1 +} diff --git a/core/parse/import_glob2.txt b/core/parse/import_glob2.txt new file mode 100644 index 000000000..c09f784ec --- /dev/null +++ b/core/parse/import_glob2.txt @@ -0,0 +1,3 @@ +glob2.host0 { + dir2 arg1 +} diff --git a/core/parse/import_test1.txt b/core/parse/import_test1.txt new file mode 100644 index 000000000..dac7b29be --- /dev/null +++ b/core/parse/import_test1.txt @@ -0,0 +1,2 @@ +dir2 arg1 arg2 +dir3
\ No newline at end of file diff --git a/core/parse/import_test2.txt b/core/parse/import_test2.txt new file mode 100644 index 000000000..140c87939 --- /dev/null +++ b/core/parse/import_test2.txt @@ -0,0 +1,4 @@ +host1 { + dir1 + dir2 arg1 +}
\ No newline at end of file diff --git a/core/parse/lexer.go b/core/parse/lexer.go new file mode 100644 index 000000000..d2939eba2 --- /dev/null +++ b/core/parse/lexer.go @@ -0,0 +1,122 @@ +package parse + +import ( + "bufio" + "io" + "unicode" +) + +type ( + // lexer is a utility which can get values, token by + // token, from a Reader. A token is a word, and tokens + // are separated by whitespace. A word can be enclosed + // in quotes if it contains whitespace. + lexer struct { + reader *bufio.Reader + token token + line int + } + + // token represents a single parsable unit. + token struct { + file string + line int + text string + } +) + +// load prepares the lexer to scan an input for tokens. +func (l *lexer) load(input io.Reader) error { + l.reader = bufio.NewReader(input) + l.line = 1 + return nil +} + +// next loads the next token into the lexer. +// A token is delimited by whitespace, unless +// the token starts with a quotes character (") +// in which case the token goes until the closing +// quotes (the enclosing quotes are not included). +// Inside quoted strings, quotes may be escaped +// with a preceding \ character. No other chars +// may be escaped. The rest of the line is skipped +// if a "#" character is read in. Returns true if +// a token was loaded; false otherwise. +func (l *lexer) next() bool { + var val []rune + var comment, quoted, escaped bool + + makeToken := func() bool { + l.token.text = string(val) + return true + } + + for { + ch, _, err := l.reader.ReadRune() + if err != nil { + if len(val) > 0 { + return makeToken() + } + if err == io.EOF { + return false + } + panic(err) + } + + if quoted { + if !escaped { + if ch == '\\' { + escaped = true + continue + } else if ch == '"' { + quoted = false + return makeToken() + } + } + if ch == '\n' { + l.line++ + } + if escaped { + // only escape quotes + if ch != '"' { + val = append(val, '\\') + } + } + val = append(val, ch) + escaped = false + continue + } + + if unicode.IsSpace(ch) { + if ch == '\r' { + continue + } + if ch == '\n' { + l.line++ + comment = false + } + if len(val) > 0 { + return makeToken() + } + continue + } + + if ch == '#' { + comment = true + } + + if comment { + continue + } + + if len(val) == 0 { + l.token = token{line: l.line} + if ch == '"' { + quoted = true + continue + } + } + + val = append(val, ch) + } +} diff --git a/core/parse/lexer_test.go b/core/parse/lexer_test.go new file mode 100644 index 000000000..f12c7e7dc --- /dev/null +++ b/core/parse/lexer_test.go @@ -0,0 +1,165 @@ +package parse + +import ( + "strings" + "testing" +) + +type lexerTestCase struct { + input string + expected []token +} + +func TestLexer(t *testing.T) { + testCases := []lexerTestCase{ + { + input: `host:123`, + expected: []token{ + {line: 1, text: "host:123"}, + }, + }, + { + input: `host:123 + + directive`, + expected: []token{ + {line: 1, text: "host:123"}, + {line: 3, text: "directive"}, + }, + }, + { + input: `host:123 { + directive + }`, + expected: []token{ + {line: 1, text: "host:123"}, + {line: 1, text: "{"}, + {line: 2, text: "directive"}, + {line: 3, text: "}"}, + }, + }, + { + input: `host:123 { directive }`, + expected: []token{ + {line: 1, text: "host:123"}, + {line: 1, text: "{"}, + {line: 1, text: "directive"}, + {line: 1, text: "}"}, + }, + }, + { + input: `host:123 { + #comment + directive + # comment + foobar # another comment + }`, + expected: []token{ + {line: 1, text: "host:123"}, + {line: 1, text: "{"}, + {line: 3, text: "directive"}, + {line: 5, text: "foobar"}, + {line: 6, text: "}"}, + }, + }, + { + input: `a "quoted value" b + foobar`, + expected: []token{ + {line: 1, text: "a"}, + {line: 1, text: "quoted value"}, + {line: 1, text: "b"}, + {line: 2, text: "foobar"}, + }, + }, + { + input: `A "quoted \"value\" inside" B`, + expected: []token{ + {line: 1, text: "A"}, + {line: 1, text: `quoted "value" inside`}, + {line: 1, text: "B"}, + }, + }, + { + input: `"don't\escape"`, + expected: []token{ + {line: 1, text: `don't\escape`}, + }, + }, + { + input: `"don't\\escape"`, + expected: []token{ + {line: 1, text: `don't\\escape`}, + }, + }, + { + input: `A "quoted value with line + break inside" { + foobar + }`, + expected: []token{ + {line: 1, text: "A"}, + {line: 1, text: "quoted value with line\n\t\t\t\t\tbreak inside"}, + {line: 2, text: "{"}, + {line: 3, text: "foobar"}, + {line: 4, text: "}"}, + }, + }, + { + input: `"C:\php\php-cgi.exe"`, + expected: []token{ + {line: 1, text: `C:\php\php-cgi.exe`}, + }, + }, + { + input: `empty "" string`, + expected: []token{ + {line: 1, text: `empty`}, + {line: 1, text: ``}, + {line: 1, text: `string`}, + }, + }, + { + input: "skip those\r\nCR characters", + expected: []token{ + {line: 1, text: "skip"}, + {line: 1, text: "those"}, + {line: 2, text: "CR"}, + {line: 2, text: "characters"}, + }, + }, + } + + for i, testCase := range testCases { + actual := tokenize(testCase.input) + lexerCompare(t, i, testCase.expected, actual) + } +} + +func tokenize(input string) (tokens []token) { + l := lexer{} + l.load(strings.NewReader(input)) + for l.next() { + tokens = append(tokens, l.token) + } + return +} + +func lexerCompare(t *testing.T, n int, expected, actual []token) { + if len(expected) != len(actual) { + t.Errorf("Test case %d: expected %d token(s) but got %d", n, len(expected), len(actual)) + } + + for i := 0; i < len(actual) && i < len(expected); i++ { + if actual[i].line != expected[i].line { + t.Errorf("Test case %d token %d ('%s'): expected line %d but was line %d", + n, i, expected[i].text, expected[i].line, actual[i].line) + break + } + if actual[i].text != expected[i].text { + t.Errorf("Test case %d token %d: expected text '%s' but was '%s'", + n, i, expected[i].text, actual[i].text) + break + } + } +} diff --git a/core/parse/parse.go b/core/parse/parse.go new file mode 100644 index 000000000..faef36c28 --- /dev/null +++ b/core/parse/parse.go @@ -0,0 +1,32 @@ +// Package parse provides facilities for parsing configuration files. +package parse + +import "io" + +// ServerBlocks parses the input just enough to organize tokens, +// in order, by server block. No further parsing is performed. +// If checkDirectives is true, only valid directives will be allowed +// otherwise we consider it a parse error. Server blocks are returned +// in the order in which they appear. +func ServerBlocks(filename string, input io.Reader, checkDirectives bool) ([]ServerBlock, error) { + p := parser{Dispenser: NewDispenser(filename, input)} + p.checkDirectives = checkDirectives + blocks, err := p.parseAll() + return blocks, err +} + +// allTokens lexes the entire input, but does not parse it. +// It returns all the tokens from the input, unstructured +// and in order. +func allTokens(input io.Reader) (tokens []token) { + l := new(lexer) + l.load(input) + for l.next() { + tokens = append(tokens, l.token) + } + return +} + +// ValidDirectives is a set of directives that are valid (unordered). Populated +// by config package's init function. +var ValidDirectives = make(map[string]struct{}) diff --git a/core/parse/parse_test.go b/core/parse/parse_test.go new file mode 100644 index 000000000..48746300f --- /dev/null +++ b/core/parse/parse_test.go @@ -0,0 +1,22 @@ +package parse + +import ( + "strings" + "testing" +) + +func TestAllTokens(t *testing.T) { + input := strings.NewReader("a b c\nd e") + expected := []string{"a", "b", "c", "d", "e"} + tokens := allTokens(input) + + if len(tokens) != len(expected) { + t.Fatalf("Expected %d tokens, got %d", len(expected), len(tokens)) + } + + for i, val := range expected { + if tokens[i].text != val { + t.Errorf("Token %d should be '%s' but was '%s'", i, val, tokens[i].text) + } + } +} diff --git a/core/parse/parsing.go b/core/parse/parsing.go new file mode 100644 index 000000000..6e73bd584 --- /dev/null +++ b/core/parse/parsing.go @@ -0,0 +1,379 @@ +package parse + +import ( + "net" + "os" + "path/filepath" + "strings" + + "github.com/miekg/dns" +) + +type parser struct { + Dispenser + block ServerBlock // current server block being parsed + eof bool // if we encounter a valid EOF in a hard place + checkDirectives bool // if true, directives must be known +} + +func (p *parser) parseAll() ([]ServerBlock, error) { + var blocks []ServerBlock + + for p.Next() { + err := p.parseOne() + if err != nil { + return blocks, err + } + if len(p.block.Addresses) > 0 { + blocks = append(blocks, p.block) + } + } + + return blocks, nil +} + +func (p *parser) parseOne() error { + p.block = ServerBlock{Tokens: make(map[string][]token)} + + err := p.begin() + if err != nil { + return err + } + + return nil +} + +func (p *parser) begin() error { + if len(p.tokens) == 0 { + return nil + } + + err := p.addresses() + if err != nil { + return err + } + + if p.eof { + // this happens if the Caddyfile consists of only + // a line of addresses and nothing else + return nil + } + + err = p.blockContents() + if err != nil { + return err + } + + return nil +} + +func (p *parser) addresses() error { + var expectingAnother bool + + for { + tkn := replaceEnvVars(p.Val()) + + // special case: import directive replaces tokens during parse-time + if tkn == "import" && p.isNewLine() { + err := p.doImport() + if err != nil { + return err + } + continue + } + + // Open brace definitely indicates end of addresses + if tkn == "{" { + if expectingAnother { + return p.Errf("Expected another address but had '%s' - check for extra comma", tkn) + } + break + } + + if tkn != "" { // empty token possible if user typed "" in Caddyfile + // Trailing comma indicates another address will follow, which + // may possibly be on the next line + if tkn[len(tkn)-1] == ',' { + tkn = tkn[:len(tkn)-1] + expectingAnother = true + } else { + expectingAnother = false // but we may still see another one on this line + } + + // Parse and save this address + addr, err := standardAddress(tkn) + if err != nil { + return err + } + p.block.Addresses = append(p.block.Addresses, addr) + } + + // Advance token and possibly break out of loop or return error + hasNext := p.Next() + if expectingAnother && !hasNext { + return p.EOFErr() + } + if !hasNext { + p.eof = true + break // EOF + } + if !expectingAnother && p.isNewLine() { + break + } + } + + return nil +} + +func (p *parser) blockContents() error { + errOpenCurlyBrace := p.openCurlyBrace() + if errOpenCurlyBrace != nil { + // single-server configs don't need curly braces + p.cursor-- + } + + err := p.directives() + if err != nil { + return err + } + + // Only look for close curly brace if there was an opening + if errOpenCurlyBrace == nil { + err = p.closeCurlyBrace() + if err != nil { + return err + } + } + + return nil +} + +// directives parses through all the lines for directives +// and it expects the next token to be the first +// directive. It goes until EOF or closing curly brace +// which ends the server block. +func (p *parser) directives() error { + for p.Next() { + // end of server block + if p.Val() == "}" { + break + } + + // special case: import directive replaces tokens during parse-time + if p.Val() == "import" { + err := p.doImport() + if err != nil { + return err + } + p.cursor-- // cursor is advanced when we continue, so roll back one more + continue + } + + // normal case: parse a directive on this line + if err := p.directive(); err != nil { + return err + } + } + return nil +} + +// doImport swaps out the import directive and its argument +// (a total of 2 tokens) with the tokens in the specified file +// or globbing pattern. When the function returns, the cursor +// is on the token before where the import directive was. In +// other words, call Next() to access the first token that was +// imported. +func (p *parser) doImport() error { + // syntax check + if !p.NextArg() { + return p.ArgErr() + } + importPattern := p.Val() + if p.NextArg() { + return p.Err("Import takes only one argument (glob pattern or file)") + } + + // do glob + matches, err := filepath.Glob(importPattern) + if err != nil { + return p.Errf("Failed to use import pattern %s: %v", importPattern, err) + } + if len(matches) == 0 { + return p.Errf("No files matching import pattern %s", importPattern) + } + + // splice out the import directive and its argument (2 tokens total) + tokensBefore := p.tokens[:p.cursor-1] + tokensAfter := p.tokens[p.cursor+1:] + + // collect all the imported tokens + var importedTokens []token + for _, importFile := range matches { + newTokens, err := p.doSingleImport(importFile) + if err != nil { + return err + } + importedTokens = append(importedTokens, newTokens...) + } + + // splice the imported tokens in the place of the import statement + // and rewind cursor so Next() will land on first imported token + p.tokens = append(tokensBefore, append(importedTokens, tokensAfter...)...) + p.cursor-- + + return nil +} + +// doSingleImport lexes the individual file at importFile and returns +// its tokens or an error, if any. +func (p *parser) doSingleImport(importFile string) ([]token, error) { + file, err := os.Open(importFile) + if err != nil { + return nil, p.Errf("Could not import %s: %v", importFile, err) + } + defer file.Close() + importedTokens := allTokens(file) + + // Tack the filename onto these tokens so errors show the imported file's name + filename := filepath.Base(importFile) + for i := 0; i < len(importedTokens); i++ { + importedTokens[i].file = filename + } + + return importedTokens, nil +} + +// directive collects tokens until the directive's scope +// closes (either end of line or end of curly brace block). +// It expects the currently-loaded token to be a directive +// (or } that ends a server block). The collected tokens +// are loaded into the current server block for later use +// by directive setup functions. +func (p *parser) directive() error { + dir := p.Val() + nesting := 0 + + if p.checkDirectives { + if _, ok := ValidDirectives[dir]; !ok { + return p.Errf("Unknown directive '%s'", dir) + } + } + + // The directive itself is appended as a relevant token + p.block.Tokens[dir] = append(p.block.Tokens[dir], p.tokens[p.cursor]) + + for p.Next() { + if p.Val() == "{" { + nesting++ + } else if p.isNewLine() && nesting == 0 { + p.cursor-- // read too far + break + } else if p.Val() == "}" && nesting > 0 { + nesting-- + } else if p.Val() == "}" && nesting == 0 { + return p.Err("Unexpected '}' because no matching opening brace") + } + p.tokens[p.cursor].text = replaceEnvVars(p.tokens[p.cursor].text) + p.block.Tokens[dir] = append(p.block.Tokens[dir], p.tokens[p.cursor]) + } + + if nesting > 0 { + return p.EOFErr() + } + return nil +} + +// openCurlyBrace expects the current token to be an +// opening curly brace. This acts like an assertion +// because it returns an error if the token is not +// a opening curly brace. It does NOT advance the token. +func (p *parser) openCurlyBrace() error { + if p.Val() != "{" { + return p.SyntaxErr("{") + } + return nil +} + +// closeCurlyBrace expects the current token to be +// a closing curly brace. This acts like an assertion +// because it returns an error if the token is not +// a closing curly brace. It does NOT advance the token. +func (p *parser) closeCurlyBrace() error { + if p.Val() != "}" { + return p.SyntaxErr("}") + } + return nil +} + +// standardAddress parses an address string into a structured format with separate +// host, and port portions, as well as the original input string. +func standardAddress(str string) (address, error) { + var err error + + // first check for scheme and strip it off + input := str + + // separate host and port + host, port, err := net.SplitHostPort(str) + if err != nil { + host, port, err = net.SplitHostPort(str + ":") + // no error check here; return err at end of function + } + + // see if we can set port based off scheme + if port == "" { + port = "53" + } + + return address{Original: input, Host: strings.ToLower(dns.Fqdn(host)), Port: port}, err +} + +// replaceEnvVars replaces environment variables that appear in the token +// and understands both the $UNIX and %WINDOWS% syntaxes. +func replaceEnvVars(s string) string { + s = replaceEnvReferences(s, "{%", "%}") + s = replaceEnvReferences(s, "{$", "}") + return s +} + +// replaceEnvReferences performs the actual replacement of env variables +// in s, given the placeholder start and placeholder end strings. +func replaceEnvReferences(s, refStart, refEnd string) string { + index := strings.Index(s, refStart) + for index != -1 { + endIndex := strings.Index(s, refEnd) + if endIndex != -1 { + ref := s[index : endIndex+len(refEnd)] + s = strings.Replace(s, ref, os.Getenv(ref[len(refStart):len(ref)-len(refEnd)]), -1) + } else { + return s + } + index = strings.Index(s, refStart) + } + return s +} + +type ( + // ServerBlock associates tokens with a list of addresses + // and groups tokens by directive name. + ServerBlock struct { + Addresses []address + Tokens map[string][]token + } + + address struct { + Original, Host, Port string + } +) + +// HostList converts the list of addresses that are +// associated with this server block into a slice of +// strings, where each address is as it was originally +// read from the input. +func (sb ServerBlock) HostList() []string { + sbHosts := make([]string, len(sb.Addresses)) + for j, addr := range sb.Addresses { + sbHosts[j] = addr.Original + } + return sbHosts +} diff --git a/core/parse/parsing_test.go b/core/parse/parsing_test.go new file mode 100644 index 000000000..493c0fff9 --- /dev/null +++ b/core/parse/parsing_test.go @@ -0,0 +1,477 @@ +package parse + +import ( + "os" + "strings" + "testing" +) + +func TestStandardAddress(t *testing.T) { + for i, test := range []struct { + input string + scheme, host, port string + shouldErr bool + }{ + {`localhost`, "", "localhost", "", false}, + {`localhost:1234`, "", "localhost", "1234", false}, + {`localhost:`, "", "localhost", "", false}, + {`0.0.0.0`, "", "0.0.0.0", "", false}, + {`127.0.0.1:1234`, "", "127.0.0.1", "1234", false}, + {`:1234`, "", "", "1234", false}, + {`[::1]`, "", "::1", "", false}, + {`[::1]:1234`, "", "::1", "1234", false}, + {`:`, "", "", "", false}, + {`localhost:http`, "http", "localhost", "80", false}, + {`localhost:https`, "https", "localhost", "443", false}, + {`:http`, "http", "", "80", false}, + {`:https`, "https", "", "443", false}, + {`http://localhost:https`, "", "", "", true}, // conflict + {`http://localhost:http`, "", "", "", true}, // repeated scheme + {`http://localhost:443`, "", "", "", true}, // not conventional + {`https://localhost:80`, "", "", "", true}, // not conventional + {`http://localhost`, "http", "localhost", "80", false}, + {`https://localhost`, "https", "localhost", "443", false}, + {`http://127.0.0.1`, "http", "127.0.0.1", "80", false}, + {`https://127.0.0.1`, "https", "127.0.0.1", "443", false}, + {`http://[::1]`, "http", "::1", "80", false}, + {`http://localhost:1234`, "http", "localhost", "1234", false}, + {`https://127.0.0.1:1234`, "https", "127.0.0.1", "1234", false}, + {`http://[::1]:1234`, "http", "::1", "1234", false}, + {``, "", "", "", false}, + {`::1`, "", "::1", "", true}, + {`localhost::`, "", "localhost::", "", true}, + {`#$%@`, "", "#$%@", "", true}, + } { + actual, err := standardAddress(test.input) + + if err != nil && !test.shouldErr { + t.Errorf("Test %d (%s): Expected no error, but had error: %v", i, test.input, err) + } + if err == nil && test.shouldErr { + t.Errorf("Test %d (%s): Expected error, but had none", i, test.input) + } + + if actual.Scheme != test.scheme { + t.Errorf("Test %d (%s): Expected scheme '%s', got '%s'", i, test.input, test.scheme, actual.Scheme) + } + if actual.Host != test.host { + t.Errorf("Test %d (%s): Expected host '%s', got '%s'", i, test.input, test.host, actual.Host) + } + if actual.Port != test.port { + t.Errorf("Test %d (%s): Expected port '%s', got '%s'", i, test.input, test.port, actual.Port) + } + } +} + +func TestParseOneAndImport(t *testing.T) { + setupParseTests() + + testParseOne := func(input string) (ServerBlock, error) { + p := testParser(input) + p.Next() // parseOne doesn't call Next() to start, so we must + err := p.parseOne() + return p.block, err + } + + for i, test := range []struct { + input string + shouldErr bool + addresses []address + tokens map[string]int // map of directive name to number of tokens expected + }{ + {`localhost`, false, []address{ + {"localhost", "", "localhost", ""}, + }, map[string]int{}}, + + {`localhost + dir1`, false, []address{ + {"localhost", "", "localhost", ""}, + }, map[string]int{ + "dir1": 1, + }}, + + {`localhost:1234 + dir1 foo bar`, false, []address{ + {"localhost:1234", "", "localhost", "1234"}, + }, map[string]int{ + "dir1": 3, + }}, + + {`localhost { + dir1 + }`, false, []address{ + {"localhost", "", "localhost", ""}, + }, map[string]int{ + "dir1": 1, + }}, + + {`localhost:1234 { + dir1 foo bar + dir2 + }`, false, []address{ + {"localhost:1234", "", "localhost", "1234"}, + }, map[string]int{ + "dir1": 3, + "dir2": 1, + }}, + + {`http://localhost https://localhost + dir1 foo bar`, false, []address{ + {"http://localhost", "http", "localhost", "80"}, + {"https://localhost", "https", "localhost", "443"}, + }, map[string]int{ + "dir1": 3, + }}, + + {`http://localhost https://localhost { + dir1 foo bar + }`, false, []address{ + {"http://localhost", "http", "localhost", "80"}, + {"https://localhost", "https", "localhost", "443"}, + }, map[string]int{ + "dir1": 3, + }}, + + {`http://localhost, https://localhost { + dir1 foo bar + }`, false, []address{ + {"http://localhost", "http", "localhost", "80"}, + {"https://localhost", "https", "localhost", "443"}, + }, map[string]int{ + "dir1": 3, + }}, + + {`http://localhost, { + }`, true, []address{ + {"http://localhost", "http", "localhost", "80"}, + }, map[string]int{}}, + + {`host1:80, http://host2.com + dir1 foo bar + dir2 baz`, false, []address{ + {"host1:80", "", "host1", "80"}, + {"http://host2.com", "http", "host2.com", "80"}, + }, map[string]int{ + "dir1": 3, + "dir2": 2, + }}, + + {`http://host1.com, + http://host2.com, + https://host3.com`, false, []address{ + {"http://host1.com", "http", "host1.com", "80"}, + {"http://host2.com", "http", "host2.com", "80"}, + {"https://host3.com", "https", "host3.com", "443"}, + }, map[string]int{}}, + + {`http://host1.com:1234, https://host2.com + dir1 foo { + bar baz + } + dir2`, false, []address{ + {"http://host1.com:1234", "http", "host1.com", "1234"}, + {"https://host2.com", "https", "host2.com", "443"}, + }, map[string]int{ + "dir1": 6, + "dir2": 1, + }}, + + {`127.0.0.1 + dir1 { + bar baz + } + dir2 { + foo bar + }`, false, []address{ + {"127.0.0.1", "", "127.0.0.1", ""}, + }, map[string]int{ + "dir1": 5, + "dir2": 5, + }}, + + {`127.0.0.1 + unknown_directive`, true, []address{ + {"127.0.0.1", "", "127.0.0.1", ""}, + }, map[string]int{}}, + + {`localhost + dir1 { + foo`, true, []address{ + {"localhost", "", "localhost", ""}, + }, map[string]int{ + "dir1": 3, + }}, + + {`localhost + dir1 { + }`, false, []address{ + {"localhost", "", "localhost", ""}, + }, map[string]int{ + "dir1": 3, + }}, + + {`localhost + dir1 { + } }`, true, []address{ + {"localhost", "", "localhost", ""}, + }, map[string]int{ + "dir1": 3, + }}, + + {`localhost + dir1 { + nested { + foo + } + } + dir2 foo bar`, false, []address{ + {"localhost", "", "localhost", ""}, + }, map[string]int{ + "dir1": 7, + "dir2": 3, + }}, + + {``, false, []address{}, map[string]int{}}, + + {`localhost + dir1 arg1 + import import_test1.txt`, false, []address{ + {"localhost", "", "localhost", ""}, + }, map[string]int{ + "dir1": 2, + "dir2": 3, + "dir3": 1, + }}, + + {`import import_test2.txt`, false, []address{ + {"host1", "", "host1", ""}, + }, map[string]int{ + "dir1": 1, + "dir2": 2, + }}, + + {`import import_test1.txt import_test2.txt`, true, []address{}, map[string]int{}}, + + {`import not_found.txt`, true, []address{}, map[string]int{}}, + + {`""`, false, []address{}, map[string]int{}}, + + {``, false, []address{}, map[string]int{}}, + } { + result, err := testParseOne(test.input) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected an error, but didn't get one", i) + } + if !test.shouldErr && err != nil { + t.Errorf("Test %d: Expected no error, but got: %v", i, err) + } + + if len(result.Addresses) != len(test.addresses) { + t.Errorf("Test %d: Expected %d addresses, got %d", + i, len(test.addresses), len(result.Addresses)) + continue + } + for j, addr := range result.Addresses { + if addr.Host != test.addresses[j].Host { + t.Errorf("Test %d, address %d: Expected host to be '%s', but was '%s'", + i, j, test.addresses[j].Host, addr.Host) + } + if addr.Port != test.addresses[j].Port { + t.Errorf("Test %d, address %d: Expected port to be '%s', but was '%s'", + i, j, test.addresses[j].Port, addr.Port) + } + } + + if len(result.Tokens) != len(test.tokens) { + t.Errorf("Test %d: Expected %d directives, had %d", + i, len(test.tokens), len(result.Tokens)) + continue + } + for directive, tokens := range result.Tokens { + if len(tokens) != test.tokens[directive] { + t.Errorf("Test %d, directive '%s': Expected %d tokens, counted %d", + i, directive, test.tokens[directive], len(tokens)) + continue + } + } + } +} + +func TestParseAll(t *testing.T) { + setupParseTests() + + for i, test := range []struct { + input string + shouldErr bool + addresses [][]address // addresses per server block, in order + }{ + {`localhost`, false, [][]address{ + {{"localhost", "", "localhost", ""}}, + }}, + + {`localhost:1234`, false, [][]address{ + {{"localhost:1234", "", "localhost", "1234"}}, + }}, + + {`localhost:1234 { + } + localhost:2015 { + }`, false, [][]address{ + {{"localhost:1234", "", "localhost", "1234"}}, + {{"localhost:2015", "", "localhost", "2015"}}, + }}, + + {`localhost:1234, http://host2`, false, [][]address{ + {{"localhost:1234", "", "localhost", "1234"}, {"http://host2", "http", "host2", "80"}}, + }}, + + {`localhost:1234, http://host2,`, true, [][]address{}}, + + {`http://host1.com, http://host2.com { + } + https://host3.com, https://host4.com { + }`, false, [][]address{ + {{"http://host1.com", "http", "host1.com", "80"}, {"http://host2.com", "http", "host2.com", "80"}}, + {{"https://host3.com", "https", "host3.com", "443"}, {"https://host4.com", "https", "host4.com", "443"}}, + }}, + + {`import import_glob*.txt`, false, [][]address{ + {{"glob0.host0", "", "glob0.host0", ""}}, + {{"glob0.host1", "", "glob0.host1", ""}}, + {{"glob1.host0", "", "glob1.host0", ""}}, + {{"glob2.host0", "", "glob2.host0", ""}}, + }}, + } { + p := testParser(test.input) + blocks, err := p.parseAll() + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected an error, but didn't get one", i) + } + if !test.shouldErr && err != nil { + t.Errorf("Test %d: Expected no error, but got: %v", i, err) + } + + if len(blocks) != len(test.addresses) { + t.Errorf("Test %d: Expected %d server blocks, got %d", + i, len(test.addresses), len(blocks)) + continue + } + for j, block := range blocks { + if len(block.Addresses) != len(test.addresses[j]) { + t.Errorf("Test %d: Expected %d addresses in block %d, got %d", + i, len(test.addresses[j]), j, len(block.Addresses)) + continue + } + for k, addr := range block.Addresses { + if addr.Host != test.addresses[j][k].Host { + t.Errorf("Test %d, block %d, address %d: Expected host to be '%s', but was '%s'", + i, j, k, test.addresses[j][k].Host, addr.Host) + } + if addr.Port != test.addresses[j][k].Port { + t.Errorf("Test %d, block %d, address %d: Expected port to be '%s', but was '%s'", + i, j, k, test.addresses[j][k].Port, addr.Port) + } + } + } + } +} + +func TestEnvironmentReplacement(t *testing.T) { + setupParseTests() + + os.Setenv("PORT", "8080") + os.Setenv("ADDRESS", "servername.com") + os.Setenv("FOOBAR", "foobar") + + // basic test; unix-style env vars + p := testParser(`{$ADDRESS}`) + blocks, _ := p.parseAll() + if actual, expected := blocks[0].Addresses[0].Host, "servername.com"; expected != actual { + t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) + } + + // multiple vars per token + p = testParser(`{$ADDRESS}:{$PORT}`) + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Addresses[0].Host, "servername.com"; expected != actual { + t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) + } + if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual { + t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) + } + + // windows-style var and unix style in same token + p = testParser(`{%ADDRESS%}:{$PORT}`) + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Addresses[0].Host, "servername.com"; expected != actual { + t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) + } + if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual { + t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) + } + + // reverse order + p = testParser(`{$ADDRESS}:{%PORT%}`) + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Addresses[0].Host, "servername.com"; expected != actual { + t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) + } + if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual { + t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) + } + + // env var in server block body as argument + p = testParser(":{%PORT%}\ndir1 {$FOOBAR}") + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual { + t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) + } + if actual, expected := blocks[0].Tokens["dir1"][1].text, "foobar"; expected != actual { + t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual) + } + + // combined windows env vars in argument + p = testParser(":{%PORT%}\ndir1 {%ADDRESS%}/{%FOOBAR%}") + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Tokens["dir1"][1].text, "servername.com/foobar"; expected != actual { + t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual) + } + + // malformed env var (windows) + p = testParser(":1234\ndir1 {%ADDRESS}") + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Tokens["dir1"][1].text, "{%ADDRESS}"; expected != actual { + t.Errorf("Expected host to be '%s' but was '%s'", expected, actual) + } + + // malformed (non-existent) env var (unix) + p = testParser(`:{$PORT$}`) + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Addresses[0].Port, ""; expected != actual { + t.Errorf("Expected port to be '%s' but was '%s'", expected, actual) + } + + // in quoted field + p = testParser(":1234\ndir1 \"Test {$FOOBAR} test\"") + blocks, _ = p.parseAll() + if actual, expected := blocks[0].Tokens["dir1"][1].text, "Test foobar test"; expected != actual { + t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual) + } +} + +func setupParseTests() { + // Set up some bogus directives for testing + ValidDirectives = map[string]struct{}{ + "dir1": {}, + "dir2": {}, + "dir3": {}, + } +} + +func testParser(input string) parser { + buf := strings.NewReader(input) + p := parser{Dispenser: NewDispenser("Test", buf), checkDirectives: true} + return p +} diff --git a/core/restart.go b/core/restart.go new file mode 100644 index 000000000..82567d35c --- /dev/null +++ b/core/restart.go @@ -0,0 +1,166 @@ +// +build !windows + +package core + +import ( + "bytes" + "encoding/gob" + "errors" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "path" + "sync/atomic" + + "github.com/miekg/coredns/core/https" +) + +func init() { + gob.Register(CaddyfileInput{}) +} + +// Restart restarts the entire application; gracefully with zero +// downtime if on a POSIX-compatible system, or forcefully if on +// Windows but with imperceptibly-short downtime. +// +// The restarted application will use newCaddyfile as its input +// configuration. If newCaddyfile is nil, the current (existing) +// Caddyfile configuration will be used. +// +// Note: The process must exist in the same place on the disk in +// order for this to work. Thus, multiple graceful restarts don't +// work if executing with `go run`, since the binary is cleaned up +// when `go run` sees the initial parent process exit. +func Restart(newCaddyfile Input) error { + log.Println("[INFO] Restarting") + + if newCaddyfile == nil { + caddyfileMu.Lock() + newCaddyfile = caddyfile + caddyfileMu.Unlock() + } + + // Get certificates for any new hosts in the new Caddyfile without causing downtime + err := getCertsForNewCaddyfile(newCaddyfile) + if err != nil { + return errors.New("TLS preload: " + err.Error()) + } + + if len(os.Args) == 0 { // this should never happen, but... + os.Args = []string{""} + } + + // Tell the child that it's a restart + os.Setenv("CADDY_RESTART", "true") + + // Prepare our payload to the child process + cdyfileGob := caddyfileGob{ + ListenerFds: make(map[string]uintptr), + Caddyfile: newCaddyfile, + OnDemandTLSCertsIssued: atomic.LoadInt32(https.OnDemandIssuedCount), + } + + // Prepare a pipe to the fork's stdin so it can get the Caddyfile + rpipe, wpipe, err := os.Pipe() + if err != nil { + return err + } + + // Prepare a pipe that the child process will use to communicate + // its success with us by sending > 0 bytes + sigrpipe, sigwpipe, err := os.Pipe() + if err != nil { + return err + } + + // Pass along relevant file descriptors to child process; ordering + // is very important since we rely on these being in certain positions. + extraFiles := []*os.File{sigwpipe} // fd 3 + + // Add file descriptors of all the sockets + serversMu.Lock() + for i, s := range servers { + extraFiles = append(extraFiles, s.ListenerFd()) + cdyfileGob.ListenerFds[s.Addr] = uintptr(4 + i) // 4 fds come before any of the listeners + } + serversMu.Unlock() + + // Set up the command + cmd := exec.Command(os.Args[0], os.Args[1:]...) + cmd.Stdin = rpipe // fd 0 + cmd.Stdout = os.Stdout // fd 1 + cmd.Stderr = os.Stderr // fd 2 + cmd.ExtraFiles = extraFiles + + // Spawn the child process + err = cmd.Start() + if err != nil { + return err + } + + // Immediately close our dup'ed fds and the write end of our signal pipe + for _, f := range extraFiles { + f.Close() + } + + // Feed Caddyfile to the child + err = gob.NewEncoder(wpipe).Encode(cdyfileGob) + if err != nil { + return err + } + wpipe.Close() + + // Determine whether child startup succeeded + answer, readErr := ioutil.ReadAll(sigrpipe) + if answer == nil || len(answer) == 0 { + cmdErr := cmd.Wait() // get exit status + log.Printf("[ERROR] Restart: child failed to initialize (%v) - changes not applied", cmdErr) + if readErr != nil { + log.Printf("[ERROR] Restart: additionally, error communicating with child process: %v", readErr) + } + return errIncompleteRestart + } + + // Looks like child is successful; we can exit gracefully. + return Stop() +} + +func getCertsForNewCaddyfile(newCaddyfile Input) error { + // parse the new caddyfile only up to (and including) TLS + // so we can know what we need to get certs for. + configs, _, _, err := loadConfigsUpToIncludingTLS(path.Base(newCaddyfile.Path()), bytes.NewReader(newCaddyfile.Body())) + if err != nil { + return errors.New("loading Caddyfile: " + err.Error()) + } + + // first mark the configs that are qualified for managed TLS + https.MarkQualified(configs) + + // since we group by bind address to obtain certs, we must call + // EnableTLS to make sure the port is set properly first + // (can ignore error since we aren't actually using the certs) + https.EnableTLS(configs, false) + + // find out if we can let the acme package start its own challenge listener + // on port 80 + var proxyACME bool + serversMu.Lock() + for _, s := range servers { + _, port, _ := net.SplitHostPort(s.Addr) + if port == "80" { + proxyACME = true + break + } + } + serversMu.Unlock() + + // place certs on the disk + err = https.ObtainCerts(configs, false, proxyACME) + if err != nil { + return errors.New("obtaining certs: " + err.Error()) + } + + return nil +} diff --git a/core/restart_windows.go b/core/restart_windows.go new file mode 100644 index 000000000..c2a4f557a --- /dev/null +++ b/core/restart_windows.go @@ -0,0 +1,31 @@ +package core + +import "log" + +// Restart restarts Caddy forcefully using newCaddyfile, +// or, if nil, the current/existing Caddyfile is reused. +func Restart(newCaddyfile Input) error { + log.Println("[INFO] Restarting") + + if newCaddyfile == nil { + caddyfileMu.Lock() + newCaddyfile = caddyfile + caddyfileMu.Unlock() + } + + wg.Add(1) // barrier so Wait() doesn't unblock + + err := Stop() + if err != nil { + return err + } + + err = Start(newCaddyfile) + if err != nil { + return err + } + + wg.Done() // take down our barrier + + return nil +} diff --git a/core/setup/bindhost.go b/core/setup/bindhost.go new file mode 100644 index 000000000..a3c07e5eb --- /dev/null +++ b/core/setup/bindhost.go @@ -0,0 +1,13 @@ +package setup + +import "github.com/miekg/coredns/middleware" + +// BindHost sets the host to bind the listener to. +func BindHost(c *Controller) (middleware.Middleware, error) { + for c.Next() { + if !c.Args(&c.BindHost) { + return nil, c.ArgErr() + } + } + return nil, nil +} diff --git a/core/setup/controller.go b/core/setup/controller.go new file mode 100644 index 000000000..1c1a93e64 --- /dev/null +++ b/core/setup/controller.go @@ -0,0 +1,83 @@ +package setup + +import ( + "fmt" + "strings" + + "github.com/miekg/coredns/core/parse" + "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/server" + "github.com/miekg/dns" +) + +// Controller is given to the setup function of middlewares which +// gives them access to be able to read tokens and set config. Each +// virtualhost gets their own server config and dispenser. +type Controller struct { + *server.Config + parse.Dispenser + + // OncePerServerBlock is a function that executes f + // exactly once per server block, no matter how many + // hosts are associated with it. If it is the first + // time, the function f is executed immediately + // (not deferred) and may return an error which is + // returned by OncePerServerBlock. + OncePerServerBlock func(f func() error) error + + // ServerBlockIndex is the 0-based index of the + // server block as it appeared in the input. + ServerBlockIndex int + + // ServerBlockHostIndex is the 0-based index of this + // host as it appeared in the input at the head of the + // server block. + ServerBlockHostIndex int + + // ServerBlockHosts is a list of hosts that are + // associated with this server block. All these + // hosts, consequently, share the same tokens. + ServerBlockHosts []string + + // ServerBlockStorage is used by a directive's + // setup function to persist state between all + // the hosts on a server block. + ServerBlockStorage interface{} +} + +// NewTestController creates a new *Controller for +// the input specified, with a filename of "Testfile". +// The Config is bare, consisting only of a Root of cwd. +// +// Used primarily for testing but needs to be exported so +// add-ons can use this as a convenience. Does not initialize +// the server-block-related fields. +func NewTestController(input string) *Controller { + return &Controller{ + Config: &server.Config{ + Root: ".", + }, + Dispenser: parse.NewDispenser("Testfile", strings.NewReader(input)), + OncePerServerBlock: func(f func() error) error { + return f() + }, + } +} + +// EmptyNext is a no-op function that can be passed into +// middleware.Middleware functions so that the assignment +// to the Next field of the Handler can be tested. +// +// Used primarily for testing but needs to be exported so +// add-ons can use this as a convenience. +var EmptyNext = middleware.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) (int, error) { + return 0, nil +}) + +// SameNext does a pointer comparison between next1 and next2. +// +// Used primarily for testing but needs to be exported so +// add-ons can use this as a convenience. +func SameNext(next1, next2 middleware.Handler) bool { + return fmt.Sprintf("%v", next1) == fmt.Sprintf("%v", next2) +} diff --git a/core/setup/errors.go b/core/setup/errors.go new file mode 100644 index 000000000..0b392ec99 --- /dev/null +++ b/core/setup/errors.go @@ -0,0 +1,132 @@ +package setup + +import ( + "io" + "log" + "os" + + "github.com/hashicorp/go-syslog" + "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/errors" +) + +// Errors configures a new errors middleware instance. +func Errors(c *Controller) (middleware.Middleware, error) { + handler, err := errorsParse(c) + if err != nil { + return nil, err + } + + // Open the log file for writing when the server starts + c.Startup = append(c.Startup, func() error { + var err error + var writer io.Writer + + switch handler.LogFile { + case "visible": + handler.Debug = true + case "stdout": + writer = os.Stdout + case "stderr": + writer = os.Stderr + case "syslog": + writer, err = gsyslog.NewLogger(gsyslog.LOG_ERR, "LOCAL0", "caddy") + if err != nil { + return err + } + default: + if handler.LogFile == "" { + writer = os.Stderr // default + break + } + + var file *os.File + file, err = os.OpenFile(handler.LogFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644) + if err != nil { + return err + } + if handler.LogRoller != nil { + file.Close() + + handler.LogRoller.Filename = handler.LogFile + + writer = handler.LogRoller.GetLogWriter() + } else { + writer = file + } + } + + handler.Log = log.New(writer, "", 0) + return nil + }) + + return func(next middleware.Handler) middleware.Handler { + handler.Next = next + return handler + }, nil +} + +func errorsParse(c *Controller) (*errors.ErrorHandler, error) { + // Very important that we make a pointer because the Startup + // function that opens the log file must have access to the + // same instance of the handler, not a copy. + handler := &errors.ErrorHandler{} + + optionalBlock := func() (bool, error) { + var hadBlock bool + + for c.NextBlock() { + hadBlock = true + + what := c.Val() + if !c.NextArg() { + return hadBlock, c.ArgErr() + } + where := c.Val() + + if what == "log" { + if where == "visible" { + handler.Debug = true + } else { + handler.LogFile = where + if c.NextArg() { + if c.Val() == "{" { + c.IncrNest() + logRoller, err := parseRoller(c) + if err != nil { + return hadBlock, err + } + handler.LogRoller = logRoller + } + } + } + } + } + return hadBlock, nil + } + + for c.Next() { + // weird hack to avoid having the handler values overwritten. + if c.Val() == "}" { + continue + } + // Configuration may be in a block + hadBlock, err := optionalBlock() + if err != nil { + return handler, err + } + + // Otherwise, the only argument would be an error log file name or 'visible' + if !hadBlock { + if c.NextArg() { + if c.Val() == "visible" { + handler.Debug = true + } else { + handler.LogFile = c.Val() + } + } + } + } + + return handler, nil +} diff --git a/core/setup/errors_test.go b/core/setup/errors_test.go new file mode 100644 index 000000000..4a079e0b5 --- /dev/null +++ b/core/setup/errors_test.go @@ -0,0 +1,158 @@ +package setup + +import ( + "testing" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/errors" +) + +func TestErrors(t *testing.T) { + c := NewTestController(`errors`) + mid, err := Errors(c) + + if err != nil { + t.Errorf("Expected no errors, got: %v", err) + } + + if mid == nil { + t.Fatal("Expected middleware, was nil instead") + } + + handler := mid(EmptyNext) + myHandler, ok := handler.(*errors.ErrorHandler) + + if !ok { + t.Fatalf("Expected handler to be type ErrorHandler, got: %#v", handler) + } + + if myHandler.LogFile != "" { + t.Errorf("Expected '%s' as the default LogFile", "") + } + if myHandler.LogRoller != nil { + t.Errorf("Expected LogRoller to be nil, got: %v", *myHandler.LogRoller) + } + if !SameNext(myHandler.Next, EmptyNext) { + t.Error("'Next' field of handler was not set properly") + } + + // Test Startup function + if len(c.Startup) == 0 { + t.Fatal("Expected 1 startup function, had 0") + } + err = c.Startup[0]() + if myHandler.Log == nil { + t.Error("Expected Log to be non-nil after startup because Debug is not enabled") + } +} + +func TestErrorsParse(t *testing.T) { + tests := []struct { + inputErrorsRules string + shouldErr bool + expectedErrorHandler errors.ErrorHandler + }{ + {`errors`, false, errors.ErrorHandler{ + LogFile: "", + }}, + {`errors errors.txt`, false, errors.ErrorHandler{ + LogFile: "errors.txt", + }}, + {`errors visible`, false, errors.ErrorHandler{ + LogFile: "", + Debug: true, + }}, + {`errors { log visible }`, false, errors.ErrorHandler{ + LogFile: "", + Debug: true, + }}, + {`errors { log errors.txt + 404 404.html + 500 500.html +}`, false, errors.ErrorHandler{ + LogFile: "errors.txt", + ErrorPages: map[int]string{ + 404: "404.html", + 500: "500.html", + }, + }}, + {`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, errors.ErrorHandler{ + LogFile: "errors.txt", + LogRoller: &middleware.LogRoller{ + MaxSize: 2, + MaxAge: 10, + MaxBackups: 3, + LocalTime: true, + }, + }}, + {`errors { log errors.txt { + size 3 + age 11 + keep 5 + } + 404 404.html + 503 503.html +}`, false, errors.ErrorHandler{ + LogFile: "errors.txt", + ErrorPages: map[int]string{ + 404: "404.html", + 503: "503.html", + }, + LogRoller: &middleware.LogRoller{ + MaxSize: 3, + MaxAge: 11, + MaxBackups: 5, + LocalTime: true, + }, + }}, + } + for i, test := range tests { + c := NewTestController(test.inputErrorsRules) + actualErrorsRule, err := errorsParse(c) + + if err == nil && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } + if actualErrorsRule.LogFile != test.expectedErrorHandler.LogFile { + t.Errorf("Test %d expected LogFile to be %s, but got %s", + i, test.expectedErrorHandler.LogFile, actualErrorsRule.LogFile) + } + if actualErrorsRule.Debug != test.expectedErrorHandler.Debug { + t.Errorf("Test %d expected Debug to be %v, but got %v", + i, test.expectedErrorHandler.Debug, actualErrorsRule.Debug) + } + if actualErrorsRule.LogRoller != nil && test.expectedErrorHandler.LogRoller == nil || actualErrorsRule.LogRoller == nil && test.expectedErrorHandler.LogRoller != nil { + t.Fatalf("Test %d expected LogRoller to be %v, but got %v", + i, test.expectedErrorHandler.LogRoller, actualErrorsRule.LogRoller) + } + if len(actualErrorsRule.ErrorPages) != len(test.expectedErrorHandler.ErrorPages) { + t.Fatalf("Test %d expected %d no of Error pages, but got %d ", + i, len(test.expectedErrorHandler.ErrorPages), len(actualErrorsRule.ErrorPages)) + } + if actualErrorsRule.LogRoller != nil && test.expectedErrorHandler.LogRoller != nil { + if actualErrorsRule.LogRoller.Filename != test.expectedErrorHandler.LogRoller.Filename { + t.Fatalf("Test %d expected LogRoller Filename to be %s, but got %s", + i, test.expectedErrorHandler.LogRoller.Filename, actualErrorsRule.LogRoller.Filename) + } + if actualErrorsRule.LogRoller.MaxAge != test.expectedErrorHandler.LogRoller.MaxAge { + t.Fatalf("Test %d expected LogRoller MaxAge to be %d, but got %d", + i, test.expectedErrorHandler.LogRoller.MaxAge, actualErrorsRule.LogRoller.MaxAge) + } + if actualErrorsRule.LogRoller.MaxBackups != test.expectedErrorHandler.LogRoller.MaxBackups { + t.Fatalf("Test %d expected LogRoller MaxBackups to be %d, but got %d", + i, test.expectedErrorHandler.LogRoller.MaxBackups, actualErrorsRule.LogRoller.MaxBackups) + } + if actualErrorsRule.LogRoller.MaxSize != test.expectedErrorHandler.LogRoller.MaxSize { + t.Fatalf("Test %d expected LogRoller MaxSize to be %d, but got %d", + i, test.expectedErrorHandler.LogRoller.MaxSize, actualErrorsRule.LogRoller.MaxSize) + } + if actualErrorsRule.LogRoller.LocalTime != test.expectedErrorHandler.LogRoller.LocalTime { + t.Fatalf("Test %d expected LogRoller LocalTime to be %t, but got %t", + i, test.expectedErrorHandler.LogRoller.LocalTime, actualErrorsRule.LogRoller.LocalTime) + } + } + } + +} diff --git a/core/setup/file.go b/core/setup/file.go new file mode 100644 index 000000000..76aed5249 --- /dev/null +++ b/core/setup/file.go @@ -0,0 +1,73 @@ +package setup + +import ( + "log" + "os" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/file" + "github.com/miekg/dns" +) + +// File sets up the file middleware. +func File(c *Controller) (middleware.Middleware, error) { + zones, err := fileParse(c) + if err != nil { + return nil, err + } + return func(next middleware.Handler) middleware.Handler { + return file.File{Next: next, Zones: zones} + }, nil + +} + +func fileParse(c *Controller) (file.Zones, error) { + // Maybe multiple, each for each zone. + z := make(map[string]file.Zone) + names := []string{} + for c.Next() { + if c.Val() == "file" { + // file db.file [origin] + if !c.NextArg() { + return file.Zones{}, c.ArgErr() + } + fileName := c.Val() + + origin := c.ServerBlockHosts[c.ServerBlockHostIndex] + if c.NextArg() { + c.Next() + origin = c.Val() + } + // normalize this origin + origin = middleware.Host(origin).StandardHost() + + zone, err := parseZone(origin, fileName) + if err == nil { + z[origin] = zone + } + names = append(names, origin) + } + } + return file.Zones{Z: z, Names: names}, nil +} + +// +// parsrZone parses the zone in filename and returns a []RR or an error. +func parseZone(origin, fileName string) (file.Zone, error) { + f, err := os.Open(fileName) + if err != nil { + return nil, err + } + tokens := dns.ParseZone(f, origin, fileName) + zone := make([]dns.RR, 0, defaultZoneSize) + for x := range tokens { + if x.Error != nil { + log.Printf("[ERROR] failed to parse %s: %v", origin, x.Error) + return nil, x.Error + } + zone = append(zone, x.RR) + } + return file.Zone(zone), nil +} + +const defaultZoneSize = 20 // A made up number. diff --git a/core/setup/log.go b/core/setup/log.go new file mode 100644 index 000000000..32d9f3250 --- /dev/null +++ b/core/setup/log.go @@ -0,0 +1,130 @@ +package setup + +import ( + "io" + "log" + "os" + + "github.com/hashicorp/go-syslog" + "github.com/miekg/coredns/middleware" + caddylog "github.com/miekg/coredns/middleware/log" + "github.com/miekg/coredns/server" +) + +// Log sets up the logging middleware. +func Log(c *Controller) (middleware.Middleware, error) { + rules, err := logParse(c) + if err != nil { + return nil, err + } + + // Open the log files for writing when the server starts + c.Startup = append(c.Startup, func() error { + for i := 0; i < len(rules); i++ { + var err error + var writer io.Writer + + if rules[i].OutputFile == "stdout" { + writer = os.Stdout + } else if rules[i].OutputFile == "stderr" { + writer = os.Stderr + } else if rules[i].OutputFile == "syslog" { + writer, err = gsyslog.NewLogger(gsyslog.LOG_INFO, "LOCAL0", "caddy") + if err != nil { + return err + } + } else { + var file *os.File + file, err = os.OpenFile(rules[i].OutputFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644) + if err != nil { + return err + } + if rules[i].Roller != nil { + file.Close() + rules[i].Roller.Filename = rules[i].OutputFile + writer = rules[i].Roller.GetLogWriter() + } else { + writer = file + } + } + + rules[i].Log = log.New(writer, "", 0) + } + + return nil + }) + + return func(next middleware.Handler) middleware.Handler { + return caddylog.Logger{Next: next, Rules: rules, ErrorFunc: server.DefaultErrorFunc} + }, nil +} + +func logParse(c *Controller) ([]caddylog.Rule, error) { + var rules []caddylog.Rule + + for c.Next() { + args := c.RemainingArgs() + + var logRoller *middleware.LogRoller + if c.NextBlock() { + if c.Val() == "rotate" { + if c.NextArg() { + if c.Val() == "{" { + var err error + logRoller, err = parseRoller(c) + if err != nil { + return nil, err + } + // This part doesn't allow having something after the rotate block + if c.Next() { + if c.Val() != "}" { + return nil, c.ArgErr() + } + } + } + } + } + } + if len(args) == 0 { + // Nothing specified; use defaults + rules = append(rules, caddylog.Rule{ + PathScope: "/", + OutputFile: caddylog.DefaultLogFilename, + Format: caddylog.DefaultLogFormat, + Roller: logRoller, + }) + } else if len(args) == 1 { + // Only an output file specified + rules = append(rules, caddylog.Rule{ + PathScope: "/", + OutputFile: args[0], + Format: caddylog.DefaultLogFormat, + Roller: logRoller, + }) + } else { + // Path scope, output file, and maybe a format specified + + format := caddylog.DefaultLogFormat + + if len(args) > 2 { + switch args[2] { + case "{common}": + format = caddylog.CommonLogFormat + case "{combined}": + format = caddylog.CombinedLogFormat + default: + format = args[2] + } + } + + rules = append(rules, caddylog.Rule{ + PathScope: args[0], + OutputFile: args[1], + Format: format, + Roller: logRoller, + }) + } + } + + return rules, nil +} diff --git a/core/setup/log_test.go b/core/setup/log_test.go new file mode 100644 index 000000000..2bfcb4e89 --- /dev/null +++ b/core/setup/log_test.go @@ -0,0 +1,175 @@ +package setup + +import ( + "testing" + + "github.com/miekg/coredns/middleware" + caddylog "github.com/miekg/coredns/middleware/log" +) + +func TestLog(t *testing.T) { + + c := NewTestController(`log`) + + mid, err := Log(c) + + if err != nil { + t.Errorf("Expected no errors, got: %v", err) + } + + if mid == nil { + t.Fatal("Expected middleware, was nil instead") + } + + handler := mid(EmptyNext) + myHandler, ok := handler.(caddylog.Logger) + + if !ok { + t.Fatalf("Expected handler to be type Logger, got: %#v", handler) + } + + if myHandler.Rules[0].PathScope != "/" { + t.Errorf("Expected / as the default PathScope") + } + if myHandler.Rules[0].OutputFile != caddylog.DefaultLogFilename { + t.Errorf("Expected %s as the default OutputFile", caddylog.DefaultLogFilename) + } + if myHandler.Rules[0].Format != caddylog.DefaultLogFormat { + t.Errorf("Expected %s as the default Log Format", caddylog.DefaultLogFormat) + } + if myHandler.Rules[0].Roller != nil { + t.Errorf("Expected Roller to be nil, got: %v", *myHandler.Rules[0].Roller) + } + if !SameNext(myHandler.Next, EmptyNext) { + t.Error("'Next' field of handler was not set properly") + } + +} + +func TestLogParse(t *testing.T) { + tests := []struct { + inputLogRules string + shouldErr bool + expectedLogRules []caddylog.Rule + }{ + {`log`, false, []caddylog.Rule{{ + PathScope: "/", + OutputFile: caddylog.DefaultLogFilename, + Format: caddylog.DefaultLogFormat, + }}}, + {`log log.txt`, false, []caddylog.Rule{{ + PathScope: "/", + OutputFile: "log.txt", + Format: caddylog.DefaultLogFormat, + }}}, + {`log /api log.txt`, false, []caddylog.Rule{{ + PathScope: "/api", + OutputFile: "log.txt", + Format: caddylog.DefaultLogFormat, + }}}, + {`log /serve stdout`, false, []caddylog.Rule{{ + PathScope: "/serve", + OutputFile: "stdout", + Format: caddylog.DefaultLogFormat, + }}}, + {`log /myapi log.txt {common}`, false, []caddylog.Rule{{ + PathScope: "/myapi", + OutputFile: "log.txt", + Format: caddylog.CommonLogFormat, + }}}, + {`log /test accesslog.txt {combined}`, false, []caddylog.Rule{{ + PathScope: "/test", + OutputFile: "accesslog.txt", + Format: caddylog.CombinedLogFormat, + }}}, + {`log /api1 log.txt + log /api2 accesslog.txt {combined}`, false, []caddylog.Rule{{ + PathScope: "/api1", + OutputFile: "log.txt", + Format: caddylog.DefaultLogFormat, + }, { + PathScope: "/api2", + OutputFile: "accesslog.txt", + Format: caddylog.CombinedLogFormat, + }}}, + {`log /api3 stdout {host} + log /api4 log.txt {when}`, false, []caddylog.Rule{{ + PathScope: "/api3", + OutputFile: "stdout", + Format: "{host}", + }, { + PathScope: "/api4", + OutputFile: "log.txt", + Format: "{when}", + }}}, + {`log access.log { rotate { size 2 age 10 keep 3 } }`, false, []caddylog.Rule{{ + PathScope: "/", + OutputFile: "access.log", + Format: caddylog.DefaultLogFormat, + Roller: &middleware.LogRoller{ + MaxSize: 2, + MaxAge: 10, + MaxBackups: 3, + LocalTime: true, + }, + }}}, + } + for i, test := range tests { + c := NewTestController(test.inputLogRules) + actualLogRules, err := logParse(c) + + if err == nil && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } + if len(actualLogRules) != len(test.expectedLogRules) { + t.Fatalf("Test %d expected %d no of Log rules, but got %d ", + i, len(test.expectedLogRules), len(actualLogRules)) + } + for j, actualLogRule := range actualLogRules { + + if actualLogRule.PathScope != test.expectedLogRules[j].PathScope { + t.Errorf("Test %d expected %dth LogRule PathScope to be %s , but got %s", + i, j, test.expectedLogRules[j].PathScope, actualLogRule.PathScope) + } + + if actualLogRule.OutputFile != test.expectedLogRules[j].OutputFile { + t.Errorf("Test %d expected %dth LogRule OutputFile to be %s , but got %s", + i, j, test.expectedLogRules[j].OutputFile, actualLogRule.OutputFile) + } + + if actualLogRule.Format != test.expectedLogRules[j].Format { + t.Errorf("Test %d expected %dth LogRule Format to be %s , but got %s", + i, j, test.expectedLogRules[j].Format, actualLogRule.Format) + } + if actualLogRule.Roller != nil && test.expectedLogRules[j].Roller == nil || actualLogRule.Roller == nil && test.expectedLogRules[j].Roller != nil { + t.Fatalf("Test %d expected %dth LogRule Roller to be %v, but got %v", + i, j, test.expectedLogRules[j].Roller, actualLogRule.Roller) + } + if actualLogRule.Roller != nil && test.expectedLogRules[j].Roller != nil { + if actualLogRule.Roller.Filename != test.expectedLogRules[j].Roller.Filename { + t.Fatalf("Test %d expected %dth LogRule Roller Filename to be %s, but got %s", + i, j, test.expectedLogRules[j].Roller.Filename, actualLogRule.Roller.Filename) + } + if actualLogRule.Roller.MaxAge != test.expectedLogRules[j].Roller.MaxAge { + t.Fatalf("Test %d expected %dth LogRule Roller MaxAge to be %d, but got %d", + i, j, test.expectedLogRules[j].Roller.MaxAge, actualLogRule.Roller.MaxAge) + } + if actualLogRule.Roller.MaxBackups != test.expectedLogRules[j].Roller.MaxBackups { + t.Fatalf("Test %d expected %dth LogRule Roller MaxBackups to be %d, but got %d", + i, j, test.expectedLogRules[j].Roller.MaxBackups, actualLogRule.Roller.MaxBackups) + } + if actualLogRule.Roller.MaxSize != test.expectedLogRules[j].Roller.MaxSize { + t.Fatalf("Test %d expected %dth LogRule Roller MaxSize to be %d, but got %d", + i, j, test.expectedLogRules[j].Roller.MaxSize, actualLogRule.Roller.MaxSize) + } + if actualLogRule.Roller.LocalTime != test.expectedLogRules[j].Roller.LocalTime { + t.Fatalf("Test %d expected %dth LogRule Roller LocalTime to be %t, but got %t", + i, j, test.expectedLogRules[j].Roller.LocalTime, actualLogRule.Roller.LocalTime) + } + } + } + } + +} diff --git a/core/setup/prometheus.go b/core/setup/prometheus.go new file mode 100644 index 000000000..3bcb907ce --- /dev/null +++ b/core/setup/prometheus.go @@ -0,0 +1,70 @@ +package setup + +import ( + "sync" + + "github.com/miekg/coredns/middleware" + prom "github.com/miekg/coredns/middleware/prometheus" +) + +const ( + path = "/metrics" + addr = "localhost:9153" +) + +var once sync.Once + +func Prometheus(c *Controller) (middleware.Middleware, error) { + metrics, err := parsePrometheus(c) + if err != nil { + return nil, err + } + if metrics.Addr == "" { + metrics.Addr = addr + } + once.Do(func() { + c.Startup = append(c.Startup, metrics.Start) + }) + + return func(next middleware.Handler) middleware.Handler { + metrics.Next = next + return metrics + }, nil +} + +func parsePrometheus(c *Controller) (*prom.Metrics, error) { + var ( + metrics *prom.Metrics + err error + ) + + for c.Next() { + if metrics != nil { + return nil, c.Err("prometheus: can only have one metrics module per server") + } + metrics = &prom.Metrics{ZoneNames: c.ServerBlockHosts} + args := c.RemainingArgs() + + switch len(args) { + case 0: + case 1: + metrics.Addr = args[0] + default: + return nil, c.ArgErr() + } + for c.NextBlock() { + switch c.Val() { + case "address": + args = c.RemainingArgs() + if len(args) != 1 { + return nil, c.ArgErr() + } + metrics.Addr = args[0] + default: + return nil, c.Errf("prometheus: unknown item: %s", c.Val()) + } + + } + } + return metrics, err +} diff --git a/core/setup/proxy.go b/core/setup/proxy.go new file mode 100644 index 000000000..6753d07ad --- /dev/null +++ b/core/setup/proxy.go @@ -0,0 +1,17 @@ +package setup + +import ( + "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/proxy" +) + +// Proxy configures a new Proxy middleware instance. +func Proxy(c *Controller) (middleware.Middleware, error) { + upstreams, err := proxy.NewStaticUpstreams(c.Dispenser) + if err != nil { + return nil, err + } + return func(next middleware.Handler) middleware.Handler { + return proxy.Proxy{Next: next, Client: proxy.Clients(), Upstreams: upstreams} + }, nil +} diff --git a/core/setup/reflect.go b/core/setup/reflect.go new file mode 100644 index 000000000..9ae1d5181 --- /dev/null +++ b/core/setup/reflect.go @@ -0,0 +1,28 @@ +package setup + +import ( + "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/reflect" +) + +// Reflect sets up the reflect middleware. +func Reflect(c *Controller) (middleware.Middleware, error) { + if err := reflectParse(c); err != nil { + return nil, err + } + return func(next middleware.Handler) middleware.Handler { + return reflect.Reflect{Next: next} + }, nil + +} + +func reflectParse(c *Controller) error { + for c.Next() { + if c.Val() == "reflect" { + if c.NextArg() { + return c.ArgErr() + } + } + } + return nil +} diff --git a/core/setup/rewrite.go b/core/setup/rewrite.go new file mode 100644 index 000000000..32f5f42a3 --- /dev/null +++ b/core/setup/rewrite.go @@ -0,0 +1,109 @@ +package setup + +import ( + "strconv" + "strings" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/coredns/middleware/rewrite" +) + +// Rewrite configures a new Rewrite middleware instance. +func Rewrite(c *Controller) (middleware.Middleware, error) { + rewrites, err := rewriteParse(c) + if err != nil { + return nil, err + } + + return func(next middleware.Handler) middleware.Handler { + return rewrite.Rewrite{ + Next: next, + Rules: rewrites, + } + }, nil +} + +func rewriteParse(c *Controller) ([]rewrite.Rule, error) { + var simpleRules []rewrite.Rule + var regexpRules []rewrite.Rule + + for c.Next() { + var rule rewrite.Rule + var err error + var base = "/" + var pattern, to string + var status int + var ext []string + + args := c.RemainingArgs() + + var ifs []rewrite.If + + switch len(args) { + case 1: + base = args[0] + fallthrough + case 0: + for c.NextBlock() { + switch c.Val() { + case "r", "regexp": + if !c.NextArg() { + return nil, c.ArgErr() + } + pattern = c.Val() + case "to": + args1 := c.RemainingArgs() + if len(args1) == 0 { + return nil, c.ArgErr() + } + to = strings.Join(args1, " ") + case "ext": + args1 := c.RemainingArgs() + if len(args1) == 0 { + return nil, c.ArgErr() + } + ext = args1 + case "if": + args1 := c.RemainingArgs() + if len(args1) != 3 { + return nil, c.ArgErr() + } + ifCond, err := rewrite.NewIf(args1[0], args1[1], args1[2]) + if err != nil { + return nil, err + } + ifs = append(ifs, ifCond) + case "status": + if !c.NextArg() { + return nil, c.ArgErr() + } + status, _ = strconv.Atoi(c.Val()) + if status < 200 || (status > 299 && status < 400) || status > 499 { + return nil, c.Err("status must be 2xx or 4xx") + } + default: + return nil, c.ArgErr() + } + } + // ensure to or status is specified + if to == "" && status == 0 { + return nil, c.ArgErr() + } + // TODO(miek): complex rules + base, pattern, to, status, ext, ifs = base, pattern, to, status, ext, ifs + err = err + // if rule, err = rewrite.NewComplexRule(base, pattern, to, status, ext, ifs); err != nil { + // return nil, err + // } + regexpRules = append(regexpRules, rule) + + // the only unhandled case is 2 and above + default: + rule = rewrite.NewSimpleRule(args[0], strings.Join(args[1:], " ")) + simpleRules = append(simpleRules, rule) + } + } + + // put simple rules in front to avoid regexp computation for them + return append(simpleRules, regexpRules...), nil +} diff --git a/core/setup/rewrite_test.go b/core/setup/rewrite_test.go new file mode 100644 index 000000000..747618305 --- /dev/null +++ b/core/setup/rewrite_test.go @@ -0,0 +1,241 @@ +package setup + +import ( + "fmt" + "regexp" + "testing" + + "github.com/miekg/coredns/middleware/rewrite" +) + +func TestRewrite(t *testing.T) { + c := NewTestController(`rewrite /from /to`) + + mid, err := Rewrite(c) + if err != nil { + t.Errorf("Expected no errors, but got: %v", err) + } + if mid == nil { + t.Fatal("Expected middleware, was nil instead") + } + + handler := mid(EmptyNext) + myHandler, ok := handler.(rewrite.Rewrite) + if !ok { + t.Fatalf("Expected handler to be type Rewrite, got: %#v", handler) + } + + if !SameNext(myHandler.Next, EmptyNext) { + t.Error("'Next' field of handler was not set properly") + } + + if len(myHandler.Rules) != 1 { + t.Errorf("Expected handler to have %d rule, has %d instead", 1, len(myHandler.Rules)) + } +} + +func TestRewriteParse(t *testing.T) { + simpleTests := []struct { + input string + shouldErr bool + expected []rewrite.Rule + }{ + {`rewrite /from /to`, false, []rewrite.Rule{ + rewrite.SimpleRule{From: "/from", To: "/to"}, + }}, + {`rewrite /from /to + rewrite a b`, false, []rewrite.Rule{ + rewrite.SimpleRule{From: "/from", To: "/to"}, + rewrite.SimpleRule{From: "a", To: "b"}, + }}, + {`rewrite a`, true, []rewrite.Rule{}}, + {`rewrite`, true, []rewrite.Rule{}}, + {`rewrite a b c`, false, []rewrite.Rule{ + rewrite.SimpleRule{From: "a", To: "b c"}, + }}, + } + + for i, test := range simpleTests { + c := NewTestController(test.input) + actual, err := rewriteParse(c) + + if err == nil && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } else if err != nil && test.shouldErr { + continue + } + + if len(actual) != len(test.expected) { + t.Fatalf("Test %d expected %d rules, but got %d", + i, len(test.expected), len(actual)) + } + + for j, e := range test.expected { + actualRule := actual[j].(rewrite.SimpleRule) + expectedRule := e.(rewrite.SimpleRule) + + if actualRule.From != expectedRule.From { + t.Errorf("Test %d, rule %d: Expected From=%s, got %s", + i, j, expectedRule.From, actualRule.From) + } + + if actualRule.To != expectedRule.To { + t.Errorf("Test %d, rule %d: Expected To=%s, got %s", + i, j, expectedRule.To, actualRule.To) + } + } + } + + regexpTests := []struct { + input string + shouldErr bool + expected []rewrite.Rule + }{ + {`rewrite { + r .* + to /to /index.php? + }`, false, []rewrite.Rule{ + &rewrite.ComplexRule{Base: "/", To: "/to /index.php?", Regexp: regexp.MustCompile(".*")}, + }}, + {`rewrite { + regexp .* + to /to + ext / html txt + }`, false, []rewrite.Rule{ + &rewrite.ComplexRule{Base: "/", To: "/to", Exts: []string{"/", "html", "txt"}, Regexp: regexp.MustCompile(".*")}, + }}, + {`rewrite /path { + r rr + to /dest + } + rewrite / { + regexp [a-z]+ + to /to /to2 + } + `, false, []rewrite.Rule{ + &rewrite.ComplexRule{Base: "/path", To: "/dest", Regexp: regexp.MustCompile("rr")}, + &rewrite.ComplexRule{Base: "/", To: "/to /to2", Regexp: regexp.MustCompile("[a-z]+")}, + }}, + {`rewrite { + r .* + }`, true, []rewrite.Rule{ + &rewrite.ComplexRule{}, + }}, + {`rewrite { + + }`, true, []rewrite.Rule{ + &rewrite.ComplexRule{}, + }}, + {`rewrite /`, true, []rewrite.Rule{ + &rewrite.ComplexRule{}, + }}, + {`rewrite { + to /to + if {path} is a + }`, false, []rewrite.Rule{ + &rewrite.ComplexRule{Base: "/", To: "/to", Ifs: []rewrite.If{{A: "{path}", Operator: "is", B: "a"}}}, + }}, + {`rewrite { + status 500 + }`, true, []rewrite.Rule{ + &rewrite.ComplexRule{}, + }}, + {`rewrite { + status 400 + }`, false, []rewrite.Rule{ + &rewrite.ComplexRule{Base: "/", Status: 400}, + }}, + {`rewrite { + to /to + status 400 + }`, false, []rewrite.Rule{ + &rewrite.ComplexRule{Base: "/", To: "/to", Status: 400}, + }}, + {`rewrite { + status 399 + }`, true, []rewrite.Rule{ + &rewrite.ComplexRule{}, + }}, + {`rewrite { + status 200 + }`, false, []rewrite.Rule{ + &rewrite.ComplexRule{Base: "/", Status: 200}, + }}, + {`rewrite { + to /to + status 200 + }`, false, []rewrite.Rule{ + &rewrite.ComplexRule{Base: "/", To: "/to", Status: 200}, + }}, + {`rewrite { + status 199 + }`, true, []rewrite.Rule{ + &rewrite.ComplexRule{}, + }}, + {`rewrite { + status 0 + }`, true, []rewrite.Rule{ + &rewrite.ComplexRule{}, + }}, + {`rewrite { + to /to + status 0 + }`, true, []rewrite.Rule{ + &rewrite.ComplexRule{}, + }}, + } + + for i, test := range regexpTests { + c := NewTestController(test.input) + actual, err := rewriteParse(c) + + if err == nil && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if err != nil && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err) + } else if err != nil && test.shouldErr { + continue + } + + if len(actual) != len(test.expected) { + t.Fatalf("Test %d expected %d rules, but got %d", + i, len(test.expected), len(actual)) + } + + for j, e := range test.expected { + actualRule := actual[j].(*rewrite.ComplexRule) + expectedRule := e.(*rewrite.ComplexRule) + + if actualRule.Base != expectedRule.Base { + t.Errorf("Test %d, rule %d: Expected Base=%s, got %s", + i, j, expectedRule.Base, actualRule.Base) + } + + if actualRule.To != expectedRule.To { + t.Errorf("Test %d, rule %d: Expected To=%s, got %s", + i, j, expectedRule.To, actualRule.To) + } + + if fmt.Sprint(actualRule.Exts) != fmt.Sprint(expectedRule.Exts) { + t.Errorf("Test %d, rule %d: Expected Ext=%v, got %v", + i, j, expectedRule.To, actualRule.To) + } + + if actualRule.Regexp != nil { + if actualRule.String() != expectedRule.String() { + t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s", + i, j, expectedRule.String(), actualRule.String()) + } + } + + if fmt.Sprint(actualRule.Ifs) != fmt.Sprint(expectedRule.Ifs) { + t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s", + i, j, fmt.Sprint(expectedRule.Ifs), fmt.Sprint(actualRule.Ifs)) + } + + } + } + +} diff --git a/core/setup/roller.go b/core/setup/roller.go new file mode 100644 index 000000000..fd772cc47 --- /dev/null +++ b/core/setup/roller.go @@ -0,0 +1,40 @@ +package setup + +import ( + "strconv" + + "github.com/miekg/coredns/middleware" +) + +func parseRoller(c *Controller) (*middleware.LogRoller, error) { + var size, age, keep int + // This is kind of a hack to support nested blocks: + // As we are already in a block: either log or errors, + // c.nesting > 0 but, as soon as c meets a }, it thinks + // the block is over and return false for c.NextBlock. + for c.NextBlock() { + what := c.Val() + if !c.NextArg() { + return nil, c.ArgErr() + } + value := c.Val() + var err error + switch what { + case "size": + size, err = strconv.Atoi(value) + case "age": + age, err = strconv.Atoi(value) + case "keep": + keep, err = strconv.Atoi(value) + } + if err != nil { + return nil, err + } + } + return &middleware.LogRoller{ + MaxSize: size, + MaxAge: age, + MaxBackups: keep, + LocalTime: true, + }, nil +} diff --git a/core/setup/root.go b/core/setup/root.go new file mode 100644 index 000000000..0fce5f170 --- /dev/null +++ b/core/setup/root.go @@ -0,0 +1,32 @@ +package setup + +import ( + "log" + "os" + + "github.com/miekg/coredns/middleware" +) + +// Root sets up the root file path of the server. +func Root(c *Controller) (middleware.Middleware, error) { + for c.Next() { + if !c.NextArg() { + return nil, c.ArgErr() + } + c.Root = c.Val() + } + + // Check if root path exists + _, err := os.Stat(c.Root) + if err != nil { + if os.IsNotExist(err) { + // Allow this, because the folder might appear later. + // But make sure the user knows! + log.Printf("[WARNING] Root path does not exist: %s", c.Root) + } else { + return nil, c.Errf("Unable to access root path '%s': %v", c.Root, err) + } + } + + return nil, nil +} diff --git a/core/setup/root_test.go b/core/setup/root_test.go new file mode 100644 index 000000000..8b38e6d04 --- /dev/null +++ b/core/setup/root_test.go @@ -0,0 +1,108 @@ +package setup + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestRoot(t *testing.T) { + + // Predefined error substrings + parseErrContent := "Parse error:" + unableToAccessErrContent := "Unable to access root path" + + existingDirPath, err := getTempDirPath() + if err != nil { + t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err) + } + + nonExistingDir := filepath.Join(existingDirPath, "highly_unlikely_to_exist_dir") + + existingFile, err := ioutil.TempFile("", "root_test") + if err != nil { + t.Fatalf("BeforeTest: Failed to create temp file for testing! Error was: %v", err) + } + defer func() { + existingFile.Close() + os.Remove(existingFile.Name()) + }() + + inaccessiblePath := getInaccessiblePath(existingFile.Name()) + + tests := []struct { + input string + shouldErr bool + expectedRoot string // expected root, set to the controller. Empty for negative cases. + expectedErrContent string // substring from the expected error. Empty for positive cases. + }{ + // positive + { + fmt.Sprintf(`root %s`, nonExistingDir), false, nonExistingDir, "", + }, + { + fmt.Sprintf(`root %s`, existingDirPath), false, existingDirPath, "", + }, + // negative + { + `root `, true, "", parseErrContent, + }, + { + fmt.Sprintf(`root %s`, inaccessiblePath), true, "", unableToAccessErrContent, + }, + { + fmt.Sprintf(`root { + %s + }`, existingDirPath), true, "", parseErrContent, + }, + } + + for i, test := range tests { + c := NewTestController(test.input) + mid, err := Root(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErrContent) { + t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input) + } + } + + // the Root method always returns a nil middleware + if mid != nil { + t.Errorf("Middware, returned from Root() was not nil: %v", mid) + } + + // check c.Root only if we are in a positive test. + if !test.shouldErr && test.expectedRoot != c.Root { + t.Errorf("Root not correctly set for input %s. Expected: %s, actual: %s", test.input, test.expectedRoot, c.Root) + } + } +} + +// getTempDirPath returnes the path to the system temp directory. If it does not exists - an error is returned. +func getTempDirPath() (string, error) { + tempDir := os.TempDir() + + _, err := os.Stat(tempDir) + if err != nil { + return "", err + } + + return tempDir, nil +} + +func getInaccessiblePath(file string) string { + // null byte in filename is not allowed on Windows AND unix + return filepath.Join("C:", "file\x00name") +} diff --git a/core/setup/startupshutdown.go b/core/setup/startupshutdown.go new file mode 100644 index 000000000..1cf2c62e0 --- /dev/null +++ b/core/setup/startupshutdown.go @@ -0,0 +1,64 @@ +package setup + +import ( + "os" + "os/exec" + "strings" + + "github.com/miekg/coredns/middleware" +) + +// Startup registers a startup callback to execute during server start. +func Startup(c *Controller) (middleware.Middleware, error) { + return nil, registerCallback(c, &c.FirstStartup) +} + +// Shutdown registers a shutdown callback to execute during process exit. +func Shutdown(c *Controller) (middleware.Middleware, error) { + return nil, registerCallback(c, &c.Shutdown) +} + +// registerCallback registers a callback function to execute by +// using c to parse the line. It appends the callback function +// to the list of callback functions passed in by reference. +func registerCallback(c *Controller, list *[]func() error) error { + var funcs []func() error + + for c.Next() { + args := c.RemainingArgs() + if len(args) == 0 { + return c.ArgErr() + } + + nonblock := false + if len(args) > 1 && args[len(args)-1] == "&" { + // Run command in background; non-blocking + nonblock = true + args = args[:len(args)-1] + } + + command, args, err := middleware.SplitCommandAndArgs(strings.Join(args, " ")) + if err != nil { + return c.Err(err.Error()) + } + + fn := func() error { + cmd := exec.Command(command, args...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if nonblock { + return cmd.Start() + } + return cmd.Run() + } + + funcs = append(funcs, fn) + } + + return c.OncePerServerBlock(func() error { + *list = append(*list, funcs...) + return nil + }) +} diff --git a/core/setup/startupshutdown_test.go b/core/setup/startupshutdown_test.go new file mode 100644 index 000000000..871a64214 --- /dev/null +++ b/core/setup/startupshutdown_test.go @@ -0,0 +1,59 @@ +package setup + +import ( + "os" + "path/filepath" + "strconv" + "testing" + "time" +) + +// The Startup function's tests are symmetrical to Shutdown tests, +// because the Startup and Shutdown functions share virtually the +// same functionality +func TestStartup(t *testing.T) { + tempDirPath, err := getTempDirPath() + if err != nil { + t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err) + } + + testDir := filepath.Join(tempDirPath, "temp_dir_for_testing_startupshutdown") + defer func() { + // clean up after non-blocking startup function quits + time.Sleep(500 * time.Millisecond) + os.RemoveAll(testDir) + }() + osSenitiveTestDir := filepath.FromSlash(testDir) + os.RemoveAll(osSenitiveTestDir) // start with a clean slate + + tests := []struct { + input string + shouldExecutionErr bool + shouldRemoveErr bool + }{ + // test case #0 tests proper functionality blocking commands + {"startup mkdir " + osSenitiveTestDir, false, false}, + + // test case #1 tests proper functionality of non-blocking commands + {"startup mkdir " + osSenitiveTestDir + " &", false, true}, + + // test case #2 tests handling of non-existent commands + {"startup " + strconv.Itoa(int(time.Now().UnixNano())), true, true}, + } + + for i, test := range tests { + c := NewTestController(test.input) + _, err = Startup(c) + if err != nil { + t.Errorf("Expected no errors, got: %v", err) + } + err = c.FirstStartup[0]() + if err != nil && !test.shouldExecutionErr { + t.Errorf("Test %d recieved an error of:\n%v", i, err) + } + err = os.Remove(osSenitiveTestDir) + if err != nil && !test.shouldRemoveErr { + t.Errorf("Test %d recieved an error of:\n%v", i, err) + } + } +} diff --git a/core/setup/testdata/blog/first_post.md b/core/setup/testdata/blog/first_post.md new file mode 100644 index 000000000..f26583b75 --- /dev/null +++ b/core/setup/testdata/blog/first_post.md @@ -0,0 +1 @@ +# Test h1 diff --git a/core/setup/testdata/header.html b/core/setup/testdata/header.html new file mode 100644 index 000000000..9c96e0e37 --- /dev/null +++ b/core/setup/testdata/header.html @@ -0,0 +1 @@ +<h1>Header title</h1> diff --git a/core/setup/testdata/tpl_with_include.html b/core/setup/testdata/tpl_with_include.html new file mode 100644 index 000000000..95eeae0c8 --- /dev/null +++ b/core/setup/testdata/tpl_with_include.html @@ -0,0 +1,10 @@ +<!DOCTYPE html> +<html> +<head> +<title>{{.Doc.title}}</title> +</head> +<body> +{{.Include "header.html"}} +{{.Doc.body}} +</body> +</html> diff --git a/core/sigtrap.go b/core/sigtrap.go new file mode 100644 index 000000000..3b74efb02 --- /dev/null +++ b/core/sigtrap.go @@ -0,0 +1,71 @@ +package core + +import ( + "log" + "os" + "os/signal" + "sync" + + "github.com/miekg/coredns/server" +) + +// TrapSignals create signal handlers for all applicable signals for this +// system. If your Go program uses signals, this is a rather invasive +// function; best to implement them yourself in that case. Signals are not +// required for the caddy package to function properly, but this is a +// convenient way to allow the user to control this package of your program. +func TrapSignals() { + trapSignalsCrossPlatform() + trapSignalsPosix() +} + +// trapSignalsCrossPlatform captures SIGINT, which triggers forceful +// shutdown that executes shutdown callbacks first. A second interrupt +// signal will exit the process immediately. +func trapSignalsCrossPlatform() { + go func() { + shutdown := make(chan os.Signal, 1) + signal.Notify(shutdown, os.Interrupt) + + for i := 0; true; i++ { + <-shutdown + + if i > 0 { + log.Println("[INFO] SIGINT: Force quit") + if PidFile != "" { + os.Remove(PidFile) + } + os.Exit(1) + } + + log.Println("[INFO] SIGINT: Shutting down") + + if PidFile != "" { + os.Remove(PidFile) + } + + go os.Exit(executeShutdownCallbacks("SIGINT")) + } + }() +} + +// executeShutdownCallbacks executes the shutdown callbacks as initiated +// by signame. It logs any errors and returns the recommended exit status. +// This function is idempotent; subsequent invocations always return 0. +func executeShutdownCallbacks(signame string) (exitCode int) { + shutdownCallbacksOnce.Do(func() { + serversMu.Lock() + errs := server.ShutdownCallbacks(servers) + serversMu.Unlock() + + if len(errs) > 0 { + for _, err := range errs { + log.Printf("[ERROR] %s shutdown: %v", signame, err) + } + exitCode = 1 + } + }) + return +} + +var shutdownCallbacksOnce sync.Once diff --git a/core/sigtrap_posix.go b/core/sigtrap_posix.go new file mode 100644 index 000000000..ba24ff4b6 --- /dev/null +++ b/core/sigtrap_posix.go @@ -0,0 +1,79 @@ +// +build !windows + +package core + +import ( + "io/ioutil" + "log" + "os" + "os/signal" + "syscall" +) + +// trapSignalsPosix captures POSIX-only signals. +func trapSignalsPosix() { + go func() { + sigchan := make(chan os.Signal, 1) + signal.Notify(sigchan, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGUSR1) + + for sig := range sigchan { + switch sig { + case syscall.SIGTERM: + log.Println("[INFO] SIGTERM: Terminating process") + if PidFile != "" { + os.Remove(PidFile) + } + os.Exit(0) + + case syscall.SIGQUIT: + log.Println("[INFO] SIGQUIT: Shutting down") + exitCode := executeShutdownCallbacks("SIGQUIT") + err := Stop() + if err != nil { + log.Printf("[ERROR] SIGQUIT stop: %v", err) + exitCode = 1 + } + if PidFile != "" { + os.Remove(PidFile) + } + os.Exit(exitCode) + + case syscall.SIGHUP: + log.Println("[INFO] SIGHUP: Hanging up") + err := Stop() + if err != nil { + log.Printf("[ERROR] SIGHUP stop: %v", err) + } + + case syscall.SIGUSR1: + log.Println("[INFO] SIGUSR1: Reloading") + + var updatedCaddyfile Input + + caddyfileMu.Lock() + if caddyfile == nil { + // Hmm, did spawing process forget to close stdin? Anyhow, this is unusual. + log.Println("[ERROR] SIGUSR1: no Caddyfile to reload (was stdin left open?)") + caddyfileMu.Unlock() + continue + } + if caddyfile.IsFile() { + body, err := ioutil.ReadFile(caddyfile.Path()) + if err == nil { + updatedCaddyfile = CaddyfileInput{ + Filepath: caddyfile.Path(), + Contents: body, + RealFile: true, + } + } + } + caddyfileMu.Unlock() + + err := Restart(updatedCaddyfile) + if err != nil { + log.Printf("[ERROR] SIGUSR1: %v", err) + } + } + } + }() +} diff --git a/core/sigtrap_windows.go b/core/sigtrap_windows.go new file mode 100644 index 000000000..59132cee4 --- /dev/null +++ b/core/sigtrap_windows.go @@ -0,0 +1,3 @@ +package core + +func trapSignalsPosix() {} diff --git a/db.dns.miek.nl b/db.dns.miek.nl new file mode 100644 index 000000000..78d75f10a --- /dev/null +++ b/db.dns.miek.nl @@ -0,0 +1,10 @@ +$TTL 30M +@ IN SOA linode.atoom.net. miek.miek.nl. ( + 1282630058 ; Serial + 4H ; Refresh + 1H ; Retry + 7D ; Expire + 4H ) ; Negative Cache TTL + IN NS linode.atoom.net. + +go IN TXT "Hello!" diff --git a/db.miek.nl b/db.miek.nl new file mode 100644 index 000000000..9f0737bb3 --- /dev/null +++ b/db.miek.nl @@ -0,0 +1,29 @@ +$TTL 30M +@ IN SOA linode.atoom.net. miek.miek.nl. ( + 1282630057 ; Serial + 4H ; Refresh + 1H ; Retry + 7D ; Expire + 4H ) ; Negative Cache TTL + IN NS linode.atoom.net. + IN NS ns-ext.nlnetlabs.nl. + IN NS omval.tednet.nl. + IN NS ext.ns.whyscream.net. + + IN MX 1 aspmx.l.google.com. + IN MX 5 alt1.aspmx.l.google.com. + IN MX 5 alt2.aspmx.l.google.com. + IN MX 10 aspmx2.googlemail.com. + IN MX 10 aspmx3.googlemail.com. + + IN HINFO "Intel x64" "Linux" + IN A 139.162.196.78 + IN AAAA 2a01:7e00::f03c:91ff:fef1:6735 + +a IN A 139.162.196.78 + IN AAAA 2a01:7e00::f03c:91ff:fef1:6735 +www IN CNAME a +archive IN CNAME a + +; Go DNS ping back +go.dns IN TXT "Hello!" diff --git a/dist/CHANGES.txt b/dist/CHANGES.txt new file mode 100644 index 000000000..34b69d7d5 --- /dev/null +++ b/dist/CHANGES.txt @@ -0,0 +1,190 @@ +CHANGES
+
+0.8.2 (February 25, 2016)
+- On-demand TLS can obtain certificates during handshakes
+- Built with Go 1.6
+- Process log (-log) is rotated when it gets large
+- Managed certificates get renewed 30 days early instead of just 14
+- fastcgi: Allow scheme prefix before address
+- markdown: Support for definition lists
+- proxy: Allow proxy to insecure HTTPS backends
+- proxy: Support proxy to unix socket
+- rewrite: Status code can be 2xx or 4xx
+- templates: New .Markdown action to interpret included file as Markdown
+- templates: .Truncate now truncates from end of string when length is negative
+- tls: Set hard limit for certificates obtained with on-demand TLS
+- tls: Load certificates from directory
+- tls: Add SHA384 cipher suites
+- Multiple bug fixes and internal changes
+
+
+0.8.1 (January 12, 2016)
+- Improved OCSP stapling
+- Better graceful reload when new hosts need certificates from Let's Encrypt
+- Current pidfile is now deleted when Caddy exits
+- browse: New default template
+- gzip: Added min_length setting
+- import: Support for glob patterns (*) to import multiple files
+- rewrite: New complex rules with conditions, regex captures, and status code
+- tls: Removed DES ciphers from default cipher suite list
+- tls: All supported certificates are OCSP-stapled
+- tls: Allow custom configuration without specifying certificate and key
+- tls: No longer allow HTTPS over port 80
+- Dozens of bug fixes, improvements, and more tests across the board
+
+
+0.8.0 (December 4, 2015)
+- HTTPS by default via Let's Encrypt (certs & keys are fully managed)
+- Graceful restarts (on POSIX-compliant systems)
+- Major internal refactoring to allow use of Caddy as library
+- New directive 'mime' to customize Content-Type based on file extension
+- New -accept flag to accept Let's Encrypt SA without prompt
+- New -email flag to customize default email used for ACME transactions
+- New -ca flag to customize ACME CA server URL
+- New -revoke flag to revoke a certificate
+- New -log flag to enable process log
+- New -pidfile flag to enable writing pidfile
+- New -grace flag to customize the graceful shutdown timeout
+- New support for SIGHUP, SIGTERM, and SIGQUIT signals
+- browse: Render filenames with multiple whitespace properly
+- core: Use environment variables in Caddyfile
+- markdown: Include Last-Modified header in response
+- markdown: Render tables, strikethrough, and fenced code blocks
+- proxy: Ability to exclude/ignore paths from proxying
+- startup, shutdown: Better Windows support
+- templates: Bug fix for .Host when port is absent
+- templates: Include Last-Modified header in response
+- templates: Support for custom delimiters
+- tls: For non-local hosts, default port is now 443 unless specified
+- tls: Force-disable HTTPS
+- tls: Specify Let's Encrypt email address
+- Many, many more tests and numerous bug fixes and improvements
+
+
+0.7.6 (September 28, 2015)
+- Pass in simple Caddyfile as command line arguments
+- basicauth: Support for legacy htpasswd files
+- browse: JSON response with file listing
+- core: Caddyfile as command line argument
+- errors: Can write full stack trace to HTTP response for debugging
+- errors, log: Roll log files after certain size or age
+- proxy: Fix for 32-bit architectures
+- rewrite: Better compatibility with fastcgi and PHP apps
+- templates: Added .StripExt and .StripHTML methods
+- Internal improvements and minor bug fixes
+
+
+0.7.5 (August 5, 2015)
+- core: All listeners bind to 0.0.0.0 unless 'bind' directive is used
+- fastcgi: Set HTTPS env variable if connection is secure
+- log: Output to system log (except Windows)
+- markdown: Added dev command to disable caching during development
+- markdown: Fixed error reporting during initial site generation
+- markdown: Fixed crash if path does not exist when server starts
+- markdown: Fixed site generation and link indexing when files change
+- templates: Added .NowDate for use in date-related functions
+- Several bug fixes related to startup and shutdown functions
+
+
+0.7.4 (July 30, 2015)
+- browse: Sorting preference persisted in cookie
+- browse: Added index.txt and default.txt to list of default files
+- browse: Template files may now use Caddy template actions
+- markdown: Template files may now use Caddy template actions
+- markdown: Several bug fixes, especially for large and empty Markdown files
+- markdown: Generate index pages to link to markdown pages (sitegen only)
+- markdown: Flatten structure of front matter, changed template variables
+- redir: Can use variables (placeholders) like log formats can
+- redir: Catch-all redirects no longer preserve path; use {uri} instead
+- redir: Syntax supports redirect tables by opening a block
+- templates: Renamed .Date to .Now and added .Truncate, .Replace actions
+- Other minor internal improvements and more tests
+
+
+0.7.3 (July 15, 2015)
+- errors: Error log now shows timestamp with each entry
+- gzip: Fixed; Default filtering is by extension; removed MIME type filter
+- import: Fixed; works inside and outside server blocks
+- redir: Query string preserved on catch-all redirects
+- templates: Proper 403 or 404 errors for restricted or missing files
+
+
+0.7.2 (July 1, 2015)
+- Custom builds through caddyserver.com - extend Caddy by writing addons
+- browse: Sort by clicking column heading or using query string
+- core: Serving hostname that doesn't resolve issues warning then listens on 0.0.0.0
+- errors: Missing error page during parse time is warning, not error
+- ext: Extension only appended if request path does not end in /
+- fastcgi: Fix for backend responding without status text
+- fastcgi: Fix PATH_TRANSLATED when PATH_INFO is empty (RFC 3875)
+- git: Removed from core (available as add-on)
+- gzip: Enable by file path and/or extension
+- gzip: Customize compression level
+- log: Fix for missing status in log entry when error unhandled
+- proxy: Strip prefix from path for proxy to path
+- redir: Meta tag redirects
+- templates: Support for nested includes
+- Internal improvements and more tests
+
+
+0.7.1 (June 2, 2015)
+- basicauth: Patched timing vulnerability
+- proxy: Support for WebSocket backends
+- tls: Client authentication
+
+
+0.7.0 (May 25, 2015)
+- New directive 'internal' to protect resources with X-Accel-Redirect
+- New -version flag to show program name and version
+- core: Fixed escaped backslash characters inside quoted strings
+- core: Fixed parsing Caddyfile for IPv6 addresses missing ports
+- core: A notice is shown when non-local address resolves to loopback interface
+- core: Warns if file descriptor limit is too low for production site (Mac/Linux)
+- fastcgi: Support for Unix sockets
+- git: Fixed issue that prevented pulling at designated interval
+- header: Remove a header field by prefixing field name with "-"
+- markdown: Simple static site generation
+- markdown: Support for metadata ("front matter") at beginning of files
+- rewrite: Experimental support for regular expressions
+- tls: Customize cipher suites and protocols
+- tls: Removed RC4 ciphers
+- Other internal improvements that are not user-facing (more tests, etc.)
+
+
+0.6.0 (May 7, 2015)
+- New directive 'git' to automatically pull changes
+- New directive 'bind' to override host server binds to
+- New -root flag to specify root path to default site
+- Ability to receive config data piped through stdin
+- core: Warning if root directory doesn't exist at startup
+- core: Entire process dies if any server fails to start
+- gzip: Fixed Content-Length value when proxying requests
+- errors: Error log now includes file and line number of panics
+- fastcgi: Pass custom environment variables
+- fastcgi: Support for HEAD, OPTIONS, PUT, PATCH, and DELETE methods
+- fastcgi: Fixed SERVER_SOFTWARE variables
+- markdown: Support for index files when URL points to a directory
+- proxy: Load balancing with multiple backends, health checks, failovers, and multiple policies
+- proxy: Add custom headers
+- startup/shutdown: Run command in background with '&' at end
+- templates: Added .tpl and .tmpl as default extensions
+- templates: Support for index files when URL points to a directory
+- templates: Changed .RemoteAddr to .IP and stripped out remote port
+- tls: TLS disabled (with warning) for servers that are explicitly http://
+- websocket: Fixed SERVER_SOFTWARE and GATEWAY_INTERFACE variables
+- Many internal improvements
+
+
+0.5.1 (April 30, 2015)
+- Default host is now 0.0.0.0 (wildcard)
+- New -host and -port flags to override default host and port
+- core: Support for binding to 0.0.0.0
+- core: Graceful error handling during heavy load; proper error responses
+- errors: Fixed file path handling
+- errors: Fixed panic due to nil log file
+- fastcgi: Support for index files
+- fastcgi: Fix for handling errors that come from responder
+
+
+0.5.0 (April 28, 2015)
+- Initial release
diff --git a/dist/LICENSES.txt b/dist/LICENSES.txt new file mode 100644 index 000000000..c6ca2e2b0 --- /dev/null +++ b/dist/LICENSES.txt @@ -0,0 +1,539 @@ +The enclosed software makes use of third-party libraries either in full
+or in part, original or modified. This file is part of your download so
+as to be in full compliance with the licenses of all bundled property.
+
+
+
+###
+### github.com/mholt/caddy
+###
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright {yyyy} {name of copyright owner}
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+
+
+
+
+
+
+
+
+
+###
+### Go standard library and http2
+###
+
+
+Copyright (c) 2012 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+
+
+
+
+
+###
+### github.com/russross/blackfriday
+###
+
+
+Blackfriday is distributed under the Simplified BSD License:
+
+> Copyright © 2011 Russ Ross
+> All rights reserved.
+>
+> Redistribution and use in source and binary forms, with or without
+> modification, are permitted provided that the following conditions
+> are met:
+>
+> 1. Redistributions of source code must retain the above copyright
+> notice, this list of conditions and the following disclaimer.
+>
+> 2. Redistributions in binary form must reproduce the above
+> copyright notice, this list of conditions and the following
+> disclaimer in the documentation and/or other materials provided with
+> the distribution.
+>
+> THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+> "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+> LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+> FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+> COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+> INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+> BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+> LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+> CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+> LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+> ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+> POSSIBILITY OF SUCH DAMAGE.
+
+
+
+
+
+
+
+
+###
+### github.com/dustin/go-humanize
+###
+
+
+Copyright (c) 2005-2008 Dustin Sallings <dustin@spy.net>
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+<http://www.opensource.org/licenses/mit-license.php>
+
+
+
+
+
+
+
+
+###
+### github.com/flynn/go-shlex
+###
+
+Apache 2.0 license as found in this file
+
+
+
+
+
+
+###
+### github.com/go-yaml/yaml
+###
+
+
+Copyright (c) 2011-2014 - Canonical Inc.
+
+This software is licensed under the LGPLv3, included below.
+
+As a special exception to the GNU Lesser General Public License version 3
+("LGPL3"), the copyright holders of this Library give you permission to
+convey to a third party a Combined Work that links statically or dynamically
+to this Library without providing any Minimal Corresponding Source or
+Minimal Application Code as set out in 4d or providing the installation
+information set out in section 4e, provided that you comply with the other
+provisions of LGPL3 and provided that you meet, for the Application the
+terms and conditions of the license(s) which apply to the Application.
+
+Except as stated in this special exception, the provisions of LGPL3 will
+continue to comply in full to this Library. If you modify this Library, you
+may apply this exception to your version of this Library, but you are not
+obliged to do so. If you do not wish to do so, delete this exception
+statement from your version. This exception does not (and cannot) modify any
+license terms which apply to the Application, with which you must still
+comply.
+
+
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
+
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
diff --git a/dist/README.txt b/dist/README.txt new file mode 100644 index 000000000..e2ec8a24b --- /dev/null +++ b/dist/README.txt @@ -0,0 +1,30 @@ +CADDY 0.8.2
+
+Website
+ https://caddyserver.com
+
+Twitter
+ @caddyserver
+
+Source Code
+ https://github.com/mholt/caddy
+ https://github.com/caddyserver
+
+
+For instructions on using Caddy, please see the user guide on the website.
+For a list of what's new in this version, see CHANGES.txt.
+
+Please consider donating to the project if you think it is helpful,
+especially if your company is using Caddy. There are also sponsorship
+opportunities available!
+
+If you have a question, bug report, or would like to contribute, please open an
+issue or submit a pull request on GitHub. Your contributions do not go unnoticed!
+
+For a good time, follow @mholt6 on Twitter.
+
+And thanks - you're awesome!
+
+
+---
+(c) 2015 - 2016 Matthew Holt
diff --git a/dist/automate.sh b/dist/automate.sh new file mode 100755 index 000000000..3cc8d41b2 --- /dev/null +++ b/dist/automate.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash +set -e +set -o pipefail +shopt -s nullglob # if no files match glob, assume empty list instead of string literal + + +## PACKAGE TO BUILD +Package=github.com/mholt/caddy + + +## PATHS TO USE +DistDir=$GOPATH/src/$Package/dist +BuildDir=$DistDir/builds +ReleaseDir=$DistDir/release + + +## BEGIN + +# Compile binaries +mkdir -p $BuildDir +cd $BuildDir +rm -f caddy* +gox $Package + +# Zip them up with release notes and stuff +mkdir -p $ReleaseDir +cd $ReleaseDir +rm -f caddy* +for f in $BuildDir/* +do + # Name .zip file same as binary, but strip .exe from end + zipname=$(basename ${f%".exe"}) + if [[ $f == *"linux"* ]] || [[ $f == *"bsd"* ]]; then + zipname=${zipname}.tar.gz + else + zipname=${zipname}.zip + fi + + # Binary inside the zip file is simply the project name + binbase=$(basename $Package) + if [[ $f == *.exe ]]; then + binbase=$binbase.exe + fi + bin=$BuildDir/$binbase + mv $f $bin + + # Compress distributable + if [[ $zipname == *.zip ]]; then + zip -j $zipname $bin $DistDir/CHANGES.txt $DistDir/LICENSES.txt $DistDir/README.txt + else + tar -cvzf $zipname -C $BuildDir $binbase -C $DistDir CHANGES.txt LICENSES.txt README.txt + fi + + # Put binary filename back to original + mv $bin $f +done diff --git a/main.go b/main.go new file mode 100644 index 000000000..d35b0a5ee --- /dev/null +++ b/main.go @@ -0,0 +1,232 @@ +package main + +import ( + "errors" + "flag" + "fmt" + "io/ioutil" + "log" + "os" + "runtime" + "strconv" + "strings" + "time" + + "github.com/miekg/coredns/core" + "github.com/miekg/coredns/core/https" + "github.com/xenolf/lego/acme" + "gopkg.in/natefinch/lumberjack.v2" +) + +func init() { + core.TrapSignals() + setVersion() + flag.BoolVar(&https.Agreed, "agree", false, "Agree to Let's Encrypt Subscriber Agreement") + flag.StringVar(&https.CAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "Certificate authority ACME server") + flag.StringVar(&conf, "conf", "", "Configuration file to use (default="+core.DefaultConfigFile+")") + flag.StringVar(&cpu, "cpu", "100%", "CPU cap") + flag.StringVar(&https.DefaultEmail, "email", "", "Default Let's Encrypt account email address") + flag.DurationVar(&core.GracefulTimeout, "grace", 5*time.Second, "Maximum duration of graceful shutdown") + flag.StringVar(&core.Host, "host", core.DefaultHost, "Default host") + flag.StringVar(&logfile, "log", "", "Process log file") + flag.StringVar(&core.PidFile, "pidfile", "", "Path to write pid file") + flag.StringVar(&core.Port, "port", core.DefaultPort, "Default port") + flag.BoolVar(&core.Quiet, "quiet", false, "Quiet mode (no initialization output)") + flag.StringVar(&revoke, "revoke", "", "Hostname for which to revoke the certificate") + flag.StringVar(&core.Root, "root", core.DefaultRoot, "Root path to default site") + flag.BoolVar(&version, "version", false, "Show version") +} + +func main() { + flag.Parse() // called here in main() to allow other packages to set flags in their inits + + core.AppName = appName + core.AppVersion = appVersion + acme.UserAgent = appName + "/" + appVersion + + // set up process log before anything bad happens + switch logfile { + case "stdout": + log.SetOutput(os.Stdout) + case "stderr": + log.SetOutput(os.Stderr) + case "": + log.SetOutput(ioutil.Discard) + default: + log.SetOutput(&lumberjack.Logger{ + Filename: logfile, + MaxSize: 100, + MaxAge: 14, + MaxBackups: 10, + }) + } + + if revoke != "" { + err := https.Revoke(revoke) + if err != nil { + log.Fatal(err) + } + fmt.Printf("Revoked certificate for %s\n", revoke) + os.Exit(0) + } + if version { + fmt.Printf("%s %s\n", appName, appVersion) + if devBuild && gitShortStat != "" { + fmt.Printf("%s\n%s\n", gitShortStat, gitFilesModified) + } + os.Exit(0) + } + + // Set CPU cap + err := setCPU(cpu) + if err != nil { + mustLogFatal(err) + } + + // Get Corefile input + caddyfile, err := core.LoadCaddyfile(loadCaddyfile) + if err != nil { + mustLogFatal(err) + } + + // Start your engines + err = core.Start(caddyfile) + if err != nil { + mustLogFatal(err) + } + + // Twiddle your thumbs + core.Wait() +} + +// mustLogFatal just wraps log.Fatal() in a way that ensures the +// output is always printed to stderr so the user can see it +// if the user is still there, even if the process log was not +// enabled. If this process is a restart, however, and the user +// might not be there anymore, this just logs to the process log +// and exits. +func mustLogFatal(args ...interface{}) { + if !core.IsRestart() { + log.SetOutput(os.Stderr) + } + log.Fatal(args...) +} + +func loadCaddyfile() (core.Input, error) { + // Try -conf flag + if conf != "" { + if conf == "stdin" { + return core.CaddyfileFromPipe(os.Stdin) + } + + contents, err := ioutil.ReadFile(conf) + if err != nil { + return nil, err + } + + return core.CaddyfileInput{ + Contents: contents, + Filepath: conf, + RealFile: true, + }, nil + } + + // command line args + if flag.NArg() > 0 { + confBody := core.Host + ":" + core.Port + "\n" + strings.Join(flag.Args(), "\n") + return core.CaddyfileInput{ + Contents: []byte(confBody), + Filepath: "args", + }, nil + } + + // Caddyfile in cwd + contents, err := ioutil.ReadFile(core.DefaultConfigFile) + if err != nil { + if os.IsNotExist(err) { + return core.DefaultInput(), nil + } + return nil, err + } + return core.CaddyfileInput{ + Contents: contents, + Filepath: core.DefaultConfigFile, + RealFile: true, + }, nil +} + +// setCPU parses string cpu and sets GOMAXPROCS +// according to its value. It accepts either +// a number (e.g. 3) or a percent (e.g. 50%). +func setCPU(cpu string) error { + var numCPU int + + availCPU := runtime.NumCPU() + + if strings.HasSuffix(cpu, "%") { + // Percent + var percent float32 + pctStr := cpu[:len(cpu)-1] + pctInt, err := strconv.Atoi(pctStr) + if err != nil || pctInt < 1 || pctInt > 100 { + return errors.New("invalid CPU value: percentage must be between 1-100") + } + percent = float32(pctInt) / 100 + numCPU = int(float32(availCPU) * percent) + } else { + // Number + num, err := strconv.Atoi(cpu) + if err != nil || num < 1 { + return errors.New("invalid CPU value: provide a number or percent greater than 0") + } + numCPU = num + } + + if numCPU > availCPU { + numCPU = availCPU + } + + runtime.GOMAXPROCS(numCPU) + return nil +} + +// setVersion figures out the version information based on +// variables set by -ldflags. +func setVersion() { + // A development build is one that's not at a tag or has uncommitted changes + devBuild = gitTag == "" || gitShortStat != "" + + // Only set the appVersion if -ldflags was used + if gitNearestTag != "" || gitTag != "" { + if devBuild && gitNearestTag != "" { + appVersion = fmt.Sprintf("%s (+%s %s)", + strings.TrimPrefix(gitNearestTag, "v"), gitCommit, buildDate) + } else if gitTag != "" { + appVersion = strings.TrimPrefix(gitTag, "v") + } + } +} + +const appName = "Caddy" + +// Flags that control program flow or startup +var ( + conf string + cpu string + logfile string + revoke string + version bool +) + +// Build information obtained with the help of -ldflags +var ( + appVersion = "(untracked dev build)" // inferred at startup + devBuild = true // inferred at startup + + buildDate string // date -u + gitTag string // git describe --exact-match HEAD 2> /dev/null + gitNearestTag string // git describe --abbrev=0 --tags HEAD + gitCommit string // git rev-parse HEAD + gitShortStat string // git diff-index --shortstat + gitFilesModified string // git diff-index --name-only HEAD +) diff --git a/main_test.go b/main_test.go new file mode 100644 index 000000000..01722ed60 --- /dev/null +++ b/main_test.go @@ -0,0 +1,75 @@ +package main + +import ( + "runtime" + "testing" +) + +func TestSetCPU(t *testing.T) { + currentCPU := runtime.GOMAXPROCS(-1) + maxCPU := runtime.NumCPU() + halfCPU := int(0.5 * float32(maxCPU)) + if halfCPU < 1 { + halfCPU = 1 + } + for i, test := range []struct { + input string + output int + shouldErr bool + }{ + {"1", 1, false}, + {"-1", currentCPU, true}, + {"0", currentCPU, true}, + {"100%", maxCPU, false}, + {"50%", halfCPU, false}, + {"110%", currentCPU, true}, + {"-10%", currentCPU, true}, + {"invalid input", currentCPU, true}, + {"invalid input%", currentCPU, true}, + {"9999", maxCPU, false}, // over available CPU + } { + err := setCPU(test.input) + if test.shouldErr && err == nil { + t.Errorf("Test %d: Expected error, but there wasn't any", i) + } + if !test.shouldErr && err != nil { + t.Errorf("Test %d: Expected no error, but there was one: %v", i, err) + } + if actual, expected := runtime.GOMAXPROCS(-1), test.output; actual != expected { + t.Errorf("Test %d: GOMAXPROCS was %d but expected %d", i, actual, expected) + } + // teardown + runtime.GOMAXPROCS(currentCPU) + } +} + +func TestSetVersion(t *testing.T) { + setVersion() + if !devBuild { + t.Error("Expected default to assume development build, but it didn't") + } + if got, want := appVersion, "(untracked dev build)"; got != want { + t.Errorf("Expected appVersion='%s', got: '%s'", want, got) + } + + gitTag = "v1.1" + setVersion() + if devBuild { + t.Error("Expected a stable build if gitTag is set with no changes") + } + if got, want := appVersion, "1.1"; got != want { + t.Errorf("Expected appVersion='%s', got: '%s'", want, got) + } + + gitTag = "" + gitNearestTag = "v1.0" + gitCommit = "deadbeef" + buildDate = "Fri Feb 26 06:53:17 UTC 2016" + setVersion() + if !devBuild { + t.Error("Expected inferring a dev build when gitTag is empty") + } + if got, want := appVersion, "1.0 (+deadbeef Fri Feb 26 06:53:17 UTC 2016)"; got != want { + t.Errorf("Expected appVersion='%s', got: '%s'", want, got) + } +} diff --git a/middleware/commands.go b/middleware/commands.go new file mode 100644 index 000000000..5c241161e --- /dev/null +++ b/middleware/commands.go @@ -0,0 +1,120 @@ +package middleware + +import ( + "errors" + "runtime" + "unicode" + + "github.com/flynn/go-shlex" +) + +var runtimeGoos = runtime.GOOS + +// SplitCommandAndArgs takes a command string and parses it +// shell-style into the command and its separate arguments. +func SplitCommandAndArgs(command string) (cmd string, args []string, err error) { + var parts []string + + if runtimeGoos == "windows" { + parts = parseWindowsCommand(command) // parse it Windows-style + } else { + parts, err = parseUnixCommand(command) // parse it Unix-style + if err != nil { + err = errors.New("error parsing command: " + err.Error()) + return + } + } + + if len(parts) == 0 { + err = errors.New("no command contained in '" + command + "'") + return + } + + cmd = parts[0] + if len(parts) > 1 { + args = parts[1:] + } + + return +} + +// parseUnixCommand parses a unix style command line and returns the +// command and its arguments or an error +func parseUnixCommand(cmd string) ([]string, error) { + return shlex.Split(cmd) +} + +// parseWindowsCommand parses windows command lines and +// returns the command and the arguments as an array. It +// should be able to parse commonly used command lines. +// Only basic syntax is supported: +// - spaces in double quotes are not token delimiters +// - double quotes are escaped by either backspace or another double quote +// - except for the above case backspaces are path separators (not special) +// +// Many sources point out that escaping quotes using backslash can be unsafe. +// Use two double quotes when possible. (Source: http://stackoverflow.com/a/31413730/2616179 ) +// +// This function has to be used on Windows instead +// of the shlex package because this function treats backslash +// characters properly. +func parseWindowsCommand(cmd string) []string { + const backslash = '\\' + const quote = '"' + + var parts []string + var part string + var inQuotes bool + var lastRune rune + + for i, ch := range cmd { + + if i != 0 { + lastRune = rune(cmd[i-1]) + } + + if ch == backslash { + // put it in the part - for now we don't know if it's an + // escaping char or path separator + part += string(ch) + continue + } + + if ch == quote { + if lastRune == backslash { + // remove the backslash from the part and add the escaped quote instead + part = part[:len(part)-1] + part += string(ch) + continue + } + + if lastRune == quote { + // revert the last change of the inQuotes state + // it was an escaping quote + inQuotes = !inQuotes + part += string(ch) + continue + } + + // normal escaping quotes + inQuotes = !inQuotes + continue + + } + + if unicode.IsSpace(ch) && !inQuotes && len(part) > 0 { + parts = append(parts, part) + part = "" + continue + } + + part += string(ch) + } + + if len(part) > 0 { + parts = append(parts, part) + part = "" + } + + return parts +} diff --git a/middleware/commands_test.go b/middleware/commands_test.go new file mode 100644 index 000000000..3001e65a5 --- /dev/null +++ b/middleware/commands_test.go @@ -0,0 +1,291 @@ +package middleware + +import ( + "fmt" + "runtime" + "strings" + "testing" +) + +func TestParseUnixCommand(t *testing.T) { + tests := []struct { + input string + expected []string + }{ + // 0 - emtpy command + { + input: ``, + expected: []string{}, + }, + // 1 - command without arguments + { + input: `command`, + expected: []string{`command`}, + }, + // 2 - command with single argument + { + input: `command arg1`, + expected: []string{`command`, `arg1`}, + }, + // 3 - command with multiple arguments + { + input: `command arg1 arg2`, + expected: []string{`command`, `arg1`, `arg2`}, + }, + // 4 - command with single argument with space character - in quotes + { + input: `command "arg1 arg1"`, + expected: []string{`command`, `arg1 arg1`}, + }, + // 5 - command with multiple spaces and tab character + { + input: "command arg1 arg2\targ3", + expected: []string{`command`, `arg1`, `arg2`, `arg3`}, + }, + // 6 - command with single argument with space character - escaped with backspace + { + input: `command arg1\ arg2`, + expected: []string{`command`, `arg1 arg2`}, + }, + // 7 - single quotes should escape special chars + { + input: `command 'arg1\ arg2'`, + expected: []string{`command`, `arg1\ arg2`}, + }, + } + + for i, test := range tests { + errorPrefix := fmt.Sprintf("Test [%d]: ", i) + errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input) + actual, _ := parseUnixCommand(test.input) + if len(actual) != len(test.expected) { + t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual) + continue + } + for j := 0; j < len(actual); j++ { + if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart { + t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j) + } + } + } +} + +func TestParseWindowsCommand(t *testing.T) { + tests := []struct { + input string + expected []string + }{ + { // 0 - empty command - do not fail + input: ``, + expected: []string{}, + }, + { // 1 - cmd without args + input: `cmd`, + expected: []string{`cmd`}, + }, + { // 2 - multiple args + input: `cmd arg1 arg2`, + expected: []string{`cmd`, `arg1`, `arg2`}, + }, + { // 3 - multiple args with space + input: `cmd "combined arg" arg2`, + expected: []string{`cmd`, `combined arg`, `arg2`}, + }, + { // 4 - path without spaces + input: `mkdir C:\Windows\foo\bar`, + expected: []string{`mkdir`, `C:\Windows\foo\bar`}, + }, + { // 5 - command with space in quotes + input: `"command here"`, + expected: []string{`command here`}, + }, + { // 6 - argument with escaped quotes (two quotes) + input: `cmd ""arg""`, + expected: []string{`cmd`, `"arg"`}, + }, + { // 7 - argument with escaped quotes (backslash) + input: `cmd \"arg\"`, + expected: []string{`cmd`, `"arg"`}, + }, + { // 8 - two quotes (escaped) inside an inQuote element + input: `cmd "a ""quoted value"`, + expected: []string{`cmd`, `a "quoted value`}, + }, + // TODO - see how many quotes are dislayed if we use "", """, """"""" + { // 9 - two quotes outside an inQuote element + input: `cmd a ""quoted value`, + expected: []string{`cmd`, `a`, `"quoted`, `value`}, + }, + { // 10 - path with space in quotes + input: `mkdir "C:\directory name\foobar"`, + expected: []string{`mkdir`, `C:\directory name\foobar`}, + }, + { // 11 - space without quotes + input: `mkdir C:\ space`, + expected: []string{`mkdir`, `C:\`, `space`}, + }, + { // 12 - space in quotes + input: `mkdir "C:\ space"`, + expected: []string{`mkdir`, `C:\ space`}, + }, + { // 13 - UNC + input: `mkdir \\?\C:\Users`, + expected: []string{`mkdir`, `\\?\C:\Users`}, + }, + { // 14 - UNC with space + input: `mkdir "\\?\C:\Program Files"`, + expected: []string{`mkdir`, `\\?\C:\Program Files`}, + }, + + { // 15 - unclosed quotes - treat as if the path ends with quote + input: `mkdir "c:\Program files`, + expected: []string{`mkdir`, `c:\Program files`}, + }, + { // 16 - quotes used inside the argument + input: `mkdir "c:\P"rogra"m f"iles`, + expected: []string{`mkdir`, `c:\Program files`}, + }, + } + + for i, test := range tests { + errorPrefix := fmt.Sprintf("Test [%d]: ", i) + errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input) + + actual := parseWindowsCommand(test.input) + if len(actual) != len(test.expected) { + t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual) + continue + } + for j := 0; j < len(actual); j++ { + if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart { + t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j) + } + } + } +} + +func TestSplitCommandAndArgs(t *testing.T) { + + // force linux parsing. It's more robust and covers error cases + runtimeGoos = "linux" + defer func() { + runtimeGoos = runtime.GOOS + }() + + var parseErrorContent = "error parsing command:" + var noCommandErrContent = "no command contained in" + + tests := []struct { + input string + expectedCommand string + expectedArgs []string + expectedErrContent string + }{ + // 0 - emtpy command + { + input: ``, + expectedCommand: ``, + expectedArgs: nil, + expectedErrContent: noCommandErrContent, + }, + // 1 - command without arguments + { + input: `command`, + expectedCommand: `command`, + expectedArgs: nil, + expectedErrContent: ``, + }, + // 2 - command with single argument + { + input: `command arg1`, + expectedCommand: `command`, + expectedArgs: []string{`arg1`}, + expectedErrContent: ``, + }, + // 3 - command with multiple arguments + { + input: `command arg1 arg2`, + expectedCommand: `command`, + expectedArgs: []string{`arg1`, `arg2`}, + expectedErrContent: ``, + }, + // 4 - command with unclosed quotes + { + input: `command "arg1 arg2`, + expectedCommand: "", + expectedArgs: nil, + expectedErrContent: parseErrorContent, + }, + // 5 - command with unclosed quotes + { + input: `command 'arg1 arg2"`, + expectedCommand: "", + expectedArgs: nil, + expectedErrContent: parseErrorContent, + }, + } + + for i, test := range tests { + errorPrefix := fmt.Sprintf("Test [%d]: ", i) + errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input) + actualCommand, actualArgs, actualErr := SplitCommandAndArgs(test.input) + + // test if error matches expectation + if test.expectedErrContent != "" { + if actualErr == nil { + t.Errorf(errorPrefix+"Expected error with content [%s], found no error."+errorSuffix, test.expectedErrContent) + } else if !strings.Contains(actualErr.Error(), test.expectedErrContent) { + t.Errorf(errorPrefix+"Expected error with content [%s], found [%v]."+errorSuffix, test.expectedErrContent, actualErr) + } + } else if actualErr != nil { + t.Errorf(errorPrefix+"Expected no error, found [%v]."+errorSuffix, actualErr) + } + + // test if command matches + if test.expectedCommand != actualCommand { + t.Errorf(errorPrefix+"Expected command: [%s], actual: [%s]."+errorSuffix, test.expectedCommand, actualCommand) + } + + // test if arguments match + if len(test.expectedArgs) != len(actualArgs) { + t.Errorf(errorPrefix+"Wrong number of arguments! Expected [%v], actual [%v]."+errorSuffix, test.expectedArgs, actualArgs) + } else { + // test args only if the count matches. + for j, actualArg := range actualArgs { + expectedArg := test.expectedArgs[j] + if actualArg != expectedArg { + t.Errorf(errorPrefix+"Argument at position [%d] differ! Expected [%s], actual [%s]"+errorSuffix, j, expectedArg, actualArg) + } + } + } + } +} + +func ExampleSplitCommandAndArgs() { + var commandLine string + var command string + var args []string + + // just for the test - change GOOS and reset it at the end of the test + runtimeGoos = "windows" + defer func() { + runtimeGoos = runtime.GOOS + }() + + commandLine = `mkdir /P "C:\Program Files"` + command, args, _ = SplitCommandAndArgs(commandLine) + + fmt.Printf("Windows: %s: %s [%s]\n", commandLine, command, strings.Join(args, ",")) + + // set GOOS to linux + runtimeGoos = "linux" + + commandLine = `mkdir -p /path/with\ space` + command, args, _ = SplitCommandAndArgs(commandLine) + + fmt.Printf("Linux: %s: %s [%s]\n", commandLine, command, strings.Join(args, ",")) + + // Output: + // Windows: mkdir /P "C:\Program Files": mkdir [/P,C:\Program Files] + // Linux: mkdir -p /path/with\ space: mkdir [-p,/path/with space] +} diff --git a/middleware/context.go b/middleware/context.go new file mode 100644 index 000000000..8868c1c03 --- /dev/null +++ b/middleware/context.go @@ -0,0 +1,135 @@ +package middleware + +import ( + "net" + "net/http" + "strings" + "time" + + "github.com/miekg/dns" +) + +// This file contains the context and functions available for +// use in the templates. + +// Context is the context with which Caddy templates are executed. +type Context struct { + Root http.FileSystem // TODO(miek): needed + Req *dns.Msg + W dns.ResponseWriter +} + +// Now returns the current timestamp in the specified format. +func (c Context) Now(format string) string { + return time.Now().Format(format) +} + +// NowDate returns the current date/time that can be used +// in other time functions. +func (c Context) NowDate() time.Time { + return time.Now() +} + +// Header gets the value of a header. +func (c Context) Header() *dns.RR_Header { + // TODO(miek) + return nil +} + +// IP gets the (remote) IP address of the client making the request. +func (c Context) IP() string { + ip, _, err := net.SplitHostPort(c.W.RemoteAddr().String()) + if err != nil { + return c.W.RemoteAddr().String() + } + return ip +} + +// Post gets the (remote) Port of the client making the request. +func (c Context) Port() (string, error) { + _, port, err := net.SplitHostPort(c.W.RemoteAddr().String()) + if err != nil { + return "0", err + } + return port, nil +} + +// Proto gets the protocol used as the transport. This +// will be udp or tcp. +func (c Context) Proto() string { + if _, ok := c.W.RemoteAddr().(*net.UDPAddr); ok { + return "udp" + } + if _, ok := c.W.RemoteAddr().(*net.TCPAddr); ok { + return "tcp" + } + return "udp" +} + +// Family returns the family of the transport. +// 1 for IPv4 and 2 for IPv6. +func (c Context) Family() int { + var a net.IP + ip := c.W.RemoteAddr() + if i, ok := ip.(*net.UDPAddr); ok { + a = i.IP + } + if i, ok := ip.(*net.TCPAddr); ok { + a = i.IP + } + + if a.To4() != nil { + return 1 + } + return 2 +} + +// Type returns the type of the question as a string. +func (c Context) Type() string { + return dns.Type(c.Req.Question[0].Qtype).String() +} + +// QType returns the type of the question as a uint16. +func (c Context) QType() uint16 { + return c.Req.Question[0].Qtype +} + +// Name returns the name of the question in the request. Note +// this name will always have a closing dot and will be lower cased. +func (c Context) Name() string { + return strings.ToLower(dns.Name(c.Req.Question[0].Name).String()) +} + +// QName returns the name of the question in the request. +func (c Context) QName() string { + return dns.Name(c.Req.Question[0].Name).String() +} + +// Class returns the class of the question in the request. +func (c Context) Class() string { + return dns.Class(c.Req.Question[0].Qclass).String() +} + +// QClass returns the class of the question in the request. +func (c Context) QClass() uint16 { + return c.Req.Question[0].Qclass +} + +// More convience types for extracting stuff from a message? +// Header? + +// ErrorMessage returns an error message suitable for sending +// back to the client. +func (c Context) ErrorMessage(rcode int) *dns.Msg { + m := new(dns.Msg) + m.SetRcode(c.Req, rcode) + return m +} + +// AnswerMessage returns an error message suitable for sending +// back to the client. +func (c Context) AnswerMessage() *dns.Msg { + m := new(dns.Msg) + m.SetReply(c.Req) + return m +} diff --git a/middleware/context_test.go b/middleware/context_test.go new file mode 100644 index 000000000..689c47c13 --- /dev/null +++ b/middleware/context_test.go @@ -0,0 +1,613 @@ +package middleware + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestInclude(t *testing.T) { + context := getContextOrFail(t) + + inputFilename := "test_file" + absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename) + defer func() { + err := os.Remove(absInFilePath) + if err != nil && !os.IsNotExist(err) { + t.Fatalf("Failed to clean test file!") + } + }() + + tests := []struct { + fileContent string + expectedContent string + shouldErr bool + expectedErrorContent string + }{ + // Test 0 - all good + { + fileContent: `str1 {{ .Root }} str2`, + expectedContent: fmt.Sprintf("str1 %s str2", context.Root), + shouldErr: false, + expectedErrorContent: "", + }, + // Test 1 - failure on template.Parse + { + fileContent: `str1 {{ .Root } str2`, + expectedContent: "", + shouldErr: true, + expectedErrorContent: `unexpected "}" in operand`, + }, + // Test 3 - failure on template.Execute + { + fileContent: `str1 {{ .InvalidField }} str2`, + expectedContent: "", + shouldErr: true, + expectedErrorContent: `InvalidField is not a field of struct type middleware.Context`, + }, + } + + for i, test := range tests { + testPrefix := getTestPrefix(i) + + // WriteFile truncates the contentt + err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm) + if err != nil { + t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err) + } + + content, err := context.Include(inputFilename) + if err != nil { + if !test.shouldErr { + t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error()) + } + if !strings.Contains(err.Error(), test.expectedErrorContent) { + t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error()) + } + } + + if err == nil && test.shouldErr { + t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename) + } + + if content != test.expectedContent { + t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename) + } + } +} + +func TestIncludeNotExisting(t *testing.T) { + context := getContextOrFail(t) + + _, err := context.Include("not_existing") + if err == nil { + t.Errorf("Expected error but found nil!") + } +} + +func TestMarkdown(t *testing.T) { + context := getContextOrFail(t) + + inputFilename := "test_file" + absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename) + defer func() { + err := os.Remove(absInFilePath) + if err != nil && !os.IsNotExist(err) { + t.Fatalf("Failed to clean test file!") + } + }() + + tests := []struct { + fileContent string + expectedContent string + }{ + // Test 0 - test parsing of markdown + { + fileContent: "* str1\n* str2\n", + expectedContent: "<ul>\n<li>str1</li>\n<li>str2</li>\n</ul>\n", + }, + } + + for i, test := range tests { + testPrefix := getTestPrefix(i) + + // WriteFile truncates the contentt + err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm) + if err != nil { + t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err) + } + + content, _ := context.Markdown(inputFilename) + if content != test.expectedContent { + t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename) + } + } +} + +func TestCookie(t *testing.T) { + + tests := []struct { + cookie *http.Cookie + cookieName string + expectedValue string + }{ + // Test 0 - happy path + { + cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"}, + cookieName: "cookieName", + expectedValue: "cookieValue", + }, + // Test 1 - try to get a non-existing cookie + { + cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"}, + cookieName: "notExisting", + expectedValue: "", + }, + // Test 2 - partial name match + { + cookie: &http.Cookie{Name: "cookie", Value: "cookieValue"}, + cookieName: "cook", + expectedValue: "", + }, + // Test 3 - cookie with optional fields + { + cookie: &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120}, + cookieName: "cookie", + expectedValue: "cookieValue", + }, + } + + for i, test := range tests { + testPrefix := getTestPrefix(i) + + // reinitialize the context for each test + context := getContextOrFail(t) + + context.Req.AddCookie(test.cookie) + + actualCookieVal := context.Cookie(test.cookieName) + + if actualCookieVal != test.expectedValue { + t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName) + } + } +} + +func TestCookieMultipleCookies(t *testing.T) { + context := getContextOrFail(t) + + cookieNameBase, cookieValueBase := "cookieName", "cookieValue" + + // make sure that there's no state and multiple requests for different cookies return the correct result + for i := 0; i < 10; i++ { + context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)}) + } + + for i := 0; i < 10; i++ { + expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i) + actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i)) + if actualCookieVal != expectedCookieVal { + t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal) + } + } +} + +func TestHeader(t *testing.T) { + context := getContextOrFail(t) + + headerKey, headerVal := "Header1", "HeaderVal1" + context.Req.Header.Add(headerKey, headerVal) + + actualHeaderVal := context.Header(headerKey) + if actualHeaderVal != headerVal { + t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal) + } + + missingHeaderVal := context.Header("not-existing") + if missingHeaderVal != "" { + t.Errorf("Expected empty header value, found %s", missingHeaderVal) + } +} + +func TestIP(t *testing.T) { + context := getContextOrFail(t) + + tests := []struct { + inputRemoteAddr string + expectedIP string + }{ + // Test 0 - ipv4 with port + {"1.1.1.1:1111", "1.1.1.1"}, + // Test 1 - ipv4 without port + {"1.1.1.1", "1.1.1.1"}, + // Test 2 - ipv6 with port + {"[::1]:11", "::1"}, + // Test 3 - ipv6 without port and brackets + {"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"}, + // Test 4 - ipv6 with zone and port + {`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`}, + } + + for i, test := range tests { + testPrefix := getTestPrefix(i) + + context.Req.RemoteAddr = test.inputRemoteAddr + actualIP := context.IP() + + if actualIP != test.expectedIP { + t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP) + } + } +} + +func TestURL(t *testing.T) { + context := getContextOrFail(t) + + inputURL := "http://localhost" + context.Req.RequestURI = inputURL + + if inputURL != context.URI() { + t.Errorf("Expected url %s, found %s", inputURL, context.URI()) + } +} + +func TestHost(t *testing.T) { + tests := []struct { + input string + expectedHost string + shouldErr bool + }{ + { + input: "localhost:123", + expectedHost: "localhost", + shouldErr: false, + }, + { + input: "localhost", + expectedHost: "localhost", + shouldErr: false, + }, + { + input: "[::]", + expectedHost: "", + shouldErr: true, + }, + } + + for _, test := range tests { + testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr) + } +} + +func TestPort(t *testing.T) { + tests := []struct { + input string + expectedPort string + shouldErr bool + }{ + { + input: "localhost:123", + expectedPort: "123", + shouldErr: false, + }, + { + input: "localhost", + expectedPort: "80", // assuming 80 is the default port + shouldErr: false, + }, + { + input: ":8080", + expectedPort: "8080", + shouldErr: false, + }, + { + input: "[::]", + expectedPort: "", + shouldErr: true, + }, + } + + for _, test := range tests { + testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr) + } +} + +func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) { + context := getContextOrFail(t) + + context.Req.Host = input + var actualResult, testedObject string + var err error + + if isTestingHost { + actualResult, err = context.Host() + testedObject = "host" + } else { + actualResult, err = context.Port() + testedObject = "port" + } + + if shouldErr && err == nil { + t.Errorf("Expected error, found nil!") + return + } + + if !shouldErr && err != nil { + t.Errorf("Expected no error, found %s", err) + return + } + + if actualResult != expectedResult { + t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult) + } +} + +func TestMethod(t *testing.T) { + context := getContextOrFail(t) + + method := "POST" + context.Req.Method = method + + if method != context.Method() { + t.Errorf("Expected method %s, found %s", method, context.Method()) + } + +} + +func TestPathMatches(t *testing.T) { + context := getContextOrFail(t) + + tests := []struct { + urlStr string + pattern string + shouldMatch bool + }{ + // Test 0 + { + urlStr: "http://localhost/", + pattern: "", + shouldMatch: true, + }, + // Test 1 + { + urlStr: "http://localhost", + pattern: "", + shouldMatch: true, + }, + // Test 1 + { + urlStr: "http://localhost/", + pattern: "/", + shouldMatch: true, + }, + // Test 3 + { + urlStr: "http://localhost/?param=val", + pattern: "/", + shouldMatch: true, + }, + // Test 4 + { + urlStr: "http://localhost/dir1/dir2", + pattern: "/dir2", + shouldMatch: false, + }, + // Test 5 + { + urlStr: "http://localhost/dir1/dir2", + pattern: "/dir1", + shouldMatch: true, + }, + // Test 6 + { + urlStr: "http://localhost:444/dir1/dir2", + pattern: "/dir1", + shouldMatch: true, + }, + // Test 7 + { + urlStr: "http://localhost/dir1/dir2", + pattern: "*/dir2", + shouldMatch: false, + }, + } + + for i, test := range tests { + testPrefix := getTestPrefix(i) + var err error + context.Req.URL, err = url.Parse(test.urlStr) + if err != nil { + t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err) + } + + matches := context.PathMatches(test.pattern) + if matches != test.shouldMatch { + t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches) + } + } +} + +func TestTruncate(t *testing.T) { + context := getContextOrFail(t) + tests := []struct { + inputString string + inputLength int + expected string + }{ + // Test 0 - small length + { + inputString: "string", + inputLength: 1, + expected: "s", + }, + // Test 1 - exact length + { + inputString: "string", + inputLength: 6, + expected: "string", + }, + // Test 2 - bigger length + { + inputString: "string", + inputLength: 10, + expected: "string", + }, + // Test 3 - zero length + { + inputString: "string", + inputLength: 0, + expected: "", + }, + // Test 4 - negative, smaller length + { + inputString: "string", + inputLength: -5, + expected: "tring", + }, + // Test 5 - negative, exact length + { + inputString: "string", + inputLength: -6, + expected: "string", + }, + // Test 6 - negative, bigger length + { + inputString: "string", + inputLength: -7, + expected: "string", + }, + } + + for i, test := range tests { + actual := context.Truncate(test.inputString, test.inputLength) + if actual != test.expected { + t.Errorf(getTestPrefix(i)+"Expected '%s', found '%s'. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength) + } + } +} + +func TestStripHTML(t *testing.T) { + context := getContextOrFail(t) + tests := []struct { + input string + expected string + }{ + // Test 0 - no tags + { + input: `h1`, + expected: `h1`, + }, + // Test 1 - happy path + { + input: `<h1>h1</h1>`, + expected: `h1`, + }, + // Test 2 - tag in quotes + { + input: `<h1">">h1</h1>`, + expected: `h1`, + }, + // Test 3 - multiple tags + { + input: `<h1><b>h1</b></h1>`, + expected: `h1`, + }, + // Test 4 - tags not closed + { + input: `<h1`, + expected: `<h1`, + }, + // Test 5 - false start + { + input: `<h1<b>hi`, + expected: `<h1hi`, + }, + } + + for i, test := range tests { + actual := context.StripHTML(test.input) + if actual != test.expected { + t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripHTML(%s)", test.expected, actual, test.input) + } + } +} + +func TestStripExt(t *testing.T) { + context := getContextOrFail(t) + tests := []struct { + input string + expected string + }{ + // Test 0 - empty input + { + input: "", + expected: "", + }, + // Test 1 - relative file with ext + { + input: "file.ext", + expected: "file", + }, + // Test 2 - relative file without ext + { + input: "file", + expected: "file", + }, + // Test 3 - absolute file without ext + { + input: "/file", + expected: "/file", + }, + // Test 4 - absolute file with ext + { + input: "/file.ext", + expected: "/file", + }, + // Test 5 - with ext but ends with / + { + input: "/dir.ext/", + expected: "/dir.ext/", + }, + // Test 6 - file with ext under dir with ext + { + input: "/dir.ext/file.ext", + expected: "/dir.ext/file", + }, + } + + for i, test := range tests { + actual := context.StripExt(test.input) + if actual != test.expected { + t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripExt(%q)", test.expected, actual, test.input) + } + } +} + +func initTestContext() (Context, error) { + body := bytes.NewBufferString("request body") + request, err := http.NewRequest("GET", "https://localhost", body) + if err != nil { + return Context{}, err + } + + return Context{Root: http.Dir(os.TempDir()), Req: request}, nil +} + +func getContextOrFail(t *testing.T) Context { + context, err := initTestContext() + if err != nil { + t.Fatalf("Failed to prepare test context") + } + return context +} + +func getTestPrefix(testN int) string { + return fmt.Sprintf("Test [%d]: ", testN) +} diff --git a/middleware/errors/errors.go b/middleware/errors/errors.go new file mode 100644 index 000000000..bf5bc7aae --- /dev/null +++ b/middleware/errors/errors.go @@ -0,0 +1,100 @@ +// Package errors implements an HTTP error handling middleware. +package errors + +import ( + "fmt" + "log" + "runtime" + "strings" + "time" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +// ErrorHandler handles DNS errors (and errors from other middleware). +type ErrorHandler struct { + Next middleware.Handler + LogFile string + Log *log.Logger + LogRoller *middleware.LogRoller + Debug bool // if true, errors are written out to client rather than to a log +} + +func (h ErrorHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + defer h.recovery(w, r) + + rcode, err := h.Next.ServeDNS(w, r) + + if err != nil { + errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err) + + if h.Debug { + // Write error to response as a txt message instead of to log + answer := debugMsg(rcode, r) + txt, _ := dns.NewRR(". IN 0 TXT " + errMsg) + answer.Answer = append(answer.Answer, txt) + w.WriteMsg(answer) + return 0, err + } + h.Log.Println(errMsg) + } + + return rcode, err +} + +func (h ErrorHandler) recovery(w dns.ResponseWriter, r *dns.Msg) { + rec := recover() + if rec == nil { + return + } + + // Obtain source of panic + // From: https://gist.github.com/swdunlop/9629168 + var name, file string // function name, file name + var line int + var pc [16]uintptr + n := runtime.Callers(3, pc[:]) + for _, pc := range pc[:n] { + fn := runtime.FuncForPC(pc) + if fn == nil { + continue + } + file, line = fn.FileLine(pc) + name = fn.Name() + if !strings.HasPrefix(name, "runtime.") { + break + } + } + + // Trim file path + delim := "/coredns/" + pkgPathPos := strings.Index(file, delim) + if pkgPathPos > -1 && len(file) > pkgPathPos+len(delim) { + file = file[pkgPathPos+len(delim):] + } + + panicMsg := fmt.Sprintf("%s [PANIC %s %s] %s:%d - %v", time.Now().Format(timeFormat), r.Question[0].Name, dns.Type(r.Question[0].Qtype), file, line, rec) + if h.Debug { + // Write error and stack trace to the response rather than to a log + var stackBuf [4096]byte + stack := stackBuf[:runtime.Stack(stackBuf[:], false)] + answer := debugMsg(dns.RcodeServerFailure, r) + // add stack buf in TXT, limited to 255 chars for now. + txt, _ := dns.NewRR(". IN 0 TXT " + string(stack[:255])) + answer.Answer = append(answer.Answer, txt) + w.WriteMsg(answer) + } else { + // Currently we don't use the function name, since file:line is more conventional + h.Log.Printf(panicMsg) + } +} + +// debugMsg creates a debug message that gets send back to the client. +func debugMsg(rcode int, r *dns.Msg) *dns.Msg { + answer := new(dns.Msg) + answer.SetRcode(r, rcode) + return answer +} + +const timeFormat = "02/Jan/2006:15:04:05 -0700" diff --git a/middleware/errors/errors_test.go b/middleware/errors/errors_test.go new file mode 100644 index 000000000..4434e835c --- /dev/null +++ b/middleware/errors/errors_test.go @@ -0,0 +1,168 @@ +package errors + +import ( + "bytes" + "errors" + "fmt" + "log" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + + "github.com/miekg/coredns/middleware" +) + +func TestErrors(t *testing.T) { + // create a temporary page + path := filepath.Join(os.TempDir(), "errors_test.html") + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + defer os.Remove(path) + + const content = "This is a error page" + _, err = f.WriteString(content) + if err != nil { + t.Fatal(err) + } + f.Close() + + buf := bytes.Buffer{} + em := ErrorHandler{ + ErrorPages: map[int]string{ + http.StatusNotFound: path, + http.StatusForbidden: "not_exist_file", + }, + Log: log.New(&buf, "", 0), + } + _, notExistErr := os.Open("not_exist_file") + + testErr := errors.New("test error") + tests := []struct { + next middleware.Handler + expectedCode int + expectedBody string + expectedLog string + expectedErr error + }{ + { + next: genErrorHandler(http.StatusOK, nil, "normal"), + expectedCode: http.StatusOK, + expectedBody: "normal", + expectedLog: "", + expectedErr: nil, + }, + { + next: genErrorHandler(http.StatusMovedPermanently, testErr, ""), + expectedCode: http.StatusMovedPermanently, + expectedBody: "", + expectedLog: fmt.Sprintf("[ERROR %d %s] %v\n", http.StatusMovedPermanently, "/", testErr), + expectedErr: testErr, + }, + { + next: genErrorHandler(http.StatusBadRequest, nil, ""), + expectedCode: 0, + expectedBody: fmt.Sprintf("%d %s\n", http.StatusBadRequest, + http.StatusText(http.StatusBadRequest)), + expectedLog: "", + expectedErr: nil, + }, + { + next: genErrorHandler(http.StatusNotFound, nil, ""), + expectedCode: 0, + expectedBody: content, + expectedLog: "", + expectedErr: nil, + }, + { + next: genErrorHandler(http.StatusForbidden, nil, ""), + expectedCode: 0, + expectedBody: fmt.Sprintf("%d %s\n", http.StatusForbidden, + http.StatusText(http.StatusForbidden)), + expectedLog: fmt.Sprintf("[NOTICE %d /] could not load error page: %v\n", + http.StatusForbidden, notExistErr), + expectedErr: nil, + }, + } + + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + for i, test := range tests { + em.Next = test.next + buf.Reset() + rec := httptest.NewRecorder() + code, err := em.ServeHTTP(rec, req) + + if err != test.expectedErr { + t.Errorf("Test %d: Expected error %v, but got %v", + i, test.expectedErr, err) + } + if code != test.expectedCode { + t.Errorf("Test %d: Expected status code %d, but got %d", + i, test.expectedCode, code) + } + if body := rec.Body.String(); body != test.expectedBody { + t.Errorf("Test %d: Expected body %q, but got %q", + i, test.expectedBody, body) + } + if log := buf.String(); !strings.Contains(log, test.expectedLog) { + t.Errorf("Test %d: Expected log %q, but got %q", + i, test.expectedLog, log) + } + } +} + +func TestVisibleErrorWithPanic(t *testing.T) { + const panicMsg = "I'm a panic" + eh := ErrorHandler{ + ErrorPages: make(map[int]string), + Debug: true, + Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + panic(panicMsg) + }), + } + + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + rec := httptest.NewRecorder() + + code, err := eh.ServeHTTP(rec, req) + + if code != 0 { + t.Errorf("Expected error handler to return 0 (it should write to response), got status %d", code) + } + if err != nil { + t.Errorf("Expected error handler to return nil error (it should panic!), but got '%v'", err) + } + + body := rec.Body.String() + + if !strings.Contains(body, "[PANIC /] middleware/errors/errors_test.go") { + t.Errorf("Expected response body to contain error log line, but it didn't:\n%s", body) + } + if !strings.Contains(body, panicMsg) { + t.Errorf("Expected response body to contain panic message, but it didn't:\n%s", body) + } + if len(body) < 500 { + t.Errorf("Expected response body to contain stack trace, but it was too short: len=%d", len(body)) + } +} + +func genErrorHandler(status int, err error, body string) middleware.Handler { + return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + if len(body) > 0 { + w.Header().Set("Content-Length", strconv.Itoa(len(body))) + fmt.Fprint(w, body) + } + return status, err + }) +} diff --git a/middleware/etcd/TODO b/middleware/etcd/TODO new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/middleware/etcd/TODO diff --git a/middleware/exchange.go b/middleware/exchange.go new file mode 100644 index 000000000..837fa3cdc --- /dev/null +++ b/middleware/exchange.go @@ -0,0 +1,10 @@ +package middleware + +import "github.com/miekg/dns" + +// Exchang sends message m to the server. +// TODO(miek): optionally it can do retries of other silly stuff. +func Exchange(c *dns.Client, m *dns.Msg, server string) (*dns.Msg, error) { + r, _, err := c.Exchange(m, server) + return r, err +} diff --git a/middleware/file/file.go b/middleware/file/file.go new file mode 100644 index 000000000..5bc5a3a3a --- /dev/null +++ b/middleware/file/file.go @@ -0,0 +1,89 @@ +package file + +// TODO(miek): the zone's implementation is basically non-existent +// we return a list and when searching for an answer we iterate +// over the list. This must be moved to a tree-like structure and +// have some fluff for DNSSEC (and be memory efficient). + +import ( + "strings" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +type ( + File struct { + Next middleware.Handler + Zones Zones + // Maybe a list of all zones as well, as a []string? + } + + Zone []dns.RR + Zones struct { + Z map[string]Zone // utterly braindead impl. TODO(miek): fix + Names []string + } +) + +func (f File) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + context := middleware.Context{W: w, Req: r} + qname := context.Name() + zone := middleware.Zones(f.Zones.Names).Matches(qname) + if zone == "" { + return f.Next.ServeDNS(w, r) + } + + names, nodata := f.Zones.Z[zone].lookup(qname, context.QType()) + var answer *dns.Msg + switch { + case nodata: + answer = context.AnswerMessage() + answer.Ns = names + case len(names) == 0: + answer = context.AnswerMessage() + answer.Ns = names + answer.Rcode = dns.RcodeNameError + case len(names) > 0: + answer = context.AnswerMessage() + answer.Answer = names + default: + answer = context.ErrorMessage(dns.RcodeServerFailure) + } + // Check return size, etc. TODO(miek) + w.WriteMsg(answer) + return 0, nil +} + +// Lookup will try to find qname and qtype in z. It returns the +// records found *or* a boolean saying NODATA. If the answer +// is NODATA then the RR returned is the SOA record. +// +// TODO(miek): EXTREMELY STUPID IMPLEMENTATION. +// Doesn't do much, no delegation, no cname, nothing really, etc. +// TODO(miek): even NODATA looks broken +func (z Zone) lookup(qname string, qtype uint16) ([]dns.RR, bool) { + var ( + nodata bool + rep []dns.RR + soa dns.RR + ) + + for _, rr := range z { + if rr.Header().Rrtype == dns.TypeSOA { + soa = rr + } + // Match function in Go DNS? + if strings.ToLower(rr.Header().Name) == qname { + if rr.Header().Rrtype == qtype { + rep = append(rep, rr) + nodata = false + } + + } + } + if nodata { + return []dns.RR{soa}, true + } + return rep, false +} diff --git a/middleware/file/file_test.go b/middleware/file/file_test.go new file mode 100644 index 000000000..54584b5cc --- /dev/null +++ b/middleware/file/file_test.go @@ -0,0 +1,325 @@ +package file + +import ( + "errors" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +var testDir = filepath.Join(os.TempDir(), "caddy_testdir") +var ErrCustom = errors.New("Custom Error") + +// testFiles is a map with relative paths to test files as keys and file content as values. +// The map represents the following structure: +// - $TEMP/caddy_testdir/ +// '-- file1.html +// '-- dirwithindex/ +// '---- index.html +// '-- dir/ +// '---- file2.html +// '---- hidden.html +var testFiles = map[string]string{ + "file1.html": "<h1>file1.html</h1>", + filepath.Join("dirwithindex", "index.html"): "<h1>dirwithindex/index.html</h1>", + filepath.Join("dir", "file2.html"): "<h1>dir/file2.html</h1>", + filepath.Join("dir", "hidden.html"): "<h1>dir/hidden.html</h1>", +} + +// TestServeHTTP covers positive scenarios when serving files. +func TestServeHTTP(t *testing.T) { + + beforeServeHTTPTest(t) + defer afterServeHTTPTest(t) + + fileserver := FileServer(http.Dir(testDir), []string{"hidden.html"}) + + movedPermanently := "Moved Permanently" + + tests := []struct { + url string + + expectedStatus int + expectedBodyContent string + }{ + // Test 0 - access without any path + { + url: "https://foo", + expectedStatus: http.StatusNotFound, + }, + // Test 1 - access root (without index.html) + { + url: "https://foo/", + expectedStatus: http.StatusNotFound, + }, + // Test 2 - access existing file + { + url: "https://foo/file1.html", + expectedStatus: http.StatusOK, + expectedBodyContent: testFiles["file1.html"], + }, + // Test 3 - access folder with index file with trailing slash + { + url: "https://foo/dirwithindex/", + expectedStatus: http.StatusOK, + expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")], + }, + // Test 4 - access folder with index file without trailing slash + { + url: "https://foo/dirwithindex", + expectedStatus: http.StatusMovedPermanently, + expectedBodyContent: movedPermanently, + }, + // Test 5 - access folder without index file + { + url: "https://foo/dir/", + expectedStatus: http.StatusNotFound, + }, + // Test 6 - access folder without trailing slash + { + url: "https://foo/dir", + expectedStatus: http.StatusMovedPermanently, + expectedBodyContent: movedPermanently, + }, + // Test 6 - access file with trailing slash + { + url: "https://foo/file1.html/", + expectedStatus: http.StatusMovedPermanently, + expectedBodyContent: movedPermanently, + }, + // Test 7 - access not existing path + { + url: "https://foo/not_existing", + expectedStatus: http.StatusNotFound, + }, + // Test 8 - access a file, marked as hidden + { + url: "https://foo/dir/hidden.html", + expectedStatus: http.StatusNotFound, + }, + // Test 9 - access a index file directly + { + url: "https://foo/dirwithindex/index.html", + expectedStatus: http.StatusOK, + expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")], + }, + // Test 10 - send a request with query params + { + url: "https://foo/dir?param1=val", + expectedStatus: http.StatusMovedPermanently, + expectedBodyContent: movedPermanently, + }, + } + + for i, test := range tests { + responseRecorder := httptest.NewRecorder() + request, err := http.NewRequest("GET", test.url, strings.NewReader("")) + status, err := fileserver.ServeHTTP(responseRecorder, request) + + // check if error matches expectations + if err != nil { + t.Errorf(getTestPrefix(i)+"Serving file at %s failed. Error was: %v", test.url, err) + } + + // check status code + if test.expectedStatus != status { + t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status) + } + + // check body content + if !strings.Contains(responseRecorder.Body.String(), test.expectedBodyContent) { + t.Errorf(getTestPrefix(i)+"Expected body to contain %q, found %q", test.expectedBodyContent, responseRecorder.Body.String()) + } + } + +} + +// beforeServeHTTPTest creates a test directory with the structure, defined in the variable testFiles +func beforeServeHTTPTest(t *testing.T) { + // make the root test dir + err := os.Mkdir(testDir, os.ModePerm) + if err != nil { + if !os.IsExist(err) { + t.Fatalf("Failed to create test dir. Error was: %v", err) + return + } + } + + for relFile, fileContent := range testFiles { + absFile := filepath.Join(testDir, relFile) + + // make sure the parent directories exist + parentDir := filepath.Dir(absFile) + _, err = os.Stat(parentDir) + if err != nil { + os.MkdirAll(parentDir, os.ModePerm) + } + + // now create the test files + f, err := os.Create(absFile) + if err != nil { + t.Fatalf("Failed to create test file %s. Error was: %v", absFile, err) + return + } + + // and fill them with content + _, err = f.WriteString(fileContent) + if err != nil { + t.Fatalf("Failed to write to %s. Error was: %v", absFile, err) + return + } + f.Close() + } + +} + +// afterServeHTTPTest removes the test dir and all its content +func afterServeHTTPTest(t *testing.T) { + // cleans up everything under the test dir. No need to clean the individual files. + err := os.RemoveAll(testDir) + if err != nil { + t.Fatalf("Failed to clean up test dir %s. Error was: %v", testDir, err) + } +} + +// failingFS implements the http.FileSystem interface. The Open method always returns the error, assigned to err +type failingFS struct { + err error // the error to return when Open is called + fileImpl http.File // inject the file implementation +} + +// Open returns the assigned failingFile and error +func (f failingFS) Open(path string) (http.File, error) { + return f.fileImpl, f.err +} + +// failingFile implements http.File but returns a predefined error on every Stat() method call. +type failingFile struct { + http.File + err error +} + +// Stat returns nil FileInfo and the provided error on every call +func (ff failingFile) Stat() (os.FileInfo, error) { + return nil, ff.err +} + +// Close is noop and returns no error +func (ff failingFile) Close() error { + return nil +} + +// TestServeHTTPFailingFS tests error cases where the Open function fails with various errors. +func TestServeHTTPFailingFS(t *testing.T) { + + tests := []struct { + fsErr error + expectedStatus int + expectedErr error + expectedHeaders map[string]string + }{ + { + fsErr: os.ErrNotExist, + expectedStatus: http.StatusNotFound, + expectedErr: nil, + }, + { + fsErr: os.ErrPermission, + expectedStatus: http.StatusForbidden, + expectedErr: os.ErrPermission, + }, + { + fsErr: ErrCustom, + expectedStatus: http.StatusServiceUnavailable, + expectedErr: ErrCustom, + expectedHeaders: map[string]string{"Retry-After": "5"}, + }, + } + + for i, test := range tests { + // initialize a file server with the failing FileSystem + fileserver := FileServer(failingFS{err: test.fsErr}, nil) + + // prepare the request and response + request, err := http.NewRequest("GET", "https://foo/", nil) + if err != nil { + t.Fatalf("Failed to build request. Error was: %v", err) + } + responseRecorder := httptest.NewRecorder() + + status, actualErr := fileserver.ServeHTTP(responseRecorder, request) + + // check the status + if status != test.expectedStatus { + t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status) + } + + // check the error + if actualErr != test.expectedErr { + t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr) + } + + // check the headers - a special case for server under load + if test.expectedHeaders != nil && len(test.expectedHeaders) > 0 { + for expectedKey, expectedVal := range test.expectedHeaders { + actualVal := responseRecorder.Header().Get(expectedKey) + if expectedVal != actualVal { + t.Errorf(getTestPrefix(i)+"Expected header %s: %s, found %s", expectedKey, expectedVal, actualVal) + } + } + } + } +} + +// TestServeHTTPFailingStat tests error cases where the initial Open function succeeds, but the Stat method on the opened file fails. +func TestServeHTTPFailingStat(t *testing.T) { + + tests := []struct { + statErr error + expectedStatus int + expectedErr error + }{ + { + statErr: os.ErrNotExist, + expectedStatus: http.StatusNotFound, + expectedErr: nil, + }, + { + statErr: os.ErrPermission, + expectedStatus: http.StatusForbidden, + expectedErr: os.ErrPermission, + }, + { + statErr: ErrCustom, + expectedStatus: http.StatusInternalServerError, + expectedErr: ErrCustom, + }, + } + + for i, test := range tests { + // initialize a file server. The FileSystem will not fail, but calls to the Stat method of the returned File object will + fileserver := FileServer(failingFS{err: nil, fileImpl: failingFile{err: test.statErr}}, nil) + + // prepare the request and response + request, err := http.NewRequest("GET", "https://foo/", nil) + if err != nil { + t.Fatalf("Failed to build request. Error was: %v", err) + } + responseRecorder := httptest.NewRecorder() + + status, actualErr := fileserver.ServeHTTP(responseRecorder, request) + + // check the status + if status != test.expectedStatus { + t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status) + } + + // check the error + if actualErr != test.expectedErr { + t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr) + } + } +} diff --git a/middleware/host.go b/middleware/host.go new file mode 100644 index 000000000..17ecedb5f --- /dev/null +++ b/middleware/host.go @@ -0,0 +1,22 @@ +package middleware + +import ( + "net" + "strings" + + "github.com/miekg/dns" +) + +// Host represents a host from the Caddyfile, may contain port. +type Host string + +// Standard host will return the host portion of host, stripping +// of any port. The host will also be fully qualified and lowercased. +func (h Host) StandardHost() string { + // separate host and port + host, _, err := net.SplitHostPort(string(h)) + if err != nil { + host, _, _ = net.SplitHostPort(string(h) + ":") + } + return strings.ToLower(dns.Fqdn(host)) +} diff --git a/middleware/log/log.go b/middleware/log/log.go new file mode 100644 index 000000000..109add9f5 --- /dev/null +++ b/middleware/log/log.go @@ -0,0 +1,66 @@ +// Package log implements basic but useful request (access) logging middleware. +package log + +import ( + "log" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +// Logger is a basic request logging middleware. +type Logger struct { + Next middleware.Handler + Rules []Rule + ErrorFunc func(dns.ResponseWriter, *dns.Msg, int) // failover error handler +} + +func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + for _, rule := range l.Rules { + /* + if middleware.Path(r.URL.Path).Matches(rule.PathScope) { + responseRecorder := middleware.NewResponseRecorder(w) + status, err := l.Next.ServeHTTP(responseRecorder, r) + if status >= 400 { + // There was an error up the chain, but no response has been written yet. + // The error must be handled here so the log entry will record the response size. + if l.ErrorFunc != nil { + l.ErrorFunc(responseRecorder, r, status) + } else { + // Default failover error handler + responseRecorder.WriteHeader(status) + fmt.Fprintf(responseRecorder, "%d %s", status, http.StatusText(status)) + } + status = 0 + } + rep := middleware.NewReplacer(r, responseRecorder, CommonLogEmptyValue) + rule.Log.Println(rep.Replace(rule.Format)) + return status, err + } + */ + rule = rule + } + return l.Next.ServeDNS(w, r) +} + +// Rule configures the logging middleware. +type Rule struct { + PathScope string + OutputFile string + Format string + Log *log.Logger + Roller *middleware.LogRoller +} + +const ( + // DefaultLogFilename is the default log filename. + DefaultLogFilename = "access.log" + // CommonLogFormat is the common log format. + CommonLogFormat = `{remote} ` + CommonLogEmptyValue + ` [{when}] "{type} {name} {proto}" {rcode} {size}` + // CommonLogEmptyValue is the common empty log value. + CommonLogEmptyValue = "-" + // CombinedLogFormat is the combined log format. + CombinedLogFormat = CommonLogFormat + ` "{>Referer}" "{>User-Agent}"` // Something here as well + // DefaultLogFormat is the default log format. + DefaultLogFormat = CommonLogFormat +) diff --git a/middleware/log/log_test.go b/middleware/log/log_test.go new file mode 100644 index 000000000..40560e4c0 --- /dev/null +++ b/middleware/log/log_test.go @@ -0,0 +1,48 @@ +package log + +import ( + "bytes" + "log" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +type erroringMiddleware struct{} + +func (erroringMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { + return http.StatusNotFound, nil +} + +func TestLoggedStatus(t *testing.T) { + var f bytes.Buffer + var next erroringMiddleware + rule := Rule{ + PathScope: "/", + Format: DefaultLogFormat, + Log: log.New(&f, "", 0), + } + + logger := Logger{ + Rules: []Rule{rule}, + Next: next, + } + + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + rec := httptest.NewRecorder() + + status, err := logger.ServeHTTP(rec, r) + if status != 0 { + t.Error("Expected status to be 0 - was", status) + } + + logged := f.String() + if !strings.Contains(logged, "404 13") { + t.Error("Expected 404 to be logged. Logged string -", logged) + } +} diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 000000000..436ec86e9 --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,105 @@ +// Package middleware provides some types and functions common among middleware. +package middleware + +import ( + "time" + + "github.com/miekg/dns" +) + +type ( + // Middleware is the middle layer which represents the traditional + // idea of middleware: it chains one Handler to the next by being + // passed the next Handler in the chain. + Middleware func(Handler) Handler + + // Handler is like dns.Handler except ServeDNS may return an rcode + // and/or error. + // + // If ServeDNS writes to the response body, it should return a status + // code of 0. This signals to other handlers above it that the response + // body is already written, and that they should not write to it also. + // + // If ServeDNS encounters an error, it should return the error value + // so it can be logged by designated error-handling middleware. + // + // If writing a response after calling another ServeDNS method, the + // returned rcode SHOULD be used when writing the response. + // + // If handling errors after calling another ServeDNS method, the + // returned error value SHOULD be logged or handled accordingly. + // + // Otherwise, return values should be propagated down the middleware + // chain by returning them unchanged. + Handler interface { + ServeDNS(dns.ResponseWriter, *dns.Msg) (int, error) + } + + // HandlerFunc is a convenience type like dns.HandlerFunc, except + // ServeDNS returns an rcode and an error. See Handler + // documentation for more information. + HandlerFunc func(dns.ResponseWriter, *dns.Msg) (int, error) +) + +// ServeDNS implements the Handler interface. +func (f HandlerFunc) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + return f(w, r) +} + +// IndexFile looks for a file in /root/fpath/indexFile for each string +// in indexFiles. If an index file is found, it returns the root-relative +// path to the file and true. If no index file is found, empty string +// and false is returned. fpath must end in a forward slash '/' +// otherwise no index files will be tried (directory paths must end +// in a forward slash according to HTTP). +// +// All paths passed into and returned from this function use '/' as the +// path separator, just like URLs. IndexFle handles path manipulation +// internally for systems that use different path separators. +/* +func IndexFile(root http.FileSystem, fpath string, indexFiles []string) (string, bool) { + if fpath[len(fpath)-1] != '/' || root == nil { + return "", false + } + for _, indexFile := range indexFiles { + // func (http.FileSystem).Open wants all paths separated by "/", + // regardless of operating system convention, so use + // path.Join instead of filepath.Join + fp := path.Join(fpath, indexFile) + f, err := root.Open(fp) + if err == nil { + f.Close() + return fp, true + } + } + return "", false +} + +// SetLastModifiedHeader checks if the provided modTime is valid and if it is sets it +// as a Last-Modified header to the ResponseWriter. If the modTime is in the future +// the current time is used instead. +func SetLastModifiedHeader(w http.ResponseWriter, modTime time.Time) { + if modTime.IsZero() || modTime.Equal(time.Unix(0, 0)) { + // the time does not appear to be valid. Don't put it in the response + return + } + + // RFC 2616 - Section 14.29 - Last-Modified: + // An origin server MUST NOT send a Last-Modified date which is later than the + // server's time of message origination. In such cases, where the resource's last + // modification would indicate some time in the future, the server MUST replace + // that date with the message origination date. + now := currentTime() + if modTime.After(now) { + modTime = now + } + + w.Header().Set("Last-Modified", modTime.UTC().Format(http.TimeFormat)) +} +*/ + +// currentTime, as it is defined here, returns time.Now(). +// It's defined as a variable for mocking time in tests. +var currentTime = func() time.Time { + return time.Now() +} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go new file mode 100644 index 000000000..62fa4e250 --- /dev/null +++ b/middleware/middleware_test.go @@ -0,0 +1,108 @@ +package middleware + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestIndexfile(t *testing.T) { + tests := []struct { + rootDir http.FileSystem + fpath string + indexFiles []string + shouldErr bool + expectedFilePath string //retun value + expectedBoolValue bool //return value + }{ + { + http.Dir("./templates/testdata"), + "/images/", + []string{"img.htm"}, + false, + "/images/img.htm", + true, + }, + } + for i, test := range tests { + actualFilePath, actualBoolValue := IndexFile(test.rootDir, test.fpath, test.indexFiles) + if actualBoolValue == true && test.shouldErr { + t.Errorf("Test %d didn't error, but it should have", i) + } else if actualBoolValue != true && !test.shouldErr { + t.Errorf("Test %d errored, but it shouldn't have; got %s", i, "Please Add a / at the end of fpath or the indexFiles doesnt exist") + } + if actualFilePath != test.expectedFilePath { + t.Fatalf("Test %d expected returned filepath to be %s, but got %s ", + i, test.expectedFilePath, actualFilePath) + + } + if actualBoolValue != test.expectedBoolValue { + t.Fatalf("Test %d expected returned bool value to be %v, but got %v ", + i, test.expectedBoolValue, actualBoolValue) + + } + } +} + +func TestSetLastModified(t *testing.T) { + nowTime := time.Now() + + // ovewrite the function to return reliable time + originalGetCurrentTimeFunc := currentTime + currentTime = func() time.Time { + return nowTime + } + defer func() { + currentTime = originalGetCurrentTimeFunc + }() + + pastTime := nowTime.Truncate(1 * time.Hour) + futureTime := nowTime.Add(1 * time.Hour) + + tests := []struct { + inputModTime time.Time + expectedIsHeaderSet bool + expectedLastModified string + }{ + { + inputModTime: pastTime, + expectedIsHeaderSet: true, + expectedLastModified: pastTime.UTC().Format(http.TimeFormat), + }, + { + inputModTime: nowTime, + expectedIsHeaderSet: true, + expectedLastModified: nowTime.UTC().Format(http.TimeFormat), + }, + { + inputModTime: futureTime, + expectedIsHeaderSet: true, + expectedLastModified: nowTime.UTC().Format(http.TimeFormat), + }, + { + inputModTime: time.Time{}, + expectedIsHeaderSet: false, + }, + } + + for i, test := range tests { + responseRecorder := httptest.NewRecorder() + errorPrefix := fmt.Sprintf("Test [%d]: ", i) + SetLastModifiedHeader(responseRecorder, test.inputModTime) + actualLastModifiedHeader := responseRecorder.Header().Get("Last-Modified") + + if test.expectedIsHeaderSet && actualLastModifiedHeader == "" { + t.Fatalf(errorPrefix + "Expected to find Last-Modified header, but found nothing") + } + + if !test.expectedIsHeaderSet && actualLastModifiedHeader != "" { + t.Fatalf(errorPrefix+"Did not expect to find Last-Modified header, but found one [%s].", actualLastModifiedHeader) + } + + if test.expectedLastModified != actualLastModifiedHeader { + t.Errorf(errorPrefix+"Expected Last-Modified content [%s], found [%s}", test.expectedLastModified, actualLastModifiedHeader) + } + } +} diff --git a/middleware/path.go b/middleware/path.go new file mode 100644 index 000000000..1ffb64b76 --- /dev/null +++ b/middleware/path.go @@ -0,0 +1,18 @@ +package middleware + +import "strings" + + +// TODO(miek): matches for names. + +// Path represents a URI path, maybe with pattern characters. +type Path string + +// Matches checks to see if other matches p. +// +// Path matching will probably not always be a direct +// comparison; this method assures that paths can be +// easily and consistently matched. +func (p Path) Matches(other string) bool { + return strings.HasPrefix(string(p), other) +} diff --git a/middleware/prometheus/handler.go b/middleware/prometheus/handler.go new file mode 100644 index 000000000..eb82b8aff --- /dev/null +++ b/middleware/prometheus/handler.go @@ -0,0 +1,31 @@ +package metrics + +import ( + "strconv" + "time" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + context := middleware.Context{W: w, Req: r} + + qname := context.Name() + qtype := context.Type() + zone := middleware.Zones(m.ZoneNames).Matches(qname) + if zone == "" { + zone = "." + } + + // Record response to get status code and size of the reply. + rw := middleware.NewResponseRecorder(w) + status, err := m.Next.ServeDNS(rw, r) + + requestCount.WithLabelValues(zone, qtype).Inc() + requestDuration.WithLabelValues(zone).Observe(float64(time.Since(rw.Start()) / time.Second)) + responseSize.WithLabelValues(zone).Observe(float64(rw.Size())) + responseRcode.WithLabelValues(zone, strconv.Itoa(rw.Rcode())).Inc() + + return status, err +} diff --git a/middleware/prometheus/metrics.go b/middleware/prometheus/metrics.go new file mode 100644 index 000000000..4c989f640 --- /dev/null +++ b/middleware/prometheus/metrics.go @@ -0,0 +1,80 @@ +package metrics + +import ( + "fmt" + "net/http" + "sync" + + "github.com/miekg/coredns/middleware" + "github.com/prometheus/client_golang/prometheus" +) + +const namespace = "daddy" + +var ( + requestCount *prometheus.CounterVec + requestDuration *prometheus.HistogramVec + responseSize *prometheus.HistogramVec + responseRcode *prometheus.CounterVec +) + +const path = "/metrics" + +// Metrics holds the prometheus configuration. The metrics' path is fixed to be /metrics +type Metrics struct { + Next middleware.Handler + Addr string // where to we listen + Once sync.Once + ZoneNames []string +} + +func (m *Metrics) Start() error { + m.Once.Do(func() { + define("") + + prometheus.MustRegister(requestCount) + prometheus.MustRegister(requestDuration) + prometheus.MustRegister(responseSize) + prometheus.MustRegister(responseRcode) + + http.Handle(path, prometheus.Handler()) + go func() { + fmt.Errorf("%s", http.ListenAndServe(m.Addr, nil)) + }() + }) + return nil +} + +func define(subsystem string) { + if subsystem == "" { + subsystem = "dns" + } + requestCount = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "request_count_total", + Help: "Counter of DNS requests made per zone and type.", + }, []string{"zone", "qtype"}) + + requestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "request_duration_seconds", + Help: "Histogram of the time (in seconds) each request took.", + }, []string{"zone"}) + + responseSize = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "response_size_bytes", + Help: "Size of the returns response in bytes.", + Buckets: []float64{0, 100, 200, 300, 400, 511, 1023, 2047, 4095, 8291, 16e3, 32e3, 48e3, 64e3}, + }, []string{"zone"}) + + responseRcode = prometheus.NewCounterVec(prometheus.CounterOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "rcode_code_count_total", + Help: "Counter of response status codes.", + }, []string{"zone", "rcode"}) +} diff --git a/middleware/proxy/policy.go b/middleware/proxy/policy.go new file mode 100644 index 000000000..a2522bcb1 --- /dev/null +++ b/middleware/proxy/policy.go @@ -0,0 +1,101 @@ +package proxy + +import ( + "math/rand" + "sync/atomic" +) + +// HostPool is a collection of UpstreamHosts. +type HostPool []*UpstreamHost + +// Policy decides how a host will be selected from a pool. +type Policy interface { + Select(pool HostPool) *UpstreamHost +} + +func init() { + RegisterPolicy("random", func() Policy { return &Random{} }) + RegisterPolicy("least_conn", func() Policy { return &LeastConn{} }) + RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} }) +} + +// Random is a policy that selects up hosts from a pool at random. +type Random struct{} + +// Select selects an up host at random from the specified pool. +func (r *Random) Select(pool HostPool) *UpstreamHost { + // instead of just generating a random index + // this is done to prevent selecting a down host + var randHost *UpstreamHost + count := 0 + for _, host := range pool { + if host.Down() { + continue + } + count++ + if count == 1 { + randHost = host + } else { + r := rand.Int() % count + if r == (count - 1) { + randHost = host + } + } + } + return randHost +} + +// LeastConn is a policy that selects the host with the least connections. +type LeastConn struct{} + +// Select selects the up host with the least number of connections in the +// pool. If more than one host has the same least number of connections, +// one of the hosts is chosen at random. +func (r *LeastConn) Select(pool HostPool) *UpstreamHost { + var bestHost *UpstreamHost + count := 0 + leastConn := int64(1<<63 - 1) + for _, host := range pool { + if host.Down() { + continue + } + hostConns := host.Conns + if hostConns < leastConn { + bestHost = host + leastConn = hostConns + count = 1 + } else if hostConns == leastConn { + // randomly select host among hosts with least connections + count++ + if count == 1 { + bestHost = host + } else { + r := rand.Int() % count + if r == (count - 1) { + bestHost = host + } + } + } + } + return bestHost +} + +// RoundRobin is a policy that selects hosts based on round robin ordering. +type RoundRobin struct { + Robin uint32 +} + +// Select selects an up host from the pool using a round robin ordering scheme. +func (r *RoundRobin) Select(pool HostPool) *UpstreamHost { + poolLen := uint32(len(pool)) + selection := atomic.AddUint32(&r.Robin, 1) % poolLen + host := pool[selection] + // if the currently selected host is down, just ffwd to up host + for i := uint32(1); host.Down() && i < poolLen; i++ { + host = pool[(selection+i)%poolLen] + } + if host.Down() { + return nil + } + return host +} diff --git a/middleware/proxy/policy_test.go b/middleware/proxy/policy_test.go new file mode 100644 index 000000000..8f4f1f792 --- /dev/null +++ b/middleware/proxy/policy_test.go @@ -0,0 +1,87 @@ +package proxy + +import ( + "net/http" + "net/http/httptest" + "os" + "testing" +) + +var workableServer *httptest.Server + +func TestMain(m *testing.M) { + workableServer = httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // do nothing + })) + r := m.Run() + workableServer.Close() + os.Exit(r) +} + +type customPolicy struct{} + +func (r *customPolicy) Select(pool HostPool) *UpstreamHost { + return pool[0] +} + +func testPool() HostPool { + pool := []*UpstreamHost{ + { + Name: workableServer.URL, // this should resolve (healthcheck test) + }, + { + Name: "http://shouldnot.resolve", // this shouldn't + }, + { + Name: "http://C", + }, + } + return HostPool(pool) +} + +func TestRoundRobinPolicy(t *testing.T) { + pool := testPool() + rrPolicy := &RoundRobin{} + h := rrPolicy.Select(pool) + // First selected host is 1, because counter starts at 0 + // and increments before host is selected + if h != pool[1] { + t.Error("Expected first round robin host to be second host in the pool.") + } + h = rrPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected second round robin host to be third host in the pool.") + } + // mark host as down + pool[0].Unhealthy = true + h = rrPolicy.Select(pool) + if h != pool[1] { + t.Error("Expected third round robin host to be first host in the pool.") + } +} + +func TestLeastConnPolicy(t *testing.T) { + pool := testPool() + lcPolicy := &LeastConn{} + pool[0].Conns = 10 + pool[1].Conns = 10 + h := lcPolicy.Select(pool) + if h != pool[2] { + t.Error("Expected least connection host to be third host.") + } + pool[2].Conns = 100 + h = lcPolicy.Select(pool) + if h != pool[0] && h != pool[1] { + t.Error("Expected least connection host to be first or second host.") + } +} + +func TestCustomPolicy(t *testing.T) { + pool := testPool() + customPolicy := &customPolicy{} + h := customPolicy.Select(pool) + if h != pool[0] { + t.Error("Expected custom policy host to be the first host.") + } +} diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go new file mode 100644 index 000000000..169e41b61 --- /dev/null +++ b/middleware/proxy/proxy.go @@ -0,0 +1,120 @@ +// Package proxy is middleware that proxies requests. +package proxy + +import ( + "errors" + "net/http" + "sync/atomic" + "time" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +var errUnreachable = errors.New("unreachable backend") + +// Proxy represents a middleware instance that can proxy requests. +type Proxy struct { + Next middleware.Handler + Client Client + Upstreams []Upstream +} + +type Client struct { + UDP *dns.Client + TCP *dns.Client +} + +// Upstream manages a pool of proxy upstream hosts. Select should return a +// suitable upstream host, or nil if no such hosts are available. +type Upstream interface { + // The domain name this upstream host should be routed on. + From() string + // Selects an upstream host to be routed to. + Select() *UpstreamHost + // Checks if subpdomain is not an ignored. + IsAllowedPath(string) bool +} + +// UpstreamHostDownFunc can be used to customize how Down behaves. +type UpstreamHostDownFunc func(*UpstreamHost) bool + +// UpstreamHost represents a single proxy upstream +type UpstreamHost struct { + Conns int64 // must be first field to be 64-bit aligned on 32-bit systems + Name string // IP address (and port) of this upstream host + Fails int32 + FailTimeout time.Duration + Unhealthy bool + ExtraHeaders http.Header + CheckDown UpstreamHostDownFunc + WithoutPathPrefix string +} + +// Down checks whether the upstream host is down or not. +// Down will try to use uh.CheckDown first, and will fall +// back to some default criteria if necessary. +func (uh *UpstreamHost) Down() bool { + if uh.CheckDown == nil { + // Default settings + return uh.Unhealthy || uh.Fails > 0 + } + return uh.CheckDown(uh) +} + +// tryDuration is how long to try upstream hosts; failures result in +// immediate retries until this duration ends or we get a nil host. +var tryDuration = 60 * time.Second + +// ServeDNS satisfies the middleware.Handler interface. +func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + for _, upstream := range p.Upstreams { + // allowed bla bla bla TODO(miek): fix full proxy spec from caddy + start := time.Now() + + // Since Select() should give us "up" hosts, keep retrying + // hosts until timeout (or until we get a nil host). + for time.Now().Sub(start) < tryDuration { + host := upstream.Select() + if host == nil { + return dns.RcodeServerFailure, errUnreachable + } + // TODO(miek): PORT! + reverseproxy := ReverseProxy{Host: host.Name, Client: p.Client} + + atomic.AddInt64(&host.Conns, 1) + backendErr := reverseproxy.ServeDNS(w, r, nil) + atomic.AddInt64(&host.Conns, -1) + if backendErr == nil { + return 0, nil + } + timeout := host.FailTimeout + if timeout == 0 { + timeout = 10 * time.Second + } + atomic.AddInt32(&host.Fails, 1) + go func(host *UpstreamHost, timeout time.Duration) { + time.Sleep(timeout) + atomic.AddInt32(&host.Fails, -1) + }(host, timeout) + } + return dns.RcodeServerFailure, errUnreachable + } + return p.Next.ServeDNS(w, r) +} + +func Clients() Client { + udp := newClient("udp", defaultTimeout) + tcp := newClient("tcp", defaultTimeout) + return Client{UDP: udp, TCP: tcp} +} + +// newClient returns a new client for proxy requests. +func newClient(net string, timeout time.Duration) *dns.Client { + if timeout == 0 { + timeout = defaultTimeout + } + return &dns.Client{Net: net, ReadTimeout: timeout, WriteTimeout: timeout, SingleInflight: true} +} + +const defaultTimeout = 5 * time.Second diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go new file mode 100644 index 000000000..8066874d2 --- /dev/null +++ b/middleware/proxy/proxy_test.go @@ -0,0 +1,317 @@ +package proxy + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "golang.org/x/net/websocket" +) + +func init() { + tryDuration = 50 * time.Millisecond // prevent tests from hanging +} + +func TestReverseProxy(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + var requestReceived bool + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestReceived = true + w.Write([]byte("Hello, client")) + })) + defer backend.Close() + + // set up proxy + p := &Proxy{ + Upstreams: []Upstream{newFakeUpstream(backend.URL, false)}, + } + + // create request and response recorder + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + w := httptest.NewRecorder() + + p.ServeHTTP(w, r) + + if !requestReceived { + t.Error("Expected backend to receive request, but it didn't") + } +} + +func TestReverseProxyInsecureSkipVerify(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + var requestReceived bool + backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestReceived = true + w.Write([]byte("Hello, client")) + })) + defer backend.Close() + + // set up proxy + p := &Proxy{ + Upstreams: []Upstream{newFakeUpstream(backend.URL, true)}, + } + + // create request and response recorder + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + w := httptest.NewRecorder() + + p.ServeHTTP(w, r) + + if !requestReceived { + t.Error("Even with insecure HTTPS, expected backend to receive request, but it didn't") + } +} + +func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { + // No-op websocket backend simply allows the WS connection to be + // accepted then it will be immediately closed. Perfect for testing. + wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {})) + defer wsNop.Close() + + // Get proxy to use for the test + p := newWebSocketTestProxy(wsNop.URL) + + // Create client request + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + r.Header = http.Header{ + "Connection": {"Upgrade"}, + "Upgrade": {"websocket"}, + "Origin": {wsNop.URL}, + "Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="}, + "Sec-WebSocket-Version": {"13"}, + } + + // Capture the request + w := &recorderHijacker{httptest.NewRecorder(), new(fakeConn)} + + // Booya! Do the test. + p.ServeHTTP(w, r) + + // Make sure the backend accepted the WS connection. + // Mostly interested in the Upgrade and Connection response headers + // and the 101 status code. + expected := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n\r\n") + actual := w.fakeConn.writeBuf.Bytes() + if !bytes.Equal(actual, expected) { + t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual) + } +} + +func TestWebSocketReverseProxyFromWSClient(t *testing.T) { + // Echo server allows us to test that socket bytes are properly + // being proxied. + wsEcho := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { + io.Copy(ws, ws) + })) + defer wsEcho.Close() + + // Get proxy to use for the test + p := newWebSocketTestProxy(wsEcho.URL) + + // This is a full end-end test, so the proxy handler + // has to be part of a server listening on a port. Our + // WS client will connect to this test server, not + // the echo client directly. + echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + defer echoProxy.Close() + + // Set up WebSocket client + url := strings.Replace(echoProxy.URL, "http://", "ws://", 1) + ws, err := websocket.Dial(url, "", echoProxy.URL) + if err != nil { + t.Fatal(err) + } + defer ws.Close() + + // Send test message + trialMsg := "Is it working?" + websocket.Message.Send(ws, trialMsg) + + // It should be echoed back to us + var actualMsg string + websocket.Message.Receive(ws, &actualMsg) + if actualMsg != trialMsg { + t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) + } +} + +func TestUnixSocketProxy(t *testing.T) { + if runtime.GOOS == "windows" { + return + } + + trialMsg := "Is it working?" + + var proxySuccess bool + + // This is our fake "application" we want to proxy to + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Request was proxied when this is called + proxySuccess = true + + fmt.Fprint(w, trialMsg) + })) + + // Get absolute path for unix: socket + socketPath, err := filepath.Abs("./test_socket") + if err != nil { + t.Fatalf("Unable to get absolute path: %v", err) + } + + // Change httptest.Server listener to listen to unix: socket + ln, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("Unable to listen: %v", err) + } + ts.Listener = ln + + ts.Start() + defer ts.Close() + + url := strings.Replace(ts.URL, "http://", "unix:", 1) + p := newWebSocketTestProxy(url) + + echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + defer echoProxy.Close() + + res, err := http.Get(echoProxy.URL) + if err != nil { + t.Fatalf("Unable to GET: %v", err) + } + + greeting, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatalf("Unable to GET: %v", err) + } + + actualMsg := fmt.Sprintf("%s", greeting) + + if !proxySuccess { + t.Errorf("Expected request to be proxied, but it wasn't") + } + + if actualMsg != trialMsg { + t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) + } +} + +func newFakeUpstream(name string, insecure bool) *fakeUpstream { + uri, _ := url.Parse(name) + u := &fakeUpstream{ + name: name, + host: &UpstreamHost{ + Name: name, + ReverseProxy: NewSingleHostReverseProxy(uri, ""), + }, + } + if insecure { + u.host.ReverseProxy.Transport = InsecureTransport + } + return u +} + +type fakeUpstream struct { + name string + host *UpstreamHost +} + +func (u *fakeUpstream) From() string { + return "/" +} + +func (u *fakeUpstream) Select() *UpstreamHost { + return u.host +} + +func (u *fakeUpstream) IsAllowedPath(requestPath string) bool { + return true +} + +// newWebSocketTestProxy returns a test proxy that will +// redirect to the specified backendAddr. The function +// also sets up the rules/environment for testing WebSocket +// proxy. +func newWebSocketTestProxy(backendAddr string) *Proxy { + return &Proxy{ + Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr}}, + } +} + +type fakeWsUpstream struct { + name string +} + +func (u *fakeWsUpstream) From() string { + return "/" +} + +func (u *fakeWsUpstream) Select() *UpstreamHost { + uri, _ := url.Parse(u.name) + return &UpstreamHost{ + Name: u.name, + ReverseProxy: NewSingleHostReverseProxy(uri, ""), + ExtraHeaders: http.Header{ + "Connection": {"{>Connection}"}, + "Upgrade": {"{>Upgrade}"}}, + } +} + +func (u *fakeWsUpstream) IsAllowedPath(requestPath string) bool { + return true +} + +// recorderHijacker is a ResponseRecorder that can +// be hijacked. +type recorderHijacker struct { + *httptest.ResponseRecorder + fakeConn *fakeConn +} + +func (rh *recorderHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return rh.fakeConn, nil, nil +} + +type fakeConn struct { + readBuf bytes.Buffer + writeBuf bytes.Buffer +} + +func (c *fakeConn) LocalAddr() net.Addr { return nil } +func (c *fakeConn) RemoteAddr() net.Addr { return nil } +func (c *fakeConn) SetDeadline(t time.Time) error { return nil } +func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil } +func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil } +func (c *fakeConn) Close() error { return nil } +func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) } +func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) } diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go new file mode 100644 index 000000000..6d27da042 --- /dev/null +++ b/middleware/proxy/reverseproxy.go @@ -0,0 +1,36 @@ +// Package proxy is middleware that proxies requests. +package proxy + +import ( + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +type ReverseProxy struct { + Host string + Client Client +} + +func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error { + // TODO(miek): use extra! + var ( + reply *dns.Msg + err error + ) + context := middleware.Context{W: w, Req: r} + + // tls+tcp ? + if context.Proto() == "tcp" { + reply, err = middleware.Exchange(p.Client.TCP, r, p.Host) + } else { + reply, err = middleware.Exchange(p.Client.UDP, r, p.Host) + } + + if err != nil { + return err + } + reply.Compress = true + reply.Id = r.Id + w.WriteMsg(reply) + return nil +} diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go new file mode 100644 index 000000000..092e2351d --- /dev/null +++ b/middleware/proxy/upstream.go @@ -0,0 +1,235 @@ +package proxy + +import ( + "io" + "io/ioutil" + "net/http" + "path" + "strconv" + "time" + + "github.com/miekg/coredns/core/parse" + "github.com/miekg/coredns/middleware" +) + +var ( + supportedPolicies = make(map[string]func() Policy) +) + +type staticUpstream struct { + from string + // TODO(miek): allows use to added headers + proxyHeaders http.Header // TODO(miek): kill + Hosts HostPool + Policy Policy + + FailTimeout time.Duration + MaxFails int32 + HealthCheck struct { + Path string + Interval time.Duration + } + WithoutPathPrefix string + IgnoredSubPaths []string +} + +// NewStaticUpstreams parses the configuration input and sets up +// static upstreams for the proxy middleware. +func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { + var upstreams []Upstream + for c.Next() { + upstream := &staticUpstream{ + from: "", + proxyHeaders: make(http.Header), + Hosts: nil, + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + + if !c.Args(&upstream.from) { + return upstreams, c.ArgErr() + } + to := c.RemainingArgs() + if len(to) == 0 { + return upstreams, c.ArgErr() + } + + for c.NextBlock() { + if err := parseBlock(&c, upstream); err != nil { + return upstreams, err + } + } + + upstream.Hosts = make([]*UpstreamHost, len(to)) + for i, host := range to { + uh := &UpstreamHost{ + Name: host, + Conns: 0, + Fails: 0, + FailTimeout: upstream.FailTimeout, + Unhealthy: false, + ExtraHeaders: upstream.proxyHeaders, + CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc { + return func(uh *UpstreamHost) bool { + if uh.Unhealthy { + return true + } + if uh.Fails >= upstream.MaxFails && + upstream.MaxFails != 0 { + return true + } + return false + } + }(upstream), + WithoutPathPrefix: upstream.WithoutPathPrefix, + } + upstream.Hosts[i] = uh + } + + if upstream.HealthCheck.Path != "" { + go upstream.HealthCheckWorker(nil) + } + upstreams = append(upstreams, upstream) + } + return upstreams, nil +} + +// RegisterPolicy adds a custom policy to the proxy. +func RegisterPolicy(name string, policy func() Policy) { + supportedPolicies[name] = policy +} + +func (u *staticUpstream) From() string { + return u.from +} + +func parseBlock(c *parse.Dispenser, u *staticUpstream) error { + switch c.Val() { + case "policy": + if !c.NextArg() { + return c.ArgErr() + } + policyCreateFunc, ok := supportedPolicies[c.Val()] + if !ok { + return c.ArgErr() + } + u.Policy = policyCreateFunc() + case "fail_timeout": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + u.FailTimeout = dur + case "max_fails": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.Atoi(c.Val()) + if err != nil { + return err + } + u.MaxFails = int32(n) + case "health_check": + if !c.NextArg() { + return c.ArgErr() + } + u.HealthCheck.Path = c.Val() + u.HealthCheck.Interval = 30 * time.Second + if c.NextArg() { + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + u.HealthCheck.Interval = dur + } + case "proxy_header": + var header, value string + if !c.Args(&header, &value) { + return c.ArgErr() + } + u.proxyHeaders.Add(header, value) + case "websocket": + u.proxyHeaders.Add("Connection", "{>Connection}") + u.proxyHeaders.Add("Upgrade", "{>Upgrade}") + case "without": + if !c.NextArg() { + return c.ArgErr() + } + u.WithoutPathPrefix = c.Val() + case "except": + ignoredPaths := c.RemainingArgs() + if len(ignoredPaths) == 0 { + return c.ArgErr() + } + u.IgnoredSubPaths = ignoredPaths + default: + return c.Errf("unknown property '%s'", c.Val()) + } + return nil +} + +func (u *staticUpstream) healthCheck() { + for _, host := range u.Hosts { + hostURL := host.Name + u.HealthCheck.Path + if r, err := http.Get(hostURL); err == nil { + io.Copy(ioutil.Discard, r.Body) + r.Body.Close() + host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400 + } else { + host.Unhealthy = true + } + } +} + +func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) { + ticker := time.NewTicker(u.HealthCheck.Interval) + u.healthCheck() + for { + select { + case <-ticker.C: + u.healthCheck() + case <-stop: + // TODO: the library should provide a stop channel and global + // waitgroup to allow goroutines started by plugins a chance + // to clean themselves up. + } + } +} + +func (u *staticUpstream) Select() *UpstreamHost { + pool := u.Hosts + if len(pool) == 1 { + if pool[0].Down() { + return nil + } + return pool[0] + } + allDown := true + for _, host := range pool { + if !host.Down() { + allDown = false + break + } + } + if allDown { + return nil + } + + if u.Policy == nil { + return (&Random{}).Select(pool) + } + return u.Policy.Select(pool) +} + +func (u *staticUpstream) IsAllowedPath(requestPath string) bool { + for _, ignoredSubPath := range u.IgnoredSubPaths { + if middleware.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) { + return false + } + } + return true +} diff --git a/middleware/proxy/upstream_test.go b/middleware/proxy/upstream_test.go new file mode 100644 index 000000000..5b2fdb1da --- /dev/null +++ b/middleware/proxy/upstream_test.go @@ -0,0 +1,83 @@ +package proxy + +import ( + "testing" + "time" +) + +func TestHealthCheck(t *testing.T) { + upstream := &staticUpstream{ + from: "", + Hosts: testPool(), + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + upstream.healthCheck() + if upstream.Hosts[0].Down() { + t.Error("Expected first host in testpool to not fail healthcheck.") + } + if !upstream.Hosts[1].Down() { + t.Error("Expected second host in testpool to fail healthcheck.") + } +} + +func TestSelect(t *testing.T) { + upstream := &staticUpstream{ + from: "", + Hosts: testPool()[:3], + Policy: &Random{}, + FailTimeout: 10 * time.Second, + MaxFails: 1, + } + upstream.Hosts[0].Unhealthy = true + upstream.Hosts[1].Unhealthy = true + upstream.Hosts[2].Unhealthy = true + if h := upstream.Select(); h != nil { + t.Error("Expected select to return nil as all host are down") + } + upstream.Hosts[2].Unhealthy = false + if h := upstream.Select(); h == nil { + t.Error("Expected select to not return nil") + } +} + +func TestRegisterPolicy(t *testing.T) { + name := "custom" + customPolicy := &customPolicy{} + RegisterPolicy(name, func() Policy { return customPolicy }) + if _, ok := supportedPolicies[name]; !ok { + t.Error("Expected supportedPolicies to have a custom policy.") + } + +} + +func TestAllowedPaths(t *testing.T) { + upstream := &staticUpstream{ + from: "/proxy", + IgnoredSubPaths: []string{"/download", "/static"}, + } + tests := []struct { + url string + expected bool + }{ + {"/proxy", true}, + {"/proxy/dl", true}, + {"/proxy/download", false}, + {"/proxy/download/static", false}, + {"/proxy/static", false}, + {"/proxy/static/download", false}, + {"/proxy/something/download", true}, + {"/proxy/something/static", true}, + {"/proxy//static", false}, + {"/proxy//static//download", false}, + {"/proxy//download", false}, + } + + for i, test := range tests { + isAllowed := upstream.IsAllowedPath(test.url) + if test.expected != isAllowed { + t.Errorf("Test %d: expected %v found %v", i+1, test.expected, isAllowed) + } + } +} diff --git a/middleware/recorder.go b/middleware/recorder.go new file mode 100644 index 000000000..38a7e0e82 --- /dev/null +++ b/middleware/recorder.go @@ -0,0 +1,70 @@ +package middleware + +import ( + "time" + + "github.com/miekg/dns" +) + +// ResponseRecorder is a type of ResponseWriter that captures +// the rcode code written to it and also the size of the message +// written in the response. A rcode code does not have +// to be written, however, in which case 0 must be assumed. +// It is best to have the constructor initialize this type +// with that default status code. +type ResponseRecorder struct { + dns.ResponseWriter + rcode int + size int + start time.Time +} + +// NewResponseRecorder makes and returns a new responseRecorder, +// which captures the DNS rcode from the ResponseWriter +// and also the length of the response message written through it. +func NewResponseRecorder(w dns.ResponseWriter) *ResponseRecorder { + return &ResponseRecorder{ + ResponseWriter: w, + rcode: 0, + start: time.Now(), + } +} + +// WriteMsg records the status code and calls the +// underlying ResponseWriter's WriteMsg method. +func (r *ResponseRecorder) WriteMsg(res *dns.Msg) error { + r.rcode = res.Rcode + r.size = res.Len() + return r.ResponseWriter.WriteMsg(res) +} + +// Write is a wrapper that records the size of the message that gets written. +func (r *ResponseRecorder) Write(buf []byte) (int, error) { + n, err := r.ResponseWriter.Write(buf) + if err == nil { + r.size += n + } + return n, err +} + +// Size returns the size. +func (r *ResponseRecorder) Size() int { + return r.size +} + +// Rcode returns the rcode. +func (r *ResponseRecorder) Rcode() int { + return r.rcode +} + +// Start returns the start time of the ResponseRecorder. +func (r *ResponseRecorder) Start() time.Time { + return r.start +} + +// Hijack implements dns.Hijacker. It simply wraps the underlying +// ResponseWriter's Hijack method if there is one, or returns an error. +func (r *ResponseRecorder) Hijack() { + r.ResponseWriter.Hijack() + return +} diff --git a/middleware/recorder_test.go b/middleware/recorder_test.go new file mode 100644 index 000000000..a8c8a5d04 --- /dev/null +++ b/middleware/recorder_test.go @@ -0,0 +1,32 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewResponseRecorder(t *testing.T) { + w := httptest.NewRecorder() + recordRequest := NewResponseRecorder(w) + if !(recordRequest.ResponseWriter == w) { + t.Fatalf("Expected Response writer in the Recording to be same as the one sent\n") + } + if recordRequest.status != http.StatusOK { + t.Fatalf("Expected recorded status to be http.StatusOK (%d) , but found %d\n ", http.StatusOK, recordRequest.status) + } +} + +func TestWrite(t *testing.T) { + w := httptest.NewRecorder() + responseTestString := "test" + recordRequest := NewResponseRecorder(w) + buf := []byte(responseTestString) + recordRequest.Write(buf) + if recordRequest.size != len(buf) { + t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(buf), recordRequest.size) + } + if w.Body.String() != responseTestString { + t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String()) + } +} diff --git a/middleware/reflect/reflect.go b/middleware/reflect/reflect.go new file mode 100644 index 000000000..6d5847b81 --- /dev/null +++ b/middleware/reflect/reflect.go @@ -0,0 +1,84 @@ +// Reflect provides middleware that reflects back some client properties. +// This is the default middleware when Caddy is run without configuration. +// +// The left-most label must be `who`. +// When queried for type A (resp. AAAA), it sends back the IPv4 (resp. v6) address. +// In the additional section the port number and transport are shown. +// Basic use pattern: +// +// dig @localhost -p 1053 who.miek.nl A +// +// ;; ANSWER SECTION: +// who.miek.nl. 0 IN A 127.0.0.1 +// +// ;; ADDITIONAL SECTION: +// who.miek.nl. 0 IN TXT "Port: 56195 (udp)" +package reflect + +import ( + "errors" + "net" + "strings" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +type Reflect struct { + Next middleware.Handler +} + +func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + context := middleware.Context{Req: r, W: w} + + class := r.Question[0].Qclass + qname := r.Question[0].Name + i, ok := dns.NextLabel(qname, 0) + + if strings.ToLower(qname[:i]) != who || ok { + err := context.ErrorMessage(dns.RcodeFormatError) + w.WriteMsg(err) + return dns.RcodeFormatError, errors.New(dns.RcodeToString[dns.RcodeFormatError]) + } + + answer := new(dns.Msg) + answer.SetReply(r) + answer.Compress = true + answer.Authoritative = true + + ip := context.IP() + proto := context.Proto() + port, _ := context.Port() + family := context.Family() + var rr dns.RR + + switch family { + case 1: + rr = new(dns.A) + rr.(*dns.A).Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeA, Class: class, Ttl: 0} + rr.(*dns.A).A = net.ParseIP(ip).To4() + case 2: + rr = new(dns.AAAA) + rr.(*dns.AAAA).Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeAAAA, Class: class, Ttl: 0} + rr.(*dns.AAAA).AAAA = net.ParseIP(ip) + } + + t := new(dns.TXT) + t.Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeTXT, Class: class, Ttl: 0} + t.Txt = []string{"Port: " + port + " (" + proto + ")"} + + switch context.Type() { + case "TXT": + answer.Answer = append(answer.Answer, t) + answer.Extra = append(answer.Extra, rr) + default: + fallthrough + case "AAAA", "A": + answer.Answer = append(answer.Answer, rr) + answer.Extra = append(answer.Extra, t) + } + w.WriteMsg(answer) + return 0, nil +} + +const who = "who." diff --git a/middleware/reflect/reflect_test.go b/middleware/reflect/reflect_test.go new file mode 100644 index 000000000..477a3a573 --- /dev/null +++ b/middleware/reflect/reflect_test.go @@ -0,0 +1 @@ +package reflect diff --git a/middleware/replacer.go b/middleware/replacer.go new file mode 100644 index 000000000..133da74c5 --- /dev/null +++ b/middleware/replacer.go @@ -0,0 +1,98 @@ +package middleware + +import ( + "strconv" + "strings" + "time" + + "github.com/miekg/dns" +) + +// Replacer is a type which can replace placeholder +// substrings in a string with actual values from a +// http.Request and responseRecorder. Always use +// NewReplacer to get one of these. +type Replacer interface { + Replace(string) string + Set(key, value string) +} + +type replacer struct { + replacements map[string]string + emptyValue string +} + +// NewReplacer makes a new replacer based on r and rr. +// Do not create a new replacer until r and rr have all +// the needed values, because this function copies those +// values into the replacer. rr may be nil if it is not +// available. emptyValue should be the string that is used +// in place of empty string (can still be empty string). +func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer { + context := Context{W: rr, Req: r} + rep := replacer{ + replacements: map[string]string{ + "{type}": context.Type(), + "{name}": context.Name(), + "{class}": context.Class(), + "{proto}": context.Proto(), + "{when}": func() string { + return time.Now().Format(timeFormat) + }(), + "{remote}": context.IP(), + "{port}": func() string { + p, _ := context.Port() + return p + }(), + }, + emptyValue: emptyValue, + } + if rr != nil { + rep.replacements["{rcode}"] = strconv.Itoa(rr.rcode) + rep.replacements["{size}"] = strconv.Itoa(rr.size) + rep.replacements["{latency}"] = time.Since(rr.start).String() + } + + return rep +} + +// Replace performs a replacement of values on s and returns +// the string with the replaced values. +func (r replacer) Replace(s string) string { + // Header replacements - these are case-insensitive, so we can't just use strings.Replace() + for strings.Contains(s, headerReplacer) { + idxStart := strings.Index(s, headerReplacer) + endOffset := idxStart + len(headerReplacer) + idxEnd := strings.Index(s[endOffset:], "}") + if idxEnd > -1 { + placeholder := strings.ToLower(s[idxStart : endOffset+idxEnd+1]) + replacement := r.replacements[placeholder] + if replacement == "" { + replacement = r.emptyValue + } + s = s[:idxStart] + replacement + s[endOffset+idxEnd+1:] + } else { + break + } + } + + // Regular replacements - these are easier because they're case-sensitive + for placeholder, replacement := range r.replacements { + if replacement == "" { + replacement = r.emptyValue + } + s = strings.Replace(s, placeholder, replacement, -1) + } + + return s +} + +// Set sets key to value in the replacements map. +func (r replacer) Set(key, value string) { + r.replacements["{"+key+"}"] = value +} + +const ( + timeFormat = "02/Jan/2006:15:04:05 -0700" + headerReplacer = "{>" +) diff --git a/middleware/replacer_test.go b/middleware/replacer_test.go new file mode 100644 index 000000000..d98bd2de1 --- /dev/null +++ b/middleware/replacer_test.go @@ -0,0 +1,124 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestNewReplacer(t *testing.T) { + w := httptest.NewRecorder() + recordRequest := NewResponseRecorder(w) + reader := strings.NewReader(`{"username": "dennis"}`) + + request, err := http.NewRequest("POST", "http://localhost", reader) + if err != nil { + t.Fatal("Request Formation Failed\n") + } + replaceValues := NewReplacer(request, recordRequest, "") + + switch v := replaceValues.(type) { + case replacer: + + if v.replacements["{host}"] != "localhost" { + t.Error("Expected host to be localhost") + } + if v.replacements["{method}"] != "POST" { + t.Error("Expected request method to be POST") + } + if v.replacements["{status}"] != "200" { + t.Error("Expected status to be 200") + } + + default: + t.Fatal("Return Value from New Replacer expected pass type assertion into a replacer type\n") + } +} + +func TestReplace(t *testing.T) { + w := httptest.NewRecorder() + recordRequest := NewResponseRecorder(w) + reader := strings.NewReader(`{"username": "dennis"}`) + + request, err := http.NewRequest("POST", "http://localhost", reader) + if err != nil { + t.Fatal("Request Formation Failed\n") + } + request.Header.Set("Custom", "foobarbaz") + request.Header.Set("ShorterVal", "1") + repl := NewReplacer(request, recordRequest, "-") + + if expected, actual := "This host is localhost.", repl.Replace("This host is {host}."); expected != actual { + t.Errorf("{host} replacement: expected '%s', got '%s'", expected, actual) + } + if expected, actual := "This request method is POST.", repl.Replace("This request method is {method}."); expected != actual { + t.Errorf("{method} replacement: expected '%s', got '%s'", expected, actual) + } + if expected, actual := "The response status is 200.", repl.Replace("The response status is {status}."); expected != actual { + t.Errorf("{status} replacement: expected '%s', got '%s'", expected, actual) + } + if expected, actual := "The Custom header is foobarbaz.", repl.Replace("The Custom header is {>Custom}."); expected != actual { + t.Errorf("{>Custom} replacement: expected '%s', got '%s'", expected, actual) + } + + // Test header case-insensitivity + if expected, actual := "The cUsToM header is foobarbaz...", repl.Replace("The cUsToM header is {>cUsToM}..."); expected != actual { + t.Errorf("{>cUsToM} replacement: expected '%s', got '%s'", expected, actual) + } + + // Test non-existent header/value + if expected, actual := "The Non-Existent header is -.", repl.Replace("The Non-Existent header is {>Non-Existent}."); expected != actual { + t.Errorf("{>Non-Existent} replacement: expected '%s', got '%s'", expected, actual) + } + + // Test bad placeholder + if expected, actual := "Bad {host placeholder...", repl.Replace("Bad {host placeholder..."); expected != actual { + t.Errorf("bad placeholder: expected '%s', got '%s'", expected, actual) + } + + // Test bad header placeholder + if expected, actual := "Bad {>Custom placeholder", repl.Replace("Bad {>Custom placeholder"); expected != actual { + t.Errorf("bad header placeholder: expected '%s', got '%s'", expected, actual) + } + + // Test bad header placeholder with valid one later + if expected, actual := "Bad -", repl.Replace("Bad {>Custom placeholder {>ShorterVal}"); expected != actual { + t.Errorf("bad header placeholders: expected '%s', got '%s'", expected, actual) + } + + // Test shorter header value with multiple placeholders + if expected, actual := "Short value 1 then foobarbaz.", repl.Replace("Short value {>ShorterVal} then {>Custom}."); expected != actual { + t.Errorf("short value: expected '%s', got '%s'", expected, actual) + } +} + +func TestSet(t *testing.T) { + w := httptest.NewRecorder() + recordRequest := NewResponseRecorder(w) + reader := strings.NewReader(`{"username": "dennis"}`) + + request, err := http.NewRequest("POST", "http://localhost", reader) + if err != nil { + t.Fatalf("Request Formation Failed \n") + } + repl := NewReplacer(request, recordRequest, "") + + repl.Set("host", "getcaddy.com") + repl.Set("method", "GET") + repl.Set("status", "201") + repl.Set("variable", "value") + + if repl.Replace("This host is {host}") != "This host is getcaddy.com" { + t.Error("Expected host replacement failed") + } + if repl.Replace("This request method is {method}") != "This request method is GET" { + t.Error("Expected method replacement failed") + } + if repl.Replace("The response status is {status}") != "The response status is 201" { + t.Error("Expected status replacement failed") + } + if repl.Replace("The value of variable is {variable}") != "The value of variable is value" { + t.Error("Expected variable replacement failed") + } +} diff --git a/middleware/rewrite/condition.go b/middleware/rewrite/condition.go new file mode 100644 index 000000000..ddd4c38b1 --- /dev/null +++ b/middleware/rewrite/condition.go @@ -0,0 +1,130 @@ +package rewrite + +import ( + "fmt" + "regexp" + "strings" + + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +// Operators +const ( + Is = "is" + Not = "not" + Has = "has" + NotHas = "not_has" + StartsWith = "starts_with" + EndsWith = "ends_with" + Match = "match" + NotMatch = "not_match" +) + +func operatorError(operator string) error { + return fmt.Errorf("Invalid operator %v", operator) +} + +func newReplacer(r *dns.Msg) middleware.Replacer { + return middleware.NewReplacer(r, nil, "") +} + +// condition is a rewrite condition. +type condition func(string, string) bool + +var conditions = map[string]condition{ + Is: isFunc, + Not: notFunc, + Has: hasFunc, + NotHas: notHasFunc, + StartsWith: startsWithFunc, + EndsWith: endsWithFunc, + Match: matchFunc, + NotMatch: notMatchFunc, +} + +// isFunc is condition for Is operator. +// It checks for equality. +func isFunc(a, b string) bool { + return a == b +} + +// notFunc is condition for Not operator. +// It checks for inequality. +func notFunc(a, b string) bool { + return a != b +} + +// hasFunc is condition for Has operator. +// It checks if b is a substring of a. +func hasFunc(a, b string) bool { + return strings.Contains(a, b) +} + +// notHasFunc is condition for NotHas operator. +// It checks if b is not a substring of a. +func notHasFunc(a, b string) bool { + return !strings.Contains(a, b) +} + +// startsWithFunc is condition for StartsWith operator. +// It checks if b is a prefix of a. +func startsWithFunc(a, b string) bool { + return strings.HasPrefix(a, b) +} + +// endsWithFunc is condition for EndsWith operator. +// It checks if b is a suffix of a. +func endsWithFunc(a, b string) bool { + return strings.HasSuffix(a, b) +} + +// matchFunc is condition for Match operator. +// It does regexp matching of a against pattern in b +// and returns if they match. +func matchFunc(a, b string) bool { + matched, _ := regexp.MatchString(b, a) + return matched +} + +// notMatchFunc is condition for NotMatch operator. +// It does regexp matching of a against pattern in b +// and returns if they do not match. +func notMatchFunc(a, b string) bool { + matched, _ := regexp.MatchString(b, a) + return !matched +} + +// If is statement for a rewrite condition. +type If struct { + A string + Operator string + B string +} + +// True returns true if the condition is true and false otherwise. +// If r is not nil, it replaces placeholders before comparison. +func (i If) True(r *dns.Msg) bool { + if c, ok := conditions[i.Operator]; ok { + a, b := i.A, i.B + if r != nil { + replacer := newReplacer(r) + a = replacer.Replace(i.A) + b = replacer.Replace(i.B) + } + return c(a, b) + } + return false +} + +// NewIf creates a new If condition. +func NewIf(a, operator, b string) (If, error) { + if _, ok := conditions[operator]; !ok { + return If{}, operatorError(operator) + } + return If{ + A: a, + Operator: operator, + B: b, + }, nil +} diff --git a/middleware/rewrite/condition_test.go b/middleware/rewrite/condition_test.go new file mode 100644 index 000000000..3c3b6053a --- /dev/null +++ b/middleware/rewrite/condition_test.go @@ -0,0 +1,106 @@ +package rewrite + +import ( + "net/http" + "strings" + "testing" +) + +func TestConditions(t *testing.T) { + tests := []struct { + condition string + isTrue bool + }{ + {"a is b", false}, + {"a is a", true}, + {"a not b", true}, + {"a not a", false}, + {"a has a", true}, + {"a has b", false}, + {"ba has b", true}, + {"bab has b", true}, + {"bab has bb", false}, + {"a not_has a", false}, + {"a not_has b", true}, + {"ba not_has b", false}, + {"bab not_has b", false}, + {"bab not_has bb", true}, + {"bab starts_with bb", false}, + {"bab starts_with ba", true}, + {"bab starts_with bab", true}, + {"bab ends_with bb", false}, + {"bab ends_with bab", true}, + {"bab ends_with ab", true}, + {"a match *", false}, + {"a match a", true}, + {"a match .*", true}, + {"a match a.*", true}, + {"a match b.*", false}, + {"ba match b.*", true}, + {"ba match b[a-z]", true}, + {"b0 match b[a-z]", false}, + {"b0a match b[a-z]", false}, + {"b0a match b[a-z]+", false}, + {"b0a match b[a-z0-9]+", true}, + {"a not_match *", true}, + {"a not_match a", false}, + {"a not_match .*", false}, + {"a not_match a.*", false}, + {"a not_match b.*", true}, + {"ba not_match b.*", false}, + {"ba not_match b[a-z]", false}, + {"b0 not_match b[a-z]", true}, + {"b0a not_match b[a-z]", true}, + {"b0a not_match b[a-z]+", true}, + {"b0a not_match b[a-z0-9]+", false}, + } + + for i, test := range tests { + str := strings.Fields(test.condition) + ifCond, err := NewIf(str[0], str[1], str[2]) + if err != nil { + t.Error(err) + } + isTrue := ifCond.True(nil) + if isTrue != test.isTrue { + t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue) + } + } + + invalidOperators := []string{"ss", "and", "if"} + for _, op := range invalidOperators { + _, err := NewIf("a", op, "b") + if err == nil { + t.Errorf("Invalid operator %v used, expected error.", op) + } + } + + replaceTests := []struct { + url string + condition string + isTrue bool + }{ + {"/home", "{uri} match /home", true}, + {"/hom", "{uri} match /home", false}, + {"/hom", "{uri} starts_with /home", false}, + {"/hom", "{uri} starts_with /h", true}, + {"/home/.hiddenfile", `{uri} match \/\.(.*)`, true}, + {"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true}, + } + + for i, test := range replaceTests { + r, err := http.NewRequest("GET", test.url, nil) + if err != nil { + t.Error(err) + } + str := strings.Fields(test.condition) + ifCond, err := NewIf(str[0], str[1], str[2]) + if err != nil { + t.Error(err) + } + isTrue := ifCond.True(r) + if isTrue != test.isTrue { + t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue) + } + } +} diff --git a/middleware/rewrite/reverter.go b/middleware/rewrite/reverter.go new file mode 100644 index 000000000..c3425866e --- /dev/null +++ b/middleware/rewrite/reverter.go @@ -0,0 +1,38 @@ +package rewrite + +import "github.com/miekg/dns" + +// ResponseRevert reverses the operations done on the question section of a packet. +// This is need because the client will otherwise disregards the response, i.e. +// dig will complain with ';; Question section mismatch: got miek.nl/HINFO/IN' +type ResponseReverter struct { + dns.ResponseWriter + original dns.Question +} + +func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg) *ResponseReverter { + return &ResponseReverter{ + ResponseWriter: w, + original: r.Question[0], + } +} + +// WriteMsg records the status code and calls the +// underlying ResponseWriter's WriteMsg method. +func (r *ResponseReverter) WriteMsg(res *dns.Msg) error { + res.Question[0] = r.original + return r.ResponseWriter.WriteMsg(res) +} + +// Write is a wrapper that records the size of the message that gets written. +func (r *ResponseReverter) Write(buf []byte) (int, error) { + n, err := r.ResponseWriter.Write(buf) + return n, err +} + +// Hijack implements dns.Hijacker. It simply wraps the underlying +// ResponseWriter's Hijack method if there is one, or returns an error. +func (r *ResponseReverter) Hijack() { + r.ResponseWriter.Hijack() + return +} diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go new file mode 100644 index 000000000..b3039615b --- /dev/null +++ b/middleware/rewrite/rewrite.go @@ -0,0 +1,223 @@ +// Package rewrite is middleware for rewriting requests internally to +// something different. +package rewrite + +import ( + "github.com/miekg/coredns/middleware" + "github.com/miekg/dns" +) + +// Result is the result of a rewrite +type Result int + +const ( + // RewriteIgnored is returned when rewrite is not done on request. + RewriteIgnored Result = iota + // RewriteDone is returned when rewrite is done on request. + RewriteDone + // RewriteStatus is returned when rewrite is not needed and status code should be set + // for the request. + RewriteStatus +) + +// Rewrite is middleware to rewrite requests internally before being handled. +type Rewrite struct { + Next middleware.Handler + Rules []Rule +} + +// ServeHTTP implements the middleware.Handler interface. +func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) { + wr := NewResponseReverter(w, r) + for _, rule := range rw.Rules { + switch result := rule.Rewrite(r); result { + case RewriteDone: + return rw.Next.ServeDNS(wr, r) + case RewriteIgnored: + break + case RewriteStatus: + // only valid for complex rules. + // if cRule, ok := rule.(*ComplexRule); ok && cRule.Status != 0 { + // return cRule.Status, nil + // } + } + } + return rw.Next.ServeDNS(w, r) +} + +// Rule describes an internal location rewrite rule. +type Rule interface { + // Rewrite rewrites the internal location of the current request. + Rewrite(*dns.Msg) Result +} + +// SimpleRule is a simple rewrite rule. If the From and To look like a type +// the type of the request is rewritten, otherwise the name is. +// Note: TSIG signed requests will be invalid. +type SimpleRule struct { + From, To string + fromType, toType uint16 +} + +// NewSimpleRule creates a new Simple Rule +func NewSimpleRule(from, to string) SimpleRule { + tpf := dns.StringToType[from] + tpt := dns.StringToType[to] + + return SimpleRule{From: from, To: to, fromType: tpf, toType: tpt} +} + +// Rewrite rewrites the the current request. +func (s SimpleRule) Rewrite(r *dns.Msg) Result { + if s.fromType > 0 && s.toType > 0 { + if r.Question[0].Qtype == s.fromType { + r.Question[0].Qtype = s.toType + return RewriteDone + } + + } + + // if the question name matches the full name, or subset rewrite that + // s.Question[0].Name + return RewriteIgnored +} + +/* +// ComplexRule is a rewrite rule based on a regular expression +type ComplexRule struct { + // Path base. Request to this path and subpaths will be rewritten + Base string + + // Path to rewrite to + To string + + // If set, neither performs rewrite nor proceeds + // with request. Only returns code. + Status int + + // Extensions to filter by + Exts []string + + // Rewrite conditions + Ifs []If + + *regexp.Regexp +} + +// NewComplexRule creates a new RegexpRule. It returns an error if regexp +// pattern (pattern) or extensions (ext) are invalid. +func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) { + // validate regexp if present + var r *regexp.Regexp + if pattern != "" { + var err error + r, err = regexp.Compile(pattern) + if err != nil { + return nil, err + } + } + + // validate extensions if present + for _, v := range ext { + if len(v) < 2 || (len(v) < 3 && v[0] == '!') { + // check if no extension is specified + if v != "/" && v != "!/" { + return nil, fmt.Errorf("invalid extension %v", v) + } + } + } + + return &ComplexRule{ + Base: base, + To: to, + Status: status, + Exts: ext, + Ifs: ifs, + Regexp: r, + }, nil +} + +// Rewrite rewrites the internal location of the current request. +func (r *ComplexRule) Rewrite(req *dns.Msg) (re Result) { + rPath := req.URL.Path + replacer := newReplacer(req) + + // validate base + if !middleware.Path(rPath).Matches(r.Base) { + return + } + + // validate extensions + if !r.matchExt(rPath) { + return + } + + // validate regexp if present + if r.Regexp != nil { + // include trailing slash in regexp if present + start := len(r.Base) + if strings.HasSuffix(r.Base, "/") { + start-- + } + + matches := r.FindStringSubmatch(rPath[start:]) + switch len(matches) { + case 0: + // no match + return + default: + // set regexp match variables {1}, {2} ... + for i := 1; i < len(matches); i++ { + replacer.Set(fmt.Sprint(i), matches[i]) + } + } + } + + // validate rewrite conditions + for _, i := range r.Ifs { + if !i.True(req) { + return + } + } + + // if status is present, stop rewrite and return it. + if r.Status != 0 { + return RewriteStatus + } + + // attempt rewrite + return To(fs, req, r.To, replacer) +} + +// matchExt matches rPath against registered file extensions. +// Returns true if a match is found and false otherwise. +func (r *ComplexRule) matchExt(rPath string) bool { + f := filepath.Base(rPath) + ext := path.Ext(f) + if ext == "" { + ext = "/" + } + + mustUse := false + for _, v := range r.Exts { + use := true + if v[0] == '!' { + use = false + v = v[1:] + } + + if use { + mustUse = true + } + + if ext == v { + return use + } + } + + if mustUse { + return false + } + return true +} +*/ diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go new file mode 100644 index 000000000..f57dfd602 --- /dev/null +++ b/middleware/rewrite/rewrite_test.go @@ -0,0 +1,159 @@ +package rewrite + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/miekg/coredns/middleware" +) + +func TestRewrite(t *testing.T) { + rw := Rewrite{ + Next: middleware.HandlerFunc(urlPrinter), + Rules: []Rule{ + NewSimpleRule("/from", "/to"), + NewSimpleRule("/a", "/b"), + NewSimpleRule("/b", "/b{uri}"), + }, + FileSys: http.Dir("."), + } + + regexps := [][]string{ + {"/reg/", ".*", "/to", ""}, + {"/r/", "[a-z]+", "/toaz", "!.html|"}, + {"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""}, + {"/ab/", "ab", "/ab?{query}", ".txt|"}, + {"/ab/", "ab", "/ab?type=html&{query}", ".html|"}, + {"/abc/", "ab", "/abc/{file}", ".html|"}, + {"/abcd/", "ab", "/a/{dir}/{file}", ".html|"}, + {"/abcde/", "ab", "/a#{fragment}", ".html|"}, + {"/ab/", `.*\.jpg`, "/ajpg", ""}, + {"/reggrp", `/ad/([0-9]+)([a-z]*)`, "/a{1}/{2}", ""}, + {"/reg2grp", `(.*)`, "/{1}", ""}, + {"/reg3grp", `(.*)/(.*)/(.*)`, "/{1}{2}{3}", ""}, + } + + for _, regexpRule := range regexps { + var ext []string + if s := strings.Split(regexpRule[3], "|"); len(s) > 1 { + ext = s[:len(s)-1] + } + rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], 0, ext, nil) + if err != nil { + t.Fatal(err) + } + rw.Rules = append(rw.Rules, rule) + } + + tests := []struct { + from string + expectedTo string + }{ + {"/from", "/to"}, + {"/a", "/b"}, + {"/b", "/b/b"}, + {"/aa", "/aa"}, + {"/", "/"}, + {"/a?foo=bar", "/b?foo=bar"}, + {"/asdf?foo=bar", "/asdf?foo=bar"}, + {"/foo#bar", "/foo#bar"}, + {"/a#foo", "/b#foo"}, + {"/reg/foo", "/to"}, + {"/re", "/re"}, + {"/r/", "/r/"}, + {"/r/123", "/r/123"}, + {"/r/a123", "/toaz"}, + {"/r/abcz", "/toaz"}, + {"/r/z", "/toaz"}, + {"/r/z.html", "/r/z.html"}, + {"/r/z.js", "/toaz"}, + {"/url/asAB", "/to/url/asAB"}, + {"/url/aBsAB", "/url/aBsAB"}, + {"/url/a00sAB", "/to/url/a00sAB"}, + {"/url/a0z0sAB", "/to/url/a0z0sAB"}, + {"/ab/aa", "/ab/aa"}, + {"/ab/ab", "/ab/ab"}, + {"/ab/ab.txt", "/ab"}, + {"/ab/ab.txt?name=name", "/ab?name=name"}, + {"/ab/ab.html?name=name", "/ab?type=html&name=name"}, + {"/abc/ab.html", "/abc/ab.html"}, + {"/abcd/abcd.html", "/a/abcd/abcd.html"}, + {"/abcde/abcde.html", "/a"}, + {"/abcde/abcde.html#1234", "/a#1234"}, + {"/ab/ab.jpg", "/ajpg"}, + {"/reggrp/ad/12", "/a12"}, + {"/reggrp/ad/124a", "/a124/a"}, + {"/reggrp/ad/124abc", "/a124/abc"}, + {"/reg2grp/ad/124abc", "/ad/124abc"}, + {"/reg3grp/ad/aa/66", "/adaa66"}, + {"/reg3grp/ad612/n1n/ab", "/ad612n1nab"}, + } + + for i, test := range tests { + req, err := http.NewRequest("GET", test.from, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) + } + + rec := httptest.NewRecorder() + rw.ServeHTTP(rec, req) + + if rec.Body.String() != test.expectedTo { + t.Errorf("Test %d: Expected URL to be '%s' but was '%s'", + i, test.expectedTo, rec.Body.String()) + } + } + + statusTests := []struct { + status int + base string + to string + regexp string + statusExpected bool + }{ + {400, "/status", "", "", true}, + {400, "/ignore", "", "", false}, + {400, "/", "", "^/ignore", false}, + {400, "/", "", "(.*)", true}, + {400, "/status", "", "", true}, + } + + for i, s := range statusTests { + urlPath := fmt.Sprintf("/status%d", i) + rule, err := NewComplexRule(s.base, s.regexp, s.to, s.status, nil, nil) + if err != nil { + t.Fatalf("Test %d: No error expected for rule but found %v", i, err) + } + rw.Rules = []Rule{rule} + req, err := http.NewRequest("GET", urlPath, nil) + if err != nil { + t.Fatalf("Test %d: Could not create HTTP request: %v", i, err) + } + + rec := httptest.NewRecorder() + code, err := rw.ServeHTTP(rec, req) + if err != nil { + t.Fatalf("Test %d: No error expected for handler but found %v", i, err) + } + if s.statusExpected { + if rec.Body.String() != "" { + t.Errorf("Test %d: Expected empty body but found %s", i, rec.Body.String()) + } + if code != s.status { + t.Errorf("Test %d: Expected status code %d found %d", i, s.status, code) + } + } else { + if code != 0 { + t.Errorf("Test %d: Expected no status code found %d", i, code) + } + } + } +} + +func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) { + fmt.Fprintf(w, r.URL.String()) + return 0, nil +} diff --git a/middleware/rewrite/testdata/testdir/empty b/middleware/rewrite/testdata/testdir/empty new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/middleware/rewrite/testdata/testdir/empty diff --git a/middleware/rewrite/testdata/testfile b/middleware/rewrite/testdata/testfile new file mode 100644 index 000000000..7b4d68d70 --- /dev/null +++ b/middleware/rewrite/testdata/testfile @@ -0,0 +1 @@ +empty
\ No newline at end of file diff --git a/middleware/roller.go b/middleware/roller.go new file mode 100644 index 000000000..995cabf91 --- /dev/null +++ b/middleware/roller.go @@ -0,0 +1,27 @@ +package middleware + +import ( + "io" + + "gopkg.in/natefinch/lumberjack.v2" +) + +// LogRoller implements a middleware that provides a rolling logger. +type LogRoller struct { + Filename string + MaxSize int + MaxAge int + MaxBackups int + LocalTime bool +} + +// GetLogWriter returns an io.Writer that writes to a rolling logger. +func (l LogRoller) GetLogWriter() io.Writer { + return &lumberjack.Logger{ + Filename: l.Filename, + MaxSize: l.MaxSize, + MaxAge: l.MaxAge, + MaxBackups: l.MaxBackups, + LocalTime: l.LocalTime, + } +} diff --git a/middleware/zone.go b/middleware/zone.go new file mode 100644 index 000000000..6798bca8e --- /dev/null +++ b/middleware/zone.go @@ -0,0 +1,21 @@ +package middleware + +import "strings" + +type Zones []string + +// Matches checks to see if other matches p. +// The match will return the most specific zones +// that matches other. The empty string signals a not found +// condition. +func (z Zones) Matches(qname string) string { + zone := "" + for _, zname := range z { + if strings.HasSuffix(qname, zname) { + if len(zname) > len(zone) { + zone = zname + } + } + } + return zone +} diff --git a/server/config.go b/server/config.go new file mode 100644 index 000000000..79a9ba84d --- /dev/null +++ b/server/config.go @@ -0,0 +1,75 @@ +package server + +import ( + "net" + + "github.com/miekg/coredns/middleware" +) + +// Config configuration for a single server. +type Config struct { + // The hostname or IP on which to serve + Host string + + // The host address to bind on - defaults to (virtual) Host if empty + BindHost string + + // The port to listen on + Port string + + // The directory from which to parse db files + Root string + + // HTTPS configuration + TLS TLSConfig + + // Middleware stack + Middleware []middleware.Middleware + + // Startup is a list of functions (or methods) to execute at + // server startup and restart; these are executed before any + // parts of the server are configured, and the functions are + // blocking. These are good for setting up middlewares and + // starting goroutines. + Startup []func() error + + // FirstStartup is like Startup but these functions only execute + // during the initial startup, not on subsequent restarts. + // + // (Note: The server does not ever run these on its own; it is up + // to the calling application to do so, and do so only once, as the + // server itself has no notion whether it's a restart or not.) + FirstStartup []func() error + + // Functions (or methods) to execute when the server quits; + // these are executed in response to SIGINT and are blocking + Shutdown []func() error + + // The path to the configuration file from which this was loaded + ConfigFile string + + // The name of the application + AppName string + + // The application's version + AppVersion string +} + +// Address returns the host:port of c as a string. +func (c Config) Address() string { + return net.JoinHostPort(c.Host, c.Port) +} + +// TLSConfig describes how TLS should be configured and used. +type TLSConfig struct { + Enabled bool // will be set to true if TLS is enabled + LetsEncryptEmail string + Manual bool // will be set to true if user provides own certs and keys + Managed bool // will be set to true if config qualifies for implicit automatic/managed HTTPS + OnDemand bool // will be set to true if user enables on-demand TLS (obtain certs during handshakes) + Ciphers []uint16 + ProtocolMinVersion uint16 + ProtocolMaxVersion uint16 + PreferServerCipherSuites bool + ClientCerts []string +} diff --git a/server/config_test.go b/server/config_test.go new file mode 100644 index 000000000..8787e467b --- /dev/null +++ b/server/config_test.go @@ -0,0 +1,25 @@ +package server + +import "testing" + +func TestConfigAddress(t *testing.T) { + cfg := Config{Host: "foobar", Port: "1234"} + if actual, expected := cfg.Address(), "foobar:1234"; expected != actual { + t.Errorf("Expected '%s' but got '%s'", expected, actual) + } + + cfg = Config{Host: "", Port: "1234"} + if actual, expected := cfg.Address(), ":1234"; expected != actual { + t.Errorf("Expected '%s' but got '%s'", expected, actual) + } + + cfg = Config{Host: "foobar", Port: ""} + if actual, expected := cfg.Address(), "foobar:"; expected != actual { + t.Errorf("Expected '%s' but got '%s'", expected, actual) + } + + cfg = Config{Host: "::1", Port: "443"} + if actual, expected := cfg.Address(), "[::1]:443"; expected != actual { + t.Errorf("Expected '%s' but got '%s'", expected, actual) + } +} diff --git a/server/graceful.go b/server/graceful.go new file mode 100644 index 000000000..5057d039b --- /dev/null +++ b/server/graceful.go @@ -0,0 +1,76 @@ +package server + +import ( + "net" + "sync" + "syscall" +) + +// newGracefulListener returns a gracefulListener that wraps l and +// uses wg (stored in the host server) to count connections. +func newGracefulListener(l ListenerFile, wg *sync.WaitGroup) *gracefulListener { + gl := &gracefulListener{ListenerFile: l, stop: make(chan error), httpWg: wg} + go func() { + <-gl.stop + gl.Lock() + gl.stopped = true + gl.Unlock() + gl.stop <- gl.ListenerFile.Close() + }() + return gl +} + +// gracefuListener is a net.Listener which can +// count the number of connections on it. Its +// methods mainly wrap net.Listener to be graceful. +type gracefulListener struct { + ListenerFile + stop chan error + stopped bool + sync.Mutex // protects the stopped flag + httpWg *sync.WaitGroup // pointer to the host's wg used for counting connections +} + +// Accept accepts a connection. +func (gl *gracefulListener) Accept() (c net.Conn, err error) { + c, err = gl.ListenerFile.Accept() + if err != nil { + return + } + c = gracefulConn{Conn: c, httpWg: gl.httpWg} + gl.httpWg.Add(1) + return +} + +// Close immediately closes the listener. +func (gl *gracefulListener) Close() error { + gl.Lock() + if gl.stopped { + gl.Unlock() + return syscall.EINVAL + } + gl.Unlock() + gl.stop <- nil + return <-gl.stop +} + +// gracefulConn represents a connection on a +// gracefulListener so that we can keep track +// of the number of connections, thus facilitating +// a graceful shutdown. +type gracefulConn struct { + net.Conn + httpWg *sync.WaitGroup // pointer to the host server's connection waitgroup +} + +// Close closes c's underlying connection while updating the wg count. +func (c gracefulConn) Close() error { + err := c.Conn.Close() + if err != nil { + return err + } + // close can fail on http2 connections (as of Oct. 2015, before http2 in std lib) + // so don't decrement count unless close succeeds + c.httpWg.Done() + return nil +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 000000000..7baa74686 --- /dev/null +++ b/server/server.go @@ -0,0 +1,431 @@ +// Package server implements a configurable, general-purpose web server. +// It relies on configurations obtained from the adjacent config package +// and can execute middleware as defined by the adjacent middleware package. +package server + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "runtime" + "sync" + "time" + + "github.com/miekg/dns" +) + +// Server represents an instance of a server, which serves +// DNS requests at a particular address (host and port). A +// server is capable of serving numerous zones on +// the same address and the listener may be stopped for +// graceful termination (POSIX only). +type Server struct { + Addr string // Address we listen on + mux *dns.ServeMux + tls bool // whether this server is serving all HTTPS hosts or not + TLSConfig *tls.Config + OnDemandTLS bool // whether this server supports on-demand TLS (load certs at handshake-time) + zones map[string]zone // zones keyed by their address + listener ListenerFile // the listener which is bound to the socket + listenerMu sync.Mutex // protects listener + dnsWg sync.WaitGroup // used to wait on outstanding connections + startChan chan struct{} // used to block until server is finished starting + connTimeout time.Duration // the maximum duration of a graceful shutdown + ReqCallback OptionalCallback // if non-nil, is executed at the beginning of every request + SNICallback func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) +} + +// ListenerFile represents a listener. +type ListenerFile interface { + net.Listener + File() (*os.File, error) +} + +// OptionalCallback is a function that may or may not handle a request. +// It returns whether or not it handled the request. If it handled the +// request, it is presumed that no further request handling should occur. +type OptionalCallback func(dns.ResponseWriter, *dns.Msg) bool + +// New creates a new Server which will bind to addr and serve +// the sites/hosts configured in configs. Its listener will +// gracefully close when the server is stopped which will take +// no longer than gracefulTimeout. +// +// This function does not start serving. +// +// Do not re-use a server (start, stop, then start again). We +// could probably add more locking to make this possible, but +// as it stands, you should dispose of a server after stopping it. +// The behavior of serving with a spent server is undefined. +func New(addr string, configs []Config, gracefulTimeout time.Duration) (*Server, error) { + var useTLS, useOnDemandTLS bool + if len(configs) > 0 { + useTLS = configs[0].TLS.Enabled + useOnDemandTLS = configs[0].TLS.OnDemand + } + + s := &Server{ + Addr: addr, + TLSConfig: new(tls.Config), + // TODO: Make these values configurable? + // ReadTimeout: 2 * time.Minute, + // WriteTimeout: 2 * time.Minute, + // MaxHeaderBytes: 1 << 16, + tls: useTLS, + OnDemandTLS: useOnDemandTLS, + zones: make(map[string]zone), + startChan: make(chan struct{}), + connTimeout: gracefulTimeout, + } + mux := dns.NewServeMux() + mux.Handle(".", s) // wildcard handler, everything will go through here + s.mux = mux + + // We have to bound our wg with one increment + // to prevent a "race condition" that is hard-coded + // into sync.WaitGroup.Wait() - basically, an add + // with a positive delta must be guaranteed to + // occur before Wait() is called on the wg. + // In a way, this kind of acts as a safety barrier. + s.dnsWg.Add(1) + + // Set up each zone + for _, conf := range configs { + // TODO(miek): something better here? + if _, exists := s.zones[conf.Host]; exists { + return nil, fmt.Errorf("cannot serve %s - host already defined for address %s", conf.Address(), s.Addr) + } + + z := zone{config: conf} + + // Build middleware stack + err := z.buildStack() + if err != nil { + return nil, err + } + + s.zones[conf.Host] = z + } + + return s, nil +} + +// Serve starts the server with an existing listener. It blocks until the +// server stops. +/* +func (s *Server) Serve(ln ListenerFile) error { + // TODO(miek): Go DNS has no server stuff that allows you to give it a listener + // and use that. + err := s.setup() + if err != nil { + defer close(s.startChan) // MUST defer so error is properly reported, same with all cases in this file + return err + } + return s.serve(ln) +} +*/ + +// ListenAndServe starts the server with a new listener. It blocks until the server stops. +func (s *Server) ListenAndServe() error { + err := s.setup() + once := sync.Once{} + + if err != nil { + close(s.startChan) + return err + } + + // TODO(miek): redo to make it more like caddy + // - error handling, re-introduce what Caddy did. + go func() { + if err := dns.ListenAndServe(s.Addr, "tcp", s.mux); err != nil { + log.Printf("[ERROR] %v\n", err) + defer once.Do(func() { close(s.startChan) }) + return + } + }() + + go func() { + if err := dns.ListenAndServe(s.Addr, "udp", s.mux); err != nil { + log.Printf("[ERROR] %v\n", err) + defer once.Do(func() { close(s.startChan) }) + return + } + }() + once.Do(func() { close(s.startChan) }) // unblock anyone waiting for this to start listening + // but block here, as this is what caddy expects + for { + select {} + } + return nil +} + +// setup prepares the server s to begin listening; it should be +// called just before the listener announces itself on the network +// and should only be called when the server is just starting up. +func (s *Server) setup() error { + // Execute startup functions now + for _, z := range s.zones { + for _, startupFunc := range z.config.Startup { + err := startupFunc() + if err != nil { + return err + } + } + } + + return nil +} + +/* +TODO(miek): no such thing in the glorious Go DNS. +// serveTLS serves TLS with SNI and client auth support if s has them enabled. It +// blocks until s quits. +func serveTLS(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error { + // Customize our TLS configuration + s.TLSConfig.MinVersion = tlsConfigs[0].ProtocolMinVersion + s.TLSConfig.MaxVersion = tlsConfigs[0].ProtocolMaxVersion + s.TLSConfig.CipherSuites = tlsConfigs[0].Ciphers + s.TLSConfig.PreferServerCipherSuites = tlsConfigs[0].PreferServerCipherSuites + + // TLS client authentication, if user enabled it + err := setupClientAuth(tlsConfigs, s.TLSConfig) + if err != nil { + defer close(s.startChan) + return err + } + + // Create TLS listener - note that we do not replace s.listener + // with this TLS listener; tls.listener is unexported and does + // not implement the File() method we need for graceful restarts + // on POSIX systems. + ln = tls.NewListener(ln, s.TLSConfig) + + close(s.startChan) // unblock anyone waiting for this to start listening + return s.Serve(ln) +} +*/ + +// Stop stops the server. It blocks until the server is +// totally stopped. On POSIX systems, it will wait for +// connections to close (up to a max timeout of a few +// seconds); on Windows it will close the listener +// immediately. +func (s *Server) Stop() (err error) { + + if runtime.GOOS != "windows" { + // force connections to close after timeout + done := make(chan struct{}) + go func() { + s.dnsWg.Done() // decrement our initial increment used as a barrier + s.dnsWg.Wait() + close(done) + }() + + // Wait for remaining connections to finish or + // force them all to close after timeout + select { + case <-time.After(s.connTimeout): + case <-done: + } + } + + // Close the listener now; this stops the server without delay + s.listenerMu.Lock() + if s.listener != nil { + err = s.listener.Close() + } + s.listenerMu.Unlock() + + return +} + +// WaitUntilStarted blocks until the server s is started, meaning +// that practically the next instruction is to start the server loop. +// It also unblocks if the server encounters an error during startup. +func (s *Server) WaitUntilStarted() { + <-s.startChan +} + +// ListenerFd gets a dup'ed file of the listener. If there +// is no underlying file, the return value will be nil. It +// is the caller's responsibility to close the file. +func (s *Server) ListenerFd() *os.File { + s.listenerMu.Lock() + defer s.listenerMu.Unlock() + if s.listener != nil { + file, _ := s.listener.File() + return file + } + return nil +} + +// ServeDNS is the entry point for every request to the address that s +// is bound to. It acts as a multiplexer for the requests zonename as +// defined in the request so that the correct zone +// (configuration and middleware stack) will handle the request. +func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + defer func() { + // In case the user doesn't enable error middleware, we still + // need to make sure that we stay alive up here + if rec := recover(); rec != nil { + // TODO(miek): serverfailure return? + } + }() + + // Execute the optional request callback if it exists + if s.ReqCallback != nil && s.ReqCallback(w, r) { + return + } + + q := r.Question[0].Name + b := make([]byte, len(q)) + off, end := 0, false + for { + l := len(q[off:]) + for i := 0; i < l; i++ { + b[i] = q[off+i] + // normalize the name for the lookup + if b[i] >= 'A' && b[i] <= 'Z' { + b[i] |= ('a' - 'A') + } + } + + if h, ok := s.zones[string(b[:l])]; ok { + if r.Question[0].Qtype != dns.TypeDS { + rcode, _ := h.stack.ServeDNS(w, r) + if rcode > 0 { + DefaultErrorFunc(w, r, rcode) + } + return + } + } + off, end = dns.NextLabel(q, off) + if end { + break + } + } + // Wildcard match, if we have found nothing try the root zone as a last resort. + if h, ok := s.zones["."]; ok { + rcode, _ := h.stack.ServeDNS(w, r) + if rcode > 0 { + DefaultErrorFunc(w, r, rcode) + } + return + } + + // Still here? Error out with SERVFAIL and some logging + remoteHost := w.RemoteAddr().String() + DefaultErrorFunc(w, r, dns.RcodeServerFailure) + + fmt.Fprintf(w, "No such zone at %s", s.Addr) + log.Printf("[INFO] %s - No such zone at %s (Remote: %s)", q, s.Addr, remoteHost) +} + +// DefaultErrorFunc responds to an HTTP request with a simple description +// of the specified HTTP status code. +func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) { + answer := new(dns.Msg) + answer.SetRcode(r, rcode) + w.WriteMsg(answer) +} + +// setupClientAuth sets up TLS client authentication only if +// any of the TLS configs specified at least one cert file. +func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error { + var clientAuth bool + for _, cfg := range tlsConfigs { + if len(cfg.ClientCerts) > 0 { + clientAuth = true + break + } + } + + if clientAuth { + pool := x509.NewCertPool() + for _, cfg := range tlsConfigs { + for _, caFile := range cfg.ClientCerts { + caCrt, err := ioutil.ReadFile(caFile) // Anyone that gets a cert from this CA can connect + if err != nil { + return err + } + if !pool.AppendCertsFromPEM(caCrt) { + return fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile) + } + } + } + config.ClientCAs = pool + config.ClientAuth = tls.RequireAndVerifyClientCert + } + + return nil +} + +// RunFirstStartupFuncs runs all of the server's FirstStartup +// callback functions unless one of them returns an error first. +// It is the caller's responsibility to call this only once and +// at the correct time. The functions here should not be executed +// at restarts or where the user does not explicitly start a new +// instance of the server. +func (s *Server) RunFirstStartupFuncs() error { + for _, z := range s.zones { + for _, f := range z.config.FirstStartup { + if err := f(); err != nil { + return err + } + } + } + return nil +} + +// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by ListenAndServe and ListenAndServeTLS so +// dead TCP connections (e.g. closing laptop mid-download) eventually +// go away. +// +// Borrowed from the Go standard library. +type tcpKeepAliveListener struct { + *net.TCPListener +} + +// Accept accepts the connection with a keep-alive enabled. +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + +// File implements ListenerFile; returns the underlying file of the listener. +func (ln tcpKeepAliveListener) File() (*os.File, error) { + return ln.TCPListener.File() +} + +// ShutdownCallbacks executes all the shutdown callbacks +// for all the virtualhosts in servers, and returns all the +// errors generated during their execution. In other words, +// an error executing one shutdown callback does not stop +// execution of others. Only one shutdown callback is executed +// at a time. You must protect the servers that are passed in +// if they are shared across threads. +func ShutdownCallbacks(servers []*Server) []error { + var errs []error + for _, s := range servers { + for _, zone := range s.zones { + for _, shutdownFunc := range zone.config.Shutdown { + err := shutdownFunc() + if err != nil { + errs = append(errs, err) + } + } + } + } + return errs +} diff --git a/server/zones.go b/server/zones.go new file mode 100644 index 000000000..6a5a7a938 --- /dev/null +++ b/server/zones.go @@ -0,0 +1,28 @@ +package server + +import "github.com/miekg/coredns/middleware" + +// zone represents a DNS zone. While a Server +// is what actually binds to the address, a user may want to serve +// multiple zones on a single address, and this is what a +// zone allows us to do. +type zone struct { + config Config + stack middleware.Handler +} + +// buildStack builds the server's middleware stack based +// on its config. This method should be called last before +// ListenAndServe begins. +func (z *zone) buildStack() error { + z.compile(z.config.Middleware) + return nil +} + +// compile is an elegant alternative to nesting middleware function +// calls like handler1(handler2(handler3(finalHandler))). +func (z *zone) compile(layers []middleware.Middleware) { + for i := len(layers) - 1; i >= 0; i-- { + z.stack = layers[i](z.stack) + } +} |