diff options
Diffstat (limited to 'server/server.go')
-rw-r--r-- | server/server.go | 431 |
1 files changed, 431 insertions, 0 deletions
diff --git a/server/server.go b/server/server.go new file mode 100644 index 000000000..7baa74686 --- /dev/null +++ b/server/server.go @@ -0,0 +1,431 @@ +// Package server implements a configurable, general-purpose web server. +// It relies on configurations obtained from the adjacent config package +// and can execute middleware as defined by the adjacent middleware package. +package server + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "runtime" + "sync" + "time" + + "github.com/miekg/dns" +) + +// Server represents an instance of a server, which serves +// DNS requests at a particular address (host and port). A +// server is capable of serving numerous zones on +// the same address and the listener may be stopped for +// graceful termination (POSIX only). +type Server struct { + Addr string // Address we listen on + mux *dns.ServeMux + tls bool // whether this server is serving all HTTPS hosts or not + TLSConfig *tls.Config + OnDemandTLS bool // whether this server supports on-demand TLS (load certs at handshake-time) + zones map[string]zone // zones keyed by their address + listener ListenerFile // the listener which is bound to the socket + listenerMu sync.Mutex // protects listener + dnsWg sync.WaitGroup // used to wait on outstanding connections + startChan chan struct{} // used to block until server is finished starting + connTimeout time.Duration // the maximum duration of a graceful shutdown + ReqCallback OptionalCallback // if non-nil, is executed at the beginning of every request + SNICallback func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) +} + +// ListenerFile represents a listener. +type ListenerFile interface { + net.Listener + File() (*os.File, error) +} + +// OptionalCallback is a function that may or may not handle a request. +// It returns whether or not it handled the request. If it handled the +// request, it is presumed that no further request handling should occur. +type OptionalCallback func(dns.ResponseWriter, *dns.Msg) bool + +// New creates a new Server which will bind to addr and serve +// the sites/hosts configured in configs. Its listener will +// gracefully close when the server is stopped which will take +// no longer than gracefulTimeout. +// +// This function does not start serving. +// +// Do not re-use a server (start, stop, then start again). We +// could probably add more locking to make this possible, but +// as it stands, you should dispose of a server after stopping it. +// The behavior of serving with a spent server is undefined. +func New(addr string, configs []Config, gracefulTimeout time.Duration) (*Server, error) { + var useTLS, useOnDemandTLS bool + if len(configs) > 0 { + useTLS = configs[0].TLS.Enabled + useOnDemandTLS = configs[0].TLS.OnDemand + } + + s := &Server{ + Addr: addr, + TLSConfig: new(tls.Config), + // TODO: Make these values configurable? + // ReadTimeout: 2 * time.Minute, + // WriteTimeout: 2 * time.Minute, + // MaxHeaderBytes: 1 << 16, + tls: useTLS, + OnDemandTLS: useOnDemandTLS, + zones: make(map[string]zone), + startChan: make(chan struct{}), + connTimeout: gracefulTimeout, + } + mux := dns.NewServeMux() + mux.Handle(".", s) // wildcard handler, everything will go through here + s.mux = mux + + // We have to bound our wg with one increment + // to prevent a "race condition" that is hard-coded + // into sync.WaitGroup.Wait() - basically, an add + // with a positive delta must be guaranteed to + // occur before Wait() is called on the wg. + // In a way, this kind of acts as a safety barrier. + s.dnsWg.Add(1) + + // Set up each zone + for _, conf := range configs { + // TODO(miek): something better here? + if _, exists := s.zones[conf.Host]; exists { + return nil, fmt.Errorf("cannot serve %s - host already defined for address %s", conf.Address(), s.Addr) + } + + z := zone{config: conf} + + // Build middleware stack + err := z.buildStack() + if err != nil { + return nil, err + } + + s.zones[conf.Host] = z + } + + return s, nil +} + +// Serve starts the server with an existing listener. It blocks until the +// server stops. +/* +func (s *Server) Serve(ln ListenerFile) error { + // TODO(miek): Go DNS has no server stuff that allows you to give it a listener + // and use that. + err := s.setup() + if err != nil { + defer close(s.startChan) // MUST defer so error is properly reported, same with all cases in this file + return err + } + return s.serve(ln) +} +*/ + +// ListenAndServe starts the server with a new listener. It blocks until the server stops. +func (s *Server) ListenAndServe() error { + err := s.setup() + once := sync.Once{} + + if err != nil { + close(s.startChan) + return err + } + + // TODO(miek): redo to make it more like caddy + // - error handling, re-introduce what Caddy did. + go func() { + if err := dns.ListenAndServe(s.Addr, "tcp", s.mux); err != nil { + log.Printf("[ERROR] %v\n", err) + defer once.Do(func() { close(s.startChan) }) + return + } + }() + + go func() { + if err := dns.ListenAndServe(s.Addr, "udp", s.mux); err != nil { + log.Printf("[ERROR] %v\n", err) + defer once.Do(func() { close(s.startChan) }) + return + } + }() + once.Do(func() { close(s.startChan) }) // unblock anyone waiting for this to start listening + // but block here, as this is what caddy expects + for { + select {} + } + return nil +} + +// setup prepares the server s to begin listening; it should be +// called just before the listener announces itself on the network +// and should only be called when the server is just starting up. +func (s *Server) setup() error { + // Execute startup functions now + for _, z := range s.zones { + for _, startupFunc := range z.config.Startup { + err := startupFunc() + if err != nil { + return err + } + } + } + + return nil +} + +/* +TODO(miek): no such thing in the glorious Go DNS. +// serveTLS serves TLS with SNI and client auth support if s has them enabled. It +// blocks until s quits. +func serveTLS(s *Server, ln net.Listener, tlsConfigs []TLSConfig) error { + // Customize our TLS configuration + s.TLSConfig.MinVersion = tlsConfigs[0].ProtocolMinVersion + s.TLSConfig.MaxVersion = tlsConfigs[0].ProtocolMaxVersion + s.TLSConfig.CipherSuites = tlsConfigs[0].Ciphers + s.TLSConfig.PreferServerCipherSuites = tlsConfigs[0].PreferServerCipherSuites + + // TLS client authentication, if user enabled it + err := setupClientAuth(tlsConfigs, s.TLSConfig) + if err != nil { + defer close(s.startChan) + return err + } + + // Create TLS listener - note that we do not replace s.listener + // with this TLS listener; tls.listener is unexported and does + // not implement the File() method we need for graceful restarts + // on POSIX systems. + ln = tls.NewListener(ln, s.TLSConfig) + + close(s.startChan) // unblock anyone waiting for this to start listening + return s.Serve(ln) +} +*/ + +// Stop stops the server. It blocks until the server is +// totally stopped. On POSIX systems, it will wait for +// connections to close (up to a max timeout of a few +// seconds); on Windows it will close the listener +// immediately. +func (s *Server) Stop() (err error) { + + if runtime.GOOS != "windows" { + // force connections to close after timeout + done := make(chan struct{}) + go func() { + s.dnsWg.Done() // decrement our initial increment used as a barrier + s.dnsWg.Wait() + close(done) + }() + + // Wait for remaining connections to finish or + // force them all to close after timeout + select { + case <-time.After(s.connTimeout): + case <-done: + } + } + + // Close the listener now; this stops the server without delay + s.listenerMu.Lock() + if s.listener != nil { + err = s.listener.Close() + } + s.listenerMu.Unlock() + + return +} + +// WaitUntilStarted blocks until the server s is started, meaning +// that practically the next instruction is to start the server loop. +// It also unblocks if the server encounters an error during startup. +func (s *Server) WaitUntilStarted() { + <-s.startChan +} + +// ListenerFd gets a dup'ed file of the listener. If there +// is no underlying file, the return value will be nil. It +// is the caller's responsibility to close the file. +func (s *Server) ListenerFd() *os.File { + s.listenerMu.Lock() + defer s.listenerMu.Unlock() + if s.listener != nil { + file, _ := s.listener.File() + return file + } + return nil +} + +// ServeDNS is the entry point for every request to the address that s +// is bound to. It acts as a multiplexer for the requests zonename as +// defined in the request so that the correct zone +// (configuration and middleware stack) will handle the request. +func (s *Server) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + defer func() { + // In case the user doesn't enable error middleware, we still + // need to make sure that we stay alive up here + if rec := recover(); rec != nil { + // TODO(miek): serverfailure return? + } + }() + + // Execute the optional request callback if it exists + if s.ReqCallback != nil && s.ReqCallback(w, r) { + return + } + + q := r.Question[0].Name + b := make([]byte, len(q)) + off, end := 0, false + for { + l := len(q[off:]) + for i := 0; i < l; i++ { + b[i] = q[off+i] + // normalize the name for the lookup + if b[i] >= 'A' && b[i] <= 'Z' { + b[i] |= ('a' - 'A') + } + } + + if h, ok := s.zones[string(b[:l])]; ok { + if r.Question[0].Qtype != dns.TypeDS { + rcode, _ := h.stack.ServeDNS(w, r) + if rcode > 0 { + DefaultErrorFunc(w, r, rcode) + } + return + } + } + off, end = dns.NextLabel(q, off) + if end { + break + } + } + // Wildcard match, if we have found nothing try the root zone as a last resort. + if h, ok := s.zones["."]; ok { + rcode, _ := h.stack.ServeDNS(w, r) + if rcode > 0 { + DefaultErrorFunc(w, r, rcode) + } + return + } + + // Still here? Error out with SERVFAIL and some logging + remoteHost := w.RemoteAddr().String() + DefaultErrorFunc(w, r, dns.RcodeServerFailure) + + fmt.Fprintf(w, "No such zone at %s", s.Addr) + log.Printf("[INFO] %s - No such zone at %s (Remote: %s)", q, s.Addr, remoteHost) +} + +// DefaultErrorFunc responds to an HTTP request with a simple description +// of the specified HTTP status code. +func DefaultErrorFunc(w dns.ResponseWriter, r *dns.Msg, rcode int) { + answer := new(dns.Msg) + answer.SetRcode(r, rcode) + w.WriteMsg(answer) +} + +// setupClientAuth sets up TLS client authentication only if +// any of the TLS configs specified at least one cert file. +func setupClientAuth(tlsConfigs []TLSConfig, config *tls.Config) error { + var clientAuth bool + for _, cfg := range tlsConfigs { + if len(cfg.ClientCerts) > 0 { + clientAuth = true + break + } + } + + if clientAuth { + pool := x509.NewCertPool() + for _, cfg := range tlsConfigs { + for _, caFile := range cfg.ClientCerts { + caCrt, err := ioutil.ReadFile(caFile) // Anyone that gets a cert from this CA can connect + if err != nil { + return err + } + if !pool.AppendCertsFromPEM(caCrt) { + return fmt.Errorf("error loading client certificate '%s': no certificates were successfully parsed", caFile) + } + } + } + config.ClientCAs = pool + config.ClientAuth = tls.RequireAndVerifyClientCert + } + + return nil +} + +// RunFirstStartupFuncs runs all of the server's FirstStartup +// callback functions unless one of them returns an error first. +// It is the caller's responsibility to call this only once and +// at the correct time. The functions here should not be executed +// at restarts or where the user does not explicitly start a new +// instance of the server. +func (s *Server) RunFirstStartupFuncs() error { + for _, z := range s.zones { + for _, f := range z.config.FirstStartup { + if err := f(); err != nil { + return err + } + } + } + return nil +} + +// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by ListenAndServe and ListenAndServeTLS so +// dead TCP connections (e.g. closing laptop mid-download) eventually +// go away. +// +// Borrowed from the Go standard library. +type tcpKeepAliveListener struct { + *net.TCPListener +} + +// Accept accepts the connection with a keep-alive enabled. +func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { + tc, err := ln.AcceptTCP() + if err != nil { + return + } + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(3 * time.Minute) + return tc, nil +} + +// File implements ListenerFile; returns the underlying file of the listener. +func (ln tcpKeepAliveListener) File() (*os.File, error) { + return ln.TCPListener.File() +} + +// ShutdownCallbacks executes all the shutdown callbacks +// for all the virtualhosts in servers, and returns all the +// errors generated during their execution. In other words, +// an error executing one shutdown callback does not stop +// execution of others. Only one shutdown callback is executed +// at a time. You must protect the servers that are passed in +// if they are shared across threads. +func ShutdownCallbacks(servers []*Server) []error { + var errs []error + for _, s := range servers { + for _, zone := range s.zones { + for _, shutdownFunc := range zone.config.Shutdown { + err := shutdownFunc() + if err != nil { + errs = append(errs, err) + } + } + } + } + return errs +} |