aboutsummaryrefslogtreecommitdiff
path: root/vendor/github.com/lib/pq/conn_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/lib/pq/conn_test.go')
-rw-r--r--vendor/github.com/lib/pq/conn_test.go67
1 files changed, 61 insertions, 6 deletions
diff --git a/vendor/github.com/lib/pq/conn_test.go b/vendor/github.com/lib/pq/conn_test.go
index 030a798c..e654b85b 100644
--- a/vendor/github.com/lib/pq/conn_test.go
+++ b/vendor/github.com/lib/pq/conn_test.go
@@ -1,6 +1,7 @@
package pq
import (
+ "context"
"database/sql"
"database/sql/driver"
"fmt"
@@ -28,7 +29,7 @@ func forceBinaryParameters() bool {
}
}
-func openTestConnConninfo(conninfo string) (*sql.DB, error) {
+func testConninfo(conninfo string) string {
defaultTo := func(envvar string, value string) {
if os.Getenv(envvar) == "" {
os.Setenv(envvar, value)
@@ -43,8 +44,11 @@ func openTestConnConninfo(conninfo string) (*sql.DB, error) {
!strings.HasPrefix(conninfo, "postgresql://") {
conninfo = conninfo + " binary_parameters=yes"
}
+ return conninfo
+}
- return sql.Open("postgres", conninfo)
+func openTestConnConninfo(conninfo string) (*sql.DB, error) {
+ return sql.Open("postgres", testConninfo(conninfo))
}
func openTestConn(t Fatalistic) *sql.DB {
@@ -637,6 +641,57 @@ func TestErrorDuringStartup(t *testing.T) {
}
}
+type testConn struct {
+ closed bool
+ net.Conn
+}
+
+func (c *testConn) Close() error {
+ c.closed = true
+ return c.Conn.Close()
+}
+
+type testDialer struct {
+ conns []*testConn
+}
+
+func (d *testDialer) Dial(ntw, addr string) (net.Conn, error) {
+ c, err := net.Dial(ntw, addr)
+ if err != nil {
+ return nil, err
+ }
+ tc := &testConn{Conn: c}
+ d.conns = append(d.conns, tc)
+ return tc, nil
+}
+
+func (d *testDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
+ c, err := net.DialTimeout(ntw, addr, timeout)
+ if err != nil {
+ return nil, err
+ }
+ tc := &testConn{Conn: c}
+ d.conns = append(d.conns, tc)
+ return tc, nil
+}
+
+func TestErrorDuringStartupClosesConn(t *testing.T) {
+ // Don't use the normal connection setup, this is intended to
+ // blow up in the startup packet from a non-existent user.
+ var d testDialer
+ c, err := DialOpen(&d, testConninfo("user=thisuserreallydoesntexist"))
+ if err == nil {
+ c.Close()
+ t.Fatal("expected dial error")
+ }
+ if len(d.conns) != 1 {
+ t.Fatalf("got len(d.conns) = %d, want = %d", len(d.conns), 1)
+ }
+ if !d.conns[0].closed {
+ t.Error("connection leaked")
+ }
+}
+
func TestBadConn(t *testing.T) {
var err error
@@ -1209,8 +1264,8 @@ func TestParseComplete(t *testing.T) {
// Test interface conformance.
var (
- _ driver.Execer = (*conn)(nil)
- _ driver.Queryer = (*conn)(nil)
+ _ driver.ExecerContext = (*conn)(nil)
+ _ driver.QueryerContext = (*conn)(nil)
)
func TestNullAfterNonNull(t *testing.T) {
@@ -1555,10 +1610,10 @@ func TestRowsResultTag(t *testing.T) {
t.Fatal(err)
}
defer conn.Close()
- q := conn.(driver.Queryer)
+ q := conn.(driver.QueryerContext)
for _, test := range tests {
- if rows, err := q.Query(test.query, nil); err != nil {
+ if rows, err := q.QueryContext(context.Background(), test.query, nil); err != nil {
t.Fatalf("%s: %s", test.query, err)
} else {
r := rows.(ResultTag)