diff options
Diffstat (limited to 'vendor/github.com/lib/pq/conn_test.go')
-rw-r--r-- | vendor/github.com/lib/pq/conn_test.go | 67 |
1 files changed, 61 insertions, 6 deletions
diff --git a/vendor/github.com/lib/pq/conn_test.go b/vendor/github.com/lib/pq/conn_test.go index 030a798c..e654b85b 100644 --- a/vendor/github.com/lib/pq/conn_test.go +++ b/vendor/github.com/lib/pq/conn_test.go @@ -1,6 +1,7 @@ package pq import ( + "context" "database/sql" "database/sql/driver" "fmt" @@ -28,7 +29,7 @@ func forceBinaryParameters() bool { } } -func openTestConnConninfo(conninfo string) (*sql.DB, error) { +func testConninfo(conninfo string) string { defaultTo := func(envvar string, value string) { if os.Getenv(envvar) == "" { os.Setenv(envvar, value) @@ -43,8 +44,11 @@ func openTestConnConninfo(conninfo string) (*sql.DB, error) { !strings.HasPrefix(conninfo, "postgresql://") { conninfo = conninfo + " binary_parameters=yes" } + return conninfo +} - return sql.Open("postgres", conninfo) +func openTestConnConninfo(conninfo string) (*sql.DB, error) { + return sql.Open("postgres", testConninfo(conninfo)) } func openTestConn(t Fatalistic) *sql.DB { @@ -637,6 +641,57 @@ func TestErrorDuringStartup(t *testing.T) { } } +type testConn struct { + closed bool + net.Conn +} + +func (c *testConn) Close() error { + c.closed = true + return c.Conn.Close() +} + +type testDialer struct { + conns []*testConn +} + +func (d *testDialer) Dial(ntw, addr string) (net.Conn, error) { + c, err := net.Dial(ntw, addr) + if err != nil { + return nil, err + } + tc := &testConn{Conn: c} + d.conns = append(d.conns, tc) + return tc, nil +} + +func (d *testDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) { + c, err := net.DialTimeout(ntw, addr, timeout) + if err != nil { + return nil, err + } + tc := &testConn{Conn: c} + d.conns = append(d.conns, tc) + return tc, nil +} + +func TestErrorDuringStartupClosesConn(t *testing.T) { + // Don't use the normal connection setup, this is intended to + // blow up in the startup packet from a non-existent user. + var d testDialer + c, err := DialOpen(&d, testConninfo("user=thisuserreallydoesntexist")) + if err == nil { + c.Close() + t.Fatal("expected dial error") + } + if len(d.conns) != 1 { + t.Fatalf("got len(d.conns) = %d, want = %d", len(d.conns), 1) + } + if !d.conns[0].closed { + t.Error("connection leaked") + } +} + func TestBadConn(t *testing.T) { var err error @@ -1209,8 +1264,8 @@ func TestParseComplete(t *testing.T) { // Test interface conformance. var ( - _ driver.Execer = (*conn)(nil) - _ driver.Queryer = (*conn)(nil) + _ driver.ExecerContext = (*conn)(nil) + _ driver.QueryerContext = (*conn)(nil) ) func TestNullAfterNonNull(t *testing.T) { @@ -1555,10 +1610,10 @@ func TestRowsResultTag(t *testing.T) { t.Fatal(err) } defer conn.Close() - q := conn.(driver.Queryer) + q := conn.(driver.QueryerContext) for _, test := range tests { - if rows, err := q.Query(test.query, nil); err != nil { + if rows, err := q.QueryContext(context.Background(), test.query, nil); err != nil { t.Fatalf("%s: %s", test.query, err) } else { r := rows.(ResultTag) |