aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Miek Gieben <miek@miek.nl> 2016-03-18 20:57:35 +0000
committerGravatar Miek Gieben <miek@miek.nl> 2016-03-18 20:57:35 +0000
commit3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d (patch)
treefae74c33cfed05de603785294593275f1901c861
downloadcoredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.tar.gz
coredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.tar.zst
coredns-3ec0d9fe6b133a64712ae69fd712c14ad1a71f4d.zip
First commit
-rw-r--r--.gitignore16
-rw-r--r--.travis.yml16
-rw-r--r--CONTRIBUTING.md77
-rw-r--r--Caddyfile-simple7
-rw-r--r--Corefile9
-rw-r--r--Corefile-name-alone1
-rw-r--r--ISSUE_TEMPLATE22
-rw-r--r--LICENSE.txt201
-rw-r--r--README.md174
-rw-r--r--TODO35
-rwxr-xr-xbuild.bash55
-rw-r--r--core/assets/path.go29
-rw-r--r--core/assets/path_test.go12
-rw-r--r--core/caddy.go388
-rw-r--r--core/caddy_test.go32
-rw-r--r--core/caddyfile/json.go173
-rw-r--r--core/caddyfile/json_test.go161
-rw-r--r--core/config.go346
-rw-r--r--core/config_test.go159
-rw-r--r--core/directives.go89
-rw-r--r--core/directives_test.go31
-rw-r--r--core/helpers.go102
-rw-r--r--core/https/certificates.go234
-rw-r--r--core/https/certificates_test.go59
-rw-r--r--core/https/client.go215
-rw-r--r--core/https/crypto.go57
-rw-r--r--core/https/crypto_test.go111
-rw-r--r--core/https/handler.go42
-rw-r--r--core/https/handler_test.go63
-rw-r--r--core/https/handshake.go320
-rw-r--r--core/https/handshake_test.go54
-rw-r--r--core/https/https.go358
-rw-r--r--core/https/https_test.go332
-rw-r--r--core/https/maintain.go211
-rw-r--r--core/https/setup.go321
-rw-r--r--core/https/setup_test.go232
-rw-r--r--core/https/storage.go94
-rw-r--r--core/https/storage_test.go88
-rw-r--r--core/https/user.go200
-rw-r--r--core/https/user_test.go196
-rw-r--r--core/parse/dispenser.go251
-rw-r--r--core/parse/dispenser_test.go292
-rw-r--r--core/parse/import_glob0.txt6
-rw-r--r--core/parse/import_glob1.txt4
-rw-r--r--core/parse/import_glob2.txt3
-rw-r--r--core/parse/import_test1.txt2
-rw-r--r--core/parse/import_test2.txt4
-rw-r--r--core/parse/lexer.go122
-rw-r--r--core/parse/lexer_test.go165
-rw-r--r--core/parse/parse.go32
-rw-r--r--core/parse/parse_test.go22
-rw-r--r--core/parse/parsing.go379
-rw-r--r--core/parse/parsing_test.go477
-rw-r--r--core/restart.go166
-rw-r--r--core/restart_windows.go31
-rw-r--r--core/setup/bindhost.go13
-rw-r--r--core/setup/controller.go83
-rw-r--r--core/setup/errors.go132
-rw-r--r--core/setup/errors_test.go158
-rw-r--r--core/setup/file.go73
-rw-r--r--core/setup/log.go130
-rw-r--r--core/setup/log_test.go175
-rw-r--r--core/setup/prometheus.go70
-rw-r--r--core/setup/proxy.go17
-rw-r--r--core/setup/reflect.go28
-rw-r--r--core/setup/rewrite.go109
-rw-r--r--core/setup/rewrite_test.go241
-rw-r--r--core/setup/roller.go40
-rw-r--r--core/setup/root.go32
-rw-r--r--core/setup/root_test.go108
-rw-r--r--core/setup/startupshutdown.go64
-rw-r--r--core/setup/startupshutdown_test.go59
-rw-r--r--core/setup/testdata/blog/first_post.md1
-rw-r--r--core/setup/testdata/header.html1
-rw-r--r--core/setup/testdata/tpl_with_include.html10
-rw-r--r--core/sigtrap.go71
-rw-r--r--core/sigtrap_posix.go79
-rw-r--r--core/sigtrap_windows.go3
-rw-r--r--db.dns.miek.nl10
-rw-r--r--db.miek.nl29
-rw-r--r--dist/CHANGES.txt190
-rw-r--r--dist/LICENSES.txt539
-rw-r--r--dist/README.txt30
-rwxr-xr-xdist/automate.sh56
-rw-r--r--main.go232
-rw-r--r--main_test.go75
-rw-r--r--middleware/commands.go120
-rw-r--r--middleware/commands_test.go291
-rw-r--r--middleware/context.go135
-rw-r--r--middleware/context_test.go613
-rw-r--r--middleware/errors/errors.go100
-rw-r--r--middleware/errors/errors_test.go168
-rw-r--r--middleware/etcd/TODO0
-rw-r--r--middleware/exchange.go10
-rw-r--r--middleware/file/file.go89
-rw-r--r--middleware/file/file_test.go325
-rw-r--r--middleware/host.go22
-rw-r--r--middleware/log/log.go66
-rw-r--r--middleware/log/log_test.go48
-rw-r--r--middleware/middleware.go105
-rw-r--r--middleware/middleware_test.go108
-rw-r--r--middleware/path.go18
-rw-r--r--middleware/prometheus/handler.go31
-rw-r--r--middleware/prometheus/metrics.go80
-rw-r--r--middleware/proxy/policy.go101
-rw-r--r--middleware/proxy/policy_test.go87
-rw-r--r--middleware/proxy/proxy.go120
-rw-r--r--middleware/proxy/proxy_test.go317
-rw-r--r--middleware/proxy/reverseproxy.go36
-rw-r--r--middleware/proxy/upstream.go235
-rw-r--r--middleware/proxy/upstream_test.go83
-rw-r--r--middleware/recorder.go70
-rw-r--r--middleware/recorder_test.go32
-rw-r--r--middleware/reflect/reflect.go84
-rw-r--r--middleware/reflect/reflect_test.go1
-rw-r--r--middleware/replacer.go98
-rw-r--r--middleware/replacer_test.go124
-rw-r--r--middleware/rewrite/condition.go130
-rw-r--r--middleware/rewrite/condition_test.go106
-rw-r--r--middleware/rewrite/reverter.go38
-rw-r--r--middleware/rewrite/rewrite.go223
-rw-r--r--middleware/rewrite/rewrite_test.go159
-rw-r--r--middleware/rewrite/testdata/testdir/empty0
-rw-r--r--middleware/rewrite/testdata/testfile1
-rw-r--r--middleware/roller.go27
-rw-r--r--middleware/zone.go21
-rw-r--r--server/config.go75
-rw-r--r--server/config_test.go25
-rw-r--r--server/graceful.go76
-rw-r--r--server/server.go431
-rw-r--r--server/zones.go28
131 files changed, 15193 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 000000000..0dd26ce5d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,16 @@
+.DS_Store
+Thumbs.db
+_gitignore/
+Vagrantfile
+.vagrant/
+
+dist/builds/
+dist/release/
+
+error.log
+access.log
+
+/*.conf
+Caddyfile
+
+og_static/ \ No newline at end of file
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 000000000..6a2da63db
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,16 @@
+language: go
+
+go:
+ - 1.6
+ - tip
+
+env:
+- CGO_ENABLED=0
+
+install:
+ - go get -t ./...
+ - go get golang.org/x/tools/cmd/vet
+
+script:
+ - go vet ./...
+ - go test ./...
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 000000000..346c6dcb9
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,77 @@
+## Contributing to Caddy
+
+Welcome! Our community focuses on helping others and making Caddy the best it
+can be. We gladly accept contributions and encourage you to get involved!
+
+
+### Join us in chat
+
+Please direct your discussion to the correct room:
+
+- **Dev Chat:** [gitter.im/mholt/caddy](https://gitter.im/mholt/caddy) - to chat
+with other Caddy developers
+- **Support:**
+[gitter.im/caddyserver/support](https://gitter.im/caddyserver/support) - to give
+and get help
+- **General:**
+[gitter.im/caddyserver/general](https://gitter.im/caddyserver/general) - for
+anything about Web development
+
+
+### Bug reports
+
+First, please [search this repository](https://github.com/mholt/caddy/search?q=&type=Issues&utf8=%E2%9C%93)
+with a variety of keywords to ensure your bug is not already reported.
+
+If not, [open an issue](https://github.com/mholt/caddy/issues) and answer the
+questions so we can understand and reproduce the problematic behavior.
+
+The burden is on you to convince us that it is actually a bug in Caddy. This is
+easiest to do when you write clear, concise instructions so we can reproduce
+the behavior (even if it seems obvious). The more detailed and specific you are,
+the faster we will be able to help you. Check out
+[How to Report Bugs Effectively](http://www.chiark.greenend.org.uk/~sgtatham/bugs.html).
+
+Please be kind. :smile: Remember that Caddy comes at no cost to you, and you're
+getting free help. If we helped you, please consider
+[donating](https://caddyserver.com/donate) - it keeps us motivated!
+
+
+### Minor improvements and new tests
+
+Submit [pull requests](https://github.com/mholt/caddy/pulls) at any time. Make
+sure to write tests to assert your change is working properly and is thoroughly
+covered.
+
+
+### Proposals, suggestions, ideas, new features
+
+First, please [search](https://github.com/mholt/caddy/search?q=&type=Issues&utf8=%E2%9C%93)
+with a variety of keywords to ensure your suggestion/proposal is new.
+
+If so, you may open either an issue or a pull request for discussion and
+feedback.
+
+The advantage of issues is that you don't have to spend time actually
+implementing your idea, but you should still describe it thoroughly. The
+advantage of a pull request is that we can immediately see the impact the change
+will have on the project, what the code will look like, and how to improve it.
+The disadvantage of pull requests is that they are unlikely to get accepted
+without significant changes, or it may be rejected entirely. Don't worry, that
+won't happen without an open discussion first.
+
+If you are going to spend significant time implementing code for a pull request,
+best to open an issue first and "claim" it and get feedback before you invest
+a lot of time.
+
+
+### Vulnerabilities
+
+If you've found a vulnerability that is serious, please email me: Matthew dot
+Holt at Gmail. If it's not a big deal, a pull request will probably be faster.
+
+
+## Thank you
+
+Thanks for your help! Caddy would not be what it is today without your
+contributions.
diff --git a/Caddyfile-simple b/Caddyfile-simple
new file mode 100644
index 000000000..610fe9a19
--- /dev/null
+++ b/Caddyfile-simple
@@ -0,0 +1,7 @@
+.:1053 {
+ prometheus
+ rewrite ANY HINFO
+
+ file db.miek.nl miek.nl
+ reflect
+}
diff --git a/Corefile b/Corefile
new file mode 100644
index 000000000..97a631950
--- /dev/null
+++ b/Corefile
@@ -0,0 +1,9 @@
+.:1053 {
+ file db.miek.nl miek.nl
+ proxy . 8.8.8.8:53
+}
+
+dns.miek.nl:1053 {
+ file db.dns.miek.nl
+ reflect
+}
diff --git a/Corefile-name-alone b/Corefile-name-alone
new file mode 100644
index 000000000..5bd787f3a
--- /dev/null
+++ b/Corefile-name-alone
@@ -0,0 +1 @@
+miek.nl
diff --git a/ISSUE_TEMPLATE b/ISSUE_TEMPLATE
new file mode 100644
index 000000000..f5b4ec6e2
--- /dev/null
+++ b/ISSUE_TEMPLATE
@@ -0,0 +1,22 @@
+*If you are filing a bug report, please answer these questions. If your issue is not a bug report,
+ you do not need to use this template. Either way, please consider donating if we've helped you.
+ Thanks!*
+
+#### 1. What version of CoreDNS are you running (`coredns -version`)?
+
+
+#### 2. What are you trying to do?
+
+
+#### 3. What is your entire Corefile?
+```text
+(Put Corefile here)
+```
+
+#### 4. How did you run CoreDNS (give the full command and describe the execution environment)?
+
+
+#### 5. What did you expect to see?
+
+
+#### 6. What did you see instead (give full error messages and/or log)?
diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100644
index 000000000..8dada3eda
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright {yyyy} {name of copyright owner}
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 000000000..1252131c3
--- /dev/null
+++ b/README.md
@@ -0,0 +1,174 @@
+<!--
+[![Caddy](https://caddyserver.com/resources/images/caddy-boxed.png)](https://caddyserver.com)
+
+[![Dev Chat](https://img.shields.io/badge/dev%20chat-gitter-ff69b4.svg?style=flat-square&label=dev+chat&color=ff69b4)](https://gitter.im/mholt/caddy)
+[![Documentation](https://img.shields.io/badge/godoc-reference-blue.svg?style=flat-square)](https://godoc.org/github.com/mholt/caddy)
+[![Linux Build Status](https://img.shields.io/travis/mholt/caddy.svg?style=flat-square&label=linux+build)](https://travis-ci.org/mholt/caddy)
+[![Windows Build Status](https://img.shields.io/appveyor/ci/mholt/caddy.svg?style=flat-square&label=windows+build)](https://ci.appveyor.com/project/mholt/caddy)
+-->
+
+CoreDNS is a lightweight, general-purpose DNS server for Windows, Mac, Linux, BSD
+and [Android](https://github.com/mholt/caddy/wiki/Running-Caddy-on-Android).
+It is a capable alternative to other popular and easy to use web servers.
+([@caddyserver](https://twitter.com/caddyserver) on Twitter)
+
+The most notable features are HTTP/2, [Let's Encrypt](https://letsencrypt.org)
+support, Virtual Hosts, TLS + SNI, and easy configuration with a
+[Caddyfile](https://caddyserver.com/docs/caddyfile). In development, you usually
+put one Caddyfile with each site. In production, Caddy serves HTTPS by default
+and manages all cryptographic assets for you.
+
+[Download](https://github.com/mholt/caddy/releases) ·
+[User Guide](https://caddyserver.com/docs)
+
+
+
+### Menu
+
+- [Getting Caddy](#getting-caddy)
+- [Quick Start](#quick-start)
+- [Running from Source](#running-from-source)
+- [Contributing](#contributing)
+- [About the Project](#about-the-project)
+
+
+
+
+## Getting Caddy
+
+Caddy binaries have no dependencies and are available for nearly every platform.
+
+[Latest release](https://github.com/mholt/caddy/releases/latest)
+
+
+
+## Quick Start
+
+The website has [full documentation](https://caddyserver.com/docs) but this will
+get you started in about 30 seconds:
+
+Place a file named "Caddyfile" with your site. Paste this into it and save:
+
+```
+localhost
+
+gzip
+browse
+ext .html
+websocket /echo cat
+log ../access.log
+header /api Access-Control-Allow-Origin *
+```
+
+Run `caddy` from that directory, and it will automatically use that Caddyfile to
+configure itself.
+
+That simple file enables compression, allows directory browsing (for folders
+without an index file), serves clean URLs, hosts a WebSocket echo server at
+/echo, logs requests to access.log, and adds the coveted
+`Access-Control-Allow-Origin: *` header for all responses from some API.
+
+Wow! Caddy can do a lot with just a few lines.
+
+
+#### Defining multiple sites
+
+You can run multiple sites from the same Caddyfile, too:
+
+```
+site1.com {
+ # ...
+}
+
+site2.com, sub.site2.com {
+ # ...
+}
+```
+
+Note that all these sites will automatically be served over HTTPS using Let's
+Encrypt as the CA. Caddy will manage the certificates (including renewals) for
+you. You don't even have to think about it.
+
+For more documentation, please view [the website](https://caddyserver.com/docs).
+You may also be interested in the [developer guide]
+(https://github.com/mholt/caddy/wiki) on this project's GitHub wiki.
+
+
+
+
+## Running from Source
+
+Note: You will need **[Go 1.6](https://golang.org/dl/)** or newer.
+
+1. `$ go get github.com/mholt/caddy`
+2. `cd` into your website's directory
+3. Run `caddy` (assumes `$GOPATH/bin` is in your `$PATH`)
+
+If you're tinkering, you can also use `go run main.go`.
+
+By default, Caddy serves the current directory at
+[localhost:2015](http://localhost:2015). You can place a Caddyfile to configure
+Caddy for serving your site.
+
+Caddy accepts some flags from the command line. Run `caddy -h` to view the help
+ for flags. You can also pipe a Caddyfile into the caddy command.
+
+**Running as root:** We advise against this; use setcap instead, like so:
+`setcap cap_net_bind_service=+ep ./caddy` This will allow you to listen on
+ports < 1024 like 80 and 443.
+
+
+
+#### Docker Container
+
+Caddy is available as a Docker container from any of these sources:
+
+- [abiosoft/caddy](https://hub.docker.com/r/abiosoft/caddy/)
+- [darron/caddy](https://hub.docker.com/r/darron/caddy/)
+- [joshix/caddy](https://hub.docker.com/r/joshix/caddy/)
+- [jumanjiman/caddy](https://hub.docker.com/r/jumanjiman/caddy/)
+- [zenithar/nano-caddy](https://hub.docker.com/r/zenithar/nano-caddy/)
+
+
+
+#### 3rd-party dependencies
+
+Although Caddy's binaries are completely static, Caddy relies on some excellent
+libraries. [Godoc.org](https://godoc.org/github.com/mholt/caddy) shows the
+packages that each Caddy package imports.
+
+
+
+
+## Contributing
+
+**[Join our dev chat on Gitter](https://gitter.im/mholt/caddy)** to chat with
+other Caddy developers! (Dev chat only; try our
+[support room](https://gitter.im/caddyserver/support) for help or
+[general](https://gitter.im/caddyserver/general) for anything else.)
+
+This project would not be what it is without your help. Please see the
+[contributing guidelines](https://github.com/mholt/caddy/blob/master/CONTRIBUTING.md)
+if you haven't already.
+
+Thanks for making Caddy -- and the Web -- better!
+
+Special thanks to
+[![DigitalOcean](http://i.imgur.com/sfGr0eY.png)](https://www.digitalocean.com)
+for hosting the Caddy project.
+
+
+
+
+## About the project
+
+Caddy was born out of the need for a "batteries-included" web server that runs
+anywhere and doesn't have to take its configuration with it. Caddy took
+inspiration from [spark](https://github.com/rif/spark),
+[nginx](https://github.com/nginx/nginx), lighttpd,
+[Websocketd](https://github.com/joewalnes/websocketd)
+and [Vagrant](https://www.vagrantup.com/),
+which provides a pleasant mixture of features from each of them.
+
+
+*Twitter: [@mholt6](https://twitter.com/mholt6)*
diff --git a/TODO b/TODO
new file mode 100644
index 000000000..79047fb5e
--- /dev/null
+++ b/TODO
@@ -0,0 +1,35 @@
+* Fix file middleware to use a proper zone implementation
+ * Zone parsing (better zone impl.)
+* Zones file parsing is done twice on startup??
+* Might need global middleware state between middlewares
+* Cleanup/make middlewares
+ * Fix complex rewrite to be useful
+ * Healthcheck middleware
+ * Slave zone middleware
+ * SkyDNS middleware, or call it etcd?
+* Fix graceful restart
+* TESTS; don't compile, need cleanups
+* http.FileSystem is half used, half not used. It's a nice abstraction
+ for finding (zone) files, maybe we should just use it.
+* prometheus:
+ * track the query type
+ * track the correct zone
+
+When there is already something running.
+
+BUG: server/server.go ListenAndServe
+Activating privacy features...
+.:1053
+panic: close of closed channel
+
+goroutine 40 [running]:
+panic(0x8e5b60, 0xc8201b60b0)
+ /home/miek/upstream/go/src/runtime/panic.go:464 +0x3e6
+github.com/miekg/daddy/server.(*Server).ListenAndServe.func1.1()
+ /home/miek/g/src/github.com/miekg/daddy/server/server.go:147 +0x24
+sync.(*Once).Do(0xc82011b830, 0xc8201d3f38)
+ /home/miek/upstream/go/src/sync/once.go:44 +0xe4
+github.com/miekg/daddy/server.(*Server).ListenAndServe.func1(0xc82011b4c0, 0xc820090800, 0xc82011b830)
+ /home/miek/g/src/github.com/miekg/daddy/server/server.go:148 +0x1e3
+created by github.com/miekg/daddy/server.(*Server).ListenAndServe
+ /home/miek/g/src/github.com/miekg/daddy/server/server.go:150 +0xfe
diff --git a/build.bash b/build.bash
new file mode 100755
index 000000000..b7c97d1ec
--- /dev/null
+++ b/build.bash
@@ -0,0 +1,55 @@
+#!/usr/bin/env bash
+#
+# Caddy build script. Automates proper versioning.
+#
+# Usage:
+#
+# $ ./build.bash [output_filename]
+#
+# Outputs compiled program in current directory.
+# Default file name is 'ecaddy'.
+#
+set -e
+
+output="$1"
+if [ -z "$output" ]; then
+ output="ecaddy"
+fi
+
+pkg=main
+
+# Timestamp of build
+builddate_id=$pkg.buildDate
+builddate=`date -u`
+
+# Current tag, if HEAD is on a tag
+tag_id=$pkg.gitTag
+set +e
+tag=`git describe --exact-match HEAD 2> /dev/null`
+set -e
+
+# Nearest tag on branch
+lasttag_id=$pkg.gitNearestTag
+lasttag=`git describe --abbrev=0 --tags HEAD`
+
+# Commit SHA
+commit_id=$pkg.gitCommit
+commit=`git rev-parse --short HEAD`
+
+# Summary of uncommited changes
+shortstat_id=$pkg.gitShortStat
+shortstat=`git diff-index --shortstat HEAD`
+
+# List of modified files
+files_id=$pkg.gitFilesModified
+files=`git diff-index --name-only HEAD`
+
+
+go build -ldflags "
+ -X \"$builddate_id=$builddate\"
+ -X \"$tag_id=$tag\"
+ -X \"$lasttag_id=$lasttag\"
+ -X \"$commit_id=$commit\"
+ -X \"$shortstat_id=$shortstat\"
+ -X \"$files_id=$files\"
+" -o "$output"
diff --git a/core/assets/path.go b/core/assets/path.go
new file mode 100644
index 000000000..46b883b1c
--- /dev/null
+++ b/core/assets/path.go
@@ -0,0 +1,29 @@
+package assets
+
+import (
+ "os"
+ "path/filepath"
+ "runtime"
+)
+
+// Path returns the path to the folder
+// where the application may store data. This
+// currently resolves to ~/.caddy
+func Path() string {
+ return filepath.Join(userHomeDir(), ".caddy")
+}
+
+// userHomeDir returns the user's home directory according to
+// environment variables.
+//
+// Credit: http://stackoverflow.com/a/7922977/1048862
+func userHomeDir() string {
+ if runtime.GOOS == "windows" {
+ home := os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH")
+ if home == "" {
+ home = os.Getenv("USERPROFILE")
+ }
+ return home
+ }
+ return os.Getenv("HOME")
+}
diff --git a/core/assets/path_test.go b/core/assets/path_test.go
new file mode 100644
index 000000000..374f813af
--- /dev/null
+++ b/core/assets/path_test.go
@@ -0,0 +1,12 @@
+package assets
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestPath(t *testing.T) {
+ if actual := Path(); !strings.HasSuffix(actual, ".caddy") {
+ t.Errorf("Expected path to be a .caddy folder, got: %v", actual)
+ }
+}
diff --git a/core/caddy.go b/core/caddy.go
new file mode 100644
index 000000000..e76fa28f1
--- /dev/null
+++ b/core/caddy.go
@@ -0,0 +1,388 @@
+// Package caddy implements the Caddy web server as a service
+// in your own Go programs.
+//
+// To use this package, follow a few simple steps:
+//
+// 1. Set the AppName and AppVersion variables.
+// 2. Call LoadCaddyfile() to get the Caddyfile (it
+// might have been piped in as part of a restart).
+// You should pass in your own Caddyfile loader.
+// 3. Call caddy.Start() to start Caddy, caddy.Stop()
+// to stop it, or caddy.Restart() to restart it.
+//
+// You should use caddy.Wait() to wait for all Caddy servers
+// to quit before your process exits.
+package core
+
+import (
+ "bytes"
+ "encoding/gob"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "os"
+ "path"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/miekg/coredns/core/https"
+ "github.com/miekg/coredns/server"
+)
+
+// Configurable application parameters
+var (
+ // AppName is the name of the application.
+ AppName string
+
+ // AppVersion is the version of the application.
+ AppVersion string
+
+ // Quiet when set to true, will not show any informative output on initialization.
+ Quiet bool
+
+ // PidFile is the path to the pidfile to create.
+ PidFile string
+
+ // GracefulTimeout is the maximum duration of a graceful shutdown.
+ GracefulTimeout time.Duration
+)
+
+var (
+ // caddyfile is the input configuration text used for this process
+ caddyfile Input
+
+ // caddyfileMu protects caddyfile during changes
+ caddyfileMu sync.Mutex
+
+ // errIncompleteRestart occurs if this process is a fork
+ // of the parent but no Caddyfile was piped in
+ errIncompleteRestart = errors.New("incomplete restart")
+
+ // servers is a list of all the currently-listening servers
+ servers []*server.Server
+
+ // serversMu protects the servers slice during changes
+ serversMu sync.Mutex
+
+ // wg is used to wait for all servers to shut down
+ wg sync.WaitGroup
+
+ // loadedGob is used if this is a child process as part of
+ // a graceful restart; it is used to map listeners to their
+ // index in the list of inherited file descriptors. This
+ // variable is not safe for concurrent access.
+ loadedGob caddyfileGob
+
+ // startedBefore should be set to true if caddy has been started
+ // at least once (does not indicate whether currently running).
+ startedBefore bool
+)
+
+const (
+ // DefaultHost is the default host.
+ DefaultHost = ""
+ // DefaultPort is the default port.
+ DefaultPort = "53"
+ // DefaultRoot is the default root folder.
+ DefaultRoot = "."
+)
+
+// Start starts Caddy with the given Caddyfile. If cdyfile
+// is nil, the LoadCaddyfile function will be called to get
+// one.
+//
+// This function blocks until all the servers are listening.
+//
+// Note (POSIX): If Start is called in the child process of a
+// restart more than once within the duration of the graceful
+// cutoff (i.e. the child process called Start a first time,
+// then called Stop, then Start again within the first 5 seconds
+// or however long GracefulTimeout is) and the Caddyfiles have
+// at least one listener address in common, the second Start
+// may fail with "address already in use" as there's no
+// guarantee that the parent process has relinquished the
+// address before the grace period ends.
+func Start(cdyfile Input) (err error) {
+ // If we return with no errors, we must do two things: tell the
+ // parent that we succeeded and write to the pidfile.
+ defer func() {
+ if err == nil {
+ signalSuccessToParent() // TODO: Is doing this more than once per process a bad idea? Start could get called more than once in other apps.
+ if PidFile != "" {
+ err := writePidFile()
+ if err != nil {
+ log.Printf("[ERROR] Could not write pidfile: %v", err)
+ }
+ }
+ }
+ }()
+
+ // Input must never be nil; try to load something
+ if cdyfile == nil {
+ cdyfile, err = LoadCaddyfile(nil)
+ if err != nil {
+ return err
+ }
+ }
+
+ caddyfileMu.Lock()
+ caddyfile = cdyfile
+ caddyfileMu.Unlock()
+
+ // load the server configs (activates Let's Encrypt)
+ configs, err := loadConfigs(path.Base(cdyfile.Path()), bytes.NewReader(cdyfile.Body()))
+ if err != nil {
+ return err
+ }
+
+ // group zones by address
+ groupings, err := arrangeBindings(configs)
+ if err != nil {
+ return err
+ }
+
+ // Start each server with its one or more configurations
+ err = startServers(groupings)
+ if err != nil {
+ return err
+ }
+ startedBefore = true
+
+ // Show initialization output
+ if !Quiet && !IsRestart() {
+ var checkedFdLimit bool
+ for _, group := range groupings {
+ for _, conf := range group.Configs {
+ // Print address of site
+ fmt.Println(conf.Address())
+
+ // Note if non-localhost site resolves to loopback interface
+ if group.BindAddr.IP.IsLoopback() && !isLocalhost(conf.Host) {
+ fmt.Printf("Notice: %s is only accessible on this machine (%s)\n",
+ conf.Host, group.BindAddr.IP.String())
+ }
+ if !checkedFdLimit && !group.BindAddr.IP.IsLoopback() && !isLocalhost(conf.Host) {
+ checkFdlimit()
+ checkedFdLimit = true
+ }
+ }
+ }
+ }
+
+ return nil
+}
+
+// startServers starts all the servers in groupings,
+// taking into account whether or not this process is
+// a child from a graceful restart or not. It blocks
+// until the servers are listening.
+func startServers(groupings bindingGroup) error {
+ var startupWg sync.WaitGroup
+ errChan := make(chan error, len(groupings)) // must be buffered to allow Serve functions below to return if stopped later
+
+ for _, group := range groupings {
+ s, err := server.New(group.BindAddr.String(), group.Configs, GracefulTimeout)
+ if err != nil {
+ return err
+ }
+ // TODO(miek): does not work, because this callback uses http instead of dns
+ // s.ReqCallback = https.RequestCallback // ensures we can solve ACME challenges while running
+ if s.OnDemandTLS {
+ s.TLSConfig.GetCertificate = https.GetOrObtainCertificate // TLS on demand -- awesome!
+ } else {
+ s.TLSConfig.GetCertificate = https.GetCertificate
+ }
+
+ var ln server.ListenerFile
+ /*
+ if IsRestart() {
+ // Look up this server's listener in the map of inherited file descriptors;
+ // if we don't have one, we must make a new one (later).
+ if fdIndex, ok := loadedGob.ListenerFds[s.Addr]; ok {
+ file := os.NewFile(fdIndex, "")
+
+ fln, err := net.FileListener(file)
+ if err != nil {
+ return err
+ }
+
+ ln, ok = fln.(server.ListenerFile)
+ if !ok {
+ return errors.New("listener for " + s.Addr + " was not a ListenerFile")
+ }
+
+ file.Close()
+ delete(loadedGob.ListenerFds, s.Addr)
+ }
+ }
+ */
+
+ wg.Add(1)
+ go func(s *server.Server, ln server.ListenerFile) {
+ defer wg.Done()
+
+ // run startup functions that should only execute when
+ // the original parent process is starting.
+ if !IsRestart() && !startedBefore {
+ err := s.RunFirstStartupFuncs()
+ if err != nil {
+ errChan <- err
+ return
+ }
+ }
+
+ // start the server
+ // TODO(miek): for now will always be nil, so we will run ListenAndServe()
+ if ln != nil {
+ //errChan <- s.Serve(ln)
+ } else {
+ errChan <- s.ListenAndServe()
+ }
+ }(s, ln)
+
+ startupWg.Add(1)
+ go func(s *server.Server) {
+ defer startupWg.Done()
+ s.WaitUntilStarted()
+ }(s)
+
+ serversMu.Lock()
+ servers = append(servers, s)
+ serversMu.Unlock()
+ }
+
+ // Close the remaining (unused) file descriptors to free up resources
+ if IsRestart() {
+ for key, fdIndex := range loadedGob.ListenerFds {
+ os.NewFile(fdIndex, "").Close()
+ delete(loadedGob.ListenerFds, key)
+ }
+ }
+
+ // Wait for all servers to finish starting
+ startupWg.Wait()
+
+ // Return the first error, if any
+ select {
+ case err := <-errChan:
+ // "use of closed network connection" is normal if it was a graceful shutdown
+ if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
+ return err
+ }
+ default:
+ }
+
+ return nil
+}
+
+// Stop stops all servers. It blocks until they are all stopped.
+// It does NOT execute shutdown callbacks that may have been
+// configured by middleware (they must be executed separately).
+func Stop() error {
+ https.Deactivate()
+
+ serversMu.Lock()
+ for _, s := range servers {
+ if err := s.Stop(); err != nil {
+ log.Printf("[ERROR] Stopping %s: %v", s.Addr, err)
+ }
+ }
+ servers = []*server.Server{} // don't reuse servers
+ serversMu.Unlock()
+
+ return nil
+}
+
+// Wait blocks until all servers are stopped.
+func Wait() {
+ wg.Wait()
+}
+
+// LoadCaddyfile loads a Caddyfile, prioritizing a Caddyfile
+// piped from stdin as part of a restart (only happens on first call
+// to LoadCaddyfile). If it is not a restart, this function tries
+// calling the user's loader function, and if that returns nil, then
+// this function resorts to the default configuration. Thus, if there
+// are no other errors, this function always returns at least the
+// default Caddyfile.
+func LoadCaddyfile(loader func() (Input, error)) (cdyfile Input, err error) {
+ // If we are a fork, finishing the restart is highest priority;
+ // piped input is required in this case.
+ if IsRestart() {
+ err := gob.NewDecoder(os.Stdin).Decode(&loadedGob)
+ if err != nil {
+ return nil, err
+ }
+ cdyfile = loadedGob.Caddyfile
+ atomic.StoreInt32(https.OnDemandIssuedCount, loadedGob.OnDemandTLSCertsIssued)
+ }
+
+ // Try user's loader
+ if cdyfile == nil && loader != nil {
+ cdyfile, err = loader()
+ }
+
+ // Otherwise revert to default
+ if cdyfile == nil {
+ cdyfile = DefaultInput()
+ }
+
+ return
+}
+
+// CaddyfileFromPipe loads the Caddyfile input from f if f is
+// not interactive input. f is assumed to be a pipe or stream,
+// such as os.Stdin. If f is not a pipe, no error is returned
+// but the Input value will be nil. An error is only returned
+// if there was an error reading the pipe, even if the length
+// of what was read is 0.
+func CaddyfileFromPipe(f *os.File) (Input, error) {
+ fi, err := f.Stat()
+ if err == nil && fi.Mode()&os.ModeCharDevice == 0 {
+ // Note that a non-nil error is not a problem. Windows
+ // will not create a stdin if there is no pipe, which
+ // produces an error when calling Stat(). But Unix will
+ // make one either way, which is why we also check that
+ // bitmask.
+ // BUG: Reading from stdin after this fails (e.g. for the let's encrypt email address) (OS X)
+ confBody, err := ioutil.ReadAll(f)
+ if err != nil {
+ return nil, err
+ }
+ return CaddyfileInput{
+ Contents: confBody,
+ Filepath: f.Name(),
+ }, nil
+ }
+
+ // not having input from the pipe is not itself an error,
+ // just means no input to return.
+ return nil, nil
+}
+
+// Caddyfile returns the current Caddyfile
+func Caddyfile() Input {
+ caddyfileMu.Lock()
+ defer caddyfileMu.Unlock()
+ return caddyfile
+}
+
+// Input represents a Caddyfile; its contents and file path
+// (which should include the file name at the end of the path).
+// If path does not apply (e.g. piped input) you may use
+// any understandable value. The path is mainly used for logging,
+// error messages, and debugging.
+type Input interface {
+ // Gets the Caddyfile contents
+ Body() []byte
+
+ // Gets the path to the origin file
+ Path() string
+
+ // IsFile returns true if the original input was a file on the file system
+ // that could be loaded again later if requested.
+ IsFile() bool
+}
diff --git a/core/caddy_test.go b/core/caddy_test.go
new file mode 100644
index 000000000..1dc230a94
--- /dev/null
+++ b/core/caddy_test.go
@@ -0,0 +1,32 @@
+package core
+
+import (
+ "net/http"
+ "testing"
+ "time"
+)
+
+func TestCaddyStartStop(t *testing.T) {
+ caddyfile := "localhost:1984"
+
+ for i := 0; i < 2; i++ {
+ err := Start(CaddyfileInput{Contents: []byte(caddyfile)})
+ if err != nil {
+ t.Fatalf("Error starting, iteration %d: %v", i, err)
+ }
+
+ client := http.Client{
+ Timeout: time.Duration(2 * time.Second),
+ }
+ resp, err := client.Get("http://localhost:1984")
+ if err != nil {
+ t.Fatalf("Expected GET request to succeed (iteration %d), but it failed: %v", i, err)
+ }
+ resp.Body.Close()
+
+ err = Stop()
+ if err != nil {
+ t.Fatalf("Error stopping, iteration %d: %v", i, err)
+ }
+ }
+}
diff --git a/core/caddyfile/json.go b/core/caddyfile/json.go
new file mode 100644
index 000000000..6f4e66771
--- /dev/null
+++ b/core/caddyfile/json.go
@@ -0,0 +1,173 @@
+package caddyfile
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "sort"
+ "strconv"
+ "strings"
+
+ "github.com/miekg/coredns/core/parse"
+)
+
+const filename = "Caddyfile"
+
+// ToJSON converts caddyfile to its JSON representation.
+func ToJSON(caddyfile []byte) ([]byte, error) {
+ var j Caddyfile
+
+ serverBlocks, err := parse.ServerBlocks(filename, bytes.NewReader(caddyfile), false)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, sb := range serverBlocks {
+ block := ServerBlock{Body: [][]interface{}{}}
+
+ // Fill up host list
+ for _, host := range sb.HostList() {
+ block.Hosts = append(block.Hosts, host)
+ }
+
+ // Extract directives deterministically by sorting them
+ var directives = make([]string, len(sb.Tokens))
+ for dir := range sb.Tokens {
+ directives = append(directives, dir)
+ }
+ sort.Strings(directives)
+
+ // Convert each directive's tokens into our JSON structure
+ for _, dir := range directives {
+ disp := parse.NewDispenserTokens(filename, sb.Tokens[dir])
+ for disp.Next() {
+ block.Body = append(block.Body, constructLine(&disp))
+ }
+ }
+
+ // tack this block onto the end of the list
+ j = append(j, block)
+ }
+
+ result, err := json.Marshal(j)
+ if err != nil {
+ return nil, err
+ }
+
+ return result, nil
+}
+
+// constructLine transforms tokens into a JSON-encodable structure;
+// but only one line at a time, to be used at the top-level of
+// a server block only (where the first token on each line is a
+// directive) - not to be used at any other nesting level.
+func constructLine(d *parse.Dispenser) []interface{} {
+ var args []interface{}
+
+ args = append(args, d.Val())
+
+ for d.NextArg() {
+ if d.Val() == "{" {
+ args = append(args, constructBlock(d))
+ continue
+ }
+ args = append(args, d.Val())
+ }
+
+ return args
+}
+
+// constructBlock recursively processes tokens into a
+// JSON-encodable structure. To be used in a directive's
+// block. Goes to end of block.
+func constructBlock(d *parse.Dispenser) [][]interface{} {
+ block := [][]interface{}{}
+
+ for d.Next() {
+ if d.Val() == "}" {
+ break
+ }
+ block = append(block, constructLine(d))
+ }
+
+ return block
+}
+
+// FromJSON converts JSON-encoded jsonBytes to Caddyfile text
+func FromJSON(jsonBytes []byte) ([]byte, error) {
+ var j Caddyfile
+ var result string
+
+ err := json.Unmarshal(jsonBytes, &j)
+ if err != nil {
+ return nil, err
+ }
+
+ for sbPos, sb := range j {
+ if sbPos > 0 {
+ result += "\n\n"
+ }
+ for i, host := range sb.Hosts {
+ if i > 0 {
+ result += ", "
+ }
+ result += host
+ }
+ result += jsonToText(sb.Body, 1)
+ }
+
+ return []byte(result), nil
+}
+
+// jsonToText recursively transforms a scope of JSON into plain
+// Caddyfile text.
+func jsonToText(scope interface{}, depth int) string {
+ var result string
+
+ switch val := scope.(type) {
+ case string:
+ if strings.ContainsAny(val, "\" \n\t\r") {
+ result += `"` + strings.Replace(val, "\"", "\\\"", -1) + `"`
+ } else {
+ result += val
+ }
+ case int:
+ result += strconv.Itoa(val)
+ case float64:
+ result += fmt.Sprintf("%v", val)
+ case bool:
+ result += fmt.Sprintf("%t", val)
+ case [][]interface{}:
+ result += " {\n"
+ for _, arg := range val {
+ result += strings.Repeat("\t", depth) + jsonToText(arg, depth+1) + "\n"
+ }
+ result += strings.Repeat("\t", depth-1) + "}"
+ case []interface{}:
+ for i, v := range val {
+ if block, ok := v.([]interface{}); ok {
+ result += "{\n"
+ for _, arg := range block {
+ result += strings.Repeat("\t", depth) + jsonToText(arg, depth+1) + "\n"
+ }
+ result += strings.Repeat("\t", depth-1) + "}"
+ continue
+ }
+ result += jsonToText(v, depth)
+ if i < len(val)-1 {
+ result += " "
+ }
+ }
+ }
+
+ return result
+}
+
+// Caddyfile encapsulates a slice of ServerBlocks.
+type Caddyfile []ServerBlock
+
+// ServerBlock represents a server block.
+type ServerBlock struct {
+ Hosts []string `json:"hosts"`
+ Body [][]interface{} `json:"body"`
+}
diff --git a/core/caddyfile/json_test.go b/core/caddyfile/json_test.go
new file mode 100644
index 000000000..2e44ae2a2
--- /dev/null
+++ b/core/caddyfile/json_test.go
@@ -0,0 +1,161 @@
+package caddyfile
+
+import "testing"
+
+var tests = []struct {
+ caddyfile, json string
+}{
+ { // 0
+ caddyfile: `foo {
+ root /bar
+}`,
+ json: `[{"hosts":["foo"],"body":[["root","/bar"]]}]`,
+ },
+ { // 1
+ caddyfile: `host1, host2 {
+ dir {
+ def
+ }
+}`,
+ json: `[{"hosts":["host1","host2"],"body":[["dir",[["def"]]]]}]`,
+ },
+ { // 2
+ caddyfile: `host1, host2 {
+ dir abc {
+ def ghi
+ jkl
+ }
+}`,
+ json: `[{"hosts":["host1","host2"],"body":[["dir","abc",[["def","ghi"],["jkl"]]]]}]`,
+ },
+ { // 3
+ caddyfile: `host1:1234, host2:5678 {
+ dir abc {
+ }
+}`,
+ json: `[{"hosts":["host1:1234","host2:5678"],"body":[["dir","abc",[]]]}]`,
+ },
+ { // 4
+ caddyfile: `host {
+ foo "bar baz"
+}`,
+ json: `[{"hosts":["host"],"body":[["foo","bar baz"]]}]`,
+ },
+ { // 5
+ caddyfile: `host, host:80 {
+ foo "bar \"baz\""
+}`,
+ json: `[{"hosts":["host","host:80"],"body":[["foo","bar \"baz\""]]}]`,
+ },
+ { // 6
+ caddyfile: `host {
+ foo "bar
+baz"
+}`,
+ json: `[{"hosts":["host"],"body":[["foo","bar\nbaz"]]}]`,
+ },
+ { // 7
+ caddyfile: `host {
+ dir 123 4.56 true
+}`,
+ json: `[{"hosts":["host"],"body":[["dir","123","4.56","true"]]}]`, // NOTE: I guess we assume numbers and booleans should be encoded as strings...?
+ },
+ { // 8
+ caddyfile: `http://host, https://host {
+}`,
+ json: `[{"hosts":["http://host","https://host"],"body":[]}]`, // hosts in JSON are always host:port format (if port is specified), for consistency
+ },
+ { // 9
+ caddyfile: `host {
+ dir1 a b
+ dir2 c d
+}`,
+ json: `[{"hosts":["host"],"body":[["dir1","a","b"],["dir2","c","d"]]}]`,
+ },
+ { // 10
+ caddyfile: `host {
+ dir a b
+ dir c d
+}`,
+ json: `[{"hosts":["host"],"body":[["dir","a","b"],["dir","c","d"]]}]`,
+ },
+ { // 11
+ caddyfile: `host {
+ dir1 a b
+ dir2 {
+ c
+ d
+ }
+}`,
+ json: `[{"hosts":["host"],"body":[["dir1","a","b"],["dir2",[["c"],["d"]]]]}]`,
+ },
+ { // 12
+ caddyfile: `host1 {
+ dir1
+}
+
+host2 {
+ dir2
+}`,
+ json: `[{"hosts":["host1"],"body":[["dir1"]]},{"hosts":["host2"],"body":[["dir2"]]}]`,
+ },
+}
+
+func TestToJSON(t *testing.T) {
+ for i, test := range tests {
+ output, err := ToJSON([]byte(test.caddyfile))
+ if err != nil {
+ t.Errorf("Test %d: %v", i, err)
+ }
+ if string(output) != test.json {
+ t.Errorf("Test %d\nExpected:\n'%s'\nActual:\n'%s'", i, test.json, string(output))
+ }
+ }
+}
+
+func TestFromJSON(t *testing.T) {
+ for i, test := range tests {
+ output, err := FromJSON([]byte(test.json))
+ if err != nil {
+ t.Errorf("Test %d: %v", i, err)
+ }
+ if string(output) != test.caddyfile {
+ t.Errorf("Test %d\nExpected:\n'%s'\nActual:\n'%s'", i, test.caddyfile, string(output))
+ }
+ }
+}
+
+func TestStandardizeAddress(t *testing.T) {
+ // host:https should be converted to https://host
+ output, err := ToJSON([]byte(`host:https`))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if expected, actual := `[{"hosts":["https://host"],"body":[]}]`, string(output); expected != actual {
+ t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual)
+ }
+
+ output, err = FromJSON([]byte(`[{"hosts":["https://host"],"body":[]}]`))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if expected, actual := "https://host {\n}", string(output); expected != actual {
+ t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual)
+ }
+
+ // host: should be converted to just host
+ output, err = ToJSON([]byte(`host:`))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if expected, actual := `[{"hosts":["host"],"body":[]}]`, string(output); expected != actual {
+ t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual)
+ }
+ output, err = FromJSON([]byte(`[{"hosts":["host:"],"body":[]}]`))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if expected, actual := "host {\n}", string(output); expected != actual {
+ t.Errorf("Expected:\n'%s'\nActual:\n'%s'", expected, actual)
+ }
+}
diff --git a/core/config.go b/core/config.go
new file mode 100644
index 000000000..8c376d878
--- /dev/null
+++ b/core/config.go
@@ -0,0 +1,346 @@
+package core
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "log"
+ "net"
+ "sync"
+
+ "github.com/miekg/coredns/core/https"
+ "github.com/miekg/coredns/core/parse"
+ "github.com/miekg/coredns/core/setup"
+ "github.com/miekg/coredns/server"
+)
+
+const (
+ // DefaultConfigFile is the name of the configuration file that is loaded
+ // by default if no other file is specified.
+ DefaultConfigFile = "Corefile"
+)
+
+func loadConfigsUpToIncludingTLS(filename string, input io.Reader) ([]server.Config, []parse.ServerBlock, int, error) {
+ var configs []server.Config
+
+ // Each server block represents similar hosts/addresses, since they
+ // were grouped together in the Caddyfile.
+ serverBlocks, err := parse.ServerBlocks(filename, input, true)
+ if err != nil {
+ return nil, nil, 0, err
+ }
+ if len(serverBlocks) == 0 {
+ newInput := DefaultInput()
+ serverBlocks, err = parse.ServerBlocks(newInput.Path(), bytes.NewReader(newInput.Body()), true)
+ if err != nil {
+ return nil, nil, 0, err
+ }
+ }
+
+ var lastDirectiveIndex int // we set up directives in two parts; this stores where we left off
+
+ // Iterate each server block and make a config for each one,
+ // executing the directives that were parsed in order up to the tls
+ // directive; this is because we must activate Let's Encrypt.
+ for i, sb := range serverBlocks {
+ onces := makeOnces()
+ storages := makeStorages()
+
+ for j, addr := range sb.Addresses {
+ config := server.Config{
+ Host: addr.Host,
+ Port: addr.Port,
+ Root: Root,
+ ConfigFile: filename,
+ AppName: AppName,
+ AppVersion: AppVersion,
+ }
+
+ // It is crucial that directives are executed in the proper order.
+ for k, dir := range directiveOrder {
+ // Execute directive if it is in the server block
+ if tokens, ok := sb.Tokens[dir.name]; ok {
+ // Each setup function gets a controller, from which setup functions
+ // get access to the config, tokens, and other state information useful
+ // to set up its own host only.
+ controller := &setup.Controller{
+ Config: &config,
+ Dispenser: parse.NewDispenserTokens(filename, tokens),
+ OncePerServerBlock: func(f func() error) error {
+ var err error
+ onces[dir.name].Do(func() {
+ err = f()
+ })
+ return err
+ },
+ ServerBlockIndex: i,
+ ServerBlockHostIndex: j,
+ ServerBlockHosts: sb.HostList(),
+ ServerBlockStorage: storages[dir.name],
+ }
+ // execute setup function and append middleware handler, if any
+ midware, err := dir.setup(controller)
+ if err != nil {
+ return nil, nil, lastDirectiveIndex, err
+ }
+ if midware != nil {
+ config.Middleware = append(config.Middleware, midware)
+ }
+ storages[dir.name] = controller.ServerBlockStorage // persist for this server block
+ }
+
+ // Stop after TLS setup, since we need to activate Let's Encrypt before continuing;
+ // it makes some changes to the configs that middlewares might want to know about.
+ if dir.name == "tls" {
+ lastDirectiveIndex = k
+ break
+ }
+ }
+
+ configs = append(configs, config)
+ }
+ }
+ return configs, serverBlocks, lastDirectiveIndex, nil
+}
+
+// loadConfigs reads input (named filename) and parses it, returning the
+// server configurations in the order they appeared in the input. As part
+// of this, it activates Let's Encrypt for the configs that are produced.
+// Thus, the returned configs are already optimally configured for HTTPS.
+func loadConfigs(filename string, input io.Reader) ([]server.Config, error) {
+ configs, serverBlocks, lastDirectiveIndex, err := loadConfigsUpToIncludingTLS(filename, input)
+ if err != nil {
+ return nil, err
+ }
+
+ // Now we have all the configs, but they have only been set up to the
+ // point of tls. We need to activate Let's Encrypt before setting up
+ // the rest of the middlewares so they have correct information regarding
+ // TLS configuration, if necessary. (this only appends, so our iterations
+ // over server blocks below shouldn't be affected)
+ if !IsRestart() && !Quiet {
+ fmt.Println("Activating privacy features...")
+ }
+ /* TODO(miek): stopped for now
+ configs, err = https.Activate(configs)
+ if err != nil {
+ return nil, err
+ } else if !IsRestart() && !Quiet {
+ fmt.Println(" done.")
+ }
+ */
+
+ // Finish setting up the rest of the directives, now that TLS is
+ // optimally configured. These loops are similar to above except
+ // we don't iterate all the directives from the beginning and we
+ // don't create new configs.
+ configIndex := -1
+ for i, sb := range serverBlocks {
+ onces := makeOnces()
+ storages := makeStorages()
+
+ for j := range sb.Addresses {
+ configIndex++
+
+ for k := lastDirectiveIndex + 1; k < len(directiveOrder); k++ {
+ dir := directiveOrder[k]
+
+ if tokens, ok := sb.Tokens[dir.name]; ok {
+ controller := &setup.Controller{
+ Config: &configs[configIndex],
+ Dispenser: parse.NewDispenserTokens(filename, tokens),
+ OncePerServerBlock: func(f func() error) error {
+ var err error
+ onces[dir.name].Do(func() {
+ err = f()
+ })
+ return err
+ },
+ ServerBlockIndex: i,
+ ServerBlockHostIndex: j,
+ ServerBlockHosts: sb.HostList(),
+ ServerBlockStorage: storages[dir.name],
+ }
+ midware, err := dir.setup(controller)
+ if err != nil {
+ return nil, err
+ }
+ if midware != nil {
+ configs[configIndex].Middleware = append(configs[configIndex].Middleware, midware)
+ }
+ storages[dir.name] = controller.ServerBlockStorage // persist for this server block
+ }
+ }
+ }
+ }
+
+ return configs, nil
+}
+
+// makeOnces makes a map of directive name to sync.Once
+// instance. This is intended to be called once per server
+// block when setting up configs so that Setup functions
+// for each directive can perform a task just once per
+// server block, even if there are multiple hosts on the block.
+//
+// We need one Once per directive, otherwise the first
+// directive to use it would exclude other directives from
+// using it at all, which would be a bug.
+func makeOnces() map[string]*sync.Once {
+ onces := make(map[string]*sync.Once)
+ for _, dir := range directiveOrder {
+ onces[dir.name] = new(sync.Once)
+ }
+ return onces
+}
+
+// makeStorages makes a map of directive name to interface{}
+// so that directives' setup functions can persist state
+// between different hosts on the same server block during the
+// setup phase.
+func makeStorages() map[string]interface{} {
+ storages := make(map[string]interface{})
+ for _, dir := range directiveOrder {
+ storages[dir.name] = nil
+ }
+ return storages
+}
+
+// arrangeBindings groups configurations by their bind address. For example,
+// a server that should listen on localhost and another on 127.0.0.1 will
+// be grouped into the same address: 127.0.0.1. It will return an error
+// if an address is malformed or a TLS listener is configured on the
+// same address as a plaintext HTTP listener. The return value is a map of
+// bind address to list of configs that would become VirtualHosts on that
+// server. Use the keys of the returned map to create listeners, and use
+// the associated values to set up the virtualhosts.
+func arrangeBindings(allConfigs []server.Config) (bindingGroup, error) {
+ var groupings bindingGroup
+
+ // Group configs by bind address
+ for _, conf := range allConfigs {
+ // use default port if none is specified
+ if conf.Port == "" {
+ conf.Port = Port
+ }
+
+ bindAddr, warnErr, fatalErr := resolveAddr(conf)
+ if fatalErr != nil {
+ return groupings, fatalErr
+ }
+ if warnErr != nil {
+ log.Printf("[WARNING] Resolving bind address for %s: %v", conf.Address(), warnErr)
+ }
+
+ // Make sure to compare the string representation of the address,
+ // not the pointer, since a new *TCPAddr is created each time.
+ var existing bool
+ for i := 0; i < len(groupings); i++ {
+ if groupings[i].BindAddr.String() == bindAddr.String() {
+ groupings[i].Configs = append(groupings[i].Configs, conf)
+ existing = true
+ break
+ }
+ }
+ if !existing {
+ groupings = append(groupings, bindingMapping{
+ BindAddr: bindAddr,
+ Configs: []server.Config{conf},
+ })
+ }
+ }
+
+ // Don't allow HTTP and HTTPS to be served on the same address
+ for _, group := range groupings {
+ isTLS := group.Configs[0].TLS.Enabled
+ for _, config := range group.Configs {
+ if config.TLS.Enabled != isTLS {
+ thisConfigProto, otherConfigProto := "HTTP", "HTTP"
+ if config.TLS.Enabled {
+ thisConfigProto = "HTTPS"
+ }
+ if group.Configs[0].TLS.Enabled {
+ otherConfigProto = "HTTPS"
+ }
+ return groupings, fmt.Errorf("configuration error: Cannot multiplex %s (%s) and %s (%s) on same address",
+ group.Configs[0].Address(), otherConfigProto, config.Address(), thisConfigProto)
+ }
+ }
+ }
+
+ return groupings, nil
+}
+
+// resolveAddr determines the address (host and port) that a config will
+// bind to. The returned address, resolvAddr, should be used to bind the
+// listener or group the config with other configs using the same address.
+// The first error, if not nil, is just a warning and should be reported
+// but execution may continue. The second error, if not nil, is a real
+// problem and the server should not be started.
+//
+// This function does not handle edge cases like port "http" or "https" if
+// they are not known to the system. It does, however, serve on the wildcard
+// host if resolving the address of the specific hostname fails.
+func resolveAddr(conf server.Config) (resolvAddr *net.TCPAddr, warnErr, fatalErr error) {
+ resolvAddr, warnErr = net.ResolveTCPAddr("tcp", net.JoinHostPort(conf.BindHost, conf.Port))
+ if warnErr != nil {
+ // the hostname probably couldn't be resolved, just bind to wildcard then
+ resolvAddr, fatalErr = net.ResolveTCPAddr("tcp", net.JoinHostPort("", conf.Port))
+ if fatalErr != nil {
+ return
+ }
+ }
+
+ return
+}
+
+// validDirective returns true if d is a valid
+// directive; false otherwise.
+func validDirective(d string) bool {
+ for _, dir := range directiveOrder {
+ if dir.name == d {
+ return true
+ }
+ }
+ return false
+}
+
+// DefaultInput returns the default Caddyfile input
+// to use when it is otherwise empty or missing.
+// It uses the default host and port (depends on
+// host, e.g. localhost is 2015, otherwise 443) and
+// root.
+func DefaultInput() CaddyfileInput {
+ port := Port
+ if https.HostQualifies(Host) && port == DefaultPort {
+ port = "443"
+ }
+ return CaddyfileInput{
+ Contents: []byte(fmt.Sprintf("%s:%s\nroot %s", Host, port, Root)),
+ }
+}
+
+// These defaults are configurable through the command line
+var (
+ // Root is the site root
+ Root = DefaultRoot
+
+ // Host is the site host
+ Host = DefaultHost
+
+ // Port is the site port
+ Port = DefaultPort
+)
+
+// bindingMapping maps a network address to configurations
+// that will bind to it. The order of the configs is important.
+type bindingMapping struct {
+ BindAddr *net.TCPAddr
+ Configs []server.Config
+}
+
+// bindingGroup maps network addresses to their configurations.
+// Preserving the order of the groupings is important
+// (related to graceful shutdown and restart)
+// so this is a slice, not a literal map.
+type bindingGroup []bindingMapping
diff --git a/core/config_test.go b/core/config_test.go
new file mode 100644
index 000000000..512128c0b
--- /dev/null
+++ b/core/config_test.go
@@ -0,0 +1,159 @@
+package core
+
+import (
+ "reflect"
+ "sync"
+ "testing"
+
+ "github.com/miekg/coredns/server"
+)
+
+func TestDefaultInput(t *testing.T) {
+ if actual, expected := string(DefaultInput().Body()), ":2015\nroot ."; actual != expected {
+ t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
+ }
+
+ // next few tests simulate user providing -host and/or -port flags
+
+ Host = "not-localhost.com"
+ if actual, expected := string(DefaultInput().Body()), "not-localhost.com:443\nroot ."; actual != expected {
+ t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
+ }
+
+ Host = "[::1]"
+ if actual, expected := string(DefaultInput().Body()), "[::1]:2015\nroot ."; actual != expected {
+ t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
+ }
+
+ Host = "127.0.1.1"
+ if actual, expected := string(DefaultInput().Body()), "127.0.1.1:2015\nroot ."; actual != expected {
+ t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
+ }
+
+ Host = "not-localhost.com"
+ Port = "1234"
+ if actual, expected := string(DefaultInput().Body()), "not-localhost.com:1234\nroot ."; actual != expected {
+ t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
+ }
+
+ Host = DefaultHost
+ Port = "1234"
+ if actual, expected := string(DefaultInput().Body()), ":1234\nroot ."; actual != expected {
+ t.Errorf("Host=%s; Port=%s; Root=%s;\nEXPECTED: '%s'\n ACTUAL: '%s'", Host, Port, Root, expected, actual)
+ }
+}
+
+func TestResolveAddr(t *testing.T) {
+ // NOTE: If tests fail due to comparing to string "127.0.0.1",
+ // it's possible that system env resolves with IPv6, or ::1.
+ // If that happens, maybe we should use actualAddr.IP.IsLoopback()
+ // for the assertion, rather than a direct string comparison.
+
+ // NOTE: Tests with {Host: "", Port: ""} and {Host: "localhost", Port: ""}
+ // will not behave the same cross-platform, so they have been omitted.
+
+ for i, test := range []struct {
+ config server.Config
+ shouldWarnErr bool
+ shouldFatalErr bool
+ expectedIP string
+ expectedPort int
+ }{
+ {server.Config{Host: "127.0.0.1", Port: "1234"}, false, false, "<nil>", 1234},
+ {server.Config{Host: "localhost", Port: "80"}, false, false, "<nil>", 80},
+ {server.Config{BindHost: "localhost", Port: "1234"}, false, false, "127.0.0.1", 1234},
+ {server.Config{BindHost: "127.0.0.1", Port: "1234"}, false, false, "127.0.0.1", 1234},
+ {server.Config{BindHost: "should-not-resolve", Port: "1234"}, true, false, "<nil>", 1234},
+ {server.Config{BindHost: "localhost", Port: "http"}, false, false, "127.0.0.1", 80},
+ {server.Config{BindHost: "localhost", Port: "https"}, false, false, "127.0.0.1", 443},
+ {server.Config{BindHost: "", Port: "1234"}, false, false, "<nil>", 1234},
+ {server.Config{BindHost: "localhost", Port: "abcd"}, false, true, "", 0},
+ {server.Config{BindHost: "127.0.0.1", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234},
+ {server.Config{BindHost: "localhost", Host: "should-not-be-used", Port: "1234"}, false, false, "127.0.0.1", 1234},
+ {server.Config{BindHost: "should-not-resolve", Host: "localhost", Port: "1234"}, true, false, "<nil>", 1234},
+ } {
+ actualAddr, warnErr, fatalErr := resolveAddr(test.config)
+
+ if test.shouldFatalErr && fatalErr == nil {
+ t.Errorf("Test %d: Expected error, but there wasn't any", i)
+ }
+ if !test.shouldFatalErr && fatalErr != nil {
+ t.Errorf("Test %d: Expected no error, but there was one: %v", i, fatalErr)
+ }
+ if fatalErr != nil {
+ continue
+ }
+
+ if test.shouldWarnErr && warnErr == nil {
+ t.Errorf("Test %d: Expected warning, but there wasn't any", i)
+ }
+ if !test.shouldWarnErr && warnErr != nil {
+ t.Errorf("Test %d: Expected no warning, but there was one: %v", i, warnErr)
+ }
+
+ if actual, expected := actualAddr.IP.String(), test.expectedIP; actual != expected {
+ t.Errorf("Test %d: IP was %s but expected %s", i, actual, expected)
+ }
+ if actual, expected := actualAddr.Port, test.expectedPort; actual != expected {
+ t.Errorf("Test %d: Port was %d but expected %d", i, actual, expected)
+ }
+ }
+}
+
+func TestMakeOnces(t *testing.T) {
+ directives := []directive{
+ {"dummy", nil},
+ {"dummy2", nil},
+ }
+ directiveOrder = directives
+ onces := makeOnces()
+ if len(onces) != len(directives) {
+ t.Errorf("onces had len %d , expected %d", len(onces), len(directives))
+ }
+ expected := map[string]*sync.Once{
+ "dummy": new(sync.Once),
+ "dummy2": new(sync.Once),
+ }
+ if !reflect.DeepEqual(onces, expected) {
+ t.Errorf("onces was %v, expected %v", onces, expected)
+ }
+}
+
+func TestMakeStorages(t *testing.T) {
+ directives := []directive{
+ {"dummy", nil},
+ {"dummy2", nil},
+ }
+ directiveOrder = directives
+ storages := makeStorages()
+ if len(storages) != len(directives) {
+ t.Errorf("storages had len %d , expected %d", len(storages), len(directives))
+ }
+ expected := map[string]interface{}{
+ "dummy": nil,
+ "dummy2": nil,
+ }
+ if !reflect.DeepEqual(storages, expected) {
+ t.Errorf("storages was %v, expected %v", storages, expected)
+ }
+}
+
+func TestValidDirective(t *testing.T) {
+ directives := []directive{
+ {"dummy", nil},
+ {"dummy2", nil},
+ }
+ directiveOrder = directives
+ for i, test := range []struct {
+ directive string
+ valid bool
+ }{
+ {"dummy", true},
+ {"dummy2", true},
+ {"dummy3", false},
+ } {
+ if actual, expected := validDirective(test.directive), test.valid; actual != expected {
+ t.Errorf("Test %d: valid was %t, expected %t", i, actual, expected)
+ }
+ }
+}
diff --git a/core/directives.go b/core/directives.go
new file mode 100644
index 000000000..96f4d910c
--- /dev/null
+++ b/core/directives.go
@@ -0,0 +1,89 @@
+package core
+
+import (
+ "github.com/miekg/coredns/core/https"
+ "github.com/miekg/coredns/core/parse"
+ "github.com/miekg/coredns/core/setup"
+ "github.com/miekg/coredns/middleware"
+)
+
+func init() {
+ // The parse package must know which directives
+ // are valid, but it must not import the setup
+ // or config package. To solve this problem, we
+ // fill up this map in our init function here.
+ // The parse package does not need to know the
+ // ordering of the directives.
+ for _, dir := range directiveOrder {
+ parse.ValidDirectives[dir.name] = struct{}{}
+ }
+}
+
+// Directives are registered in the order they should be
+// executed. Middleware (directives that inject a handler)
+// are executed in the order A-B-C-*-C-B-A, assuming
+// they all call the Next handler in the chain.
+//
+// Ordering is VERY important. Every middleware will
+// feel the effects of all other middleware below
+// (after) them during a request, but they must not
+// care what middleware above them are doing.
+//
+// For example, log needs to know the status code and
+// exactly how many bytes were written to the client,
+// which every other middleware can affect, so it gets
+// registered first. The errors middleware does not
+// care if gzip or log modifies its response, so it
+// gets registered below them. Gzip, on the other hand,
+// DOES care what errors does to the response since it
+// must compress every output to the client, even error
+// pages, so it must be registered before the errors
+// middleware and any others that would write to the
+// response.
+var directiveOrder = []directive{
+ // Essential directives that initialize vital configuration settings
+ {"root", setup.Root},
+ {"bind", setup.BindHost},
+ {"tls", https.Setup},
+
+ // Other directives that don't create HTTP handlers
+ {"startup", setup.Startup},
+ {"shutdown", setup.Shutdown},
+
+ // Directives that inject handlers (middleware)
+ {"prometheus", setup.Prometheus},
+ {"rewrite", setup.Rewrite},
+ {"file", setup.File},
+ {"reflect", setup.Reflect},
+ {"log", setup.Log},
+ {"errors", setup.Errors},
+ {"proxy", setup.Proxy},
+}
+
+// RegisterDirective adds the given directive to caddy's list of directives.
+// Pass the name of a directive you want it to be placed after,
+// otherwise it will be placed at the bottom of the stack.
+func RegisterDirective(name string, setup SetupFunc, after string) {
+ dir := directive{name: name, setup: setup}
+ idx := len(directiveOrder)
+ for i := range directiveOrder {
+ if directiveOrder[i].name == after {
+ idx = i + 1
+ break
+ }
+ }
+ newDirectives := append(directiveOrder[:idx], append([]directive{dir}, directiveOrder[idx:]...)...)
+ directiveOrder = newDirectives
+ parse.ValidDirectives[name] = struct{}{}
+}
+
+// directive ties together a directive name with its setup function.
+type directive struct {
+ name string
+ setup SetupFunc
+}
+
+// SetupFunc takes a controller and may optionally return a middleware.
+// If the resulting middleware is not nil, it will be chained into
+// the HTTP handlers in the order specified in this package.
+type SetupFunc func(c *setup.Controller) (middleware.Middleware, error)
diff --git a/core/directives_test.go b/core/directives_test.go
new file mode 100644
index 000000000..1bee144f5
--- /dev/null
+++ b/core/directives_test.go
@@ -0,0 +1,31 @@
+package core
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestRegister(t *testing.T) {
+ directives := []directive{
+ {"dummy", nil},
+ {"dummy2", nil},
+ }
+ directiveOrder = directives
+ RegisterDirective("foo", nil, "dummy")
+ if len(directiveOrder) != 3 {
+ t.Fatal("Should have 3 directives now")
+ }
+ getNames := func() (s []string) {
+ for _, d := range directiveOrder {
+ s = append(s, d.name)
+ }
+ return s
+ }
+ if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2"}) {
+ t.Fatalf("directive order doesn't match: %s", getNames())
+ }
+ RegisterDirective("bar", nil, "ASDASD")
+ if !reflect.DeepEqual(getNames(), []string{"dummy", "foo", "dummy2", "bar"}) {
+ t.Fatalf("directive order doesn't match: %s", getNames())
+ }
+}
diff --git a/core/helpers.go b/core/helpers.go
new file mode 100644
index 000000000..8ef6a54cd
--- /dev/null
+++ b/core/helpers.go
@@ -0,0 +1,102 @@
+package core
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "os"
+ "os/exec"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+)
+
+// isLocalhost returns true if host looks explicitly like a localhost address.
+func isLocalhost(host string) bool {
+ return host == "localhost" || host == "::1" || strings.HasPrefix(host, "127.")
+}
+
+// checkFdlimit issues a warning if the OS max file descriptors is below a recommended minimum.
+func checkFdlimit() {
+ const min = 4096
+
+ // Warn if ulimit is too low for production sites
+ if runtime.GOOS == "linux" || runtime.GOOS == "darwin" {
+ out, err := exec.Command("sh", "-c", "ulimit -n").Output() // use sh because ulimit isn't in Linux $PATH
+ if err == nil {
+ // Note that an error here need not be reported
+ lim, err := strconv.Atoi(string(bytes.TrimSpace(out)))
+ if err == nil && lim < min {
+ fmt.Printf("Warning: File descriptor limit %d is too low for production sites. At least %d is recommended. Set with \"ulimit -n %d\".\n", lim, min, min)
+ }
+ }
+ }
+}
+
+// signalSuccessToParent tells the parent our status using pipe at index 3.
+// If this process is not a restart, this function does nothing.
+// Calling this function once this process has successfully initialized
+// is vital so that the parent process can unblock and kill itself.
+// This function is idempotent; it executes at most once per process.
+func signalSuccessToParent() {
+ signalParentOnce.Do(func() {
+ if IsRestart() {
+ ppipe := os.NewFile(3, "") // parent is reading from pipe at index 3
+ _, err := ppipe.Write([]byte("success")) // we must send some bytes to the parent
+ if err != nil {
+ log.Printf("[ERROR] Communicating successful init to parent: %v", err)
+ }
+ ppipe.Close()
+ }
+ })
+}
+
+// signalParentOnce is used to make sure that the parent is only
+// signaled once; doing so more than once breaks whatever socket is
+// at fd 4 (the reason for this is still unclear - to reproduce,
+// call Stop() and Start() in succession at least once after a
+// restart, then try loading first host of Caddyfile in the browser).
+// Do not use this directly - call signalSuccessToParent instead.
+var signalParentOnce sync.Once
+
+// caddyfileGob maps bind address to index of the file descriptor
+// in the Files array passed to the child process. It also contains
+// the caddyfile contents and other state needed by the new process.
+// Used only during graceful restarts where a new process is spawned.
+type caddyfileGob struct {
+ ListenerFds map[string]uintptr
+ Caddyfile Input
+ OnDemandTLSCertsIssued int32
+}
+
+// IsRestart returns whether this process is, according
+// to env variables, a fork as part of a graceful restart.
+func IsRestart() bool {
+ return os.Getenv("CADDY_RESTART") == "true"
+}
+
+// writePidFile writes the process ID to the file at PidFile, if specified.
+func writePidFile() error {
+ pid := []byte(strconv.Itoa(os.Getpid()) + "\n")
+ return ioutil.WriteFile(PidFile, pid, 0644)
+}
+
+// CaddyfileInput represents a Caddyfile as input
+// and is simply a convenient way to implement
+// the Input interface.
+type CaddyfileInput struct {
+ Filepath string
+ Contents []byte
+ RealFile bool
+}
+
+// Body returns c.Contents.
+func (c CaddyfileInput) Body() []byte { return c.Contents }
+
+// Path returns c.Filepath.
+func (c CaddyfileInput) Path() string { return c.Filepath }
+
+// IsFile returns true if the original input was a real file on the file system.
+func (c CaddyfileInput) IsFile() bool { return c.RealFile }
diff --git a/core/https/certificates.go b/core/https/certificates.go
new file mode 100644
index 000000000..0dc3db523
--- /dev/null
+++ b/core/https/certificates.go
@@ -0,0 +1,234 @@
+package https
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "errors"
+ "io/ioutil"
+ "log"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/xenolf/lego/acme"
+ "golang.org/x/crypto/ocsp"
+)
+
+// certCache stores certificates in memory,
+// keying certificates by name.
+var certCache = make(map[string]Certificate)
+var certCacheMu sync.RWMutex
+
+// Certificate is a tls.Certificate with associated metadata tacked on.
+// Even if the metadata can be obtained by parsing the certificate,
+// we can be more efficient by extracting the metadata once so it's
+// just there, ready to use.
+type Certificate struct {
+ tls.Certificate
+
+ // Names is the list of names this certificate is written for.
+ // The first is the CommonName (if any), the rest are SAN.
+ Names []string
+
+ // NotAfter is when the certificate expires.
+ NotAfter time.Time
+
+ // Managed certificates are certificates that Caddy is managing,
+ // as opposed to the user specifying a certificate and key file
+ // or directory and managing the certificate resources themselves.
+ Managed bool
+
+ // OnDemand certificates are obtained or loaded on-demand during TLS
+ // handshakes (as opposed to preloaded certificates, which are loaded
+ // at startup). If OnDemand is true, Managed must necessarily be true.
+ // OnDemand certificates are maintained in the background just like
+ // preloaded ones, however, if an OnDemand certificate fails to renew,
+ // it is removed from the in-memory cache.
+ OnDemand bool
+
+ // OCSP contains the certificate's parsed OCSP response.
+ OCSP *ocsp.Response
+}
+
+// getCertificate gets a certificate that matches name (a server name)
+// from the in-memory cache. If there is no exact match for name, it
+// will be checked against names of the form '*.example.com' (wildcard
+// certificates) according to RFC 6125. If a match is found, matched will
+// be true. If no matches are found, matched will be false and a default
+// certificate will be returned with defaulted set to true. If no default
+// certificate is set, defaulted will be set to false.
+//
+// The logic in this function is adapted from the Go standard library,
+// which is by the Go Authors.
+//
+// This function is safe for concurrent use.
+func getCertificate(name string) (cert Certificate, matched, defaulted bool) {
+ var ok bool
+
+ // Not going to trim trailing dots here since RFC 3546 says,
+ // "The hostname is represented ... without a trailing dot."
+ // Just normalize to lowercase.
+ name = strings.ToLower(name)
+
+ certCacheMu.RLock()
+ defer certCacheMu.RUnlock()
+
+ // exact match? great, let's use it
+ if cert, ok = certCache[name]; ok {
+ matched = true
+ return
+ }
+
+ // try replacing labels in the name with wildcards until we get a match
+ labels := strings.Split(name, ".")
+ for i := range labels {
+ labels[i] = "*"
+ candidate := strings.Join(labels, ".")
+ if cert, ok = certCache[candidate]; ok {
+ matched = true
+ return
+ }
+ }
+
+ // if nothing matches, use the default certificate or bust
+ cert, defaulted = certCache[""]
+ return
+}
+
+// cacheManagedCertificate loads the certificate for domain into the
+// cache, flagging it as Managed and, if onDemand is true, as OnDemand
+// (meaning that it was obtained or loaded during a TLS handshake).
+//
+// This function is safe for concurrent use.
+func cacheManagedCertificate(domain string, onDemand bool) (Certificate, error) {
+ cert, err := makeCertificateFromDisk(storage.SiteCertFile(domain), storage.SiteKeyFile(domain))
+ if err != nil {
+ return cert, err
+ }
+ cert.Managed = true
+ cert.OnDemand = onDemand
+ cacheCertificate(cert)
+ return cert, nil
+}
+
+// cacheUnmanagedCertificatePEMFile loads a certificate for host using certFile
+// and keyFile, which must be in PEM format. It stores the certificate in
+// memory. The Managed and OnDemand flags of the certificate will be set to
+// false.
+//
+// This function is safe for concurrent use.
+func cacheUnmanagedCertificatePEMFile(certFile, keyFile string) error {
+ cert, err := makeCertificateFromDisk(certFile, keyFile)
+ if err != nil {
+ return err
+ }
+ cacheCertificate(cert)
+ return nil
+}
+
+// cacheUnmanagedCertificatePEMBytes makes a certificate out of the PEM bytes
+// of the certificate and key, then caches it in memory.
+//
+// This function is safe for concurrent use.
+func cacheUnmanagedCertificatePEMBytes(certBytes, keyBytes []byte) error {
+ cert, err := makeCertificate(certBytes, keyBytes)
+ if err != nil {
+ return err
+ }
+ cacheCertificate(cert)
+ return nil
+}
+
+// makeCertificateFromDisk makes a Certificate by loading the
+// certificate and key files. It fills out all the fields in
+// the certificate except for the Managed and OnDemand flags.
+// (It is up to the caller to set those.)
+func makeCertificateFromDisk(certFile, keyFile string) (Certificate, error) {
+ certPEMBlock, err := ioutil.ReadFile(certFile)
+ if err != nil {
+ return Certificate{}, err
+ }
+ keyPEMBlock, err := ioutil.ReadFile(keyFile)
+ if err != nil {
+ return Certificate{}, err
+ }
+ return makeCertificate(certPEMBlock, keyPEMBlock)
+}
+
+// makeCertificate turns a certificate PEM bundle and a key PEM block into
+// a Certificate, with OCSP and other relevant metadata tagged with it,
+// except for the OnDemand and Managed flags. It is up to the caller to
+// set those properties.
+func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
+ var cert Certificate
+
+ // Convert to a tls.Certificate
+ tlsCert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
+ if err != nil {
+ return cert, err
+ }
+ if len(tlsCert.Certificate) == 0 {
+ return cert, errors.New("certificate is empty")
+ }
+
+ // Parse leaf certificate and extract relevant metadata
+ leaf, err := x509.ParseCertificate(tlsCert.Certificate[0])
+ if err != nil {
+ return cert, err
+ }
+ if leaf.Subject.CommonName != "" {
+ cert.Names = []string{strings.ToLower(leaf.Subject.CommonName)}
+ }
+ for _, name := range leaf.DNSNames {
+ if name != leaf.Subject.CommonName {
+ cert.Names = append(cert.Names, strings.ToLower(name))
+ }
+ }
+ cert.NotAfter = leaf.NotAfter
+
+ // Staple OCSP
+ ocspBytes, ocspResp, err := acme.GetOCSPForCert(certPEMBlock)
+ if err != nil {
+ // An error here is not a problem because a certificate may simply
+ // not contain a link to an OCSP server. But we should log it anyway.
+ log.Printf("[WARNING] No OCSP stapling for %v: %v", cert.Names, err)
+ } else if ocspResp.Status == ocsp.Good {
+ tlsCert.OCSPStaple = ocspBytes
+ cert.OCSP = ocspResp
+ }
+
+ cert.Certificate = tlsCert
+ return cert, nil
+}
+
+// cacheCertificate adds cert to the in-memory cache. If the cache is
+// empty, cert will be used as the default certificate. If the cache is
+// full, random entries are deleted until there is room to map all the
+// names on the certificate.
+//
+// This certificate will be keyed to the names in cert.Names. Any name
+// that is already a key in the cache will be replaced with this cert.
+//
+// This function is safe for concurrent use.
+func cacheCertificate(cert Certificate) {
+ certCacheMu.Lock()
+ if _, ok := certCache[""]; !ok {
+ // use as default
+ cert.Names = append(cert.Names, "")
+ certCache[""] = cert
+ }
+ for len(certCache)+len(cert.Names) > 10000 {
+ // for simplicity, just remove random elements
+ for key := range certCache {
+ if key == "" { // ... but not the default cert
+ continue
+ }
+ delete(certCache, key)
+ break
+ }
+ }
+ for _, name := range cert.Names {
+ certCache[name] = cert
+ }
+ certCacheMu.Unlock()
+}
diff --git a/core/https/certificates_test.go b/core/https/certificates_test.go
new file mode 100644
index 000000000..dbfb4efc1
--- /dev/null
+++ b/core/https/certificates_test.go
@@ -0,0 +1,59 @@
+package https
+
+import "testing"
+
+func TestUnexportedGetCertificate(t *testing.T) {
+ defer func() { certCache = make(map[string]Certificate) }()
+
+ // When cache is empty
+ if _, matched, defaulted := getCertificate("example.com"); matched || defaulted {
+ t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
+ }
+
+ // When cache has one certificate in it (also is default)
+ defaultCert := Certificate{Names: []string{"example.com", ""}}
+ certCache[""] = defaultCert
+ certCache["example.com"] = defaultCert
+ if cert, matched, defaulted := getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
+ t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
+ }
+ if cert, matched, defaulted := getCertificate(""); !matched || defaulted || cert.Names[0] != "example.com" {
+ t.Errorf("Didn't get a cert for '' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
+ }
+
+ // When retrieving wildcard certificate
+ certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}}
+ if cert, matched, defaulted := getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
+ t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
+ }
+
+ // When no certificate matches, the default is returned
+ if cert, matched, defaulted := getCertificate("nomatch"); matched || !defaulted {
+ t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
+ } else if cert.Names[0] != "example.com" {
+ t.Errorf("Expected default cert, got: %v", cert)
+ }
+}
+
+func TestCacheCertificate(t *testing.T) {
+ defer func() { certCache = make(map[string]Certificate) }()
+
+ cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}})
+ if _, ok := certCache["example.com"]; !ok {
+ t.Error("Expected first cert to be cached by key 'example.com', but it wasn't")
+ }
+ if _, ok := certCache["sub.example.com"]; !ok {
+ t.Error("Expected first cert to be cached by key 'sub.exmaple.com', but it wasn't")
+ }
+ if cert, ok := certCache[""]; !ok || cert.Names[2] != "" {
+ t.Error("Expected first cert to be cached additionally as the default certificate with empty name added, but it wasn't")
+ }
+
+ cacheCertificate(Certificate{Names: []string{"example2.com"}})
+ if _, ok := certCache["example2.com"]; !ok {
+ t.Error("Expected second cert to be cached by key 'exmaple2.com', but it wasn't")
+ }
+ if cert, ok := certCache[""]; ok && cert.Names[0] == "example2.com" {
+ t.Error("Expected second cert to NOT be cached as default, but it was")
+ }
+}
diff --git a/core/https/client.go b/core/https/client.go
new file mode 100644
index 000000000..e9e8cd82c
--- /dev/null
+++ b/core/https/client.go
@@ -0,0 +1,215 @@
+package https
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/miekg/coredns/server"
+ "github.com/xenolf/lego/acme"
+)
+
+// acmeMu ensures that only one ACME challenge occurs at a time.
+var acmeMu sync.Mutex
+
+// ACMEClient is an acme.Client with custom state attached.
+type ACMEClient struct {
+ *acme.Client
+ AllowPrompts bool // if false, we assume AlternatePort must be used
+}
+
+// NewACMEClient creates a new ACMEClient given an email and whether
+// prompting the user is allowed. Clients should not be kept and
+// re-used over long periods of time, but immediate re-use is more
+// efficient than re-creating on every iteration.
+var NewACMEClient = func(email string, allowPrompts bool) (*ACMEClient, error) {
+ // Look up or create the LE user account
+ leUser, err := getUser(email)
+ if err != nil {
+ return nil, err
+ }
+
+ // The client facilitates our communication with the CA server.
+ client, err := acme.NewClient(CAUrl, &leUser, KeyType)
+ if err != nil {
+ return nil, err
+ }
+
+ // If not registered, the user must register an account with the CA
+ // and agree to terms
+ if leUser.Registration == nil {
+ reg, err := client.Register()
+ if err != nil {
+ return nil, errors.New("registration error: " + err.Error())
+ }
+ leUser.Registration = reg
+
+ if allowPrompts { // can't prompt a user who isn't there
+ if !Agreed && reg.TosURL == "" {
+ Agreed = promptUserAgreement(saURL, false) // TODO - latest URL
+ }
+ if !Agreed && reg.TosURL == "" {
+ return nil, errors.New("user must agree to terms")
+ }
+ }
+
+ err = client.AgreeToTOS()
+ if err != nil {
+ saveUser(leUser) // Might as well try, right?
+ return nil, errors.New("error agreeing to terms: " + err.Error())
+ }
+
+ // save user to the file system
+ err = saveUser(leUser)
+ if err != nil {
+ return nil, errors.New("could not save user: " + err.Error())
+ }
+ }
+
+ return &ACMEClient{
+ Client: client,
+ AllowPrompts: allowPrompts,
+ }, nil
+}
+
+// NewACMEClientGetEmail creates a new ACMEClient and gets an email
+// address at the same time (a server config is required, since it
+// may contain an email address in it).
+func NewACMEClientGetEmail(config server.Config, allowPrompts bool) (*ACMEClient, error) {
+ return NewACMEClient(getEmail(config, allowPrompts), allowPrompts)
+}
+
+// Configure configures c according to bindHost, which is the host (not
+// whole address) to bind the listener to in solving the http and tls-sni
+// challenges.
+func (c *ACMEClient) Configure(bindHost string) {
+ // If we allow prompts, operator must be present. In our case,
+ // that is synonymous with saying the server is not already
+ // started. So if the user is still there, we don't use
+ // AlternatePort because we don't need to proxy the challenges.
+ // Conversely, if the operator is not there, the server has
+ // already started and we need to proxy the challenge.
+ if c.AllowPrompts {
+ // Operator is present; server is not already listening
+ c.SetHTTPAddress(net.JoinHostPort(bindHost, ""))
+ c.SetTLSAddress(net.JoinHostPort(bindHost, ""))
+ //c.ExcludeChallenges([]acme.Challenge{acme.DNS01})
+ } else {
+ // Operator is not present; server is started, so proxy challenges
+ c.SetHTTPAddress(net.JoinHostPort(bindHost, AlternatePort))
+ c.SetTLSAddress(net.JoinHostPort(bindHost, AlternatePort))
+ //c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01})
+ }
+ c.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01}) // TODO: can we proxy TLS challenges? and we should support DNS...
+}
+
+// Obtain obtains a single certificate for names. It stores the certificate
+// on the disk if successful.
+func (c *ACMEClient) Obtain(names []string) error {
+Attempts:
+ for attempts := 0; attempts < 2; attempts++ {
+ acmeMu.Lock()
+ certificate, failures := c.ObtainCertificate(names, true, nil)
+ acmeMu.Unlock()
+ if len(failures) > 0 {
+ // Error - try to fix it or report it to the user and abort
+ var errMsg string // we'll combine all the failures into a single error message
+ var promptedForAgreement bool // only prompt user for agreement at most once
+
+ for errDomain, obtainErr := range failures {
+ // TODO: Double-check, will obtainErr ever be nil?
+ if tosErr, ok := obtainErr.(acme.TOSError); ok {
+ // Terms of Service agreement error; we can probably deal with this
+ if !Agreed && !promptedForAgreement && c.AllowPrompts {
+ Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
+ promptedForAgreement = true
+ }
+ if Agreed || !c.AllowPrompts {
+ err := c.AgreeToTOS()
+ if err != nil {
+ return errors.New("error agreeing to updated terms: " + err.Error())
+ }
+ continue Attempts
+ }
+ }
+
+ // If user did not agree or it was any other kind of error, just append to the list of errors
+ errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
+ }
+ return errors.New(errMsg)
+ }
+
+ // Success - immediately save the certificate resource
+ err := saveCertResource(certificate)
+ if err != nil {
+ return fmt.Errorf("error saving assets for %v: %v", names, err)
+ }
+
+ break
+ }
+
+ return nil
+}
+
+// Renew renews the managed certificate for name. Right now our storage
+// mechanism only supports one name per certificate, so this function only
+// accepts one domain as input. It can be easily modified to support SAN
+// certificates if, one day, they become desperately needed enough that our
+// storage mechanism is upgraded to be more complex to support SAN certs.
+//
+// Anyway, this function is safe for concurrent use.
+func (c *ACMEClient) Renew(name string) error {
+ // Prepare for renewal (load PEM cert, key, and meta)
+ certBytes, err := ioutil.ReadFile(storage.SiteCertFile(name))
+ if err != nil {
+ return err
+ }
+ keyBytes, err := ioutil.ReadFile(storage.SiteKeyFile(name))
+ if err != nil {
+ return err
+ }
+ metaBytes, err := ioutil.ReadFile(storage.SiteMetaFile(name))
+ if err != nil {
+ return err
+ }
+ var certMeta acme.CertificateResource
+ err = json.Unmarshal(metaBytes, &certMeta)
+ certMeta.Certificate = certBytes
+ certMeta.PrivateKey = keyBytes
+
+ // Perform renewal and retry if necessary, but not too many times.
+ var newCertMeta acme.CertificateResource
+ var success bool
+ for attempts := 0; attempts < 2; attempts++ {
+ acmeMu.Lock()
+ newCertMeta, err = c.RenewCertificate(certMeta, true)
+ acmeMu.Unlock()
+ if err == nil {
+ success = true
+ break
+ }
+
+ // If the legal terms changed and need to be agreed to again,
+ // we can handle that.
+ if _, ok := err.(acme.TOSError); ok {
+ err := c.AgreeToTOS()
+ if err != nil {
+ return err
+ }
+ continue
+ }
+
+ // For any other kind of error, wait 10s and try again.
+ time.Sleep(10 * time.Second)
+ }
+
+ if !success {
+ return errors.New("too many renewal attempts; last error: " + err.Error())
+ }
+
+ return saveCertResource(newCertMeta)
+}
diff --git a/core/https/crypto.go b/core/https/crypto.go
new file mode 100644
index 000000000..bc0ff6373
--- /dev/null
+++ b/core/https/crypto.go
@@ -0,0 +1,57 @@
+package https
+
+import (
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/pem"
+ "errors"
+ "io/ioutil"
+ "os"
+)
+
+// loadPrivateKey loads a PEM-encoded ECC/RSA private key from file.
+func loadPrivateKey(file string) (crypto.PrivateKey, error) {
+ keyBytes, err := ioutil.ReadFile(file)
+ if err != nil {
+ return nil, err
+ }
+ keyBlock, _ := pem.Decode(keyBytes)
+
+ switch keyBlock.Type {
+ case "RSA PRIVATE KEY":
+ return x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
+ case "EC PRIVATE KEY":
+ return x509.ParseECPrivateKey(keyBlock.Bytes)
+ }
+
+ return nil, errors.New("unknown private key type")
+}
+
+// savePrivateKey saves a PEM-encoded ECC/RSA private key to file.
+func savePrivateKey(key crypto.PrivateKey, file string) error {
+ var pemType string
+ var keyBytes []byte
+ switch key := key.(type) {
+ case *ecdsa.PrivateKey:
+ var err error
+ pemType = "EC"
+ keyBytes, err = x509.MarshalECPrivateKey(key)
+ if err != nil {
+ return err
+ }
+ case *rsa.PrivateKey:
+ pemType = "RSA"
+ keyBytes = x509.MarshalPKCS1PrivateKey(key)
+ }
+
+ pemKey := pem.Block{Type: pemType + " PRIVATE KEY", Bytes: keyBytes}
+ keyOut, err := os.Create(file)
+ if err != nil {
+ return err
+ }
+ keyOut.Chmod(0600)
+ defer keyOut.Close()
+ return pem.Encode(keyOut, &pemKey)
+}
diff --git a/core/https/crypto_test.go b/core/https/crypto_test.go
new file mode 100644
index 000000000..c1f32b27d
--- /dev/null
+++ b/core/https/crypto_test.go
@@ -0,0 +1,111 @@
+package https
+
+import (
+ "bytes"
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "os"
+ "runtime"
+ "testing"
+)
+
+func TestSaveAndLoadRSAPrivateKey(t *testing.T) {
+ keyFile := "test.key"
+ defer os.Remove(keyFile)
+
+ privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // test save
+ err = savePrivateKey(privateKey, keyFile)
+ if err != nil {
+ t.Fatal("error saving private key:", err)
+ }
+
+ // it doesn't make sense to test file permission on windows
+ if runtime.GOOS != "windows" {
+ // get info of the key file
+ info, err := os.Stat(keyFile)
+ if err != nil {
+ t.Fatal("error stating private key:", err)
+ }
+ // verify permission of key file is correct
+ if info.Mode().Perm() != 0600 {
+ t.Error("Expected key file to have permission 0600, but it wasn't")
+ }
+ }
+
+ // test load
+ loadedKey, err := loadPrivateKey(keyFile)
+ if err != nil {
+ t.Error("error loading private key:", err)
+ }
+
+ // verify loaded key is correct
+ if !PrivateKeysSame(privateKey, loadedKey) {
+ t.Error("Expected key bytes to be the same, but they weren't")
+ }
+}
+
+func TestSaveAndLoadECCPrivateKey(t *testing.T) {
+ keyFile := "test.key"
+ defer os.Remove(keyFile)
+
+ privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // test save
+ err = savePrivateKey(privateKey, keyFile)
+ if err != nil {
+ t.Fatal("error saving private key:", err)
+ }
+
+ // it doesn't make sense to test file permission on windows
+ if runtime.GOOS != "windows" {
+ // get info of the key file
+ info, err := os.Stat(keyFile)
+ if err != nil {
+ t.Fatal("error stating private key:", err)
+ }
+ // verify permission of key file is correct
+ if info.Mode().Perm() != 0600 {
+ t.Error("Expected key file to have permission 0600, but it wasn't")
+ }
+ }
+
+ // test load
+ loadedKey, err := loadPrivateKey(keyFile)
+ if err != nil {
+ t.Error("error loading private key:", err)
+ }
+
+ // verify loaded key is correct
+ if !PrivateKeysSame(privateKey, loadedKey) {
+ t.Error("Expected key bytes to be the same, but they weren't")
+ }
+}
+
+// PrivateKeysSame compares the bytes of a and b and returns true if they are the same.
+func PrivateKeysSame(a, b crypto.PrivateKey) bool {
+ return bytes.Equal(PrivateKeyBytes(a), PrivateKeyBytes(b))
+}
+
+// PrivateKeyBytes returns the bytes of DER-encoded key.
+func PrivateKeyBytes(key crypto.PrivateKey) []byte {
+ var keyBytes []byte
+ switch key := key.(type) {
+ case *rsa.PrivateKey:
+ keyBytes = x509.MarshalPKCS1PrivateKey(key)
+ case *ecdsa.PrivateKey:
+ keyBytes, _ = x509.MarshalECPrivateKey(key)
+ }
+ return keyBytes
+}
diff --git a/core/https/handler.go b/core/https/handler.go
new file mode 100644
index 000000000..f3139f54e
--- /dev/null
+++ b/core/https/handler.go
@@ -0,0 +1,42 @@
+package https
+
+import (
+ "crypto/tls"
+ "log"
+ "net/http"
+ "net/http/httputil"
+ "net/url"
+ "strings"
+)
+
+const challengeBasePath = "/.well-known/acme-challenge"
+
+// RequestCallback proxies challenge requests to ACME client if the
+// request path starts with challengeBasePath. It returns true if it
+// handled the request and no more needs to be done; it returns false
+// if this call was a no-op and the request still needs handling.
+func RequestCallback(w http.ResponseWriter, r *http.Request) bool {
+ if strings.HasPrefix(r.URL.Path, challengeBasePath) {
+ scheme := "http"
+ if r.TLS != nil {
+ scheme = "https"
+ }
+
+ upstream, err := url.Parse(scheme + "://localhost:" + AlternatePort)
+ if err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ log.Printf("[ERROR] ACME proxy handler: %v", err)
+ return true
+ }
+
+ proxy := httputil.NewSingleHostReverseProxy(upstream)
+ proxy.Transport = &http.Transport{
+ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // solver uses self-signed certs
+ }
+ proxy.ServeHTTP(w, r)
+
+ return true
+ }
+
+ return false
+}
diff --git a/core/https/handler_test.go b/core/https/handler_test.go
new file mode 100644
index 000000000..016799ffb
--- /dev/null
+++ b/core/https/handler_test.go
@@ -0,0 +1,63 @@
+package https
+
+import (
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestRequestCallbackNoOp(t *testing.T) {
+ // try base paths that aren't handled by this handler
+ for _, url := range []string{
+ "http://localhost/",
+ "http://localhost/foo.html",
+ "http://localhost/.git",
+ "http://localhost/.well-known/",
+ "http://localhost/.well-known/acme-challenging",
+ } {
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ t.Fatalf("Could not craft request, got error: %v", err)
+ }
+ rw := httptest.NewRecorder()
+ if RequestCallback(rw, req) {
+ t.Errorf("Got true with this URL, but shouldn't have: %s", url)
+ }
+ }
+}
+
+func TestRequestCallbackSuccess(t *testing.T) {
+ expectedPath := challengeBasePath + "/asdf"
+
+ // Set up fake acme handler backend to make sure proxying succeeds
+ var proxySuccess bool
+ ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ proxySuccess = true
+ if r.URL.Path != expectedPath {
+ t.Errorf("Expected path '%s' but got '%s' instead", expectedPath, r.URL.Path)
+ }
+ }))
+
+ // Custom listener that uses the port we expect
+ ln, err := net.Listen("tcp", "127.0.0.1:"+AlternatePort)
+ if err != nil {
+ t.Fatalf("Unable to start test server listener: %v", err)
+ }
+ ts.Listener = ln
+
+ // Start our engines and run the test
+ ts.Start()
+ defer ts.Close()
+ req, err := http.NewRequest("GET", "http://127.0.0.1:"+AlternatePort+expectedPath, nil)
+ if err != nil {
+ t.Fatalf("Could not craft request, got error: %v", err)
+ }
+ rw := httptest.NewRecorder()
+
+ RequestCallback(rw, req)
+
+ if !proxySuccess {
+ t.Fatal("Expected request to be proxied, but it wasn't")
+ }
+}
diff --git a/core/https/handshake.go b/core/https/handshake.go
new file mode 100644
index 000000000..4c1fc22c3
--- /dev/null
+++ b/core/https/handshake.go
@@ -0,0 +1,320 @@
+package https
+
+import (
+ "bytes"
+ "crypto/tls"
+ "encoding/pem"
+ "errors"
+ "fmt"
+ "log"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/miekg/coredns/server"
+ "github.com/xenolf/lego/acme"
+)
+
+// GetCertificate gets a certificate to satisfy clientHello as long as
+// the certificate is already cached in memory. It will not be loaded
+// from disk or obtained from the CA during the handshake.
+//
+// This function is safe for use as a tls.Config.GetCertificate callback.
+func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ cert, err := getCertDuringHandshake(clientHello.ServerName, false, false)
+ return &cert.Certificate, err
+}
+
+// GetOrObtainCertificate will get a certificate to satisfy clientHello, even
+// if that means obtaining a new certificate from a CA during the handshake.
+// It first checks the in-memory cache, then accesses disk, then accesses the
+// network if it must. An obtained certificate will be stored on disk and
+// cached in memory.
+//
+// This function is safe for use as a tls.Config.GetCertificate callback.
+func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ cert, err := getCertDuringHandshake(clientHello.ServerName, true, true)
+ return &cert.Certificate, err
+}
+
+// getCertDuringHandshake will get a certificate for name. It first tries
+// the in-memory cache. If no certificate for name is in the cache and if
+// loadIfNecessary == true, it goes to disk to load it into the cache and
+// serve it. If it's not on disk and if obtainIfNecessary == true, the
+// certificate will be obtained from the CA, cached, and served. If
+// obtainIfNecessary is true, then loadIfNecessary must also be set to true.
+// An error will be returned if and only if no certificate is available.
+//
+// This function is safe for concurrent use.
+func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
+ // First check our in-memory cache to see if we've already loaded it
+ cert, matched, defaulted := getCertificate(name)
+ if matched {
+ return cert, nil
+ }
+
+ if loadIfNecessary {
+ // Then check to see if we have one on disk
+ loadedCert, err := cacheManagedCertificate(name, true)
+ if err == nil {
+ loadedCert, err = handshakeMaintenance(name, loadedCert)
+ if err != nil {
+ log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err)
+ }
+ return loadedCert, nil
+ }
+
+ if obtainIfNecessary {
+ // By this point, we need to ask the CA for a certificate
+
+ name = strings.ToLower(name)
+
+ // Make sure aren't over any applicable limits
+ err := checkLimitsForObtainingNewCerts(name)
+ if err != nil {
+ return Certificate{}, err
+ }
+
+ // Name has to qualify for a certificate
+ if !HostQualifies(name) {
+ return cert, errors.New("hostname '" + name + "' does not qualify for certificate")
+ }
+
+ // Obtain certificate from the CA
+ return obtainOnDemandCertificate(name)
+ }
+ }
+
+ if defaulted {
+ return cert, nil
+ }
+
+ return Certificate{}, errors.New("no certificate for " + name)
+}
+
+// checkLimitsForObtainingNewCerts checks to see if name can be issued right
+// now according to mitigating factors we keep track of and preferences the
+// user has set. If a non-nil error is returned, do not issue a new certificate
+// for name.
+func checkLimitsForObtainingNewCerts(name string) error {
+ // User can set hard limit for number of certs for the process to issue
+ if onDemandMaxIssue > 0 && atomic.LoadInt32(OnDemandIssuedCount) >= onDemandMaxIssue {
+ return fmt.Errorf("%s: maximum certificates issued (%d)", name, onDemandMaxIssue)
+ }
+
+ // Make sure name hasn't failed a challenge recently
+ failedIssuanceMu.RLock()
+ when, ok := failedIssuance[name]
+ failedIssuanceMu.RUnlock()
+ if ok {
+ return fmt.Errorf("%s: throttled; refusing to issue cert since last attempt on %s failed", name, when.String())
+ }
+
+ // Make sure, if we've issued a few certificates already, that we haven't
+ // issued any recently
+ lastIssueTimeMu.Lock()
+ since := time.Since(lastIssueTime)
+ lastIssueTimeMu.Unlock()
+ if atomic.LoadInt32(OnDemandIssuedCount) >= 10 && since < 10*time.Minute {
+ return fmt.Errorf("%s: throttled; last certificate was obtained %v ago", name, since)
+ }
+
+ // 👍Good to go
+ return nil
+}
+
+// obtainOnDemandCertificate obtains a certificate for name for the given
+// name. If another goroutine has already started obtaining a cert for
+// name, it will wait and use what the other goroutine obtained.
+//
+// This function is safe for use by multiple concurrent goroutines.
+func obtainOnDemandCertificate(name string) (Certificate, error) {
+ // We must protect this process from happening concurrently, so synchronize.
+ obtainCertWaitChansMu.Lock()
+ wait, ok := obtainCertWaitChans[name]
+ if ok {
+ // lucky us -- another goroutine is already obtaining the certificate.
+ // wait for it to finish obtaining the cert and then we'll use it.
+ obtainCertWaitChansMu.Unlock()
+ <-wait
+ return getCertDuringHandshake(name, true, false)
+ }
+
+ // looks like it's up to us to do all the work and obtain the cert
+ wait = make(chan struct{})
+ obtainCertWaitChans[name] = wait
+ obtainCertWaitChansMu.Unlock()
+
+ // Unblock waiters and delete waitgroup when we return
+ defer func() {
+ obtainCertWaitChansMu.Lock()
+ close(wait)
+ delete(obtainCertWaitChans, name)
+ obtainCertWaitChansMu.Unlock()
+ }()
+
+ log.Printf("[INFO] Obtaining new certificate for %s", name)
+
+ // obtain cert
+ client, err := NewACMEClientGetEmail(server.Config{}, false)
+ if err != nil {
+ return Certificate{}, errors.New("error creating client: " + err.Error())
+ }
+ client.Configure("") // TODO: which BindHost?
+ err = client.Obtain([]string{name})
+ if err != nil {
+ // Failed to solve challenge, so don't allow another on-demand
+ // issue for this name to be attempted for a little while.
+ failedIssuanceMu.Lock()
+ failedIssuance[name] = time.Now()
+ go func(name string) {
+ time.Sleep(5 * time.Minute)
+ failedIssuanceMu.Lock()
+ delete(failedIssuance, name)
+ failedIssuanceMu.Unlock()
+ }(name)
+ failedIssuanceMu.Unlock()
+ return Certificate{}, err
+ }
+
+ // Success - update counters and stuff
+ atomic.AddInt32(OnDemandIssuedCount, 1)
+ lastIssueTimeMu.Lock()
+ lastIssueTime = time.Now()
+ lastIssueTimeMu.Unlock()
+
+ // The certificate is already on disk; now just start over to load it and serve it
+ return getCertDuringHandshake(name, true, false)
+}
+
+// handshakeMaintenance performs a check on cert for expiration and OCSP
+// validity.
+//
+// This function is safe for use by multiple concurrent goroutines.
+func handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
+ // Check cert expiration
+ timeLeft := cert.NotAfter.Sub(time.Now().UTC())
+ if timeLeft < renewDurationBefore {
+ log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
+ return renewDynamicCertificate(name)
+ }
+
+ // Check OCSP staple validity
+ if cert.OCSP != nil {
+ refreshTime := cert.OCSP.ThisUpdate.Add(cert.OCSP.NextUpdate.Sub(cert.OCSP.ThisUpdate) / 2)
+ if time.Now().After(refreshTime) {
+ err := stapleOCSP(&cert, nil)
+ if err != nil {
+ // An error with OCSP stapling is not the end of the world, and in fact, is
+ // quite common considering not all certs have issuer URLs that support it.
+ log.Printf("[ERROR] Getting OCSP for %s: %v", name, err)
+ }
+ certCacheMu.Lock()
+ certCache[name] = cert
+ certCacheMu.Unlock()
+ }
+ }
+
+ return cert, nil
+}
+
+// renewDynamicCertificate renews currentCert using the clientHello. It returns the
+// certificate to use and an error, if any. currentCert may be returned even if an
+// error occurs, since we perform renewals before they expire and it may still be
+// usable. name should already be lower-cased before calling this function.
+//
+// This function is safe for use by multiple concurrent goroutines.
+func renewDynamicCertificate(name string) (Certificate, error) {
+ obtainCertWaitChansMu.Lock()
+ wait, ok := obtainCertWaitChans[name]
+ if ok {
+ // lucky us -- another goroutine is already renewing the certificate.
+ // wait for it to finish, then we'll use the new one.
+ obtainCertWaitChansMu.Unlock()
+ <-wait
+ return getCertDuringHandshake(name, true, false)
+ }
+
+ // looks like it's up to us to do all the work and renew the cert
+ wait = make(chan struct{})
+ obtainCertWaitChans[name] = wait
+ obtainCertWaitChansMu.Unlock()
+
+ // unblock waiters and delete waitgroup when we return
+ defer func() {
+ obtainCertWaitChansMu.Lock()
+ close(wait)
+ delete(obtainCertWaitChans, name)
+ obtainCertWaitChansMu.Unlock()
+ }()
+
+ log.Printf("[INFO] Renewing certificate for %s", name)
+
+ client, err := NewACMEClientGetEmail(server.Config{}, false)
+ if err != nil {
+ return Certificate{}, err
+ }
+ client.Configure("") // TODO: Bind address of relevant listener, yuck
+ err = client.Renew(name)
+ if err != nil {
+ return Certificate{}, err
+ }
+
+ return getCertDuringHandshake(name, true, false)
+}
+
+// stapleOCSP staples OCSP information to cert for hostname name.
+// If you have it handy, you should pass in the PEM-encoded certificate
+// bundle; otherwise the DER-encoded cert will have to be PEM-encoded.
+// If you don't have the PEM blocks handy, just pass in nil.
+//
+// Errors here are not necessarily fatal, it could just be that the
+// certificate doesn't have an issuer URL.
+func stapleOCSP(cert *Certificate, pemBundle []byte) error {
+ if pemBundle == nil {
+ // The function in the acme package that gets OCSP requires a PEM-encoded cert
+ bundle := new(bytes.Buffer)
+ for _, derBytes := range cert.Certificate.Certificate {
+ pem.Encode(bundle, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+ }
+ pemBundle = bundle.Bytes()
+ }
+
+ ocspBytes, ocspResp, err := acme.GetOCSPForCert(pemBundle)
+ if err != nil {
+ return err
+ }
+
+ cert.Certificate.OCSPStaple = ocspBytes
+ cert.OCSP = ocspResp
+
+ return nil
+}
+
+// obtainCertWaitChans is used to coordinate obtaining certs for each hostname.
+var obtainCertWaitChans = make(map[string]chan struct{})
+var obtainCertWaitChansMu sync.Mutex
+
+// OnDemandIssuedCount is the number of certificates that have been issued
+// on-demand by this process. It is only safe to modify this count atomically.
+// If it reaches onDemandMaxIssue, on-demand issuances will fail.
+var OnDemandIssuedCount = new(int32)
+
+// onDemandMaxIssue is set based on max_certs in tls config. It specifies the
+// maximum number of certificates that can be issued.
+// TODO: This applies globally, but we should probably make a server-specific
+// way to keep track of these limits and counts, since it's specified in the
+// Caddyfile...
+var onDemandMaxIssue int32
+
+// failedIssuance is a set of names that we recently failed to get a
+// certificate for from the ACME CA. They are removed after some time.
+// When a name is in this map, do not issue a certificate for it on-demand.
+var failedIssuance = make(map[string]time.Time)
+var failedIssuanceMu sync.RWMutex
+
+// lastIssueTime records when we last obtained a certificate successfully.
+// If this value is recent, do not make any on-demand certificate requests.
+var lastIssueTime time.Time
+var lastIssueTimeMu sync.Mutex
diff --git a/core/https/handshake_test.go b/core/https/handshake_test.go
new file mode 100644
index 000000000..cf70eb17d
--- /dev/null
+++ b/core/https/handshake_test.go
@@ -0,0 +1,54 @@
+package https
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "testing"
+)
+
+func TestGetCertificate(t *testing.T) {
+ defer func() { certCache = make(map[string]Certificate) }()
+
+ hello := &tls.ClientHelloInfo{ServerName: "example.com"}
+ helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
+ helloNoSNI := &tls.ClientHelloInfo{}
+ helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"}
+
+ // When cache is empty
+ if cert, err := GetCertificate(hello); err == nil {
+ t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert)
+ }
+ if cert, err := GetCertificate(helloNoSNI); err == nil {
+ t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
+ }
+
+ // When cache has one certificate in it (also is default)
+ defaultCert := Certificate{Names: []string{"example.com", ""}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
+ certCache[""] = defaultCert
+ certCache["example.com"] = defaultCert
+ if cert, err := GetCertificate(hello); err != nil {
+ t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
+ } else if cert.Leaf.DNSNames[0] != "example.com" {
+ t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
+ }
+ if cert, err := GetCertificate(helloNoSNI); err != nil {
+ t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
+ } else if cert.Leaf.DNSNames[0] != "example.com" {
+ t.Errorf("Got wrong certificate for no SNI; expected 'example.com' as default, got: %v", cert)
+ }
+
+ // When retrieving wildcard certificate
+ certCache["*.example.com"] = Certificate{Names: []string{"*.example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}}}
+ if cert, err := GetCertificate(helloSub); err != nil {
+ t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
+ } else if cert.Leaf.DNSNames[0] != "*.example.com" {
+ t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
+ }
+
+ // When no certificate matches, the default is returned
+ if cert, err := GetCertificate(helloNoMatch); err != nil {
+ t.Errorf("Expected default certificate with no error when no matches, got err: %v", err)
+ } else if cert.Leaf.DNSNames[0] != "example.com" {
+ t.Errorf("Expected default cert with no matches, got: %v", cert)
+ }
+}
diff --git a/core/https/https.go b/core/https/https.go
new file mode 100644
index 000000000..0deb88b86
--- /dev/null
+++ b/core/https/https.go
@@ -0,0 +1,358 @@
+// Package https facilitates the management of TLS assets and integrates
+// Let's Encrypt functionality into Caddy with first-class support for
+// creating and renewing certificates automatically. It is designed to
+// configure sites for HTTPS by default.
+package https
+
+import (
+ "encoding/json"
+ "errors"
+ "io/ioutil"
+ "net"
+ "os"
+ "strings"
+
+ "github.com/miekg/coredns/server"
+ "github.com/xenolf/lego/acme"
+)
+
+// Activate sets up TLS for each server config in configs
+// as needed; this consists of acquiring and maintaining
+// certificates and keys for qualifying configs and enabling
+// OCSP stapling for all TLS-enabled configs.
+//
+// This function may prompt the user to provide an email
+// address if none is available through other means. It
+// prefers the email address specified in the config, but
+// if that is not available it will check the command line
+// argument. If absent, it will use the most recent email
+// address from last time. If there isn't one, the user
+// will be prompted and shown SA link.
+//
+// Also note that calling this function activates asset
+// management automatically, which keeps certificates
+// renewed and OCSP stapling updated.
+//
+// Activate returns the updated list of configs, since
+// some may have been appended, for example, to redirect
+// plaintext HTTP requests to their HTTPS counterpart.
+// This function only appends; it does not splice.
+func Activate(configs []server.Config) ([]server.Config, error) {
+ // just in case previous caller forgot...
+ Deactivate()
+
+ // pre-screen each config and earmark the ones that qualify for managed TLS
+ MarkQualified(configs)
+
+ // place certificates and keys on disk
+ err := ObtainCerts(configs, true, false)
+ if err != nil {
+ return configs, err
+ }
+
+ // update TLS configurations
+ err = EnableTLS(configs, true)
+ if err != nil {
+ return configs, err
+ }
+
+ // renew all relevant certificates that need renewal. this is important
+ // to do right away for a couple reasons, mainly because each restart,
+ // the renewal ticker is reset, so if restarts happen more often than
+ // the ticker interval, renewals would never happen. but doing
+ // it right away at start guarantees that renewals aren't missed.
+ err = renewManagedCertificates(true)
+ if err != nil {
+ return configs, err
+ }
+
+ // keep certificates renewed and OCSP stapling updated
+ go maintainAssets(stopChan)
+
+ return configs, nil
+}
+
+// Deactivate cleans up long-term, in-memory resources
+// allocated by calling Activate(). Essentially, it stops
+// the asset maintainer from running, meaning that certificates
+// will not be renewed, OCSP staples will not be updated, etc.
+func Deactivate() (err error) {
+ defer func() {
+ if rec := recover(); rec != nil {
+ err = errors.New("already deactivated")
+ }
+ }()
+ close(stopChan)
+ stopChan = make(chan struct{})
+ return
+}
+
+// MarkQualified scans each config and, if it qualifies for managed
+// TLS, it sets the Managed field of the TLSConfig to true.
+func MarkQualified(configs []server.Config) {
+ for i := 0; i < len(configs); i++ {
+ if ConfigQualifies(configs[i]) {
+ configs[i].TLS.Managed = true
+ }
+ }
+}
+
+// ObtainCerts obtains certificates for all these configs as long as a
+// certificate does not already exist on disk. It does not modify the
+// configs at all; it only obtains and stores certificates and keys to
+// the disk. If allowPrompts is true, the user may be shown a prompt.
+// If proxyACME is true, the ACME challenges will be proxied to our alt port.
+func ObtainCerts(configs []server.Config, allowPrompts, proxyACME bool) error {
+ // We group configs by email so we don't make the same clients over and
+ // over. This has the potential to prompt the user for an email, but we
+ // prevent that by assuming that if we already have a listener that can
+ // proxy ACME challenge requests, then the server is already running and
+ // the operator is no longer present.
+ groupedConfigs := groupConfigsByEmail(configs, allowPrompts)
+
+ for email, group := range groupedConfigs {
+ // Wait as long as we can before creating the client, because it
+ // may not be needed, for example, if we already have what we
+ // need on disk. Creating a client involves the network and
+ // potentially prompting the user, etc., so only do if necessary.
+ var client *ACMEClient
+
+ for _, cfg := range group {
+ if !HostQualifies(cfg.Host) || existingCertAndKey(cfg.Host) {
+ continue
+ }
+
+ // Now we definitely do need a client
+ if client == nil {
+ var err error
+ client, err = NewACMEClient(email, allowPrompts)
+ if err != nil {
+ return errors.New("error creating client: " + err.Error())
+ }
+ }
+
+ // c.Configure assumes that allowPrompts == !proxyACME,
+ // but that's not always true. For example, a restart where
+ // the user isn't present and we're not listening on port 80.
+ // TODO: This could probably be refactored better.
+ if proxyACME {
+ client.SetHTTPAddress(net.JoinHostPort(cfg.BindHost, AlternatePort))
+ client.SetTLSAddress(net.JoinHostPort(cfg.BindHost, AlternatePort))
+ client.ExcludeChallenges([]acme.Challenge{acme.TLSSNI01, acme.DNS01})
+ } else {
+ client.SetHTTPAddress(net.JoinHostPort(cfg.BindHost, ""))
+ client.SetTLSAddress(net.JoinHostPort(cfg.BindHost, ""))
+ client.ExcludeChallenges([]acme.Challenge{acme.DNS01})
+ }
+
+ err := client.Obtain([]string{cfg.Host})
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+// groupConfigsByEmail groups configs by the email address to be used by an
+// ACME client. It only groups configs that have TLS enabled and that are
+// marked as Managed. If userPresent is true, the operator MAY be prompted
+// for an email address.
+func groupConfigsByEmail(configs []server.Config, userPresent bool) map[string][]server.Config {
+ initMap := make(map[string][]server.Config)
+ for _, cfg := range configs {
+ if !cfg.TLS.Managed {
+ continue
+ }
+ leEmail := getEmail(cfg, userPresent)
+ initMap[leEmail] = append(initMap[leEmail], cfg)
+ }
+ return initMap
+}
+
+// EnableTLS configures each config to use TLS according to default settings.
+// It will only change configs that are marked as managed, and assumes that
+// certificates and keys are already on disk. If loadCertificates is true,
+// the certificates will be loaded from disk into the cache for this process
+// to use. If false, TLS will still be enabled and configured with default
+// settings, but no certificates will be parsed loaded into the cache, and
+// the returned error value will always be nil.
+func EnableTLS(configs []server.Config, loadCertificates bool) error {
+ for i := 0; i < len(configs); i++ {
+ if !configs[i].TLS.Managed {
+ continue
+ }
+ configs[i].TLS.Enabled = true
+ if loadCertificates && HostQualifies(configs[i].Host) {
+ _, err := cacheManagedCertificate(configs[i].Host, false)
+ if err != nil {
+ return err
+ }
+ }
+ setDefaultTLSParams(&configs[i])
+ }
+ return nil
+}
+
+// hostHasOtherPort returns true if there is another config in the list with the same
+// hostname that has port otherPort, or false otherwise. All the configs are checked
+// against the hostname of allConfigs[thisConfigIdx].
+func hostHasOtherPort(allConfigs []server.Config, thisConfigIdx int, otherPort string) bool {
+ for i, otherCfg := range allConfigs {
+ if i == thisConfigIdx {
+ continue // has to be a config OTHER than the one we're comparing against
+ }
+ if otherCfg.Host == allConfigs[thisConfigIdx].Host && otherCfg.Port == otherPort {
+ return true
+ }
+ }
+ return false
+}
+
+// ConfigQualifies returns true if cfg qualifies for
+// fully managed TLS (but not on-demand TLS, which is
+// not considered here). It does NOT check to see if a
+// cert and key already exist for the config. If the
+// config does qualify, you should set cfg.TLS.Managed
+// to true and check that instead, because the process of
+// setting up the config may make it look like it
+// doesn't qualify even though it originally did.
+func ConfigQualifies(cfg server.Config) bool {
+ return (!cfg.TLS.Manual || cfg.TLS.OnDemand) && // user might provide own cert and key
+
+ // user can force-disable automatic HTTPS for this host
+ cfg.Port != "80" &&
+ cfg.TLS.LetsEncryptEmail != "off" &&
+
+ // we get can't certs for some kinds of hostnames, but
+ // on-demand TLS allows empty hostnames at startup
+ (HostQualifies(cfg.Host) || cfg.TLS.OnDemand)
+}
+
+// HostQualifies returns true if the hostname alone
+// appears eligible for automatic HTTPS. For example,
+// localhost, empty hostname, and IP addresses are
+// not eligible because we cannot obtain certificates
+// for those names.
+func HostQualifies(hostname string) bool {
+ return hostname != "localhost" && // localhost is ineligible
+
+ // hostname must not be empty
+ strings.TrimSpace(hostname) != "" &&
+
+ // cannot be an IP address, see
+ // https://community.letsencrypt.org/t/certificate-for-static-ip/84/2?u=mholt
+ // (also trim [] from either end, since that special case can sneak through
+ // for IPv6 addresses using the -host flag and with empty/no Caddyfile)
+ net.ParseIP(strings.Trim(hostname, "[]")) == nil
+}
+
+// existingCertAndKey returns true if the host has a certificate
+// and private key in storage already, false otherwise.
+func existingCertAndKey(host string) bool {
+ _, err := os.Stat(storage.SiteCertFile(host))
+ if err != nil {
+ return false
+ }
+ _, err = os.Stat(storage.SiteKeyFile(host))
+ if err != nil {
+ return false
+ }
+ return true
+}
+
+// saveCertResource saves the certificate resource to disk. This
+// includes the certificate file itself, the private key, and the
+// metadata file.
+func saveCertResource(cert acme.CertificateResource) error {
+ err := os.MkdirAll(storage.Site(cert.Domain), 0700)
+ if err != nil {
+ return err
+ }
+
+ // Save cert
+ err = ioutil.WriteFile(storage.SiteCertFile(cert.Domain), cert.Certificate, 0600)
+ if err != nil {
+ return err
+ }
+
+ // Save private key
+ err = ioutil.WriteFile(storage.SiteKeyFile(cert.Domain), cert.PrivateKey, 0600)
+ if err != nil {
+ return err
+ }
+
+ // Save cert metadata
+ jsonBytes, err := json.MarshalIndent(&cert, "", "\t")
+ if err != nil {
+ return err
+ }
+ err = ioutil.WriteFile(storage.SiteMetaFile(cert.Domain), jsonBytes, 0600)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// Revoke revokes the certificate for host via ACME protocol.
+func Revoke(host string) error {
+ if !existingCertAndKey(host) {
+ return errors.New("no certificate and key for " + host)
+ }
+
+ email := getEmail(server.Config{Host: host}, true)
+ if email == "" {
+ return errors.New("email is required to revoke")
+ }
+
+ client, err := NewACMEClient(email, true)
+ if err != nil {
+ return err
+ }
+
+ certFile := storage.SiteCertFile(host)
+ certBytes, err := ioutil.ReadFile(certFile)
+ if err != nil {
+ return err
+ }
+
+ err = client.RevokeCertificate(certBytes)
+ if err != nil {
+ return err
+ }
+
+ err = os.Remove(certFile)
+ if err != nil {
+ return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error())
+ }
+
+ return nil
+}
+
+var (
+ // DefaultEmail represents the Let's Encrypt account email to use if none provided
+ DefaultEmail string
+
+ // Agreed indicates whether user has agreed to the Let's Encrypt SA
+ Agreed bool
+
+ // CAUrl represents the base URL to the CA's ACME endpoint
+ CAUrl string
+)
+
+// AlternatePort is the port on which the acme client will open a
+// listener and solve the CA's challenges. If this alternate port
+// is used instead of the default port (80 or 443), then the
+// default port for the challenge must be forwarded to this one.
+const AlternatePort = "5033"
+
+// KeyType is the type to use for new keys.
+// This shouldn't need to change except for in tests;
+// the size can be drastically reduced for speed.
+var KeyType = acme.EC384
+
+// stopChan is used to signal the maintenance goroutine
+// to terminate.
+var stopChan chan struct{}
diff --git a/core/https/https_test.go b/core/https/https_test.go
new file mode 100644
index 000000000..40b67367e
--- /dev/null
+++ b/core/https/https_test.go
@@ -0,0 +1,332 @@
+package https
+
+import (
+ "io/ioutil"
+ "net/http"
+ "os"
+ "testing"
+
+ "github.com/miekg/coredns/middleware/redirect"
+ "github.com/miekg/coredns/server"
+ "github.com/xenolf/lego/acme"
+)
+
+func TestHostQualifies(t *testing.T) {
+ for i, test := range []struct {
+ host string
+ expect bool
+ }{
+ {"localhost", false},
+ {"127.0.0.1", false},
+ {"127.0.1.5", false},
+ {"::1", false},
+ {"[::1]", false},
+ {"[::]", false},
+ {"::", false},
+ {"", false},
+ {" ", false},
+ {"0.0.0.0", false},
+ {"192.168.1.3", false},
+ {"10.0.2.1", false},
+ {"169.112.53.4", false},
+ {"foobar.com", true},
+ {"sub.foobar.com", true},
+ } {
+ if HostQualifies(test.host) && !test.expect {
+ t.Errorf("Test %d: Expected '%s' to NOT qualify, but it did", i, test.host)
+ }
+ if !HostQualifies(test.host) && test.expect {
+ t.Errorf("Test %d: Expected '%s' to qualify, but it did NOT", i, test.host)
+ }
+ }
+}
+
+func TestConfigQualifies(t *testing.T) {
+ for i, test := range []struct {
+ cfg server.Config
+ expect bool
+ }{
+ {server.Config{Host: ""}, false},
+ {server.Config{Host: "localhost"}, false},
+ {server.Config{Host: "123.44.3.21"}, false},
+ {server.Config{Host: "example.com"}, true},
+ {server.Config{Host: "example.com", TLS: server.TLSConfig{Manual: true}}, false},
+ {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}}, false},
+ {server.Config{Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}}, true},
+ {server.Config{Host: "example.com", Scheme: "http"}, false},
+ {server.Config{Host: "example.com", Port: "80"}, false},
+ {server.Config{Host: "example.com", Port: "1234"}, true},
+ {server.Config{Host: "example.com", Scheme: "https"}, true},
+ {server.Config{Host: "example.com", Port: "80", Scheme: "https"}, false},
+ } {
+ if test.expect && !ConfigQualifies(test.cfg) {
+ t.Errorf("Test %d: Expected config to qualify, but it did NOT: %#v", i, test.cfg)
+ }
+ if !test.expect && ConfigQualifies(test.cfg) {
+ t.Errorf("Test %d: Expected config to NOT qualify, but it did: %#v", i, test.cfg)
+ }
+ }
+}
+
+func TestRedirPlaintextHost(t *testing.T) {
+ cfg := redirPlaintextHost(server.Config{
+ Host: "example.com",
+ BindHost: "93.184.216.34",
+ Port: "1234",
+ })
+
+ // Check host and port
+ if actual, expected := cfg.Host, "example.com"; actual != expected {
+ t.Errorf("Expected redir config to have host %s but got %s", expected, actual)
+ }
+ if actual, expected := cfg.BindHost, "93.184.216.34"; actual != expected {
+ t.Errorf("Expected redir config to have bindhost %s but got %s", expected, actual)
+ }
+ if actual, expected := cfg.Port, "80"; actual != expected {
+ t.Errorf("Expected redir config to have port '%s' but got '%s'", expected, actual)
+ }
+
+ // Make sure redirect handler is set up properly
+ if cfg.Middleware == nil || len(cfg.Middleware) != 1 {
+ t.Fatalf("Redir config middleware not set up properly; got: %#v", cfg.Middleware)
+ }
+
+ handler, ok := cfg.Middleware[0](nil).(redirect.Redirect)
+ if !ok {
+ t.Fatalf("Expected a redirect.Redirect middleware, but got: %#v", handler)
+ }
+ if len(handler.Rules) != 1 {
+ t.Fatalf("Expected one redirect rule, got: %#v", handler.Rules)
+ }
+
+ // Check redirect rule for correctness
+ if actual, expected := handler.Rules[0].FromScheme, "http"; actual != expected {
+ t.Errorf("Expected redirect rule to be from scheme '%s' but is actually from '%s'", expected, actual)
+ }
+ if actual, expected := handler.Rules[0].FromPath, "/"; actual != expected {
+ t.Errorf("Expected redirect rule to be for path '%s' but is actually for '%s'", expected, actual)
+ }
+ if actual, expected := handler.Rules[0].To, "https://{host}:1234{uri}"; actual != expected {
+ t.Errorf("Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual)
+ }
+ if actual, expected := handler.Rules[0].Code, http.StatusMovedPermanently; actual != expected {
+ t.Errorf("Expected redirect rule to have code %d but was %d", expected, actual)
+ }
+
+ // browsers can infer a default port from scheme, so make sure the port
+ // doesn't get added in explicitly for default ports like 443 for https.
+ cfg = redirPlaintextHost(server.Config{Host: "example.com", Port: "443"})
+ handler, ok = cfg.Middleware[0](nil).(redirect.Redirect)
+ if actual, expected := handler.Rules[0].To, "https://{host}{uri}"; actual != expected {
+ t.Errorf("(Default Port) Expected redirect rule to be to URL '%s' but is actually to '%s'", expected, actual)
+ }
+}
+
+func TestSaveCertResource(t *testing.T) {
+ storage = Storage("./le_test_save")
+ defer func() {
+ err := os.RemoveAll(string(storage))
+ if err != nil {
+ t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err)
+ }
+ }()
+
+ domain := "example.com"
+ certContents := "certificate"
+ keyContents := "private key"
+ metaContents := `{
+ "domain": "example.com",
+ "certUrl": "https://example.com/cert",
+ "certStableUrl": "https://example.com/cert/stable"
+}`
+
+ cert := acme.CertificateResource{
+ Domain: domain,
+ CertURL: "https://example.com/cert",
+ CertStableURL: "https://example.com/cert/stable",
+ PrivateKey: []byte(keyContents),
+ Certificate: []byte(certContents),
+ }
+
+ err := saveCertResource(cert)
+ if err != nil {
+ t.Fatalf("Expected no error, got: %v", err)
+ }
+
+ certFile, err := ioutil.ReadFile(storage.SiteCertFile(domain))
+ if err != nil {
+ t.Errorf("Expected no error reading certificate file, got: %v", err)
+ }
+ if string(certFile) != certContents {
+ t.Errorf("Expected certificate file to contain '%s', got '%s'", certContents, string(certFile))
+ }
+
+ keyFile, err := ioutil.ReadFile(storage.SiteKeyFile(domain))
+ if err != nil {
+ t.Errorf("Expected no error reading private key file, got: %v", err)
+ }
+ if string(keyFile) != keyContents {
+ t.Errorf("Expected private key file to contain '%s', got '%s'", keyContents, string(keyFile))
+ }
+
+ metaFile, err := ioutil.ReadFile(storage.SiteMetaFile(domain))
+ if err != nil {
+ t.Errorf("Expected no error reading meta file, got: %v", err)
+ }
+ if string(metaFile) != metaContents {
+ t.Errorf("Expected meta file to contain '%s', got '%s'", metaContents, string(metaFile))
+ }
+}
+
+func TestExistingCertAndKey(t *testing.T) {
+ storage = Storage("./le_test_existing")
+ defer func() {
+ err := os.RemoveAll(string(storage))
+ if err != nil {
+ t.Fatalf("Could not remove temporary storage directory (%s): %v", storage, err)
+ }
+ }()
+
+ domain := "example.com"
+
+ if existingCertAndKey(domain) {
+ t.Errorf("Did NOT expect %v to have existing cert or key, but it did", domain)
+ }
+
+ err := saveCertResource(acme.CertificateResource{
+ Domain: domain,
+ PrivateKey: []byte("key"),
+ Certificate: []byte("cert"),
+ })
+ if err != nil {
+ t.Fatalf("Expected no error, got: %v", err)
+ }
+
+ if !existingCertAndKey(domain) {
+ t.Errorf("Expected %v to have existing cert and key, but it did NOT", domain)
+ }
+}
+
+func TestHostHasOtherPort(t *testing.T) {
+ configs := []server.Config{
+ {Host: "example.com", Port: "80"},
+ {Host: "sub1.example.com", Port: "80"},
+ {Host: "sub1.example.com", Port: "443"},
+ }
+
+ if hostHasOtherPort(configs, 0, "80") {
+ t.Errorf(`Expected hostHasOtherPort(configs, 0, "80") to be false, but got true`)
+ }
+ if hostHasOtherPort(configs, 0, "443") {
+ t.Errorf(`Expected hostHasOtherPort(configs, 0, "443") to be false, but got true`)
+ }
+ if !hostHasOtherPort(configs, 1, "443") {
+ t.Errorf(`Expected hostHasOtherPort(configs, 1, "443") to be true, but got false`)
+ }
+}
+
+func TestMakePlaintextRedirects(t *testing.T) {
+ configs := []server.Config{
+ // Happy path = standard redirect from 80 to 443
+ {Host: "example.com", TLS: server.TLSConfig{Managed: true}},
+
+ // Host on port 80 already defined; don't change it (no redirect)
+ {Host: "sub1.example.com", Port: "80", Scheme: "http"},
+ {Host: "sub1.example.com", TLS: server.TLSConfig{Managed: true}},
+
+ // Redirect from port 80 to port 5000 in this case
+ {Host: "sub2.example.com", Port: "5000", TLS: server.TLSConfig{Managed: true}},
+
+ // Can redirect from 80 to either 443 or 5001, but choose 443
+ {Host: "sub3.example.com", Port: "443", TLS: server.TLSConfig{Managed: true}},
+ {Host: "sub3.example.com", Port: "5001", Scheme: "https", TLS: server.TLSConfig{Managed: true}},
+ }
+
+ result := MakePlaintextRedirects(configs)
+ expectedRedirCount := 3
+
+ if len(result) != len(configs)+expectedRedirCount {
+ t.Errorf("Expected %d redirect(s) to be added, but got %d",
+ expectedRedirCount, len(result)-len(configs))
+ }
+}
+
+func TestEnableTLS(t *testing.T) {
+ configs := []server.Config{
+ {Host: "example.com", TLS: server.TLSConfig{Managed: true}},
+ {}, // not managed - no changes!
+ }
+
+ EnableTLS(configs, false)
+
+ if !configs[0].TLS.Enabled {
+ t.Errorf("Expected config 0 to have TLS.Enabled == true, but it was false")
+ }
+ if configs[1].TLS.Enabled {
+ t.Errorf("Expected config 1 to have TLS.Enabled == false, but it was true")
+ }
+}
+
+func TestGroupConfigsByEmail(t *testing.T) {
+ if groupConfigsByEmail([]server.Config{}, false) == nil {
+ t.Errorf("With empty input, returned map was nil, but expected non-nil map")
+ }
+
+ configs := []server.Config{
+ {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}},
+ {Host: "sub1.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}},
+ {Host: "sub2.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}},
+ {Host: "sub3.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar", Managed: true}},
+ {Host: "sub4.example.com", TLS: server.TLSConfig{LetsEncryptEmail: "", Managed: true}},
+ {Host: "sub5.example.com", TLS: server.TLSConfig{LetsEncryptEmail: ""}}, // not managed
+ }
+ DefaultEmail = "test@example.com"
+
+ groups := groupConfigsByEmail(configs, true)
+
+ if groups == nil {
+ t.Fatalf("Returned map was nil, but expected values")
+ }
+
+ if len(groups) != 2 {
+ t.Errorf("Expected 2 groups, got %d: %#v", len(groups), groups)
+ }
+ if len(groups["foo@bar"]) != 2 {
+ t.Errorf("Expected 2 configs for foo@bar, got %d: %#v", len(groups["foobar"]), groups["foobar"])
+ }
+ if len(groups[DefaultEmail]) != 3 {
+ t.Errorf("Expected 3 configs for %s, got %d: %#v", DefaultEmail, len(groups["foobar"]), groups["foobar"])
+ }
+}
+
+func TestMarkQualified(t *testing.T) {
+ // TODO: TestConfigQualifies and this test share the same config list...
+ configs := []server.Config{
+ {Host: ""},
+ {Host: "localhost"},
+ {Host: "123.44.3.21"},
+ {Host: "example.com"},
+ {Host: "example.com", TLS: server.TLSConfig{Manual: true}},
+ {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "off"}},
+ {Host: "example.com", TLS: server.TLSConfig{LetsEncryptEmail: "foo@bar.com"}},
+ {Host: "example.com", Scheme: "http"},
+ {Host: "example.com", Port: "80"},
+ {Host: "example.com", Port: "1234"},
+ {Host: "example.com", Scheme: "https"},
+ {Host: "example.com", Port: "80", Scheme: "https"},
+ }
+ expectedManagedCount := 4
+
+ MarkQualified(configs)
+
+ count := 0
+ for _, cfg := range configs {
+ if cfg.TLS.Managed {
+ count++
+ }
+ }
+
+ if count != expectedManagedCount {
+ t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count)
+ }
+}
diff --git a/core/https/maintain.go b/core/https/maintain.go
new file mode 100644
index 000000000..46fd3d1f5
--- /dev/null
+++ b/core/https/maintain.go
@@ -0,0 +1,211 @@
+package https
+
+import (
+ "log"
+ "time"
+
+ "github.com/miekg/coredns/server"
+
+ "golang.org/x/crypto/ocsp"
+)
+
+const (
+ // RenewInterval is how often to check certificates for renewal.
+ RenewInterval = 12 * time.Hour
+
+ // OCSPInterval is how often to check if OCSP stapling needs updating.
+ OCSPInterval = 1 * time.Hour
+)
+
+// maintainAssets is a permanently-blocking function
+// that loops indefinitely and, on a regular schedule, checks
+// certificates for expiration and initiates a renewal of certs
+// that are expiring soon. It also updates OCSP stapling and
+// performs other maintenance of assets.
+//
+// You must pass in the channel which you'll close when
+// maintenance should stop, to allow this goroutine to clean up
+// after itself and unblock.
+func maintainAssets(stopChan chan struct{}) {
+ renewalTicker := time.NewTicker(RenewInterval)
+ ocspTicker := time.NewTicker(OCSPInterval)
+
+ for {
+ select {
+ case <-renewalTicker.C:
+ log.Println("[INFO] Scanning for expiring certificates")
+ renewManagedCertificates(false)
+ log.Println("[INFO] Done checking certificates")
+ case <-ocspTicker.C:
+ log.Println("[INFO] Scanning for stale OCSP staples")
+ updateOCSPStaples()
+ log.Println("[INFO] Done checking OCSP staples")
+ case <-stopChan:
+ renewalTicker.Stop()
+ ocspTicker.Stop()
+ log.Println("[INFO] Stopped background maintenance routine")
+ return
+ }
+ }
+}
+
+func renewManagedCertificates(allowPrompts bool) (err error) {
+ var renewed, deleted []Certificate
+ var client *ACMEClient
+ visitedNames := make(map[string]struct{})
+
+ certCacheMu.RLock()
+ for name, cert := range certCache {
+ if !cert.Managed {
+ continue
+ }
+
+ // the list of names on this cert should never be empty...
+ if cert.Names == nil || len(cert.Names) == 0 {
+ log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v", name, cert.Names)
+ deleted = append(deleted, cert)
+ continue
+ }
+
+ // skip names whose certificate we've already renewed
+ if _, ok := visitedNames[name]; ok {
+ continue
+ }
+ for _, name := range cert.Names {
+ visitedNames[name] = struct{}{}
+ }
+
+ timeLeft := cert.NotAfter.Sub(time.Now().UTC())
+ if timeLeft < renewDurationBefore {
+ log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
+
+ if client == nil {
+ client, err = NewACMEClientGetEmail(server.Config{}, allowPrompts)
+ if err != nil {
+ return err
+ }
+ client.Configure("") // TODO: Bind address of relevant listener, yuck
+ }
+
+ err := client.Renew(cert.Names[0]) // managed certs better have only one name
+ if err != nil {
+ if client.AllowPrompts && timeLeft < 0 {
+ // Certificate renewal failed, the operator is present, and the certificate
+ // is already expired; we should stop immediately and return the error. Note
+ // that we used to do this any time a renewal failed at startup. However,
+ // after discussion in https://github.com/miekg/coredns/issues/642 we decided to
+ // only stop startup if the certificate is expired. We still log the error
+ // otherwise.
+ certCacheMu.RUnlock()
+ return err
+ }
+ log.Printf("[ERROR] %v", err)
+ if cert.OnDemand {
+ deleted = append(deleted, cert)
+ }
+ } else {
+ renewed = append(renewed, cert)
+ }
+ }
+ }
+ certCacheMu.RUnlock()
+
+ // Apply changes to the cache
+ for _, cert := range renewed {
+ _, err := cacheManagedCertificate(cert.Names[0], cert.OnDemand)
+ if err != nil {
+ if client.AllowPrompts {
+ return err // operator is present, so report error immediately
+ }
+ log.Printf("[ERROR] %v", err)
+ }
+ }
+ for _, cert := range deleted {
+ certCacheMu.Lock()
+ for _, name := range cert.Names {
+ delete(certCache, name)
+ }
+ certCacheMu.Unlock()
+ }
+
+ return nil
+}
+
+func updateOCSPStaples() {
+ // Create a temporary place to store updates
+ // until we release the potentially long-lived
+ // read lock and use a short-lived write lock.
+ type ocspUpdate struct {
+ rawBytes []byte
+ parsed *ocsp.Response
+ }
+ updated := make(map[string]ocspUpdate)
+
+ // A single SAN certificate maps to multiple names, so we use this
+ // set to make sure we don't waste cycles checking OCSP for the same
+ // certificate multiple times.
+ visited := make(map[string]struct{})
+
+ certCacheMu.RLock()
+ for name, cert := range certCache {
+ // skip this certificate if we've already visited it,
+ // and if not, mark all the names as visited
+ if _, ok := visited[name]; ok {
+ continue
+ }
+ for _, n := range cert.Names {
+ visited[n] = struct{}{}
+ }
+
+ // no point in updating OCSP for expired certificates
+ if time.Now().After(cert.NotAfter) {
+ continue
+ }
+
+ var lastNextUpdate time.Time
+ if cert.OCSP != nil {
+ // start checking OCSP staple about halfway through validity period for good measure
+ lastNextUpdate = cert.OCSP.NextUpdate
+ refreshTime := cert.OCSP.ThisUpdate.Add(lastNextUpdate.Sub(cert.OCSP.ThisUpdate) / 2)
+
+ // since OCSP is already stapled, we need only check if we're in that "refresh window"
+ if time.Now().Before(refreshTime) {
+ continue
+ }
+ }
+
+ err := stapleOCSP(&cert, nil)
+ if err != nil {
+ if cert.OCSP != nil {
+ // if it was no staple before, that's fine, otherwise we should log the error
+ log.Printf("[ERROR] Checking OCSP for %s: %v", name, err)
+ }
+ continue
+ }
+
+ // By this point, we've obtained the latest OCSP response.
+ // If there was no staple before, or if the response is updated, make
+ // sure we apply the update to all names on the certificate.
+ if lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate {
+ log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s",
+ cert.Names, lastNextUpdate, cert.OCSP.NextUpdate)
+ for _, n := range cert.Names {
+ updated[n] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP}
+ }
+ }
+ }
+ certCacheMu.RUnlock()
+
+ // This write lock should be brief since we have all the info we need now.
+ certCacheMu.Lock()
+ for name, update := range updated {
+ cert := certCache[name]
+ cert.OCSP = update.parsed
+ cert.Certificate.OCSPStaple = update.rawBytes
+ certCache[name] = cert
+ }
+ certCacheMu.Unlock()
+}
+
+// renewDurationBefore is how long before expiration to renew certificates.
+const renewDurationBefore = (24 * time.Hour) * 30
diff --git a/core/https/setup.go b/core/https/setup.go
new file mode 100644
index 000000000..ec90e0284
--- /dev/null
+++ b/core/https/setup.go
@@ -0,0 +1,321 @@
+package https
+
+import (
+ "bytes"
+ "crypto/tls"
+ "encoding/pem"
+ "io/ioutil"
+ "log"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+
+ "github.com/miekg/coredns/core/setup"
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/coredns/server"
+)
+
+// Setup sets up the TLS configuration and installs certificates that
+// are specified by the user in the config file. All the automatic HTTPS
+// stuff comes later outside of this function.
+func Setup(c *setup.Controller) (middleware.Middleware, error) {
+ if c.Port == "80" {
+ c.TLS.Enabled = false
+ log.Printf("[WARNING] TLS disabled for %s.", c.Address())
+ return nil, nil
+ }
+ c.TLS.Enabled = true
+
+ // TODO(miek): disabled for now
+ return nil, nil
+
+ for c.Next() {
+ var certificateFile, keyFile, loadDir, maxCerts string
+
+ args := c.RemainingArgs()
+ switch len(args) {
+ case 1:
+ c.TLS.LetsEncryptEmail = args[0]
+
+ // user can force-disable managed TLS this way
+ if c.TLS.LetsEncryptEmail == "off" {
+ c.TLS.Enabled = false
+ return nil, nil
+ }
+ case 2:
+ certificateFile = args[0]
+ keyFile = args[1]
+ c.TLS.Manual = true
+ }
+
+ // Optional block with extra parameters
+ var hadBlock bool
+ for c.NextBlock() {
+ hadBlock = true
+ switch c.Val() {
+ case "protocols":
+ args := c.RemainingArgs()
+ if len(args) != 2 {
+ return nil, c.ArgErr()
+ }
+ value, ok := supportedProtocols[strings.ToLower(args[0])]
+ if !ok {
+ return nil, c.Errf("Wrong protocol name or protocol not supported '%s'", c.Val())
+ }
+ c.TLS.ProtocolMinVersion = value
+ value, ok = supportedProtocols[strings.ToLower(args[1])]
+ if !ok {
+ return nil, c.Errf("Wrong protocol name or protocol not supported '%s'", c.Val())
+ }
+ c.TLS.ProtocolMaxVersion = value
+ case "ciphers":
+ for c.NextArg() {
+ value, ok := supportedCiphersMap[strings.ToUpper(c.Val())]
+ if !ok {
+ return nil, c.Errf("Wrong cipher name or cipher not supported '%s'", c.Val())
+ }
+ c.TLS.Ciphers = append(c.TLS.Ciphers, value)
+ }
+ case "clients":
+ c.TLS.ClientCerts = c.RemainingArgs()
+ if len(c.TLS.ClientCerts) == 0 {
+ return nil, c.ArgErr()
+ }
+ case "load":
+ c.Args(&loadDir)
+ c.TLS.Manual = true
+ case "max_certs":
+ c.Args(&maxCerts)
+ c.TLS.OnDemand = true
+ default:
+ return nil, c.Errf("Unknown keyword '%s'", c.Val())
+ }
+ }
+
+ // tls requires at least one argument if a block is not opened
+ if len(args) == 0 && !hadBlock {
+ return nil, c.ArgErr()
+ }
+
+ // set certificate limit if on-demand TLS is enabled
+ if maxCerts != "" {
+ maxCertsNum, err := strconv.Atoi(maxCerts)
+ if err != nil || maxCertsNum < 1 {
+ return nil, c.Err("max_certs must be a positive integer")
+ }
+ if onDemandMaxIssue == 0 || int32(maxCertsNum) < onDemandMaxIssue { // keep the minimum; TODO: We have to do this because it is global; should be per-server or per-vhost...
+ onDemandMaxIssue = int32(maxCertsNum)
+ }
+ }
+
+ // don't try to load certificates unless we're supposed to
+ if !c.TLS.Enabled || !c.TLS.Manual {
+ continue
+ }
+
+ // load a single certificate and key, if specified
+ if certificateFile != "" && keyFile != "" {
+ err := cacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
+ if err != nil {
+ return nil, c.Errf("Unable to load certificate and key files for %s: %v", c.Host, err)
+ }
+ log.Printf("[INFO] Successfully loaded TLS assets from %s and %s", certificateFile, keyFile)
+ }
+
+ // load a directory of certificates, if specified
+ if loadDir != "" {
+ err := loadCertsInDir(c, loadDir)
+ if err != nil {
+ return nil, err
+ }
+ }
+ }
+
+ setDefaultTLSParams(c.Config)
+
+ return nil, nil
+}
+
+// loadCertsInDir loads all the certificates/keys in dir, as long as
+// the file ends with .pem. This method of loading certificates is
+// modeled after haproxy, which expects the certificate and key to
+// be bundled into the same file:
+// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
+//
+// This function may write to the log as it walks the directory tree.
+func loadCertsInDir(c *setup.Controller, dir string) error {
+ return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
+ if err != nil {
+ log.Printf("[WARNING] Unable to traverse into %s; skipping", path)
+ return nil
+ }
+ if info.IsDir() {
+ return nil
+ }
+ if strings.HasSuffix(strings.ToLower(info.Name()), ".pem") {
+ certBuilder, keyBuilder := new(bytes.Buffer), new(bytes.Buffer)
+ var foundKey bool // use only the first key in the file
+
+ bundle, err := ioutil.ReadFile(path)
+ if err != nil {
+ return err
+ }
+
+ for {
+ // Decode next block so we can see what type it is
+ var derBlock *pem.Block
+ derBlock, bundle = pem.Decode(bundle)
+ if derBlock == nil {
+ break
+ }
+
+ if derBlock.Type == "CERTIFICATE" {
+ // Re-encode certificate as PEM, appending to certificate chain
+ pem.Encode(certBuilder, derBlock)
+ } else if derBlock.Type == "EC PARAMETERS" {
+ // EC keys generated from openssl can be composed of two blocks:
+ // parameters and key (parameter block should come first)
+ if !foundKey {
+ // Encode parameters
+ pem.Encode(keyBuilder, derBlock)
+
+ // Key must immediately follow
+ derBlock, bundle = pem.Decode(bundle)
+ if derBlock == nil || derBlock.Type != "EC PRIVATE KEY" {
+ return c.Errf("%s: expected elliptic private key to immediately follow EC parameters", path)
+ }
+ pem.Encode(keyBuilder, derBlock)
+ foundKey = true
+ }
+ } else if derBlock.Type == "PRIVATE KEY" || strings.HasSuffix(derBlock.Type, " PRIVATE KEY") {
+ // RSA key
+ if !foundKey {
+ pem.Encode(keyBuilder, derBlock)
+ foundKey = true
+ }
+ } else {
+ return c.Errf("%s: unrecognized PEM block type: %s", path, derBlock.Type)
+ }
+ }
+
+ certPEMBytes, keyPEMBytes := certBuilder.Bytes(), keyBuilder.Bytes()
+ if len(certPEMBytes) == 0 {
+ return c.Errf("%s: failed to parse PEM data", path)
+ }
+ if len(keyPEMBytes) == 0 {
+ return c.Errf("%s: no private key block found", path)
+ }
+
+ err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
+ if err != nil {
+ return c.Errf("%s: failed to load cert and key for %s: %v", path, c.Host, err)
+ }
+ log.Printf("[INFO] Successfully loaded TLS assets from %s", path)
+ }
+ return nil
+ })
+}
+
+// setDefaultTLSParams sets the default TLS cipher suites, protocol versions,
+// and server preferences of a server.Config if they were not previously set
+// (it does not overwrite; only fills in missing values). It will also set the
+// port to 443 if not already set, TLS is enabled, TLS is manual, and the host
+// does not equal localhost.
+func setDefaultTLSParams(c *server.Config) {
+ // If no ciphers provided, use default list
+ if len(c.TLS.Ciphers) == 0 {
+ c.TLS.Ciphers = defaultCiphers
+ }
+
+ // Not a cipher suite, but still important for mitigating protocol downgrade attacks
+ // (prepend since having it at end breaks http2 due to non-h2-approved suites before it)
+ c.TLS.Ciphers = append([]uint16{tls.TLS_FALLBACK_SCSV}, c.TLS.Ciphers...)
+
+ // Set default protocol min and max versions - must balance compatibility and security
+ if c.TLS.ProtocolMinVersion == 0 {
+ c.TLS.ProtocolMinVersion = tls.VersionTLS10
+ }
+ if c.TLS.ProtocolMaxVersion == 0 {
+ c.TLS.ProtocolMaxVersion = tls.VersionTLS12
+ }
+
+ // Prefer server cipher suites
+ c.TLS.PreferServerCipherSuites = true
+
+ // Default TLS port is 443; only use if port is not manually specified,
+ // TLS is enabled, and the host is not localhost
+ if c.Port == "" && c.TLS.Enabled && (!c.TLS.Manual || c.TLS.OnDemand) && c.Host != "localhost" {
+ c.Port = "443"
+ }
+}
+
+// Map of supported protocols.
+// SSLv3 will be not supported in future release.
+// HTTP/2 only supports TLS 1.2 and higher.
+var supportedProtocols = map[string]uint16{
+ "ssl3.0": tls.VersionSSL30,
+ "tls1.0": tls.VersionTLS10,
+ "tls1.1": tls.VersionTLS11,
+ "tls1.2": tls.VersionTLS12,
+}
+
+// Map of supported ciphers, used only for parsing config.
+//
+// Note that, at time of writing, HTTP/2 blacklists 276 cipher suites,
+// including all but two of the suites below (the two GCM suites).
+// See https://http2.github.io/http2-spec/#BadCipherSuites
+//
+// TLS_FALLBACK_SCSV is not in this list because we manually ensure
+// it is always added (even though it is not technically a cipher suite).
+//
+// This map, like any map, is NOT ORDERED. Do not range over this map.
+var supportedCiphersMap = map[string]uint16{
+ "ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
+ "ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
+ "ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+ "ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+ "ECDHE-RSA-AES128-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ "ECDHE-RSA-AES256-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
+ "ECDHE-ECDSA-AES256-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
+ "ECDHE-ECDSA-AES128-CBC-SHA": tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+ "RSA-AES128-CBC-SHA": tls.TLS_RSA_WITH_AES_128_CBC_SHA,
+ "RSA-AES256-CBC-SHA": tls.TLS_RSA_WITH_AES_256_CBC_SHA,
+ "ECDHE-RSA-3DES-EDE-CBC-SHA": tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
+ "RSA-3DES-EDE-CBC-SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
+}
+
+// List of supported cipher suites in descending order of preference.
+// Ordering is very important! Getting the wrong order will break
+// mainstream clients, especially with HTTP/2.
+//
+// Note that TLS_FALLBACK_SCSV is not in this list since it is always
+// added manually.
+var supportedCiphers = []uint16{
+ tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
+ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+ tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+ tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
+ tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+ tls.TLS_RSA_WITH_AES_256_CBC_SHA,
+ tls.TLS_RSA_WITH_AES_128_CBC_SHA,
+ tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
+ tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
+}
+
+// List of all the ciphers we want to use by default
+var defaultCiphers = []uint16{
+ tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
+ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+ tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+ tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
+ tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+ tls.TLS_RSA_WITH_AES_256_CBC_SHA,
+ tls.TLS_RSA_WITH_AES_128_CBC_SHA,
+}
diff --git a/core/https/setup_test.go b/core/https/setup_test.go
new file mode 100644
index 000000000..339fcdb5a
--- /dev/null
+++ b/core/https/setup_test.go
@@ -0,0 +1,232 @@
+package https
+
+import (
+ "crypto/tls"
+ "io/ioutil"
+ "log"
+ "os"
+ "testing"
+
+ "github.com/miekg/coredns/core/setup"
+)
+
+func TestMain(m *testing.M) {
+ // Write test certificates to disk before tests, and clean up
+ // when we're done.
+ err := ioutil.WriteFile(certFile, testCert, 0644)
+ if err != nil {
+ log.Fatal(err)
+ }
+ err = ioutil.WriteFile(keyFile, testKey, 0644)
+ if err != nil {
+ os.Remove(certFile)
+ log.Fatal(err)
+ }
+
+ result := m.Run()
+
+ os.Remove(certFile)
+ os.Remove(keyFile)
+ os.Exit(result)
+}
+
+func TestSetupParseBasic(t *testing.T) {
+ c := setup.NewTestController(`tls ` + certFile + ` ` + keyFile + ``)
+
+ _, err := Setup(c)
+ if err != nil {
+ t.Errorf("Expected no errors, got: %v", err)
+ }
+
+ // Basic checks
+ if !c.TLS.Manual {
+ t.Error("Expected TLS Manual=true, but was false")
+ }
+ if !c.TLS.Enabled {
+ t.Error("Expected TLS Enabled=true, but was false")
+ }
+
+ // Security defaults
+ if c.TLS.ProtocolMinVersion != tls.VersionTLS10 {
+ t.Errorf("Expected 'tls1.0 (0x0301)' as ProtocolMinVersion, got %#v", c.TLS.ProtocolMinVersion)
+ }
+ if c.TLS.ProtocolMaxVersion != tls.VersionTLS12 {
+ t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMaxVersion, got %v", c.TLS.ProtocolMaxVersion)
+ }
+
+ // Cipher checks
+ expectedCiphers := []uint16{
+ tls.TLS_FALLBACK_SCSV,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
+ tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+ tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+ tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
+ tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+ tls.TLS_RSA_WITH_AES_256_CBC_SHA,
+ tls.TLS_RSA_WITH_AES_128_CBC_SHA,
+ }
+
+ // Ensure count is correct (plus one for TLS_FALLBACK_SCSV)
+ if len(c.TLS.Ciphers) != len(expectedCiphers) {
+ t.Errorf("Expected %v Ciphers (including TLS_FALLBACK_SCSV), got %v",
+ len(expectedCiphers), len(c.TLS.Ciphers))
+ }
+
+ // Ensure ordering is correct
+ for i, actual := range c.TLS.Ciphers {
+ if actual != expectedCiphers[i] {
+ t.Errorf("Expected cipher in position %d to be %0x, got %0x", i, expectedCiphers[i], actual)
+ }
+ }
+
+ if !c.TLS.PreferServerCipherSuites {
+ t.Error("Expected PreferServerCipherSuites = true, but was false")
+ }
+}
+
+func TestSetupParseIncompleteParams(t *testing.T) {
+ // Using tls without args is an error because it's unnecessary.
+ c := setup.NewTestController(`tls`)
+ _, err := Setup(c)
+ if err == nil {
+ t.Error("Expected an error, but didn't get one")
+ }
+}
+
+func TestSetupParseWithOptionalParams(t *testing.T) {
+ params := `tls ` + certFile + ` ` + keyFile + ` {
+ protocols ssl3.0 tls1.2
+ ciphers RSA-AES256-CBC-SHA ECDHE-RSA-AES128-GCM-SHA256 ECDHE-ECDSA-AES256-GCM-SHA384
+ }`
+ c := setup.NewTestController(params)
+
+ _, err := Setup(c)
+ if err != nil {
+ t.Errorf("Expected no errors, got: %v", err)
+ }
+
+ if c.TLS.ProtocolMinVersion != tls.VersionSSL30 {
+ t.Errorf("Expected 'ssl3.0 (0x0300)' as ProtocolMinVersion, got %#v", c.TLS.ProtocolMinVersion)
+ }
+
+ if c.TLS.ProtocolMaxVersion != tls.VersionTLS12 {
+ t.Errorf("Expected 'tls1.2 (0x0302)' as ProtocolMaxVersion, got %#v", c.TLS.ProtocolMaxVersion)
+ }
+
+ if len(c.TLS.Ciphers)-1 != 3 {
+ t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1)
+ }
+}
+
+func TestSetupDefaultWithOptionalParams(t *testing.T) {
+ params := `tls {
+ ciphers RSA-3DES-EDE-CBC-SHA
+ }`
+ c := setup.NewTestController(params)
+
+ _, err := Setup(c)
+ if err != nil {
+ t.Errorf("Expected no errors, got: %v", err)
+ }
+ if len(c.TLS.Ciphers)-1 != 1 {
+ t.Errorf("Expected 1 ciphers (not including TLS_FALLBACK_SCSV), got %v", len(c.TLS.Ciphers)-1)
+ }
+}
+
+// TODO: If we allow this... but probably not a good idea.
+// func TestSetupDisableHTTPRedirect(t *testing.T) {
+// c := NewTestController(`tls {
+// allow_http
+// }`)
+// _, err := TLS(c)
+// if err != nil {
+// t.Errorf("Expected no error, but got %v", err)
+// }
+// if !c.TLS.DisableHTTPRedir {
+// t.Error("Expected HTTP redirect to be disabled, but it wasn't")
+// }
+// }
+
+func TestSetupParseWithWrongOptionalParams(t *testing.T) {
+ // Test protocols wrong params
+ params := `tls ` + certFile + ` ` + keyFile + ` {
+ protocols ssl tls
+ }`
+ c := setup.NewTestController(params)
+ _, err := Setup(c)
+ if err == nil {
+ t.Errorf("Expected errors, but no error returned")
+ }
+
+ // Test ciphers wrong params
+ params = `tls ` + certFile + ` ` + keyFile + ` {
+ ciphers not-valid-cipher
+ }`
+ c = setup.NewTestController(params)
+ _, err = Setup(c)
+ if err == nil {
+ t.Errorf("Expected errors, but no error returned")
+ }
+}
+
+func TestSetupParseWithClientAuth(t *testing.T) {
+ params := `tls ` + certFile + ` ` + keyFile + ` {
+ clients client_ca.crt client2_ca.crt
+ }`
+ c := setup.NewTestController(params)
+ _, err := Setup(c)
+ if err != nil {
+ t.Errorf("Expected no errors, got: %v", err)
+ }
+
+ if count := len(c.TLS.ClientCerts); count != 2 {
+ t.Fatalf("Expected two client certs, had %d", count)
+ }
+ if actual := c.TLS.ClientCerts[0]; actual != "client_ca.crt" {
+ t.Errorf("Expected first client cert file to be '%s', but was '%s'", "client_ca.crt", actual)
+ }
+ if actual := c.TLS.ClientCerts[1]; actual != "client2_ca.crt" {
+ t.Errorf("Expected second client cert file to be '%s', but was '%s'", "client2_ca.crt", actual)
+ }
+
+ // Test missing client cert file
+ params = `tls ` + certFile + ` ` + keyFile + ` {
+ clients
+ }`
+ c = setup.NewTestController(params)
+ _, err = Setup(c)
+ if err == nil {
+ t.Errorf("Expected an error, but no error returned")
+ }
+}
+
+const (
+ certFile = "test_cert.pem"
+ keyFile = "test_key.pem"
+)
+
+var testCert = []byte(`-----BEGIN CERTIFICATE-----
+MIIBkjCCATmgAwIBAgIJANfFCBcABL6LMAkGByqGSM49BAEwFDESMBAGA1UEAxMJ
+bG9jYWxob3N0MB4XDTE2MDIxMDIyMjAyNFoXDTE4MDIwOTIyMjAyNFowFDESMBAG
+A1UEAxMJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEs22MtnG7
+9K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDLSiVQvFZ6lUszTlczNxVk
+pEfqrM6xAupB7qN1MHMwHQYDVR0OBBYEFHxYDvAxUwL4XrjPev6qZ/BiLDs5MEQG
+A1UdIwQ9MDuAFHxYDvAxUwL4XrjPev6qZ/BiLDs5oRikFjAUMRIwEAYDVQQDEwls
+b2NhbGhvc3SCCQDXxQgXAAS+izAMBgNVHRMEBTADAQH/MAkGByqGSM49BAEDSAAw
+RQIgRvBqbyJM2JCJqhA1FmcoZjeMocmhxQHTt1c+1N2wFUgCIQDtvrivbBPA688N
+Qh3sMeAKNKPsx5NxYdoWuu9KWcKz9A==
+-----END CERTIFICATE-----
+`)
+
+var testKey = []byte(`-----BEGIN EC PARAMETERS-----
+BggqhkjOPQMBBw==
+-----END EC PARAMETERS-----
+-----BEGIN EC PRIVATE KEY-----
+MHcCAQEEIGLtRmwzYVcrH3J0BnzYbGPdWVF10i9p6mxkA4+b2fURoAoGCCqGSM49
+AwEHoUQDQgAEs22MtnG79K1mvIyjEO9GLx7BFD0tBbGnwQ0VPsuCxC6IeVuXbQDL
+SiVQvFZ6lUszTlczNxVkpEfqrM6xAupB7g==
+-----END EC PRIVATE KEY-----
+`)
diff --git a/core/https/storage.go b/core/https/storage.go
new file mode 100644
index 000000000..5d8e949da
--- /dev/null
+++ b/core/https/storage.go
@@ -0,0 +1,94 @@
+package https
+
+import (
+ "path/filepath"
+ "strings"
+
+ "github.com/miekg/coredns/core/assets"
+)
+
+// storage is used to get file paths in a consistent,
+// cross-platform way for persisting Let's Encrypt assets
+// on the file system.
+var storage = Storage(filepath.Join(assets.Path(), "letsencrypt"))
+
+// Storage is a root directory and facilitates
+// forming file paths derived from it.
+type Storage string
+
+// Sites gets the directory that stores site certificate and keys.
+func (s Storage) Sites() string {
+ return filepath.Join(string(s), "sites")
+}
+
+// Site returns the path to the folder containing assets for domain.
+func (s Storage) Site(domain string) string {
+ return filepath.Join(s.Sites(), domain)
+}
+
+// SiteCertFile returns the path to the certificate file for domain.
+func (s Storage) SiteCertFile(domain string) string {
+ return filepath.Join(s.Site(domain), domain+".crt")
+}
+
+// SiteKeyFile returns the path to domain's private key file.
+func (s Storage) SiteKeyFile(domain string) string {
+ return filepath.Join(s.Site(domain), domain+".key")
+}
+
+// SiteMetaFile returns the path to the domain's asset metadata file.
+func (s Storage) SiteMetaFile(domain string) string {
+ return filepath.Join(s.Site(domain), domain+".json")
+}
+
+// Users gets the directory that stores account folders.
+func (s Storage) Users() string {
+ return filepath.Join(string(s), "users")
+}
+
+// User gets the account folder for the user with email.
+func (s Storage) User(email string) string {
+ if email == "" {
+ email = emptyEmail
+ }
+ return filepath.Join(s.Users(), email)
+}
+
+// UserRegFile gets the path to the registration file for
+// the user with the given email address.
+func (s Storage) UserRegFile(email string) string {
+ if email == "" {
+ email = emptyEmail
+ }
+ fileName := emailUsername(email)
+ if fileName == "" {
+ fileName = "registration"
+ }
+ return filepath.Join(s.User(email), fileName+".json")
+}
+
+// UserKeyFile gets the path to the private key file for
+// the user with the given email address.
+func (s Storage) UserKeyFile(email string) string {
+ if email == "" {
+ email = emptyEmail
+ }
+ fileName := emailUsername(email)
+ if fileName == "" {
+ fileName = "private"
+ }
+ return filepath.Join(s.User(email), fileName+".key")
+}
+
+// emailUsername returns the username portion of an
+// email address (part before '@') or the original
+// input if it can't find the "@" symbol.
+func emailUsername(email string) string {
+ at := strings.Index(email, "@")
+ if at == -1 {
+ return email
+ } else if at == 0 {
+ return email[1:]
+ }
+ return email[:at]
+}
diff --git a/core/https/storage_test.go b/core/https/storage_test.go
new file mode 100644
index 000000000..85c2220eb
--- /dev/null
+++ b/core/https/storage_test.go
@@ -0,0 +1,88 @@
+package https
+
+import (
+ "path/filepath"
+ "testing"
+)
+
+func TestStorage(t *testing.T) {
+ storage = Storage("./le_test")
+
+ if expected, actual := filepath.Join("le_test", "sites"), storage.Sites(); actual != expected {
+ t.Errorf("Expected Sites() to return '%s' but got '%s'", expected, actual)
+ }
+ if expected, actual := filepath.Join("le_test", "sites", "test.com"), storage.Site("test.com"); actual != expected {
+ t.Errorf("Expected Site() to return '%s' but got '%s'", expected, actual)
+ }
+ if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.crt"), storage.SiteCertFile("test.com"); actual != expected {
+ t.Errorf("Expected SiteCertFile() to return '%s' but got '%s'", expected, actual)
+ }
+ if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.key"), storage.SiteKeyFile("test.com"); actual != expected {
+ t.Errorf("Expected SiteKeyFile() to return '%s' but got '%s'", expected, actual)
+ }
+ if expected, actual := filepath.Join("le_test", "sites", "test.com", "test.com.json"), storage.SiteMetaFile("test.com"); actual != expected {
+ t.Errorf("Expected SiteMetaFile() to return '%s' but got '%s'", expected, actual)
+ }
+ if expected, actual := filepath.Join("le_test", "users"), storage.Users(); actual != expected {
+ t.Errorf("Expected Users() to return '%s' but got '%s'", expected, actual)
+ }
+ if expected, actual := filepath.Join("le_test", "users", "me@example.com"), storage.User("me@example.com"); actual != expected {
+ t.Errorf("Expected User() to return '%s' but got '%s'", expected, actual)
+ }
+ if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.json"), storage.UserRegFile("me@example.com"); actual != expected {
+ t.Errorf("Expected UserRegFile() to return '%s' but got '%s'", expected, actual)
+ }
+ if expected, actual := filepath.Join("le_test", "users", "me@example.com", "me.key"), storage.UserKeyFile("me@example.com"); actual != expected {
+ t.Errorf("Expected UserKeyFile() to return '%s' but got '%s'", expected, actual)
+ }
+
+ // Test with empty emails
+ if expected, actual := filepath.Join("le_test", "users", emptyEmail), storage.User(emptyEmail); actual != expected {
+ t.Errorf("Expected User(\"\") to return '%s' but got '%s'", expected, actual)
+ }
+ if expected, actual := filepath.Join("le_test", "users", emptyEmail, emptyEmail+".json"), storage.UserRegFile(""); actual != expected {
+ t.Errorf("Expected UserRegFile(\"\") to return '%s' but got '%s'", expected, actual)
+ }
+ if expected, actual := filepath.Join("le_test", "users", emptyEmail, emptyEmail+".key"), storage.UserKeyFile(""); actual != expected {
+ t.Errorf("Expected UserKeyFile(\"\") to return '%s' but got '%s'", expected, actual)
+ }
+}
+
+func TestEmailUsername(t *testing.T) {
+ for i, test := range []struct {
+ input, expect string
+ }{
+ {
+ input: "username@example.com",
+ expect: "username",
+ },
+ {
+ input: "plus+addressing@example.com",
+ expect: "plus+addressing",
+ },
+ {
+ input: "me+plus-addressing@example.com",
+ expect: "me+plus-addressing",
+ },
+ {
+ input: "not-an-email",
+ expect: "not-an-email",
+ },
+ {
+ input: "@foobar.com",
+ expect: "foobar.com",
+ },
+ {
+ input: emptyEmail,
+ expect: emptyEmail,
+ },
+ {
+ input: "",
+ expect: "",
+ },
+ } {
+ if actual := emailUsername(test.input); actual != test.expect {
+ t.Errorf("Test %d: Expected username to be '%s' but was '%s'", i, test.expect, actual)
+ }
+ }
+}
diff --git a/core/https/user.go b/core/https/user.go
new file mode 100644
index 000000000..13d93b1da
--- /dev/null
+++ b/core/https/user.go
@@ -0,0 +1,200 @@
+package https
+
+import (
+ "bufio"
+ "crypto"
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "strings"
+
+ "github.com/miekg/coredns/server"
+ "github.com/xenolf/lego/acme"
+)
+
+// User represents a Let's Encrypt user account.
+type User struct {
+ Email string
+ Registration *acme.RegistrationResource
+ key crypto.PrivateKey
+}
+
+// GetEmail gets u's email.
+func (u User) GetEmail() string {
+ return u.Email
+}
+
+// GetRegistration gets u's registration resource.
+func (u User) GetRegistration() *acme.RegistrationResource {
+ return u.Registration
+}
+
+// GetPrivateKey gets u's private key.
+func (u User) GetPrivateKey() crypto.PrivateKey {
+ return u.key
+}
+
+// getUser loads the user with the given email from disk.
+// If the user does not exist, it will create a new one,
+// but it does NOT save new users to the disk or register
+// them via ACME. It does NOT prompt the user.
+func getUser(email string) (User, error) {
+ var user User
+
+ // open user file
+ regFile, err := os.Open(storage.UserRegFile(email))
+ if err != nil {
+ if os.IsNotExist(err) {
+ // create a new user
+ return newUser(email)
+ }
+ return user, err
+ }
+ defer regFile.Close()
+
+ // load user information
+ err = json.NewDecoder(regFile).Decode(&user)
+ if err != nil {
+ return user, err
+ }
+
+ // load their private key
+ user.key, err = loadPrivateKey(storage.UserKeyFile(email))
+ if err != nil {
+ return user, err
+ }
+
+ return user, nil
+}
+
+// saveUser persists a user's key and account registration
+// to the file system. It does NOT register the user via ACME
+// or prompt the user.
+func saveUser(user User) error {
+ // make user account folder
+ err := os.MkdirAll(storage.User(user.Email), 0700)
+ if err != nil {
+ return err
+ }
+
+ // save private key file
+ err = savePrivateKey(user.key, storage.UserKeyFile(user.Email))
+ if err != nil {
+ return err
+ }
+
+ // save registration file
+ jsonBytes, err := json.MarshalIndent(&user, "", "\t")
+ if err != nil {
+ return err
+ }
+
+ return ioutil.WriteFile(storage.UserRegFile(user.Email), jsonBytes, 0600)
+}
+
+// newUser creates a new User for the given email address
+// with a new private key. This function does NOT save the
+// user to disk or register it via ACME. If you want to use
+// a user account that might already exist, call getUser
+// instead. It does NOT prompt the user.
+func newUser(email string) (User, error) {
+ user := User{Email: email}
+ privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
+ if err != nil {
+ return user, errors.New("error generating private key: " + err.Error())
+ }
+ user.key = privateKey
+ return user, nil
+}
+
+// getEmail does everything it can to obtain an email
+// address from the user to use for TLS for cfg. If it
+// cannot get an email address, it returns empty string.
+// (It will warn the user of the consequences of an
+// empty email.) This function MAY prompt the user for
+// input. If userPresent is false, the operator will
+// NOT be prompted and an empty email may be returned.
+func getEmail(cfg server.Config, userPresent bool) string {
+ // First try the tls directive from the Caddyfile
+ leEmail := cfg.TLS.LetsEncryptEmail
+ if leEmail == "" {
+ // Then try memory (command line flag or typed by user previously)
+ leEmail = DefaultEmail
+ }
+ if leEmail == "" {
+ // Then try to get most recent user email ~/.caddy/users file
+ userDirs, err := ioutil.ReadDir(storage.Users())
+ if err == nil {
+ var mostRecent os.FileInfo
+ for _, dir := range userDirs {
+ if !dir.IsDir() {
+ continue
+ }
+ if mostRecent == nil || dir.ModTime().After(mostRecent.ModTime()) {
+ leEmail = dir.Name()
+ DefaultEmail = leEmail // save for next time
+ }
+ }
+ }
+ }
+ if leEmail == "" && userPresent {
+ // Alas, we must bother the user and ask for an email address;
+ // if they proceed they also agree to the SA.
+ reader := bufio.NewReader(stdin)
+ fmt.Println("\nYour sites will be served over HTTPS automatically using Let's Encrypt.")
+ fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:")
+ fmt.Println(" " + saURL) // TODO: Show current SA link
+ fmt.Println("Please enter your email address so you can recover your account if needed.")
+ fmt.Println("You can leave it blank, but you'll lose the ability to recover your account.")
+ fmt.Print("Email address: ")
+ var err error
+ leEmail, err = reader.ReadString('\n')
+ if err != nil {
+ return ""
+ }
+ leEmail = strings.TrimSpace(leEmail)
+ DefaultEmail = leEmail
+ Agreed = true
+ }
+ return leEmail
+}
+
+// promptUserAgreement prompts the user to agree to the agreement
+// at agreementURL via stdin. If the agreement has changed, then pass
+// true as the second argument. If this is the user's first time
+// agreeing, pass false. It returns whether the user agreed or not.
+func promptUserAgreement(agreementURL string, changed bool) bool {
+ if changed {
+ fmt.Printf("The Let's Encrypt Subscriber Agreement has changed:\n %s\n", agreementURL)
+ fmt.Print("Do you agree to the new terms? (y/n): ")
+ } else {
+ fmt.Printf("To continue, you must agree to the Let's Encrypt Subscriber Agreement:\n %s\n", agreementURL)
+ fmt.Print("Do you agree to the terms? (y/n): ")
+ }
+
+ reader := bufio.NewReader(stdin)
+ answer, err := reader.ReadString('\n')
+ if err != nil {
+ return false
+ }
+ answer = strings.ToLower(strings.TrimSpace(answer))
+
+ return answer == "y" || answer == "yes"
+}
+
+// stdin is used to read the user's input if prompted;
+// this is changed by tests during tests.
+var stdin = io.ReadWriter(os.Stdin)
+
+// The name of the folder for accounts where the email
+// address was not provided; default 'username' if you will.
+const emptyEmail = "default"
+
+// TODO: Use latest
+const saURL = "https://letsencrypt.org/documents/LE-SA-v1.0.1-July-27-2015.pdf"
diff --git a/core/https/user_test.go b/core/https/user_test.go
new file mode 100644
index 000000000..3e1af5007
--- /dev/null
+++ b/core/https/user_test.go
@@ -0,0 +1,196 @@
+package https
+
+import (
+ "bytes"
+ "crypto/rand"
+ "crypto/rsa"
+ "io"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/miekg/coredns/server"
+ "github.com/xenolf/lego/acme"
+)
+
+func TestUser(t *testing.T) {
+ privateKey, err := rsa.GenerateKey(rand.Reader, 128)
+ if err != nil {
+ t.Fatalf("Could not generate test private key: %v", err)
+ }
+ u := User{
+ Email: "me@mine.com",
+ Registration: new(acme.RegistrationResource),
+ key: privateKey,
+ }
+
+ if expected, actual := "me@mine.com", u.GetEmail(); actual != expected {
+ t.Errorf("Expected email '%s' but got '%s'", expected, actual)
+ }
+ if u.GetRegistration() == nil {
+ t.Error("Expected a registration resource, but got nil")
+ }
+ if expected, actual := privateKey, u.GetPrivateKey(); actual != expected {
+ t.Errorf("Expected the private key at address %p but got one at %p instead ", expected, actual)
+ }
+}
+
+func TestNewUser(t *testing.T) {
+ email := "me@foobar.com"
+ user, err := newUser(email)
+ if err != nil {
+ t.Fatalf("Error creating user: %v", err)
+ }
+ if user.key == nil {
+ t.Error("Private key is nil")
+ }
+ if user.Email != email {
+ t.Errorf("Expected email to be %s, but was %s", email, user.Email)
+ }
+ if user.Registration != nil {
+ t.Error("New user already has a registration resource; it shouldn't")
+ }
+}
+
+func TestSaveUser(t *testing.T) {
+ storage = Storage("./testdata")
+ defer os.RemoveAll(string(storage))
+
+ email := "me@foobar.com"
+ user, err := newUser(email)
+ if err != nil {
+ t.Fatalf("Error creating user: %v", err)
+ }
+
+ err = saveUser(user)
+ if err != nil {
+ t.Fatalf("Error saving user: %v", err)
+ }
+ _, err = os.Stat(storage.UserRegFile(email))
+ if err != nil {
+ t.Errorf("Cannot access user registration file, error: %v", err)
+ }
+ _, err = os.Stat(storage.UserKeyFile(email))
+ if err != nil {
+ t.Errorf("Cannot access user private key file, error: %v", err)
+ }
+}
+
+func TestGetUserDoesNotAlreadyExist(t *testing.T) {
+ storage = Storage("./testdata")
+ defer os.RemoveAll(string(storage))
+
+ user, err := getUser("user_does_not_exist@foobar.com")
+ if err != nil {
+ t.Fatalf("Error getting user: %v", err)
+ }
+
+ if user.key == nil {
+ t.Error("Expected user to have a private key, but it was nil")
+ }
+}
+
+func TestGetUserAlreadyExists(t *testing.T) {
+ storage = Storage("./testdata")
+ defer os.RemoveAll(string(storage))
+
+ email := "me@foobar.com"
+
+ // Set up test
+ user, err := newUser(email)
+ if err != nil {
+ t.Fatalf("Error creating user: %v", err)
+ }
+ err = saveUser(user)
+ if err != nil {
+ t.Fatalf("Error saving user: %v", err)
+ }
+
+ // Expect to load user from disk
+ user2, err := getUser(email)
+ if err != nil {
+ t.Fatalf("Error getting user: %v", err)
+ }
+
+ // Assert keys are the same
+ if !PrivateKeysSame(user.key, user2.key) {
+ t.Error("Expected private key to be the same after loading, but it wasn't")
+ }
+
+ // Assert emails are the same
+ if user.Email != user2.Email {
+ t.Errorf("Expected emails to be equal, but was '%s' before and '%s' after loading", user.Email, user2.Email)
+ }
+}
+
+func TestGetEmail(t *testing.T) {
+ // let's not clutter up the output
+ origStdout := os.Stdout
+ os.Stdout = nil
+ defer func() { os.Stdout = origStdout }()
+
+ storage = Storage("./testdata")
+ defer os.RemoveAll(string(storage))
+ DefaultEmail = "test2@foo.com"
+
+ // Test1: Use email in config
+ config := server.Config{
+ TLS: server.TLSConfig{
+ LetsEncryptEmail: "test1@foo.com",
+ },
+ }
+ actual := getEmail(config, true)
+ if actual != "test1@foo.com" {
+ t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", "test1@foo.com", actual)
+ }
+
+ // Test2: Use default email from flag (or user previously typing it)
+ actual = getEmail(server.Config{}, true)
+ if actual != DefaultEmail {
+ t.Errorf("Did not get correct email from config; expected '%s' but got '%s'", DefaultEmail, actual)
+ }
+
+ // Test3: Get input from user
+ DefaultEmail = ""
+ stdin = new(bytes.Buffer)
+ _, err := io.Copy(stdin, strings.NewReader("test3@foo.com\n"))
+ if err != nil {
+ t.Fatalf("Could not simulate user input, error: %v", err)
+ }
+ actual = getEmail(server.Config{}, true)
+ if actual != "test3@foo.com" {
+ t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual)
+ }
+
+ // Test4: Get most recent email from before
+ DefaultEmail = ""
+ for i, eml := range []string{
+ "test4-3@foo.com",
+ "test4-2@foo.com",
+ "test4-1@foo.com",
+ } {
+ u, err := newUser(eml)
+ if err != nil {
+ t.Fatalf("Error creating user %d: %v", i, err)
+ }
+ err = saveUser(u)
+ if err != nil {
+ t.Fatalf("Error saving user %d: %v", i, err)
+ }
+
+ // Change modified time so they're all different, so the test becomes deterministic
+ f, err := os.Stat(storage.User(eml))
+ if err != nil {
+ t.Fatalf("Could not access user folder for '%s': %v", eml, err)
+ }
+ chTime := f.ModTime().Add(-(time.Duration(i) * time.Second))
+ if err := os.Chtimes(storage.User(eml), chTime, chTime); err != nil {
+ t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err)
+ }
+ }
+ actual = getEmail(server.Config{}, true)
+ if actual != "test4-3@foo.com" {
+ t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual)
+ }
+}
diff --git a/core/parse/dispenser.go b/core/parse/dispenser.go
new file mode 100644
index 000000000..08aa6e76d
--- /dev/null
+++ b/core/parse/dispenser.go
@@ -0,0 +1,251 @@
+package parse
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "strings"
+)
+
+// Dispenser is a type that dispenses tokens, similarly to a lexer,
+// except that it can do so with some notion of structure and has
+// some really convenient methods.
+type Dispenser struct {
+ filename string
+ tokens []token
+ cursor int
+ nesting int
+}
+
+// NewDispenser returns a Dispenser, ready to use for parsing the given input.
+func NewDispenser(filename string, input io.Reader) Dispenser {
+ return Dispenser{
+ filename: filename,
+ tokens: allTokens(input),
+ cursor: -1,
+ }
+}
+
+// NewDispenserTokens returns a Dispenser filled with the given tokens.
+func NewDispenserTokens(filename string, tokens []token) Dispenser {
+ return Dispenser{
+ filename: filename,
+ tokens: tokens,
+ cursor: -1,
+ }
+}
+
+// Next loads the next token. Returns true if a token
+// was loaded; false otherwise. If false, all tokens
+// have been consumed.
+func (d *Dispenser) Next() bool {
+ if d.cursor < len(d.tokens)-1 {
+ d.cursor++
+ return true
+ }
+ return false
+}
+
+// NextArg loads the next token if it is on the same
+// line. Returns true if a token was loaded; false
+// otherwise. If false, all tokens on the line have
+// been consumed. It handles imported tokens correctly.
+func (d *Dispenser) NextArg() bool {
+ if d.cursor < 0 {
+ d.cursor++
+ return true
+ }
+ if d.cursor >= len(d.tokens) {
+ return false
+ }
+ if d.cursor < len(d.tokens)-1 &&
+ d.tokens[d.cursor].file == d.tokens[d.cursor+1].file &&
+ d.tokens[d.cursor].line+d.numLineBreaks(d.cursor) == d.tokens[d.cursor+1].line {
+ d.cursor++
+ return true
+ }
+ return false
+}
+
+// NextLine loads the next token only if it is not on the same
+// line as the current token, and returns true if a token was
+// loaded; false otherwise. If false, there is not another token
+// or it is on the same line. It handles imported tokens correctly.
+func (d *Dispenser) NextLine() bool {
+ if d.cursor < 0 {
+ d.cursor++
+ return true
+ }
+ if d.cursor >= len(d.tokens) {
+ return false
+ }
+ if d.cursor < len(d.tokens)-1 &&
+ (d.tokens[d.cursor].file != d.tokens[d.cursor+1].file ||
+ d.tokens[d.cursor].line+d.numLineBreaks(d.cursor) < d.tokens[d.cursor+1].line) {
+ d.cursor++
+ return true
+ }
+ return false
+}
+
+// NextBlock can be used as the condition of a for loop
+// to load the next token as long as it opens a block or
+// is already in a block. It returns true if a token was
+// loaded, or false when the block's closing curly brace
+// was loaded and thus the block ended. Nested blocks are
+// not supported.
+func (d *Dispenser) NextBlock() bool {
+ if d.nesting > 0 {
+ d.Next()
+ if d.Val() == "}" {
+ d.nesting--
+ return false
+ }
+ return true
+ }
+ if !d.NextArg() { // block must open on same line
+ return false
+ }
+ if d.Val() != "{" {
+ d.cursor-- // roll back if not opening brace
+ return false
+ }
+ d.Next()
+ if d.Val() == "}" {
+ // Open and then closed right away
+ return false
+ }
+ d.nesting++
+ return true
+}
+
+// IncrNest adds a level of nesting to the dispenser.
+func (d *Dispenser) IncrNest() {
+ d.nesting++
+ return
+}
+
+// Val gets the text of the current token. If there is no token
+// loaded, it returns empty string.
+func (d *Dispenser) Val() string {
+ if d.cursor < 0 || d.cursor >= len(d.tokens) {
+ return ""
+ }
+ return d.tokens[d.cursor].text
+}
+
+// Line gets the line number of the current token. If there is no token
+// loaded, it returns 0.
+func (d *Dispenser) Line() int {
+ if d.cursor < 0 || d.cursor >= len(d.tokens) {
+ return 0
+ }
+ return d.tokens[d.cursor].line
+}
+
+// File gets the filename of the current token. If there is no token loaded,
+// it returns the filename originally given when parsing started.
+func (d *Dispenser) File() string {
+ if d.cursor < 0 || d.cursor >= len(d.tokens) {
+ return d.filename
+ }
+ if tokenFilename := d.tokens[d.cursor].file; tokenFilename != "" {
+ return tokenFilename
+ }
+ return d.filename
+}
+
+// Args is a convenience function that loads the next arguments
+// (tokens on the same line) into an arbitrary number of strings
+// pointed to in targets. If there are fewer tokens available
+// than string pointers, the remaining strings will not be changed
+// and false will be returned. If there were enough tokens available
+// to fill the arguments, then true will be returned.
+func (d *Dispenser) Args(targets ...*string) bool {
+ enough := true
+ for i := 0; i < len(targets); i++ {
+ if !d.NextArg() {
+ enough = false
+ break
+ }
+ *targets[i] = d.Val()
+ }
+ return enough
+}
+
+// RemainingArgs loads any more arguments (tokens on the same line)
+// into a slice and returns them. Open curly brace tokens also indicate
+// the end of arguments, and the curly brace is not included in
+// the return value nor is it loaded.
+func (d *Dispenser) RemainingArgs() []string {
+ var args []string
+
+ for d.NextArg() {
+ if d.Val() == "{" {
+ d.cursor--
+ break
+ }
+ args = append(args, d.Val())
+ }
+
+ return args
+}
+
+// ArgErr returns an argument error, meaning that another
+// argument was expected but not found. In other words,
+// a line break or open curly brace was encountered instead of
+// an argument.
+func (d *Dispenser) ArgErr() error {
+ if d.Val() == "{" {
+ return d.Err("Unexpected token '{', expecting argument")
+ }
+ return d.Errf("Wrong argument count or unexpected line ending after '%s'", d.Val())
+}
+
+// SyntaxErr creates a generic syntax error which explains what was
+// found and what was expected.
+func (d *Dispenser) SyntaxErr(expected string) error {
+ msg := fmt.Sprintf("%s:%d - Syntax error: Unexpected token '%s', expecting '%s'", d.File(), d.Line(), d.Val(), expected)
+ return errors.New(msg)
+}
+
+// EOFErr returns an error indicating that the dispenser reached
+// the end of the input when searching for the next token.
+func (d *Dispenser) EOFErr() error {
+ return d.Errf("Unexpected EOF")
+}
+
+// Err generates a custom parse error with a message of msg.
+func (d *Dispenser) Err(msg string) error {
+ msg = fmt.Sprintf("%s:%d - Parse error: %s", d.File(), d.Line(), msg)
+ return errors.New(msg)
+}
+
+// Errf is like Err, but for formatted error messages
+func (d *Dispenser) Errf(format string, args ...interface{}) error {
+ return d.Err(fmt.Sprintf(format, args...))
+}
+
+// numLineBreaks counts how many line breaks are in the token
+// value given by the token index tknIdx. It returns 0 if the
+// token does not exist or there are no line breaks.
+func (d *Dispenser) numLineBreaks(tknIdx int) int {
+ if tknIdx < 0 || tknIdx >= len(d.tokens) {
+ return 0
+ }
+ return strings.Count(d.tokens[tknIdx].text, "\n")
+}
+
+// isNewLine determines whether the current token is on a different
+// line (higher line number) than the previous token. It handles imported
+// tokens correctly. If there isn't a previous token, it returns true.
+func (d *Dispenser) isNewLine() bool {
+ if d.cursor < 1 {
+ return true
+ }
+ if d.cursor > len(d.tokens)-1 {
+ return false
+ }
+ return d.tokens[d.cursor-1].file != d.tokens[d.cursor].file ||
+ d.tokens[d.cursor-1].line+d.numLineBreaks(d.cursor-1) < d.tokens[d.cursor].line
+}
diff --git a/core/parse/dispenser_test.go b/core/parse/dispenser_test.go
new file mode 100644
index 000000000..20a7ddcac
--- /dev/null
+++ b/core/parse/dispenser_test.go
@@ -0,0 +1,292 @@
+package parse
+
+import (
+ "reflect"
+ "strings"
+ "testing"
+)
+
+func TestDispenser_Val_Next(t *testing.T) {
+ input := `host:port
+ dir1 arg1
+ dir2 arg2 arg3
+ dir3`
+ d := NewDispenser("Testfile", strings.NewReader(input))
+
+ if val := d.Val(); val != "" {
+ t.Fatalf("Val(): Should return empty string when no token loaded; got '%s'", val)
+ }
+
+ assertNext := func(shouldLoad bool, expectedCursor int, expectedVal string) {
+ if loaded := d.Next(); loaded != shouldLoad {
+ t.Errorf("Next(): Expected %v but got %v instead (val '%s')", shouldLoad, loaded, d.Val())
+ }
+ if d.cursor != expectedCursor {
+ t.Errorf("Expected cursor to be %d, but was %d", expectedCursor, d.cursor)
+ }
+ if d.nesting != 0 {
+ t.Errorf("Nesting should be 0, was %d instead", d.nesting)
+ }
+ if val := d.Val(); val != expectedVal {
+ t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val)
+ }
+ }
+
+ assertNext(true, 0, "host:port")
+ assertNext(true, 1, "dir1")
+ assertNext(true, 2, "arg1")
+ assertNext(true, 3, "dir2")
+ assertNext(true, 4, "arg2")
+ assertNext(true, 5, "arg3")
+ assertNext(true, 6, "dir3")
+ // Note: This next test simply asserts existing behavior.
+ // If desired, we may wish to empty the token value after
+ // reading past the EOF. Open an issue if you want this change.
+ assertNext(false, 6, "dir3")
+}
+
+func TestDispenser_NextArg(t *testing.T) {
+ input := `dir1 arg1
+ dir2 arg2 arg3
+ dir3`
+ d := NewDispenser("Testfile", strings.NewReader(input))
+
+ assertNext := func(shouldLoad bool, expectedVal string, expectedCursor int) {
+ if d.Next() != shouldLoad {
+ t.Errorf("Next(): Should load token but got false instead (val: '%s')", d.Val())
+ }
+ if d.cursor != expectedCursor {
+ t.Errorf("Next(): Expected cursor to be at %d, but it was %d", expectedCursor, d.cursor)
+ }
+ if val := d.Val(); val != expectedVal {
+ t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val)
+ }
+ }
+
+ assertNextArg := func(expectedVal string, loadAnother bool, expectedCursor int) {
+ if d.NextArg() != true {
+ t.Error("NextArg(): Should load next argument but got false instead")
+ }
+ if d.cursor != expectedCursor {
+ t.Errorf("NextArg(): Expected cursor to be at %d, but it was %d", expectedCursor, d.cursor)
+ }
+ if val := d.Val(); val != expectedVal {
+ t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val)
+ }
+ if !loadAnother {
+ if d.NextArg() != false {
+ t.Fatalf("NextArg(): Should NOT load another argument, but got true instead (val: '%s')", d.Val())
+ }
+ if d.cursor != expectedCursor {
+ t.Errorf("NextArg(): Expected cursor to remain at %d, but it was %d", expectedCursor, d.cursor)
+ }
+ }
+ }
+
+ assertNext(true, "dir1", 0)
+ assertNextArg("arg1", false, 1)
+ assertNext(true, "dir2", 2)
+ assertNextArg("arg2", true, 3)
+ assertNextArg("arg3", false, 4)
+ assertNext(true, "dir3", 5)
+ assertNext(false, "dir3", 5)
+}
+
+func TestDispenser_NextLine(t *testing.T) {
+ input := `host:port
+ dir1 arg1
+ dir2 arg2 arg3`
+ d := NewDispenser("Testfile", strings.NewReader(input))
+
+ assertNextLine := func(shouldLoad bool, expectedVal string, expectedCursor int) {
+ if d.NextLine() != shouldLoad {
+ t.Errorf("NextLine(): Should load token but got false instead (val: '%s')", d.Val())
+ }
+ if d.cursor != expectedCursor {
+ t.Errorf("NextLine(): Expected cursor to be %d, instead was %d", expectedCursor, d.cursor)
+ }
+ if val := d.Val(); val != expectedVal {
+ t.Errorf("Val(): Expected '%s' but got '%s'", expectedVal, val)
+ }
+ }
+
+ assertNextLine(true, "host:port", 0)
+ assertNextLine(true, "dir1", 1)
+ assertNextLine(false, "dir1", 1)
+ d.Next() // arg1
+ assertNextLine(true, "dir2", 3)
+ assertNextLine(false, "dir2", 3)
+ d.Next() // arg2
+ assertNextLine(false, "arg2", 4)
+ d.Next() // arg3
+ assertNextLine(false, "arg3", 5)
+}
+
+func TestDispenser_NextBlock(t *testing.T) {
+ input := `foobar1 {
+ sub1 arg1
+ sub2
+ }
+ foobar2 {
+ }`
+ d := NewDispenser("Testfile", strings.NewReader(input))
+
+ assertNextBlock := func(shouldLoad bool, expectedCursor, expectedNesting int) {
+ if loaded := d.NextBlock(); loaded != shouldLoad {
+ t.Errorf("NextBlock(): Should return %v but got %v", shouldLoad, loaded)
+ }
+ if d.cursor != expectedCursor {
+ t.Errorf("NextBlock(): Expected cursor to be %d, was %d", expectedCursor, d.cursor)
+ }
+ if d.nesting != expectedNesting {
+ t.Errorf("NextBlock(): Nesting should be %d, not %d", expectedNesting, d.nesting)
+ }
+ }
+
+ assertNextBlock(false, -1, 0)
+ d.Next() // foobar1
+ assertNextBlock(true, 2, 1)
+ assertNextBlock(true, 3, 1)
+ assertNextBlock(true, 4, 1)
+ assertNextBlock(false, 5, 0)
+ d.Next() // foobar2
+ assertNextBlock(false, 8, 0) // empty block is as if it didn't exist
+}
+
+func TestDispenser_Args(t *testing.T) {
+ var s1, s2, s3 string
+ input := `dir1 arg1 arg2 arg3
+ dir2 arg4 arg5
+ dir3 arg6 arg7
+ dir4`
+ d := NewDispenser("Testfile", strings.NewReader(input))
+
+ d.Next() // dir1
+
+ // As many strings as arguments
+ if all := d.Args(&s1, &s2, &s3); !all {
+ t.Error("Args(): Expected true, got false")
+ }
+ if s1 != "arg1" {
+ t.Errorf("Args(): Expected s1 to be 'arg1', got '%s'", s1)
+ }
+ if s2 != "arg2" {
+ t.Errorf("Args(): Expected s2 to be 'arg2', got '%s'", s2)
+ }
+ if s3 != "arg3" {
+ t.Errorf("Args(): Expected s3 to be 'arg3', got '%s'", s3)
+ }
+
+ d.Next() // dir2
+
+ // More strings than arguments
+ if all := d.Args(&s1, &s2, &s3); all {
+ t.Error("Args(): Expected false, got true")
+ }
+ if s1 != "arg4" {
+ t.Errorf("Args(): Expected s1 to be 'arg4', got '%s'", s1)
+ }
+ if s2 != "arg5" {
+ t.Errorf("Args(): Expected s2 to be 'arg5', got '%s'", s2)
+ }
+ if s3 != "arg3" {
+ t.Errorf("Args(): Expected s3 to be unchanged ('arg3'), instead got '%s'", s3)
+ }
+
+ // (quick cursor check just for kicks and giggles)
+ if d.cursor != 6 {
+ t.Errorf("Cursor should be 6, but is %d", d.cursor)
+ }
+
+ d.Next() // dir3
+
+ // More arguments than strings
+ if all := d.Args(&s1); !all {
+ t.Error("Args(): Expected true, got false")
+ }
+ if s1 != "arg6" {
+ t.Errorf("Args(): Expected s1 to be 'arg6', got '%s'", s1)
+ }
+
+ d.Next() // dir4
+
+ // No arguments or strings
+ if all := d.Args(); !all {
+ t.Error("Args(): Expected true, got false")
+ }
+
+ // No arguments but at least one string
+ if all := d.Args(&s1); all {
+ t.Error("Args(): Expected false, got true")
+ }
+}
+
+func TestDispenser_RemainingArgs(t *testing.T) {
+ input := `dir1 arg1 arg2 arg3
+ dir2 arg4 arg5
+ dir3 arg6 { arg7
+ dir4`
+ d := NewDispenser("Testfile", strings.NewReader(input))
+
+ d.Next() // dir1
+
+ args := d.RemainingArgs()
+ if expected := []string{"arg1", "arg2", "arg3"}; !reflect.DeepEqual(args, expected) {
+ t.Errorf("RemainingArgs(): Expected %v, got %v", expected, args)
+ }
+
+ d.Next() // dir2
+
+ args = d.RemainingArgs()
+ if expected := []string{"arg4", "arg5"}; !reflect.DeepEqual(args, expected) {
+ t.Errorf("RemainingArgs(): Expected %v, got %v", expected, args)
+ }
+
+ d.Next() // dir3
+
+ args = d.RemainingArgs()
+ if expected := []string{"arg6"}; !reflect.DeepEqual(args, expected) {
+ t.Errorf("RemainingArgs(): Expected %v, got %v", expected, args)
+ }
+
+ d.Next() // {
+ d.Next() // arg7
+ d.Next() // dir4
+
+ args = d.RemainingArgs()
+ if len(args) != 0 {
+ t.Errorf("RemainingArgs(): Expected %v, got %v", []string{}, args)
+ }
+}
+
+func TestDispenser_ArgErr_Err(t *testing.T) {
+ input := `dir1 {
+ }
+ dir2 arg1 arg2`
+ d := NewDispenser("Testfile", strings.NewReader(input))
+
+ d.cursor = 1 // {
+
+ if err := d.ArgErr(); err == nil || !strings.Contains(err.Error(), "{") {
+ t.Errorf("ArgErr(): Expected an error message with { in it, but got '%v'", err)
+ }
+
+ d.cursor = 5 // arg2
+
+ if err := d.ArgErr(); err == nil || !strings.Contains(err.Error(), "arg2") {
+ t.Errorf("ArgErr(): Expected an error message with 'arg2' in it; got '%v'", err)
+ }
+
+ err := d.Err("foobar")
+ if err == nil {
+ t.Fatalf("Err(): Expected an error, got nil")
+ }
+
+ if !strings.Contains(err.Error(), "Testfile:3") {
+ t.Errorf("Expected error message with filename:line in it; got '%v'", err)
+ }
+
+ if !strings.Contains(err.Error(), "foobar") {
+ t.Errorf("Expected error message with custom message in it ('foobar'); got '%v'", err)
+ }
+}
diff --git a/core/parse/import_glob0.txt b/core/parse/import_glob0.txt
new file mode 100644
index 000000000..e610b5e7c
--- /dev/null
+++ b/core/parse/import_glob0.txt
@@ -0,0 +1,6 @@
+glob0.host0 {
+ dir2 arg1
+}
+
+glob0.host1 {
+}
diff --git a/core/parse/import_glob1.txt b/core/parse/import_glob1.txt
new file mode 100644
index 000000000..111eb044d
--- /dev/null
+++ b/core/parse/import_glob1.txt
@@ -0,0 +1,4 @@
+glob1.host0 {
+ dir1
+ dir2 arg1
+}
diff --git a/core/parse/import_glob2.txt b/core/parse/import_glob2.txt
new file mode 100644
index 000000000..c09f784ec
--- /dev/null
+++ b/core/parse/import_glob2.txt
@@ -0,0 +1,3 @@
+glob2.host0 {
+ dir2 arg1
+}
diff --git a/core/parse/import_test1.txt b/core/parse/import_test1.txt
new file mode 100644
index 000000000..dac7b29be
--- /dev/null
+++ b/core/parse/import_test1.txt
@@ -0,0 +1,2 @@
+dir2 arg1 arg2
+dir3 \ No newline at end of file
diff --git a/core/parse/import_test2.txt b/core/parse/import_test2.txt
new file mode 100644
index 000000000..140c87939
--- /dev/null
+++ b/core/parse/import_test2.txt
@@ -0,0 +1,4 @@
+host1 {
+ dir1
+ dir2 arg1
+} \ No newline at end of file
diff --git a/core/parse/lexer.go b/core/parse/lexer.go
new file mode 100644
index 000000000..d2939eba2
--- /dev/null
+++ b/core/parse/lexer.go
@@ -0,0 +1,122 @@
+package parse
+
+import (
+ "bufio"
+ "io"
+ "unicode"
+)
+
+type (
+ // lexer is a utility which can get values, token by
+ // token, from a Reader. A token is a word, and tokens
+ // are separated by whitespace. A word can be enclosed
+ // in quotes if it contains whitespace.
+ lexer struct {
+ reader *bufio.Reader
+ token token
+ line int
+ }
+
+ // token represents a single parsable unit.
+ token struct {
+ file string
+ line int
+ text string
+ }
+)
+
+// load prepares the lexer to scan an input for tokens.
+func (l *lexer) load(input io.Reader) error {
+ l.reader = bufio.NewReader(input)
+ l.line = 1
+ return nil
+}
+
+// next loads the next token into the lexer.
+// A token is delimited by whitespace, unless
+// the token starts with a quotes character (")
+// in which case the token goes until the closing
+// quotes (the enclosing quotes are not included).
+// Inside quoted strings, quotes may be escaped
+// with a preceding \ character. No other chars
+// may be escaped. The rest of the line is skipped
+// if a "#" character is read in. Returns true if
+// a token was loaded; false otherwise.
+func (l *lexer) next() bool {
+ var val []rune
+ var comment, quoted, escaped bool
+
+ makeToken := func() bool {
+ l.token.text = string(val)
+ return true
+ }
+
+ for {
+ ch, _, err := l.reader.ReadRune()
+ if err != nil {
+ if len(val) > 0 {
+ return makeToken()
+ }
+ if err == io.EOF {
+ return false
+ }
+ panic(err)
+ }
+
+ if quoted {
+ if !escaped {
+ if ch == '\\' {
+ escaped = true
+ continue
+ } else if ch == '"' {
+ quoted = false
+ return makeToken()
+ }
+ }
+ if ch == '\n' {
+ l.line++
+ }
+ if escaped {
+ // only escape quotes
+ if ch != '"' {
+ val = append(val, '\\')
+ }
+ }
+ val = append(val, ch)
+ escaped = false
+ continue
+ }
+
+ if unicode.IsSpace(ch) {
+ if ch == '\r' {
+ continue
+ }
+ if ch == '\n' {
+ l.line++
+ comment = false
+ }
+ if len(val) > 0 {
+ return makeToken()
+ }
+ continue
+ }
+
+ if ch == '#' {
+ comment = true
+ }
+
+ if comment {
+ continue
+ }
+
+ if len(val) == 0 {
+ l.token = token{line: l.line}
+ if ch == '"' {
+ quoted = true
+ continue
+ }
+ }
+
+ val = append(val, ch)
+ }
+}
diff --git a/core/parse/lexer_test.go b/core/parse/lexer_test.go
new file mode 100644
index 000000000..f12c7e7dc
--- /dev/null
+++ b/core/parse/lexer_test.go
@@ -0,0 +1,165 @@
+package parse
+
+import (
+ "strings"
+ "testing"
+)
+
+type lexerTestCase struct {
+ input string
+ expected []token
+}
+
+func TestLexer(t *testing.T) {
+ testCases := []lexerTestCase{
+ {
+ input: `host:123`,
+ expected: []token{
+ {line: 1, text: "host:123"},
+ },
+ },
+ {
+ input: `host:123
+
+ directive`,
+ expected: []token{
+ {line: 1, text: "host:123"},
+ {line: 3, text: "directive"},
+ },
+ },
+ {
+ input: `host:123 {
+ directive
+ }`,
+ expected: []token{
+ {line: 1, text: "host:123"},
+ {line: 1, text: "{"},
+ {line: 2, text: "directive"},
+ {line: 3, text: "}"},
+ },
+ },
+ {
+ input: `host:123 { directive }`,
+ expected: []token{
+ {line: 1, text: "host:123"},
+ {line: 1, text: "{"},
+ {line: 1, text: "directive"},
+ {line: 1, text: "}"},
+ },
+ },
+ {
+ input: `host:123 {
+ #comment
+ directive
+ # comment
+ foobar # another comment
+ }`,
+ expected: []token{
+ {line: 1, text: "host:123"},
+ {line: 1, text: "{"},
+ {line: 3, text: "directive"},
+ {line: 5, text: "foobar"},
+ {line: 6, text: "}"},
+ },
+ },
+ {
+ input: `a "quoted value" b
+ foobar`,
+ expected: []token{
+ {line: 1, text: "a"},
+ {line: 1, text: "quoted value"},
+ {line: 1, text: "b"},
+ {line: 2, text: "foobar"},
+ },
+ },
+ {
+ input: `A "quoted \"value\" inside" B`,
+ expected: []token{
+ {line: 1, text: "A"},
+ {line: 1, text: `quoted "value" inside`},
+ {line: 1, text: "B"},
+ },
+ },
+ {
+ input: `"don't\escape"`,
+ expected: []token{
+ {line: 1, text: `don't\escape`},
+ },
+ },
+ {
+ input: `"don't\\escape"`,
+ expected: []token{
+ {line: 1, text: `don't\\escape`},
+ },
+ },
+ {
+ input: `A "quoted value with line
+ break inside" {
+ foobar
+ }`,
+ expected: []token{
+ {line: 1, text: "A"},
+ {line: 1, text: "quoted value with line\n\t\t\t\t\tbreak inside"},
+ {line: 2, text: "{"},
+ {line: 3, text: "foobar"},
+ {line: 4, text: "}"},
+ },
+ },
+ {
+ input: `"C:\php\php-cgi.exe"`,
+ expected: []token{
+ {line: 1, text: `C:\php\php-cgi.exe`},
+ },
+ },
+ {
+ input: `empty "" string`,
+ expected: []token{
+ {line: 1, text: `empty`},
+ {line: 1, text: ``},
+ {line: 1, text: `string`},
+ },
+ },
+ {
+ input: "skip those\r\nCR characters",
+ expected: []token{
+ {line: 1, text: "skip"},
+ {line: 1, text: "those"},
+ {line: 2, text: "CR"},
+ {line: 2, text: "characters"},
+ },
+ },
+ }
+
+ for i, testCase := range testCases {
+ actual := tokenize(testCase.input)
+ lexerCompare(t, i, testCase.expected, actual)
+ }
+}
+
+func tokenize(input string) (tokens []token) {
+ l := lexer{}
+ l.load(strings.NewReader(input))
+ for l.next() {
+ tokens = append(tokens, l.token)
+ }
+ return
+}
+
+func lexerCompare(t *testing.T, n int, expected, actual []token) {
+ if len(expected) != len(actual) {
+ t.Errorf("Test case %d: expected %d token(s) but got %d", n, len(expected), len(actual))
+ }
+
+ for i := 0; i < len(actual) && i < len(expected); i++ {
+ if actual[i].line != expected[i].line {
+ t.Errorf("Test case %d token %d ('%s'): expected line %d but was line %d",
+ n, i, expected[i].text, expected[i].line, actual[i].line)
+ break
+ }
+ if actual[i].text != expected[i].text {
+ t.Errorf("Test case %d token %d: expected text '%s' but was '%s'",
+ n, i, expected[i].text, actual[i].text)
+ break
+ }
+ }
+}
diff --git a/core/parse/parse.go b/core/parse/parse.go
new file mode 100644
index 000000000..faef36c28
--- /dev/null
+++ b/core/parse/parse.go
@@ -0,0 +1,32 @@
+// Package parse provides facilities for parsing configuration files.
+package parse
+
+import "io"
+
+// ServerBlocks parses the input just enough to organize tokens,
+// in order, by server block. No further parsing is performed.
+// If checkDirectives is true, only valid directives will be allowed
+// otherwise we consider it a parse error. Server blocks are returned
+// in the order in which they appear.
+func ServerBlocks(filename string, input io.Reader, checkDirectives bool) ([]ServerBlock, error) {
+ p := parser{Dispenser: NewDispenser(filename, input)}
+ p.checkDirectives = checkDirectives
+ blocks, err := p.parseAll()
+ return blocks, err
+}
+
+// allTokens lexes the entire input, but does not parse it.
+// It returns all the tokens from the input, unstructured
+// and in order.
+func allTokens(input io.Reader) (tokens []token) {
+ l := new(lexer)
+ l.load(input)
+ for l.next() {
+ tokens = append(tokens, l.token)
+ }
+ return
+}
+
+// ValidDirectives is a set of directives that are valid (unordered). Populated
+// by config package's init function.
+var ValidDirectives = make(map[string]struct{})
diff --git a/core/parse/parse_test.go b/core/parse/parse_test.go
new file mode 100644
index 000000000..48746300f
--- /dev/null
+++ b/core/parse/parse_test.go
@@ -0,0 +1,22 @@
+package parse
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestAllTokens(t *testing.T) {
+ input := strings.NewReader("a b c\nd e")
+ expected := []string{"a", "b", "c", "d", "e"}
+ tokens := allTokens(input)
+
+ if len(tokens) != len(expected) {
+ t.Fatalf("Expected %d tokens, got %d", len(expected), len(tokens))
+ }
+
+ for i, val := range expected {
+ if tokens[i].text != val {
+ t.Errorf("Token %d should be '%s' but was '%s'", i, val, tokens[i].text)
+ }
+ }
+}
diff --git a/core/parse/parsing.go b/core/parse/parsing.go
new file mode 100644
index 000000000..6e73bd584
--- /dev/null
+++ b/core/parse/parsing.go
@@ -0,0 +1,379 @@
+package parse
+
+import (
+ "net"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/miekg/dns"
+)
+
+type parser struct {
+ Dispenser
+ block ServerBlock // current server block being parsed
+ eof bool // if we encounter a valid EOF in a hard place
+ checkDirectives bool // if true, directives must be known
+}
+
+func (p *parser) parseAll() ([]ServerBlock, error) {
+ var blocks []ServerBlock
+
+ for p.Next() {
+ err := p.parseOne()
+ if err != nil {
+ return blocks, err
+ }
+ if len(p.block.Addresses) > 0 {
+ blocks = append(blocks, p.block)
+ }
+ }
+
+ return blocks, nil
+}
+
+func (p *parser) parseOne() error {
+ p.block = ServerBlock{Tokens: make(map[string][]token)}
+
+ err := p.begin()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (p *parser) begin() error {
+ if len(p.tokens) == 0 {
+ return nil
+ }
+
+ err := p.addresses()
+ if err != nil {
+ return err
+ }
+
+ if p.eof {
+ // this happens if the Caddyfile consists of only
+ // a line of addresses and nothing else
+ return nil
+ }
+
+ err = p.blockContents()
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func (p *parser) addresses() error {
+ var expectingAnother bool
+
+ for {
+ tkn := replaceEnvVars(p.Val())
+
+ // special case: import directive replaces tokens during parse-time
+ if tkn == "import" && p.isNewLine() {
+ err := p.doImport()
+ if err != nil {
+ return err
+ }
+ continue
+ }
+
+ // Open brace definitely indicates end of addresses
+ if tkn == "{" {
+ if expectingAnother {
+ return p.Errf("Expected another address but had '%s' - check for extra comma", tkn)
+ }
+ break
+ }
+
+ if tkn != "" { // empty token possible if user typed "" in Caddyfile
+ // Trailing comma indicates another address will follow, which
+ // may possibly be on the next line
+ if tkn[len(tkn)-1] == ',' {
+ tkn = tkn[:len(tkn)-1]
+ expectingAnother = true
+ } else {
+ expectingAnother = false // but we may still see another one on this line
+ }
+
+ // Parse and save this address
+ addr, err := standardAddress(tkn)
+ if err != nil {
+ return err
+ }
+ p.block.Addresses = append(p.block.Addresses, addr)
+ }
+
+ // Advance token and possibly break out of loop or return error
+ hasNext := p.Next()
+ if expectingAnother && !hasNext {
+ return p.EOFErr()
+ }
+ if !hasNext {
+ p.eof = true
+ break // EOF
+ }
+ if !expectingAnother && p.isNewLine() {
+ break
+ }
+ }
+
+ return nil
+}
+
+func (p *parser) blockContents() error {
+ errOpenCurlyBrace := p.openCurlyBrace()
+ if errOpenCurlyBrace != nil {
+ // single-server configs don't need curly braces
+ p.cursor--
+ }
+
+ err := p.directives()
+ if err != nil {
+ return err
+ }
+
+ // Only look for close curly brace if there was an opening
+ if errOpenCurlyBrace == nil {
+ err = p.closeCurlyBrace()
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// directives parses through all the lines for directives
+// and it expects the next token to be the first
+// directive. It goes until EOF or closing curly brace
+// which ends the server block.
+func (p *parser) directives() error {
+ for p.Next() {
+ // end of server block
+ if p.Val() == "}" {
+ break
+ }
+
+ // special case: import directive replaces tokens during parse-time
+ if p.Val() == "import" {
+ err := p.doImport()
+ if err != nil {
+ return err
+ }
+ p.cursor-- // cursor is advanced when we continue, so roll back one more
+ continue
+ }
+
+ // normal case: parse a directive on this line
+ if err := p.directive(); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// doImport swaps out the import directive and its argument
+// (a total of 2 tokens) with the tokens in the specified file
+// or globbing pattern. When the function returns, the cursor
+// is on the token before where the import directive was. In
+// other words, call Next() to access the first token that was
+// imported.
+func (p *parser) doImport() error {
+ // syntax check
+ if !p.NextArg() {
+ return p.ArgErr()
+ }
+ importPattern := p.Val()
+ if p.NextArg() {
+ return p.Err("Import takes only one argument (glob pattern or file)")
+ }
+
+ // do glob
+ matches, err := filepath.Glob(importPattern)
+ if err != nil {
+ return p.Errf("Failed to use import pattern %s: %v", importPattern, err)
+ }
+ if len(matches) == 0 {
+ return p.Errf("No files matching import pattern %s", importPattern)
+ }
+
+ // splice out the import directive and its argument (2 tokens total)
+ tokensBefore := p.tokens[:p.cursor-1]
+ tokensAfter := p.tokens[p.cursor+1:]
+
+ // collect all the imported tokens
+ var importedTokens []token
+ for _, importFile := range matches {
+ newTokens, err := p.doSingleImport(importFile)
+ if err != nil {
+ return err
+ }
+ importedTokens = append(importedTokens, newTokens...)
+ }
+
+ // splice the imported tokens in the place of the import statement
+ // and rewind cursor so Next() will land on first imported token
+ p.tokens = append(tokensBefore, append(importedTokens, tokensAfter...)...)
+ p.cursor--
+
+ return nil
+}
+
+// doSingleImport lexes the individual file at importFile and returns
+// its tokens or an error, if any.
+func (p *parser) doSingleImport(importFile string) ([]token, error) {
+ file, err := os.Open(importFile)
+ if err != nil {
+ return nil, p.Errf("Could not import %s: %v", importFile, err)
+ }
+ defer file.Close()
+ importedTokens := allTokens(file)
+
+ // Tack the filename onto these tokens so errors show the imported file's name
+ filename := filepath.Base(importFile)
+ for i := 0; i < len(importedTokens); i++ {
+ importedTokens[i].file = filename
+ }
+
+ return importedTokens, nil
+}
+
+// directive collects tokens until the directive's scope
+// closes (either end of line or end of curly brace block).
+// It expects the currently-loaded token to be a directive
+// (or } that ends a server block). The collected tokens
+// are loaded into the current server block for later use
+// by directive setup functions.
+func (p *parser) directive() error {
+ dir := p.Val()
+ nesting := 0
+
+ if p.checkDirectives {
+ if _, ok := ValidDirectives[dir]; !ok {
+ return p.Errf("Unknown directive '%s'", dir)
+ }
+ }
+
+ // The directive itself is appended as a relevant token
+ p.block.Tokens[dir] = append(p.block.Tokens[dir], p.tokens[p.cursor])
+
+ for p.Next() {
+ if p.Val() == "{" {
+ nesting++
+ } else if p.isNewLine() && nesting == 0 {
+ p.cursor-- // read too far
+ break
+ } else if p.Val() == "}" && nesting > 0 {
+ nesting--
+ } else if p.Val() == "}" && nesting == 0 {
+ return p.Err("Unexpected '}' because no matching opening brace")
+ }
+ p.tokens[p.cursor].text = replaceEnvVars(p.tokens[p.cursor].text)
+ p.block.Tokens[dir] = append(p.block.Tokens[dir], p.tokens[p.cursor])
+ }
+
+ if nesting > 0 {
+ return p.EOFErr()
+ }
+ return nil
+}
+
+// openCurlyBrace expects the current token to be an
+// opening curly brace. This acts like an assertion
+// because it returns an error if the token is not
+// a opening curly brace. It does NOT advance the token.
+func (p *parser) openCurlyBrace() error {
+ if p.Val() != "{" {
+ return p.SyntaxErr("{")
+ }
+ return nil
+}
+
+// closeCurlyBrace expects the current token to be
+// a closing curly brace. This acts like an assertion
+// because it returns an error if the token is not
+// a closing curly brace. It does NOT advance the token.
+func (p *parser) closeCurlyBrace() error {
+ if p.Val() != "}" {
+ return p.SyntaxErr("}")
+ }
+ return nil
+}
+
+// standardAddress parses an address string into a structured format with separate
+// host, and port portions, as well as the original input string.
+func standardAddress(str string) (address, error) {
+ var err error
+
+ // first check for scheme and strip it off
+ input := str
+
+ // separate host and port
+ host, port, err := net.SplitHostPort(str)
+ if err != nil {
+ host, port, err = net.SplitHostPort(str + ":")
+ // no error check here; return err at end of function
+ }
+
+ // see if we can set port based off scheme
+ if port == "" {
+ port = "53"
+ }
+
+ return address{Original: input, Host: strings.ToLower(dns.Fqdn(host)), Port: port}, err
+}
+
+// replaceEnvVars replaces environment variables that appear in the token
+// and understands both the $UNIX and %WINDOWS% syntaxes.
+func replaceEnvVars(s string) string {
+ s = replaceEnvReferences(s, "{%", "%}")
+ s = replaceEnvReferences(s, "{$", "}")
+ return s
+}
+
+// replaceEnvReferences performs the actual replacement of env variables
+// in s, given the placeholder start and placeholder end strings.
+func replaceEnvReferences(s, refStart, refEnd string) string {
+ index := strings.Index(s, refStart)
+ for index != -1 {
+ endIndex := strings.Index(s, refEnd)
+ if endIndex != -1 {
+ ref := s[index : endIndex+len(refEnd)]
+ s = strings.Replace(s, ref, os.Getenv(ref[len(refStart):len(ref)-len(refEnd)]), -1)
+ } else {
+ return s
+ }
+ index = strings.Index(s, refStart)
+ }
+ return s
+}
+
+type (
+ // ServerBlock associates tokens with a list of addresses
+ // and groups tokens by directive name.
+ ServerBlock struct {
+ Addresses []address
+ Tokens map[string][]token
+ }
+
+ address struct {
+ Original, Host, Port string
+ }
+)
+
+// HostList converts the list of addresses that are
+// associated with this server block into a slice of
+// strings, where each address is as it was originally
+// read from the input.
+func (sb ServerBlock) HostList() []string {
+ sbHosts := make([]string, len(sb.Addresses))
+ for j, addr := range sb.Addresses {
+ sbHosts[j] = addr.Original
+ }
+ return sbHosts
+}
diff --git a/core/parse/parsing_test.go b/core/parse/parsing_test.go
new file mode 100644
index 000000000..493c0fff9
--- /dev/null
+++ b/core/parse/parsing_test.go
@@ -0,0 +1,477 @@
+package parse
+
+import (
+ "os"
+ "strings"
+ "testing"
+)
+
+func TestStandardAddress(t *testing.T) {
+ for i, test := range []struct {
+ input string
+ scheme, host, port string
+ shouldErr bool
+ }{
+ {`localhost`, "", "localhost", "", false},
+ {`localhost:1234`, "", "localhost", "1234", false},
+ {`localhost:`, "", "localhost", "", false},
+ {`0.0.0.0`, "", "0.0.0.0", "", false},
+ {`127.0.0.1:1234`, "", "127.0.0.1", "1234", false},
+ {`:1234`, "", "", "1234", false},
+ {`[::1]`, "", "::1", "", false},
+ {`[::1]:1234`, "", "::1", "1234", false},
+ {`:`, "", "", "", false},
+ {`localhost:http`, "http", "localhost", "80", false},
+ {`localhost:https`, "https", "localhost", "443", false},
+ {`:http`, "http", "", "80", false},
+ {`:https`, "https", "", "443", false},
+ {`http://localhost:https`, "", "", "", true}, // conflict
+ {`http://localhost:http`, "", "", "", true}, // repeated scheme
+ {`http://localhost:443`, "", "", "", true}, // not conventional
+ {`https://localhost:80`, "", "", "", true}, // not conventional
+ {`http://localhost`, "http", "localhost", "80", false},
+ {`https://localhost`, "https", "localhost", "443", false},
+ {`http://127.0.0.1`, "http", "127.0.0.1", "80", false},
+ {`https://127.0.0.1`, "https", "127.0.0.1", "443", false},
+ {`http://[::1]`, "http", "::1", "80", false},
+ {`http://localhost:1234`, "http", "localhost", "1234", false},
+ {`https://127.0.0.1:1234`, "https", "127.0.0.1", "1234", false},
+ {`http://[::1]:1234`, "http", "::1", "1234", false},
+ {``, "", "", "", false},
+ {`::1`, "", "::1", "", true},
+ {`localhost::`, "", "localhost::", "", true},
+ {`#$%@`, "", "#$%@", "", true},
+ } {
+ actual, err := standardAddress(test.input)
+
+ if err != nil && !test.shouldErr {
+ t.Errorf("Test %d (%s): Expected no error, but had error: %v", i, test.input, err)
+ }
+ if err == nil && test.shouldErr {
+ t.Errorf("Test %d (%s): Expected error, but had none", i, test.input)
+ }
+
+ if actual.Scheme != test.scheme {
+ t.Errorf("Test %d (%s): Expected scheme '%s', got '%s'", i, test.input, test.scheme, actual.Scheme)
+ }
+ if actual.Host != test.host {
+ t.Errorf("Test %d (%s): Expected host '%s', got '%s'", i, test.input, test.host, actual.Host)
+ }
+ if actual.Port != test.port {
+ t.Errorf("Test %d (%s): Expected port '%s', got '%s'", i, test.input, test.port, actual.Port)
+ }
+ }
+}
+
+func TestParseOneAndImport(t *testing.T) {
+ setupParseTests()
+
+ testParseOne := func(input string) (ServerBlock, error) {
+ p := testParser(input)
+ p.Next() // parseOne doesn't call Next() to start, so we must
+ err := p.parseOne()
+ return p.block, err
+ }
+
+ for i, test := range []struct {
+ input string
+ shouldErr bool
+ addresses []address
+ tokens map[string]int // map of directive name to number of tokens expected
+ }{
+ {`localhost`, false, []address{
+ {"localhost", "", "localhost", ""},
+ }, map[string]int{}},
+
+ {`localhost
+ dir1`, false, []address{
+ {"localhost", "", "localhost", ""},
+ }, map[string]int{
+ "dir1": 1,
+ }},
+
+ {`localhost:1234
+ dir1 foo bar`, false, []address{
+ {"localhost:1234", "", "localhost", "1234"},
+ }, map[string]int{
+ "dir1": 3,
+ }},
+
+ {`localhost {
+ dir1
+ }`, false, []address{
+ {"localhost", "", "localhost", ""},
+ }, map[string]int{
+ "dir1": 1,
+ }},
+
+ {`localhost:1234 {
+ dir1 foo bar
+ dir2
+ }`, false, []address{
+ {"localhost:1234", "", "localhost", "1234"},
+ }, map[string]int{
+ "dir1": 3,
+ "dir2": 1,
+ }},
+
+ {`http://localhost https://localhost
+ dir1 foo bar`, false, []address{
+ {"http://localhost", "http", "localhost", "80"},
+ {"https://localhost", "https", "localhost", "443"},
+ }, map[string]int{
+ "dir1": 3,
+ }},
+
+ {`http://localhost https://localhost {
+ dir1 foo bar
+ }`, false, []address{
+ {"http://localhost", "http", "localhost", "80"},
+ {"https://localhost", "https", "localhost", "443"},
+ }, map[string]int{
+ "dir1": 3,
+ }},
+
+ {`http://localhost, https://localhost {
+ dir1 foo bar
+ }`, false, []address{
+ {"http://localhost", "http", "localhost", "80"},
+ {"https://localhost", "https", "localhost", "443"},
+ }, map[string]int{
+ "dir1": 3,
+ }},
+
+ {`http://localhost, {
+ }`, true, []address{
+ {"http://localhost", "http", "localhost", "80"},
+ }, map[string]int{}},
+
+ {`host1:80, http://host2.com
+ dir1 foo bar
+ dir2 baz`, false, []address{
+ {"host1:80", "", "host1", "80"},
+ {"http://host2.com", "http", "host2.com", "80"},
+ }, map[string]int{
+ "dir1": 3,
+ "dir2": 2,
+ }},
+
+ {`http://host1.com,
+ http://host2.com,
+ https://host3.com`, false, []address{
+ {"http://host1.com", "http", "host1.com", "80"},
+ {"http://host2.com", "http", "host2.com", "80"},
+ {"https://host3.com", "https", "host3.com", "443"},
+ }, map[string]int{}},
+
+ {`http://host1.com:1234, https://host2.com
+ dir1 foo {
+ bar baz
+ }
+ dir2`, false, []address{
+ {"http://host1.com:1234", "http", "host1.com", "1234"},
+ {"https://host2.com", "https", "host2.com", "443"},
+ }, map[string]int{
+ "dir1": 6,
+ "dir2": 1,
+ }},
+
+ {`127.0.0.1
+ dir1 {
+ bar baz
+ }
+ dir2 {
+ foo bar
+ }`, false, []address{
+ {"127.0.0.1", "", "127.0.0.1", ""},
+ }, map[string]int{
+ "dir1": 5,
+ "dir2": 5,
+ }},
+
+ {`127.0.0.1
+ unknown_directive`, true, []address{
+ {"127.0.0.1", "", "127.0.0.1", ""},
+ }, map[string]int{}},
+
+ {`localhost
+ dir1 {
+ foo`, true, []address{
+ {"localhost", "", "localhost", ""},
+ }, map[string]int{
+ "dir1": 3,
+ }},
+
+ {`localhost
+ dir1 {
+ }`, false, []address{
+ {"localhost", "", "localhost", ""},
+ }, map[string]int{
+ "dir1": 3,
+ }},
+
+ {`localhost
+ dir1 {
+ } }`, true, []address{
+ {"localhost", "", "localhost", ""},
+ }, map[string]int{
+ "dir1": 3,
+ }},
+
+ {`localhost
+ dir1 {
+ nested {
+ foo
+ }
+ }
+ dir2 foo bar`, false, []address{
+ {"localhost", "", "localhost", ""},
+ }, map[string]int{
+ "dir1": 7,
+ "dir2": 3,
+ }},
+
+ {``, false, []address{}, map[string]int{}},
+
+ {`localhost
+ dir1 arg1
+ import import_test1.txt`, false, []address{
+ {"localhost", "", "localhost", ""},
+ }, map[string]int{
+ "dir1": 2,
+ "dir2": 3,
+ "dir3": 1,
+ }},
+
+ {`import import_test2.txt`, false, []address{
+ {"host1", "", "host1", ""},
+ }, map[string]int{
+ "dir1": 1,
+ "dir2": 2,
+ }},
+
+ {`import import_test1.txt import_test2.txt`, true, []address{}, map[string]int{}},
+
+ {`import not_found.txt`, true, []address{}, map[string]int{}},
+
+ {`""`, false, []address{}, map[string]int{}},
+
+ {``, false, []address{}, map[string]int{}},
+ } {
+ result, err := testParseOne(test.input)
+
+ if test.shouldErr && err == nil {
+ t.Errorf("Test %d: Expected an error, but didn't get one", i)
+ }
+ if !test.shouldErr && err != nil {
+ t.Errorf("Test %d: Expected no error, but got: %v", i, err)
+ }
+
+ if len(result.Addresses) != len(test.addresses) {
+ t.Errorf("Test %d: Expected %d addresses, got %d",
+ i, len(test.addresses), len(result.Addresses))
+ continue
+ }
+ for j, addr := range result.Addresses {
+ if addr.Host != test.addresses[j].Host {
+ t.Errorf("Test %d, address %d: Expected host to be '%s', but was '%s'",
+ i, j, test.addresses[j].Host, addr.Host)
+ }
+ if addr.Port != test.addresses[j].Port {
+ t.Errorf("Test %d, address %d: Expected port to be '%s', but was '%s'",
+ i, j, test.addresses[j].Port, addr.Port)
+ }
+ }
+
+ if len(result.Tokens) != len(test.tokens) {
+ t.Errorf("Test %d: Expected %d directives, had %d",
+ i, len(test.tokens), len(result.Tokens))
+ continue
+ }
+ for directive, tokens := range result.Tokens {
+ if len(tokens) != test.tokens[directive] {
+ t.Errorf("Test %d, directive '%s': Expected %d tokens, counted %d",
+ i, directive, test.tokens[directive], len(tokens))
+ continue
+ }
+ }
+ }
+}
+
+func TestParseAll(t *testing.T) {
+ setupParseTests()
+
+ for i, test := range []struct {
+ input string
+ shouldErr bool
+ addresses [][]address // addresses per server block, in order
+ }{
+ {`localhost`, false, [][]address{
+ {{"localhost", "", "localhost", ""}},
+ }},
+
+ {`localhost:1234`, false, [][]address{
+ {{"localhost:1234", "", "localhost", "1234"}},
+ }},
+
+ {`localhost:1234 {
+ }
+ localhost:2015 {
+ }`, false, [][]address{
+ {{"localhost:1234", "", "localhost", "1234"}},
+ {{"localhost:2015", "", "localhost", "2015"}},
+ }},
+
+ {`localhost:1234, http://host2`, false, [][]address{
+ {{"localhost:1234", "", "localhost", "1234"}, {"http://host2", "http", "host2", "80"}},
+ }},
+
+ {`localhost:1234, http://host2,`, true, [][]address{}},
+
+ {`http://host1.com, http://host2.com {
+ }
+ https://host3.com, https://host4.com {
+ }`, false, [][]address{
+ {{"http://host1.com", "http", "host1.com", "80"}, {"http://host2.com", "http", "host2.com", "80"}},
+ {{"https://host3.com", "https", "host3.com", "443"}, {"https://host4.com", "https", "host4.com", "443"}},
+ }},
+
+ {`import import_glob*.txt`, false, [][]address{
+ {{"glob0.host0", "", "glob0.host0", ""}},
+ {{"glob0.host1", "", "glob0.host1", ""}},
+ {{"glob1.host0", "", "glob1.host0", ""}},
+ {{"glob2.host0", "", "glob2.host0", ""}},
+ }},
+ } {
+ p := testParser(test.input)
+ blocks, err := p.parseAll()
+
+ if test.shouldErr && err == nil {
+ t.Errorf("Test %d: Expected an error, but didn't get one", i)
+ }
+ if !test.shouldErr && err != nil {
+ t.Errorf("Test %d: Expected no error, but got: %v", i, err)
+ }
+
+ if len(blocks) != len(test.addresses) {
+ t.Errorf("Test %d: Expected %d server blocks, got %d",
+ i, len(test.addresses), len(blocks))
+ continue
+ }
+ for j, block := range blocks {
+ if len(block.Addresses) != len(test.addresses[j]) {
+ t.Errorf("Test %d: Expected %d addresses in block %d, got %d",
+ i, len(test.addresses[j]), j, len(block.Addresses))
+ continue
+ }
+ for k, addr := range block.Addresses {
+ if addr.Host != test.addresses[j][k].Host {
+ t.Errorf("Test %d, block %d, address %d: Expected host to be '%s', but was '%s'",
+ i, j, k, test.addresses[j][k].Host, addr.Host)
+ }
+ if addr.Port != test.addresses[j][k].Port {
+ t.Errorf("Test %d, block %d, address %d: Expected port to be '%s', but was '%s'",
+ i, j, k, test.addresses[j][k].Port, addr.Port)
+ }
+ }
+ }
+ }
+}
+
+func TestEnvironmentReplacement(t *testing.T) {
+ setupParseTests()
+
+ os.Setenv("PORT", "8080")
+ os.Setenv("ADDRESS", "servername.com")
+ os.Setenv("FOOBAR", "foobar")
+
+ // basic test; unix-style env vars
+ p := testParser(`{$ADDRESS}`)
+ blocks, _ := p.parseAll()
+ if actual, expected := blocks[0].Addresses[0].Host, "servername.com"; expected != actual {
+ t.Errorf("Expected host to be '%s' but was '%s'", expected, actual)
+ }
+
+ // multiple vars per token
+ p = testParser(`{$ADDRESS}:{$PORT}`)
+ blocks, _ = p.parseAll()
+ if actual, expected := blocks[0].Addresses[0].Host, "servername.com"; expected != actual {
+ t.Errorf("Expected host to be '%s' but was '%s'", expected, actual)
+ }
+ if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual {
+ t.Errorf("Expected port to be '%s' but was '%s'", expected, actual)
+ }
+
+ // windows-style var and unix style in same token
+ p = testParser(`{%ADDRESS%}:{$PORT}`)
+ blocks, _ = p.parseAll()
+ if actual, expected := blocks[0].Addresses[0].Host, "servername.com"; expected != actual {
+ t.Errorf("Expected host to be '%s' but was '%s'", expected, actual)
+ }
+ if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual {
+ t.Errorf("Expected port to be '%s' but was '%s'", expected, actual)
+ }
+
+ // reverse order
+ p = testParser(`{$ADDRESS}:{%PORT%}`)
+ blocks, _ = p.parseAll()
+ if actual, expected := blocks[0].Addresses[0].Host, "servername.com"; expected != actual {
+ t.Errorf("Expected host to be '%s' but was '%s'", expected, actual)
+ }
+ if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual {
+ t.Errorf("Expected port to be '%s' but was '%s'", expected, actual)
+ }
+
+ // env var in server block body as argument
+ p = testParser(":{%PORT%}\ndir1 {$FOOBAR}")
+ blocks, _ = p.parseAll()
+ if actual, expected := blocks[0].Addresses[0].Port, "8080"; expected != actual {
+ t.Errorf("Expected port to be '%s' but was '%s'", expected, actual)
+ }
+ if actual, expected := blocks[0].Tokens["dir1"][1].text, "foobar"; expected != actual {
+ t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual)
+ }
+
+ // combined windows env vars in argument
+ p = testParser(":{%PORT%}\ndir1 {%ADDRESS%}/{%FOOBAR%}")
+ blocks, _ = p.parseAll()
+ if actual, expected := blocks[0].Tokens["dir1"][1].text, "servername.com/foobar"; expected != actual {
+ t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual)
+ }
+
+ // malformed env var (windows)
+ p = testParser(":1234\ndir1 {%ADDRESS}")
+ blocks, _ = p.parseAll()
+ if actual, expected := blocks[0].Tokens["dir1"][1].text, "{%ADDRESS}"; expected != actual {
+ t.Errorf("Expected host to be '%s' but was '%s'", expected, actual)
+ }
+
+ // malformed (non-existent) env var (unix)
+ p = testParser(`:{$PORT$}`)
+ blocks, _ = p.parseAll()
+ if actual, expected := blocks[0].Addresses[0].Port, ""; expected != actual {
+ t.Errorf("Expected port to be '%s' but was '%s'", expected, actual)
+ }
+
+ // in quoted field
+ p = testParser(":1234\ndir1 \"Test {$FOOBAR} test\"")
+ blocks, _ = p.parseAll()
+ if actual, expected := blocks[0].Tokens["dir1"][1].text, "Test foobar test"; expected != actual {
+ t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual)
+ }
+}
+
+func setupParseTests() {
+ // Set up some bogus directives for testing
+ ValidDirectives = map[string]struct{}{
+ "dir1": {},
+ "dir2": {},
+ "dir3": {},
+ }
+}
+
+func testParser(input string) parser {
+ buf := strings.NewReader(input)
+ p := parser{Dispenser: NewDispenser("Test", buf), checkDirectives: true}
+ return p
+}
diff --git a/core/restart.go b/core/restart.go
new file mode 100644
index 000000000..82567d35c
--- /dev/null
+++ b/core/restart.go
@@ -0,0 +1,166 @@
+// +build !windows
+
+package core
+
+import (
+ "bytes"
+ "encoding/gob"
+ "errors"
+ "io/ioutil"
+ "log"
+ "net"
+ "os"
+ "os/exec"
+ "path"
+ "sync/atomic"
+
+ "github.com/miekg/coredns/core/https"
+)
+
+func init() {
+ gob.Register(CaddyfileInput{})
+}
+
+// Restart restarts the entire application; gracefully with zero
+// downtime if on a POSIX-compatible system, or forcefully if on
+// Windows but with imperceptibly-short downtime.
+//
+// The restarted application will use newCaddyfile as its input
+// configuration. If newCaddyfile is nil, the current (existing)
+// Caddyfile configuration will be used.
+//
+// Note: The process must exist in the same place on the disk in
+// order for this to work. Thus, multiple graceful restarts don't
+// work if executing with `go run`, since the binary is cleaned up
+// when `go run` sees the initial parent process exit.
+func Restart(newCaddyfile Input) error {
+ log.Println("[INFO] Restarting")
+
+ if newCaddyfile == nil {
+ caddyfileMu.Lock()
+ newCaddyfile = caddyfile
+ caddyfileMu.Unlock()
+ }
+
+ // Get certificates for any new hosts in the new Caddyfile without causing downtime
+ err := getCertsForNewCaddyfile(newCaddyfile)
+ if err != nil {
+ return errors.New("TLS preload: " + err.Error())
+ }
+
+ if len(os.Args) == 0 { // this should never happen, but...
+ os.Args = []string{""}
+ }
+
+ // Tell the child that it's a restart
+ os.Setenv("CADDY_RESTART", "true")
+
+ // Prepare our payload to the child process
+ cdyfileGob := caddyfileGob{
+ ListenerFds: make(map[string]uintptr),
+ Caddyfile: newCaddyfile,
+ OnDemandTLSCertsIssued: atomic.LoadInt32(https.OnDemandIssuedCount),
+ }
+
+ // Prepare a pipe to the fork's stdin so it can get the Caddyfile
+ rpipe, wpipe, err := os.Pipe()
+ if err != nil {
+ return err
+ }
+
+ // Prepare a pipe that the child process will use to communicate
+ // its success with us by sending > 0 bytes
+ sigrpipe, sigwpipe, err := os.Pipe()
+ if err != nil {
+ return err
+ }
+
+ // Pass along relevant file descriptors to child process; ordering
+ // is very important since we rely on these being in certain positions.
+ extraFiles := []*os.File{sigwpipe} // fd 3
+
+ // Add file descriptors of all the sockets
+ serversMu.Lock()
+ for i, s := range servers {
+ extraFiles = append(extraFiles, s.ListenerFd())
+ cdyfileGob.ListenerFds[s.Addr] = uintptr(4 + i) // 4 fds come before any of the listeners
+ }
+ serversMu.Unlock()
+
+ // Set up the command
+ cmd := exec.Command(os.Args[0], os.Args[1:]...)
+ cmd.Stdin = rpipe // fd 0
+ cmd.Stdout = os.Stdout // fd 1
+ cmd.Stderr = os.Stderr // fd 2
+ cmd.ExtraFiles = extraFiles
+
+ // Spawn the child process
+ err = cmd.Start()
+ if err != nil {
+ return err
+ }
+
+ // Immediately close our dup'ed fds and the write end of our signal pipe
+ for _, f := range extraFiles {
+ f.Close()
+ }
+
+ // Feed Caddyfile to the child
+ err = gob.NewEncoder(wpipe).Encode(cdyfileGob)
+ if err != nil {
+ return err
+ }
+ wpipe.Close()
+
+ // Determine whether child startup succeeded
+ answer, readErr := ioutil.ReadAll(sigrpipe)
+ if answer == nil || len(answer) == 0 {
+ cmdErr := cmd.Wait() // get exit status
+ log.Printf("[ERROR] Restart: child failed to initialize (%v) - changes not applied", cmdErr)
+ if readErr != nil {
+ log.Printf("[ERROR] Restart: additionally, error communicating with child process: %v", readErr)
+ }
+ return errIncompleteRestart
+ }
+
+ // Looks like child is successful; we can exit gracefully.
+ return Stop()
+}
+
+func getCertsForNewCaddyfile(newCaddyfile Input) error {
+ // parse the new caddyfile only up to (and including) TLS
+ // so we can know what we need to get certs for.
+ configs, _, _, err := loadConfigsUpToIncludingTLS(path.Base(newCaddyfile.Path()), bytes.NewReader(newCaddyfile.Body()))
+ if err != nil {
+ return errors.New("loading Caddyfile: " + err.Error())
+ }
+
+ // first mark the configs that are qualified for managed TLS
+ https.MarkQualified(configs)
+
+ // since we group by bind address to obtain certs, we must call
+ // EnableTLS to make sure the port is set properly first
+ // (can ignore error since we aren't actually using the certs)
+ https.EnableTLS(configs, false)
+
+ // find out if we can let the acme package start its own challenge listener
+ // on port 80
+ var proxyACME bool
+ serversMu.Lock()
+ for _, s := range servers {
+ _, port, _ := net.SplitHostPort(s.Addr)
+ if port == "80" {
+ proxyACME = true
+ break
+ }
+ }
+ serversMu.Unlock()
+
+ // place certs on the disk
+ err = https.ObtainCerts(configs, false, proxyACME)
+ if err != nil {
+ return errors.New("obtaining certs: " + err.Error())
+ }
+
+ return nil
+}
diff --git a/core/restart_windows.go b/core/restart_windows.go
new file mode 100644
index 000000000..c2a4f557a
--- /dev/null
+++ b/core/restart_windows.go
@@ -0,0 +1,31 @@
+package core
+
+import "log"
+
+// Restart restarts Caddy forcefully using newCaddyfile,
+// or, if nil, the current/existing Caddyfile is reused.
+func Restart(newCaddyfile Input) error {
+ log.Println("[INFO] Restarting")
+
+ if newCaddyfile == nil {
+ caddyfileMu.Lock()
+ newCaddyfile = caddyfile
+ caddyfileMu.Unlock()
+ }
+
+ wg.Add(1) // barrier so Wait() doesn't unblock
+
+ err := Stop()
+ if err != nil {
+ return err
+ }
+
+ err = Start(newCaddyfile)
+ if err != nil {
+ return err
+ }
+
+ wg.Done() // take down our barrier
+
+ return nil
+}
diff --git a/core/setup/bindhost.go b/core/setup/bindhost.go
new file mode 100644
index 000000000..a3c07e5eb
--- /dev/null
+++ b/core/setup/bindhost.go
@@ -0,0 +1,13 @@
+package setup
+
+import "github.com/miekg/coredns/middleware"
+
+// BindHost sets the host to bind the listener to.
+func BindHost(c *Controller) (middleware.Middleware, error) {
+ for c.Next() {
+ if !c.Args(&c.BindHost) {
+ return nil, c.ArgErr()
+ }
+ }
+ return nil, nil
+}
diff --git a/core/setup/controller.go b/core/setup/controller.go
new file mode 100644
index 000000000..1c1a93e64
--- /dev/null
+++ b/core/setup/controller.go
@@ -0,0 +1,83 @@
+package setup
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/miekg/coredns/core/parse"
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/coredns/server"
+ "github.com/miekg/dns"
+)
+
+// Controller is given to the setup function of middlewares which
+// gives them access to be able to read tokens and set config. Each
+// virtualhost gets their own server config and dispenser.
+type Controller struct {
+ *server.Config
+ parse.Dispenser
+
+ // OncePerServerBlock is a function that executes f
+ // exactly once per server block, no matter how many
+ // hosts are associated with it. If it is the first
+ // time, the function f is executed immediately
+ // (not deferred) and may return an error which is
+ // returned by OncePerServerBlock.
+ OncePerServerBlock func(f func() error) error
+
+ // ServerBlockIndex is the 0-based index of the
+ // server block as it appeared in the input.
+ ServerBlockIndex int
+
+ // ServerBlockHostIndex is the 0-based index of this
+ // host as it appeared in the input at the head of the
+ // server block.
+ ServerBlockHostIndex int
+
+ // ServerBlockHosts is a list of hosts that are
+ // associated with this server block. All these
+ // hosts, consequently, share the same tokens.
+ ServerBlockHosts []string
+
+ // ServerBlockStorage is used by a directive's
+ // setup function to persist state between all
+ // the hosts on a server block.
+ ServerBlockStorage interface{}
+}
+
+// NewTestController creates a new *Controller for
+// the input specified, with a filename of "Testfile".
+// The Config is bare, consisting only of a Root of cwd.
+//
+// Used primarily for testing but needs to be exported so
+// add-ons can use this as a convenience. Does not initialize
+// the server-block-related fields.
+func NewTestController(input string) *Controller {
+ return &Controller{
+ Config: &server.Config{
+ Root: ".",
+ },
+ Dispenser: parse.NewDispenser("Testfile", strings.NewReader(input)),
+ OncePerServerBlock: func(f func() error) error {
+ return f()
+ },
+ }
+}
+
+// EmptyNext is a no-op function that can be passed into
+// middleware.Middleware functions so that the assignment
+// to the Next field of the Handler can be tested.
+//
+// Used primarily for testing but needs to be exported so
+// add-ons can use this as a convenience.
+var EmptyNext = middleware.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ return 0, nil
+})
+
+// SameNext does a pointer comparison between next1 and next2.
+//
+// Used primarily for testing but needs to be exported so
+// add-ons can use this as a convenience.
+func SameNext(next1, next2 middleware.Handler) bool {
+ return fmt.Sprintf("%v", next1) == fmt.Sprintf("%v", next2)
+}
diff --git a/core/setup/errors.go b/core/setup/errors.go
new file mode 100644
index 000000000..0b392ec99
--- /dev/null
+++ b/core/setup/errors.go
@@ -0,0 +1,132 @@
+package setup
+
+import (
+ "io"
+ "log"
+ "os"
+
+ "github.com/hashicorp/go-syslog"
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/coredns/middleware/errors"
+)
+
+// Errors configures a new errors middleware instance.
+func Errors(c *Controller) (middleware.Middleware, error) {
+ handler, err := errorsParse(c)
+ if err != nil {
+ return nil, err
+ }
+
+ // Open the log file for writing when the server starts
+ c.Startup = append(c.Startup, func() error {
+ var err error
+ var writer io.Writer
+
+ switch handler.LogFile {
+ case "visible":
+ handler.Debug = true
+ case "stdout":
+ writer = os.Stdout
+ case "stderr":
+ writer = os.Stderr
+ case "syslog":
+ writer, err = gsyslog.NewLogger(gsyslog.LOG_ERR, "LOCAL0", "caddy")
+ if err != nil {
+ return err
+ }
+ default:
+ if handler.LogFile == "" {
+ writer = os.Stderr // default
+ break
+ }
+
+ var file *os.File
+ file, err = os.OpenFile(handler.LogFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644)
+ if err != nil {
+ return err
+ }
+ if handler.LogRoller != nil {
+ file.Close()
+
+ handler.LogRoller.Filename = handler.LogFile
+
+ writer = handler.LogRoller.GetLogWriter()
+ } else {
+ writer = file
+ }
+ }
+
+ handler.Log = log.New(writer, "", 0)
+ return nil
+ })
+
+ return func(next middleware.Handler) middleware.Handler {
+ handler.Next = next
+ return handler
+ }, nil
+}
+
+func errorsParse(c *Controller) (*errors.ErrorHandler, error) {
+ // Very important that we make a pointer because the Startup
+ // function that opens the log file must have access to the
+ // same instance of the handler, not a copy.
+ handler := &errors.ErrorHandler{}
+
+ optionalBlock := func() (bool, error) {
+ var hadBlock bool
+
+ for c.NextBlock() {
+ hadBlock = true
+
+ what := c.Val()
+ if !c.NextArg() {
+ return hadBlock, c.ArgErr()
+ }
+ where := c.Val()
+
+ if what == "log" {
+ if where == "visible" {
+ handler.Debug = true
+ } else {
+ handler.LogFile = where
+ if c.NextArg() {
+ if c.Val() == "{" {
+ c.IncrNest()
+ logRoller, err := parseRoller(c)
+ if err != nil {
+ return hadBlock, err
+ }
+ handler.LogRoller = logRoller
+ }
+ }
+ }
+ }
+ }
+ return hadBlock, nil
+ }
+
+ for c.Next() {
+ // weird hack to avoid having the handler values overwritten.
+ if c.Val() == "}" {
+ continue
+ }
+ // Configuration may be in a block
+ hadBlock, err := optionalBlock()
+ if err != nil {
+ return handler, err
+ }
+
+ // Otherwise, the only argument would be an error log file name or 'visible'
+ if !hadBlock {
+ if c.NextArg() {
+ if c.Val() == "visible" {
+ handler.Debug = true
+ } else {
+ handler.LogFile = c.Val()
+ }
+ }
+ }
+ }
+
+ return handler, nil
+}
diff --git a/core/setup/errors_test.go b/core/setup/errors_test.go
new file mode 100644
index 000000000..4a079e0b5
--- /dev/null
+++ b/core/setup/errors_test.go
@@ -0,0 +1,158 @@
+package setup
+
+import (
+ "testing"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/coredns/middleware/errors"
+)
+
+func TestErrors(t *testing.T) {
+ c := NewTestController(`errors`)
+ mid, err := Errors(c)
+
+ if err != nil {
+ t.Errorf("Expected no errors, got: %v", err)
+ }
+
+ if mid == nil {
+ t.Fatal("Expected middleware, was nil instead")
+ }
+
+ handler := mid(EmptyNext)
+ myHandler, ok := handler.(*errors.ErrorHandler)
+
+ if !ok {
+ t.Fatalf("Expected handler to be type ErrorHandler, got: %#v", handler)
+ }
+
+ if myHandler.LogFile != "" {
+ t.Errorf("Expected '%s' as the default LogFile", "")
+ }
+ if myHandler.LogRoller != nil {
+ t.Errorf("Expected LogRoller to be nil, got: %v", *myHandler.LogRoller)
+ }
+ if !SameNext(myHandler.Next, EmptyNext) {
+ t.Error("'Next' field of handler was not set properly")
+ }
+
+ // Test Startup function
+ if len(c.Startup) == 0 {
+ t.Fatal("Expected 1 startup function, had 0")
+ }
+ err = c.Startup[0]()
+ if myHandler.Log == nil {
+ t.Error("Expected Log to be non-nil after startup because Debug is not enabled")
+ }
+}
+
+func TestErrorsParse(t *testing.T) {
+ tests := []struct {
+ inputErrorsRules string
+ shouldErr bool
+ expectedErrorHandler errors.ErrorHandler
+ }{
+ {`errors`, false, errors.ErrorHandler{
+ LogFile: "",
+ }},
+ {`errors errors.txt`, false, errors.ErrorHandler{
+ LogFile: "errors.txt",
+ }},
+ {`errors visible`, false, errors.ErrorHandler{
+ LogFile: "",
+ Debug: true,
+ }},
+ {`errors { log visible }`, false, errors.ErrorHandler{
+ LogFile: "",
+ Debug: true,
+ }},
+ {`errors { log errors.txt
+ 404 404.html
+ 500 500.html
+}`, false, errors.ErrorHandler{
+ LogFile: "errors.txt",
+ ErrorPages: map[int]string{
+ 404: "404.html",
+ 500: "500.html",
+ },
+ }},
+ {`errors { log errors.txt { size 2 age 10 keep 3 } }`, false, errors.ErrorHandler{
+ LogFile: "errors.txt",
+ LogRoller: &middleware.LogRoller{
+ MaxSize: 2,
+ MaxAge: 10,
+ MaxBackups: 3,
+ LocalTime: true,
+ },
+ }},
+ {`errors { log errors.txt {
+ size 3
+ age 11
+ keep 5
+ }
+ 404 404.html
+ 503 503.html
+}`, false, errors.ErrorHandler{
+ LogFile: "errors.txt",
+ ErrorPages: map[int]string{
+ 404: "404.html",
+ 503: "503.html",
+ },
+ LogRoller: &middleware.LogRoller{
+ MaxSize: 3,
+ MaxAge: 11,
+ MaxBackups: 5,
+ LocalTime: true,
+ },
+ }},
+ }
+ for i, test := range tests {
+ c := NewTestController(test.inputErrorsRules)
+ actualErrorsRule, err := errorsParse(c)
+
+ if err == nil && test.shouldErr {
+ t.Errorf("Test %d didn't error, but it should have", i)
+ } else if err != nil && !test.shouldErr {
+ t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
+ }
+ if actualErrorsRule.LogFile != test.expectedErrorHandler.LogFile {
+ t.Errorf("Test %d expected LogFile to be %s, but got %s",
+ i, test.expectedErrorHandler.LogFile, actualErrorsRule.LogFile)
+ }
+ if actualErrorsRule.Debug != test.expectedErrorHandler.Debug {
+ t.Errorf("Test %d expected Debug to be %v, but got %v",
+ i, test.expectedErrorHandler.Debug, actualErrorsRule.Debug)
+ }
+ if actualErrorsRule.LogRoller != nil && test.expectedErrorHandler.LogRoller == nil || actualErrorsRule.LogRoller == nil && test.expectedErrorHandler.LogRoller != nil {
+ t.Fatalf("Test %d expected LogRoller to be %v, but got %v",
+ i, test.expectedErrorHandler.LogRoller, actualErrorsRule.LogRoller)
+ }
+ if len(actualErrorsRule.ErrorPages) != len(test.expectedErrorHandler.ErrorPages) {
+ t.Fatalf("Test %d expected %d no of Error pages, but got %d ",
+ i, len(test.expectedErrorHandler.ErrorPages), len(actualErrorsRule.ErrorPages))
+ }
+ if actualErrorsRule.LogRoller != nil && test.expectedErrorHandler.LogRoller != nil {
+ if actualErrorsRule.LogRoller.Filename != test.expectedErrorHandler.LogRoller.Filename {
+ t.Fatalf("Test %d expected LogRoller Filename to be %s, but got %s",
+ i, test.expectedErrorHandler.LogRoller.Filename, actualErrorsRule.LogRoller.Filename)
+ }
+ if actualErrorsRule.LogRoller.MaxAge != test.expectedErrorHandler.LogRoller.MaxAge {
+ t.Fatalf("Test %d expected LogRoller MaxAge to be %d, but got %d",
+ i, test.expectedErrorHandler.LogRoller.MaxAge, actualErrorsRule.LogRoller.MaxAge)
+ }
+ if actualErrorsRule.LogRoller.MaxBackups != test.expectedErrorHandler.LogRoller.MaxBackups {
+ t.Fatalf("Test %d expected LogRoller MaxBackups to be %d, but got %d",
+ i, test.expectedErrorHandler.LogRoller.MaxBackups, actualErrorsRule.LogRoller.MaxBackups)
+ }
+ if actualErrorsRule.LogRoller.MaxSize != test.expectedErrorHandler.LogRoller.MaxSize {
+ t.Fatalf("Test %d expected LogRoller MaxSize to be %d, but got %d",
+ i, test.expectedErrorHandler.LogRoller.MaxSize, actualErrorsRule.LogRoller.MaxSize)
+ }
+ if actualErrorsRule.LogRoller.LocalTime != test.expectedErrorHandler.LogRoller.LocalTime {
+ t.Fatalf("Test %d expected LogRoller LocalTime to be %t, but got %t",
+ i, test.expectedErrorHandler.LogRoller.LocalTime, actualErrorsRule.LogRoller.LocalTime)
+ }
+ }
+ }
+
+}
diff --git a/core/setup/file.go b/core/setup/file.go
new file mode 100644
index 000000000..76aed5249
--- /dev/null
+++ b/core/setup/file.go
@@ -0,0 +1,73 @@
+package setup
+
+import (
+ "log"
+ "os"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/coredns/middleware/file"
+ "github.com/miekg/dns"
+)
+
+// File sets up the file middleware.
+func File(c *Controller) (middleware.Middleware, error) {
+ zones, err := fileParse(c)
+ if err != nil {
+ return nil, err
+ }
+ return func(next middleware.Handler) middleware.Handler {
+ return file.File{Next: next, Zones: zones}
+ }, nil
+
+}
+
+func fileParse(c *Controller) (file.Zones, error) {
+ // Maybe multiple, each for each zone.
+ z := make(map[string]file.Zone)
+ names := []string{}
+ for c.Next() {
+ if c.Val() == "file" {
+ // file db.file [origin]
+ if !c.NextArg() {
+ return file.Zones{}, c.ArgErr()
+ }
+ fileName := c.Val()
+
+ origin := c.ServerBlockHosts[c.ServerBlockHostIndex]
+ if c.NextArg() {
+ c.Next()
+ origin = c.Val()
+ }
+ // normalize this origin
+ origin = middleware.Host(origin).StandardHost()
+
+ zone, err := parseZone(origin, fileName)
+ if err == nil {
+ z[origin] = zone
+ }
+ names = append(names, origin)
+ }
+ }
+ return file.Zones{Z: z, Names: names}, nil
+}
+
+//
+// parsrZone parses the zone in filename and returns a []RR or an error.
+func parseZone(origin, fileName string) (file.Zone, error) {
+ f, err := os.Open(fileName)
+ if err != nil {
+ return nil, err
+ }
+ tokens := dns.ParseZone(f, origin, fileName)
+ zone := make([]dns.RR, 0, defaultZoneSize)
+ for x := range tokens {
+ if x.Error != nil {
+ log.Printf("[ERROR] failed to parse %s: %v", origin, x.Error)
+ return nil, x.Error
+ }
+ zone = append(zone, x.RR)
+ }
+ return file.Zone(zone), nil
+}
+
+const defaultZoneSize = 20 // A made up number.
diff --git a/core/setup/log.go b/core/setup/log.go
new file mode 100644
index 000000000..32d9f3250
--- /dev/null
+++ b/core/setup/log.go
@@ -0,0 +1,130 @@
+package setup
+
+import (
+ "io"
+ "log"
+ "os"
+
+ "github.com/hashicorp/go-syslog"
+ "github.com/miekg/coredns/middleware"
+ caddylog "github.com/miekg/coredns/middleware/log"
+ "github.com/miekg/coredns/server"
+)
+
+// Log sets up the logging middleware.
+func Log(c *Controller) (middleware.Middleware, error) {
+ rules, err := logParse(c)
+ if err != nil {
+ return nil, err
+ }
+
+ // Open the log files for writing when the server starts
+ c.Startup = append(c.Startup, func() error {
+ for i := 0; i < len(rules); i++ {
+ var err error
+ var writer io.Writer
+
+ if rules[i].OutputFile == "stdout" {
+ writer = os.Stdout
+ } else if rules[i].OutputFile == "stderr" {
+ writer = os.Stderr
+ } else if rules[i].OutputFile == "syslog" {
+ writer, err = gsyslog.NewLogger(gsyslog.LOG_INFO, "LOCAL0", "caddy")
+ if err != nil {
+ return err
+ }
+ } else {
+ var file *os.File
+ file, err = os.OpenFile(rules[i].OutputFile, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0644)
+ if err != nil {
+ return err
+ }
+ if rules[i].Roller != nil {
+ file.Close()
+ rules[i].Roller.Filename = rules[i].OutputFile
+ writer = rules[i].Roller.GetLogWriter()
+ } else {
+ writer = file
+ }
+ }
+
+ rules[i].Log = log.New(writer, "", 0)
+ }
+
+ return nil
+ })
+
+ return func(next middleware.Handler) middleware.Handler {
+ return caddylog.Logger{Next: next, Rules: rules, ErrorFunc: server.DefaultErrorFunc}
+ }, nil
+}
+
+func logParse(c *Controller) ([]caddylog.Rule, error) {
+ var rules []caddylog.Rule
+
+ for c.Next() {
+ args := c.RemainingArgs()
+
+ var logRoller *middleware.LogRoller
+ if c.NextBlock() {
+ if c.Val() == "rotate" {
+ if c.NextArg() {
+ if c.Val() == "{" {
+ var err error
+ logRoller, err = parseRoller(c)
+ if err != nil {
+ return nil, err
+ }
+ // This part doesn't allow having something after the rotate block
+ if c.Next() {
+ if c.Val() != "}" {
+ return nil, c.ArgErr()
+ }
+ }
+ }
+ }
+ }
+ }
+ if len(args) == 0 {
+ // Nothing specified; use defaults
+ rules = append(rules, caddylog.Rule{
+ PathScope: "/",
+ OutputFile: caddylog.DefaultLogFilename,
+ Format: caddylog.DefaultLogFormat,
+ Roller: logRoller,
+ })
+ } else if len(args) == 1 {
+ // Only an output file specified
+ rules = append(rules, caddylog.Rule{
+ PathScope: "/",
+ OutputFile: args[0],
+ Format: caddylog.DefaultLogFormat,
+ Roller: logRoller,
+ })
+ } else {
+ // Path scope, output file, and maybe a format specified
+
+ format := caddylog.DefaultLogFormat
+
+ if len(args) > 2 {
+ switch args[2] {
+ case "{common}":
+ format = caddylog.CommonLogFormat
+ case "{combined}":
+ format = caddylog.CombinedLogFormat
+ default:
+ format = args[2]
+ }
+ }
+
+ rules = append(rules, caddylog.Rule{
+ PathScope: args[0],
+ OutputFile: args[1],
+ Format: format,
+ Roller: logRoller,
+ })
+ }
+ }
+
+ return rules, nil
+}
diff --git a/core/setup/log_test.go b/core/setup/log_test.go
new file mode 100644
index 000000000..2bfcb4e89
--- /dev/null
+++ b/core/setup/log_test.go
@@ -0,0 +1,175 @@
+package setup
+
+import (
+ "testing"
+
+ "github.com/miekg/coredns/middleware"
+ caddylog "github.com/miekg/coredns/middleware/log"
+)
+
+func TestLog(t *testing.T) {
+
+ c := NewTestController(`log`)
+
+ mid, err := Log(c)
+
+ if err != nil {
+ t.Errorf("Expected no errors, got: %v", err)
+ }
+
+ if mid == nil {
+ t.Fatal("Expected middleware, was nil instead")
+ }
+
+ handler := mid(EmptyNext)
+ myHandler, ok := handler.(caddylog.Logger)
+
+ if !ok {
+ t.Fatalf("Expected handler to be type Logger, got: %#v", handler)
+ }
+
+ if myHandler.Rules[0].PathScope != "/" {
+ t.Errorf("Expected / as the default PathScope")
+ }
+ if myHandler.Rules[0].OutputFile != caddylog.DefaultLogFilename {
+ t.Errorf("Expected %s as the default OutputFile", caddylog.DefaultLogFilename)
+ }
+ if myHandler.Rules[0].Format != caddylog.DefaultLogFormat {
+ t.Errorf("Expected %s as the default Log Format", caddylog.DefaultLogFormat)
+ }
+ if myHandler.Rules[0].Roller != nil {
+ t.Errorf("Expected Roller to be nil, got: %v", *myHandler.Rules[0].Roller)
+ }
+ if !SameNext(myHandler.Next, EmptyNext) {
+ t.Error("'Next' field of handler was not set properly")
+ }
+
+}
+
+func TestLogParse(t *testing.T) {
+ tests := []struct {
+ inputLogRules string
+ shouldErr bool
+ expectedLogRules []caddylog.Rule
+ }{
+ {`log`, false, []caddylog.Rule{{
+ PathScope: "/",
+ OutputFile: caddylog.DefaultLogFilename,
+ Format: caddylog.DefaultLogFormat,
+ }}},
+ {`log log.txt`, false, []caddylog.Rule{{
+ PathScope: "/",
+ OutputFile: "log.txt",
+ Format: caddylog.DefaultLogFormat,
+ }}},
+ {`log /api log.txt`, false, []caddylog.Rule{{
+ PathScope: "/api",
+ OutputFile: "log.txt",
+ Format: caddylog.DefaultLogFormat,
+ }}},
+ {`log /serve stdout`, false, []caddylog.Rule{{
+ PathScope: "/serve",
+ OutputFile: "stdout",
+ Format: caddylog.DefaultLogFormat,
+ }}},
+ {`log /myapi log.txt {common}`, false, []caddylog.Rule{{
+ PathScope: "/myapi",
+ OutputFile: "log.txt",
+ Format: caddylog.CommonLogFormat,
+ }}},
+ {`log /test accesslog.txt {combined}`, false, []caddylog.Rule{{
+ PathScope: "/test",
+ OutputFile: "accesslog.txt",
+ Format: caddylog.CombinedLogFormat,
+ }}},
+ {`log /api1 log.txt
+ log /api2 accesslog.txt {combined}`, false, []caddylog.Rule{{
+ PathScope: "/api1",
+ OutputFile: "log.txt",
+ Format: caddylog.DefaultLogFormat,
+ }, {
+ PathScope: "/api2",
+ OutputFile: "accesslog.txt",
+ Format: caddylog.CombinedLogFormat,
+ }}},
+ {`log /api3 stdout {host}
+ log /api4 log.txt {when}`, false, []caddylog.Rule{{
+ PathScope: "/api3",
+ OutputFile: "stdout",
+ Format: "{host}",
+ }, {
+ PathScope: "/api4",
+ OutputFile: "log.txt",
+ Format: "{when}",
+ }}},
+ {`log access.log { rotate { size 2 age 10 keep 3 } }`, false, []caddylog.Rule{{
+ PathScope: "/",
+ OutputFile: "access.log",
+ Format: caddylog.DefaultLogFormat,
+ Roller: &middleware.LogRoller{
+ MaxSize: 2,
+ MaxAge: 10,
+ MaxBackups: 3,
+ LocalTime: true,
+ },
+ }}},
+ }
+ for i, test := range tests {
+ c := NewTestController(test.inputLogRules)
+ actualLogRules, err := logParse(c)
+
+ if err == nil && test.shouldErr {
+ t.Errorf("Test %d didn't error, but it should have", i)
+ } else if err != nil && !test.shouldErr {
+ t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
+ }
+ if len(actualLogRules) != len(test.expectedLogRules) {
+ t.Fatalf("Test %d expected %d no of Log rules, but got %d ",
+ i, len(test.expectedLogRules), len(actualLogRules))
+ }
+ for j, actualLogRule := range actualLogRules {
+
+ if actualLogRule.PathScope != test.expectedLogRules[j].PathScope {
+ t.Errorf("Test %d expected %dth LogRule PathScope to be %s , but got %s",
+ i, j, test.expectedLogRules[j].PathScope, actualLogRule.PathScope)
+ }
+
+ if actualLogRule.OutputFile != test.expectedLogRules[j].OutputFile {
+ t.Errorf("Test %d expected %dth LogRule OutputFile to be %s , but got %s",
+ i, j, test.expectedLogRules[j].OutputFile, actualLogRule.OutputFile)
+ }
+
+ if actualLogRule.Format != test.expectedLogRules[j].Format {
+ t.Errorf("Test %d expected %dth LogRule Format to be %s , but got %s",
+ i, j, test.expectedLogRules[j].Format, actualLogRule.Format)
+ }
+ if actualLogRule.Roller != nil && test.expectedLogRules[j].Roller == nil || actualLogRule.Roller == nil && test.expectedLogRules[j].Roller != nil {
+ t.Fatalf("Test %d expected %dth LogRule Roller to be %v, but got %v",
+ i, j, test.expectedLogRules[j].Roller, actualLogRule.Roller)
+ }
+ if actualLogRule.Roller != nil && test.expectedLogRules[j].Roller != nil {
+ if actualLogRule.Roller.Filename != test.expectedLogRules[j].Roller.Filename {
+ t.Fatalf("Test %d expected %dth LogRule Roller Filename to be %s, but got %s",
+ i, j, test.expectedLogRules[j].Roller.Filename, actualLogRule.Roller.Filename)
+ }
+ if actualLogRule.Roller.MaxAge != test.expectedLogRules[j].Roller.MaxAge {
+ t.Fatalf("Test %d expected %dth LogRule Roller MaxAge to be %d, but got %d",
+ i, j, test.expectedLogRules[j].Roller.MaxAge, actualLogRule.Roller.MaxAge)
+ }
+ if actualLogRule.Roller.MaxBackups != test.expectedLogRules[j].Roller.MaxBackups {
+ t.Fatalf("Test %d expected %dth LogRule Roller MaxBackups to be %d, but got %d",
+ i, j, test.expectedLogRules[j].Roller.MaxBackups, actualLogRule.Roller.MaxBackups)
+ }
+ if actualLogRule.Roller.MaxSize != test.expectedLogRules[j].Roller.MaxSize {
+ t.Fatalf("Test %d expected %dth LogRule Roller MaxSize to be %d, but got %d",
+ i, j, test.expectedLogRules[j].Roller.MaxSize, actualLogRule.Roller.MaxSize)
+ }
+ if actualLogRule.Roller.LocalTime != test.expectedLogRules[j].Roller.LocalTime {
+ t.Fatalf("Test %d expected %dth LogRule Roller LocalTime to be %t, but got %t",
+ i, j, test.expectedLogRules[j].Roller.LocalTime, actualLogRule.Roller.LocalTime)
+ }
+ }
+ }
+ }
+
+}
diff --git a/core/setup/prometheus.go b/core/setup/prometheus.go
new file mode 100644
index 000000000..3bcb907ce
--- /dev/null
+++ b/core/setup/prometheus.go
@@ -0,0 +1,70 @@
+package setup
+
+import (
+ "sync"
+
+ "github.com/miekg/coredns/middleware"
+ prom "github.com/miekg/coredns/middleware/prometheus"
+)
+
+const (
+ path = "/metrics"
+ addr = "localhost:9153"
+)
+
+var once sync.Once
+
+func Prometheus(c *Controller) (middleware.Middleware, error) {
+ metrics, err := parsePrometheus(c)
+ if err != nil {
+ return nil, err
+ }
+ if metrics.Addr == "" {
+ metrics.Addr = addr
+ }
+ once.Do(func() {
+ c.Startup = append(c.Startup, metrics.Start)
+ })
+
+ return func(next middleware.Handler) middleware.Handler {
+ metrics.Next = next
+ return metrics
+ }, nil
+}
+
+func parsePrometheus(c *Controller) (*prom.Metrics, error) {
+ var (
+ metrics *prom.Metrics
+ err error
+ )
+
+ for c.Next() {
+ if metrics != nil {
+ return nil, c.Err("prometheus: can only have one metrics module per server")
+ }
+ metrics = &prom.Metrics{ZoneNames: c.ServerBlockHosts}
+ args := c.RemainingArgs()
+
+ switch len(args) {
+ case 0:
+ case 1:
+ metrics.Addr = args[0]
+ default:
+ return nil, c.ArgErr()
+ }
+ for c.NextBlock() {
+ switch c.Val() {
+ case "address":
+ args = c.RemainingArgs()
+ if len(args) != 1 {
+ return nil, c.ArgErr()
+ }
+ metrics.Addr = args[0]
+ default:
+ return nil, c.Errf("prometheus: unknown item: %s", c.Val())
+ }
+
+ }
+ }
+ return metrics, err
+}
diff --git a/core/setup/proxy.go b/core/setup/proxy.go
new file mode 100644
index 000000000..6753d07ad
--- /dev/null
+++ b/core/setup/proxy.go
@@ -0,0 +1,17 @@
+package setup
+
+import (
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/coredns/middleware/proxy"
+)
+
+// Proxy configures a new Proxy middleware instance.
+func Proxy(c *Controller) (middleware.Middleware, error) {
+ upstreams, err := proxy.NewStaticUpstreams(c.Dispenser)
+ if err != nil {
+ return nil, err
+ }
+ return func(next middleware.Handler) middleware.Handler {
+ return proxy.Proxy{Next: next, Client: proxy.Clients(), Upstreams: upstreams}
+ }, nil
+}
diff --git a/core/setup/reflect.go b/core/setup/reflect.go
new file mode 100644
index 000000000..9ae1d5181
--- /dev/null
+++ b/core/setup/reflect.go
@@ -0,0 +1,28 @@
+package setup
+
+import (
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/coredns/middleware/reflect"
+)
+
+// Reflect sets up the reflect middleware.
+func Reflect(c *Controller) (middleware.Middleware, error) {
+ if err := reflectParse(c); err != nil {
+ return nil, err
+ }
+ return func(next middleware.Handler) middleware.Handler {
+ return reflect.Reflect{Next: next}
+ }, nil
+
+}
+
+func reflectParse(c *Controller) error {
+ for c.Next() {
+ if c.Val() == "reflect" {
+ if c.NextArg() {
+ return c.ArgErr()
+ }
+ }
+ }
+ return nil
+}
diff --git a/core/setup/rewrite.go b/core/setup/rewrite.go
new file mode 100644
index 000000000..32f5f42a3
--- /dev/null
+++ b/core/setup/rewrite.go
@@ -0,0 +1,109 @@
+package setup
+
+import (
+ "strconv"
+ "strings"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/coredns/middleware/rewrite"
+)
+
+// Rewrite configures a new Rewrite middleware instance.
+func Rewrite(c *Controller) (middleware.Middleware, error) {
+ rewrites, err := rewriteParse(c)
+ if err != nil {
+ return nil, err
+ }
+
+ return func(next middleware.Handler) middleware.Handler {
+ return rewrite.Rewrite{
+ Next: next,
+ Rules: rewrites,
+ }
+ }, nil
+}
+
+func rewriteParse(c *Controller) ([]rewrite.Rule, error) {
+ var simpleRules []rewrite.Rule
+ var regexpRules []rewrite.Rule
+
+ for c.Next() {
+ var rule rewrite.Rule
+ var err error
+ var base = "/"
+ var pattern, to string
+ var status int
+ var ext []string
+
+ args := c.RemainingArgs()
+
+ var ifs []rewrite.If
+
+ switch len(args) {
+ case 1:
+ base = args[0]
+ fallthrough
+ case 0:
+ for c.NextBlock() {
+ switch c.Val() {
+ case "r", "regexp":
+ if !c.NextArg() {
+ return nil, c.ArgErr()
+ }
+ pattern = c.Val()
+ case "to":
+ args1 := c.RemainingArgs()
+ if len(args1) == 0 {
+ return nil, c.ArgErr()
+ }
+ to = strings.Join(args1, " ")
+ case "ext":
+ args1 := c.RemainingArgs()
+ if len(args1) == 0 {
+ return nil, c.ArgErr()
+ }
+ ext = args1
+ case "if":
+ args1 := c.RemainingArgs()
+ if len(args1) != 3 {
+ return nil, c.ArgErr()
+ }
+ ifCond, err := rewrite.NewIf(args1[0], args1[1], args1[2])
+ if err != nil {
+ return nil, err
+ }
+ ifs = append(ifs, ifCond)
+ case "status":
+ if !c.NextArg() {
+ return nil, c.ArgErr()
+ }
+ status, _ = strconv.Atoi(c.Val())
+ if status < 200 || (status > 299 && status < 400) || status > 499 {
+ return nil, c.Err("status must be 2xx or 4xx")
+ }
+ default:
+ return nil, c.ArgErr()
+ }
+ }
+ // ensure to or status is specified
+ if to == "" && status == 0 {
+ return nil, c.ArgErr()
+ }
+ // TODO(miek): complex rules
+ base, pattern, to, status, ext, ifs = base, pattern, to, status, ext, ifs
+ err = err
+ // if rule, err = rewrite.NewComplexRule(base, pattern, to, status, ext, ifs); err != nil {
+ // return nil, err
+ // }
+ regexpRules = append(regexpRules, rule)
+
+ // the only unhandled case is 2 and above
+ default:
+ rule = rewrite.NewSimpleRule(args[0], strings.Join(args[1:], " "))
+ simpleRules = append(simpleRules, rule)
+ }
+ }
+
+ // put simple rules in front to avoid regexp computation for them
+ return append(simpleRules, regexpRules...), nil
+}
diff --git a/core/setup/rewrite_test.go b/core/setup/rewrite_test.go
new file mode 100644
index 000000000..747618305
--- /dev/null
+++ b/core/setup/rewrite_test.go
@@ -0,0 +1,241 @@
+package setup
+
+import (
+ "fmt"
+ "regexp"
+ "testing"
+
+ "github.com/miekg/coredns/middleware/rewrite"
+)
+
+func TestRewrite(t *testing.T) {
+ c := NewTestController(`rewrite /from /to`)
+
+ mid, err := Rewrite(c)
+ if err != nil {
+ t.Errorf("Expected no errors, but got: %v", err)
+ }
+ if mid == nil {
+ t.Fatal("Expected middleware, was nil instead")
+ }
+
+ handler := mid(EmptyNext)
+ myHandler, ok := handler.(rewrite.Rewrite)
+ if !ok {
+ t.Fatalf("Expected handler to be type Rewrite, got: %#v", handler)
+ }
+
+ if !SameNext(myHandler.Next, EmptyNext) {
+ t.Error("'Next' field of handler was not set properly")
+ }
+
+ if len(myHandler.Rules) != 1 {
+ t.Errorf("Expected handler to have %d rule, has %d instead", 1, len(myHandler.Rules))
+ }
+}
+
+func TestRewriteParse(t *testing.T) {
+ simpleTests := []struct {
+ input string
+ shouldErr bool
+ expected []rewrite.Rule
+ }{
+ {`rewrite /from /to`, false, []rewrite.Rule{
+ rewrite.SimpleRule{From: "/from", To: "/to"},
+ }},
+ {`rewrite /from /to
+ rewrite a b`, false, []rewrite.Rule{
+ rewrite.SimpleRule{From: "/from", To: "/to"},
+ rewrite.SimpleRule{From: "a", To: "b"},
+ }},
+ {`rewrite a`, true, []rewrite.Rule{}},
+ {`rewrite`, true, []rewrite.Rule{}},
+ {`rewrite a b c`, false, []rewrite.Rule{
+ rewrite.SimpleRule{From: "a", To: "b c"},
+ }},
+ }
+
+ for i, test := range simpleTests {
+ c := NewTestController(test.input)
+ actual, err := rewriteParse(c)
+
+ if err == nil && test.shouldErr {
+ t.Errorf("Test %d didn't error, but it should have", i)
+ } else if err != nil && !test.shouldErr {
+ t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
+ } else if err != nil && test.shouldErr {
+ continue
+ }
+
+ if len(actual) != len(test.expected) {
+ t.Fatalf("Test %d expected %d rules, but got %d",
+ i, len(test.expected), len(actual))
+ }
+
+ for j, e := range test.expected {
+ actualRule := actual[j].(rewrite.SimpleRule)
+ expectedRule := e.(rewrite.SimpleRule)
+
+ if actualRule.From != expectedRule.From {
+ t.Errorf("Test %d, rule %d: Expected From=%s, got %s",
+ i, j, expectedRule.From, actualRule.From)
+ }
+
+ if actualRule.To != expectedRule.To {
+ t.Errorf("Test %d, rule %d: Expected To=%s, got %s",
+ i, j, expectedRule.To, actualRule.To)
+ }
+ }
+ }
+
+ regexpTests := []struct {
+ input string
+ shouldErr bool
+ expected []rewrite.Rule
+ }{
+ {`rewrite {
+ r .*
+ to /to /index.php?
+ }`, false, []rewrite.Rule{
+ &rewrite.ComplexRule{Base: "/", To: "/to /index.php?", Regexp: regexp.MustCompile(".*")},
+ }},
+ {`rewrite {
+ regexp .*
+ to /to
+ ext / html txt
+ }`, false, []rewrite.Rule{
+ &rewrite.ComplexRule{Base: "/", To: "/to", Exts: []string{"/", "html", "txt"}, Regexp: regexp.MustCompile(".*")},
+ }},
+ {`rewrite /path {
+ r rr
+ to /dest
+ }
+ rewrite / {
+ regexp [a-z]+
+ to /to /to2
+ }
+ `, false, []rewrite.Rule{
+ &rewrite.ComplexRule{Base: "/path", To: "/dest", Regexp: regexp.MustCompile("rr")},
+ &rewrite.ComplexRule{Base: "/", To: "/to /to2", Regexp: regexp.MustCompile("[a-z]+")},
+ }},
+ {`rewrite {
+ r .*
+ }`, true, []rewrite.Rule{
+ &rewrite.ComplexRule{},
+ }},
+ {`rewrite {
+
+ }`, true, []rewrite.Rule{
+ &rewrite.ComplexRule{},
+ }},
+ {`rewrite /`, true, []rewrite.Rule{
+ &rewrite.ComplexRule{},
+ }},
+ {`rewrite {
+ to /to
+ if {path} is a
+ }`, false, []rewrite.Rule{
+ &rewrite.ComplexRule{Base: "/", To: "/to", Ifs: []rewrite.If{{A: "{path}", Operator: "is", B: "a"}}},
+ }},
+ {`rewrite {
+ status 500
+ }`, true, []rewrite.Rule{
+ &rewrite.ComplexRule{},
+ }},
+ {`rewrite {
+ status 400
+ }`, false, []rewrite.Rule{
+ &rewrite.ComplexRule{Base: "/", Status: 400},
+ }},
+ {`rewrite {
+ to /to
+ status 400
+ }`, false, []rewrite.Rule{
+ &rewrite.ComplexRule{Base: "/", To: "/to", Status: 400},
+ }},
+ {`rewrite {
+ status 399
+ }`, true, []rewrite.Rule{
+ &rewrite.ComplexRule{},
+ }},
+ {`rewrite {
+ status 200
+ }`, false, []rewrite.Rule{
+ &rewrite.ComplexRule{Base: "/", Status: 200},
+ }},
+ {`rewrite {
+ to /to
+ status 200
+ }`, false, []rewrite.Rule{
+ &rewrite.ComplexRule{Base: "/", To: "/to", Status: 200},
+ }},
+ {`rewrite {
+ status 199
+ }`, true, []rewrite.Rule{
+ &rewrite.ComplexRule{},
+ }},
+ {`rewrite {
+ status 0
+ }`, true, []rewrite.Rule{
+ &rewrite.ComplexRule{},
+ }},
+ {`rewrite {
+ to /to
+ status 0
+ }`, true, []rewrite.Rule{
+ &rewrite.ComplexRule{},
+ }},
+ }
+
+ for i, test := range regexpTests {
+ c := NewTestController(test.input)
+ actual, err := rewriteParse(c)
+
+ if err == nil && test.shouldErr {
+ t.Errorf("Test %d didn't error, but it should have", i)
+ } else if err != nil && !test.shouldErr {
+ t.Errorf("Test %d errored, but it shouldn't have; got '%v'", i, err)
+ } else if err != nil && test.shouldErr {
+ continue
+ }
+
+ if len(actual) != len(test.expected) {
+ t.Fatalf("Test %d expected %d rules, but got %d",
+ i, len(test.expected), len(actual))
+ }
+
+ for j, e := range test.expected {
+ actualRule := actual[j].(*rewrite.ComplexRule)
+ expectedRule := e.(*rewrite.ComplexRule)
+
+ if actualRule.Base != expectedRule.Base {
+ t.Errorf("Test %d, rule %d: Expected Base=%s, got %s",
+ i, j, expectedRule.Base, actualRule.Base)
+ }
+
+ if actualRule.To != expectedRule.To {
+ t.Errorf("Test %d, rule %d: Expected To=%s, got %s",
+ i, j, expectedRule.To, actualRule.To)
+ }
+
+ if fmt.Sprint(actualRule.Exts) != fmt.Sprint(expectedRule.Exts) {
+ t.Errorf("Test %d, rule %d: Expected Ext=%v, got %v",
+ i, j, expectedRule.To, actualRule.To)
+ }
+
+ if actualRule.Regexp != nil {
+ if actualRule.String() != expectedRule.String() {
+ t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s",
+ i, j, expectedRule.String(), actualRule.String())
+ }
+ }
+
+ if fmt.Sprint(actualRule.Ifs) != fmt.Sprint(expectedRule.Ifs) {
+ t.Errorf("Test %d, rule %d: Expected Pattern=%s, got %s",
+ i, j, fmt.Sprint(expectedRule.Ifs), fmt.Sprint(actualRule.Ifs))
+ }
+
+ }
+ }
+
+}
diff --git a/core/setup/roller.go b/core/setup/roller.go
new file mode 100644
index 000000000..fd772cc47
--- /dev/null
+++ b/core/setup/roller.go
@@ -0,0 +1,40 @@
+package setup
+
+import (
+ "strconv"
+
+ "github.com/miekg/coredns/middleware"
+)
+
+func parseRoller(c *Controller) (*middleware.LogRoller, error) {
+ var size, age, keep int
+ // This is kind of a hack to support nested blocks:
+ // As we are already in a block: either log or errors,
+ // c.nesting > 0 but, as soon as c meets a }, it thinks
+ // the block is over and return false for c.NextBlock.
+ for c.NextBlock() {
+ what := c.Val()
+ if !c.NextArg() {
+ return nil, c.ArgErr()
+ }
+ value := c.Val()
+ var err error
+ switch what {
+ case "size":
+ size, err = strconv.Atoi(value)
+ case "age":
+ age, err = strconv.Atoi(value)
+ case "keep":
+ keep, err = strconv.Atoi(value)
+ }
+ if err != nil {
+ return nil, err
+ }
+ }
+ return &middleware.LogRoller{
+ MaxSize: size,
+ MaxAge: age,
+ MaxBackups: keep,
+ LocalTime: true,
+ }, nil
+}
diff --git a/core/setup/root.go b/core/setup/root.go
new file mode 100644
index 000000000..0fce5f170
--- /dev/null
+++ b/core/setup/root.go
@@ -0,0 +1,32 @@
+package setup
+
+import (
+ "log"
+ "os"
+
+ "github.com/miekg/coredns/middleware"
+)
+
+// Root sets up the root file path of the server.
+func Root(c *Controller) (middleware.Middleware, error) {
+ for c.Next() {
+ if !c.NextArg() {
+ return nil, c.ArgErr()
+ }
+ c.Root = c.Val()
+ }
+
+ // Check if root path exists
+ _, err := os.Stat(c.Root)
+ if err != nil {
+ if os.IsNotExist(err) {
+ // Allow this, because the folder might appear later.
+ // But make sure the user knows!
+ log.Printf("[WARNING] Root path does not exist: %s", c.Root)
+ } else {
+ return nil, c.Errf("Unable to access root path '%s': %v", c.Root, err)
+ }
+ }
+
+ return nil, nil
+}
diff --git a/core/setup/root_test.go b/core/setup/root_test.go
new file mode 100644
index 000000000..8b38e6d04
--- /dev/null
+++ b/core/setup/root_test.go
@@ -0,0 +1,108 @@
+package setup
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+func TestRoot(t *testing.T) {
+
+ // Predefined error substrings
+ parseErrContent := "Parse error:"
+ unableToAccessErrContent := "Unable to access root path"
+
+ existingDirPath, err := getTempDirPath()
+ if err != nil {
+ t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err)
+ }
+
+ nonExistingDir := filepath.Join(existingDirPath, "highly_unlikely_to_exist_dir")
+
+ existingFile, err := ioutil.TempFile("", "root_test")
+ if err != nil {
+ t.Fatalf("BeforeTest: Failed to create temp file for testing! Error was: %v", err)
+ }
+ defer func() {
+ existingFile.Close()
+ os.Remove(existingFile.Name())
+ }()
+
+ inaccessiblePath := getInaccessiblePath(existingFile.Name())
+
+ tests := []struct {
+ input string
+ shouldErr bool
+ expectedRoot string // expected root, set to the controller. Empty for negative cases.
+ expectedErrContent string // substring from the expected error. Empty for positive cases.
+ }{
+ // positive
+ {
+ fmt.Sprintf(`root %s`, nonExistingDir), false, nonExistingDir, "",
+ },
+ {
+ fmt.Sprintf(`root %s`, existingDirPath), false, existingDirPath, "",
+ },
+ // negative
+ {
+ `root `, true, "", parseErrContent,
+ },
+ {
+ fmt.Sprintf(`root %s`, inaccessiblePath), true, "", unableToAccessErrContent,
+ },
+ {
+ fmt.Sprintf(`root {
+ %s
+ }`, existingDirPath), true, "", parseErrContent,
+ },
+ }
+
+ for i, test := range tests {
+ c := NewTestController(test.input)
+ mid, err := Root(c)
+
+ if test.shouldErr && err == nil {
+ t.Errorf("Test %d: Expected error but found %s for input %s", i, err, test.input)
+ }
+
+ if err != nil {
+ if !test.shouldErr {
+ t.Errorf("Test %d: Expected no error but found one for input %s. Error was: %v", i, test.input, err)
+ }
+
+ if !strings.Contains(err.Error(), test.expectedErrContent) {
+ t.Errorf("Test %d: Expected error to contain: %v, found error: %v, input: %s", i, test.expectedErrContent, err, test.input)
+ }
+ }
+
+ // the Root method always returns a nil middleware
+ if mid != nil {
+ t.Errorf("Middware, returned from Root() was not nil: %v", mid)
+ }
+
+ // check c.Root only if we are in a positive test.
+ if !test.shouldErr && test.expectedRoot != c.Root {
+ t.Errorf("Root not correctly set for input %s. Expected: %s, actual: %s", test.input, test.expectedRoot, c.Root)
+ }
+ }
+}
+
+// getTempDirPath returnes the path to the system temp directory. If it does not exists - an error is returned.
+func getTempDirPath() (string, error) {
+ tempDir := os.TempDir()
+
+ _, err := os.Stat(tempDir)
+ if err != nil {
+ return "", err
+ }
+
+ return tempDir, nil
+}
+
+func getInaccessiblePath(file string) string {
+ // null byte in filename is not allowed on Windows AND unix
+ return filepath.Join("C:", "file\x00name")
+}
diff --git a/core/setup/startupshutdown.go b/core/setup/startupshutdown.go
new file mode 100644
index 000000000..1cf2c62e0
--- /dev/null
+++ b/core/setup/startupshutdown.go
@@ -0,0 +1,64 @@
+package setup
+
+import (
+ "os"
+ "os/exec"
+ "strings"
+
+ "github.com/miekg/coredns/middleware"
+)
+
+// Startup registers a startup callback to execute during server start.
+func Startup(c *Controller) (middleware.Middleware, error) {
+ return nil, registerCallback(c, &c.FirstStartup)
+}
+
+// Shutdown registers a shutdown callback to execute during process exit.
+func Shutdown(c *Controller) (middleware.Middleware, error) {
+ return nil, registerCallback(c, &c.Shutdown)
+}
+
+// registerCallback registers a callback function to execute by
+// using c to parse the line. It appends the callback function
+// to the list of callback functions passed in by reference.
+func registerCallback(c *Controller, list *[]func() error) error {
+ var funcs []func() error
+
+ for c.Next() {
+ args := c.RemainingArgs()
+ if len(args) == 0 {
+ return c.ArgErr()
+ }
+
+ nonblock := false
+ if len(args) > 1 && args[len(args)-1] == "&" {
+ // Run command in background; non-blocking
+ nonblock = true
+ args = args[:len(args)-1]
+ }
+
+ command, args, err := middleware.SplitCommandAndArgs(strings.Join(args, " "))
+ if err != nil {
+ return c.Err(err.Error())
+ }
+
+ fn := func() error {
+ cmd := exec.Command(command, args...)
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ if nonblock {
+ return cmd.Start()
+ }
+ return cmd.Run()
+ }
+
+ funcs = append(funcs, fn)
+ }
+
+ return c.OncePerServerBlock(func() error {
+ *list = append(*list, funcs...)
+ return nil
+ })
+}
diff --git a/core/setup/startupshutdown_test.go b/core/setup/startupshutdown_test.go
new file mode 100644
index 000000000..871a64214
--- /dev/null
+++ b/core/setup/startupshutdown_test.go
@@ -0,0 +1,59 @@
+package setup
+
+import (
+ "os"
+ "path/filepath"
+ "strconv"
+ "testing"
+ "time"
+)
+
+// The Startup function's tests are symmetrical to Shutdown tests,
+// because the Startup and Shutdown functions share virtually the
+// same functionality
+func TestStartup(t *testing.T) {
+ tempDirPath, err := getTempDirPath()
+ if err != nil {
+ t.Fatalf("BeforeTest: Failed to find an existing directory for testing! Error was: %v", err)
+ }
+
+ testDir := filepath.Join(tempDirPath, "temp_dir_for_testing_startupshutdown")
+ defer func() {
+ // clean up after non-blocking startup function quits
+ time.Sleep(500 * time.Millisecond)
+ os.RemoveAll(testDir)
+ }()
+ osSenitiveTestDir := filepath.FromSlash(testDir)
+ os.RemoveAll(osSenitiveTestDir) // start with a clean slate
+
+ tests := []struct {
+ input string
+ shouldExecutionErr bool
+ shouldRemoveErr bool
+ }{
+ // test case #0 tests proper functionality blocking commands
+ {"startup mkdir " + osSenitiveTestDir, false, false},
+
+ // test case #1 tests proper functionality of non-blocking commands
+ {"startup mkdir " + osSenitiveTestDir + " &", false, true},
+
+ // test case #2 tests handling of non-existent commands
+ {"startup " + strconv.Itoa(int(time.Now().UnixNano())), true, true},
+ }
+
+ for i, test := range tests {
+ c := NewTestController(test.input)
+ _, err = Startup(c)
+ if err != nil {
+ t.Errorf("Expected no errors, got: %v", err)
+ }
+ err = c.FirstStartup[0]()
+ if err != nil && !test.shouldExecutionErr {
+ t.Errorf("Test %d recieved an error of:\n%v", i, err)
+ }
+ err = os.Remove(osSenitiveTestDir)
+ if err != nil && !test.shouldRemoveErr {
+ t.Errorf("Test %d recieved an error of:\n%v", i, err)
+ }
+ }
+}
diff --git a/core/setup/testdata/blog/first_post.md b/core/setup/testdata/blog/first_post.md
new file mode 100644
index 000000000..f26583b75
--- /dev/null
+++ b/core/setup/testdata/blog/first_post.md
@@ -0,0 +1 @@
+# Test h1
diff --git a/core/setup/testdata/header.html b/core/setup/testdata/header.html
new file mode 100644
index 000000000..9c96e0e37
--- /dev/null
+++ b/core/setup/testdata/header.html
@@ -0,0 +1 @@
+<h1>Header title</h1>
diff --git a/core/setup/testdata/tpl_with_include.html b/core/setup/testdata/tpl_with_include.html
new file mode 100644
index 000000000..95eeae0c8
--- /dev/null
+++ b/core/setup/testdata/tpl_with_include.html
@@ -0,0 +1,10 @@
+<!DOCTYPE html>
+<html>
+<head>
+<title>{{.Doc.title}}</title>
+</head>
+<body>
+{{.Include "header.html"}}
+{{.Doc.body}}
+</body>
+</html>
diff --git a/core/sigtrap.go b/core/sigtrap.go
new file mode 100644
index 000000000..3b74efb02
--- /dev/null
+++ b/core/sigtrap.go
@@ -0,0 +1,71 @@
+package core
+
+import (
+ "log"
+ "os"
+ "os/signal"
+ "sync"
+
+ "github.com/miekg/coredns/server"
+)
+
+// TrapSignals create signal handlers for all applicable signals for this
+// system. If your Go program uses signals, this is a rather invasive
+// function; best to implement them yourself in that case. Signals are not
+// required for the caddy package to function properly, but this is a
+// convenient way to allow the user to control this package of your program.
+func TrapSignals() {
+ trapSignalsCrossPlatform()
+ trapSignalsPosix()
+}
+
+// trapSignalsCrossPlatform captures SIGINT, which triggers forceful
+// shutdown that executes shutdown callbacks first. A second interrupt
+// signal will exit the process immediately.
+func trapSignalsCrossPlatform() {
+ go func() {
+ shutdown := make(chan os.Signal, 1)
+ signal.Notify(shutdown, os.Interrupt)
+
+ for i := 0; true; i++ {
+ <-shutdown
+
+ if i > 0 {
+ log.Println("[INFO] SIGINT: Force quit")
+ if PidFile != "" {
+ os.Remove(PidFile)
+ }
+ os.Exit(1)
+ }
+
+ log.Println("[INFO] SIGINT: Shutting down")
+
+ if PidFile != "" {
+ os.Remove(PidFile)
+ }
+
+ go os.Exit(executeShutdownCallbacks("SIGINT"))
+ }
+ }()
+}
+
+// executeShutdownCallbacks executes the shutdown callbacks as initiated
+// by signame. It logs any errors and returns the recommended exit status.
+// This function is idempotent; subsequent invocations always return 0.
+func executeShutdownCallbacks(signame string) (exitCode int) {
+ shutdownCallbacksOnce.Do(func() {
+ serversMu.Lock()
+ errs := server.ShutdownCallbacks(servers)
+ serversMu.Unlock()
+
+ if len(errs) > 0 {
+ for _, err := range errs {
+ log.Printf("[ERROR] %s shutdown: %v", signame, err)
+ }
+ exitCode = 1
+ }
+ })
+ return
+}
+
+var shutdownCallbacksOnce sync.Once
diff --git a/core/sigtrap_posix.go b/core/sigtrap_posix.go
new file mode 100644
index 000000000..ba24ff4b6
--- /dev/null
+++ b/core/sigtrap_posix.go
@@ -0,0 +1,79 @@
+// +build !windows
+
+package core
+
+import (
+ "io/ioutil"
+ "log"
+ "os"
+ "os/signal"
+ "syscall"
+)
+
+// trapSignalsPosix captures POSIX-only signals.
+func trapSignalsPosix() {
+ go func() {
+ sigchan := make(chan os.Signal, 1)
+ signal.Notify(sigchan, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGUSR1)
+
+ for sig := range sigchan {
+ switch sig {
+ case syscall.SIGTERM:
+ log.Println("[INFO] SIGTERM: Terminating process")
+ if PidFile != "" {
+ os.Remove(PidFile)
+ }
+ os.Exit(0)
+
+ case syscall.SIGQUIT:
+ log.Println("[INFO] SIGQUIT: Shutting down")
+ exitCode := executeShutdownCallbacks("SIGQUIT")
+ err := Stop()
+ if err != nil {
+ log.Printf("[ERROR] SIGQUIT stop: %v", err)
+ exitCode = 1
+ }
+ if PidFile != "" {
+ os.Remove(PidFile)
+ }
+ os.Exit(exitCode)
+
+ case syscall.SIGHUP:
+ log.Println("[INFO] SIGHUP: Hanging up")
+ err := Stop()
+ if err != nil {
+ log.Printf("[ERROR] SIGHUP stop: %v", err)
+ }
+
+ case syscall.SIGUSR1:
+ log.Println("[INFO] SIGUSR1: Reloading")
+
+ var updatedCaddyfile Input
+
+ caddyfileMu.Lock()
+ if caddyfile == nil {
+ // Hmm, did spawing process forget to close stdin? Anyhow, this is unusual.
+ log.Println("[ERROR] SIGUSR1: no Caddyfile to reload (was stdin left open?)")
+ caddyfileMu.Unlock()
+ continue
+ }
+ if caddyfile.IsFile() {
+ body, err := ioutil.ReadFile(caddyfile.Path())
+ if err == nil {
+ updatedCaddyfile = CaddyfileInput{
+ Filepath: caddyfile.Path(),
+ Contents: body,
+ RealFile: true,
+ }
+ }
+ }
+ caddyfileMu.Unlock()
+
+ err := Restart(updatedCaddyfile)
+ if err != nil {
+ log.Printf("[ERROR] SIGUSR1: %v", err)
+ }
+ }
+ }
+ }()
+}
diff --git a/core/sigtrap_windows.go b/core/sigtrap_windows.go
new file mode 100644
index 000000000..59132cee4
--- /dev/null
+++ b/core/sigtrap_windows.go
@@ -0,0 +1,3 @@
+package core
+
+func trapSignalsPosix() {}
diff --git a/db.dns.miek.nl b/db.dns.miek.nl
new file mode 100644
index 000000000..78d75f10a
--- /dev/null
+++ b/db.dns.miek.nl
@@ -0,0 +1,10 @@
+$TTL 30M
+@ IN SOA linode.atoom.net. miek.miek.nl. (
+ 1282630058 ; Serial
+ 4H ; Refresh
+ 1H ; Retry
+ 7D ; Expire
+ 4H ) ; Negative Cache TTL
+ IN NS linode.atoom.net.
+
+go IN TXT "Hello!"
diff --git a/db.miek.nl b/db.miek.nl
new file mode 100644
index 000000000..9f0737bb3
--- /dev/null
+++ b/db.miek.nl
@@ -0,0 +1,29 @@
+$TTL 30M
+@ IN SOA linode.atoom.net. miek.miek.nl. (
+ 1282630057 ; Serial
+ 4H ; Refresh
+ 1H ; Retry
+ 7D ; Expire
+ 4H ) ; Negative Cache TTL
+ IN NS linode.atoom.net.
+ IN NS ns-ext.nlnetlabs.nl.
+ IN NS omval.tednet.nl.
+ IN NS ext.ns.whyscream.net.
+
+ IN MX 1 aspmx.l.google.com.
+ IN MX 5 alt1.aspmx.l.google.com.
+ IN MX 5 alt2.aspmx.l.google.com.
+ IN MX 10 aspmx2.googlemail.com.
+ IN MX 10 aspmx3.googlemail.com.
+
+ IN HINFO "Intel x64" "Linux"
+ IN A 139.162.196.78
+ IN AAAA 2a01:7e00::f03c:91ff:fef1:6735
+
+a IN A 139.162.196.78
+ IN AAAA 2a01:7e00::f03c:91ff:fef1:6735
+www IN CNAME a
+archive IN CNAME a
+
+; Go DNS ping back
+go.dns IN TXT "Hello!"
diff --git a/dist/CHANGES.txt b/dist/CHANGES.txt
new file mode 100644
index 000000000..34b69d7d5
--- /dev/null
+++ b/dist/CHANGES.txt
@@ -0,0 +1,190 @@
+CHANGES
+
+0.8.2 (February 25, 2016)
+- On-demand TLS can obtain certificates during handshakes
+- Built with Go 1.6
+- Process log (-log) is rotated when it gets large
+- Managed certificates get renewed 30 days early instead of just 14
+- fastcgi: Allow scheme prefix before address
+- markdown: Support for definition lists
+- proxy: Allow proxy to insecure HTTPS backends
+- proxy: Support proxy to unix socket
+- rewrite: Status code can be 2xx or 4xx
+- templates: New .Markdown action to interpret included file as Markdown
+- templates: .Truncate now truncates from end of string when length is negative
+- tls: Set hard limit for certificates obtained with on-demand TLS
+- tls: Load certificates from directory
+- tls: Add SHA384 cipher suites
+- Multiple bug fixes and internal changes
+
+
+0.8.1 (January 12, 2016)
+- Improved OCSP stapling
+- Better graceful reload when new hosts need certificates from Let's Encrypt
+- Current pidfile is now deleted when Caddy exits
+- browse: New default template
+- gzip: Added min_length setting
+- import: Support for glob patterns (*) to import multiple files
+- rewrite: New complex rules with conditions, regex captures, and status code
+- tls: Removed DES ciphers from default cipher suite list
+- tls: All supported certificates are OCSP-stapled
+- tls: Allow custom configuration without specifying certificate and key
+- tls: No longer allow HTTPS over port 80
+- Dozens of bug fixes, improvements, and more tests across the board
+
+
+0.8.0 (December 4, 2015)
+- HTTPS by default via Let's Encrypt (certs & keys are fully managed)
+- Graceful restarts (on POSIX-compliant systems)
+- Major internal refactoring to allow use of Caddy as library
+- New directive 'mime' to customize Content-Type based on file extension
+- New -accept flag to accept Let's Encrypt SA without prompt
+- New -email flag to customize default email used for ACME transactions
+- New -ca flag to customize ACME CA server URL
+- New -revoke flag to revoke a certificate
+- New -log flag to enable process log
+- New -pidfile flag to enable writing pidfile
+- New -grace flag to customize the graceful shutdown timeout
+- New support for SIGHUP, SIGTERM, and SIGQUIT signals
+- browse: Render filenames with multiple whitespace properly
+- core: Use environment variables in Caddyfile
+- markdown: Include Last-Modified header in response
+- markdown: Render tables, strikethrough, and fenced code blocks
+- proxy: Ability to exclude/ignore paths from proxying
+- startup, shutdown: Better Windows support
+- templates: Bug fix for .Host when port is absent
+- templates: Include Last-Modified header in response
+- templates: Support for custom delimiters
+- tls: For non-local hosts, default port is now 443 unless specified
+- tls: Force-disable HTTPS
+- tls: Specify Let's Encrypt email address
+- Many, many more tests and numerous bug fixes and improvements
+
+
+0.7.6 (September 28, 2015)
+- Pass in simple Caddyfile as command line arguments
+- basicauth: Support for legacy htpasswd files
+- browse: JSON response with file listing
+- core: Caddyfile as command line argument
+- errors: Can write full stack trace to HTTP response for debugging
+- errors, log: Roll log files after certain size or age
+- proxy: Fix for 32-bit architectures
+- rewrite: Better compatibility with fastcgi and PHP apps
+- templates: Added .StripExt and .StripHTML methods
+- Internal improvements and minor bug fixes
+
+
+0.7.5 (August 5, 2015)
+- core: All listeners bind to 0.0.0.0 unless 'bind' directive is used
+- fastcgi: Set HTTPS env variable if connection is secure
+- log: Output to system log (except Windows)
+- markdown: Added dev command to disable caching during development
+- markdown: Fixed error reporting during initial site generation
+- markdown: Fixed crash if path does not exist when server starts
+- markdown: Fixed site generation and link indexing when files change
+- templates: Added .NowDate for use in date-related functions
+- Several bug fixes related to startup and shutdown functions
+
+
+0.7.4 (July 30, 2015)
+- browse: Sorting preference persisted in cookie
+- browse: Added index.txt and default.txt to list of default files
+- browse: Template files may now use Caddy template actions
+- markdown: Template files may now use Caddy template actions
+- markdown: Several bug fixes, especially for large and empty Markdown files
+- markdown: Generate index pages to link to markdown pages (sitegen only)
+- markdown: Flatten structure of front matter, changed template variables
+- redir: Can use variables (placeholders) like log formats can
+- redir: Catch-all redirects no longer preserve path; use {uri} instead
+- redir: Syntax supports redirect tables by opening a block
+- templates: Renamed .Date to .Now and added .Truncate, .Replace actions
+- Other minor internal improvements and more tests
+
+
+0.7.3 (July 15, 2015)
+- errors: Error log now shows timestamp with each entry
+- gzip: Fixed; Default filtering is by extension; removed MIME type filter
+- import: Fixed; works inside and outside server blocks
+- redir: Query string preserved on catch-all redirects
+- templates: Proper 403 or 404 errors for restricted or missing files
+
+
+0.7.2 (July 1, 2015)
+- Custom builds through caddyserver.com - extend Caddy by writing addons
+- browse: Sort by clicking column heading or using query string
+- core: Serving hostname that doesn't resolve issues warning then listens on 0.0.0.0
+- errors: Missing error page during parse time is warning, not error
+- ext: Extension only appended if request path does not end in /
+- fastcgi: Fix for backend responding without status text
+- fastcgi: Fix PATH_TRANSLATED when PATH_INFO is empty (RFC 3875)
+- git: Removed from core (available as add-on)
+- gzip: Enable by file path and/or extension
+- gzip: Customize compression level
+- log: Fix for missing status in log entry when error unhandled
+- proxy: Strip prefix from path for proxy to path
+- redir: Meta tag redirects
+- templates: Support for nested includes
+- Internal improvements and more tests
+
+
+0.7.1 (June 2, 2015)
+- basicauth: Patched timing vulnerability
+- proxy: Support for WebSocket backends
+- tls: Client authentication
+
+
+0.7.0 (May 25, 2015)
+- New directive 'internal' to protect resources with X-Accel-Redirect
+- New -version flag to show program name and version
+- core: Fixed escaped backslash characters inside quoted strings
+- core: Fixed parsing Caddyfile for IPv6 addresses missing ports
+- core: A notice is shown when non-local address resolves to loopback interface
+- core: Warns if file descriptor limit is too low for production site (Mac/Linux)
+- fastcgi: Support for Unix sockets
+- git: Fixed issue that prevented pulling at designated interval
+- header: Remove a header field by prefixing field name with "-"
+- markdown: Simple static site generation
+- markdown: Support for metadata ("front matter") at beginning of files
+- rewrite: Experimental support for regular expressions
+- tls: Customize cipher suites and protocols
+- tls: Removed RC4 ciphers
+- Other internal improvements that are not user-facing (more tests, etc.)
+
+
+0.6.0 (May 7, 2015)
+- New directive 'git' to automatically pull changes
+- New directive 'bind' to override host server binds to
+- New -root flag to specify root path to default site
+- Ability to receive config data piped through stdin
+- core: Warning if root directory doesn't exist at startup
+- core: Entire process dies if any server fails to start
+- gzip: Fixed Content-Length value when proxying requests
+- errors: Error log now includes file and line number of panics
+- fastcgi: Pass custom environment variables
+- fastcgi: Support for HEAD, OPTIONS, PUT, PATCH, and DELETE methods
+- fastcgi: Fixed SERVER_SOFTWARE variables
+- markdown: Support for index files when URL points to a directory
+- proxy: Load balancing with multiple backends, health checks, failovers, and multiple policies
+- proxy: Add custom headers
+- startup/shutdown: Run command in background with '&' at end
+- templates: Added .tpl and .tmpl as default extensions
+- templates: Support for index files when URL points to a directory
+- templates: Changed .RemoteAddr to .IP and stripped out remote port
+- tls: TLS disabled (with warning) for servers that are explicitly http://
+- websocket: Fixed SERVER_SOFTWARE and GATEWAY_INTERFACE variables
+- Many internal improvements
+
+
+0.5.1 (April 30, 2015)
+- Default host is now 0.0.0.0 (wildcard)
+- New -host and -port flags to override default host and port
+- core: Support for binding to 0.0.0.0
+- core: Graceful error handling during heavy load; proper error responses
+- errors: Fixed file path handling
+- errors: Fixed panic due to nil log file
+- fastcgi: Support for index files
+- fastcgi: Fix for handling errors that come from responder
+
+
+0.5.0 (April 28, 2015)
+- Initial release
diff --git a/dist/LICENSES.txt b/dist/LICENSES.txt
new file mode 100644
index 000000000..c6ca2e2b0
--- /dev/null
+++ b/dist/LICENSES.txt
@@ -0,0 +1,539 @@
+The enclosed software makes use of third-party libraries either in full
+or in part, original or modified. This file is part of your download so
+as to be in full compliance with the licenses of all bundled property.
+
+
+
+###
+### github.com/mholt/caddy
+###
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright {yyyy} {name of copyright owner}
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+
+
+
+
+
+
+
+
+
+###
+### Go standard library and http2
+###
+
+
+Copyright (c) 2012 The Go Authors. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are
+met:
+
+ * Redistributions of source code must retain the above copyright
+notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above
+copyright notice, this list of conditions and the following disclaimer
+in the documentation and/or other materials provided with the
+distribution.
+ * Neither the name of Google Inc. nor the names of its
+contributors may be used to endorse or promote products derived from
+this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+
+
+
+
+
+###
+### github.com/russross/blackfriday
+###
+
+
+Blackfriday is distributed under the Simplified BSD License:
+
+> Copyright © 2011 Russ Ross
+> All rights reserved.
+>
+> Redistribution and use in source and binary forms, with or without
+> modification, are permitted provided that the following conditions
+> are met:
+>
+> 1. Redistributions of source code must retain the above copyright
+> notice, this list of conditions and the following disclaimer.
+>
+> 2. Redistributions in binary form must reproduce the above
+> copyright notice, this list of conditions and the following
+> disclaimer in the documentation and/or other materials provided with
+> the distribution.
+>
+> THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+> "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+> LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+> FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+> COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+> INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+> BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+> LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+> CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+> LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+> ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+> POSSIBILITY OF SUCH DAMAGE.
+
+
+
+
+
+
+
+
+###
+### github.com/dustin/go-humanize
+###
+
+
+Copyright (c) 2005-2008 Dustin Sallings <dustin@spy.net>
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+<http://www.opensource.org/licenses/mit-license.php>
+
+
+
+
+
+
+
+
+###
+### github.com/flynn/go-shlex
+###
+
+Apache 2.0 license as found in this file
+
+
+
+
+
+
+###
+### github.com/go-yaml/yaml
+###
+
+
+Copyright (c) 2011-2014 - Canonical Inc.
+
+This software is licensed under the LGPLv3, included below.
+
+As a special exception to the GNU Lesser General Public License version 3
+("LGPL3"), the copyright holders of this Library give you permission to
+convey to a third party a Combined Work that links statically or dynamically
+to this Library without providing any Minimal Corresponding Source or
+Minimal Application Code as set out in 4d or providing the installation
+information set out in section 4e, provided that you comply with the other
+provisions of LGPL3 and provided that you meet, for the Application the
+terms and conditions of the license(s) which apply to the Application.
+
+Except as stated in this special exception, the provisions of LGPL3 will
+continue to comply in full to this Library. If you modify this Library, you
+may apply this exception to your version of this Library, but you are not
+obliged to do so. If you do not wish to do so, delete this exception
+statement from your version. This exception does not (and cannot) modify any
+license terms which apply to the Application, with which you must still
+comply.
+
+
+ GNU LESSER GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+
+ This version of the GNU Lesser General Public License incorporates
+the terms and conditions of version 3 of the GNU General Public
+License, supplemented by the additional permissions listed below.
+
+ 0. Additional Definitions.
+
+ As used herein, "this License" refers to version 3 of the GNU Lesser
+General Public License, and the "GNU GPL" refers to version 3 of the GNU
+General Public License.
+
+ "The Library" refers to a covered work governed by this License,
+other than an Application or a Combined Work as defined below.
+
+ An "Application" is any work that makes use of an interface provided
+by the Library, but which is not otherwise based on the Library.
+Defining a subclass of a class defined by the Library is deemed a mode
+of using an interface provided by the Library.
+
+ A "Combined Work" is a work produced by combining or linking an
+Application with the Library. The particular version of the Library
+with which the Combined Work was made is also called the "Linked
+Version".
+
+ The "Minimal Corresponding Source" for a Combined Work means the
+Corresponding Source for the Combined Work, excluding any source code
+for portions of the Combined Work that, considered in isolation, are
+based on the Application, and not on the Linked Version.
+
+ The "Corresponding Application Code" for a Combined Work means the
+object code and/or source code for the Application, including any data
+and utility programs needed for reproducing the Combined Work from the
+Application, but excluding the System Libraries of the Combined Work.
+
+ 1. Exception to Section 3 of the GNU GPL.
+
+ You may convey a covered work under sections 3 and 4 of this License
+without being bound by section 3 of the GNU GPL.
+
+ 2. Conveying Modified Versions.
+
+ If you modify a copy of the Library, and, in your modifications, a
+facility refers to a function or data to be supplied by an Application
+that uses the facility (other than as an argument passed when the
+facility is invoked), then you may convey a copy of the modified
+version:
+
+ a) under this License, provided that you make a good faith effort to
+ ensure that, in the event an Application does not supply the
+ function or data, the facility still operates, and performs
+ whatever part of its purpose remains meaningful, or
+
+ b) under the GNU GPL, with none of the additional permissions of
+ this License applicable to that copy.
+
+ 3. Object Code Incorporating Material from Library Header Files.
+
+ The object code form of an Application may incorporate material from
+a header file that is part of the Library. You may convey such object
+code under terms of your choice, provided that, if the incorporated
+material is not limited to numerical parameters, data structure
+layouts and accessors, or small macros, inline functions and templates
+(ten or fewer lines in length), you do both of the following:
+
+ a) Give prominent notice with each copy of the object code that the
+ Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the object code with a copy of the GNU GPL and this license
+ document.
+
+ 4. Combined Works.
+
+ You may convey a Combined Work under terms of your choice that,
+taken together, effectively do not restrict modification of the
+portions of the Library contained in the Combined Work and reverse
+engineering for debugging such modifications, if you also do each of
+the following:
+
+ a) Give prominent notice with each copy of the Combined Work that
+ the Library is used in it and that the Library and its use are
+ covered by this License.
+
+ b) Accompany the Combined Work with a copy of the GNU GPL and this license
+ document.
+
+ c) For a Combined Work that displays copyright notices during
+ execution, include the copyright notice for the Library among
+ these notices, as well as a reference directing the user to the
+ copies of the GNU GPL and this license document.
+
+ d) Do one of the following:
+
+ 0) Convey the Minimal Corresponding Source under the terms of this
+ License, and the Corresponding Application Code in a form
+ suitable for, and under terms that permit, the user to
+ recombine or relink the Application with a modified version of
+ the Linked Version to produce a modified Combined Work, in the
+ manner specified by section 6 of the GNU GPL for conveying
+ Corresponding Source.
+
+ 1) Use a suitable shared library mechanism for linking with the
+ Library. A suitable mechanism is one that (a) uses at run time
+ a copy of the Library already present on the user's computer
+ system, and (b) will operate properly with a modified version
+ of the Library that is interface-compatible with the Linked
+ Version.
+
+ e) Provide Installation Information, but only if you would otherwise
+ be required to provide such information under section 6 of the
+ GNU GPL, and only to the extent that such information is
+ necessary to install and execute a modified version of the
+ Combined Work produced by recombining or relinking the
+ Application with a modified version of the Linked Version. (If
+ you use option 4d0, the Installation Information must accompany
+ the Minimal Corresponding Source and Corresponding Application
+ Code. If you use option 4d1, you must provide the Installation
+ Information in the manner specified by section 6 of the GNU GPL
+ for conveying Corresponding Source.)
+
+ 5. Combined Libraries.
+
+ You may place library facilities that are a work based on the
+Library side by side in a single library together with other library
+facilities that are not Applications and are not covered by this
+License, and convey such a combined library under terms of your
+choice, if you do both of the following:
+
+ a) Accompany the combined library with a copy of the same work based
+ on the Library, uncombined with any other library facilities,
+ conveyed under the terms of this License.
+
+ b) Give prominent notice with the combined library that part of it
+ is a work based on the Library, and explaining where to find the
+ accompanying uncombined form of the same work.
+
+ 6. Revised Versions of the GNU Lesser General Public License.
+
+ The Free Software Foundation may publish revised and/or new versions
+of the GNU Lesser General Public License from time to time. Such new
+versions will be similar in spirit to the present version, but may
+differ in detail to address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Library as you received it specifies that a certain numbered version
+of the GNU Lesser General Public License "or any later version"
+applies to it, you have the option of following the terms and
+conditions either of that published version or of any later version
+published by the Free Software Foundation. If the Library as you
+received it does not specify a version number of the GNU Lesser
+General Public License, you may choose any version of the GNU Lesser
+General Public License ever published by the Free Software Foundation.
+
+ If the Library as you received it specifies that a proxy can decide
+whether future versions of the GNU Lesser General Public License shall
+apply, that proxy's public statement of acceptance of any version is
+permanent authorization for you to choose that version for the
+Library.
diff --git a/dist/README.txt b/dist/README.txt
new file mode 100644
index 000000000..e2ec8a24b
--- /dev/null
+++ b/dist/README.txt
@@ -0,0 +1,30 @@
+CADDY 0.8.2
+
+Website
+ https://caddyserver.com
+
+Twitter
+ @caddyserver
+
+Source Code
+ https://github.com/mholt/caddy
+ https://github.com/caddyserver
+
+
+For instructions on using Caddy, please see the user guide on the website.
+For a list of what's new in this version, see CHANGES.txt.
+
+Please consider donating to the project if you think it is helpful,
+especially if your company is using Caddy. There are also sponsorship
+opportunities available!
+
+If you have a question, bug report, or would like to contribute, please open an
+issue or submit a pull request on GitHub. Your contributions do not go unnoticed!
+
+For a good time, follow @mholt6 on Twitter.
+
+And thanks - you're awesome!
+
+
+---
+(c) 2015 - 2016 Matthew Holt
diff --git a/dist/automate.sh b/dist/automate.sh
new file mode 100755
index 000000000..3cc8d41b2
--- /dev/null
+++ b/dist/automate.sh
@@ -0,0 +1,56 @@
+#!/usr/bin/env bash
+set -e
+set -o pipefail
+shopt -s nullglob # if no files match glob, assume empty list instead of string literal
+
+
+## PACKAGE TO BUILD
+Package=github.com/mholt/caddy
+
+
+## PATHS TO USE
+DistDir=$GOPATH/src/$Package/dist
+BuildDir=$DistDir/builds
+ReleaseDir=$DistDir/release
+
+
+## BEGIN
+
+# Compile binaries
+mkdir -p $BuildDir
+cd $BuildDir
+rm -f caddy*
+gox $Package
+
+# Zip them up with release notes and stuff
+mkdir -p $ReleaseDir
+cd $ReleaseDir
+rm -f caddy*
+for f in $BuildDir/*
+do
+ # Name .zip file same as binary, but strip .exe from end
+ zipname=$(basename ${f%".exe"})
+ if [[ $f == *"linux"* ]] || [[ $f == *"bsd"* ]]; then
+ zipname=${zipname}.tar.gz
+ else
+ zipname=${zipname}.zip
+ fi
+
+ # Binary inside the zip file is simply the project name
+ binbase=$(basename $Package)
+ if [[ $f == *.exe ]]; then
+ binbase=$binbase.exe
+ fi
+ bin=$BuildDir/$binbase
+ mv $f $bin
+
+ # Compress distributable
+ if [[ $zipname == *.zip ]]; then
+ zip -j $zipname $bin $DistDir/CHANGES.txt $DistDir/LICENSES.txt $DistDir/README.txt
+ else
+ tar -cvzf $zipname -C $BuildDir $binbase -C $DistDir CHANGES.txt LICENSES.txt README.txt
+ fi
+
+ # Put binary filename back to original
+ mv $bin $f
+done
diff --git a/main.go b/main.go
new file mode 100644
index 000000000..d35b0a5ee
--- /dev/null
+++ b/main.go
@@ -0,0 +1,232 @@
+package main
+
+import (
+ "errors"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "os"
+ "runtime"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/miekg/coredns/core"
+ "github.com/miekg/coredns/core/https"
+ "github.com/xenolf/lego/acme"
+ "gopkg.in/natefinch/lumberjack.v2"
+)
+
+func init() {
+ core.TrapSignals()
+ setVersion()
+ flag.BoolVar(&https.Agreed, "agree", false, "Agree to Let's Encrypt Subscriber Agreement")
+ flag.StringVar(&https.CAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "Certificate authority ACME server")
+ flag.StringVar(&conf, "conf", "", "Configuration file to use (default="+core.DefaultConfigFile+")")
+ flag.StringVar(&cpu, "cpu", "100%", "CPU cap")
+ flag.StringVar(&https.DefaultEmail, "email", "", "Default Let's Encrypt account email address")
+ flag.DurationVar(&core.GracefulTimeout, "grace", 5*time.Second, "Maximum duration of graceful shutdown")
+ flag.StringVar(&core.Host, "host", core.DefaultHost, "Default host")
+ flag.StringVar(&logfile, "log", "", "Process log file")
+ flag.StringVar(&core.PidFile, "pidfile", "", "Path to write pid file")
+ flag.StringVar(&core.Port, "port", core.DefaultPort, "Default port")
+ flag.BoolVar(&core.Quiet, "quiet", false, "Quiet mode (no initialization output)")
+ flag.StringVar(&revoke, "revoke", "", "Hostname for which to revoke the certificate")
+ flag.StringVar(&core.Root, "root", core.DefaultRoot, "Root path to default site")
+ flag.BoolVar(&version, "version", false, "Show version")
+}
+
+func main() {
+ flag.Parse() // called here in main() to allow other packages to set flags in their inits
+
+ core.AppName = appName
+ core.AppVersion = appVersion
+ acme.UserAgent = appName + "/" + appVersion
+
+ // set up process log before anything bad happens
+ switch logfile {
+ case "stdout":
+ log.SetOutput(os.Stdout)
+ case "stderr":
+ log.SetOutput(os.Stderr)
+ case "":
+ log.SetOutput(ioutil.Discard)
+ default:
+ log.SetOutput(&lumberjack.Logger{
+ Filename: logfile,
+ MaxSize: 100,
+ MaxAge: 14,
+ MaxBackups: 10,
+ })
+ }
+
+ if revoke != "" {
+ err := https.Revoke(revoke)
+ if err != nil {
+ log.Fatal(err)
+ }
+ fmt.Printf("Revoked certificate for %s\n", revoke)
+ os.Exit(0)
+ }
+ if version {
+ fmt.Printf("%s %s\n", appName, appVersion)
+ if devBuild && gitShortStat != "" {
+ fmt.Printf("%s\n%s\n", gitShortStat, gitFilesModified)
+ }
+ os.Exit(0)
+ }
+
+ // Set CPU cap
+ err := setCPU(cpu)
+ if err != nil {
+ mustLogFatal(err)
+ }
+
+ // Get Corefile input
+ caddyfile, err := core.LoadCaddyfile(loadCaddyfile)
+ if err != nil {
+ mustLogFatal(err)
+ }
+
+ // Start your engines
+ err = core.Start(caddyfile)
+ if err != nil {
+ mustLogFatal(err)
+ }
+
+ // Twiddle your thumbs
+ core.Wait()
+}
+
+// mustLogFatal just wraps log.Fatal() in a way that ensures the
+// output is always printed to stderr so the user can see it
+// if the user is still there, even if the process log was not
+// enabled. If this process is a restart, however, and the user
+// might not be there anymore, this just logs to the process log
+// and exits.
+func mustLogFatal(args ...interface{}) {
+ if !core.IsRestart() {
+ log.SetOutput(os.Stderr)
+ }
+ log.Fatal(args...)
+}
+
+func loadCaddyfile() (core.Input, error) {
+ // Try -conf flag
+ if conf != "" {
+ if conf == "stdin" {
+ return core.CaddyfileFromPipe(os.Stdin)
+ }
+
+ contents, err := ioutil.ReadFile(conf)
+ if err != nil {
+ return nil, err
+ }
+
+ return core.CaddyfileInput{
+ Contents: contents,
+ Filepath: conf,
+ RealFile: true,
+ }, nil
+ }
+
+ // command line args
+ if flag.NArg() > 0 {
+ confBody := core.Host + ":" + core.Port + "\n" + strings.Join(flag.Args(), "\n")
+ return core.CaddyfileInput{
+ Contents: []byte(confBody),
+ Filepath: "args",
+ }, nil
+ }
+
+ // Caddyfile in cwd
+ contents, err := ioutil.ReadFile(core.DefaultConfigFile)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return core.DefaultInput(), nil
+ }
+ return nil, err
+ }
+ return core.CaddyfileInput{
+ Contents: contents,
+ Filepath: core.DefaultConfigFile,
+ RealFile: true,
+ }, nil
+}
+
+// setCPU parses string cpu and sets GOMAXPROCS
+// according to its value. It accepts either
+// a number (e.g. 3) or a percent (e.g. 50%).
+func setCPU(cpu string) error {
+ var numCPU int
+
+ availCPU := runtime.NumCPU()
+
+ if strings.HasSuffix(cpu, "%") {
+ // Percent
+ var percent float32
+ pctStr := cpu[:len(cpu)-1]
+ pctInt, err := strconv.Atoi(pctStr)
+ if err != nil || pctInt < 1 || pctInt > 100 {
+ return errors.New("invalid CPU value: percentage must be between 1-100")
+ }
+ percent = float32(pctInt) / 100
+ numCPU = int(float32(availCPU) * percent)
+ } else {
+ // Number
+ num, err := strconv.Atoi(cpu)
+ if err != nil || num < 1 {
+ return errors.New("invalid CPU value: provide a number or percent greater than 0")
+ }
+ numCPU = num
+ }
+
+ if numCPU > availCPU {
+ numCPU = availCPU
+ }
+
+ runtime.GOMAXPROCS(numCPU)
+ return nil
+}
+
+// setVersion figures out the version information based on
+// variables set by -ldflags.
+func setVersion() {
+ // A development build is one that's not at a tag or has uncommitted changes
+ devBuild = gitTag == "" || gitShortStat != ""
+
+ // Only set the appVersion if -ldflags was used
+ if gitNearestTag != "" || gitTag != "" {
+ if devBuild && gitNearestTag != "" {
+ appVersion = fmt.Sprintf("%s (+%s %s)",
+ strings.TrimPrefix(gitNearestTag, "v"), gitCommit, buildDate)
+ } else if gitTag != "" {
+ appVersion = strings.TrimPrefix(gitTag, "v")
+ }
+ }
+}
+
+const appName = "Caddy"
+
+// Flags that control program flow or startup
+var (
+ conf string
+ cpu string
+ logfile string
+ revoke string
+ version bool
+)
+
+// Build information obtained with the help of -ldflags
+var (
+ appVersion = "(untracked dev build)" // inferred at startup
+ devBuild = true // inferred at startup
+
+ buildDate string // date -u
+ gitTag string // git describe --exact-match HEAD 2> /dev/null
+ gitNearestTag string // git describe --abbrev=0 --tags HEAD
+ gitCommit string // git rev-parse HEAD
+ gitShortStat string // git diff-index --shortstat
+ gitFilesModified string // git diff-index --name-only HEAD
+)
diff --git a/main_test.go b/main_test.go
new file mode 100644
index 000000000..01722ed60
--- /dev/null
+++ b/main_test.go
@@ -0,0 +1,75 @@
+package main
+
+import (
+ "runtime"
+ "testing"
+)
+
+func TestSetCPU(t *testing.T) {
+ currentCPU := runtime.GOMAXPROCS(-1)
+ maxCPU := runtime.NumCPU()
+ halfCPU := int(0.5 * float32(maxCPU))
+ if halfCPU < 1 {
+ halfCPU = 1
+ }
+ for i, test := range []struct {
+ input string
+ output int
+ shouldErr bool
+ }{
+ {"1", 1, false},
+ {"-1", currentCPU, true},
+ {"0", currentCPU, true},
+ {"100%", maxCPU, false},
+ {"50%", halfCPU, false},
+ {"110%", currentCPU, true},
+ {"-10%", currentCPU, true},
+ {"invalid input", currentCPU, true},
+ {"invalid input%", currentCPU, true},
+ {"9999", maxCPU, false}, // over available CPU
+ } {
+ err := setCPU(test.input)
+ if test.shouldErr && err == nil {
+ t.Errorf("Test %d: Expected error, but there wasn't any", i)
+ }
+ if !test.shouldErr && err != nil {
+ t.Errorf("Test %d: Expected no error, but there was one: %v", i, err)
+ }
+ if actual, expected := runtime.GOMAXPROCS(-1), test.output; actual != expected {
+ t.Errorf("Test %d: GOMAXPROCS was %d but expected %d", i, actual, expected)
+ }
+ // teardown
+ runtime.GOMAXPROCS(currentCPU)
+ }
+}
+
+func TestSetVersion(t *testing.T) {
+ setVersion()
+ if !devBuild {
+ t.Error("Expected default to assume development build, but it didn't")
+ }
+ if got, want := appVersion, "(untracked dev build)"; got != want {
+ t.Errorf("Expected appVersion='%s', got: '%s'", want, got)
+ }
+
+ gitTag = "v1.1"
+ setVersion()
+ if devBuild {
+ t.Error("Expected a stable build if gitTag is set with no changes")
+ }
+ if got, want := appVersion, "1.1"; got != want {
+ t.Errorf("Expected appVersion='%s', got: '%s'", want, got)
+ }
+
+ gitTag = ""
+ gitNearestTag = "v1.0"
+ gitCommit = "deadbeef"
+ buildDate = "Fri Feb 26 06:53:17 UTC 2016"
+ setVersion()
+ if !devBuild {
+ t.Error("Expected inferring a dev build when gitTag is empty")
+ }
+ if got, want := appVersion, "1.0 (+deadbeef Fri Feb 26 06:53:17 UTC 2016)"; got != want {
+ t.Errorf("Expected appVersion='%s', got: '%s'", want, got)
+ }
+}
diff --git a/middleware/commands.go b/middleware/commands.go
new file mode 100644
index 000000000..5c241161e
--- /dev/null
+++ b/middleware/commands.go
@@ -0,0 +1,120 @@
+package middleware
+
+import (
+ "errors"
+ "runtime"
+ "unicode"
+
+ "github.com/flynn/go-shlex"
+)
+
+var runtimeGoos = runtime.GOOS
+
+// SplitCommandAndArgs takes a command string and parses it
+// shell-style into the command and its separate arguments.
+func SplitCommandAndArgs(command string) (cmd string, args []string, err error) {
+ var parts []string
+
+ if runtimeGoos == "windows" {
+ parts = parseWindowsCommand(command) // parse it Windows-style
+ } else {
+ parts, err = parseUnixCommand(command) // parse it Unix-style
+ if err != nil {
+ err = errors.New("error parsing command: " + err.Error())
+ return
+ }
+ }
+
+ if len(parts) == 0 {
+ err = errors.New("no command contained in '" + command + "'")
+ return
+ }
+
+ cmd = parts[0]
+ if len(parts) > 1 {
+ args = parts[1:]
+ }
+
+ return
+}
+
+// parseUnixCommand parses a unix style command line and returns the
+// command and its arguments or an error
+func parseUnixCommand(cmd string) ([]string, error) {
+ return shlex.Split(cmd)
+}
+
+// parseWindowsCommand parses windows command lines and
+// returns the command and the arguments as an array. It
+// should be able to parse commonly used command lines.
+// Only basic syntax is supported:
+// - spaces in double quotes are not token delimiters
+// - double quotes are escaped by either backspace or another double quote
+// - except for the above case backspaces are path separators (not special)
+//
+// Many sources point out that escaping quotes using backslash can be unsafe.
+// Use two double quotes when possible. (Source: http://stackoverflow.com/a/31413730/2616179 )
+//
+// This function has to be used on Windows instead
+// of the shlex package because this function treats backslash
+// characters properly.
+func parseWindowsCommand(cmd string) []string {
+ const backslash = '\\'
+ const quote = '"'
+
+ var parts []string
+ var part string
+ var inQuotes bool
+ var lastRune rune
+
+ for i, ch := range cmd {
+
+ if i != 0 {
+ lastRune = rune(cmd[i-1])
+ }
+
+ if ch == backslash {
+ // put it in the part - for now we don't know if it's an
+ // escaping char or path separator
+ part += string(ch)
+ continue
+ }
+
+ if ch == quote {
+ if lastRune == backslash {
+ // remove the backslash from the part and add the escaped quote instead
+ part = part[:len(part)-1]
+ part += string(ch)
+ continue
+ }
+
+ if lastRune == quote {
+ // revert the last change of the inQuotes state
+ // it was an escaping quote
+ inQuotes = !inQuotes
+ part += string(ch)
+ continue
+ }
+
+ // normal escaping quotes
+ inQuotes = !inQuotes
+ continue
+
+ }
+
+ if unicode.IsSpace(ch) && !inQuotes && len(part) > 0 {
+ parts = append(parts, part)
+ part = ""
+ continue
+ }
+
+ part += string(ch)
+ }
+
+ if len(part) > 0 {
+ parts = append(parts, part)
+ part = ""
+ }
+
+ return parts
+}
diff --git a/middleware/commands_test.go b/middleware/commands_test.go
new file mode 100644
index 000000000..3001e65a5
--- /dev/null
+++ b/middleware/commands_test.go
@@ -0,0 +1,291 @@
+package middleware
+
+import (
+ "fmt"
+ "runtime"
+ "strings"
+ "testing"
+)
+
+func TestParseUnixCommand(t *testing.T) {
+ tests := []struct {
+ input string
+ expected []string
+ }{
+ // 0 - emtpy command
+ {
+ input: ``,
+ expected: []string{},
+ },
+ // 1 - command without arguments
+ {
+ input: `command`,
+ expected: []string{`command`},
+ },
+ // 2 - command with single argument
+ {
+ input: `command arg1`,
+ expected: []string{`command`, `arg1`},
+ },
+ // 3 - command with multiple arguments
+ {
+ input: `command arg1 arg2`,
+ expected: []string{`command`, `arg1`, `arg2`},
+ },
+ // 4 - command with single argument with space character - in quotes
+ {
+ input: `command "arg1 arg1"`,
+ expected: []string{`command`, `arg1 arg1`},
+ },
+ // 5 - command with multiple spaces and tab character
+ {
+ input: "command arg1 arg2\targ3",
+ expected: []string{`command`, `arg1`, `arg2`, `arg3`},
+ },
+ // 6 - command with single argument with space character - escaped with backspace
+ {
+ input: `command arg1\ arg2`,
+ expected: []string{`command`, `arg1 arg2`},
+ },
+ // 7 - single quotes should escape special chars
+ {
+ input: `command 'arg1\ arg2'`,
+ expected: []string{`command`, `arg1\ arg2`},
+ },
+ }
+
+ for i, test := range tests {
+ errorPrefix := fmt.Sprintf("Test [%d]: ", i)
+ errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input)
+ actual, _ := parseUnixCommand(test.input)
+ if len(actual) != len(test.expected) {
+ t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual)
+ continue
+ }
+ for j := 0; j < len(actual); j++ {
+ if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart {
+ t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j)
+ }
+ }
+ }
+}
+
+func TestParseWindowsCommand(t *testing.T) {
+ tests := []struct {
+ input string
+ expected []string
+ }{
+ { // 0 - empty command - do not fail
+ input: ``,
+ expected: []string{},
+ },
+ { // 1 - cmd without args
+ input: `cmd`,
+ expected: []string{`cmd`},
+ },
+ { // 2 - multiple args
+ input: `cmd arg1 arg2`,
+ expected: []string{`cmd`, `arg1`, `arg2`},
+ },
+ { // 3 - multiple args with space
+ input: `cmd "combined arg" arg2`,
+ expected: []string{`cmd`, `combined arg`, `arg2`},
+ },
+ { // 4 - path without spaces
+ input: `mkdir C:\Windows\foo\bar`,
+ expected: []string{`mkdir`, `C:\Windows\foo\bar`},
+ },
+ { // 5 - command with space in quotes
+ input: `"command here"`,
+ expected: []string{`command here`},
+ },
+ { // 6 - argument with escaped quotes (two quotes)
+ input: `cmd ""arg""`,
+ expected: []string{`cmd`, `"arg"`},
+ },
+ { // 7 - argument with escaped quotes (backslash)
+ input: `cmd \"arg\"`,
+ expected: []string{`cmd`, `"arg"`},
+ },
+ { // 8 - two quotes (escaped) inside an inQuote element
+ input: `cmd "a ""quoted value"`,
+ expected: []string{`cmd`, `a "quoted value`},
+ },
+ // TODO - see how many quotes are dislayed if we use "", """, """""""
+ { // 9 - two quotes outside an inQuote element
+ input: `cmd a ""quoted value`,
+ expected: []string{`cmd`, `a`, `"quoted`, `value`},
+ },
+ { // 10 - path with space in quotes
+ input: `mkdir "C:\directory name\foobar"`,
+ expected: []string{`mkdir`, `C:\directory name\foobar`},
+ },
+ { // 11 - space without quotes
+ input: `mkdir C:\ space`,
+ expected: []string{`mkdir`, `C:\`, `space`},
+ },
+ { // 12 - space in quotes
+ input: `mkdir "C:\ space"`,
+ expected: []string{`mkdir`, `C:\ space`},
+ },
+ { // 13 - UNC
+ input: `mkdir \\?\C:\Users`,
+ expected: []string{`mkdir`, `\\?\C:\Users`},
+ },
+ { // 14 - UNC with space
+ input: `mkdir "\\?\C:\Program Files"`,
+ expected: []string{`mkdir`, `\\?\C:\Program Files`},
+ },
+
+ { // 15 - unclosed quotes - treat as if the path ends with quote
+ input: `mkdir "c:\Program files`,
+ expected: []string{`mkdir`, `c:\Program files`},
+ },
+ { // 16 - quotes used inside the argument
+ input: `mkdir "c:\P"rogra"m f"iles`,
+ expected: []string{`mkdir`, `c:\Program files`},
+ },
+ }
+
+ for i, test := range tests {
+ errorPrefix := fmt.Sprintf("Test [%d]: ", i)
+ errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input)
+
+ actual := parseWindowsCommand(test.input)
+ if len(actual) != len(test.expected) {
+ t.Errorf(errorPrefix+"Expected %d parts, got %d: %#v."+errorSuffix, len(test.expected), len(actual), actual)
+ continue
+ }
+ for j := 0; j < len(actual); j++ {
+ if expectedPart, actualPart := test.expected[j], actual[j]; expectedPart != actualPart {
+ t.Errorf(errorPrefix+"Expected: %v Actual: %v (index %d)."+errorSuffix, expectedPart, actualPart, j)
+ }
+ }
+ }
+}
+
+func TestSplitCommandAndArgs(t *testing.T) {
+
+ // force linux parsing. It's more robust and covers error cases
+ runtimeGoos = "linux"
+ defer func() {
+ runtimeGoos = runtime.GOOS
+ }()
+
+ var parseErrorContent = "error parsing command:"
+ var noCommandErrContent = "no command contained in"
+
+ tests := []struct {
+ input string
+ expectedCommand string
+ expectedArgs []string
+ expectedErrContent string
+ }{
+ // 0 - emtpy command
+ {
+ input: ``,
+ expectedCommand: ``,
+ expectedArgs: nil,
+ expectedErrContent: noCommandErrContent,
+ },
+ // 1 - command without arguments
+ {
+ input: `command`,
+ expectedCommand: `command`,
+ expectedArgs: nil,
+ expectedErrContent: ``,
+ },
+ // 2 - command with single argument
+ {
+ input: `command arg1`,
+ expectedCommand: `command`,
+ expectedArgs: []string{`arg1`},
+ expectedErrContent: ``,
+ },
+ // 3 - command with multiple arguments
+ {
+ input: `command arg1 arg2`,
+ expectedCommand: `command`,
+ expectedArgs: []string{`arg1`, `arg2`},
+ expectedErrContent: ``,
+ },
+ // 4 - command with unclosed quotes
+ {
+ input: `command "arg1 arg2`,
+ expectedCommand: "",
+ expectedArgs: nil,
+ expectedErrContent: parseErrorContent,
+ },
+ // 5 - command with unclosed quotes
+ {
+ input: `command 'arg1 arg2"`,
+ expectedCommand: "",
+ expectedArgs: nil,
+ expectedErrContent: parseErrorContent,
+ },
+ }
+
+ for i, test := range tests {
+ errorPrefix := fmt.Sprintf("Test [%d]: ", i)
+ errorSuffix := fmt.Sprintf(" Command to parse: [%s]", test.input)
+ actualCommand, actualArgs, actualErr := SplitCommandAndArgs(test.input)
+
+ // test if error matches expectation
+ if test.expectedErrContent != "" {
+ if actualErr == nil {
+ t.Errorf(errorPrefix+"Expected error with content [%s], found no error."+errorSuffix, test.expectedErrContent)
+ } else if !strings.Contains(actualErr.Error(), test.expectedErrContent) {
+ t.Errorf(errorPrefix+"Expected error with content [%s], found [%v]."+errorSuffix, test.expectedErrContent, actualErr)
+ }
+ } else if actualErr != nil {
+ t.Errorf(errorPrefix+"Expected no error, found [%v]."+errorSuffix, actualErr)
+ }
+
+ // test if command matches
+ if test.expectedCommand != actualCommand {
+ t.Errorf(errorPrefix+"Expected command: [%s], actual: [%s]."+errorSuffix, test.expectedCommand, actualCommand)
+ }
+
+ // test if arguments match
+ if len(test.expectedArgs) != len(actualArgs) {
+ t.Errorf(errorPrefix+"Wrong number of arguments! Expected [%v], actual [%v]."+errorSuffix, test.expectedArgs, actualArgs)
+ } else {
+ // test args only if the count matches.
+ for j, actualArg := range actualArgs {
+ expectedArg := test.expectedArgs[j]
+ if actualArg != expectedArg {
+ t.Errorf(errorPrefix+"Argument at position [%d] differ! Expected [%s], actual [%s]"+errorSuffix, j, expectedArg, actualArg)
+ }
+ }
+ }
+ }
+}
+
+func ExampleSplitCommandAndArgs() {
+ var commandLine string
+ var command string
+ var args []string
+
+ // just for the test - change GOOS and reset it at the end of the test
+ runtimeGoos = "windows"
+ defer func() {
+ runtimeGoos = runtime.GOOS
+ }()
+
+ commandLine = `mkdir /P "C:\Program Files"`
+ command, args, _ = SplitCommandAndArgs(commandLine)
+
+ fmt.Printf("Windows: %s: %s [%s]\n", commandLine, command, strings.Join(args, ","))
+
+ // set GOOS to linux
+ runtimeGoos = "linux"
+
+ commandLine = `mkdir -p /path/with\ space`
+ command, args, _ = SplitCommandAndArgs(commandLine)
+
+ fmt.Printf("Linux: %s: %s [%s]\n", commandLine, command, strings.Join(args, ","))
+
+ // Output:
+ // Windows: mkdir /P "C:\Program Files": mkdir [/P,C:\Program Files]
+ // Linux: mkdir -p /path/with\ space: mkdir [-p,/path/with space]
+}
diff --git a/middleware/context.go b/middleware/context.go
new file mode 100644
index 000000000..8868c1c03
--- /dev/null
+++ b/middleware/context.go
@@ -0,0 +1,135 @@
+package middleware
+
+import (
+ "net"
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// This file contains the context and functions available for
+// use in the templates.
+
+// Context is the context with which Caddy templates are executed.
+type Context struct {
+ Root http.FileSystem // TODO(miek): needed
+ Req *dns.Msg
+ W dns.ResponseWriter
+}
+
+// Now returns the current timestamp in the specified format.
+func (c Context) Now(format string) string {
+ return time.Now().Format(format)
+}
+
+// NowDate returns the current date/time that can be used
+// in other time functions.
+func (c Context) NowDate() time.Time {
+ return time.Now()
+}
+
+// Header gets the value of a header.
+func (c Context) Header() *dns.RR_Header {
+ // TODO(miek)
+ return nil
+}
+
+// IP gets the (remote) IP address of the client making the request.
+func (c Context) IP() string {
+ ip, _, err := net.SplitHostPort(c.W.RemoteAddr().String())
+ if err != nil {
+ return c.W.RemoteAddr().String()
+ }
+ return ip
+}
+
+// Post gets the (remote) Port of the client making the request.
+func (c Context) Port() (string, error) {
+ _, port, err := net.SplitHostPort(c.W.RemoteAddr().String())
+ if err != nil {
+ return "0", err
+ }
+ return port, nil
+}
+
+// Proto gets the protocol used as the transport. This
+// will be udp or tcp.
+func (c Context) Proto() string {
+ if _, ok := c.W.RemoteAddr().(*net.UDPAddr); ok {
+ return "udp"
+ }
+ if _, ok := c.W.RemoteAddr().(*net.TCPAddr); ok {
+ return "tcp"
+ }
+ return "udp"
+}
+
+// Family returns the family of the transport.
+// 1 for IPv4 and 2 for IPv6.
+func (c Context) Family() int {
+ var a net.IP
+ ip := c.W.RemoteAddr()
+ if i, ok := ip.(*net.UDPAddr); ok {
+ a = i.IP
+ }
+ if i, ok := ip.(*net.TCPAddr); ok {
+ a = i.IP
+ }
+
+ if a.To4() != nil {
+ return 1
+ }
+ return 2
+}
+
+// Type returns the type of the question as a string.
+func (c Context) Type() string {
+ return dns.Type(c.Req.Question[0].Qtype).String()
+}
+
+// QType returns the type of the question as a uint16.
+func (c Context) QType() uint16 {
+ return c.Req.Question[0].Qtype
+}
+
+// Name returns the name of the question in the request. Note
+// this name will always have a closing dot and will be lower cased.
+func (c Context) Name() string {
+ return strings.ToLower(dns.Name(c.Req.Question[0].Name).String())
+}
+
+// QName returns the name of the question in the request.
+func (c Context) QName() string {
+ return dns.Name(c.Req.Question[0].Name).String()
+}
+
+// Class returns the class of the question in the request.
+func (c Context) Class() string {
+ return dns.Class(c.Req.Question[0].Qclass).String()
+}
+
+// QClass returns the class of the question in the request.
+func (c Context) QClass() uint16 {
+ return c.Req.Question[0].Qclass
+}
+
+// More convience types for extracting stuff from a message?
+// Header?
+
+// ErrorMessage returns an error message suitable for sending
+// back to the client.
+func (c Context) ErrorMessage(rcode int) *dns.Msg {
+ m := new(dns.Msg)
+ m.SetRcode(c.Req, rcode)
+ return m
+}
+
+// AnswerMessage returns an error message suitable for sending
+// back to the client.
+func (c Context) AnswerMessage() *dns.Msg {
+ m := new(dns.Msg)
+ m.SetReply(c.Req)
+ return m
+}
diff --git a/middleware/context_test.go b/middleware/context_test.go
new file mode 100644
index 000000000..689c47c13
--- /dev/null
+++ b/middleware/context_test.go
@@ -0,0 +1,613 @@
+package middleware
+
+import (
+ "bytes"
+ "fmt"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestInclude(t *testing.T) {
+ context := getContextOrFail(t)
+
+ inputFilename := "test_file"
+ absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
+ defer func() {
+ err := os.Remove(absInFilePath)
+ if err != nil && !os.IsNotExist(err) {
+ t.Fatalf("Failed to clean test file!")
+ }
+ }()
+
+ tests := []struct {
+ fileContent string
+ expectedContent string
+ shouldErr bool
+ expectedErrorContent string
+ }{
+ // Test 0 - all good
+ {
+ fileContent: `str1 {{ .Root }} str2`,
+ expectedContent: fmt.Sprintf("str1 %s str2", context.Root),
+ shouldErr: false,
+ expectedErrorContent: "",
+ },
+ // Test 1 - failure on template.Parse
+ {
+ fileContent: `str1 {{ .Root } str2`,
+ expectedContent: "",
+ shouldErr: true,
+ expectedErrorContent: `unexpected "}" in operand`,
+ },
+ // Test 3 - failure on template.Execute
+ {
+ fileContent: `str1 {{ .InvalidField }} str2`,
+ expectedContent: "",
+ shouldErr: true,
+ expectedErrorContent: `InvalidField is not a field of struct type middleware.Context`,
+ },
+ }
+
+ for i, test := range tests {
+ testPrefix := getTestPrefix(i)
+
+ // WriteFile truncates the contentt
+ err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
+ if err != nil {
+ t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
+ }
+
+ content, err := context.Include(inputFilename)
+ if err != nil {
+ if !test.shouldErr {
+ t.Errorf(testPrefix+"Expected no error, found [%s]", test.expectedErrorContent, err.Error())
+ }
+ if !strings.Contains(err.Error(), test.expectedErrorContent) {
+ t.Errorf(testPrefix+"Expected error content [%s], found [%s]", test.expectedErrorContent, err.Error())
+ }
+ }
+
+ if err == nil && test.shouldErr {
+ t.Errorf(testPrefix+"Expected error [%s] but found nil. Input file was: %s", test.expectedErrorContent, inputFilename)
+ }
+
+ if content != test.expectedContent {
+ t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
+ }
+ }
+}
+
+func TestIncludeNotExisting(t *testing.T) {
+ context := getContextOrFail(t)
+
+ _, err := context.Include("not_existing")
+ if err == nil {
+ t.Errorf("Expected error but found nil!")
+ }
+}
+
+func TestMarkdown(t *testing.T) {
+ context := getContextOrFail(t)
+
+ inputFilename := "test_file"
+ absInFilePath := filepath.Join(fmt.Sprintf("%s", context.Root), inputFilename)
+ defer func() {
+ err := os.Remove(absInFilePath)
+ if err != nil && !os.IsNotExist(err) {
+ t.Fatalf("Failed to clean test file!")
+ }
+ }()
+
+ tests := []struct {
+ fileContent string
+ expectedContent string
+ }{
+ // Test 0 - test parsing of markdown
+ {
+ fileContent: "* str1\n* str2\n",
+ expectedContent: "<ul>\n<li>str1</li>\n<li>str2</li>\n</ul>\n",
+ },
+ }
+
+ for i, test := range tests {
+ testPrefix := getTestPrefix(i)
+
+ // WriteFile truncates the contentt
+ err := ioutil.WriteFile(absInFilePath, []byte(test.fileContent), os.ModePerm)
+ if err != nil {
+ t.Fatal(testPrefix+"Failed to create test file. Error was: %v", err)
+ }
+
+ content, _ := context.Markdown(inputFilename)
+ if content != test.expectedContent {
+ t.Errorf(testPrefix+"Expected content [%s] but found [%s]. Input file was: %s", test.expectedContent, content, inputFilename)
+ }
+ }
+}
+
+func TestCookie(t *testing.T) {
+
+ tests := []struct {
+ cookie *http.Cookie
+ cookieName string
+ expectedValue string
+ }{
+ // Test 0 - happy path
+ {
+ cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
+ cookieName: "cookieName",
+ expectedValue: "cookieValue",
+ },
+ // Test 1 - try to get a non-existing cookie
+ {
+ cookie: &http.Cookie{Name: "cookieName", Value: "cookieValue"},
+ cookieName: "notExisting",
+ expectedValue: "",
+ },
+ // Test 2 - partial name match
+ {
+ cookie: &http.Cookie{Name: "cookie", Value: "cookieValue"},
+ cookieName: "cook",
+ expectedValue: "",
+ },
+ // Test 3 - cookie with optional fields
+ {
+ cookie: &http.Cookie{Name: "cookie", Value: "cookieValue", Path: "/path", Domain: "https://localhost", Expires: (time.Now().Add(10 * time.Minute)), MaxAge: 120},
+ cookieName: "cookie",
+ expectedValue: "cookieValue",
+ },
+ }
+
+ for i, test := range tests {
+ testPrefix := getTestPrefix(i)
+
+ // reinitialize the context for each test
+ context := getContextOrFail(t)
+
+ context.Req.AddCookie(test.cookie)
+
+ actualCookieVal := context.Cookie(test.cookieName)
+
+ if actualCookieVal != test.expectedValue {
+ t.Errorf(testPrefix+"Expected cookie value [%s] but found [%s] for cookie with name %s", test.expectedValue, actualCookieVal, test.cookieName)
+ }
+ }
+}
+
+func TestCookieMultipleCookies(t *testing.T) {
+ context := getContextOrFail(t)
+
+ cookieNameBase, cookieValueBase := "cookieName", "cookieValue"
+
+ // make sure that there's no state and multiple requests for different cookies return the correct result
+ for i := 0; i < 10; i++ {
+ context.Req.AddCookie(&http.Cookie{Name: fmt.Sprintf("%s%d", cookieNameBase, i), Value: fmt.Sprintf("%s%d", cookieValueBase, i)})
+ }
+
+ for i := 0; i < 10; i++ {
+ expectedCookieVal := fmt.Sprintf("%s%d", cookieValueBase, i)
+ actualCookieVal := context.Cookie(fmt.Sprintf("%s%d", cookieNameBase, i))
+ if actualCookieVal != expectedCookieVal {
+ t.Fatalf("Expected cookie value %s, found %s", expectedCookieVal, actualCookieVal)
+ }
+ }
+}
+
+func TestHeader(t *testing.T) {
+ context := getContextOrFail(t)
+
+ headerKey, headerVal := "Header1", "HeaderVal1"
+ context.Req.Header.Add(headerKey, headerVal)
+
+ actualHeaderVal := context.Header(headerKey)
+ if actualHeaderVal != headerVal {
+ t.Errorf("Expected header %s, found %s", headerVal, actualHeaderVal)
+ }
+
+ missingHeaderVal := context.Header("not-existing")
+ if missingHeaderVal != "" {
+ t.Errorf("Expected empty header value, found %s", missingHeaderVal)
+ }
+}
+
+func TestIP(t *testing.T) {
+ context := getContextOrFail(t)
+
+ tests := []struct {
+ inputRemoteAddr string
+ expectedIP string
+ }{
+ // Test 0 - ipv4 with port
+ {"1.1.1.1:1111", "1.1.1.1"},
+ // Test 1 - ipv4 without port
+ {"1.1.1.1", "1.1.1.1"},
+ // Test 2 - ipv6 with port
+ {"[::1]:11", "::1"},
+ // Test 3 - ipv6 without port and brackets
+ {"[2001:db8:a0b:12f0::1]", "[2001:db8:a0b:12f0::1]"},
+ // Test 4 - ipv6 with zone and port
+ {`[fe80:1::3%eth0]:44`, `fe80:1::3%eth0`},
+ }
+
+ for i, test := range tests {
+ testPrefix := getTestPrefix(i)
+
+ context.Req.RemoteAddr = test.inputRemoteAddr
+ actualIP := context.IP()
+
+ if actualIP != test.expectedIP {
+ t.Errorf(testPrefix+"Expected IP %s, found %s", test.expectedIP, actualIP)
+ }
+ }
+}
+
+func TestURL(t *testing.T) {
+ context := getContextOrFail(t)
+
+ inputURL := "http://localhost"
+ context.Req.RequestURI = inputURL
+
+ if inputURL != context.URI() {
+ t.Errorf("Expected url %s, found %s", inputURL, context.URI())
+ }
+}
+
+func TestHost(t *testing.T) {
+ tests := []struct {
+ input string
+ expectedHost string
+ shouldErr bool
+ }{
+ {
+ input: "localhost:123",
+ expectedHost: "localhost",
+ shouldErr: false,
+ },
+ {
+ input: "localhost",
+ expectedHost: "localhost",
+ shouldErr: false,
+ },
+ {
+ input: "[::]",
+ expectedHost: "",
+ shouldErr: true,
+ },
+ }
+
+ for _, test := range tests {
+ testHostOrPort(t, true, test.input, test.expectedHost, test.shouldErr)
+ }
+}
+
+func TestPort(t *testing.T) {
+ tests := []struct {
+ input string
+ expectedPort string
+ shouldErr bool
+ }{
+ {
+ input: "localhost:123",
+ expectedPort: "123",
+ shouldErr: false,
+ },
+ {
+ input: "localhost",
+ expectedPort: "80", // assuming 80 is the default port
+ shouldErr: false,
+ },
+ {
+ input: ":8080",
+ expectedPort: "8080",
+ shouldErr: false,
+ },
+ {
+ input: "[::]",
+ expectedPort: "",
+ shouldErr: true,
+ },
+ }
+
+ for _, test := range tests {
+ testHostOrPort(t, false, test.input, test.expectedPort, test.shouldErr)
+ }
+}
+
+func testHostOrPort(t *testing.T, isTestingHost bool, input, expectedResult string, shouldErr bool) {
+ context := getContextOrFail(t)
+
+ context.Req.Host = input
+ var actualResult, testedObject string
+ var err error
+
+ if isTestingHost {
+ actualResult, err = context.Host()
+ testedObject = "host"
+ } else {
+ actualResult, err = context.Port()
+ testedObject = "port"
+ }
+
+ if shouldErr && err == nil {
+ t.Errorf("Expected error, found nil!")
+ return
+ }
+
+ if !shouldErr && err != nil {
+ t.Errorf("Expected no error, found %s", err)
+ return
+ }
+
+ if actualResult != expectedResult {
+ t.Errorf("Expected %s %s, found %s", testedObject, expectedResult, actualResult)
+ }
+}
+
+func TestMethod(t *testing.T) {
+ context := getContextOrFail(t)
+
+ method := "POST"
+ context.Req.Method = method
+
+ if method != context.Method() {
+ t.Errorf("Expected method %s, found %s", method, context.Method())
+ }
+
+}
+
+func TestPathMatches(t *testing.T) {
+ context := getContextOrFail(t)
+
+ tests := []struct {
+ urlStr string
+ pattern string
+ shouldMatch bool
+ }{
+ // Test 0
+ {
+ urlStr: "http://localhost/",
+ pattern: "",
+ shouldMatch: true,
+ },
+ // Test 1
+ {
+ urlStr: "http://localhost",
+ pattern: "",
+ shouldMatch: true,
+ },
+ // Test 1
+ {
+ urlStr: "http://localhost/",
+ pattern: "/",
+ shouldMatch: true,
+ },
+ // Test 3
+ {
+ urlStr: "http://localhost/?param=val",
+ pattern: "/",
+ shouldMatch: true,
+ },
+ // Test 4
+ {
+ urlStr: "http://localhost/dir1/dir2",
+ pattern: "/dir2",
+ shouldMatch: false,
+ },
+ // Test 5
+ {
+ urlStr: "http://localhost/dir1/dir2",
+ pattern: "/dir1",
+ shouldMatch: true,
+ },
+ // Test 6
+ {
+ urlStr: "http://localhost:444/dir1/dir2",
+ pattern: "/dir1",
+ shouldMatch: true,
+ },
+ // Test 7
+ {
+ urlStr: "http://localhost/dir1/dir2",
+ pattern: "*/dir2",
+ shouldMatch: false,
+ },
+ }
+
+ for i, test := range tests {
+ testPrefix := getTestPrefix(i)
+ var err error
+ context.Req.URL, err = url.Parse(test.urlStr)
+ if err != nil {
+ t.Fatalf("Failed to prepare test URL from string %s! Error was: %s", test.urlStr, err)
+ }
+
+ matches := context.PathMatches(test.pattern)
+ if matches != test.shouldMatch {
+ t.Errorf(testPrefix+"Expected and actual result differ: expected to match [%t], actual matches [%t]", test.shouldMatch, matches)
+ }
+ }
+}
+
+func TestTruncate(t *testing.T) {
+ context := getContextOrFail(t)
+ tests := []struct {
+ inputString string
+ inputLength int
+ expected string
+ }{
+ // Test 0 - small length
+ {
+ inputString: "string",
+ inputLength: 1,
+ expected: "s",
+ },
+ // Test 1 - exact length
+ {
+ inputString: "string",
+ inputLength: 6,
+ expected: "string",
+ },
+ // Test 2 - bigger length
+ {
+ inputString: "string",
+ inputLength: 10,
+ expected: "string",
+ },
+ // Test 3 - zero length
+ {
+ inputString: "string",
+ inputLength: 0,
+ expected: "",
+ },
+ // Test 4 - negative, smaller length
+ {
+ inputString: "string",
+ inputLength: -5,
+ expected: "tring",
+ },
+ // Test 5 - negative, exact length
+ {
+ inputString: "string",
+ inputLength: -6,
+ expected: "string",
+ },
+ // Test 6 - negative, bigger length
+ {
+ inputString: "string",
+ inputLength: -7,
+ expected: "string",
+ },
+ }
+
+ for i, test := range tests {
+ actual := context.Truncate(test.inputString, test.inputLength)
+ if actual != test.expected {
+ t.Errorf(getTestPrefix(i)+"Expected '%s', found '%s'. Input was Truncate(%q, %d)", test.expected, actual, test.inputString, test.inputLength)
+ }
+ }
+}
+
+func TestStripHTML(t *testing.T) {
+ context := getContextOrFail(t)
+ tests := []struct {
+ input string
+ expected string
+ }{
+ // Test 0 - no tags
+ {
+ input: `h1`,
+ expected: `h1`,
+ },
+ // Test 1 - happy path
+ {
+ input: `<h1>h1</h1>`,
+ expected: `h1`,
+ },
+ // Test 2 - tag in quotes
+ {
+ input: `<h1">">h1</h1>`,
+ expected: `h1`,
+ },
+ // Test 3 - multiple tags
+ {
+ input: `<h1><b>h1</b></h1>`,
+ expected: `h1`,
+ },
+ // Test 4 - tags not closed
+ {
+ input: `<h1`,
+ expected: `<h1`,
+ },
+ // Test 5 - false start
+ {
+ input: `<h1<b>hi`,
+ expected: `<h1hi`,
+ },
+ }
+
+ for i, test := range tests {
+ actual := context.StripHTML(test.input)
+ if actual != test.expected {
+ t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripHTML(%s)", test.expected, actual, test.input)
+ }
+ }
+}
+
+func TestStripExt(t *testing.T) {
+ context := getContextOrFail(t)
+ tests := []struct {
+ input string
+ expected string
+ }{
+ // Test 0 - empty input
+ {
+ input: "",
+ expected: "",
+ },
+ // Test 1 - relative file with ext
+ {
+ input: "file.ext",
+ expected: "file",
+ },
+ // Test 2 - relative file without ext
+ {
+ input: "file",
+ expected: "file",
+ },
+ // Test 3 - absolute file without ext
+ {
+ input: "/file",
+ expected: "/file",
+ },
+ // Test 4 - absolute file with ext
+ {
+ input: "/file.ext",
+ expected: "/file",
+ },
+ // Test 5 - with ext but ends with /
+ {
+ input: "/dir.ext/",
+ expected: "/dir.ext/",
+ },
+ // Test 6 - file with ext under dir with ext
+ {
+ input: "/dir.ext/file.ext",
+ expected: "/dir.ext/file",
+ },
+ }
+
+ for i, test := range tests {
+ actual := context.StripExt(test.input)
+ if actual != test.expected {
+ t.Errorf(getTestPrefix(i)+"Expected %s, found %s. Input was StripExt(%q)", test.expected, actual, test.input)
+ }
+ }
+}
+
+func initTestContext() (Context, error) {
+ body := bytes.NewBufferString("request body")
+ request, err := http.NewRequest("GET", "https://localhost", body)
+ if err != nil {
+ return Context{}, err
+ }
+
+ return Context{Root: http.Dir(os.TempDir()), Req: request}, nil
+}
+
+func getContextOrFail(t *testing.T) Context {
+ context, err := initTestContext()
+ if err != nil {
+ t.Fatalf("Failed to prepare test context")
+ }
+ return context
+}
+
+func getTestPrefix(testN int) string {
+ return fmt.Sprintf("Test [%d]: ", testN)
+}
diff --git a/middleware/errors/errors.go b/middleware/errors/errors.go
new file mode 100644
index 000000000..bf5bc7aae
--- /dev/null
+++ b/middleware/errors/errors.go
@@ -0,0 +1,100 @@
+// Package errors implements an HTTP error handling middleware.
+package errors
+
+import (
+ "fmt"
+ "log"
+ "runtime"
+ "strings"
+ "time"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+// ErrorHandler handles DNS errors (and errors from other middleware).
+type ErrorHandler struct {
+ Next middleware.Handler
+ LogFile string
+ Log *log.Logger
+ LogRoller *middleware.LogRoller
+ Debug bool // if true, errors are written out to client rather than to a log
+}
+
+func (h ErrorHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ defer h.recovery(w, r)
+
+ rcode, err := h.Next.ServeDNS(w, r)
+
+ if err != nil {
+ errMsg := fmt.Sprintf("%s [ERROR %d %s %s] %v", time.Now().Format(timeFormat), rcode, r.Question[0].Name, dns.Type(r.Question[0].Qclass), err)
+
+ if h.Debug {
+ // Write error to response as a txt message instead of to log
+ answer := debugMsg(rcode, r)
+ txt, _ := dns.NewRR(". IN 0 TXT " + errMsg)
+ answer.Answer = append(answer.Answer, txt)
+ w.WriteMsg(answer)
+ return 0, err
+ }
+ h.Log.Println(errMsg)
+ }
+
+ return rcode, err
+}
+
+func (h ErrorHandler) recovery(w dns.ResponseWriter, r *dns.Msg) {
+ rec := recover()
+ if rec == nil {
+ return
+ }
+
+ // Obtain source of panic
+ // From: https://gist.github.com/swdunlop/9629168
+ var name, file string // function name, file name
+ var line int
+ var pc [16]uintptr
+ n := runtime.Callers(3, pc[:])
+ for _, pc := range pc[:n] {
+ fn := runtime.FuncForPC(pc)
+ if fn == nil {
+ continue
+ }
+ file, line = fn.FileLine(pc)
+ name = fn.Name()
+ if !strings.HasPrefix(name, "runtime.") {
+ break
+ }
+ }
+
+ // Trim file path
+ delim := "/coredns/"
+ pkgPathPos := strings.Index(file, delim)
+ if pkgPathPos > -1 && len(file) > pkgPathPos+len(delim) {
+ file = file[pkgPathPos+len(delim):]
+ }
+
+ panicMsg := fmt.Sprintf("%s [PANIC %s %s] %s:%d - %v", time.Now().Format(timeFormat), r.Question[0].Name, dns.Type(r.Question[0].Qtype), file, line, rec)
+ if h.Debug {
+ // Write error and stack trace to the response rather than to a log
+ var stackBuf [4096]byte
+ stack := stackBuf[:runtime.Stack(stackBuf[:], false)]
+ answer := debugMsg(dns.RcodeServerFailure, r)
+ // add stack buf in TXT, limited to 255 chars for now.
+ txt, _ := dns.NewRR(". IN 0 TXT " + string(stack[:255]))
+ answer.Answer = append(answer.Answer, txt)
+ w.WriteMsg(answer)
+ } else {
+ // Currently we don't use the function name, since file:line is more conventional
+ h.Log.Printf(panicMsg)
+ }
+}
+
+// debugMsg creates a debug message that gets send back to the client.
+func debugMsg(rcode int, r *dns.Msg) *dns.Msg {
+ answer := new(dns.Msg)
+ answer.SetRcode(r, rcode)
+ return answer
+}
+
+const timeFormat = "02/Jan/2006:15:04:05 -0700"
diff --git a/middleware/errors/errors_test.go b/middleware/errors/errors_test.go
new file mode 100644
index 000000000..4434e835c
--- /dev/null
+++ b/middleware/errors/errors_test.go
@@ -0,0 +1,168 @@
+package errors
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "testing"
+
+ "github.com/miekg/coredns/middleware"
+)
+
+func TestErrors(t *testing.T) {
+ // create a temporary page
+ path := filepath.Join(os.TempDir(), "errors_test.html")
+ f, err := os.Create(path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.Remove(path)
+
+ const content = "This is a error page"
+ _, err = f.WriteString(content)
+ if err != nil {
+ t.Fatal(err)
+ }
+ f.Close()
+
+ buf := bytes.Buffer{}
+ em := ErrorHandler{
+ ErrorPages: map[int]string{
+ http.StatusNotFound: path,
+ http.StatusForbidden: "not_exist_file",
+ },
+ Log: log.New(&buf, "", 0),
+ }
+ _, notExistErr := os.Open("not_exist_file")
+
+ testErr := errors.New("test error")
+ tests := []struct {
+ next middleware.Handler
+ expectedCode int
+ expectedBody string
+ expectedLog string
+ expectedErr error
+ }{
+ {
+ next: genErrorHandler(http.StatusOK, nil, "normal"),
+ expectedCode: http.StatusOK,
+ expectedBody: "normal",
+ expectedLog: "",
+ expectedErr: nil,
+ },
+ {
+ next: genErrorHandler(http.StatusMovedPermanently, testErr, ""),
+ expectedCode: http.StatusMovedPermanently,
+ expectedBody: "",
+ expectedLog: fmt.Sprintf("[ERROR %d %s] %v\n", http.StatusMovedPermanently, "/", testErr),
+ expectedErr: testErr,
+ },
+ {
+ next: genErrorHandler(http.StatusBadRequest, nil, ""),
+ expectedCode: 0,
+ expectedBody: fmt.Sprintf("%d %s\n", http.StatusBadRequest,
+ http.StatusText(http.StatusBadRequest)),
+ expectedLog: "",
+ expectedErr: nil,
+ },
+ {
+ next: genErrorHandler(http.StatusNotFound, nil, ""),
+ expectedCode: 0,
+ expectedBody: content,
+ expectedLog: "",
+ expectedErr: nil,
+ },
+ {
+ next: genErrorHandler(http.StatusForbidden, nil, ""),
+ expectedCode: 0,
+ expectedBody: fmt.Sprintf("%d %s\n", http.StatusForbidden,
+ http.StatusText(http.StatusForbidden)),
+ expectedLog: fmt.Sprintf("[NOTICE %d /] could not load error page: %v\n",
+ http.StatusForbidden, notExistErr),
+ expectedErr: nil,
+ },
+ }
+
+ req, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for i, test := range tests {
+ em.Next = test.next
+ buf.Reset()
+ rec := httptest.NewRecorder()
+ code, err := em.ServeHTTP(rec, req)
+
+ if err != test.expectedErr {
+ t.Errorf("Test %d: Expected error %v, but got %v",
+ i, test.expectedErr, err)
+ }
+ if code != test.expectedCode {
+ t.Errorf("Test %d: Expected status code %d, but got %d",
+ i, test.expectedCode, code)
+ }
+ if body := rec.Body.String(); body != test.expectedBody {
+ t.Errorf("Test %d: Expected body %q, but got %q",
+ i, test.expectedBody, body)
+ }
+ if log := buf.String(); !strings.Contains(log, test.expectedLog) {
+ t.Errorf("Test %d: Expected log %q, but got %q",
+ i, test.expectedLog, log)
+ }
+ }
+}
+
+func TestVisibleErrorWithPanic(t *testing.T) {
+ const panicMsg = "I'm a panic"
+ eh := ErrorHandler{
+ ErrorPages: make(map[int]string),
+ Debug: true,
+ Next: middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
+ panic(panicMsg)
+ }),
+ }
+
+ req, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rec := httptest.NewRecorder()
+
+ code, err := eh.ServeHTTP(rec, req)
+
+ if code != 0 {
+ t.Errorf("Expected error handler to return 0 (it should write to response), got status %d", code)
+ }
+ if err != nil {
+ t.Errorf("Expected error handler to return nil error (it should panic!), but got '%v'", err)
+ }
+
+ body := rec.Body.String()
+
+ if !strings.Contains(body, "[PANIC /] middleware/errors/errors_test.go") {
+ t.Errorf("Expected response body to contain error log line, but it didn't:\n%s", body)
+ }
+ if !strings.Contains(body, panicMsg) {
+ t.Errorf("Expected response body to contain panic message, but it didn't:\n%s", body)
+ }
+ if len(body) < 500 {
+ t.Errorf("Expected response body to contain stack trace, but it was too short: len=%d", len(body))
+ }
+}
+
+func genErrorHandler(status int, err error, body string) middleware.Handler {
+ return middleware.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
+ if len(body) > 0 {
+ w.Header().Set("Content-Length", strconv.Itoa(len(body)))
+ fmt.Fprint(w, body)
+ }
+ return status, err
+ })
+}
diff --git a/middleware/etcd/TODO b/middleware/etcd/TODO
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/middleware/etcd/TODO
diff --git a/middleware/exchange.go b/middleware/exchange.go
new file mode 100644
index 000000000..837fa3cdc
--- /dev/null
+++ b/middleware/exchange.go
@@ -0,0 +1,10 @@
+package middleware
+
+import "github.com/miekg/dns"
+
+// Exchang sends message m to the server.
+// TODO(miek): optionally it can do retries of other silly stuff.
+func Exchange(c *dns.Client, m *dns.Msg, server string) (*dns.Msg, error) {
+ r, _, err := c.Exchange(m, server)
+ return r, err
+}
diff --git a/middleware/file/file.go b/middleware/file/file.go
new file mode 100644
index 000000000..5bc5a3a3a
--- /dev/null
+++ b/middleware/file/file.go
@@ -0,0 +1,89 @@
+package file
+
+// TODO(miek): the zone's implementation is basically non-existent
+// we return a list and when searching for an answer we iterate
+// over the list. This must be moved to a tree-like structure and
+// have some fluff for DNSSEC (and be memory efficient).
+
+import (
+ "strings"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+type (
+ File struct {
+ Next middleware.Handler
+ Zones Zones
+ // Maybe a list of all zones as well, as a []string?
+ }
+
+ Zone []dns.RR
+ Zones struct {
+ Z map[string]Zone // utterly braindead impl. TODO(miek): fix
+ Names []string
+ }
+)
+
+func (f File) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ context := middleware.Context{W: w, Req: r}
+ qname := context.Name()
+ zone := middleware.Zones(f.Zones.Names).Matches(qname)
+ if zone == "" {
+ return f.Next.ServeDNS(w, r)
+ }
+
+ names, nodata := f.Zones.Z[zone].lookup(qname, context.QType())
+ var answer *dns.Msg
+ switch {
+ case nodata:
+ answer = context.AnswerMessage()
+ answer.Ns = names
+ case len(names) == 0:
+ answer = context.AnswerMessage()
+ answer.Ns = names
+ answer.Rcode = dns.RcodeNameError
+ case len(names) > 0:
+ answer = context.AnswerMessage()
+ answer.Answer = names
+ default:
+ answer = context.ErrorMessage(dns.RcodeServerFailure)
+ }
+ // Check return size, etc. TODO(miek)
+ w.WriteMsg(answer)
+ return 0, nil
+}
+
+// Lookup will try to find qname and qtype in z. It returns the
+// records found *or* a boolean saying NODATA. If the answer
+// is NODATA then the RR returned is the SOA record.
+//
+// TODO(miek): EXTREMELY STUPID IMPLEMENTATION.
+// Doesn't do much, no delegation, no cname, nothing really, etc.
+// TODO(miek): even NODATA looks broken
+func (z Zone) lookup(qname string, qtype uint16) ([]dns.RR, bool) {
+ var (
+ nodata bool
+ rep []dns.RR
+ soa dns.RR
+ )
+
+ for _, rr := range z {
+ if rr.Header().Rrtype == dns.TypeSOA {
+ soa = rr
+ }
+ // Match function in Go DNS?
+ if strings.ToLower(rr.Header().Name) == qname {
+ if rr.Header().Rrtype == qtype {
+ rep = append(rep, rr)
+ nodata = false
+ }
+
+ }
+ }
+ if nodata {
+ return []dns.RR{soa}, true
+ }
+ return rep, false
+}
diff --git a/middleware/file/file_test.go b/middleware/file/file_test.go
new file mode 100644
index 000000000..54584b5cc
--- /dev/null
+++ b/middleware/file/file_test.go
@@ -0,0 +1,325 @@
+package file
+
+import (
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+var testDir = filepath.Join(os.TempDir(), "caddy_testdir")
+var ErrCustom = errors.New("Custom Error")
+
+// testFiles is a map with relative paths to test files as keys and file content as values.
+// The map represents the following structure:
+// - $TEMP/caddy_testdir/
+// '-- file1.html
+// '-- dirwithindex/
+// '---- index.html
+// '-- dir/
+// '---- file2.html
+// '---- hidden.html
+var testFiles = map[string]string{
+ "file1.html": "<h1>file1.html</h1>",
+ filepath.Join("dirwithindex", "index.html"): "<h1>dirwithindex/index.html</h1>",
+ filepath.Join("dir", "file2.html"): "<h1>dir/file2.html</h1>",
+ filepath.Join("dir", "hidden.html"): "<h1>dir/hidden.html</h1>",
+}
+
+// TestServeHTTP covers positive scenarios when serving files.
+func TestServeHTTP(t *testing.T) {
+
+ beforeServeHTTPTest(t)
+ defer afterServeHTTPTest(t)
+
+ fileserver := FileServer(http.Dir(testDir), []string{"hidden.html"})
+
+ movedPermanently := "Moved Permanently"
+
+ tests := []struct {
+ url string
+
+ expectedStatus int
+ expectedBodyContent string
+ }{
+ // Test 0 - access without any path
+ {
+ url: "https://foo",
+ expectedStatus: http.StatusNotFound,
+ },
+ // Test 1 - access root (without index.html)
+ {
+ url: "https://foo/",
+ expectedStatus: http.StatusNotFound,
+ },
+ // Test 2 - access existing file
+ {
+ url: "https://foo/file1.html",
+ expectedStatus: http.StatusOK,
+ expectedBodyContent: testFiles["file1.html"],
+ },
+ // Test 3 - access folder with index file with trailing slash
+ {
+ url: "https://foo/dirwithindex/",
+ expectedStatus: http.StatusOK,
+ expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")],
+ },
+ // Test 4 - access folder with index file without trailing slash
+ {
+ url: "https://foo/dirwithindex",
+ expectedStatus: http.StatusMovedPermanently,
+ expectedBodyContent: movedPermanently,
+ },
+ // Test 5 - access folder without index file
+ {
+ url: "https://foo/dir/",
+ expectedStatus: http.StatusNotFound,
+ },
+ // Test 6 - access folder without trailing slash
+ {
+ url: "https://foo/dir",
+ expectedStatus: http.StatusMovedPermanently,
+ expectedBodyContent: movedPermanently,
+ },
+ // Test 6 - access file with trailing slash
+ {
+ url: "https://foo/file1.html/",
+ expectedStatus: http.StatusMovedPermanently,
+ expectedBodyContent: movedPermanently,
+ },
+ // Test 7 - access not existing path
+ {
+ url: "https://foo/not_existing",
+ expectedStatus: http.StatusNotFound,
+ },
+ // Test 8 - access a file, marked as hidden
+ {
+ url: "https://foo/dir/hidden.html",
+ expectedStatus: http.StatusNotFound,
+ },
+ // Test 9 - access a index file directly
+ {
+ url: "https://foo/dirwithindex/index.html",
+ expectedStatus: http.StatusOK,
+ expectedBodyContent: testFiles[filepath.Join("dirwithindex", "index.html")],
+ },
+ // Test 10 - send a request with query params
+ {
+ url: "https://foo/dir?param1=val",
+ expectedStatus: http.StatusMovedPermanently,
+ expectedBodyContent: movedPermanently,
+ },
+ }
+
+ for i, test := range tests {
+ responseRecorder := httptest.NewRecorder()
+ request, err := http.NewRequest("GET", test.url, strings.NewReader(""))
+ status, err := fileserver.ServeHTTP(responseRecorder, request)
+
+ // check if error matches expectations
+ if err != nil {
+ t.Errorf(getTestPrefix(i)+"Serving file at %s failed. Error was: %v", test.url, err)
+ }
+
+ // check status code
+ if test.expectedStatus != status {
+ t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
+ }
+
+ // check body content
+ if !strings.Contains(responseRecorder.Body.String(), test.expectedBodyContent) {
+ t.Errorf(getTestPrefix(i)+"Expected body to contain %q, found %q", test.expectedBodyContent, responseRecorder.Body.String())
+ }
+ }
+
+}
+
+// beforeServeHTTPTest creates a test directory with the structure, defined in the variable testFiles
+func beforeServeHTTPTest(t *testing.T) {
+ // make the root test dir
+ err := os.Mkdir(testDir, os.ModePerm)
+ if err != nil {
+ if !os.IsExist(err) {
+ t.Fatalf("Failed to create test dir. Error was: %v", err)
+ return
+ }
+ }
+
+ for relFile, fileContent := range testFiles {
+ absFile := filepath.Join(testDir, relFile)
+
+ // make sure the parent directories exist
+ parentDir := filepath.Dir(absFile)
+ _, err = os.Stat(parentDir)
+ if err != nil {
+ os.MkdirAll(parentDir, os.ModePerm)
+ }
+
+ // now create the test files
+ f, err := os.Create(absFile)
+ if err != nil {
+ t.Fatalf("Failed to create test file %s. Error was: %v", absFile, err)
+ return
+ }
+
+ // and fill them with content
+ _, err = f.WriteString(fileContent)
+ if err != nil {
+ t.Fatalf("Failed to write to %s. Error was: %v", absFile, err)
+ return
+ }
+ f.Close()
+ }
+
+}
+
+// afterServeHTTPTest removes the test dir and all its content
+func afterServeHTTPTest(t *testing.T) {
+ // cleans up everything under the test dir. No need to clean the individual files.
+ err := os.RemoveAll(testDir)
+ if err != nil {
+ t.Fatalf("Failed to clean up test dir %s. Error was: %v", testDir, err)
+ }
+}
+
+// failingFS implements the http.FileSystem interface. The Open method always returns the error, assigned to err
+type failingFS struct {
+ err error // the error to return when Open is called
+ fileImpl http.File // inject the file implementation
+}
+
+// Open returns the assigned failingFile and error
+func (f failingFS) Open(path string) (http.File, error) {
+ return f.fileImpl, f.err
+}
+
+// failingFile implements http.File but returns a predefined error on every Stat() method call.
+type failingFile struct {
+ http.File
+ err error
+}
+
+// Stat returns nil FileInfo and the provided error on every call
+func (ff failingFile) Stat() (os.FileInfo, error) {
+ return nil, ff.err
+}
+
+// Close is noop and returns no error
+func (ff failingFile) Close() error {
+ return nil
+}
+
+// TestServeHTTPFailingFS tests error cases where the Open function fails with various errors.
+func TestServeHTTPFailingFS(t *testing.T) {
+
+ tests := []struct {
+ fsErr error
+ expectedStatus int
+ expectedErr error
+ expectedHeaders map[string]string
+ }{
+ {
+ fsErr: os.ErrNotExist,
+ expectedStatus: http.StatusNotFound,
+ expectedErr: nil,
+ },
+ {
+ fsErr: os.ErrPermission,
+ expectedStatus: http.StatusForbidden,
+ expectedErr: os.ErrPermission,
+ },
+ {
+ fsErr: ErrCustom,
+ expectedStatus: http.StatusServiceUnavailable,
+ expectedErr: ErrCustom,
+ expectedHeaders: map[string]string{"Retry-After": "5"},
+ },
+ }
+
+ for i, test := range tests {
+ // initialize a file server with the failing FileSystem
+ fileserver := FileServer(failingFS{err: test.fsErr}, nil)
+
+ // prepare the request and response
+ request, err := http.NewRequest("GET", "https://foo/", nil)
+ if err != nil {
+ t.Fatalf("Failed to build request. Error was: %v", err)
+ }
+ responseRecorder := httptest.NewRecorder()
+
+ status, actualErr := fileserver.ServeHTTP(responseRecorder, request)
+
+ // check the status
+ if status != test.expectedStatus {
+ t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
+ }
+
+ // check the error
+ if actualErr != test.expectedErr {
+ t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr)
+ }
+
+ // check the headers - a special case for server under load
+ if test.expectedHeaders != nil && len(test.expectedHeaders) > 0 {
+ for expectedKey, expectedVal := range test.expectedHeaders {
+ actualVal := responseRecorder.Header().Get(expectedKey)
+ if expectedVal != actualVal {
+ t.Errorf(getTestPrefix(i)+"Expected header %s: %s, found %s", expectedKey, expectedVal, actualVal)
+ }
+ }
+ }
+ }
+}
+
+// TestServeHTTPFailingStat tests error cases where the initial Open function succeeds, but the Stat method on the opened file fails.
+func TestServeHTTPFailingStat(t *testing.T) {
+
+ tests := []struct {
+ statErr error
+ expectedStatus int
+ expectedErr error
+ }{
+ {
+ statErr: os.ErrNotExist,
+ expectedStatus: http.StatusNotFound,
+ expectedErr: nil,
+ },
+ {
+ statErr: os.ErrPermission,
+ expectedStatus: http.StatusForbidden,
+ expectedErr: os.ErrPermission,
+ },
+ {
+ statErr: ErrCustom,
+ expectedStatus: http.StatusInternalServerError,
+ expectedErr: ErrCustom,
+ },
+ }
+
+ for i, test := range tests {
+ // initialize a file server. The FileSystem will not fail, but calls to the Stat method of the returned File object will
+ fileserver := FileServer(failingFS{err: nil, fileImpl: failingFile{err: test.statErr}}, nil)
+
+ // prepare the request and response
+ request, err := http.NewRequest("GET", "https://foo/", nil)
+ if err != nil {
+ t.Fatalf("Failed to build request. Error was: %v", err)
+ }
+ responseRecorder := httptest.NewRecorder()
+
+ status, actualErr := fileserver.ServeHTTP(responseRecorder, request)
+
+ // check the status
+ if status != test.expectedStatus {
+ t.Errorf(getTestPrefix(i)+"Expected status %d, found %d", test.expectedStatus, status)
+ }
+
+ // check the error
+ if actualErr != test.expectedErr {
+ t.Errorf(getTestPrefix(i)+"Expected err %v, found %v", test.expectedErr, actualErr)
+ }
+ }
+}
diff --git a/middleware/host.go b/middleware/host.go
new file mode 100644
index 000000000..17ecedb5f
--- /dev/null
+++ b/middleware/host.go
@@ -0,0 +1,22 @@
+package middleware
+
+import (
+ "net"
+ "strings"
+
+ "github.com/miekg/dns"
+)
+
+// Host represents a host from the Caddyfile, may contain port.
+type Host string
+
+// Standard host will return the host portion of host, stripping
+// of any port. The host will also be fully qualified and lowercased.
+func (h Host) StandardHost() string {
+ // separate host and port
+ host, _, err := net.SplitHostPort(string(h))
+ if err != nil {
+ host, _, _ = net.SplitHostPort(string(h) + ":")
+ }
+ return strings.ToLower(dns.Fqdn(host))
+}
diff --git a/middleware/log/log.go b/middleware/log/log.go
new file mode 100644
index 000000000..109add9f5
--- /dev/null
+++ b/middleware/log/log.go
@@ -0,0 +1,66 @@
+// Package log implements basic but useful request (access) logging middleware.
+package log
+
+import (
+ "log"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+// Logger is a basic request logging middleware.
+type Logger struct {
+ Next middleware.Handler
+ Rules []Rule
+ ErrorFunc func(dns.ResponseWriter, *dns.Msg, int) // failover error handler
+}
+
+func (l Logger) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ for _, rule := range l.Rules {
+ /*
+ if middleware.Path(r.URL.Path).Matches(rule.PathScope) {
+ responseRecorder := middleware.NewResponseRecorder(w)
+ status, err := l.Next.ServeHTTP(responseRecorder, r)
+ if status >= 400 {
+ // There was an error up the chain, but no response has been written yet.
+ // The error must be handled here so the log entry will record the response size.
+ if l.ErrorFunc != nil {
+ l.ErrorFunc(responseRecorder, r, status)
+ } else {
+ // Default failover error handler
+ responseRecorder.WriteHeader(status)
+ fmt.Fprintf(responseRecorder, "%d %s", status, http.StatusText(status))
+ }
+ status = 0
+ }
+ rep := middleware.NewReplacer(r, responseRecorder, CommonLogEmptyValue)
+ rule.Log.Println(rep.Replace(rule.Format))
+ return status, err
+ }
+ */
+ rule = rule
+ }
+ return l.Next.ServeDNS(w, r)
+}
+
+// Rule configures the logging middleware.
+type Rule struct {
+ PathScope string
+ OutputFile string
+ Format string
+ Log *log.Logger
+ Roller *middleware.LogRoller
+}
+
+const (
+ // DefaultLogFilename is the default log filename.
+ DefaultLogFilename = "access.log"
+ // CommonLogFormat is the common log format.
+ CommonLogFormat = `{remote} ` + CommonLogEmptyValue + ` [{when}] "{type} {name} {proto}" {rcode} {size}`
+ // CommonLogEmptyValue is the common empty log value.
+ CommonLogEmptyValue = "-"
+ // CombinedLogFormat is the combined log format.
+ CombinedLogFormat = CommonLogFormat + ` "{>Referer}" "{>User-Agent}"` // Something here as well
+ // DefaultLogFormat is the default log format.
+ DefaultLogFormat = CommonLogFormat
+)
diff --git a/middleware/log/log_test.go b/middleware/log/log_test.go
new file mode 100644
index 000000000..40560e4c0
--- /dev/null
+++ b/middleware/log/log_test.go
@@ -0,0 +1,48 @@
+package log
+
+import (
+ "bytes"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+type erroringMiddleware struct{}
+
+func (erroringMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
+ return http.StatusNotFound, nil
+}
+
+func TestLoggedStatus(t *testing.T) {
+ var f bytes.Buffer
+ var next erroringMiddleware
+ rule := Rule{
+ PathScope: "/",
+ Format: DefaultLogFormat,
+ Log: log.New(&f, "", 0),
+ }
+
+ logger := Logger{
+ Rules: []Rule{rule},
+ Next: next,
+ }
+
+ r, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rec := httptest.NewRecorder()
+
+ status, err := logger.ServeHTTP(rec, r)
+ if status != 0 {
+ t.Error("Expected status to be 0 - was", status)
+ }
+
+ logged := f.String()
+ if !strings.Contains(logged, "404 13") {
+ t.Error("Expected 404 to be logged. Logged string -", logged)
+ }
+}
diff --git a/middleware/middleware.go b/middleware/middleware.go
new file mode 100644
index 000000000..436ec86e9
--- /dev/null
+++ b/middleware/middleware.go
@@ -0,0 +1,105 @@
+// Package middleware provides some types and functions common among middleware.
+package middleware
+
+import (
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+type (
+ // Middleware is the middle layer which represents the traditional
+ // idea of middleware: it chains one Handler to the next by being
+ // passed the next Handler in the chain.
+ Middleware func(Handler) Handler
+
+ // Handler is like dns.Handler except ServeDNS may return an rcode
+ // and/or error.
+ //
+ // If ServeDNS writes to the response body, it should return a status
+ // code of 0. This signals to other handlers above it that the response
+ // body is already written, and that they should not write to it also.
+ //
+ // If ServeDNS encounters an error, it should return the error value
+ // so it can be logged by designated error-handling middleware.
+ //
+ // If writing a response after calling another ServeDNS method, the
+ // returned rcode SHOULD be used when writing the response.
+ //
+ // If handling errors after calling another ServeDNS method, the
+ // returned error value SHOULD be logged or handled accordingly.
+ //
+ // Otherwise, return values should be propagated down the middleware
+ // chain by returning them unchanged.
+ Handler interface {
+ ServeDNS(dns.ResponseWriter, *dns.Msg) (int, error)
+ }
+
+ // HandlerFunc is a convenience type like dns.HandlerFunc, except
+ // ServeDNS returns an rcode and an error. See Handler
+ // documentation for more information.
+ HandlerFunc func(dns.ResponseWriter, *dns.Msg) (int, error)
+)
+
+// ServeDNS implements the Handler interface.
+func (f HandlerFunc) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ return f(w, r)
+}
+
+// IndexFile looks for a file in /root/fpath/indexFile for each string
+// in indexFiles. If an index file is found, it returns the root-relative
+// path to the file and true. If no index file is found, empty string
+// and false is returned. fpath must end in a forward slash '/'
+// otherwise no index files will be tried (directory paths must end
+// in a forward slash according to HTTP).
+//
+// All paths passed into and returned from this function use '/' as the
+// path separator, just like URLs. IndexFle handles path manipulation
+// internally for systems that use different path separators.
+/*
+func IndexFile(root http.FileSystem, fpath string, indexFiles []string) (string, bool) {
+ if fpath[len(fpath)-1] != '/' || root == nil {
+ return "", false
+ }
+ for _, indexFile := range indexFiles {
+ // func (http.FileSystem).Open wants all paths separated by "/",
+ // regardless of operating system convention, so use
+ // path.Join instead of filepath.Join
+ fp := path.Join(fpath, indexFile)
+ f, err := root.Open(fp)
+ if err == nil {
+ f.Close()
+ return fp, true
+ }
+ }
+ return "", false
+}
+
+// SetLastModifiedHeader checks if the provided modTime is valid and if it is sets it
+// as a Last-Modified header to the ResponseWriter. If the modTime is in the future
+// the current time is used instead.
+func SetLastModifiedHeader(w http.ResponseWriter, modTime time.Time) {
+ if modTime.IsZero() || modTime.Equal(time.Unix(0, 0)) {
+ // the time does not appear to be valid. Don't put it in the response
+ return
+ }
+
+ // RFC 2616 - Section 14.29 - Last-Modified:
+ // An origin server MUST NOT send a Last-Modified date which is later than the
+ // server's time of message origination. In such cases, where the resource's last
+ // modification would indicate some time in the future, the server MUST replace
+ // that date with the message origination date.
+ now := currentTime()
+ if modTime.After(now) {
+ modTime = now
+ }
+
+ w.Header().Set("Last-Modified", modTime.UTC().Format(http.TimeFormat))
+}
+*/
+
+// currentTime, as it is defined here, returns time.Now().
+// It's defined as a variable for mocking time in tests.
+var currentTime = func() time.Time {
+ return time.Now()
+}
diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go
new file mode 100644
index 000000000..62fa4e250
--- /dev/null
+++ b/middleware/middleware_test.go
@@ -0,0 +1,108 @@
+package middleware
+
+import (
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+)
+
+func TestIndexfile(t *testing.T) {
+ tests := []struct {
+ rootDir http.FileSystem
+ fpath string
+ indexFiles []string
+ shouldErr bool
+ expectedFilePath string //retun value
+ expectedBoolValue bool //return value
+ }{
+ {
+ http.Dir("./templates/testdata"),
+ "/images/",
+ []string{"img.htm"},
+ false,
+ "/images/img.htm",
+ true,
+ },
+ }
+ for i, test := range tests {
+ actualFilePath, actualBoolValue := IndexFile(test.rootDir, test.fpath, test.indexFiles)
+ if actualBoolValue == true && test.shouldErr {
+ t.Errorf("Test %d didn't error, but it should have", i)
+ } else if actualBoolValue != true && !test.shouldErr {
+ t.Errorf("Test %d errored, but it shouldn't have; got %s", i, "Please Add a / at the end of fpath or the indexFiles doesnt exist")
+ }
+ if actualFilePath != test.expectedFilePath {
+ t.Fatalf("Test %d expected returned filepath to be %s, but got %s ",
+ i, test.expectedFilePath, actualFilePath)
+
+ }
+ if actualBoolValue != test.expectedBoolValue {
+ t.Fatalf("Test %d expected returned bool value to be %v, but got %v ",
+ i, test.expectedBoolValue, actualBoolValue)
+
+ }
+ }
+}
+
+func TestSetLastModified(t *testing.T) {
+ nowTime := time.Now()
+
+ // ovewrite the function to return reliable time
+ originalGetCurrentTimeFunc := currentTime
+ currentTime = func() time.Time {
+ return nowTime
+ }
+ defer func() {
+ currentTime = originalGetCurrentTimeFunc
+ }()
+
+ pastTime := nowTime.Truncate(1 * time.Hour)
+ futureTime := nowTime.Add(1 * time.Hour)
+
+ tests := []struct {
+ inputModTime time.Time
+ expectedIsHeaderSet bool
+ expectedLastModified string
+ }{
+ {
+ inputModTime: pastTime,
+ expectedIsHeaderSet: true,
+ expectedLastModified: pastTime.UTC().Format(http.TimeFormat),
+ },
+ {
+ inputModTime: nowTime,
+ expectedIsHeaderSet: true,
+ expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
+ },
+ {
+ inputModTime: futureTime,
+ expectedIsHeaderSet: true,
+ expectedLastModified: nowTime.UTC().Format(http.TimeFormat),
+ },
+ {
+ inputModTime: time.Time{},
+ expectedIsHeaderSet: false,
+ },
+ }
+
+ for i, test := range tests {
+ responseRecorder := httptest.NewRecorder()
+ errorPrefix := fmt.Sprintf("Test [%d]: ", i)
+ SetLastModifiedHeader(responseRecorder, test.inputModTime)
+ actualLastModifiedHeader := responseRecorder.Header().Get("Last-Modified")
+
+ if test.expectedIsHeaderSet && actualLastModifiedHeader == "" {
+ t.Fatalf(errorPrefix + "Expected to find Last-Modified header, but found nothing")
+ }
+
+ if !test.expectedIsHeaderSet && actualLastModifiedHeader != "" {
+ t.Fatalf(errorPrefix+"Did not expect to find Last-Modified header, but found one [%s].", actualLastModifiedHeader)
+ }
+
+ if test.expectedLastModified != actualLastModifiedHeader {
+ t.Errorf(errorPrefix+"Expected Last-Modified content [%s], found [%s}", test.expectedLastModified, actualLastModifiedHeader)
+ }
+ }
+}
diff --git a/middleware/path.go b/middleware/path.go
new file mode 100644
index 000000000..1ffb64b76
--- /dev/null
+++ b/middleware/path.go
@@ -0,0 +1,18 @@
+package middleware
+
+import "strings"
+
+
+// TODO(miek): matches for names.
+
+// Path represents a URI path, maybe with pattern characters.
+type Path string
+
+// Matches checks to see if other matches p.
+//
+// Path matching will probably not always be a direct
+// comparison; this method assures that paths can be
+// easily and consistently matched.
+func (p Path) Matches(other string) bool {
+ return strings.HasPrefix(string(p), other)
+}
diff --git a/middleware/prometheus/handler.go b/middleware/prometheus/handler.go
new file mode 100644
index 000000000..eb82b8aff
--- /dev/null
+++ b/middleware/prometheus/handler.go
@@ -0,0 +1,31 @@
+package metrics
+
+import (
+ "strconv"
+ "time"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+func (m *Metrics) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ context := middleware.Context{W: w, Req: r}
+
+ qname := context.Name()
+ qtype := context.Type()
+ zone := middleware.Zones(m.ZoneNames).Matches(qname)
+ if zone == "" {
+ zone = "."
+ }
+
+ // Record response to get status code and size of the reply.
+ rw := middleware.NewResponseRecorder(w)
+ status, err := m.Next.ServeDNS(rw, r)
+
+ requestCount.WithLabelValues(zone, qtype).Inc()
+ requestDuration.WithLabelValues(zone).Observe(float64(time.Since(rw.Start()) / time.Second))
+ responseSize.WithLabelValues(zone).Observe(float64(rw.Size()))
+ responseRcode.WithLabelValues(zone, strconv.Itoa(rw.Rcode())).Inc()
+
+ return status, err
+}
diff --git a/middleware/prometheus/metrics.go b/middleware/prometheus/metrics.go
new file mode 100644
index 000000000..4c989f640
--- /dev/null
+++ b/middleware/prometheus/metrics.go
@@ -0,0 +1,80 @@
+package metrics
+
+import (
+ "fmt"
+ "net/http"
+ "sync"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/prometheus/client_golang/prometheus"
+)
+
+const namespace = "daddy"
+
+var (
+ requestCount *prometheus.CounterVec
+ requestDuration *prometheus.HistogramVec
+ responseSize *prometheus.HistogramVec
+ responseRcode *prometheus.CounterVec
+)
+
+const path = "/metrics"
+
+// Metrics holds the prometheus configuration. The metrics' path is fixed to be /metrics
+type Metrics struct {
+ Next middleware.Handler
+ Addr string // where to we listen
+ Once sync.Once
+ ZoneNames []string
+}
+
+func (m *Metrics) Start() error {
+ m.Once.Do(func() {
+ define("")
+
+ prometheus.MustRegister(requestCount)
+ prometheus.MustRegister(requestDuration)
+ prometheus.MustRegister(responseSize)
+ prometheus.MustRegister(responseRcode)
+
+ http.Handle(path, prometheus.Handler())
+ go func() {
+ fmt.Errorf("%s", http.ListenAndServe(m.Addr, nil))
+ }()
+ })
+ return nil
+}
+
+func define(subsystem string) {
+ if subsystem == "" {
+ subsystem = "dns"
+ }
+ requestCount = prometheus.NewCounterVec(prometheus.CounterOpts{
+ Namespace: namespace,
+ Subsystem: subsystem,
+ Name: "request_count_total",
+ Help: "Counter of DNS requests made per zone and type.",
+ }, []string{"zone", "qtype"})
+
+ requestDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
+ Namespace: namespace,
+ Subsystem: subsystem,
+ Name: "request_duration_seconds",
+ Help: "Histogram of the time (in seconds) each request took.",
+ }, []string{"zone"})
+
+ responseSize = prometheus.NewHistogramVec(prometheus.HistogramOpts{
+ Namespace: namespace,
+ Subsystem: subsystem,
+ Name: "response_size_bytes",
+ Help: "Size of the returns response in bytes.",
+ Buckets: []float64{0, 100, 200, 300, 400, 511, 1023, 2047, 4095, 8291, 16e3, 32e3, 48e3, 64e3},
+ }, []string{"zone"})
+
+ responseRcode = prometheus.NewCounterVec(prometheus.CounterOpts{
+ Namespace: namespace,
+ Subsystem: subsystem,
+ Name: "rcode_code_count_total",
+ Help: "Counter of response status codes.",
+ }, []string{"zone", "rcode"})
+}
diff --git a/middleware/proxy/policy.go b/middleware/proxy/policy.go
new file mode 100644
index 000000000..a2522bcb1
--- /dev/null
+++ b/middleware/proxy/policy.go
@@ -0,0 +1,101 @@
+package proxy
+
+import (
+ "math/rand"
+ "sync/atomic"
+)
+
+// HostPool is a collection of UpstreamHosts.
+type HostPool []*UpstreamHost
+
+// Policy decides how a host will be selected from a pool.
+type Policy interface {
+ Select(pool HostPool) *UpstreamHost
+}
+
+func init() {
+ RegisterPolicy("random", func() Policy { return &Random{} })
+ RegisterPolicy("least_conn", func() Policy { return &LeastConn{} })
+ RegisterPolicy("round_robin", func() Policy { return &RoundRobin{} })
+}
+
+// Random is a policy that selects up hosts from a pool at random.
+type Random struct{}
+
+// Select selects an up host at random from the specified pool.
+func (r *Random) Select(pool HostPool) *UpstreamHost {
+ // instead of just generating a random index
+ // this is done to prevent selecting a down host
+ var randHost *UpstreamHost
+ count := 0
+ for _, host := range pool {
+ if host.Down() {
+ continue
+ }
+ count++
+ if count == 1 {
+ randHost = host
+ } else {
+ r := rand.Int() % count
+ if r == (count - 1) {
+ randHost = host
+ }
+ }
+ }
+ return randHost
+}
+
+// LeastConn is a policy that selects the host with the least connections.
+type LeastConn struct{}
+
+// Select selects the up host with the least number of connections in the
+// pool. If more than one host has the same least number of connections,
+// one of the hosts is chosen at random.
+func (r *LeastConn) Select(pool HostPool) *UpstreamHost {
+ var bestHost *UpstreamHost
+ count := 0
+ leastConn := int64(1<<63 - 1)
+ for _, host := range pool {
+ if host.Down() {
+ continue
+ }
+ hostConns := host.Conns
+ if hostConns < leastConn {
+ bestHost = host
+ leastConn = hostConns
+ count = 1
+ } else if hostConns == leastConn {
+ // randomly select host among hosts with least connections
+ count++
+ if count == 1 {
+ bestHost = host
+ } else {
+ r := rand.Int() % count
+ if r == (count - 1) {
+ bestHost = host
+ }
+ }
+ }
+ }
+ return bestHost
+}
+
+// RoundRobin is a policy that selects hosts based on round robin ordering.
+type RoundRobin struct {
+ Robin uint32
+}
+
+// Select selects an up host from the pool using a round robin ordering scheme.
+func (r *RoundRobin) Select(pool HostPool) *UpstreamHost {
+ poolLen := uint32(len(pool))
+ selection := atomic.AddUint32(&r.Robin, 1) % poolLen
+ host := pool[selection]
+ // if the currently selected host is down, just ffwd to up host
+ for i := uint32(1); host.Down() && i < poolLen; i++ {
+ host = pool[(selection+i)%poolLen]
+ }
+ if host.Down() {
+ return nil
+ }
+ return host
+}
diff --git a/middleware/proxy/policy_test.go b/middleware/proxy/policy_test.go
new file mode 100644
index 000000000..8f4f1f792
--- /dev/null
+++ b/middleware/proxy/policy_test.go
@@ -0,0 +1,87 @@
+package proxy
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "testing"
+)
+
+var workableServer *httptest.Server
+
+func TestMain(m *testing.M) {
+ workableServer = httptest.NewServer(http.HandlerFunc(
+ func(w http.ResponseWriter, r *http.Request) {
+ // do nothing
+ }))
+ r := m.Run()
+ workableServer.Close()
+ os.Exit(r)
+}
+
+type customPolicy struct{}
+
+func (r *customPolicy) Select(pool HostPool) *UpstreamHost {
+ return pool[0]
+}
+
+func testPool() HostPool {
+ pool := []*UpstreamHost{
+ {
+ Name: workableServer.URL, // this should resolve (healthcheck test)
+ },
+ {
+ Name: "http://shouldnot.resolve", // this shouldn't
+ },
+ {
+ Name: "http://C",
+ },
+ }
+ return HostPool(pool)
+}
+
+func TestRoundRobinPolicy(t *testing.T) {
+ pool := testPool()
+ rrPolicy := &RoundRobin{}
+ h := rrPolicy.Select(pool)
+ // First selected host is 1, because counter starts at 0
+ // and increments before host is selected
+ if h != pool[1] {
+ t.Error("Expected first round robin host to be second host in the pool.")
+ }
+ h = rrPolicy.Select(pool)
+ if h != pool[2] {
+ t.Error("Expected second round robin host to be third host in the pool.")
+ }
+ // mark host as down
+ pool[0].Unhealthy = true
+ h = rrPolicy.Select(pool)
+ if h != pool[1] {
+ t.Error("Expected third round robin host to be first host in the pool.")
+ }
+}
+
+func TestLeastConnPolicy(t *testing.T) {
+ pool := testPool()
+ lcPolicy := &LeastConn{}
+ pool[0].Conns = 10
+ pool[1].Conns = 10
+ h := lcPolicy.Select(pool)
+ if h != pool[2] {
+ t.Error("Expected least connection host to be third host.")
+ }
+ pool[2].Conns = 100
+ h = lcPolicy.Select(pool)
+ if h != pool[0] && h != pool[1] {
+ t.Error("Expected least connection host to be first or second host.")
+ }
+}
+
+func TestCustomPolicy(t *testing.T) {
+ pool := testPool()
+ customPolicy := &customPolicy{}
+ h := customPolicy.Select(pool)
+ if h != pool[0] {
+ t.Error("Expected custom policy host to be the first host.")
+ }
+}
diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go
new file mode 100644
index 000000000..169e41b61
--- /dev/null
+++ b/middleware/proxy/proxy.go
@@ -0,0 +1,120 @@
+// Package proxy is middleware that proxies requests.
+package proxy
+
+import (
+ "errors"
+ "net/http"
+ "sync/atomic"
+ "time"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+var errUnreachable = errors.New("unreachable backend")
+
+// Proxy represents a middleware instance that can proxy requests.
+type Proxy struct {
+ Next middleware.Handler
+ Client Client
+ Upstreams []Upstream
+}
+
+type Client struct {
+ UDP *dns.Client
+ TCP *dns.Client
+}
+
+// Upstream manages a pool of proxy upstream hosts. Select should return a
+// suitable upstream host, or nil if no such hosts are available.
+type Upstream interface {
+ // The domain name this upstream host should be routed on.
+ From() string
+ // Selects an upstream host to be routed to.
+ Select() *UpstreamHost
+ // Checks if subpdomain is not an ignored.
+ IsAllowedPath(string) bool
+}
+
+// UpstreamHostDownFunc can be used to customize how Down behaves.
+type UpstreamHostDownFunc func(*UpstreamHost) bool
+
+// UpstreamHost represents a single proxy upstream
+type UpstreamHost struct {
+ Conns int64 // must be first field to be 64-bit aligned on 32-bit systems
+ Name string // IP address (and port) of this upstream host
+ Fails int32
+ FailTimeout time.Duration
+ Unhealthy bool
+ ExtraHeaders http.Header
+ CheckDown UpstreamHostDownFunc
+ WithoutPathPrefix string
+}
+
+// Down checks whether the upstream host is down or not.
+// Down will try to use uh.CheckDown first, and will fall
+// back to some default criteria if necessary.
+func (uh *UpstreamHost) Down() bool {
+ if uh.CheckDown == nil {
+ // Default settings
+ return uh.Unhealthy || uh.Fails > 0
+ }
+ return uh.CheckDown(uh)
+}
+
+// tryDuration is how long to try upstream hosts; failures result in
+// immediate retries until this duration ends or we get a nil host.
+var tryDuration = 60 * time.Second
+
+// ServeDNS satisfies the middleware.Handler interface.
+func (p Proxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ for _, upstream := range p.Upstreams {
+ // allowed bla bla bla TODO(miek): fix full proxy spec from caddy
+ start := time.Now()
+
+ // Since Select() should give us "up" hosts, keep retrying
+ // hosts until timeout (or until we get a nil host).
+ for time.Now().Sub(start) < tryDuration {
+ host := upstream.Select()
+ if host == nil {
+ return dns.RcodeServerFailure, errUnreachable
+ }
+ // TODO(miek): PORT!
+ reverseproxy := ReverseProxy{Host: host.Name, Client: p.Client}
+
+ atomic.AddInt64(&host.Conns, 1)
+ backendErr := reverseproxy.ServeDNS(w, r, nil)
+ atomic.AddInt64(&host.Conns, -1)
+ if backendErr == nil {
+ return 0, nil
+ }
+ timeout := host.FailTimeout
+ if timeout == 0 {
+ timeout = 10 * time.Second
+ }
+ atomic.AddInt32(&host.Fails, 1)
+ go func(host *UpstreamHost, timeout time.Duration) {
+ time.Sleep(timeout)
+ atomic.AddInt32(&host.Fails, -1)
+ }(host, timeout)
+ }
+ return dns.RcodeServerFailure, errUnreachable
+ }
+ return p.Next.ServeDNS(w, r)
+}
+
+func Clients() Client {
+ udp := newClient("udp", defaultTimeout)
+ tcp := newClient("tcp", defaultTimeout)
+ return Client{UDP: udp, TCP: tcp}
+}
+
+// newClient returns a new client for proxy requests.
+func newClient(net string, timeout time.Duration) *dns.Client {
+ if timeout == 0 {
+ timeout = defaultTimeout
+ }
+ return &dns.Client{Net: net, ReadTimeout: timeout, WriteTimeout: timeout, SingleInflight: true}
+}
+
+const defaultTimeout = 5 * time.Second
diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go
new file mode 100644
index 000000000..8066874d2
--- /dev/null
+++ b/middleware/proxy/proxy_test.go
@@ -0,0 +1,317 @@
+package proxy
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+
+ "golang.org/x/net/websocket"
+)
+
+func init() {
+ tryDuration = 50 * time.Millisecond // prevent tests from hanging
+}
+
+func TestReverseProxy(t *testing.T) {
+ log.SetOutput(ioutil.Discard)
+ defer log.SetOutput(os.Stderr)
+
+ var requestReceived bool
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ requestReceived = true
+ w.Write([]byte("Hello, client"))
+ }))
+ defer backend.Close()
+
+ // set up proxy
+ p := &Proxy{
+ Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
+ }
+
+ // create request and response recorder
+ r, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatalf("Failed to create request: %v", err)
+ }
+ w := httptest.NewRecorder()
+
+ p.ServeHTTP(w, r)
+
+ if !requestReceived {
+ t.Error("Expected backend to receive request, but it didn't")
+ }
+}
+
+func TestReverseProxyInsecureSkipVerify(t *testing.T) {
+ log.SetOutput(ioutil.Discard)
+ defer log.SetOutput(os.Stderr)
+
+ var requestReceived bool
+ backend := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ requestReceived = true
+ w.Write([]byte("Hello, client"))
+ }))
+ defer backend.Close()
+
+ // set up proxy
+ p := &Proxy{
+ Upstreams: []Upstream{newFakeUpstream(backend.URL, true)},
+ }
+
+ // create request and response recorder
+ r, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatalf("Failed to create request: %v", err)
+ }
+ w := httptest.NewRecorder()
+
+ p.ServeHTTP(w, r)
+
+ if !requestReceived {
+ t.Error("Even with insecure HTTPS, expected backend to receive request, but it didn't")
+ }
+}
+
+func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
+ // No-op websocket backend simply allows the WS connection to be
+ // accepted then it will be immediately closed. Perfect for testing.
+ wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {}))
+ defer wsNop.Close()
+
+ // Get proxy to use for the test
+ p := newWebSocketTestProxy(wsNop.URL)
+
+ // Create client request
+ r, err := http.NewRequest("GET", "/", nil)
+ if err != nil {
+ t.Fatalf("Failed to create request: %v", err)
+ }
+ r.Header = http.Header{
+ "Connection": {"Upgrade"},
+ "Upgrade": {"websocket"},
+ "Origin": {wsNop.URL},
+ "Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="},
+ "Sec-WebSocket-Version": {"13"},
+ }
+
+ // Capture the request
+ w := &recorderHijacker{httptest.NewRecorder(), new(fakeConn)}
+
+ // Booya! Do the test.
+ p.ServeHTTP(w, r)
+
+ // Make sure the backend accepted the WS connection.
+ // Mostly interested in the Upgrade and Connection response headers
+ // and the 101 status code.
+ expected := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n\r\n")
+ actual := w.fakeConn.writeBuf.Bytes()
+ if !bytes.Equal(actual, expected) {
+ t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual)
+ }
+}
+
+func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
+ // Echo server allows us to test that socket bytes are properly
+ // being proxied.
+ wsEcho := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {
+ io.Copy(ws, ws)
+ }))
+ defer wsEcho.Close()
+
+ // Get proxy to use for the test
+ p := newWebSocketTestProxy(wsEcho.URL)
+
+ // This is a full end-end test, so the proxy handler
+ // has to be part of a server listening on a port. Our
+ // WS client will connect to this test server, not
+ // the echo client directly.
+ echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ p.ServeHTTP(w, r)
+ }))
+ defer echoProxy.Close()
+
+ // Set up WebSocket client
+ url := strings.Replace(echoProxy.URL, "http://", "ws://", 1)
+ ws, err := websocket.Dial(url, "", echoProxy.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ws.Close()
+
+ // Send test message
+ trialMsg := "Is it working?"
+ websocket.Message.Send(ws, trialMsg)
+
+ // It should be echoed back to us
+ var actualMsg string
+ websocket.Message.Receive(ws, &actualMsg)
+ if actualMsg != trialMsg {
+ t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
+ }
+}
+
+func TestUnixSocketProxy(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ return
+ }
+
+ trialMsg := "Is it working?"
+
+ var proxySuccess bool
+
+ // This is our fake "application" we want to proxy to
+ ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Request was proxied when this is called
+ proxySuccess = true
+
+ fmt.Fprint(w, trialMsg)
+ }))
+
+ // Get absolute path for unix: socket
+ socketPath, err := filepath.Abs("./test_socket")
+ if err != nil {
+ t.Fatalf("Unable to get absolute path: %v", err)
+ }
+
+ // Change httptest.Server listener to listen to unix: socket
+ ln, err := net.Listen("unix", socketPath)
+ if err != nil {
+ t.Fatalf("Unable to listen: %v", err)
+ }
+ ts.Listener = ln
+
+ ts.Start()
+ defer ts.Close()
+
+ url := strings.Replace(ts.URL, "http://", "unix:", 1)
+ p := newWebSocketTestProxy(url)
+
+ echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ p.ServeHTTP(w, r)
+ }))
+ defer echoProxy.Close()
+
+ res, err := http.Get(echoProxy.URL)
+ if err != nil {
+ t.Fatalf("Unable to GET: %v", err)
+ }
+
+ greeting, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatalf("Unable to GET: %v", err)
+ }
+
+ actualMsg := fmt.Sprintf("%s", greeting)
+
+ if !proxySuccess {
+ t.Errorf("Expected request to be proxied, but it wasn't")
+ }
+
+ if actualMsg != trialMsg {
+ t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg)
+ }
+}
+
+func newFakeUpstream(name string, insecure bool) *fakeUpstream {
+ uri, _ := url.Parse(name)
+ u := &fakeUpstream{
+ name: name,
+ host: &UpstreamHost{
+ Name: name,
+ ReverseProxy: NewSingleHostReverseProxy(uri, ""),
+ },
+ }
+ if insecure {
+ u.host.ReverseProxy.Transport = InsecureTransport
+ }
+ return u
+}
+
+type fakeUpstream struct {
+ name string
+ host *UpstreamHost
+}
+
+func (u *fakeUpstream) From() string {
+ return "/"
+}
+
+func (u *fakeUpstream) Select() *UpstreamHost {
+ return u.host
+}
+
+func (u *fakeUpstream) IsAllowedPath(requestPath string) bool {
+ return true
+}
+
+// newWebSocketTestProxy returns a test proxy that will
+// redirect to the specified backendAddr. The function
+// also sets up the rules/environment for testing WebSocket
+// proxy.
+func newWebSocketTestProxy(backendAddr string) *Proxy {
+ return &Proxy{
+ Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr}},
+ }
+}
+
+type fakeWsUpstream struct {
+ name string
+}
+
+func (u *fakeWsUpstream) From() string {
+ return "/"
+}
+
+func (u *fakeWsUpstream) Select() *UpstreamHost {
+ uri, _ := url.Parse(u.name)
+ return &UpstreamHost{
+ Name: u.name,
+ ReverseProxy: NewSingleHostReverseProxy(uri, ""),
+ ExtraHeaders: http.Header{
+ "Connection": {"{>Connection}"},
+ "Upgrade": {"{>Upgrade}"}},
+ }
+}
+
+func (u *fakeWsUpstream) IsAllowedPath(requestPath string) bool {
+ return true
+}
+
+// recorderHijacker is a ResponseRecorder that can
+// be hijacked.
+type recorderHijacker struct {
+ *httptest.ResponseRecorder
+ fakeConn *fakeConn
+}
+
+func (rh *recorderHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ return rh.fakeConn, nil, nil
+}
+
+type fakeConn struct {
+ readBuf bytes.Buffer
+ writeBuf bytes.Buffer
+}
+
+func (c *fakeConn) LocalAddr() net.Addr { return nil }
+func (c *fakeConn) RemoteAddr() net.Addr { return nil }
+func (c *fakeConn) SetDeadline(t time.Time) error { return nil }
+func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil }
+func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil }
+func (c *fakeConn) Close() error { return nil }
+func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) }
+func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) }
diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go
new file mode 100644
index 000000000..6d27da042
--- /dev/null
+++ b/middleware/proxy/reverseproxy.go
@@ -0,0 +1,36 @@
+// Package proxy is middleware that proxies requests.
+package proxy
+
+import (
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+type ReverseProxy struct {
+ Host string
+ Client Client
+}
+
+func (p ReverseProxy) ServeDNS(w dns.ResponseWriter, r *dns.Msg, extra []dns.RR) error {
+ // TODO(miek): use extra!
+ var (
+ reply *dns.Msg
+ err error
+ )
+ context := middleware.Context{W: w, Req: r}
+
+ // tls+tcp ?
+ if context.Proto() == "tcp" {
+ reply, err = middleware.Exchange(p.Client.TCP, r, p.Host)
+ } else {
+ reply, err = middleware.Exchange(p.Client.UDP, r, p.Host)
+ }
+
+ if err != nil {
+ return err
+ }
+ reply.Compress = true
+ reply.Id = r.Id
+ w.WriteMsg(reply)
+ return nil
+}
diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go
new file mode 100644
index 000000000..092e2351d
--- /dev/null
+++ b/middleware/proxy/upstream.go
@@ -0,0 +1,235 @@
+package proxy
+
+import (
+ "io"
+ "io/ioutil"
+ "net/http"
+ "path"
+ "strconv"
+ "time"
+
+ "github.com/miekg/coredns/core/parse"
+ "github.com/miekg/coredns/middleware"
+)
+
+var (
+ supportedPolicies = make(map[string]func() Policy)
+)
+
+type staticUpstream struct {
+ from string
+ // TODO(miek): allows use to added headers
+ proxyHeaders http.Header // TODO(miek): kill
+ Hosts HostPool
+ Policy Policy
+
+ FailTimeout time.Duration
+ MaxFails int32
+ HealthCheck struct {
+ Path string
+ Interval time.Duration
+ }
+ WithoutPathPrefix string
+ IgnoredSubPaths []string
+}
+
+// NewStaticUpstreams parses the configuration input and sets up
+// static upstreams for the proxy middleware.
+func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) {
+ var upstreams []Upstream
+ for c.Next() {
+ upstream := &staticUpstream{
+ from: "",
+ proxyHeaders: make(http.Header),
+ Hosts: nil,
+ Policy: &Random{},
+ FailTimeout: 10 * time.Second,
+ MaxFails: 1,
+ }
+
+ if !c.Args(&upstream.from) {
+ return upstreams, c.ArgErr()
+ }
+ to := c.RemainingArgs()
+ if len(to) == 0 {
+ return upstreams, c.ArgErr()
+ }
+
+ for c.NextBlock() {
+ if err := parseBlock(&c, upstream); err != nil {
+ return upstreams, err
+ }
+ }
+
+ upstream.Hosts = make([]*UpstreamHost, len(to))
+ for i, host := range to {
+ uh := &UpstreamHost{
+ Name: host,
+ Conns: 0,
+ Fails: 0,
+ FailTimeout: upstream.FailTimeout,
+ Unhealthy: false,
+ ExtraHeaders: upstream.proxyHeaders,
+ CheckDown: func(upstream *staticUpstream) UpstreamHostDownFunc {
+ return func(uh *UpstreamHost) bool {
+ if uh.Unhealthy {
+ return true
+ }
+ if uh.Fails >= upstream.MaxFails &&
+ upstream.MaxFails != 0 {
+ return true
+ }
+ return false
+ }
+ }(upstream),
+ WithoutPathPrefix: upstream.WithoutPathPrefix,
+ }
+ upstream.Hosts[i] = uh
+ }
+
+ if upstream.HealthCheck.Path != "" {
+ go upstream.HealthCheckWorker(nil)
+ }
+ upstreams = append(upstreams, upstream)
+ }
+ return upstreams, nil
+}
+
+// RegisterPolicy adds a custom policy to the proxy.
+func RegisterPolicy(name string, policy func() Policy) {
+ supportedPolicies[name] = policy
+}
+
+func (u *staticUpstream) From() string {
+ return u.from
+}
+
+func parseBlock(c *parse.Dispenser, u *staticUpstream) error {
+ switch c.Val() {
+ case "policy":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ policyCreateFunc, ok := supportedPolicies[c.Val()]
+ if !ok {
+ return c.ArgErr()
+ }
+ u.Policy = policyCreateFunc()
+ case "fail_timeout":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ dur, err := time.ParseDuration(c.Val())
+ if err != nil {
+ return err
+ }
+ u.FailTimeout = dur
+ case "max_fails":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ n, err := strconv.Atoi(c.Val())
+ if err != nil {
+ return err
+ }
+ u.MaxFails = int32(n)
+ case "health_check":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ u.HealthCheck.Path = c.Val()
+ u.HealthCheck.Interval = 30 * time.Second
+ if c.NextArg() {
+ dur, err := time.ParseDuration(c.Val())
+ if err != nil {
+ return err
+ }
+ u.HealthCheck.Interval = dur
+ }
+ case "proxy_header":
+ var header, value string
+ if !c.Args(&header, &value) {
+ return c.ArgErr()
+ }
+ u.proxyHeaders.Add(header, value)
+ case "websocket":
+ u.proxyHeaders.Add("Connection", "{>Connection}")
+ u.proxyHeaders.Add("Upgrade", "{>Upgrade}")
+ case "without":
+ if !c.NextArg() {
+ return c.ArgErr()
+ }
+ u.WithoutPathPrefix = c.Val()
+ case "except":
+ ignoredPaths := c.RemainingArgs()
+ if len(ignoredPaths) == 0 {
+ return c.ArgErr()
+ }
+ u.IgnoredSubPaths = ignoredPaths
+ default:
+ return c.Errf("unknown property '%s'", c.Val())
+ }
+ return nil
+}
+
+func (u *staticUpstream) healthCheck() {
+ for _, host := range u.Hosts {
+ hostURL := host.Name + u.HealthCheck.Path
+ if r, err := http.Get(hostURL); err == nil {
+ io.Copy(ioutil.Discard, r.Body)
+ r.Body.Close()
+ host.Unhealthy = r.StatusCode < 200 || r.StatusCode >= 400
+ } else {
+ host.Unhealthy = true
+ }
+ }
+}
+
+func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) {
+ ticker := time.NewTicker(u.HealthCheck.Interval)
+ u.healthCheck()
+ for {
+ select {
+ case <-ticker.C:
+ u.healthCheck()
+ case <-stop:
+ // TODO: the library should provide a stop channel and global
+ // waitgroup to allow goroutines started by plugins a chance
+ // to clean themselves up.
+ }
+ }
+}
+
+func (u *staticUpstream) Select() *UpstreamHost {
+ pool := u.Hosts
+ if len(pool) == 1 {
+ if pool[0].Down() {
+ return nil
+ }
+ return pool[0]
+ }
+ allDown := true
+ for _, host := range pool {
+ if !host.Down() {
+ allDown = false
+ break
+ }
+ }
+ if allDown {
+ return nil
+ }
+
+ if u.Policy == nil {
+ return (&Random{}).Select(pool)
+ }
+ return u.Policy.Select(pool)
+}
+
+func (u *staticUpstream) IsAllowedPath(requestPath string) bool {
+ for _, ignoredSubPath := range u.IgnoredSubPaths {
+ if middleware.Path(path.Clean(requestPath)).Matches(path.Join(u.From(), ignoredSubPath)) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/middleware/proxy/upstream_test.go b/middleware/proxy/upstream_test.go
new file mode 100644
index 000000000..5b2fdb1da
--- /dev/null
+++ b/middleware/proxy/upstream_test.go
@@ -0,0 +1,83 @@
+package proxy
+
+import (
+ "testing"
+ "time"
+)
+
+func TestHealthCheck(t *testing.T) {
+ upstream := &staticUpstream{
+ from: "",
+ Hosts: testPool(),
+ Policy: &Random{},
+ FailTimeout: 10 * time.Second,
+ MaxFails: 1,
+ }
+ upstream.healthCheck()
+ if upstream.Hosts[0].Down() {
+ t.Error("Expected first host in testpool to not fail healthcheck.")
+ }
+ if !upstream.Hosts[1].Down() {
+ t.Error("Expected second host in testpool to fail healthcheck.")
+ }
+}
+
+func TestSelect(t *testing.T) {
+ upstream := &staticUpstream{
+ from: "",
+ Hosts: testPool()[:3],
+ Policy: &Random{},
+ FailTimeout: 10 * time.Second,
+ MaxFails: 1,
+ }
+ upstream.Hosts[0].Unhealthy = true
+ upstream.Hosts[1].Unhealthy = true
+ upstream.Hosts[2].Unhealthy = true
+ if h := upstream.Select(); h != nil {
+ t.Error("Expected select to return nil as all host are down")
+ }
+ upstream.Hosts[2].Unhealthy = false
+ if h := upstream.Select(); h == nil {
+ t.Error("Expected select to not return nil")
+ }
+}
+
+func TestRegisterPolicy(t *testing.T) {
+ name := "custom"
+ customPolicy := &customPolicy{}
+ RegisterPolicy(name, func() Policy { return customPolicy })
+ if _, ok := supportedPolicies[name]; !ok {
+ t.Error("Expected supportedPolicies to have a custom policy.")
+ }
+
+}
+
+func TestAllowedPaths(t *testing.T) {
+ upstream := &staticUpstream{
+ from: "/proxy",
+ IgnoredSubPaths: []string{"/download", "/static"},
+ }
+ tests := []struct {
+ url string
+ expected bool
+ }{
+ {"/proxy", true},
+ {"/proxy/dl", true},
+ {"/proxy/download", false},
+ {"/proxy/download/static", false},
+ {"/proxy/static", false},
+ {"/proxy/static/download", false},
+ {"/proxy/something/download", true},
+ {"/proxy/something/static", true},
+ {"/proxy//static", false},
+ {"/proxy//static//download", false},
+ {"/proxy//download", false},
+ }
+
+ for i, test := range tests {
+ isAllowed := upstream.IsAllowedPath(test.url)
+ if test.expected != isAllowed {
+ t.Errorf("Test %d: expected %v found %v", i+1, test.expected, isAllowed)
+ }
+ }
+}
diff --git a/middleware/recorder.go b/middleware/recorder.go
new file mode 100644
index 000000000..38a7e0e82
--- /dev/null
+++ b/middleware/recorder.go
@@ -0,0 +1,70 @@
+package middleware
+
+import (
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// ResponseRecorder is a type of ResponseWriter that captures
+// the rcode code written to it and also the size of the message
+// written in the response. A rcode code does not have
+// to be written, however, in which case 0 must be assumed.
+// It is best to have the constructor initialize this type
+// with that default status code.
+type ResponseRecorder struct {
+ dns.ResponseWriter
+ rcode int
+ size int
+ start time.Time
+}
+
+// NewResponseRecorder makes and returns a new responseRecorder,
+// which captures the DNS rcode from the ResponseWriter
+// and also the length of the response message written through it.
+func NewResponseRecorder(w dns.ResponseWriter) *ResponseRecorder {
+ return &ResponseRecorder{
+ ResponseWriter: w,
+ rcode: 0,
+ start: time.Now(),
+ }
+}
+
+// WriteMsg records the status code and calls the
+// underlying ResponseWriter's WriteMsg method.
+func (r *ResponseRecorder) WriteMsg(res *dns.Msg) error {
+ r.rcode = res.Rcode
+ r.size = res.Len()
+ return r.ResponseWriter.WriteMsg(res)
+}
+
+// Write is a wrapper that records the size of the message that gets written.
+func (r *ResponseRecorder) Write(buf []byte) (int, error) {
+ n, err := r.ResponseWriter.Write(buf)
+ if err == nil {
+ r.size += n
+ }
+ return n, err
+}
+
+// Size returns the size.
+func (r *ResponseRecorder) Size() int {
+ return r.size
+}
+
+// Rcode returns the rcode.
+func (r *ResponseRecorder) Rcode() int {
+ return r.rcode
+}
+
+// Start returns the start time of the ResponseRecorder.
+func (r *ResponseRecorder) Start() time.Time {
+ return r.start
+}
+
+// Hijack implements dns.Hijacker. It simply wraps the underlying
+// ResponseWriter's Hijack method if there is one, or returns an error.
+func (r *ResponseRecorder) Hijack() {
+ r.ResponseWriter.Hijack()
+ return
+}
diff --git a/middleware/recorder_test.go b/middleware/recorder_test.go
new file mode 100644
index 000000000..a8c8a5d04
--- /dev/null
+++ b/middleware/recorder_test.go
@@ -0,0 +1,32 @@
+package middleware
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "testing"
+)
+
+func TestNewResponseRecorder(t *testing.T) {
+ w := httptest.NewRecorder()
+ recordRequest := NewResponseRecorder(w)
+ if !(recordRequest.ResponseWriter == w) {
+ t.Fatalf("Expected Response writer in the Recording to be same as the one sent\n")
+ }
+ if recordRequest.status != http.StatusOK {
+ t.Fatalf("Expected recorded status to be http.StatusOK (%d) , but found %d\n ", http.StatusOK, recordRequest.status)
+ }
+}
+
+func TestWrite(t *testing.T) {
+ w := httptest.NewRecorder()
+ responseTestString := "test"
+ recordRequest := NewResponseRecorder(w)
+ buf := []byte(responseTestString)
+ recordRequest.Write(buf)
+ if recordRequest.size != len(buf) {
+ t.Fatalf("Expected the bytes written counter to be %d, but instead found %d\n", len(buf), recordRequest.size)
+ }
+ if w.Body.String() != responseTestString {
+ t.Fatalf("Expected Response Body to be %s , but found %s\n", responseTestString, w.Body.String())
+ }
+}
diff --git a/middleware/reflect/reflect.go b/middleware/reflect/reflect.go
new file mode 100644
index 000000000..6d5847b81
--- /dev/null
+++ b/middleware/reflect/reflect.go
@@ -0,0 +1,84 @@
+// Reflect provides middleware that reflects back some client properties.
+// This is the default middleware when Caddy is run without configuration.
+//
+// The left-most label must be `who`.
+// When queried for type A (resp. AAAA), it sends back the IPv4 (resp. v6) address.
+// In the additional section the port number and transport are shown.
+// Basic use pattern:
+//
+// dig @localhost -p 1053 who.miek.nl A
+//
+// ;; ANSWER SECTION:
+// who.miek.nl. 0 IN A 127.0.0.1
+//
+// ;; ADDITIONAL SECTION:
+// who.miek.nl. 0 IN TXT "Port: 56195 (udp)"
+package reflect
+
+import (
+ "errors"
+ "net"
+ "strings"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+type Reflect struct {
+ Next middleware.Handler
+}
+
+func (rl Reflect) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ context := middleware.Context{Req: r, W: w}
+
+ class := r.Question[0].Qclass
+ qname := r.Question[0].Name
+ i, ok := dns.NextLabel(qname, 0)
+
+ if strings.ToLower(qname[:i]) != who || ok {
+ err := context.ErrorMessage(dns.RcodeFormatError)
+ w.WriteMsg(err)
+ return dns.RcodeFormatError, errors.New(dns.RcodeToString[dns.RcodeFormatError])
+ }
+
+ answer := new(dns.Msg)
+ answer.SetReply(r)
+ answer.Compress = true
+ answer.Authoritative = true
+
+ ip := context.IP()
+ proto := context.Proto()
+ port, _ := context.Port()
+ family := context.Family()
+ var rr dns.RR
+
+ switch family {
+ case 1:
+ rr = new(dns.A)
+ rr.(*dns.A).Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeA, Class: class, Ttl: 0}
+ rr.(*dns.A).A = net.ParseIP(ip).To4()
+ case 2:
+ rr = new(dns.AAAA)
+ rr.(*dns.AAAA).Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeAAAA, Class: class, Ttl: 0}
+ rr.(*dns.AAAA).AAAA = net.ParseIP(ip)
+ }
+
+ t := new(dns.TXT)
+ t.Hdr = dns.RR_Header{Name: qname, Rrtype: dns.TypeTXT, Class: class, Ttl: 0}
+ t.Txt = []string{"Port: " + port + " (" + proto + ")"}
+
+ switch context.Type() {
+ case "TXT":
+ answer.Answer = append(answer.Answer, t)
+ answer.Extra = append(answer.Extra, rr)
+ default:
+ fallthrough
+ case "AAAA", "A":
+ answer.Answer = append(answer.Answer, rr)
+ answer.Extra = append(answer.Extra, t)
+ }
+ w.WriteMsg(answer)
+ return 0, nil
+}
+
+const who = "who."
diff --git a/middleware/reflect/reflect_test.go b/middleware/reflect/reflect_test.go
new file mode 100644
index 000000000..477a3a573
--- /dev/null
+++ b/middleware/reflect/reflect_test.go
@@ -0,0 +1 @@
+package reflect
diff --git a/middleware/replacer.go b/middleware/replacer.go
new file mode 100644
index 000000000..133da74c5
--- /dev/null
+++ b/middleware/replacer.go
@@ -0,0 +1,98 @@
+package middleware
+
+import (
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// Replacer is a type which can replace placeholder
+// substrings in a string with actual values from a
+// http.Request and responseRecorder. Always use
+// NewReplacer to get one of these.
+type Replacer interface {
+ Replace(string) string
+ Set(key, value string)
+}
+
+type replacer struct {
+ replacements map[string]string
+ emptyValue string
+}
+
+// NewReplacer makes a new replacer based on r and rr.
+// Do not create a new replacer until r and rr have all
+// the needed values, because this function copies those
+// values into the replacer. rr may be nil if it is not
+// available. emptyValue should be the string that is used
+// in place of empty string (can still be empty string).
+func NewReplacer(r *dns.Msg, rr *ResponseRecorder, emptyValue string) Replacer {
+ context := Context{W: rr, Req: r}
+ rep := replacer{
+ replacements: map[string]string{
+ "{type}": context.Type(),
+ "{name}": context.Name(),
+ "{class}": context.Class(),
+ "{proto}": context.Proto(),
+ "{when}": func() string {
+ return time.Now().Format(timeFormat)
+ }(),
+ "{remote}": context.IP(),
+ "{port}": func() string {
+ p, _ := context.Port()
+ return p
+ }(),
+ },
+ emptyValue: emptyValue,
+ }
+ if rr != nil {
+ rep.replacements["{rcode}"] = strconv.Itoa(rr.rcode)
+ rep.replacements["{size}"] = strconv.Itoa(rr.size)
+ rep.replacements["{latency}"] = time.Since(rr.start).String()
+ }
+
+ return rep
+}
+
+// Replace performs a replacement of values on s and returns
+// the string with the replaced values.
+func (r replacer) Replace(s string) string {
+ // Header replacements - these are case-insensitive, so we can't just use strings.Replace()
+ for strings.Contains(s, headerReplacer) {
+ idxStart := strings.Index(s, headerReplacer)
+ endOffset := idxStart + len(headerReplacer)
+ idxEnd := strings.Index(s[endOffset:], "}")
+ if idxEnd > -1 {
+ placeholder := strings.ToLower(s[idxStart : endOffset+idxEnd+1])
+ replacement := r.replacements[placeholder]
+ if replacement == "" {
+ replacement = r.emptyValue
+ }
+ s = s[:idxStart] + replacement + s[endOffset+idxEnd+1:]
+ } else {
+ break
+ }
+ }
+
+ // Regular replacements - these are easier because they're case-sensitive
+ for placeholder, replacement := range r.replacements {
+ if replacement == "" {
+ replacement = r.emptyValue
+ }
+ s = strings.Replace(s, placeholder, replacement, -1)
+ }
+
+ return s
+}
+
+// Set sets key to value in the replacements map.
+func (r replacer) Set(key, value string) {
+ r.replacements["{"+key+"}"] = value
+}
+
+const (
+ timeFormat = "02/Jan/2006:15:04:05 -0700"
+ headerReplacer = "{>"
+)
diff --git a/middleware/replacer_test.go b/middleware/replacer_test.go
new file mode 100644
index 000000000..d98bd2de1
--- /dev/null
+++ b/middleware/replacer_test.go
@@ -0,0 +1,124 @@
+package middleware
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+)
+
+func TestNewReplacer(t *testing.T) {
+ w := httptest.NewRecorder()
+ recordRequest := NewResponseRecorder(w)
+ reader := strings.NewReader(`{"username": "dennis"}`)
+
+ request, err := http.NewRequest("POST", "http://localhost", reader)
+ if err != nil {
+ t.Fatal("Request Formation Failed\n")
+ }
+ replaceValues := NewReplacer(request, recordRequest, "")
+
+ switch v := replaceValues.(type) {
+ case replacer:
+
+ if v.replacements["{host}"] != "localhost" {
+ t.Error("Expected host to be localhost")
+ }
+ if v.replacements["{method}"] != "POST" {
+ t.Error("Expected request method to be POST")
+ }
+ if v.replacements["{status}"] != "200" {
+ t.Error("Expected status to be 200")
+ }
+
+ default:
+ t.Fatal("Return Value from New Replacer expected pass type assertion into a replacer type\n")
+ }
+}
+
+func TestReplace(t *testing.T) {
+ w := httptest.NewRecorder()
+ recordRequest := NewResponseRecorder(w)
+ reader := strings.NewReader(`{"username": "dennis"}`)
+
+ request, err := http.NewRequest("POST", "http://localhost", reader)
+ if err != nil {
+ t.Fatal("Request Formation Failed\n")
+ }
+ request.Header.Set("Custom", "foobarbaz")
+ request.Header.Set("ShorterVal", "1")
+ repl := NewReplacer(request, recordRequest, "-")
+
+ if expected, actual := "This host is localhost.", repl.Replace("This host is {host}."); expected != actual {
+ t.Errorf("{host} replacement: expected '%s', got '%s'", expected, actual)
+ }
+ if expected, actual := "This request method is POST.", repl.Replace("This request method is {method}."); expected != actual {
+ t.Errorf("{method} replacement: expected '%s', got '%s'", expected, actual)
+ }
+ if expected, actual := "The response status is 200.", repl.Replace("The response status is {status}."); expected != actual {
+ t.Errorf("{status} replacement: expected '%s', got '%s'", expected, actual)
+ }
+ if expected, actual := "The Custom header is foobarbaz.", repl.Replace("The Custom header is {>Custom}."); expected != actual {
+ t.Errorf("{>Custom} replacement: expected '%s', got '%s'", expected, actual)
+ }
+
+ // Test header case-insensitivity
+ if expected, actual := "The cUsToM header is foobarbaz...", repl.Replace("The cUsToM header is {>cUsToM}..."); expected != actual {
+ t.Errorf("{>cUsToM} replacement: expected '%s', got '%s'", expected, actual)
+ }
+
+ // Test non-existent header/value
+ if expected, actual := "The Non-Existent header is -.", repl.Replace("The Non-Existent header is {>Non-Existent}."); expected != actual {
+ t.Errorf("{>Non-Existent} replacement: expected '%s', got '%s'", expected, actual)
+ }
+
+ // Test bad placeholder
+ if expected, actual := "Bad {host placeholder...", repl.Replace("Bad {host placeholder..."); expected != actual {
+ t.Errorf("bad placeholder: expected '%s', got '%s'", expected, actual)
+ }
+
+ // Test bad header placeholder
+ if expected, actual := "Bad {>Custom placeholder", repl.Replace("Bad {>Custom placeholder"); expected != actual {
+ t.Errorf("bad header placeholder: expected '%s', got '%s'", expected, actual)
+ }
+
+ // Test bad header placeholder with valid one later
+ if expected, actual := "Bad -", repl.Replace("Bad {>Custom placeholder {>ShorterVal}"); expected != actual {
+ t.Errorf("bad header placeholders: expected '%s', got '%s'", expected, actual)
+ }
+
+ // Test shorter header value with multiple placeholders
+ if expected, actual := "Short value 1 then foobarbaz.", repl.Replace("Short value {>ShorterVal} then {>Custom}."); expected != actual {
+ t.Errorf("short value: expected '%s', got '%s'", expected, actual)
+ }
+}
+
+func TestSet(t *testing.T) {
+ w := httptest.NewRecorder()
+ recordRequest := NewResponseRecorder(w)
+ reader := strings.NewReader(`{"username": "dennis"}`)
+
+ request, err := http.NewRequest("POST", "http://localhost", reader)
+ if err != nil {
+ t.Fatalf("Request Formation Failed \n")
+ }
+ repl := NewReplacer(request, recordRequest, "")
+
+ repl.Set("host", "getcaddy.com")
+ repl.Set("method", "GET")
+ repl.Set("status", "201")
+ repl.Set("variable", "value")
+
+ if repl.Replace("This host is {host}") != "This host is getcaddy.com" {
+ t.Error("Expected host replacement failed")
+ }
+ if repl.Replace("This request method is {method}") != "This request method is GET" {
+ t.Error("Expected method replacement failed")
+ }
+ if repl.Replace("The response status is {status}") != "The response status is 201" {
+ t.Error("Expected status replacement failed")
+ }
+ if repl.Replace("The value of variable is {variable}") != "The value of variable is value" {
+ t.Error("Expected variable replacement failed")
+ }
+}
diff --git a/middleware/rewrite/condition.go b/middleware/rewrite/condition.go
new file mode 100644
index 000000000..ddd4c38b1
--- /dev/null
+++ b/middleware/rewrite/condition.go
@@ -0,0 +1,130 @@
+package rewrite
+
+import (
+ "fmt"
+ "regexp"
+ "strings"
+
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+// Operators
+const (
+ Is = "is"
+ Not = "not"
+ Has = "has"
+ NotHas = "not_has"
+ StartsWith = "starts_with"
+ EndsWith = "ends_with"
+ Match = "match"
+ NotMatch = "not_match"
+)
+
+func operatorError(operator string) error {
+ return fmt.Errorf("Invalid operator %v", operator)
+}
+
+func newReplacer(r *dns.Msg) middleware.Replacer {
+ return middleware.NewReplacer(r, nil, "")
+}
+
+// condition is a rewrite condition.
+type condition func(string, string) bool
+
+var conditions = map[string]condition{
+ Is: isFunc,
+ Not: notFunc,
+ Has: hasFunc,
+ NotHas: notHasFunc,
+ StartsWith: startsWithFunc,
+ EndsWith: endsWithFunc,
+ Match: matchFunc,
+ NotMatch: notMatchFunc,
+}
+
+// isFunc is condition for Is operator.
+// It checks for equality.
+func isFunc(a, b string) bool {
+ return a == b
+}
+
+// notFunc is condition for Not operator.
+// It checks for inequality.
+func notFunc(a, b string) bool {
+ return a != b
+}
+
+// hasFunc is condition for Has operator.
+// It checks if b is a substring of a.
+func hasFunc(a, b string) bool {
+ return strings.Contains(a, b)
+}
+
+// notHasFunc is condition for NotHas operator.
+// It checks if b is not a substring of a.
+func notHasFunc(a, b string) bool {
+ return !strings.Contains(a, b)
+}
+
+// startsWithFunc is condition for StartsWith operator.
+// It checks if b is a prefix of a.
+func startsWithFunc(a, b string) bool {
+ return strings.HasPrefix(a, b)
+}
+
+// endsWithFunc is condition for EndsWith operator.
+// It checks if b is a suffix of a.
+func endsWithFunc(a, b string) bool {
+ return strings.HasSuffix(a, b)
+}
+
+// matchFunc is condition for Match operator.
+// It does regexp matching of a against pattern in b
+// and returns if they match.
+func matchFunc(a, b string) bool {
+ matched, _ := regexp.MatchString(b, a)
+ return matched
+}
+
+// notMatchFunc is condition for NotMatch operator.
+// It does regexp matching of a against pattern in b
+// and returns if they do not match.
+func notMatchFunc(a, b string) bool {
+ matched, _ := regexp.MatchString(b, a)
+ return !matched
+}
+
+// If is statement for a rewrite condition.
+type If struct {
+ A string
+ Operator string
+ B string
+}
+
+// True returns true if the condition is true and false otherwise.
+// If r is not nil, it replaces placeholders before comparison.
+func (i If) True(r *dns.Msg) bool {
+ if c, ok := conditions[i.Operator]; ok {
+ a, b := i.A, i.B
+ if r != nil {
+ replacer := newReplacer(r)
+ a = replacer.Replace(i.A)
+ b = replacer.Replace(i.B)
+ }
+ return c(a, b)
+ }
+ return false
+}
+
+// NewIf creates a new If condition.
+func NewIf(a, operator, b string) (If, error) {
+ if _, ok := conditions[operator]; !ok {
+ return If{}, operatorError(operator)
+ }
+ return If{
+ A: a,
+ Operator: operator,
+ B: b,
+ }, nil
+}
diff --git a/middleware/rewrite/condition_test.go b/middleware/rewrite/condition_test.go
new file mode 100644
index 000000000..3c3b6053a
--- /dev/null
+++ b/middleware/rewrite/condition_test.go
@@ -0,0 +1,106 @@
+package rewrite
+
+import (
+ "net/http"
+ "strings"
+ "testing"
+)
+
+func TestConditions(t *testing.T) {
+ tests := []struct {
+ condition string
+ isTrue bool
+ }{
+ {"a is b", false},
+ {"a is a", true},
+ {"a not b", true},
+ {"a not a", false},
+ {"a has a", true},
+ {"a has b", false},
+ {"ba has b", true},
+ {"bab has b", true},
+ {"bab has bb", false},
+ {"a not_has a", false},
+ {"a not_has b", true},
+ {"ba not_has b", false},
+ {"bab not_has b", false},
+ {"bab not_has bb", true},
+ {"bab starts_with bb", false},
+ {"bab starts_with ba", true},
+ {"bab starts_with bab", true},
+ {"bab ends_with bb", false},
+ {"bab ends_with bab", true},
+ {"bab ends_with ab", true},
+ {"a match *", false},
+ {"a match a", true},
+ {"a match .*", true},
+ {"a match a.*", true},
+ {"a match b.*", false},
+ {"ba match b.*", true},
+ {"ba match b[a-z]", true},
+ {"b0 match b[a-z]", false},
+ {"b0a match b[a-z]", false},
+ {"b0a match b[a-z]+", false},
+ {"b0a match b[a-z0-9]+", true},
+ {"a not_match *", true},
+ {"a not_match a", false},
+ {"a not_match .*", false},
+ {"a not_match a.*", false},
+ {"a not_match b.*", true},
+ {"ba not_match b.*", false},
+ {"ba not_match b[a-z]", false},
+ {"b0 not_match b[a-z]", true},
+ {"b0a not_match b[a-z]", true},
+ {"b0a not_match b[a-z]+", true},
+ {"b0a not_match b[a-z0-9]+", false},
+ }
+
+ for i, test := range tests {
+ str := strings.Fields(test.condition)
+ ifCond, err := NewIf(str[0], str[1], str[2])
+ if err != nil {
+ t.Error(err)
+ }
+ isTrue := ifCond.True(nil)
+ if isTrue != test.isTrue {
+ t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
+ }
+ }
+
+ invalidOperators := []string{"ss", "and", "if"}
+ for _, op := range invalidOperators {
+ _, err := NewIf("a", op, "b")
+ if err == nil {
+ t.Errorf("Invalid operator %v used, expected error.", op)
+ }
+ }
+
+ replaceTests := []struct {
+ url string
+ condition string
+ isTrue bool
+ }{
+ {"/home", "{uri} match /home", true},
+ {"/hom", "{uri} match /home", false},
+ {"/hom", "{uri} starts_with /home", false},
+ {"/hom", "{uri} starts_with /h", true},
+ {"/home/.hiddenfile", `{uri} match \/\.(.*)`, true},
+ {"/home/.hiddendir/afile", `{uri} match \/\.(.*)`, true},
+ }
+
+ for i, test := range replaceTests {
+ r, err := http.NewRequest("GET", test.url, nil)
+ if err != nil {
+ t.Error(err)
+ }
+ str := strings.Fields(test.condition)
+ ifCond, err := NewIf(str[0], str[1], str[2])
+ if err != nil {
+ t.Error(err)
+ }
+ isTrue := ifCond.True(r)
+ if isTrue != test.isTrue {
+ t.Errorf("Test %v: expected %v found %v", i, test.isTrue, isTrue)
+ }
+ }
+}
diff --git a/middleware/rewrite/reverter.go b/middleware/rewrite/reverter.go
new file mode 100644
index 000000000..c3425866e
--- /dev/null
+++ b/middleware/rewrite/reverter.go
@@ -0,0 +1,38 @@
+package rewrite
+
+import "github.com/miekg/dns"
+
+// ResponseRevert reverses the operations done on the question section of a packet.
+// This is need because the client will otherwise disregards the response, i.e.
+// dig will complain with ';; Question section mismatch: got miek.nl/HINFO/IN'
+type ResponseReverter struct {
+ dns.ResponseWriter
+ original dns.Question
+}
+
+func NewResponseReverter(w dns.ResponseWriter, r *dns.Msg) *ResponseReverter {
+ return &ResponseReverter{
+ ResponseWriter: w,
+ original: r.Question[0],
+ }
+}
+
+// WriteMsg records the status code and calls the
+// underlying ResponseWriter's WriteMsg method.
+func (r *ResponseReverter) WriteMsg(res *dns.Msg) error {
+ res.Question[0] = r.original
+ return r.ResponseWriter.WriteMsg(res)
+}
+
+// Write is a wrapper that records the size of the message that gets written.
+func (r *ResponseReverter) Write(buf []byte) (int, error) {
+ n, err := r.ResponseWriter.Write(buf)
+ return n, err
+}
+
+// Hijack implements dns.Hijacker. It simply wraps the underlying
+// ResponseWriter's Hijack method if there is one, or returns an error.
+func (r *ResponseReverter) Hijack() {
+ r.ResponseWriter.Hijack()
+ return
+}
diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go
new file mode 100644
index 000000000..b3039615b
--- /dev/null
+++ b/middleware/rewrite/rewrite.go
@@ -0,0 +1,223 @@
+// Package rewrite is middleware for rewriting requests internally to
+// something different.
+package rewrite
+
+import (
+ "github.com/miekg/coredns/middleware"
+ "github.com/miekg/dns"
+)
+
+// Result is the result of a rewrite
+type Result int
+
+const (
+ // RewriteIgnored is returned when rewrite is not done on request.
+ RewriteIgnored Result = iota
+ // RewriteDone is returned when rewrite is done on request.
+ RewriteDone
+ // RewriteStatus is returned when rewrite is not needed and status code should be set
+ // for the request.
+ RewriteStatus
+)
+
+// Rewrite is middleware to rewrite requests internally before being handled.
+type Rewrite struct {
+ Next middleware.Handler
+ Rules []Rule
+}
+
+// ServeHTTP implements the middleware.Handler interface.
+func (rw Rewrite) ServeDNS(w dns.ResponseWriter, r *dns.Msg) (int, error) {
+ wr := NewResponseReverter(w, r)
+ for _, rule := range rw.Rules {
+ switch result := rule.Rewrite(r); result {
+ case RewriteDone:
+ return rw.Next.ServeDNS(wr, r)
+ case RewriteIgnored:
+ break
+ case RewriteStatus:
+ // only valid for complex rules.
+ // if cRule, ok := rule.(*ComplexRule); ok && cRule.Status != 0 {
+ // return cRule.Status, nil
+ // }
+ }
+ }
+ return rw.Next.ServeDNS(w, r)
+}
+
+// Rule describes an internal location rewrite rule.
+type Rule interface {
+ // Rewrite rewrites the internal location of the current request.
+ Rewrite(*dns.Msg) Result
+}
+
+// SimpleRule is a simple rewrite rule. If the From and To look like a type
+// the type of the request is rewritten, otherwise the name is.
+// Note: TSIG signed requests will be invalid.
+type SimpleRule struct {
+ From, To string
+ fromType, toType uint16
+}
+
+// NewSimpleRule creates a new Simple Rule
+func NewSimpleRule(from, to string) SimpleRule {
+ tpf := dns.StringToType[from]
+ tpt := dns.StringToType[to]
+
+ return SimpleRule{From: from, To: to, fromType: tpf, toType: tpt}
+}
+
+// Rewrite rewrites the the current request.
+func (s SimpleRule) Rewrite(r *dns.Msg) Result {
+ if s.fromType > 0 && s.toType > 0 {
+ if r.Question[0].Qtype == s.fromType {
+ r.Question[0].Qtype = s.toType
+ return RewriteDone
+ }
+
+ }
+
+ // if the question name matches the full name, or subset rewrite that
+ // s.Question[0].Name
+ return RewriteIgnored
+}
+
+/*
+// ComplexRule is a rewrite rule based on a regular expression
+type ComplexRule struct {
+ // Path base. Request to this path and subpaths will be rewritten
+ Base string
+
+ // Path to rewrite to
+ To string
+
+ // If set, neither performs rewrite nor proceeds
+ // with request. Only returns code.
+ Status int
+
+ // Extensions to filter by
+ Exts []string
+
+ // Rewrite conditions
+ Ifs []If
+
+ *regexp.Regexp
+}
+
+// NewComplexRule creates a new RegexpRule. It returns an error if regexp
+// pattern (pattern) or extensions (ext) are invalid.
+func NewComplexRule(base, pattern, to string, status int, ext []string, ifs []If) (*ComplexRule, error) {
+ // validate regexp if present
+ var r *regexp.Regexp
+ if pattern != "" {
+ var err error
+ r, err = regexp.Compile(pattern)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ // validate extensions if present
+ for _, v := range ext {
+ if len(v) < 2 || (len(v) < 3 && v[0] == '!') {
+ // check if no extension is specified
+ if v != "/" && v != "!/" {
+ return nil, fmt.Errorf("invalid extension %v", v)
+ }
+ }
+ }
+
+ return &ComplexRule{
+ Base: base,
+ To: to,
+ Status: status,
+ Exts: ext,
+ Ifs: ifs,
+ Regexp: r,
+ }, nil
+}
+
+// Rewrite rewrites the internal location of the current request.
+func (r *ComplexRule) Rewrite(req *dns.Msg) (re Result) {
+ rPath := req.URL.Path
+ replacer := newReplacer(req)
+
+ // validate base
+ if !middleware.Path(rPath).Matches(r.Base) {
+ return
+ }
+
+ // validate extensions
+ if !r.matchExt(rPath) {
+ return
+ }
+
+ // validate regexp if present
+ if r.Regexp != nil {
+ // include trailing slash in regexp if present
+ start := len(r.Base)
+ if strings.HasSuffix(r.Base, "/") {
+ start--
+ }
+
+ matches := r.FindStringSubmatch(rPath[start:])
+ switch len(matches) {
+ case 0:
+ // no match
+ return
+ default:
+ // set regexp match variables {1}, {2} ...
+ for i := 1; i < len(matches); i++ {
+ replacer.Set(fmt.Sprint(i), matches[i])
+ }
+ }
+ }
+
+ // validate rewrite conditions
+ for _, i := range r.Ifs {
+ if !i.True(req) {
+ return
+ }
+ }
+
+ // if status is present, stop rewrite and return it.
+ if r.Status != 0 {
+ return RewriteStatus
+ }
+
+ // attempt rewrite
+ return To(fs, req, r.To, replacer)
+}
+
+// matchExt matches rPath against registered file extensions.
+// Returns true if a match is found and false otherwise.
+func (r *ComplexRule) matchExt(rPath string) bool {
+ f := filepath.Base(rPath)
+ ext := path.Ext(f)
+ if ext == "" {
+ ext = "/"
+ }
+
+ mustUse := false
+ for _, v := range r.Exts {
+ use := true
+ if v[0] == '!' {
+ use = false
+ v = v[1:]
+ }
+
+ if use {
+ mustUse = true
+ }
+
+ if ext == v {
+ return use
+ }
+ }
+
+ if mustUse {
+ return false
+ }
+ return true
+}
+*/
diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go
new file mode 100644
index 000000000..f57dfd602
--- /dev/null
+++ b/middleware/rewrite/rewrite_test.go
@@ -0,0 +1,159 @@
+package rewrite
+
+import (
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/miekg/coredns/middleware"
+)
+
+func TestRewrite(t *testing.T) {
+ rw := Rewrite{
+ Next: middleware.HandlerFunc(urlPrinter),
+ Rules: []Rule{
+ NewSimpleRule("/from", "/to"),
+ NewSimpleRule("/a", "/b"),
+ NewSimpleRule("/b", "/b{uri}"),
+ },
+ FileSys: http.Dir("."),
+ }
+
+ regexps := [][]string{
+ {"/reg/", ".*", "/to", ""},
+ {"/r/", "[a-z]+", "/toaz", "!.html|"},
+ {"/url/", "a([a-z0-9]*)s([A-Z]{2})", "/to/{path}", ""},
+ {"/ab/", "ab", "/ab?{query}", ".txt|"},
+ {"/ab/", "ab", "/ab?type=html&{query}", ".html|"},
+ {"/abc/", "ab", "/abc/{file}", ".html|"},
+ {"/abcd/", "ab", "/a/{dir}/{file}", ".html|"},
+ {"/abcde/", "ab", "/a#{fragment}", ".html|"},
+ {"/ab/", `.*\.jpg`, "/ajpg", ""},
+ {"/reggrp", `/ad/([0-9]+)([a-z]*)`, "/a{1}/{2}", ""},
+ {"/reg2grp", `(.*)`, "/{1}", ""},
+ {"/reg3grp", `(.*)/(.*)/(.*)`, "/{1}{2}{3}", ""},
+ }
+
+ for _, regexpRule := range regexps {
+ var ext []string
+ if s := strings.Split(regexpRule[3], "|"); len(s) > 1 {
+ ext = s[:len(s)-1]
+ }
+ rule, err := NewComplexRule(regexpRule[0], regexpRule[1], regexpRule[2], 0, ext, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ rw.Rules = append(rw.Rules, rule)
+ }
+
+ tests := []struct {
+ from string
+ expectedTo string
+ }{
+ {"/from", "/to"},
+ {"/a", "/b"},
+ {"/b", "/b/b"},
+ {"/aa", "/aa"},
+ {"/", "/"},
+ {"/a?foo=bar", "/b?foo=bar"},
+ {"/asdf?foo=bar", "/asdf?foo=bar"},
+ {"/foo#bar", "/foo#bar"},
+ {"/a#foo", "/b#foo"},
+ {"/reg/foo", "/to"},
+ {"/re", "/re"},
+ {"/r/", "/r/"},
+ {"/r/123", "/r/123"},
+ {"/r/a123", "/toaz"},
+ {"/r/abcz", "/toaz"},
+ {"/r/z", "/toaz"},
+ {"/r/z.html", "/r/z.html"},
+ {"/r/z.js", "/toaz"},
+ {"/url/asAB", "/to/url/asAB"},
+ {"/url/aBsAB", "/url/aBsAB"},
+ {"/url/a00sAB", "/to/url/a00sAB"},
+ {"/url/a0z0sAB", "/to/url/a0z0sAB"},
+ {"/ab/aa", "/ab/aa"},
+ {"/ab/ab", "/ab/ab"},
+ {"/ab/ab.txt", "/ab"},
+ {"/ab/ab.txt?name=name", "/ab?name=name"},
+ {"/ab/ab.html?name=name", "/ab?type=html&name=name"},
+ {"/abc/ab.html", "/abc/ab.html"},
+ {"/abcd/abcd.html", "/a/abcd/abcd.html"},
+ {"/abcde/abcde.html", "/a"},
+ {"/abcde/abcde.html#1234", "/a#1234"},
+ {"/ab/ab.jpg", "/ajpg"},
+ {"/reggrp/ad/12", "/a12"},
+ {"/reggrp/ad/124a", "/a124/a"},
+ {"/reggrp/ad/124abc", "/a124/abc"},
+ {"/reg2grp/ad/124abc", "/ad/124abc"},
+ {"/reg3grp/ad/aa/66", "/adaa66"},
+ {"/reg3grp/ad612/n1n/ab", "/ad612n1nab"},
+ }
+
+ for i, test := range tests {
+ req, err := http.NewRequest("GET", test.from, nil)
+ if err != nil {
+ t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
+ }
+
+ rec := httptest.NewRecorder()
+ rw.ServeHTTP(rec, req)
+
+ if rec.Body.String() != test.expectedTo {
+ t.Errorf("Test %d: Expected URL to be '%s' but was '%s'",
+ i, test.expectedTo, rec.Body.String())
+ }
+ }
+
+ statusTests := []struct {
+ status int
+ base string
+ to string
+ regexp string
+ statusExpected bool
+ }{
+ {400, "/status", "", "", true},
+ {400, "/ignore", "", "", false},
+ {400, "/", "", "^/ignore", false},
+ {400, "/", "", "(.*)", true},
+ {400, "/status", "", "", true},
+ }
+
+ for i, s := range statusTests {
+ urlPath := fmt.Sprintf("/status%d", i)
+ rule, err := NewComplexRule(s.base, s.regexp, s.to, s.status, nil, nil)
+ if err != nil {
+ t.Fatalf("Test %d: No error expected for rule but found %v", i, err)
+ }
+ rw.Rules = []Rule{rule}
+ req, err := http.NewRequest("GET", urlPath, nil)
+ if err != nil {
+ t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
+ }
+
+ rec := httptest.NewRecorder()
+ code, err := rw.ServeHTTP(rec, req)
+ if err != nil {
+ t.Fatalf("Test %d: No error expected for handler but found %v", i, err)
+ }
+ if s.statusExpected {
+ if rec.Body.String() != "" {
+ t.Errorf("Test %d: Expected empty body but found %s", i, rec.Body.String())
+ }
+ if code != s.status {
+ t.Errorf("Test %d: Expected status code %d found %d", i, s.status, code)
+ }
+ } else {
+ if code != 0 {
+ t.Errorf("Test %d: Expected no status code found %d", i, code)
+ }
+ }
+ }
+}
+
+func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
+ fmt.Fprintf(w, r.URL.String())
+ return 0, nil
+}
diff --git a/middleware/rewrite/testdata/testdir/empty b/middleware/rewrite/testdata/testdir/empty
new file mode 100644
index 000000000..e69de29bb
--- /dev/null
+++ b/middleware/rewrite/testdata/testdir/empty
diff --git a/middleware/rewrite/testdata/testfile b/middleware/rewrite/testdata/testfile
new file mode 100644
index 000000000..7b4d68d70
--- /dev/null
+++ b/middleware/rewrite/testdata/testfile
@@ -0,0 +1 @@
+empty \ No newline at end of file
diff --git a/middleware/roller.go b/middleware/roller.go
new file mode 100644
index 000000000..995cabf91
--- /dev/null
+++ b/middleware/roller.go
@@ -0,0 +1,27 @@
+package middleware
+
+import (
+ "io"
+
+ "gopkg.in/natefinch/lumberjack.v2"
+)
+
+// LogRoller implements a middleware that provides a rolling logger.
+type LogRoller struct {
+ Filename string
+ MaxSize int
+ MaxAge int
+ MaxBackups int
+ LocalTime bool
+}
+
+// GetLogWriter returns an io.Writer that writes to a rolling logger.
+func (l LogRoller) GetLogWriter() io.Writer {
+ return &lumberjack.Logger{
+ Filename: l.Filename,
+ MaxSize: l.MaxSize,
+ MaxAge: l.MaxAge,
+ MaxBackups: l.MaxBackups,
+ LocalTime: l.LocalTime,
+ }
+}
diff --git a/middleware/zone.go b/middleware/zone.go
new file mode 100644
index 000000000..6798bca8e
--- /dev/null
+++ b/middleware/zone.go
@@ -0,0 +1,21 @@
+package middleware
+
+import "strings"
+
+type Zones []string
+
+// Matches checks to see if other matches p.
+// The match will return the most specific zones
+// that matches other. The empty string signals a not found
+// condition.
+func (z Zones) Matches(qname string) string {
+ zone := ""
+ for _, zname := range z {
+ if strings.HasSuffix(qname, zname) {
+ if len(zname) > len(zone) {
+ zone = zname
+ }
+ }
+ }
+ return zone
+}
diff --git a/server/config.go b/server/config.go
new file mode 100644
index 000000000..79a9ba84d
--- /dev/null
+++ b/server/config.go
@@ -0,0 +1,75 @@
+package server
+
+import (
+ "net"
+
+ "github.com/miekg/coredns/middleware"
+)
+
+// Config configuration for a single server.
+type Config struct {
+ // The hostname or IP on which to serve
+ Host string
+
+ // The host address to bind on - defaults to (virtual) Host if empty
+ BindHost string
+
+ // The port to listen on
+ Port string
+
+ // The directory from which to parse db files
+ Root string
+
+ // HTTPS configuration
+ TLS TLSConfig
+
+ // Middleware stack
+ Middleware []middleware.Middleware
+
+ // Startup is a list of functions (or methods) to execute at
+ // server startup and restart; these are executed before any
+ // parts of the server are configured, and the functions are
+ // blocking. These are good for setting up middlewares and
+ // starting goroutines.
+ Startup []func() error
+
+ // FirstStartup is like Startup but these functions only execute
+ // during the initial startup, not on subsequent restarts.
+ //
+ // (Note: The server does not ever run these on its own; it is up
+ // to the calling application to do so, and do so only once, as the
+ // server itself has no notion whether it's a restart or not.)
+ FirstStartup []func() error
+
+ // Functions (or methods) to execute when the server quits;
+ // these are executed in response to SIGINT and are blocking
+ Shutdown []func() error
+
+ // The path to the configuration file from which this was loaded
+ ConfigFile string
+
+ // The name of the application
+ AppName string
+
+ // The application's version
+ AppVersion string
+}
+
+// Address returns the host:port of c as a string.
+func (c Config) Address() string {
+ return net.JoinHostPort(c.Host, c.Port)
+}
+
+// TLSConfig describes how TLS should be configured and used.
+type TLSConfig struct {
+ Enabled bool // will be set to true if TLS is enabled
+ LetsEncryptEmail string
+ Manual bool // will be set to true if user provides own certs and keys
+ Managed bool // will be set to true if config qualifies for implicit automatic/managed HTTPS
+ OnDemand bool // will be set to true if user enables on-demand TLS (obtain certs during handshakes)
+ Ciphers []uint16
+ ProtocolMinVersion uint16
+ ProtocolMaxVersion uint16
+ PreferServerCipherSuites bool
+ ClientCerts []string
+}
diff --git a/server/config_test.go b/server/config_test.go
new file mode 100644
index 000000000..8787e467b
--- /dev/null
+++ b/server/config_test.go
@@ -0,0 +1,25 @@
+package server
+
+import "testing"
+
+func TestConfigAddress(t *testing.T) {
+ cfg := Config{Host: "foobar", Port: "1234"}
+ if actual, expected := cfg.Address(), "foobar:1234"; expected != actual {
+ t.Errorf("Expected '%s' but got '%s'", expected, actual)
+ }
+
+ cfg = Config{Host: "", Port: "1234"}
+ if actual, expected := cfg.Address(), ":1234"; expected != actual {
+ t.Errorf("Expected '%s' but got '%s'", expected, actual)
+ }
+
+ cfg = Config{Host: "foobar", Port: ""}
+ if actual, expected := cfg.Address(), "foobar:"; expected != actual {
+ t.Errorf("Expected '%s' but got '%s'", expected, actual)
+ }
+
+ cfg = Config{Host: "::1", Port: "443"}
+ if actual, expected := cfg.Address(), "[::1]:443"; expected != actual {
+ t.Errorf("Expected '%s' but got '%s'", expected, actual)
+ }
+}
diff --git a/server/graceful.go b/server/graceful.go
new file mode 100644
index 000000000..5057d039b
--- /dev/null
+++ b/server/graceful.go
@@ -0,0 +1,76 @@
+package server
+
+import (
+ "net"
+ "sync"
+ "syscall"
+)
+
+// newGracefulListener returns a gracefulListener that wraps l and
+// uses wg (stored in the host server) to count connections.
+func newGracefulListener(l ListenerFile, wg *sync.WaitGroup) *gracefulListener {
+ gl := &gracefulListener{ListenerFile: l, stop: make(chan error), httpWg: wg}
+ go func() {
+ <-gl.stop
+ gl.Lock()
+ gl.stopped = true
+ gl.Unlock()
+ gl.stop <- gl.ListenerFile.Close()
+ }()
+ return gl
+}
+
+// gracefuListener is a net.Listener which can
+// count the number of connections on it. Its
+// methods mainly wrap net.Listener to be graceful.
+type gracefulListener struct {
+ ListenerFile
+ stop chan error
+ stopped bool
+ sync.Mutex // protects the stopped flag
+ httpWg *sync.WaitGroup // pointer to the host's wg used for counting connections
+}
+
+// Accept accepts a connection.
+func (gl *gracefulListener) Accept() (c net.Conn, err error) {
+ c, err = gl.ListenerFile.Accept()
+ if err != nil {
+ return
+ }
+ c = gracefulConn{Conn: c, httpWg: gl.httpWg}
+ gl.httpWg.Add(1)
+ return
+}
+
+// Close immediately closes the listener.
+func (gl *gracefulListener) Close() error {
+ gl.Lock()
+ if gl.stopped {
+ gl.Unlock()
+ return syscall.EINVAL
+ }
+ gl.Unlock()
+ gl.stop <- nil
+ return <-gl.stop
+}
+
+// gracefulConn represents a connection on a
+// gracefulListener so that we can keep track
+// of the number of connections, thus facilitating
+// a graceful shutdown.
+type gracefulConn struct {
+ net.Conn
+ httpWg *sync.WaitGroup // pointer to the host server's connection waitgroup
+}
+
+// Close closes c's underlying connection while updating the wg count.
+func (c gracefulConn) Close() error {
+ err := c.Conn.Close()
+ if err != nil {
+ return err
+ }
+ // close can fail on http2 connections (as of Oct. 2015, before http2 in std lib)
+ // so don't decrement count unless close succeeds
+ c.httpWg.Done()
+ return nil
+}
diff --git a/server/server.go b/server/server.go
new file mode 100644
index 000000000..7baa74686
--- /dev/null
+++ b/server/server.go
@@ -0,0 +1,431 @@
+// Package server implements a configurable, general-purpose web server.
+// It relies on configurations obtained from the adjacent config package
+// and can execute middleware as defined by the adjacent middleware package.
+package server
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net"
+ "os"
+ "runtime"
+ "sync"
+ "time"
+
+ "github.com/miekg/dns"
+)
+
+// Server represents an instance of a server, which serves
+// DNS requests at a particular address (host and port). A
+// server is capable of serving numerous zones on
+// the same address and the listener may be stopped for
+// graceful termination (POSIX only).
+type Server struct {
+ Addr string // Address we listen on
+ mux *dns.ServeMux
+ tls bool // whether this server is serving all HTTPS hosts or not
+ TLSConfig *tls.Config
+ OnDemandTLS bool // whether this server supports on-demand TLS (load certs at handshake-time)
+ zones map[string]zone // zones keyed by their address
+ listener ListenerFile // the listener which is bound to the socket
+ listenerMu sync.Mutex // protects listener
+ dnsWg sync.WaitGroup // used to wait on outstanding connections
+ startChan chan struct{} // used to block until server is finished starting
+ connTimeout time.Duration // the maximum duration of a graceful shutdown
+ ReqCallback OptionalCallback // if non-nil, is executed at the beginning of every request
+ SNICallback func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)
+}
+
+// ListenerFile represents a listener.
+type ListenerFile interface {
+ net.Listener
+ File() (*os.File, error)
+}
+
+// OptionalCallback is a function that may or may not handle a request.
+// It returns whether or not it handled the request. If it handled the
+// request, it is presumed that no further request handling should occur.
+type OptionalCallback func(dns.ResponseWriter, *dns.Msg) bool
+
+// New creates a new Server which will bind to addr and serve
+// the sites/hosts configured in configs. Its listener will
+// gracefully close when the server is stopped which will take
+// no longer than gracefulTimeout.
+//
+// This function does not start serving.
+//
+// Do not re-use a server (start, stop, then start again). We
+// could probably add more locking to make this possible, but
+// as it stands, you should dispose of a server after stopping it.
+// The behavior of serving with a spent server is undefined.
+func New(addr string, configs []Config, gracefulTimeout time.Duration) (*Server, error) {
+ var useTLS, useOnDemandTLS bool
+ if len(configs) > 0 {
+ useTLS = configs[0].TLS.Enabled
+ useOnDemandTLS = configs[0].TLS.OnDemand
+ }
+
+ s := &Server{
+ Addr: addr,
+ TLSConfig: new(tls.Config),
+ // TODO: Make these values configurable?
+ // ReadTimeout: 2 * time.Minute,
+ // WriteTimeout: 2 * time.Minute,
+ // MaxHeaderBytes: 1 << 16,
+ tls: useTLS,
+ OnDemandTLS: useOnDemandTLS,
+ zones: make(map[string]zone),
+ startChan: make(chan struct{}),
+ connTimeout: gracefulTimeout,
+ }
+ mux := dns.NewServeMux()
+ mux.Handle(".", s) // wildcard handler, everything will go through here
+ s.mux = mux
+
+ // We have to bound our wg with one increment
+ // to prevent a "race condition" that is hard-coded
+ // into sync.WaitGroup.Wait() - basically, an add
+ // with a positive delta must be guaranteed to
+ // occur before Wait() is called on the wg.
+ // In a way, this kind of acts as a safety barrier.
+ s.dnsWg.Add(1)
+
+ // Set up each zone
+ for _, conf := range configs {
+ // TODO(miek): something better here?
+ if _, exists := s.zones[conf.Host]; exists {
+ return nil, fmt.Errorf("cannot serve %s - host already defined for address %s", conf.Address(), s.Addr)
+ }
+
+ z := zone{config: conf}
+
+ // Build middleware stack
+ err := z.buildStack()
+ if err != nil {
+ return nil, err
+ }
+
+ s.zones[conf.Host] = z
+ }
+
+ return s, nil
+}
+
+// Serve starts the server with an existing listener. It blocks until the
+// server stops.
+/*
+func (s *Server) Serve(ln ListenerFile) error {
+ // TODO(miek): Go DNS has no server stuff that allows you to give it a listener
+ // and use that.
+ err := s.setup()
+ if err != nil {
+ defer close(s.startChan) // MUST defer so error is properly reported, same with all cases in this file
+ return err
+ }
+ return s.serve(ln)
+}
+*/
+
+// ListenAndServe starts the server with a new listener. It blocks until the server stops.
+func (s *Server) ListenAndServe() error {
+ err := s.setup()
+ once := sync.Once{}
+
+ if err != nil {
+ close(s.startChan)
+ return err
+ }
+
+ // TODO(miek): redo to make it more like caddy
+ // - error handling, re-introduce what Caddy did.
+ go func() {
+ if err := dns.ListenAndServe(s.Addr, "tcp", s.mux); err != nil {
+ log.Printf("[ERROR] %v\n", err)
+ defer once.Do(func() { close(s.startChan) })
+ return
+ }
+ }()
+
+ go func() {
+ if err := dns.ListenAndServe(s.Addr, "udp", s.mux); err != nil {
+ log.Printf("[ERROR] %v\n", err)
+ defer once.Do(func() { close(s.startChan) })
+ return
+ }
+ }()
+ once.Do(func() { close(s.startChan) }) // unblock anyone waiting for this to start listening
+ // but block here, as this is what caddy expects
+ for {
+ select {}
+ }
+ return nil
+}
+
+// setup prepares the server s to begin listening; it should be
+// called just before the listener announces itself on the network
+// and should only be called when the server is just starting up.
+func (s *Server) setup() error {
+ // Execute startup functions now
+ for _, z := range s.zones {
+ for _, startupFunc := range z.config.Startup {
+ err := startupFunc()
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+/*
+TODO(miek): no such thing in the glorious Go DNS.
+// serveTLS serves TLS with SNI and client auth support if s has them enabled. It
+// blocks until s quits.
+func serveTLS(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error {
+ // Customize our TLS configuration
+ s.TLSConfig.MinVersion = tlsConfigs[0].ProtocolMinVersion
+ s.TLSConfig.MaxVersion = tlsConfigs[0].ProtocolMaxVersion
+ s.TLSConfig.CipherSuites = tlsConfigs[0].Ciphers
+ s.TLSConfig.PreferServerCipherSuites = tlsConfigs[0].PreferServerCipherSuites
+
+ // TLS client authentication, if user enabled it
+ err := setupClientAuth(tlsConfigs, s.TLSConfig)
+ if err != nil {
+ defer close(s.startChan)
+ return err
+ }
+
+ // Create TLS listener - note that we do not replace s.listener
+ // with this TLS listener; tls.listener is unexported and does
+ // not implement the File() method we need for graceful restarts
+ // on POSIX systems.
+ ln = tls.NewListener(ln, s.TLSConfig)
+
+ close(s.startChan) // unblock anyone waiting for this to start listening
+ return s.Serve(ln)
+}
+*/
+
+// Stop stops the server. It blocks until the server is
+// totally stopped. On POSIX systems, it will wait for
+// connections to close (up to a max timeout of a few
+// seconds); on Windows it will close the listener
+// immediately.
+func (s *Server) Stop() (err error) {
+
+ if runtime.GOOS != "windows" {
+ // force connections to close after timeout
+ done := make(chan struct{})
+ go func() {
+ s.dnsWg.Done() // decrement our initial increment used as a barrier
+ s.dnsWg.Wait()
+ close(done)
+ }()
+
+ // Wait for remaining connections to finish or
+ // force them all to close after timeout
+ select {
+ case <-time.After(s.connTimeout):
+ case <-done:
+ }
+ }
+
+ // Close the listener now; this stops the server without delay
+ s.listenerMu.Lock()
+ if s.listener != nil {
+ err = s.listener.Close()
+ }
+ s.listenerMu.Unlock()
+
+ return
+}
+
+// WaitUntilStarted blocks until the server s is started, meaning
+// that practically the next instruction is to start the server loop.
+// It also unblocks if the server encounters an error during startup.
+func (s *Server) WaitUntilStarted() {
+ <-s.startChan
+}
+
+// ListenerFd gets a dup'ed file of the listener. If there
+// is no underlying file, the return value will be nil. It
+// is the caller's responsibility to close the file.
+func (s *Server) ListenerFd() *os.File {
+ s.listenerMu.Lock()
+ defer s.listenerMu.Unlock()
+ if s.listener != nil {
+ file, _ := s.listener.File()
+ return file
+ }
+ return nil
+}
+
+// ServeDNS is the entry point for every request to the address that s
+// is bound to. It acts as a multiplexer for the requests zonename as
+// defined in the request so that the correct zone
+// (configuration and middleware stack) will handle the request.
+func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
+ defer func() {
+ // In case the user doesn't enable error middleware, we still
+ // need to make sure that we stay alive up here
+ if rec := recover(); rec != nil {
+ // TODO(miek): serverfailure return?
+ }
+ }()
+
+ // Execute the optional request callback if it exists
+ if s.ReqCallback != nil && s.ReqCallback(w, r) {
+ return
+ }
+
+ q := r.Question[0].Name
+ b := make([]byte, len(q))
+ off, end := 0, false
+ for {
+ l := len(q[off:])
+ for i := 0; i < l; i++ {
+ b[i] = q[off+i]
+ // normalize the name for the lookup
+ if b[i] >= 'A' && b[i] <= 'Z' {
+ b[i] |= ('a' - 'A')
+ }
+ }
+
+ if h, ok := s.zones[string(b[:l])]; ok {
+ if r.Question[0].Qtype != dns.TypeDS {
+ rcode, _ := h.stack.ServeDNS(w, r)
+ if rcode > 0 {
+ DefaultErrorFunc(w, r, rcode)
+ }
+ return
+ }
+ }
+ off, end = dns.NextLabel(q, off)
+ if end {
+ break
+ }
+ }
+ // Wildcard match, if we have found nothing try the root zone as a last resort.
+ if h, ok := s.zones["."]; ok {
+ rcode, _ := h.stack.ServeDNS(w, r)
+ if rcode > 0 {
+ DefaultErrorFunc(w, r, rcode)
+ }
+ return
+ }
+
+ // Still here? Error out with SERVFAIL and some logging
+ remoteHost := w.RemoteAddr().String()
+ DefaultErrorFunc(w, r, dns.RcodeServerFailure)
+
+ fmt.Fprintf(w, "No such zone at %s", s.Addr)
+ log.Printf("[INFO] %s - No such zone at %s (Remote: %s)", q, s.Addr, remoteHost)
+}
+
+// DefaultErrorFunc responds to an HTTP request with a simple description
+// of the specified HTTP status code.
+func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) {
+ answer := new(dns.Msg)
+ answer.SetRcode(r, rcode)
+ w.WriteMsg(answer)
+}
+
+// setupClientAuth sets up TLS client authentication only if
+// any of the TLS configs specified at least one cert file.
+func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error {
+ var clientAuth bool
+ for _, cfg := range tlsConfigs {
+ if len(cfg.ClientCerts) > 0 {
+ clientAuth = true
+ break
+ }
+ }
+
+ if clientAuth {
+ pool := x509.NewCertPool()
+ for _, cfg := range tlsConfigs {
+ for _, caFile := range cfg.ClientCerts {
+ caCrt, err := ioutil.ReadFile(caFile) // Anyone that gets a cert from this CA can connect
+ if err != nil {
+ return err
+ }
+ if !pool.AppendCertsFromPEM(caCrt) {
+ return fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile)
+ }
+ }
+ }
+ config.ClientCAs = pool
+ config.ClientAuth = tls.RequireAndVerifyClientCert
+ }
+
+ return nil
+}
+
+// RunFirstStartupFuncs runs all of the server's FirstStartup
+// callback functions unless one of them returns an error first.
+// It is the caller's responsibility to call this only once and
+// at the correct time. The functions here should not be executed
+// at restarts or where the user does not explicitly start a new
+// instance of the server.
+func (s *Server) RunFirstStartupFuncs() error {
+ for _, z := range s.zones {
+ for _, f := range z.config.FirstStartup {
+ if err := f(); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
+// connections. It's used by ListenAndServe and ListenAndServeTLS so
+// dead TCP connections (e.g. closing laptop mid-download) eventually
+// go away.
+//
+// Borrowed from the Go standard library.
+type tcpKeepAliveListener struct {
+ *net.TCPListener
+}
+
+// Accept accepts the connection with a keep-alive enabled.
+func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
+ tc, err := ln.AcceptTCP()
+ if err != nil {
+ return
+ }
+ tc.SetKeepAlive(true)
+ tc.SetKeepAlivePeriod(3 * time.Minute)
+ return tc, nil
+}
+
+// File implements ListenerFile; returns the underlying file of the listener.
+func (ln tcpKeepAliveListener) File() (*os.File, error) {
+ return ln.TCPListener.File()
+}
+
+// ShutdownCallbacks executes all the shutdown callbacks
+// for all the virtualhosts in servers, and returns all the
+// errors generated during their execution. In other words,
+// an error executing one shutdown callback does not stop
+// execution of others. Only one shutdown callback is executed
+// at a time. You must protect the servers that are passed in
+// if they are shared across threads.
+func ShutdownCallbacks(servers []*Server) []error {
+ var errs []error
+ for _, s := range servers {
+ for _, zone := range s.zones {
+ for _, shutdownFunc := range zone.config.Shutdown {
+ err := shutdownFunc()
+ if err != nil {
+ errs = append(errs, err)
+ }
+ }
+ }
+ }
+ return errs
+}
diff --git a/server/zones.go b/server/zones.go
new file mode 100644
index 000000000..6a5a7a938
--- /dev/null
+++ b/server/zones.go
@@ -0,0 +1,28 @@
+package server
+
+import "github.com/miekg/coredns/middleware"
+
+// zone represents a DNS zone. While a Server
+// is what actually binds to the address, a user may want to serve
+// multiple zones on a single address, and this is what a
+// zone allows us to do.
+type zone struct {
+ config Config
+ stack middleware.Handler
+}
+
+// buildStack builds the server's middleware stack based
+// on its config. This method should be called last before
+// ListenAndServe begins.
+func (z *zone) buildStack() error {
+ z.compile(z.config.Middleware)
+ return nil
+}
+
+// compile is an elegant alternative to nesting middleware function
+// calls like handler1(handler2(handler3(finalHandler))).
+func (z *zone) compile(layers []middleware.Middleware) {
+ for i := len(layers) - 1; i >= 0; i-- {
+ z.stack = layers[i](z.stack)
+ }
+}