diff options
Diffstat (limited to 'libgo/go/exp/sql/sql.go')
-rw-r--r-- | libgo/go/exp/sql/sql.go | 78 |
1 files changed, 57 insertions, 21 deletions
diff --git a/libgo/go/exp/sql/sql.go b/libgo/go/exp/sql/sql.go index 291af7f67dc..c055fdd68c6 100644 --- a/libgo/go/exp/sql/sql.go +++ b/libgo/go/exp/sql/sql.go @@ -88,8 +88,9 @@ type DB struct { driver driver.Driver dsn string - mu sync.Mutex + mu sync.Mutex // protects freeConn and closed freeConn []driver.Conn + closed bool } // Open opens a database specified by its database driver name and a @@ -106,6 +107,22 @@ func Open(driverName, dataSourceName string) (*DB, error) { return &DB{driver: driver, dsn: dataSourceName}, nil } +// Close closes the database, releasing any open resources. +func (db *DB) Close() error { + db.mu.Lock() + defer db.mu.Unlock() + var err error + for _, c := range db.freeConn { + err1 := c.Close() + if err1 != nil { + err = err1 + } + } + db.freeConn = nil + db.closed = true + return err +} + func (db *DB) maxIdleConns() int { const defaultMaxIdleConns = 2 // TODO(bradfitz): ask driver, if supported, for its default preference @@ -116,6 +133,9 @@ func (db *DB) maxIdleConns() int { // conn returns a newly-opened or cached driver.Conn func (db *DB) conn() (driver.Conn, error) { db.mu.Lock() + if db.closed { + return nil, errors.New("sql: database is closed") + } if n := len(db.freeConn); n > 0 { conn := db.freeConn[n-1] db.freeConn = db.freeConn[:n-1] @@ -140,11 +160,13 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { } func (db *DB) putConn(c driver.Conn) { - if n := len(db.freeConn); n < db.maxIdleConns() { + db.mu.Lock() + defer db.mu.Unlock() + if n := len(db.freeConn); !db.closed && n < db.maxIdleConns() { db.freeConn = append(db.freeConn, c) return } - db.closeConn(c) + db.closeConn(c) // TODO(bradfitz): release lock before calling this? } func (db *DB) closeConn(c driver.Conn) { @@ -180,17 +202,11 @@ func (db *DB) Prepare(query string) (*Stmt, error) { // Exec executes a query without returning any rows. func (db *DB) Exec(query string, args ...interface{}) (Result, error) { - // Optional fast path, if the driver implements driver.Execer. - if execer, ok := db.driver.(driver.Execer); ok { - resi, err := execer.Exec(query, args) - if err != nil { - return nil, err - } - return result{resi}, nil + sargs, err := subsetTypeArgs(args) + if err != nil { + return nil, err } - // If the driver does not implement driver.Execer, we need - // a connection. ci, err := db.conn() if err != nil { return nil, err @@ -198,11 +214,13 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { defer db.putConn(ci) if execer, ok := ci.(driver.Execer); ok { - resi, err := execer.Exec(query, args) - if err != nil { - return nil, err + resi, err := execer.Exec(query, sargs) + if err != driver.ErrSkip { + if err != nil { + return nil, err + } + return result{resi}, nil } - return result{resi}, nil } sti, err := ci.Prepare(query) @@ -210,7 +228,8 @@ func (db *DB) Exec(query string, args ...interface{}) (Result, error) { return nil, err } defer sti.Close() - resi, err := sti.Exec(args) + + resi, err := sti.Exec(sargs) if err != nil { return nil, err } @@ -386,7 +405,13 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { return nil, err } defer sti.Close() - resi, err := sti.Exec(args) + + sargs, err := subsetTypeArgs(args) + if err != nil { + return nil, err + } + + resi, err := sti.Exec(sargs) if err != nil { return nil, err } @@ -449,7 +474,10 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { } defer releaseConn() - if want := si.NumInput(); len(args) != want { + // -1 means the driver doesn't know how to count the number of + // placeholders, so we won't sanity check input here and instead let the + // driver deal with errors. + if want := si.NumInput(); want != -1 && len(args) != want { return nil, fmt.Errorf("db: expected %d arguments, got %d", want, len(args)) } @@ -545,10 +573,18 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { if err != nil { return nil, err } - if len(args) != si.NumInput() { + + // -1 means the driver doesn't know how to count the number of + // placeholders, so we won't sanity check input here and instead let the + // driver deal with errors. + if want := si.NumInput(); want != -1 && len(args) != want { return nil, fmt.Errorf("db: statement expects %d inputs; got %d", si.NumInput(), len(args)) } - rowsi, err := si.Query(args) + sargs, err := subsetTypeArgs(args) + if err != nil { + return nil, err + } + rowsi, err := si.Query(sargs) if err != nil { s.db.putConn(ci) return nil, err |