From 9dac382fbfb9cc0e92ea9d77504b103f08826b7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Torkel=20=C3=96degaard?= Date: Wed, 10 Feb 2016 10:09:24 +0100 Subject: [PATCH] dependency(): updated go lib go-xorm --- Godeps/Godeps.json | 9 +- .../src/github.com/go-xorm/core/LICENSE | 27 + .../src/github.com/go-xorm/core/cache.go | 19 +- .../src/github.com/go-xorm/core/column.go | 60 +- .../src/github.com/go-xorm/core/db.go | 125 ++- .../src/github.com/go-xorm/core/db_test.go | 8 +- .../src/github.com/go-xorm/core/dialect.go | 63 +- .../src/github.com/go-xorm/core/error.go | 2 +- .../src/github.com/go-xorm/core/scan.go | 52 + .../src/github.com/go-xorm/core/type.go | 10 +- .../src/github.com/go-xorm/xorm/LICENSE | 2 +- .../src/github.com/go-xorm/xorm/README.md | 186 +++- .../src/github.com/go-xorm/xorm/README_CN.md | 178 +++- .../src/github.com/go-xorm/xorm/VERSION | 2 +- .../src/github.com/go-xorm/xorm/doc.go | 2 +- .../src/github.com/go-xorm/xorm/engine.go | 137 ++- .../src/github.com/go-xorm/xorm/error.go | 4 + .../xorm/examples/goroutine.db-journal | Bin 2576 -> 0 bytes .../go-xorm/xorm/examples/goroutine.go | 4 +- .../github.com/go-xorm/xorm/goracle_driver.go | 4 + .../src/github.com/go-xorm/xorm/helpers.go | 73 +- .../github.com/go-xorm/xorm/helpers_test.go | 22 + .../src/github.com/go-xorm/xorm/logger.go | 4 + .../src/github.com/go-xorm/xorm/lru_cacher.go | 5 +- .../xorm/{memroy_store.go => memory_store.go} | 5 +- .../github.com/go-xorm/xorm/mssql_dialect.go | 19 +- .../github.com/go-xorm/xorm/mymysql_driver.go | 4 + .../github.com/go-xorm/xorm/mysql_dialect.go | 6 + .../github.com/go-xorm/xorm/mysql_driver.go | 4 + .../github.com/go-xorm/xorm/oci8_driver.go | 4 + .../github.com/go-xorm/xorm/odbc_driver.go | 4 + .../github.com/go-xorm/xorm/oracle_dialect.go | 12 +- .../go-xorm/xorm/postgres_dialect.go | 13 +- .../src/github.com/go-xorm/xorm/pq_driver.go | 8 +- .../src/github.com/go-xorm/xorm/processors.go | 8 + .../src/github.com/go-xorm/xorm/rows.go | 6 +- .../src/github.com/go-xorm/xorm/session.go | 955 ++++++++++++------ .../go-xorm/xorm/sqlite3_dialect.go | 30 +- .../github.com/go-xorm/xorm/sqlite3_driver.go | 4 + .../src/github.com/go-xorm/xorm/statement.go | 635 ++++++------ .../src/github.com/go-xorm/xorm/syslogger.go | 4 + .../src/github.com/go-xorm/xorm/xorm.go | 8 +- 42 files changed, 1893 insertions(+), 834 deletions(-) create mode 100644 Godeps/_workspace/src/github.com/go-xorm/core/LICENSE create mode 100644 Godeps/_workspace/src/github.com/go-xorm/core/scan.go delete mode 100644 Godeps/_workspace/src/github.com/go-xorm/xorm/examples/goroutine.db-journal create mode 100644 Godeps/_workspace/src/github.com/go-xorm/xorm/helpers_test.go rename Godeps/_workspace/src/github.com/go-xorm/xorm/{memroy_store.go => memory_store.go} (82%) diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index 9502b16e0e5..6449d1941c4 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -1,6 +1,6 @@ { "ImportPath": "github.com/grafana/grafana", - "GoVersion": "go1.5", + "GoVersion": "go1.5.1", "Packages": [ "./pkg/..." ], @@ -106,12 +106,13 @@ }, { "ImportPath": "github.com/go-xorm/core", - "Rev": "be6e7ac47dc57bd0ada25322fa526944f66ccaa6" + "Comment": "v0.4.4-7-g9e608f7", + "Rev": "9e608f7330b9d16fe2818cfe731128b3f156cb9a" }, { "ImportPath": "github.com/go-xorm/xorm", - "Comment": "v0.4.2-58-ge2889e5", - "Rev": "e2889e5517600b82905f1d2ba8b70deb71823ffe" + "Comment": "v0.4.4-44-gf561133", + "Rev": "f56113384f2c63dfe4cd8e768e349f1c35122b58" }, { "ImportPath": "github.com/gosimple/slug", diff --git a/Godeps/_workspace/src/github.com/go-xorm/core/LICENSE b/Godeps/_workspace/src/github.com/go-xorm/core/LICENSE new file mode 100644 index 00000000000..1130797806c --- /dev/null +++ b/Godeps/_workspace/src/github.com/go-xorm/core/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2013 - 2015 Lunny Xiao +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the {organization} nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/Godeps/_workspace/src/github.com/go-xorm/core/cache.go b/Godeps/_workspace/src/github.com/go-xorm/core/cache.go index 2a61ef7521d..bf81bd52ba4 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/core/cache.go +++ b/Godeps/_workspace/src/github.com/go-xorm/core/cache.go @@ -1,10 +1,11 @@ package core import ( - "encoding/json" "errors" "fmt" "time" + "bytes" + "encoding/gob" ) const ( @@ -47,16 +48,20 @@ type Cacher interface { } func encodeIds(ids []PK) (string, error) { - b, err := json.Marshal(ids) - if err != nil { - return "", err - } - return string(b), nil + buf := new(bytes.Buffer) + enc := gob.NewEncoder(buf) + err := enc.Encode(ids) + + return buf.String(), err } + func decodeIds(s string) ([]PK, error) { pks := make([]PK, 0) - err := json.Unmarshal([]byte(s), &pks) + + dec := gob.NewDecoder(bytes.NewBufferString(s)) + err := dec.Decode(&pks) + return pks, err } diff --git a/Godeps/_workspace/src/github.com/go-xorm/core/column.go b/Godeps/_workspace/src/github.com/go-xorm/core/column.go index 52468aa20e6..0e0def08c2f 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/core/column.go +++ b/Godeps/_workspace/src/github.com/go-xorm/core/column.go @@ -1,10 +1,10 @@ package core import ( - "errors" "fmt" "reflect" "strings" + "time" ) const ( @@ -35,6 +35,8 @@ type Column struct { DefaultIsEmpty bool EnumOptions map[string]int SetOptions map[string]int + DisableTimeZone bool + TimeZone *time.Location // column specified time zone } func NewColumn(name, fieldName string, sqlType SQLType, len1, len2 int, nullable bool) *Column { @@ -122,50 +124,34 @@ func (col *Column) ValueOfV(dataStruct *reflect.Value) (*reflect.Value, error) { } if dataStruct.Type().Kind() == reflect.Map { - var keyValue reflect.Value - - if len(col.fieldPath) == 1 { - keyValue = reflect.ValueOf(col.FieldName) - } else if len(col.fieldPath) == 2 { - keyValue = reflect.ValueOf(col.fieldPath[1]) - } else { - return nil, fmt.Errorf("Unsupported mutliderive %v", col.FieldName) - } - + keyValue := reflect.ValueOf(col.fieldPath[len(col.fieldPath)-1]) fieldValue = dataStruct.MapIndex(keyValue) return &fieldValue, nil + } else if dataStruct.Type().Kind() == reflect.Interface { + structValue := reflect.ValueOf(dataStruct.Interface()) + dataStruct = &structValue } - if len(col.fieldPath) == 1 { - fieldValue = dataStruct.FieldByName(col.FieldName) - } else if len(col.fieldPath) == 2 { - parentField := dataStruct.FieldByName(col.fieldPath[0]) - if parentField.IsValid() { - if parentField.Kind() == reflect.Struct { - fieldValue = parentField.FieldByName(col.fieldPath[1]) - } else if parentField.Kind() == reflect.Ptr { - if parentField.IsNil() { - parentField.Set(reflect.New(parentField.Type().Elem())) - fieldValue = parentField.Elem().FieldByName(col.fieldPath[1]) - } else { - parentField = parentField.Elem() - if parentField.IsValid() { - fieldValue = parentField.FieldByName(col.fieldPath[1]) - } else { - return nil, fmt.Errorf("field %v is not valid", col.FieldName) - } - } - } - } else { - // so we can use a different struct as conditions - fieldValue = dataStruct.FieldByName(col.fieldPath[1]) + level := len(col.fieldPath) + fieldValue = dataStruct.FieldByName(col.fieldPath[0]) + for i := 0; i < level-1; i++ { + if !fieldValue.IsValid() { + break + } + if fieldValue.Kind() == reflect.Struct { + fieldValue = fieldValue.FieldByName(col.fieldPath[i+1]) + } else if fieldValue.Kind() == reflect.Ptr { + if fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + } + fieldValue = fieldValue.Elem().FieldByName(col.fieldPath[i+1]) + } else { + return nil, fmt.Errorf("field %v is not valid", col.FieldName) } - } else { - return nil, fmt.Errorf("Unsupported mutliderive %v", col.FieldName) } if !fieldValue.IsValid() { - return nil, errors.New("no find field matched") + return nil, fmt.Errorf("field %v is not valid", col.FieldName) } return &fieldValue, nil diff --git a/Godeps/_workspace/src/github.com/go-xorm/core/db.go b/Godeps/_workspace/src/github.com/go-xorm/core/db.go index 1f70a926116..169d855315e 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/core/db.go +++ b/Godeps/_workspace/src/github.com/go-xorm/core/db.go @@ -2,6 +2,7 @@ package core import ( "database/sql" + "database/sql/driver" "errors" "reflect" "regexp" @@ -29,10 +30,24 @@ func StructToSlice(query string, st interface{}) (string, []interface{}, error) } args := make([]interface{}, 0) + var err error query = re.ReplaceAllStringFunc(query, func(src string) string { - args = append(args, vv.Elem().FieldByName(src[1:]).Interface()) + fv := vv.Elem().FieldByName(src[1:]).Interface() + if v, ok := fv.(driver.Valuer); ok { + var value driver.Value + value, err = v.Value() + if err != nil { + return "?" + } + args = append(args, value) + } else { + args = append(args, fv) + } return "?" }) + if err != nil { + return "", []interface{}{}, err + } return query, args, nil } @@ -43,12 +58,25 @@ type DB struct { func Open(driverName, dataSourceName string) (*DB, error) { db, err := sql.Open(driverName, dataSourceName) - return &DB{db, NewCacheMapper(&SnakeMapper{})}, err + if err != nil { + return nil, err + } + return &DB{db, NewCacheMapper(&SnakeMapper{})}, nil +} + +func FromDB(db *sql.DB) *DB { + return &DB{db, NewCacheMapper(&SnakeMapper{})} } func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { rows, err := db.DB.Query(query, args...) - return &Rows{rows, db.Mapper}, err + if err != nil { + if rows != nil { + rows.Close() + } + return nil, err + } + return &Rows{rows, db.Mapper}, nil } func (db *DB) QueryMap(query string, mp interface{}) (*Rows, error) { @@ -68,28 +96,87 @@ func (db *DB) QueryStruct(query string, st interface{}) (*Rows, error) { } type Row struct { - *sql.Row + rows *Rows // One of these two will be non-nil: - err error // deferred error for easy chaining - Mapper IMapper + err error // deferred error for easy chaining +} + +func (row *Row) Columns() ([]string, error) { + if row.err != nil { + return nil, row.err + } + return row.rows.Columns() } func (row *Row) Scan(dest ...interface{}) error { if row.err != nil { return row.err } - return row.Row.Scan(dest...) + defer row.rows.Close() + + for _, dp := range dest { + if _, ok := dp.(*sql.RawBytes); ok { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") + } + } + + if !row.rows.Next() { + if err := row.rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + err := row.rows.Scan(dest...) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + if err := row.rows.Close(); err != nil { + return err + } + + return nil +} + +func (row *Row) ScanStructByName(dest interface{}) error { + if row.err != nil { + return row.err + } + return row.rows.ScanStructByName(dest) +} + +func (row *Row) ScanStructByIndex(dest interface{}) error { + if row.err != nil { + return row.err + } + return row.rows.ScanStructByIndex(dest) +} + +// scan data to a slice's pointer, slice's length should equal to columns' number +func (row *Row) ScanSlice(dest interface{}) error { + if row.err != nil { + return row.err + } + return row.rows.ScanSlice(dest) +} + +// scan data to a map's pointer +func (row *Row) ScanMap(dest interface{}) error { + if row.err != nil { + return row.err + } + return row.rows.ScanMap(dest) } func (db *DB) QueryRow(query string, args ...interface{}) *Row { - row := db.DB.QueryRow(query, args...) - return &Row{row, nil, db.Mapper} + rows, err := db.Query(query, args...) + return &Row{rows, err} } func (db *DB) QueryRowMap(query string, mp interface{}) *Row { query, args, err := MapToSlice(query, mp) if err != nil { - return &Row{nil, err, db.Mapper} + return &Row{nil, err} } return db.QueryRow(query, args...) } @@ -97,7 +184,7 @@ func (db *DB) QueryRowMap(query string, mp interface{}) *Row { func (db *DB) QueryRowStruct(query string, st interface{}) *Row { query, args, err := StructToSlice(query, st) if err != nil { - return &Row{nil, err, db.Mapper} + return &Row{nil, err} } return db.QueryRow(query, args...) } @@ -187,14 +274,14 @@ func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) { } func (s *Stmt) QueryRow(args ...interface{}) *Row { - row := s.Stmt.QueryRow(args...) - return &Row{row, nil, s.Mapper} + rows, err := s.Query(args...) + return &Row{rows, err} } func (s *Stmt) QueryRowMap(mp interface{}) *Row { vv := reflect.ValueOf(mp) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map { - return &Row{nil, errors.New("mp should be a map's pointer"), s.Mapper} + return &Row{nil, errors.New("mp should be a map's pointer")} } args := make([]interface{}, len(s.names)) @@ -208,7 +295,7 @@ func (s *Stmt) QueryRowMap(mp interface{}) *Row { func (s *Stmt) QueryRowStruct(st interface{}) *Row { vv := reflect.ValueOf(st) if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct { - return &Row{nil, errors.New("st should be a struct's pointer"), s.Mapper} + return &Row{nil, errors.New("st should be a struct's pointer")} } args := make([]interface{}, len(s.names)) @@ -540,14 +627,14 @@ func (tx *Tx) QueryStruct(query string, st interface{}) (*Rows, error) { } func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { - row := tx.Tx.QueryRow(query, args...) - return &Row{row, nil, tx.Mapper} + rows, err := tx.Query(query, args...) + return &Row{rows, err} } func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row { query, args, err := MapToSlice(query, mp) if err != nil { - return &Row{nil, err, tx.Mapper} + return &Row{nil, err} } return tx.QueryRow(query, args...) } @@ -555,7 +642,7 @@ func (tx *Tx) QueryRowMap(query string, mp interface{}) *Row { func (tx *Tx) QueryRowStruct(query string, st interface{}) *Row { query, args, err := StructToSlice(query, st) if err != nil { - return &Row{nil, err, tx.Mapper} + return &Row{nil, err} } return tx.QueryRow(query, args...) } diff --git a/Godeps/_workspace/src/github.com/go-xorm/core/db_test.go b/Godeps/_workspace/src/github.com/go-xorm/core/db_test.go index 65bf241446d..94c4ea4f9f6 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/core/db_test.go +++ b/Godeps/_workspace/src/github.com/go-xorm/core/db_test.go @@ -24,7 +24,7 @@ type User struct { Age float32 Alias string NickName string - Created time.Time + Created NullTime } func init() { @@ -85,7 +85,7 @@ func BenchmarkOriQuery(b *testing.B) { var Id int64 var Name, Title, Alias, NickName string var Age float32 - var Created time.Time + var Created NullTime err = rows.Scan(&Id, &Name, &Title, &Age, &Alias, &NickName, &Created) if err != nil { b.Error(err) @@ -600,7 +600,7 @@ func TestExecStruct(t *testing.T) { Age: 1.2, Alias: "lunny", NickName: "lunny xiao", - Created: time.Now(), + Created: NullTime(time.Now()), } _, err = db.ExecStruct("insert into user (`name`, title, age, alias, nick_name,created) "+ @@ -645,7 +645,7 @@ func BenchmarkExecStruct(b *testing.B) { Age: 1.2, Alias: "lunny", NickName: "lunny xiao", - Created: time.Now(), + Created: NullTime(time.Now()), } for i := 0; i < b.N; i++ { diff --git a/Godeps/_workspace/src/github.com/go-xorm/core/dialect.go b/Godeps/_workspace/src/github.com/go-xorm/core/dialect.go index 43a22670913..45bc5c20883 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/core/dialect.go +++ b/Godeps/_workspace/src/github.com/go-xorm/core/dialect.go @@ -54,7 +54,7 @@ type Dialect interface { IndexCheckSql(tableName, idxName string) (string, []interface{}) TableCheckSql(tableName string) (string, []interface{}) - IsColumnExist(tableName string, col *Column) (bool, error) + IsColumnExist(tableName string, colName string) (bool, error) CreateTableSql(table *Table, tableName, storeEngine, charset string) string DropTableSql(tableName string) string @@ -63,6 +63,8 @@ type Dialect interface { ModifyColumnSql(tableName string, col *Column) string + ForUpdateSql(query string) string + //CreateTableIfNotExists(table *Table, tableName, storeEngine, charset string) error //MustDropTable(tableName string) error @@ -164,10 +166,10 @@ func (db *Base) HasRecords(query string, args ...interface{}) (bool, error) { return false, nil } -func (db *Base) IsColumnExist(tableName string, col *Column) (bool, error) { +func (db *Base) IsColumnExist(tableName, colName string) (bool, error) { query := "SELECT `COLUMN_NAME` FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = ? AND `TABLE_NAME` = ? AND `COLUMN_NAME` = ?" query = strings.Replace(query, "`", db.dialect.QuoteStr(), -1) - return db.HasRecords(query, db.DbName, tableName, col.Name) + return db.HasRecords(query, db.DbName, tableName, colName) } /* @@ -229,28 +231,33 @@ func (b *Base) CreateTableSql(table *Table, tableName, storeEngine, charset stri tableName = table.Name } - sql += b.dialect.Quote(tableName) + " (" + sql += b.dialect.Quote(tableName) + sql += " (" - pkList := table.PrimaryKeys + if len(table.ColumnsSeq()) > 0 { + pkList := table.PrimaryKeys - for _, colName := range table.ColumnsSeq() { - col := table.GetColumn(colName) - if col.IsPrimaryKey && len(pkList) == 1 { - sql += col.String(b.dialect) - } else { - sql += col.StringNoPk(b.dialect) + for _, colName := range table.ColumnsSeq() { + col := table.GetColumn(colName) + if col.IsPrimaryKey && len(pkList) == 1 { + sql += col.String(b.dialect) + } else { + sql += col.StringNoPk(b.dialect) + } + sql = strings.TrimSpace(sql) + sql += ", " } - sql = strings.TrimSpace(sql) - sql += ", " - } - if len(pkList) > 1 { - sql += "PRIMARY KEY ( " - sql += b.dialect.Quote(strings.Join(pkList, b.dialect.Quote(","))) - sql += " ), " - } + if len(pkList) > 1 { + sql += "PRIMARY KEY ( " + sql += b.dialect.Quote(strings.Join(pkList, b.dialect.Quote(","))) + sql += " ), " + } + + sql = sql[:len(sql)-2] + } + sql += ")" - sql = sql[:len(sql)-2] + ")" if b.dialect.SupportEngine() && storeEngine != "" { sql += " ENGINE=" + storeEngine } @@ -262,21 +269,25 @@ func (b *Base) CreateTableSql(table *Table, tableName, storeEngine, charset stri sql += " DEFAULT CHARSET " + charset } } - sql += ";" + return sql } +func (b *Base) ForUpdateSql(query string) string { + return query + " FOR UPDATE" +} + var ( - dialects = map[DbType]Dialect{} + dialects = map[DbType]func() Dialect{} ) -func RegisterDialect(dbName DbType, dialect Dialect) { - if dialect == nil { +func RegisterDialect(dbName DbType, dialectFunc func() Dialect) { + if dialectFunc == nil { panic("core: Register dialect is nil") } - dialects[dbName] = dialect // !nashtsai! allow override dialect + dialects[dbName] = dialectFunc // !nashtsai! allow override dialect } func QueryDialect(dbName DbType) Dialect { - return dialects[dbName] + return dialects[dbName]() } diff --git a/Godeps/_workspace/src/github.com/go-xorm/core/error.go b/Godeps/_workspace/src/github.com/go-xorm/core/error.go index 414ba037d1e..13c179251d9 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/core/error.go +++ b/Godeps/_workspace/src/github.com/go-xorm/core/error.go @@ -4,7 +4,7 @@ import "errors" var ( ErrNoMapPointer = errors.New("mp should be a map's pointer") - ErrNoStructPointer = errors.New("mp should be a map's pointer") + ErrNoStructPointer = errors.New("mp should be a struct's pointer") //ErrNotExist = errors.New("Not exist") //ErrIgnore = errors.New("Ignore") ) diff --git a/Godeps/_workspace/src/github.com/go-xorm/core/scan.go b/Godeps/_workspace/src/github.com/go-xorm/core/scan.go new file mode 100644 index 00000000000..7da338d8645 --- /dev/null +++ b/Godeps/_workspace/src/github.com/go-xorm/core/scan.go @@ -0,0 +1,52 @@ +package core + +import ( + "database/sql/driver" + "fmt" + "time" +) + +type NullTime time.Time + +var ( + _ driver.Valuer = NullTime{} +) + +func (ns *NullTime) Scan(value interface{}) error { + if value == nil { + return nil + } + return convertTime(ns, value) +} + +// Value implements the driver Valuer interface. +func (ns NullTime) Value() (driver.Value, error) { + if (time.Time)(ns).IsZero() { + return nil, nil + } + return (time.Time)(ns).Format("2006-01-02 15:04:05"), nil +} + +func convertTime(dest *NullTime, src interface{}) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + t, err := time.Parse("2006-01-02 15:04:05", s) + if err != nil { + return err + } + *dest = NullTime(t) + return nil + case []uint8: + t, err := time.Parse("2006-01-02 15:04:05", string(s)) + if err != nil { + return err + } + *dest = NullTime(t) + return nil + case nil: + default: + return fmt.Errorf("unsupported driver -> Scan pair: %T -> %T", src, dest) + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/go-xorm/core/type.go b/Godeps/_workspace/src/github.com/go-xorm/core/type.go index 73b9921ee63..c7c4d7bd08f 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/core/type.go +++ b/Godeps/_workspace/src/github.com/go-xorm/core/type.go @@ -53,6 +53,10 @@ func (s *SQLType) IsNumeric() bool { return s.IsType(NUMERIC_TYPE) } +func (s *SQLType) IsJson() bool { + return s.Name == Json +} + var ( Bit = "BIT" TinyInt = "TINYINT" @@ -101,6 +105,8 @@ var ( Serial = "SERIAL" BigSerial = "BIGSERIAL" + Json = "JSON" + SqlTypes = map[string]int{ Bit: NUMERIC_TYPE, TinyInt: NUMERIC_TYPE, @@ -112,6 +118,7 @@ var ( Enum: TEXT_TYPE, Set: TEXT_TYPE, + Json: TEXT_TYPE, Char: TEXT_TYPE, Varchar: TEXT_TYPE, @@ -229,7 +236,6 @@ var ( ) func Type2SQLType(t reflect.Type) (st SQLType) { - switch k := t.Kind(); k { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: st = SQLType{Int, 0, 0} @@ -252,7 +258,7 @@ func Type2SQLType(t reflect.Type) (st SQLType) { case reflect.String: st = SQLType{Varchar, 255, 0} case reflect.Struct: - if t.ConvertibleTo(reflect.TypeOf(c_TIME_DEFAULT)) { + if t.ConvertibleTo(TimeType) { st = SQLType{DateTime, 0, 0} } else { // TODO need to handle association struct diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/LICENSE b/Godeps/_workspace/src/github.com/go-xorm/xorm/LICENSE index 9ac0c261f59..84d2ae5386d 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/LICENSE +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013 - 2014 +Copyright (c) 2013 - 2015 The Xorm Authors All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/README.md b/Godeps/_workspace/src/github.com/go-xorm/xorm/README.md index fe8aca3c374..4b4685ea72b 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/README.md +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/README.md @@ -2,6 +2,8 @@ Xorm is a simple and powerful ORM for Go. +[![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/go-xorm/xorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) + [![Build Status](https://drone.io/github.com/go-xorm/tests/status.png)](https://drone.io/github.com/go-xorm/tests/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/go-xorm/xorm) [![Bitdeli Badge](https://d2weczhvl823v0.cloudfront.net/lunny/xorm/trend.png)](https://bitdeli.com/free "Bitdeli Badge") # Features @@ -9,7 +11,7 @@ Xorm is a simple and powerful ORM for Go. * Struct <-> Table Mapping Support * Chainable APIs - + * Transaction Support * Both ORM and raw SQL operation Support @@ -33,38 +35,44 @@ Drivers for Go's sql package which currently support database/sql includes: * MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) -* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) - * Postgres: [github.com/lib/pq](https://github.com/lib/pq) +* Tidb: [github.com/pingcap/tidb](https://github.com/pingcap/tidb) + +* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) + * MsSql: [github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) * MsSql: [github.com/lunny/godbc](https://github.com/lunny/godbc) +* Oracle: [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (experiment) +* ql: [github.com/cznic/ql](https://github.com/cznic/ql) (experiment) # Changelog -* **v0.4.1** - Features: - * Add deleted xorm tag for soft delete and add unscoped +* **v0.4.4** + * ql database expriment support + * tidb database expriment support + * sql.NullString and etc. field support + * select ForUpdate support + * many bugs fixed -* **v0.4.0 RC1** - Changes: - * moved xorm cmd to [github.com/go-xorm/cmd](github.com/go-xorm/cmd) - * refactored general DB operation a core lib at [github.com/go-xorm/core](https://github.com/go-xorm/core) - * moved tests to github.com/go-xorm/tests [github.com/go-xorm/tests](github.com/go-xorm/tests) +* **v0.4.3** + * Json column type support + * oracle expirement support + * bug fixed - Improvements: - * Prepared statement cache - * Add Incr API - * Specify Timezone Location +* **v0.4.2** + * Transaction will auto rollback if not Rollback or Commit be called. + * Gonic Mapper support + * bug fixed -[More changelogs ...](https://github.com/go-xorm/manual-en-US/tree/master/chapter-15) +[More changelogs ...](https://github.com/go-xorm/manual-en-US/tree/master/chapter-16) # Installation -If you have [gopm](https://github.com/gpmgo/gopm) installed, +If you have [gopm](https://github.com/gpmgo/gopm) installed, gopm get github.com/go-xorm/xorm @@ -80,8 +88,152 @@ Or * [GoWalker](http://gowalker.org/github.com/go-xorm/xorm) +# Quick Start + +* Create Engine + +```Go +engine, err := xorm.NewEngine(driverName, dataSourceName) +``` + +* Define a struct and Sync2 table struct to database + +```Go +type User struct { + Id int64 + Name string + Salt string + Age int + Passwd string `xorm:"varchar(200)"` + Created time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` +} + +err := engine.Sync2(new(User)) +``` + +* Query a SQL string, the returned results is []map[string][]byte + +```Go +results, err := engine.Query("select * from user") +``` + +* Execute a SQL string, the returned results + +```Go +affected, err := engine.Exec("update user set age = ? where name = ?", age, name) +``` + +* Insert one or multipe records to database + +```Go +affected, err := engine.Insert(&user) +// INSERT INTO struct () values () +affected, err := engine.Insert(&user1, &user2) +// INSERT INTO struct1 () values () +// INSERT INTO struct2 () values () +affected, err := engine.Insert(&users) +// INSERT INTO struct () values (),(),() +affected, err := engine.Insert(&user1, &users) +// INSERT INTO struct1 () values () +// INSERT INTO struct2 () values (),(),() +``` + +* Query one record from database + +```Go +has, err := engine.Get(&user) +// SELECT * FROM user LIMIT 1 +has, err := engine.Where("name = ?", name).Desc("id").Get(&user) +// SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1 +``` + +* Query multiple records from database, also you can use join and extends + +```Go +var users []User +err := engine.Where("name = ?", name).And("age > 10").Limit(10, 0).Find(&users) +// SELECT * FROM user WHERE name = ? AND age > 10 limit 0 offset 10 + +type Detail struct { + Id int64 + UserId int64 `xorm:"index"` +} + +type UserDetail struct { + User `xorm:"extends"` + Detail `xorm:"extends"` +} + +var users []UserDetail +err := engine.Table("user").Select("user.*, detail.*") + Join("INNER", "detail", "detail.user_id = user.id"). + Where("user.name = ?", name).Limit(10, 0). + Find(&users) +// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 0 offset 10 +``` + +* Query multiple records and record by record handle, there two methods Iterate and Rows + +```Go +err := engine.Iterate(&User{Name:name}, func(idx int, bean interface{}) error { + user := bean.(*User) + return nil +}) +// SELECT * FROM user + +rows, err := engine.Rows(&User{Name:name}) +// SELECT * FROM user +defer rows.Close() +bean := new(Struct) +for rows.Next() { + err = rows.Scan(bean) +} +``` + +* Update one or more records, default will update non-empty and non-zero fields except to use Cols, AllCols and etc. + +```Go +affected, err := engine.Id(1).Update(&user) +// UPDATE user SET ... Where id = ? + +affected, err := engine.Update(&user, &User{Name:name}) +// UPDATE user SET ... Where name = ? + +var ids = []int64{1, 2, 3} +affected, err := engine.In(ids).Update(&user) +// UPDATE user SET ... Where id IN (?, ?, ?) + +// force update indicated columns by Cols +affected, err := engine.Id(1).Cols("age").Update(&User{Name:name, Age: 12}) +// UPDATE user SET age = ?, updated=? Where id = ? + +// force NOT update indicated columns by Omit +affected, err := engine.Id(1).Omit("name").Update(&User{Name:name, Age: 12}) +// UPDATE user SET age = ?, updated=? Where id = ? + +affected, err := engine.Id(1).AllCols().Update(&user) +// UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ? +``` + +* Delete one or more records, Delete MUST has conditon + +```Go +affected, err := engine.Where(...).Delete(&user) +// DELETE FROM user Where ... +``` + +* Count records + +```Go +counts, err := engine.Count(&user) +// SELECT count(*) AS total FROM user +``` + # Cases +* [github.com/m3ng9i/qreader](https://github.com/m3ng9i/qreader) + * [Wego](http://github.com/go-tango/wego) * [Docker.cn](https://docker.cn/) diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/README_CN.md b/Godeps/_workspace/src/github.com/go-xorm/xorm/README_CN.md index 5a167f9b148..4466bc707cd 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/README_CN.md +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/README_CN.md @@ -4,6 +4,8 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作非常简便。 +[![Gitter](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/go-xorm/xorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) + [![Build Status](https://drone.io/github.com/go-xorm/tests/status.png)](https://drone.io/github.com/go-xorm/tests/latest) [![Go Walker](http://gowalker.org/api/v1/badge)](http://gowalker.org/github.com/go-xorm/xorm) ## 特性 @@ -18,7 +20,7 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 * 支持使用Id, In, Where, Limit, Join, Having, Table, Sql, Cols等函数和结构体等方式作为条件 -* 支持级联加载Struct +* 支持级联加载Struct * 支持缓存 @@ -34,29 +36,42 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 * MyMysql: [github.com/ziutek/mymysql/godrv](https://github.com/ziutek/mymysql/godrv) -* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) - * Postgres: [github.com/lib/pq](https://github.com/lib/pq) +* Tidb: [github.com/pingcap/tidb](https://github.com/pingcap/tidb) + +* SQLite: [github.com/mattn/go-sqlite3](https://github.com/mattn/go-sqlite3) + * MsSql: [github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) * MsSql: [github.com/lunny/godbc](https://github.com/lunny/godbc) +* Oracle: [github.com/mattn/go-oci8](https://github.com/mattn/go-oci8) (试验性支持) + +* ql: [github.com/cznic/ql](https://github.com/cznic/ql) (试验性支持) + ## 更新日志 -* **v0.4.2** - 新特性: - * deleted标记 - * bug fixed +* **v0.4.4** + * Tidb 数据库支持 + * QL 试验性支持 + * sql.NullString支持 + * ForUpdate 支持 + * bug修正 + +* **v0.4.3** + * Json 字段类型支持 + * oracle实验性支持 + * bug修正 [更多更新日志...](https://github.com/go-xorm/manual-zh-CN/tree/master/chapter-16) ## 安装 -推荐使用 [gopm](https://github.com/gpmgo/gopm) 进行安装: +推荐使用 [gopm](https://github.com/gpmgo/gopm) 进行安装: gopm get github.com/go-xorm/xorm - + 或者您也可以使用go工具进行安装: go get github.com/go-xorm/xorm @@ -69,8 +84,151 @@ xorm是一个简单而强大的Go语言ORM库. 通过它可以使数据库操作 * [Godoc代码文档](http://godoc.org/github.com/go-xorm/xorm) +# 快速开始 -## 案例 +* 第一步创建引擎,driverName, dataSourceName和database/sql接口相同 + +```Go +engine, err := xorm.NewEngine(driverName, dataSourceName) +``` + +* 定义一个和表同步的结构体,并且自动同步结构体到数据库 + +```Go +type User struct { + Id int64 + Name string + Salt string + Age int + Passwd string `xorm:"varchar(200)"` + Created time.Time `xorm:"created"` + Updated time.Time `xorm:"updated"` +} + +err := engine.Sync2(new(User)) +``` + +* 最原始的也支持SQL语句查询,返回的结果类型为 []map[string][]byte + +```Go +results, err := engine.Query("select * from user") +``` + +* 执行一个SQL语句 + +```Go +affected, err := engine.Exec("update user set age = ? where name = ?", age, name) +``` + +* 插入一条或者多条记录 + +```Go +affected, err := engine.Insert(&user) +// INSERT INTO struct () values () +affected, err := engine.Insert(&user1, &user2) +// INSERT INTO struct1 () values () +// INSERT INTO struct2 () values () +affected, err := engine.Insert(&users) +// INSERT INTO struct () values (),(),() +affected, err := engine.Insert(&user1, &users) +// INSERT INTO struct1 () values () +// INSERT INTO struct2 () values (),(),() +``` + +* 查询单条记录 + +```Go +has, err := engine.Get(&user) +// SELECT * FROM user LIMIT 1 +has, err := engine.Where("name = ?", name).Desc("id").Get(&user) +// SELECT * FROM user WHERE name = ? ORDER BY id DESC LIMIT 1 +``` + +* 查询多条记录,当然可以使用Join和extends来组合使用 + +```Go +var users []User +err := engine.Where("name = ?", name).And("age > 10").Limit(10, 0).Find(&users) +// SELECT * FROM user WHERE name = ? AND age > 10 limit 0 offset 10 + +type Detail struct { + Id int64 + UserId int64 `xorm:"index"` +} + +type UserDetail struct { + User `xorm:"extends"` + Detail `xorm:"extends"` +} + +var users []UserDetail +err := engine.Table("user").Select("user.*, detail.*") + Join("INNER", "detail", "detail.user_id = user.id"). + Where("user.name = ?", name).Limit(10, 0). + Find(&users) +// SELECT user.*, detail.* FROM user INNER JOIN detail WHERE user.name = ? limit 0 offset 10 +``` + +* 根据条件遍历数据库,可以有两种方式: Iterate and Rows + +```Go +err := engine.Iterate(&User{Name:name}, func(idx int, bean interface{}) error { + user := bean.(*User) + return nil +}) +// SELECT * FROM user + +rows, err := engine.Rows(&User{Name:name}) +// SELECT * FROM user +defer rows.Close() +bean := new(Struct) +for rows.Next() { + err = rows.Scan(bean) +} +``` + +* 更新数据,除非使用Cols,AllCols函数指明,默认只更新非空和非0的字段 + +```Go +affected, err := engine.Id(1).Update(&user) +// UPDATE user SET ... Where id = ? + +affected, err := engine.Update(&user, &User{Name:name}) +// UPDATE user SET ... Where name = ? + +var ids = []int64{1, 2, 3} +affected, err := engine.In(ids).Update(&user) +// UPDATE user SET ... Where id IN (?, ?, ?) + +// force update indicated columns by Cols +affected, err := engine.Id(1).Cols("age").Update(&User{Name:name, Age: 12}) +// UPDATE user SET age = ?, updated=? Where id = ? + +// force NOT update indicated columns by Omit +affected, err := engine.Id(1).Omit("name").Update(&User{Name:name, Age: 12}) +// UPDATE user SET age = ?, updated=? Where id = ? + +affected, err := engine.Id(1).AllCols().Update(&user) +// UPDATE user SET name=?,age=?,salt=?,passwd=?,updated=? Where id = ? +``` + +* 删除记录,需要注意,删除必须至少有一个条件,否则会报错。要清空数据库可以用EmptyTable + +```Go +affected, err := engine.Where(...).Delete(&user) +// DELETE FROM user Where ... +``` + +* 获取记录条数 + +```Go +counts, err := engine.Count(&user) +// SELECT count(*) AS total FROM user +``` + +# 案例 + +* [github.com/m3ng9i/qreader](https://github.com/m3ng9i/qreader) * [Wego](http://github.com/go-tango/wego) diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/VERSION b/Godeps/_workspace/src/github.com/go-xorm/xorm/VERSION index 4e64e0e9e47..d1a98fc7b2e 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/VERSION +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/VERSION @@ -1 +1 @@ -xorm v0.4.2.0225 +xorm v0.4.5.0204 diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/doc.go b/Godeps/_workspace/src/github.com/go-xorm/xorm/doc.go index 722088ca775..54ce2b985b7 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/doc.go +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/doc.go @@ -1,4 +1,4 @@ -// Copyright 2013 - 2014 The XORM Authors. All rights reserved. +// Copyright 2013 - 2015 The Xorm Authors. All rights reserved. // Use of this source code is governed by a BSD // license that can be found in the LICENSE file. diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/engine.go b/Godeps/_workspace/src/github.com/go-xorm/xorm/engine.go index 8f0c805dc34..afb38766db3 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/engine.go +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/engine.go @@ -1,3 +1,7 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package xorm import ( @@ -316,6 +320,12 @@ func (engine *Engine) NoAutoTime() *Session { return session.NoAutoTime() } +func (engine *Engine) NoAutoCondition(no ...bool) *Session { + session := engine.NewSession() + session.IsAutoClose = true + return session.NoAutoCondition(no...) +} + // Retrieve all tables, columns, indexes' informations from database. func (engine *Engine) DBMetas() ([]*core.Table, error) { tables, err := engine.dialect.GetTables() @@ -373,13 +383,25 @@ func (engine *Engine) DumpAll(w io.Writer) error { return err } - for _, table := range tables { - _, err = io.WriteString(w, engine.dialect.CreateTableSql(table, "", table.StoreEngine, "")+"\n\n") + _, err = io.WriteString(w, fmt.Sprintf("/*Generated by xorm v%s %s*/\n\n", + Version, time.Now().In(engine.TZLocation).Format("2006-01-02 15:04:05"))) + if err != nil { + return err + } + + for i, table := range tables { + if i > 0 { + _, err = io.WriteString(w, "\n") + if err != nil { + return err + } + } + _, err = io.WriteString(w, engine.dialect.CreateTableSql(table, "", table.StoreEngine, "")+";\n") if err != nil { return err } for _, index := range table.Indexes { - _, err = io.WriteString(w, engine.dialect.CreateIndexSql(table.Name, index)+"\n\n") + _, err = io.WriteString(w, engine.dialect.CreateIndexSql(table.Name, index)+";\n") if err != nil { return err } @@ -439,7 +461,7 @@ func (engine *Engine) DumpAll(w io.Writer) error { } } } - _, err = io.WriteString(w, temp[2:]+");\n\n") + _, err = io.WriteString(w, temp[2:]+");\n") if err != nil { return err } @@ -506,6 +528,12 @@ func (engine *Engine) Distinct(columns ...string) *Session { return session.Distinct(columns...) } +func (engine *Engine) Select(str string) *Session { + session := engine.NewSession() + session.IsAutoClose = true + return session.Select(str) +} + // only use the paramters as select or update columns func (engine *Engine) Cols(columns ...string) *Session { session := engine.NewSession() @@ -543,6 +571,13 @@ func (engine *Engine) Omit(columns ...string) *Session { return session.Omit(columns...) } +// Set null when column is zero-value and nullable for update +func (engine *Engine) Nullable(columns ...string) *Session { + session := engine.NewSession() + session.IsAutoClose = true + return session.Nullable(columns...) +} + // This method will generate "column IN (?, ?)" func (engine *Engine) In(column string, args ...interface{}) *Session { session := engine.NewSession() @@ -642,20 +677,20 @@ func (engine *Engine) Having(conditions string) *Session { func (engine *Engine) autoMapType(v reflect.Value) *core.Table { t := v.Type() - engine.mutex.RLock() + engine.mutex.Lock() table, ok := engine.Tables[t] - engine.mutex.RUnlock() if !ok { table = engine.mapType(v) - engine.mutex.Lock() engine.Tables[t] = table - if v.CanAddr() { - engine.GobRegister(v.Addr().Interface()) - } else { - engine.GobRegister(v.Interface()) + if engine.Cacher != nil { + if v.CanAddr() { + engine.GobRegister(v.Addr().Interface()) + } else { + engine.GobRegister(v.Interface()) + } } - engine.mutex.Unlock() } + engine.mutex.Unlock() return table } @@ -691,33 +726,24 @@ func (engine *Engine) newTable() *core.Table { return table } +type TableName interface { + TableName() string +} + func (engine *Engine) mapType(v reflect.Value) *core.Table { t := v.Type() table := engine.newTable() - method := v.MethodByName("TableName") - if !method.IsValid() { - if v.CanAddr() { - method = v.Addr().MethodByName("TableName") - } - } - if method.IsValid() { - params := []reflect.Value{} - results := method.Call(params) - if len(results) == 1 { - table.Name = results[0].Interface().(string) - } - } - - if table.Name == "" { + if tb, ok := v.Interface().(TableName); ok { + table.Name = tb.TableName() + } else { table.Name = engine.TableMapper.Obj2Table(t.Name()) } + table.Type = t var idFieldColName string var err error - - hasCacheTag := false - hasNoCacheTag := false + var hasCacheTag, hasNoCacheTag bool for i := 0; i < t.NumField(); i++ { tag := t.Field(i).Tag @@ -730,7 +756,7 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { if ormTagStr != "" { col = &core.Column{FieldName: t.Field(i).Name, Nullable: true, IsPrimaryKey: false, IsAutoIncrement: false, MapType: core.TWOSIDES, Indexes: make(map[string]bool)} - tags := strings.Split(ormTagStr, " ") + tags := splitTag(ormTagStr) if len(tags) > 0 { if tags[0] == "-" { @@ -804,6 +830,16 @@ func (engine *Engine) mapType(v reflect.Value) *core.Table { case k == "VERSION": col.IsVersion = true col.Default = "1" + case k == "UTC": + col.TimeZone = time.UTC + case k == "LOCAL": + col.TimeZone = time.Local + case strings.HasPrefix(k, "LOCALE(") && strings.HasSuffix(k, ")"): + location := k[len("INDEX")+1 : len(k)-1] + col.TimeZone, err = time.LoadLocation(location) + if err != nil { + engine.LogError(err) + } case k == "UPDATED": col.IsUpdated = true case k == "DELETED": @@ -1112,7 +1148,7 @@ func (engine *Engine) Sync(beans ...interface{}) error { session := engine.NewSession() session.Statement.RefTable = table defer session.Close() - isExist, err := session.Engine.dialect.IsColumnExist(table.Name, col) + isExist, err := session.Engine.dialect.IsColumnExist(table.Name, col.Name) if err != nil { return err } @@ -1387,7 +1423,6 @@ var ( ) func (engine *Engine) TZTime(t time.Time) time.Time { - if NULL_TIME != t { // if time is not initialized it's not suitable for Time.In() return t.In(engine.TZLocation) } @@ -1405,35 +1440,51 @@ func (engine *Engine) NowTime2(sqlTypeName string) (interface{}, time.Time) { } func (engine *Engine) FormatTime(sqlTypeName string, t time.Time) (v interface{}) { + return engine.formatTime(engine.TZLocation, sqlTypeName, t) +} + +func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{}) { + if col.DisableTimeZone { + return engine.formatTime(nil, col.SQLType.Name, t) + } else if col.TimeZone != nil { + return engine.formatTime(col.TimeZone, col.SQLType.Name, t) + } + return engine.formatTime(engine.TZLocation, col.SQLType.Name, t) +} + +func (engine *Engine) formatTime(tz *time.Location, sqlTypeName string, t time.Time) (v interface{}) { if engine.dialect.DBType() == core.ORACLE { return t } + if tz != nil { + t = engine.TZTime(t) + } switch sqlTypeName { case core.Time: - s := engine.TZTime(t).Format("2006-01-02 15:04:05") //time.RFC3339 + s := t.Format("2006-01-02 15:04:05") //time.RFC3339 v = s[11:19] case core.Date: - v = engine.TZTime(t).Format("2006-01-02") + v = t.Format("2006-01-02") case core.DateTime, core.TimeStamp: if engine.dialect.DBType() == "ql" { - v = engine.TZTime(t) + v = t } else if engine.dialect.DBType() == "sqlite3" { - v = engine.TZTime(t).UTC().Format("2006-01-02 15:04:05") + v = t.UTC().Format("2006-01-02 15:04:05") } else { - v = engine.TZTime(t).Format("2006-01-02 15:04:05") + v = t.Format("2006-01-02 15:04:05") } case core.TimeStampz: if engine.dialect.DBType() == core.MSSQL { - v = engine.TZTime(t).Format("2006-01-02T15:04:05.9999999Z07:00") + v = t.Format("2006-01-02T15:04:05.9999999Z07:00") } else if engine.DriverName() == "mssql" { - v = engine.TZTime(t) + v = t } else { - v = engine.TZTime(t).Format(time.RFC3339Nano) + v = t.Format(time.RFC3339Nano) } case core.BigInt, core.Int: - v = engine.TZTime(t).Unix() + v = t.Unix() default: - v = engine.TZTime(t) + v = t } return } diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/error.go b/Godeps/_workspace/src/github.com/go-xorm/xorm/error.go index c868173f708..61537a34626 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/error.go +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/error.go @@ -1,3 +1,7 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package xorm import ( diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/examples/goroutine.db-journal b/Godeps/_workspace/src/github.com/go-xorm/xorm/examples/goroutine.db-journal deleted file mode 100644 index 95fc238e2941a7bdc3c6c855336d5289be8828d5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2576 zcmeI!J5K^Z6b0bn99-AODvBt+@r`d(NGvTWEw`jGp|BtrBPAuBlyp*3Qc_Ykbsk+iI!eJoV0w9_wr7r 0 ORDER BY f.attnum;` +WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.attnum > 0 ORDER BY f.attnum;` rows, err := db.DB().Query(s, args...) if db.Logger != nil { diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/pq_driver.go b/Godeps/_workspace/src/github.com/go-xorm/xorm/pq_driver.go index c8dd5aa009c..a4e269756b6 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/pq_driver.go +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/pq_driver.go @@ -1,3 +1,7 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package xorm import ( @@ -37,7 +41,7 @@ func parseURL(connstr string) (string, error) { return "", err } - if u.Scheme != "postgres" { + if u.Scheme != "postgresql" && u.Scheme != "postgres" { return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) } @@ -99,7 +103,7 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { db := &core.Uri{DbType: core.POSTGRES} o := make(values) var err error - if strings.HasPrefix(dataSourceName, "postgres://") { + if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") { dataSourceName, err = parseURL(dataSourceName) if err != nil { return nil, err diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/processors.go b/Godeps/_workspace/src/github.com/go-xorm/xorm/processors.go index 03ae8e0f80c..8f95ae3be73 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/processors.go +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/processors.go @@ -1,3 +1,7 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package xorm // Executed before an object is initially persisted to the database @@ -19,6 +23,10 @@ type BeforeSetProcessor interface { BeforeSet(string, Cell) } +type AfterSetProcessor interface { + AfterSet(string, Cell) +} + // !nashtsai! TODO enable BeforeValidateProcessor when xorm start to support validations //// Executed before an object is validated //type BeforeValidateProcessor interface { diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/rows.go b/Godeps/_workspace/src/github.com/go-xorm/xorm/rows.go index 0def55757c8..fb18454d1dc 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/rows.go +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/rows.go @@ -1,3 +1,7 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package xorm import ( @@ -41,7 +45,7 @@ func newRows(session *Session, bean interface{}) (*Rows, error) { sqlStr = filter.Do(sqlStr, session.Engine.dialect, rows.session.Statement.RefTable) } - rows.session.Engine.logSQL(sqlStr, args) + rows.session.saveLastSQL(sqlStr, args) var err error rows.stmt, err = rows.session.DB().Prepare(sqlStr) if err != nil { diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/session.go b/Godeps/_workspace/src/github.com/go-xorm/xorm/session.go index 0d11d99fd0f..b7ce7ff6498 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/session.go +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/session.go @@ -1,7 +1,12 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package xorm import ( "database/sql" + "database/sql/driver" "encoding/json" "errors" "fmt" @@ -39,18 +44,25 @@ type Session struct { beforeClosures []func(interface{}) afterClosures []func(interface{}) + prepareStmt bool stmtCache map[uint32]*core.Stmt //key: hash.Hash32 of (queryStr, len(queryStr)) cascadeDeep int + + // !evalphobia! stored the last executed query on this session + //beforeSQLExec func(string, ...interface{}) + lastSQL string + lastSQLArgs []interface{} } // Method Init reset the session as the init status. func (session *Session) Init() { - session.Statement = Statement{Engine: session.Engine} session.Statement.Init() + session.Statement.Engine = session.Engine session.IsAutoCommit = true session.IsCommitedOrRollbacked = false session.IsAutoClose = false session.AutoResetStatement = true + session.prepareStmt = false // !nashtsai! is lazy init better? session.afterInsertBeans = make(map[interface{}]*[]func(interface{}), 0) @@ -58,6 +70,9 @@ func (session *Session) Init() { session.afterDeleteBeans = make(map[interface{}]*[]func(interface{}), 0) session.beforeClosures = make([]func(interface{}), 0) session.afterClosures = make([]func(interface{}), 0) + + session.lastSQL = "" + session.lastSQLArgs = []interface{}{} } // Method Close release the connection from pool @@ -67,11 +82,15 @@ func (session *Session) Close() { } if session.db != nil { - //session.Engine.Pool.ReleaseDB(session.Engine, session.Db) - session.db = nil + // When Close be called, if session is a transaction and do not call + // Commit or Rollback, then call Rollback. + if session.Tx != nil && !session.IsCommitedOrRollbacked { + session.Rollback() + } session.Tx = nil session.stmtCache = nil session.Init() + session.db = nil } } @@ -81,6 +100,12 @@ func (session *Session) resetStatement() { } } +// Prepare +func (session *Session) Prepare() *Session { + session.prepareStmt = true + return session +} + // Method Sql provides raw sql input parameter. When you have a complex SQL statement // and cannot use Where, Id, In and etc. Methods to describe, you can use Sql. func (session *Session) Sql(querystring string, args ...interface{}) *Session { @@ -164,6 +189,12 @@ func (session *Session) SetExpr(column string, expression string) *Session { return session } +// Method Cols provides some columns to special +func (session *Session) Select(str string) *Session { + session.Statement.Select(str) + return session +} + // Method Cols provides some columns to special func (session *Session) Cols(columns ...string) *Session { session.Statement.Cols(columns...) @@ -203,12 +234,24 @@ func (session *Session) Distinct(columns ...string) *Session { return session } +// Set Read/Write locking for UPDATE +func (session *Session) ForUpdate() *Session { + session.Statement.IsForUpdate = true + return session +} + // Only not use the paramters as select or update columns func (session *Session) Omit(columns ...string) *Session { session.Statement.Omit(columns...) return session } +// Set null when column is zero-value and nullable for update +func (session *Session) Nullable(columns ...string) *Session { + session.Statement.Nullable(columns...) + return session +} + // Method NoAutoTime means do not automatically give created field and updated field // the current time on the current session temporarily func (session *Session) NoAutoTime() *Session { @@ -216,6 +259,11 @@ func (session *Session) NoAutoTime() *Session { return session } +func (session *Session) NoAutoCondition(no ...bool) *Session { + session.Statement.NoAutoCondition(no...) + return session +} + // Method Limit provide limit and offset query condition func (session *Session) Limit(limit int, start ...int) *Session { session.Statement.Limit(limit, start...) @@ -304,8 +352,7 @@ func (session *Session) Begin() error { session.IsAutoCommit = false session.IsCommitedOrRollbacked = false session.Tx = tx - - session.Engine.logSQL("BEGIN TRANSACTION") + session.saveLastSQL("BEGIN TRANSACTION") } return nil } @@ -313,7 +360,7 @@ func (session *Session) Begin() error { // When using transaction, you can rollback if any error func (session *Session) Rollback() error { if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { - session.Engine.logSQL(session.Engine.dialect.RollBackStr()) + session.saveLastSQL(session.Engine.dialect.RollBackStr()) session.IsCommitedOrRollbacked = true return session.Tx.Rollback() } @@ -323,7 +370,7 @@ func (session *Session) Rollback() error { // When using transaction, Commit will commit all operations. func (session *Session) Commit() error { if !session.IsAutoCommit && !session.IsCommitedOrRollbacked { - session.Engine.logSQL("COMMIT") + session.saveLastSQL("COMMIT") session.IsCommitedOrRollbacked = true var err error if err = session.Tx.Commit(); err == nil { @@ -426,17 +473,20 @@ func (session *Session) scanMapIntoStruct(obj interface{}, objMap map[string][]b //Execute sql func (session *Session) innerExec(sqlStr string, args ...interface{}) (sql.Result, error) { - stmt, err := session.doPrepare(sqlStr) - if err != nil { - return nil, err - } - //defer stmt.Close() + if session.prepareStmt { + stmt, err := session.doPrepare(sqlStr) + if err != nil { + return nil, err + } - res, err := stmt.Exec(args...) - if err != nil { - return nil, err + res, err := stmt.Exec(args...) + if err != nil { + return nil, err + } + return res, nil } - return res, nil + + return session.DB().Exec(sqlStr, args...) } func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, error) { @@ -444,7 +494,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er sqlStr = filter.Do(sqlStr, session.Engine.dialect, session.Statement.RefTable) } - session.Engine.logSQL(sqlStr, args...) + session.saveLastSQL(sqlStr, args...) return session.Engine.LogSQLExecutionTime(sqlStr, args, func() (sql.Result, error) { if session.IsAutoCommit { @@ -587,11 +637,15 @@ func (session *Session) DropTable(beanOrTableName interface{}) error { return nil } -func (statement *Statement) JoinColumns(cols []*core.Column) string { +func (statement *Statement) JoinColumns(cols []*core.Column, includeTableName bool) string { var colnames = make([]string, len(cols)) for i, col := range cols { - colnames[i] = statement.Engine.Quote(statement.TableName()) + - "." + statement.Engine.Quote(col.Name) + if includeTableName { + colnames[i] = statement.Engine.Quote(statement.TableName()) + + "." + statement.Engine.Quote(col.Name) + } else { + colnames[i] = statement.Engine.Quote(col.Name) + } } return strings.Join(colnames, ", ") } @@ -603,21 +657,33 @@ func (statement *Statement) convertIdSql(sqlStr string) string { return "" } - colstrs := statement.JoinColumns(cols) - sqls := splitNNoCase(sqlStr, "from", 2) + colstrs := statement.JoinColumns(cols, false) + sqls := splitNNoCase(sqlStr, " from ", 2) if len(sqls) != 2 { return "" } + if statement.Engine.dialect.DBType() == "ql" { + return fmt.Sprintf("SELECT id() FROM %v", sqls[1]) + } return fmt.Sprintf("SELECT %s FROM %v", colstrs, sqls[1]) } return "" } -func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { - // if has no reftable, then don't use cache currently +func (session *Session) canCache() bool { if session.Statement.RefTable == nil || session.Statement.JoinStr != "" || - session.Statement.RawSQL != "" { + session.Statement.RawSQL != "" || + session.Tx != nil || + len(session.Statement.selectStr) > 0 { + return false + } + return true +} + +func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interface{}) (has bool, err error) { + // if has no reftable, then don't use cache currently + if !session.canCache() { return false, ErrCacheFailed } @@ -715,7 +781,7 @@ func (session *Session) cacheGet(bean interface{}, sqlStr string, args ...interf } func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr interface{}, args ...interface{}) (err error) { - if session.Statement.RefTable == nil || + if !session.canCache() || indexNoCase(sqlStr, "having") != -1 || indexNoCase(sqlStr, "group by") != -1 { return ErrCacheFailed @@ -859,7 +925,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in } temps[ididxes[sid]] = bean - session.Engine.LogDebug("[cacheFind] cache bean:", tableName, id, bean) + session.Engine.LogDebug("[cacheFind] cache bean:", tableName, id, bean, temps) cacher.PutBean(tableName, sid, bean) } } @@ -867,7 +933,7 @@ func (session *Session) cacheFind(t reflect.Type, sqlStr string, rowsSlicePtr in for j := 0; j < len(temps); j++ { bean := temps[j] if bean == nil { - session.Engine.LogWarn("[cacheFind] cache no hit:", tableName, ides[j]) + session.Engine.LogWarn("[cacheFind] cache no hit:", tableName, ids[j], temps) // return errors.New("cache error") // !nashtsai! no need to return error, but continue instead continue } @@ -992,12 +1058,16 @@ func (session *Session) Get(bean interface{}) (bool, error) { var err error session.queryPreprocess(&sqlStr, args...) if session.IsAutoCommit { - stmt, err := session.doPrepare(sqlStr) - if err != nil { - return false, err + if session.prepareStmt { + stmt, errPrepare := session.doPrepare(sqlStr) + if errPrepare != nil { + return false, errPrepare + } + // defer stmt.Close() // !nashtsai! don't close due to stmt is cached and bounded to this session + rawRows, err = stmt.Query(args...) + } else { + rawRows, err = session.DB().Query(sqlStr, args...) } - // defer stmt.Close() // !nashtsai! don't close due to stmt is cached and bounded to this session - rawRows, err = stmt.Query(args...) } else { rawRows, err = session.Tx.Query(sqlStr, args...) } @@ -1154,18 +1224,25 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) table = session.Statement.RefTable } - if len(condiBean) > 0 { - colNames, args := buildConditions(session.Engine, table, condiBean[0], true, true, - false, true, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.unscoped, session.Statement.mustColumnMap) + var addedTableName = (len(session.Statement.JoinStr) > 0) + if !session.Statement.noAutoCondition && len(condiBean) > 0 { + colNames, args := session.Statement.buildConditions(table, condiBean[0], true, true, false, true, addedTableName) session.Statement.ConditionStr = strings.Join(colNames, " AND ") session.Statement.BeanArgs = args } else { // !oinume! Add " IS NULL" to WHERE whatever condiBean is given. // See https://github.com/go-xorm/xorm/issues/179 if col := table.DeletedColumn(); col != nil && !session.Statement.unscoped { // tag "deleted" is enabled - session.Statement.ConditionStr = fmt.Sprintf("(%v IS NULL or %v = '0001-01-01 00:00:00') ", - session.Engine.Quote(col.Name), session.Engine.Quote(col.Name)) + var colName string = session.Engine.Quote(col.Name) + if addedTableName { + var nm = session.Statement.TableName() + if len(session.Statement.TableAlias) > 0 { + nm = session.Statement.TableAlias + } + colName = session.Engine.Quote(nm) + "." + colName + } + session.Statement.ConditionStr = fmt.Sprintf("%v IS NULL OR %v = '0001-01-01 00:00:00'", + colName, colName) } } @@ -1173,20 +1250,24 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) var args []interface{} if session.Statement.RawSQL == "" { var columnStr string = session.Statement.ColumnStr - if session.Statement.JoinStr == "" { - if columnStr == "" { - if session.Statement.GroupByStr != "" { - columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) - } else { - columnStr = session.Statement.genColumnStr() - } - } + if len(session.Statement.selectStr) > 0 { + columnStr = session.Statement.selectStr } else { - if columnStr == "" { - if session.Statement.GroupByStr != "" { - columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) - } else { - columnStr = "*" + if session.Statement.JoinStr == "" { + if columnStr == "" { + if session.Statement.GroupByStr != "" { + columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) + } else { + columnStr = session.Statement.genColumnStr() + } + } + } else { + if columnStr == "" { + if session.Statement.GroupByStr != "" { + columnStr = session.Statement.Engine.Quote(strings.Replace(session.Statement.GroupByStr, ",", session.Engine.Quote(","), -1)) + } else { + columnStr = "*" + } } } } @@ -1227,11 +1308,15 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) session.queryPreprocess(&sqlStr, args...) if session.IsAutoCommit { - stmt, err = session.doPrepare(sqlStr) - if err != nil { - return err + if session.prepareStmt { + stmt, err = session.doPrepare(sqlStr) + if err != nil { + return err + } + rawRows, err = stmt.Query(args...) + } else { + rawRows, err = session.DB().Query(sqlStr, args...) } - rawRows, err = stmt.Query(args...) } else { rawRows, err = session.Tx.Query(sqlStr, args...) } @@ -1279,7 +1364,6 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } table := session.Engine.autoMapType(dataStruct) - return session.rows2Beans(rawRows, fields, fieldsCount, table, newElemFunc, sliceValueSetFunc) } else { resultsSlice, err := session.query(sqlStr, args...) @@ -1333,20 +1417,6 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) return nil } -// func (session *Session) queryRows(rawStmt **sql.Stmt, rawRows **sql.Rows, sqlStr string, args ...interface{}) error { -// var err error -// if session.IsAutoCommit { -// *rawStmt, err = session.doPrepare(sqlStr) -// if err != nil { -// return err -// } -// *rawRows, err = (*rawStmt).Query(args...) -// } else { -// *rawRows, err = session.Tx.Query(sqlStr, args...) -// } -// return err -// } - // Test if database is ok func (session *Session) Ping() error { defer session.resetStatement() @@ -1422,7 +1492,7 @@ func (session *Session) isTableEmpty(tableName string) (bool, error) { var total int64 sql := fmt.Sprintf("select count(*) from %s", session.Engine.Quote(tableName)) err := session.DB().QueryRow(sql).Scan(&total) - session.Engine.logSQL(sql) + session.saveLastSQL(sql) if err != nil { return true, err } @@ -1550,7 +1620,6 @@ type Cell *interface{} func (session *Session) rows2Beans(rows *core.Rows, fields []string, fieldsCount int, table *core.Table, newElemFunc func() reflect.Value, sliceValueSetFunc func(*reflect.Value)) error { - for rows.Next() { var newValue reflect.Value = newElemFunc() bean := newValue.Interface() @@ -1560,7 +1629,6 @@ func (session *Session) rows2Beans(rows *core.Rows, fields []string, fieldsCount return err } sliceValueSetFunc(&newValue) - } return nil } @@ -1591,6 +1659,14 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount } } + defer func() { + if b, hasAfterSet := bean.(AfterSetProcessor); hasAfterSet { + for ii, key := range fields { + b.AfterSet(key, Cell(scanResults[ii].(*interface{}))) + } + } + }() + var tempMap = make(map[string]int) for ii, key := range fields { var idx int @@ -1640,7 +1716,6 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount hasAssigned := false switch fieldType.Kind() { - case reflect.Complex64, reflect.Complex128: if rawValueType.Kind() == reflect.String { hasAssigned = true @@ -1651,6 +1726,15 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount return err } fieldValue.Set(x.Elem()) + } else if rawValueType.Kind() == reflect.Slice { + hasAssigned = true + x := reflect.New(fieldType) + err := json.Unmarshal(vv.Bytes(), x.Interface()) + if err != nil { + session.Engine.LogError(err) + return err + } + fieldValue.Set(x.Elem()) } case reflect.Slice, reflect.Array: switch rawValueType.Kind() { @@ -1695,6 +1779,7 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount fieldValue.SetUint(uint64(vv.Int())) } case reflect.Struct: + col := table.GetColumn(key) if fieldType.ConvertibleTo(core.TimeType) { if rawValueType == core.TimeType { hasAssigned = true @@ -1702,12 +1787,16 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount t := vv.Convert(core.TimeType).Interface().(time.Time) z, _ := t.Zone() if len(z) == 0 || t.Year() == 0 { // !nashtsai! HACK tmp work around for lib/pq doesn't properly time with location - session.Engine.LogDebug("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) + session.Engine.LogDebugf("empty zone key[%v] : %v | zone: %v | location: %+v\n", key, t, z, *t.Location()) t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.Local) } // !nashtsai! convert to engine location - t = t.In(session.Engine.TZLocation) + if col.TimeZone == nil { + t = t.In(session.Engine.TZLocation) + } else { + t = t.In(col.TimeZone) + } fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) // t = fieldValue.Interface().(time.Time) @@ -1716,17 +1805,66 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount } else if rawValueType == core.IntType || rawValueType == core.Int64Type || rawValueType == core.Int32Type { hasAssigned = true - t := time.Unix(vv.Int(), 0).In(session.Engine.TZLocation) - vv = reflect.ValueOf(t) - fieldValue.Set(vv) + var tz *time.Location + if col.TimeZone == nil { + tz = session.Engine.TZLocation + } else { + tz = col.TimeZone + } + t := time.Unix(vv.Int(), 0).In(tz) + //vv = reflect.ValueOf(t) + fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + } else { + if d, ok := vv.Interface().([]uint8); ok { + hasAssigned = true + t, err := session.byte2Time(col, d) + //fmt.Println(string(d), t, err) + if err != nil { + session.Engine.LogError("byte2Time error:", err.Error()) + hasAssigned = false + } else { + fieldValue.Set(reflect.ValueOf(t).Convert(fieldType)) + } + } else { + panic(fmt.Sprintf("rawValueType is %v, value is %v", rawValueType, vv.Interface())) + } + } + } else if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + // !! 增加支持sql.Scanner接口的结构,如sql.NullString + hasAssigned = true + if err := nulVal.Scan(vv.Interface()); err != nil { + //fmt.Println("sql.Sanner error:", err.Error()) + session.Engine.LogError("sql.Sanner error:", err.Error()) + hasAssigned = false + } + } else if col.SQLType.IsJson() { + if rawValueType.Kind() == reflect.String { + hasAssigned = true + x := reflect.New(fieldType) + err := json.Unmarshal([]byte(vv.String()), x.Interface()) + if err != nil { + session.Engine.LogError(err) + return err + } + fieldValue.Set(x.Elem()) + } else if rawValueType.Kind() == reflect.Slice { + hasAssigned = true + x := reflect.New(fieldType) + err := json.Unmarshal(vv.Bytes(), x.Interface()) + if err != nil { + session.Engine.LogError(err) + return err + } + fieldValue.Set(x.Elem()) } } else if session.Statement.UseCascade { table := session.Engine.autoMapType(*fieldValue) if table != nil { - if len(table.PrimaryKeys) > 1 { - panic("unsupported composited primary key cascade") + if len(table.PrimaryKeys) != 1 { + panic("unsupported non or composited primary key cascade") } var pk = make(core.PK, len(table.PrimaryKeys)) + switch rawValueType.Kind() { case reflect.Int64: pk[0] = vv.Int() @@ -1750,8 +1888,10 @@ func (session *Session) _row2Bean(rows *core.Rows, fields []string, fieldsCount pk[0] = uint8(vv.Uint()) case reflect.String: pk[0] = vv.String() + case reflect.Slice: + pk[0], _ = strconv.ParseInt(string(rawValue.Interface().([]byte)), 10, 64) default: - panic("unsupported primary key type cascade") + panic(fmt.Sprintf("unsupported primary key type: %v, %v", rawValueType, fieldValue)) } if !isPKZero(pk) { @@ -1914,7 +2054,7 @@ func (session *Session) queryPreprocess(sqlStr *string, paramStr ...interface{}) *sqlStr = filter.Do(*sqlStr, session.Engine.dialect, session.Statement.RefTable) } - session.Engine.logSQL(*sqlStr, paramStr...) + session.saveLastSQL(*sqlStr, paramStr...) } func (session *Session) query(sqlStr string, paramStr ...interface{}) (resultsSlice []map[string][]byte, err error) { @@ -1922,7 +2062,7 @@ func (session *Session) query(sqlStr string, paramStr ...interface{}) (resultsSl session.queryPreprocess(&sqlStr, paramStr...) if session.IsAutoCommit { - return session.innerQuery(session.DB(), sqlStr, paramStr...) + return session.innerQuery(sqlStr, paramStr...) } return session.txQuery(session.Tx, sqlStr, paramStr...) } @@ -1937,22 +2077,33 @@ func (session *Session) txQuery(tx *core.Tx, sqlStr string, params ...interface{ return rows2maps(rows) } -func (session *Session) innerQuery(db *core.DB, sqlStr string, params ...interface{}) (resultsSlice []map[string][]byte, err error) { - stmt, rows, err := session.Engine.LogSQLQueryTime(sqlStr, params, func() (*core.Stmt, *core.Rows, error) { - stmt, err := db.Prepare(sqlStr) - if err != nil { - return stmt, nil, err +func (session *Session) innerQuery(sqlStr string, params ...interface{}) ([]map[string][]byte, error) { + var callback func() (*core.Stmt, *core.Rows, error) + if session.prepareStmt { + callback = func() (*core.Stmt, *core.Rows, error) { + stmt, err := session.doPrepare(sqlStr) + if err != nil { + return nil, nil, err + } + rows, err := stmt.Query(params...) + if err != nil { + return nil, nil, err + } + return stmt, rows, nil } - rows, err := stmt.Query(params...) - - return stmt, rows, err - }) + } else { + callback = func() (*core.Stmt, *core.Rows, error) { + rows, err := session.DB().Query(sqlStr, params...) + if err != nil { + return nil, nil, err + } + return nil, rows, err + } + } + _, rows, err := session.Engine.LogSQLQueryTime(sqlStr, params, callback) if rows != nil { defer rows.Close() } - if stmt != nil { - defer stmt.Close() - } if err != nil { return nil, err } @@ -2044,7 +2195,9 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error cols := make([]*core.Column, 0) for i := 0; i < size; i++ { - elemValue := sliceValue.Index(i).Interface() + v := sliceValue.Index(i) + vv := reflect.Indirect(v) + elemValue := v.Interface() colPlaces := make([]string, 0) // handle BeforeInsertProcessor @@ -2060,8 +2213,12 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error if i == 0 { for _, col := range table.Columns() { - fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName) - if col.IsAutoIncrement && fieldValue.Int() == 0 { + ptrFieldValue, err := col.ValueOfV(&vv) + if err != nil { + return 0, err + } + fieldValue := *ptrFieldValue + if col.IsAutoIncrement && isZero(fieldValue.Interface()) { continue } if col.MapType == core.ONLYFROMDB { @@ -2103,8 +2260,13 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } } else { for _, col := range cols { - fieldValue := reflect.Indirect(reflect.ValueOf(elemValue)).FieldByName(col.FieldName) - if col.IsAutoIncrement && fieldValue.Int() == 0 { + ptrFieldValue, err := col.ValueOfV(&vv) + if err != nil { + return 0, err + } + fieldValue := *ptrFieldValue + + if col.IsAutoIncrement && isZero(fieldValue.Interface()) { continue } if col.MapType == core.ONLYFROMDB { @@ -2167,7 +2329,8 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error lenAfterClosures := len(session.afterClosures) for i := 0; i < size; i++ { - elemValue := sliceValue.Index(i).Interface() + elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface() + // handle AfterInsertProcessor if session.IsAutoCommit { // !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi?? @@ -2226,17 +2389,20 @@ func (session *Session) byte2Time(col *core.Column, data []byte) (outTime time.T // time stamp sd, err := strconv.ParseInt(sdata, 10, 64) if err == nil { - x = time.Unix(0, sd) + x = time.Unix(sd, 0) // !nashtsai! HACK mymysql driver is casuing Local location being change to CHAT and cause wrong time conversion - x = x.In(time.UTC) - x = time.Date(x.Year(), x.Month(), x.Day(), x.Hour(), - x.Minute(), x.Second(), x.Nanosecond(), session.Engine.TZLocation) + //fmt.Println(x.In(session.Engine.TZLocation), "===") + if col.TimeZone == nil { + x = x.In(session.Engine.TZLocation) + } else { + x = x.In(col.TimeZone) + } + //fmt.Println(x, "=====") session.Engine.LogDebugf("time(0) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else { session.Engine.LogDebugf("time(0) err key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } - } else if len(sdata) > 19 { - + } else if len(sdata) > 19 && strings.Contains(sdata, "-") { x, err = time.ParseInLocation(time.RFC3339Nano, sdata, session.Engine.TZLocation) session.Engine.LogDebugf("time(1) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) if err != nil { @@ -2248,7 +2414,7 @@ func (session *Session) byte2Time(col *core.Column, data []byte) (outTime time.T session.Engine.LogDebugf("time(3) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } - } else if len(sdata) == 19 { + } else if len(sdata) == 19 && strings.Contains(sdata, "-") { x, err = time.ParseInLocation("2006-01-02 15:04:05", sdata, session.Engine.TZLocation) session.Engine.LogDebugf("time(4) key[%v]: %+v | sdata: [%v]\n", col.FieldName, x, sdata) } else if len(sdata) == 10 && sdata[4] == '-' && sdata[7] == '-' { @@ -2352,7 +2518,6 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } else { x = 0 } - //fmt.Println("######", x, data) } else if strings.HasPrefix(sdata, "0x") { x, err = strconv.ParseInt(sdata, 16, 64) } else if strings.HasPrefix(sdata, "0") { @@ -2382,108 +2547,115 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, fieldValue.SetUint(x) //Currently only support Time type case reflect.Struct: - if fieldType.ConvertibleTo(core.TimeType) { - x, err := session.byte2Time(col, data) - if err != nil { - return err + // !! 增加支持sql.Scanner接口的结构,如sql.NullString + if nulVal, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { + if err := nulVal.Scan(data); err != nil { + return fmt.Errorf("sql.Scan(%v) failed: %s ", data, err.Error()) } - v = x - fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) - } else if session.Statement.UseCascade { - table := session.Engine.autoMapType(*fieldValue) - if table != nil { - if len(table.PrimaryKeys) > 1 { - panic("unsupported composited primary key cascade") + } else { + if fieldType.ConvertibleTo(core.TimeType) { + x, err := session.byte2Time(col, data) + if err != nil { + return err } - var pk = make(core.PK, len(table.PrimaryKeys)) - rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) - switch rawValueType.Kind() { - case reflect.Int64: - x, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) + v = x + fieldValue.Set(reflect.ValueOf(v).Convert(fieldType)) + } else if session.Statement.UseCascade { + table := session.Engine.autoMapType(*fieldValue) + if table != nil { + if len(table.PrimaryKeys) > 1 { + panic("unsupported composited primary key cascade") } - pk[0] = x - case reflect.Int: - x, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) + var pk = make(core.PK, len(table.PrimaryKeys)) + rawValueType := table.ColumnType(table.PKColumns()[0].FieldName) + switch rawValueType.Kind() { + case reflect.Int64: + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = x + case reflect.Int: + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = int(x) + case reflect.Int32: + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = int32(x) + case reflect.Int16: + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = int16(x) + case reflect.Int8: + x, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = int8(x) + case reflect.Uint64: + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = x + case reflect.Uint: + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = uint(x) + case reflect.Uint32: + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = uint32(x) + case reflect.Uint16: + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = uint16(x) + case reflect.Uint8: + x, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return fmt.Errorf("arg %v as int: %s", key, err.Error()) + } + pk[0] = uint8(x) + case reflect.String: + pk[0] = string(data) + default: + panic("unsupported primary key type cascade") } - pk[0] = int(x) - case reflect.Int32: - x, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = int32(x) - case reflect.Int16: - x, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = int16(x) - case reflect.Int8: - x, err := strconv.ParseInt(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = int8(x) - case reflect.Uint64: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = x - case reflect.Uint: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = uint(x) - case reflect.Uint32: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = uint32(x) - case reflect.Uint16: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = uint16(x) - case reflect.Uint8: - x, err := strconv.ParseUint(string(data), 10, 64) - if err != nil { - return fmt.Errorf("arg %v as int: %s", key, err.Error()) - } - pk[0] = uint8(x) - case reflect.String: - pk[0] = string(data) - default: - panic("unsupported primary key type cascade") - } - if !isPKZero(pk) { - // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch - // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne - // property to be fetched lazily - structInter := reflect.New(fieldValue.Type()) - newsession := session.Engine.NewSession() - defer newsession.Close() - has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface()) - if err != nil { - return err - } - if has { - v = structInter.Elem().Interface() - fieldValue.Set(reflect.ValueOf(v)) - } else { - return errors.New("cascade obj is not exist!") + if !isPKZero(pk) { + // !nashtsai! TODO for hasOne relationship, it's preferred to use join query for eager fetch + // however, also need to consider adding a 'lazy' attribute to xorm tag which allow hasOne + // property to be fetched lazily + structInter := reflect.New(fieldValue.Type()) + newsession := session.Engine.NewSession() + defer newsession.Close() + has, err := newsession.Id(pk).NoCascade().Get(structInter.Interface()) + if err != nil { + return err + } + if has { + v = structInter.Elem().Interface() + fieldValue.Set(reflect.ValueOf(v)) + } else { + return errors.New("cascade obj is not exist!") + } } + } else { + return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) } - } else { - return fmt.Errorf("unsupported struct type in Scan: %s", fieldValue.Type().String()) } } case reflect.Ptr: @@ -2601,7 +2773,6 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } else { x = 0 } - //fmt.Println("######", x, data) } else if strings.HasPrefix(sdata, "0x") { x, err = strconv.ParseInt(sdata, 16, 64) } else if strings.HasPrefix(sdata, "0") { @@ -2627,7 +2798,6 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } else { x = 0 } - //fmt.Println("######", x, data) } else if strings.HasPrefix(sdata, "0x") { x1, err = strconv.ParseInt(sdata, 16, 64) x = int(x1) @@ -2656,7 +2826,6 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } else { x = 0 } - //fmt.Println("######", x, data) } else if strings.HasPrefix(sdata, "0x") { x1, err = strconv.ParseInt(sdata, 16, 64) x = int32(x1) @@ -2685,7 +2854,6 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } else { x = 0 } - //fmt.Println("######", x, data) } else if strings.HasPrefix(sdata, "0x") { x1, err = strconv.ParseInt(sdata, 16, 64) x = int8(x1) @@ -2714,7 +2882,6 @@ func (session *Session) bytes2Value(col *core.Column, fieldValue *reflect.Value, } else { x = 0 } - //fmt.Println("######", x, data) } else if strings.HasPrefix(sdata, "0x") { x1, err = strconv.ParseInt(sdata, 16, 64) x = int16(x1) @@ -2892,31 +3059,47 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val case reflect.String: return fieldValue.String(), nil case reflect.Struct: - if fieldType == core.TimeType { - switch fieldValue.Interface().(type) { - case time.Time: - t := fieldValue.Interface().(time.Time) - if session.Engine.dialect.DBType() == core.MSSQL { - if t.IsZero() { - return nil, nil - } + if fieldType.ConvertibleTo(core.TimeType) { + t := fieldValue.Convert(core.TimeType).Interface().(time.Time) + if session.Engine.dialect.DBType() == core.MSSQL { + if t.IsZero() { + return nil, nil } - tf := session.Engine.FormatTime(col.SQLType.Name, t) - return tf, nil - default: - return fieldValue.Interface(), nil } + tf := session.Engine.FormatTime(col.SQLType.Name, t) + return tf, nil } - if fieldTable, ok := session.Engine.Tables[fieldValue.Type()]; ok { + + if !col.SQLType.IsJson() { + // !! 增加支持driver.Valuer接口的结构,如sql.NullString + if v, ok := fieldValue.Interface().(driver.Valuer); ok { + return v.Value() + } + + fieldTable := session.Engine.autoMapType(fieldValue) if len(fieldTable.PrimaryKeys) == 1 { pkField := reflect.Indirect(fieldValue).FieldByName(fieldTable.PKColumns()[0].FieldName) return pkField.Interface(), nil - } else { - return 0, fmt.Errorf("no primary key for col %v", col.Name) } - } else { - return 0, fmt.Errorf("Unsupported type %v\n", fieldValue.Type()) + return 0, fmt.Errorf("no primary key for col %v", col.Name) } + + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + session.Engine.LogError(err) + return 0, err + } + return string(bytes), nil + } else if col.SQLType.IsBlob() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + session.Engine.LogError(err) + return 0, err + } + return bytes, nil + } + return nil, fmt.Errorf("Unsupported type %v", fieldValue.Type()) case reflect.Complex64, reflect.Complex128: bytes, err := json.Marshal(fieldValue.Interface()) if err != nil { @@ -2950,9 +3133,8 @@ func (session *Session) value2Interface(col *core.Column, fieldValue reflect.Val } } return bytes, nil - } else { - return nil, ErrUnSupportedType } + return nil, ErrUnSupportedType case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: return int64(fieldValue.Uint()), nil default: @@ -2974,12 +3156,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { processor.BeforeInsert() } // -- - colNames, args, err := genCols(table, session, bean, false, false) if err != nil { return 0, err } - // insert expr columns, override if exists exprColumns := session.Statement.getExpr() exprColVals := make([]string, 0, len(exprColumns)) @@ -3044,9 +3224,10 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { // for postgres, many of them didn't implement lastInsertId, so we should // implemented it ourself. + if session.Engine.dialect.DBType() == core.ORACLE && len(table.AutoIncrement) > 0 { + //assert table.AutoIncrement != "" + res, err := session.query("select seq_atable.currval from dual", args...) - if session.Engine.DriverName() != core.POSTGRES || table.AutoIncrement == "" { - res, err := session.exec(sqlStr, args...) if err != nil { return 0, err } else { @@ -3066,14 +3247,14 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { } } - if table.AutoIncrement == "" { - return res.RowsAffected() + if len(res) < 1 { + return 0, errors.New("insert no error but not returned id") } - var id int64 = 0 - id, err = res.LastInsertId() - if err != nil || id <= 0 { - return res.RowsAffected() + idByte := res[0][table.AutoIncrement] + id, err := strconv.ParseInt(string(idByte), 10, 64) + if err != nil { + return 1, err } aiValue, err := table.AutoIncrColumn().ValueOf(bean) @@ -3081,27 +3262,15 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { session.Engine.LogError(err) } - if aiValue == nil || !aiValue.IsValid() /*|| aiValue.Int() != 0*/ || !aiValue.CanSet() { - return res.RowsAffected() + if aiValue == nil || !aiValue.IsValid() /*|| aiValue. != 0*/ || !aiValue.CanSet() { + return 1, nil } - var v interface{} = id - switch aiValue.Type().Kind() { - case reflect.Int32: - v = int32(id) - case reflect.Int: - v = int(id) - case reflect.Uint32: - v = uint32(id) - case reflect.Uint64: - v = uint64(id) - case reflect.Uint: - v = uint(id) - } + v := int64ToInt(id, aiValue.Type().Kind()) aiValue.Set(reflect.ValueOf(v)) - return res.RowsAffected() - } else { + return 1, nil + } else if session.Engine.dialect.DBType() == core.POSTGRES && len(table.AutoIncrement) > 0 { //assert table.AutoIncrement != "" sqlStr = sqlStr + " RETURNING " + session.Engine.Quote(table.AutoIncrement) res, err := session.query(sqlStr, args...) @@ -3144,22 +3313,54 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { return 1, nil } - var v interface{} = id - switch aiValue.Type().Kind() { - case reflect.Int32: - v = int32(id) - case reflect.Int: - v = int(id) - case reflect.Uint32: - v = uint32(id) - case reflect.Uint64: - v = uint64(id) - case reflect.Uint: - v = uint(id) - } + v := int64ToInt(id, aiValue.Type().Kind()) aiValue.Set(reflect.ValueOf(v)) return 1, nil + } else { + res, err := session.exec(sqlStr, args...) + if err != nil { + return 0, err + } else { + handleAfterInsertProcessorFunc(bean) + } + + if cacher := session.Engine.getCacher2(table); cacher != nil && session.Statement.UseCache { + session.cacheInsert(session.Statement.TableName()) + } + + if table.Version != "" && session.Statement.checkVersion { + verValue, err := table.VersionColumn().ValueOf(bean) + if err != nil { + session.Engine.LogError(err) + } else if verValue.IsValid() && verValue.CanSet() { + verValue.SetInt(1) + } + } + + if table.AutoIncrement == "" { + return res.RowsAffected() + } + + var id int64 = 0 + id, err = res.LastInsertId() + if err != nil || id <= 0 { + return res.RowsAffected() + } + + aiValue, err := table.AutoIncrColumn().ValueOf(bean) + if err != nil { + session.Engine.LogError(err) + } + + if aiValue == nil || !aiValue.IsValid() /*|| aiValue.Int() != 0*/ || !aiValue.CanSet() { + return res.RowsAffected() + } + + v := int64ToInt(id, aiValue.Type().Kind()) + aiValue.Set(reflect.ValueOf(v)) + + return res.RowsAffected() } } @@ -3180,7 +3381,7 @@ func (statement *Statement) convertUpdateSql(sqlStr string) (string, string) { return "", "" } - colstrs := statement.JoinColumns(statement.RefTable.PKColumns()) + colstrs := statement.JoinColumns(statement.RefTable.PKColumns(), true) sqls := splitNNoCase(sqlStr, "where", 2) if len(sqls) != 2 { if len(sqls) == 1 { @@ -3233,7 +3434,8 @@ func (session *Session) cacheInsert(tables ...string) error { } func (session *Session) cacheUpdate(sqlStr string, args ...interface{}) error { - if session.Statement.RefTable == nil || len(session.Statement.RefTable.PrimaryKeys) != 1 { + if session.Statement.RefTable == nil || + session.Tx != nil { return ErrCacheFailed } @@ -3382,21 +3584,24 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 // -- var err error - if t.Kind() == reflect.Struct { + var isMap = t.Kind() == reflect.Map + var isStruct = t.Kind() == reflect.Struct + if isStruct { table = session.Engine.TableInfo(bean) session.Statement.RefTable = table if session.Statement.ColumnStr == "" { colNames, args = buildUpdates(session.Engine, table, bean, false, false, false, false, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.mustColumnMap, session.Statement.columnMap, true) + session.Statement.mustColumnMap, session.Statement.nullableMap, + session.Statement.columnMap, true, session.Statement.unscoped) } else { colNames, args, err = genCols(table, session, bean, true, true) if err != nil { return 0, err } } - } else if t.Kind() == reflect.Map { + } else if isMap { if session.Statement.RefTable == nil { return 0, ErrTableNotFound } @@ -3420,10 +3625,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 args = append(args, val) var colName = col.Name - session.afterClosures = append(session.afterClosures, func(bean interface{}) { - col := table.GetColumn(colName) - setColumnTime(bean, col, t) - }) + if isStruct { + session.afterClosures = append(session.afterClosures, func(bean interface{}) { + col := table.GetColumn(colName) + setColumnTime(bean, col, t) + }) + } } //for update action to like "column = column + ?" @@ -3447,10 +3654,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 var condiColNames []string var condiArgs []interface{} - if len(condiBean) > 0 { - condiColNames, condiArgs = buildConditions(session.Engine, session.Statement.RefTable, condiBean[0], true, true, - false, true, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.unscoped, session.Statement.mustColumnMap) + if !session.Statement.noAutoCondition && len(condiBean) > 0 { + condiColNames, condiArgs = session.Statement.buildConditions(session.Statement.RefTable, condiBean[0], true, true, false, true, false) } var condition = "" @@ -3567,6 +3772,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } else { afterClosures := make([]func(interface{}), lenAfterClosures) copy(afterClosures, session.afterClosures) + // FIXME: if bean is a map type, it will panic because map cannot be as map key session.afterUpdateBeans[bean] = &afterClosures } @@ -3583,7 +3789,8 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error { - if session.Statement.RefTable == nil || len(session.Statement.RefTable.PrimaryKeys) != 1 { + if session.Statement.RefTable == nil || + session.Tx != nil { return ErrCacheFailed } @@ -3608,15 +3815,25 @@ func (session *Session) cacheDelete(sqlStr string, args ...interface{}) error { if len(resultsSlice) > 0 { for _, data := range resultsSlice { var id int64 - if v, ok := data[session.Statement.RefTable.PrimaryKeys[0]]; !ok { - return errors.New("no id") - } else { - id, err = strconv.ParseInt(string(v), 10, 64) - if err != nil { - return err + var pk core.PK = make([]interface{}, 0) + for _, col := range session.Statement.RefTable.PKColumns() { + if v, ok := data[col.Name]; !ok { + return errors.New("no id") + } else { + if col.SQLType.IsText() { + pk = append(pk, string(v)) + } else if col.SQLType.IsNumeric() { + id, err = strconv.ParseInt(string(v), 10, 64) + if err != nil { + return err + } + pk = append(pk, id) + } else { + return errors.New("not supported primary key type") + } } } - ids = append(ids, core.PK{id}) + ids = append(ids, pk) } } } /*else { @@ -3657,10 +3874,12 @@ func (session *Session) Delete(bean interface{}) (int64, error) { table := session.Engine.TableInfo(bean) session.Statement.RefTable = table - colNames, args := buildConditions(session.Engine, table, bean, true, true, - false, true, session.Statement.allUseBool, session.Statement.useAllCols, - session.Statement.unscoped, session.Statement.mustColumnMap) + var colNames []string + var args []interface{} + if !session.Statement.noAutoCondition { + colNames, args = session.Statement.buildConditions(table, bean, true, true, false, true, false) + } var condition = "" var andStr = session.Engine.dialect.AndStr() @@ -3681,32 +3900,91 @@ func (session *Session) Delete(bean interface{}) (int64, error) { condition += inSql args = append(args, inArgs...) } - if len(condition) == 0 { + if len(condition) == 0 && session.Statement.LimitN == 0 { return 0, ErrNeedDeletedCond } - sqlStr, sqlStrForCache := "", "" + var deleteSql, realSql string + var tableName = session.Engine.Quote(session.Statement.TableName()) + + if len(condition) > 0 { + deleteSql = fmt.Sprintf("DELETE FROM %v WHERE %v", tableName, condition) + } else { + deleteSql = fmt.Sprintf("DELETE FROM %v", tableName) + } + + var orderSql string + if len(session.Statement.OrderStr) > 0 { + orderSql += fmt.Sprintf(" ORDER BY %s", session.Statement.OrderStr) + } + if session.Statement.LimitN > 0 { + orderSql += fmt.Sprintf(" LIMIT %d", session.Statement.LimitN) + } + + if len(orderSql) > 0 { + switch session.Engine.dialect.DBType() { + case core.POSTGRES: + inSql := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSql) + if len(condition) > 0 { + deleteSql += " AND " + inSql + } else { + deleteSql += " WHERE " + inSql + } + case core.SQLITE: + inSql := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSql) + if len(condition) > 0 { + deleteSql += " AND " + inSql + } else { + deleteSql += " WHERE " + inSql + } + // TODO: how to handle delete limit on mssql? + case core.MSSQL: + return 0, ErrNotImplemented + default: + deleteSql += orderSql + } + } + argsForCache := make([]interface{}, 0, len(args)*2) if session.Statement.unscoped || table.DeletedColumn() == nil { // tag "deleted" is disabled - sqlStr = fmt.Sprintf("DELETE FROM %v WHERE %v", - session.Engine.Quote(session.Statement.TableName()), condition) - - sqlStrForCache = sqlStr + realSql = deleteSql copy(argsForCache, args) argsForCache = append(session.Statement.Params, argsForCache...) } else { // !oinume! sqlStrForCache and argsForCache is needed to behave as executing "DELETE FROM ..." for cache. - sqlStrForCache = fmt.Sprintf("DELETE FROM %v WHERE %v", - session.Engine.Quote(session.Statement.TableName()), condition) copy(argsForCache, args) argsForCache = append(session.Statement.Params, argsForCache...) deletedColumn := table.DeletedColumn() - sqlStr = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", + realSql = fmt.Sprintf("UPDATE %v SET %v = ? WHERE %v", session.Engine.Quote(session.Statement.TableName()), session.Engine.Quote(deletedColumn.Name), condition) + if len(orderSql) > 0 { + switch session.Engine.dialect.DBType() { + case core.POSTGRES: + inSql := fmt.Sprintf("ctid IN (SELECT ctid FROM %s%s)", tableName, orderSql) + if len(condition) > 0 { + realSql += " AND " + inSql + } else { + realSql += " WHERE " + inSql + } + case core.SQLITE: + inSql := fmt.Sprintf("rowid IN (SELECT rowid FROM %s%s)", tableName, orderSql) + if len(condition) > 0 { + realSql += " AND " + inSql + } else { + realSql += " WHERE " + inSql + } + // TODO: how to handle delete limit on mssql? + case core.MSSQL: + return 0, ErrNotImplemented + default: + realSql += orderSql + } + } + // !oinume! Insert NowTime to the head of session.Statement.Params session.Statement.Params = append(session.Statement.Params, "") paramsLen := len(session.Statement.Params) @@ -3725,10 +4003,10 @@ func (session *Session) Delete(bean interface{}) (int64, error) { args = append(session.Statement.Params, args...) if cacher := session.Engine.getCacher2(session.Statement.RefTable); cacher != nil && session.Statement.UseCache { - session.cacheDelete(sqlStrForCache, argsForCache...) + session.cacheDelete(deleteSql, argsForCache...) } - res, err := session.exec(sqlStr, args...) + res, err := session.exec(realSql, args...) if err != nil { return 0, err } @@ -3763,6 +4041,18 @@ func (session *Session) Delete(bean interface{}) (int64, error) { return res.RowsAffected() } +// saveLastSQL stores executed query information +func (session *Session) saveLastSQL(sql string, args ...interface{}) { + session.lastSQL = sql + session.lastSQLArgs = args + session.Engine.logSQL(sql, args...) +} + +// LastSQL returns last query information +func (session *Session) LastSQL() (string, []interface{}) { + return session.lastSQL, session.lastSQLArgs +} + func (s *Session) Sync2(beans ...interface{}) error { engine := s.Engine @@ -3779,7 +4069,7 @@ func (s *Session) Sync2(beans ...interface{}) error { var oriTable *core.Table for _, tb := range tables { - if tb.Name == table.Name { + if equalNoCase(tb.Name, table.Name) { oriTable = tb break } @@ -3804,7 +4094,7 @@ func (s *Session) Sync2(beans ...interface{}) error { for _, col := range table.Columns() { var oriCol *core.Column for _, col2 := range oriTable.Columns() { - if col.Name == col2.Name { + if equalNoCase(col.Name, col2.Name) { oriCol = col2 break } @@ -3826,10 +4116,26 @@ func (s *Session) Sync2(beans ...interface{}) error { engine.LogWarnf("Table %s column %s db type is %s, struct type is %s\n", table.Name, col.Name, curType, expectedType) } + } else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) { + if engine.dialect.DBType() == core.MYSQL { + if oriCol.Length < col.Length { + engine.LogInfof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", + table.Name, col.Name, oriCol.Length, col.Length) + _, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) + } + } } else { engine.LogWarnf("Table %s column %s db type is %s, struct type is %s", table.Name, col.Name, curType, expectedType) } + } else if expectedType == core.Varchar { + if engine.dialect.DBType() == core.MYSQL { + if oriCol.Length < col.Length { + engine.LogInfof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", + table.Name, col.Name, oriCol.Length, col.Length) + _, err = engine.Exec(engine.dialect.ModifyColumnSql(table.Name, col)) + } + } } if col.Default != oriCol.Default { engine.LogWarnf("Table %s Column %s db default is %s, struct default is %s", @@ -3851,6 +4157,7 @@ func (s *Session) Sync2(beans ...interface{}) error { } var foundIndexNames = make(map[string]bool) + var addedNames = make(map[string]*core.Index) for name, index := range table.Indexes { var oriIndex *core.Index @@ -3874,20 +4181,7 @@ func (s *Session) Sync2(beans ...interface{}) error { } if oriIndex == nil { - if index.Type == core.UniqueType { - session := engine.NewSession() - session.Statement.RefTable = table - defer session.Close() - err = session.addUnique(table.Name, name) - } else if index.Type == core.IndexType { - session := engine.NewSession() - session.Statement.RefTable = table - defer session.Close() - err = session.addIndex(table.Name, name) - } - if err != nil { - return err - } + addedNames[name] = index } } @@ -3900,13 +4194,30 @@ func (s *Session) Sync2(beans ...interface{}) error { } } } + + for name, index := range addedNames { + if index.Type == core.UniqueType { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addUnique(table.Name, name) + } else if index.Type == core.IndexType { + session := engine.NewSession() + session.Statement.RefTable = table + defer session.Close() + err = session.addIndex(table.Name, name) + } + if err != nil { + return err + } + } } } for _, table := range tables { var oriTable *core.Table for _, structTable := range structTables { - if table.Name == structTable.Name { + if equalNoCase(table.Name, structTable.Name) { oriTable = structTable break } diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/sqlite3_dialect.go b/Godeps/_workspace/src/github.com/go-xorm/xorm/sqlite3_dialect.go index cb9e7f54a78..9f29d587cdb 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/sqlite3_dialect.go +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/sqlite3_dialect.go @@ -1,9 +1,14 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package xorm import ( "database/sql" "errors" "fmt" + "regexp" "strings" "github.com/go-xorm/core" @@ -152,13 +157,21 @@ func (db *sqlite3) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName st func (db *sqlite3) SqlType(c *core.Column) string { switch t := c.SQLType.Name; t { + case core.Bool: + if c.Default == "true" { + c.Default = "1" + } else if c.Default == "false" { + c.Default = "0" + } + return core.Integer case core.Date, core.DateTime, core.TimeStamp, core.Time: return core.DateTime case core.TimeStampz: return core.Text - case core.Char, core.Varchar, core.NVarchar, core.TinyText, core.Text, core.MediumText, core.LongText: + case core.Char, core.Varchar, core.NVarchar, core.TinyText, + core.Text, core.MediumText, core.LongText, core.Json: return core.Text - case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt, core.Bool: + case core.Bit, core.TinyInt, core.SmallInt, core.MediumInt, core.Int, core.Integer, core.BigInt: return core.Integer case core.Float, core.Double, core.Real: return core.Real @@ -238,15 +251,19 @@ func (db *sqlite3) DropIndexSql(tableName string, index *core.Index) string { return fmt.Sprintf("DROP INDEX %v", quote(idxName)) } +func (db *sqlite3) ForUpdateSql(query string) string { + return query +} + /*func (db *sqlite3) ColumnCheckSql(tableName, colName string) (string, []interface{}) { args := []interface{}{tableName} sql := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" return sql, args }*/ -func (db *sqlite3) IsColumnExist(tableName string, col *core.Column) (bool, error) { +func (db *sqlite3) IsColumnExist(tableName, colName string) (bool, error) { args := []interface{}{tableName} - query := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + col.Name + "`%') or (sql like '%[" + col.Name + "]%'))" + query := "SELECT name FROM sqlite_master WHERE type='table' and name = ? and ((sql like '%`" + colName + "`%') or (sql like '%[" + colName + "]%'))" rows, err := db.DB().Query(query, args...) if db.Logger != nil { db.Logger.Info("[sql]", query, args) @@ -290,10 +307,13 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu nStart := strings.Index(name, "(") nEnd := strings.LastIndex(name, ")") - colCreates := strings.Split(name[nStart+1:nEnd], ",") + reg := regexp.MustCompile(`[^\(,\)]*(\([^\(]*\))?`) + colCreates := reg.FindAllString(name[nStart+1:nEnd], -1) cols := make(map[string]*core.Column) colSeq := make([]string, 0) for _, colStr := range colCreates { + reg = regexp.MustCompile(`,\s`) + colStr = reg.ReplaceAllString(colStr, ",") fields := strings.Fields(strings.TrimSpace(colStr)) col := new(core.Column) col.Indexes = make(map[string]bool) diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/sqlite3_driver.go b/Godeps/_workspace/src/github.com/go-xorm/xorm/sqlite3_driver.go index 2ecd9edf5d5..6ae19569ef8 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/sqlite3_driver.go +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/sqlite3_driver.go @@ -1,3 +1,7 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package xorm import ( diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/statement.go b/Godeps/_workspace/src/github.com/go-xorm/xorm/statement.go index b8f859714ea..9b9042fa9ad 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/statement.go +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/statement.go @@ -1,6 +1,12 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package xorm import ( + "bytes" + "database/sql/driver" "encoding/json" "errors" "fmt" @@ -33,42 +39,46 @@ type exprParam struct { // statement save all the sql info for executing SQL type Statement struct { - RefTable *core.Table - Engine *Engine - Start int - LimitN int - WhereStr string - IdParam *core.PK - Params []interface{} - OrderStr string - JoinStr string - GroupByStr string - HavingStr string - ColumnStr string - columnMap map[string]bool - useAllCols bool - OmitStr string - ConditionStr string - AltTableName string - RawSQL string - RawParams []interface{} - UseCascade bool - UseAutoJoin bool - StoreEngine string - Charset string - BeanArgs []interface{} - UseCache bool - UseAutoTime bool - IsDistinct bool - TableAlias string - allUseBool bool - checkVersion bool - unscoped bool - mustColumnMap map[string]bool - inColumns map[string]*inParam - incrColumns map[string]incrParam - decrColumns map[string]decrParam - exprColumns map[string]exprParam + RefTable *core.Table + Engine *Engine + Start int + LimitN int + WhereStr string + IdParam *core.PK + Params []interface{} + OrderStr string + JoinStr string + GroupByStr string + HavingStr string + ColumnStr string + selectStr string + columnMap map[string]bool + useAllCols bool + OmitStr string + ConditionStr string + AltTableName string + RawSQL string + RawParams []interface{} + UseCascade bool + UseAutoJoin bool + StoreEngine string + Charset string + BeanArgs []interface{} + UseCache bool + UseAutoTime bool + noAutoCondition bool + IsDistinct bool + IsForUpdate bool + TableAlias string + allUseBool bool + checkVersion bool + unscoped bool + mustColumnMap map[string]bool + nullableMap map[string]bool + inColumns map[string]*inParam + incrColumns map[string]incrParam + decrColumns map[string]decrParam + exprColumns map[string]exprParam } // init @@ -94,11 +104,15 @@ func (statement *Statement) Init() { statement.BeanArgs = make([]interface{}, 0) statement.UseCache = true statement.UseAutoTime = true + statement.noAutoCondition = false statement.IsDistinct = false + statement.IsForUpdate = false statement.TableAlias = "" + statement.selectStr = "" statement.allUseBool = false statement.useAllCols = false statement.mustColumnMap = make(map[string]bool) + statement.nullableMap = make(map[string]bool) statement.checkVersion = true statement.unscoped = false statement.inColumns = make(map[string]*inParam) @@ -107,20 +121,29 @@ func (statement *Statement) Init() { statement.exprColumns = make(map[string]exprParam) } -// add the raw sql statement +// NoAutoCondition if you do not want convert bean's field as query condition, then use this function +func (statement *Statement) NoAutoCondition(no ...bool) *Statement { + statement.noAutoCondition = true + if len(no) > 0 { + statement.noAutoCondition = no[0] + } + return statement +} + +// Sql add the raw sql statement func (statement *Statement) Sql(querystring string, args ...interface{}) *Statement { statement.RawSQL = querystring statement.RawParams = args return statement } -// set the table alias +// Alias set the table alias func (statement *Statement) Alias(alias string) *Statement { statement.TableAlias = alias return statement } -// add Where statment +// Where add Where statment func (statement *Statement) Where(querystring string, args ...interface{}) *Statement { if !strings.Contains(querystring, statement.Engine.dialect.EqStr()) { querystring = strings.Replace(querystring, "=", statement.Engine.dialect.EqStr(), -1) @@ -130,11 +153,13 @@ func (statement *Statement) Where(querystring string, args ...interface{}) *Stat return statement } -// add Where & and statment +// And add Where & and statment func (statement *Statement) And(querystring string, args ...interface{}) *Statement { - if statement.WhereStr != "" { - statement.WhereStr = fmt.Sprintf("(%v) %s (%v)", statement.WhereStr, + if len(statement.WhereStr) > 0 { + var buf bytes.Buffer + fmt.Fprintf(&buf, "(%v) %s (%v)", statement.WhereStr, statement.Engine.dialect.AndStr(), querystring) + statement.WhereStr = buf.String() } else { statement.WhereStr = querystring } @@ -142,11 +167,13 @@ func (statement *Statement) And(querystring string, args ...interface{}) *Statem return statement } -// add Where & Or statment +// Or add Where & Or statment func (statement *Statement) Or(querystring string, args ...interface{}) *Statement { - if statement.WhereStr != "" { - statement.WhereStr = fmt.Sprintf("(%v) %s (%v)", statement.WhereStr, + if len(statement.WhereStr) > 0 { + var buf bytes.Buffer + fmt.Fprintf(&buf, "(%v) %s (%v)", statement.WhereStr, statement.Engine.dialect.OrStr(), querystring) + statement.WhereStr = buf.String() } else { statement.WhereStr = querystring } @@ -154,7 +181,7 @@ func (statement *Statement) Or(querystring string, args ...interface{}) *Stateme return statement } -// tempororily set table name +// Table tempororily set table name, the parameter could be a string or a pointer of struct func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { v := rValue(tableNameOrBean) t := v.Type() @@ -166,127 +193,12 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { return statement } -/*func (statement *Statement) genFields(bean interface{}) map[string]interface{} { - results := make(map[string]interface{}) - table := statement.Engine.TableInfo(bean) - for _, col := range table.Columns { - fieldValue := col.ValueOf(bean) - fieldType := reflect.TypeOf(fieldValue.Interface()) - var val interface{} - switch fieldType.Kind() { - case reflect.Bool: - if allUseBool { - val = fieldValue.Interface() - } else if _, ok := boolColumnMap[col.Name]; ok { - val = fieldValue.Interface() - } else { - // if a bool in a struct, it will not be as a condition because it default is false, - // please use Where() instead - continue - } - case reflect.String: - if fieldValue.String() == "" { - continue - } - // for MyString, should convert to string or panic - if fieldType.String() != reflect.String.String() { - val = fieldValue.String() - } else { - val = fieldValue.Interface() - } - case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64: - if fieldValue.Int() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Float32, reflect.Float64: - if fieldValue.Float() == 0.0 { - continue - } - val = fieldValue.Interface() - case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64: - if fieldValue.Uint() == 0 { - continue - } - val = fieldValue.Interface() - case reflect.Struct: - if fieldType == reflect.TypeOf(time.Now()) { - t := fieldValue.Interface().(time.Time) - if t.IsZero() || !fieldValue.IsValid() { - continue - } - var str string - if col.SQLType.Name == Time { - s := t.UTC().Format("2006-01-02 15:04:05") - val = s[11:19] - } else if col.SQLType.Name == Date { - str = t.Format("2006-01-02") - val = str - } else { - val = t - } - } else { - engine.autoMapType(fieldValue.Type()) - if table, ok := engine.Tables[fieldValue.Type()]; ok { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumn().FieldName) - if pkField.Int() != 0 { - val = pkField.Interface() - } else { - continue - } - } else { - val = fieldValue.Interface() - } - } - case reflect.Array, reflect.Slice, reflect.Map: - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldValue.IsNil() || !fieldValue.IsValid() { - continue - } - - if col.SQLType.IsText() { - bytes, err := json.Marshal(fieldValue.Interface()) - if err != nil { - engine.LogError(err) - continue - } - val = string(bytes) - } else if col.SQLType.IsBlob() { - var bytes []byte - var err error - if (fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice) && - fieldType.Elem().Kind() == reflect.Uint8 { - if fieldValue.Len() > 0 { - val = fieldValue.Bytes() - } else { - continue - } - } else { - bytes, err = json.Marshal(fieldValue.Interface()) - if err != nil { - engine.LogError(err) - continue - } - val = bytes - } - } else { - continue - } - default: - val = fieldValue.Interface() - } - results[col.Name] = val - } - return results -}*/ - -// Auto generating conditions according a struct +// Auto generating update columnes and values according a struct func buildUpdates(engine *Engine, table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, - mustColumnMap map[string]bool, columnMap map[string]bool, update bool) ([]string, []interface{}) { + mustColumnMap map[string]bool, nullableMap map[string]bool, + columnMap map[string]bool, update, unscoped bool) ([]string, []interface{}) { colNames := make([]string, 0) var args = make([]interface{}, 0) @@ -303,17 +215,13 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, if !includeAutoIncr && col.IsAutoIncrement { continue } - if col.IsDeleted { + if col.IsDeleted && !unscoped { continue } if use, ok := columnMap[col.Name]; ok && !use { continue } - if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { - continue - } - fieldValuePtr, err := col.ValueOf(bean) if err != nil { engine.LogError(err) @@ -325,7 +233,9 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, requiredField := useAllCols includeNil := useAllCols - if b, ok := mustColumnMap[strings.ToLower(col.Name)]; ok { + lColName := strings.ToLower(col.Name) + + if b, ok := mustColumnMap[lColName]; ok { if b { requiredField = true } else { @@ -333,6 +243,16 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, } } + // !evalphobia! set fieldValue as nil when column is nullable and zero-value + if b, ok := nullableMap[lColName]; ok { + if b && col.Nullable && isZero(fieldValue.Interface()) { + var nilValue *int + fieldValue = reflect.ValueOf(nilValue) + fieldType = reflect.TypeOf(fieldValue.Interface()) + includeNil = true + } + } + var val interface{} if fieldValue.CanAddr() { @@ -410,38 +330,53 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, t := int64(fieldValue.Uint()) val = reflect.ValueOf(&t).Interface() case reflect.Struct: - if fieldType == reflect.TypeOf(time.Now()) { - t := fieldValue.Interface().(time.Time) + if fieldType.ConvertibleTo(core.TimeType) { + t := fieldValue.Convert(core.TimeType).Interface().(time.Time) if !requiredField && (t.IsZero() || !fieldValue.IsValid()) { continue } val = engine.FormatTime(col.SQLType.Name, t) + } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ = nulType.Value() } else { - engine.autoMapType(fieldValue) - if table, ok := engine.Tables[fieldValue.Type()]; ok { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - //if pkField.Int() != 0 { - if pkField.IsValid() && !isZero(pkField.Interface()) { - val = pkField.Interface() + if !col.SQLType.IsJson() { + engine.autoMapType(fieldValue) + if table, ok := engine.Tables[fieldValue.Type()]; ok { + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + if pkField.IsValid() && !isZero(pkField.Interface()) { + val = pkField.Interface() + } else { + continue + } } else { - continue + //TODO: how to handler? + panic("not supported") } } else { - //TODO: how to handler? - panic("not supported") + val = fieldValue.Interface() } } else { - val = fieldValue.Interface() + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface())) + } + if col.SQLType.IsText() { + val = string(bytes) + } else if col.SQLType.IsBlob() { + val = bytes + } } } case reflect.Array, reflect.Slice, reflect.Map: - if fieldValue == reflect.Zero(fieldType) { - continue - } - if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { - continue + if !requiredField { + if fieldValue == reflect.Zero(fieldType) { + continue + } + if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 { + continue + } } if col.SQLType.IsText() { @@ -492,8 +427,7 @@ func buildUpdates(engine *Engine, table *core.Table, bean interface{}, func buildConditions(engine *Engine, table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, allUseBool bool, useAllCols bool, unscoped bool, - mustColumnMap map[string]bool) ([]string, []interface{}) { - + mustColumnMap map[string]bool, tableName, aliasName string, addedTableName bool) ([]string, []interface{}) { colNames := make([]string, 0) var args = make([]interface{}, 0) for _, col := range table.Columns() { @@ -510,6 +444,21 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, if engine.dialect.DBType() == core.MSSQL && col.SQLType.Name == core.Text { continue } + if col.SQLType.IsJson() { + continue + } + + var colName string + if addedTableName { + var nm = tableName + if len(aliasName) > 0 { + nm = aliasName + } + colName = engine.Quote(nm) + "." + engine.Quote(col.Name) + } else { + colName = engine.Quote(col.Name) + } + fieldValuePtr, err := col.ValueOf(bean) if err != nil { engine.LogError(err) @@ -517,7 +466,8 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, } if col.IsDeleted && !unscoped { // tag "deleted" is enabled - colNames = append(colNames, fmt.Sprintf("(%v IS NULL or %v = '0001-01-01 00:00:00')", engine.Quote(col.Name), engine.Quote(col.Name))) + colNames = append(colNames, fmt.Sprintf("%v IS NULL or %v = '0001-01-01 00:00:00'", + colName, colName)) } fieldValue := *fieldValuePtr @@ -539,7 +489,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, if fieldValue.IsNil() { if includeNil { args = append(args, nil) - colNames = append(colNames, fmt.Sprintf("%v %s ?", engine.Quote(col.Name), engine.dialect.EqStr())) + colNames = append(colNames, fmt.Sprintf("%v %s ?", colName, engine.dialect.EqStr())) } continue } else if !fieldValue.IsValid() { @@ -597,24 +547,49 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, val = engine.FormatTime(col.SQLType.Name, t) } else if _, ok := reflect.New(fieldType).Interface().(core.Conversion); ok { continue + } else if valNul, ok := fieldValue.Interface().(driver.Valuer); ok { + val, _ = valNul.Value() + if val == nil { + continue + } } else { - engine.autoMapType(fieldValue) - if table, ok := engine.Tables[fieldValue.Type()]; ok { - if len(table.PrimaryKeys) == 1 { - pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) - // fix non-int pk issues - //if pkField.Int() != 0 { - if pkField.IsValid() && !isZero(pkField.Interface()) { - val = pkField.Interface() - } else { + if col.SQLType.IsJson() { + if col.SQLType.IsText() { + bytes, err := json.Marshal(fieldValue.Interface()) + if err != nil { + engine.LogError(err) continue } - } else { - //TODO: how to handler? - panic("not supported") + val = string(bytes) + } else if col.SQLType.IsBlob() { + var bytes []byte + var err error + bytes, err = json.Marshal(fieldValue.Interface()) + if err != nil { + engine.LogError(err) + continue + } + val = bytes } } else { - val = fieldValue.Interface() + engine.autoMapType(fieldValue) + if table, ok := engine.Tables[fieldValue.Type()]; ok { + if len(table.PrimaryKeys) == 1 { + pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName) + // fix non-int pk issues + //if pkField.Int() != 0 { + if pkField.IsValid() && !isZero(pkField.Interface()) { + val = pkField.Interface() + } else { + continue + } + } else { + //TODO: how to handler? + panic(fmt.Sprintln("not supported", fieldValue.Interface(), "as", table.PrimaryKeys)) + } + } else { + val = fieldValue.Interface() + } } } case reflect.Array, reflect.Slice, reflect.Map: @@ -662,7 +637,7 @@ func buildConditions(engine *Engine, table *core.Table, bean interface{}, if col.IsPrimaryKey && engine.dialect.DBType() == "ql" { condi = "id() == ?" } else { - condi = fmt.Sprintf("%v %s ?", engine.Quote(col.Name), engine.dialect.EqStr()) + condi = fmt.Sprintf("%v %s ?", colName, engine.dialect.EqStr()) } colNames = append(colNames, condi) } @@ -709,7 +684,7 @@ func (statement *Statement) Id(id interface{}) *Statement { return statement } -// Generate "Update ... Set column = column + arg" statment +// Incr Generate "Update ... Set column = column + arg" statment func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { k := strings.ToLower(column) if len(arg) > 0 { @@ -720,7 +695,7 @@ func (statement *Statement) Incr(column string, arg ...interface{}) *Statement { return statement } -// Generate "Update ... Set column = column - arg" statment +// Decr Generate "Update ... Set column = column - arg" statment func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { k := strings.ToLower(column) if len(arg) > 0 { @@ -731,7 +706,7 @@ func (statement *Statement) Decr(column string, arg ...interface{}) *Statement { return statement } -// Generate "Update ... Set column = {expression}" statment +// SetExpr Generate "Update ... Set column = {expression}" statment func (statement *Statement) SetExpr(column string, expression string) *Statement { k := strings.ToLower(column) statement.exprColumns[k] = exprParam{column, expression} @@ -755,9 +730,14 @@ func (statement *Statement) getExpr() map[string]exprParam { // Generate "Where column IN (?) " statment func (statement *Statement) In(column string, args ...interface{}) *Statement { + length := len(args) + if length == 0 { + return statement + } + k := strings.ToLower(column) var newargs []interface{} - if len(args) == 1 && + if length == 1 && reflect.TypeOf(args[0]).Kind() == reflect.Slice { newargs = make([]interface{}, 0) v := reflect.ValueOf(args[0]) @@ -781,12 +761,17 @@ func (statement *Statement) genInSql() (string, []interface{}) { return "", []interface{}{} } - inStrs := make([]string, 0, len(statement.inColumns)) + inStrs := make([]string, len(statement.inColumns), len(statement.inColumns)) args := make([]interface{}, 0) + var buf bytes.Buffer + var i int for _, params := range statement.inColumns { - inStrs = append(inStrs, fmt.Sprintf("(%v IN (%v))", + buf.Reset() + fmt.Fprintf(&buf, "(%v IN (%v))", statement.Engine.autoQuote(params.colName), - strings.Join(makeArray("?", len(params.args)), ","))) + strings.Join(makeArray("?", len(params.args)), ",")) + inStrs[i] = buf.String() + i++ args = append(args, params.args...) } @@ -799,7 +784,7 @@ func (statement *Statement) genInSql() (string, []interface{}) { func (statement *Statement) attachInSql() { inSql, inArgs := statement.genInSql() if len(inSql) > 0 { - if statement.ConditionStr != "" { + if len(statement.ConditionStr) > 0 { statement.ConditionStr += " " + statement.Engine.dialect.AndStr() + " " } statement.ConditionStr += inSql @@ -858,6 +843,18 @@ func (statement *Statement) Distinct(columns ...string) *Statement { return statement } +// Generate "SELECT ... FOR UPDATE" statment +func (statement *Statement) ForUpdate() *Statement { + statement.IsForUpdate = true + return statement +} + +// replace select +func (s *Statement) Select(str string) *Statement { + s.selectStr = str + return s +} + // Generate "col1, col2" statement func (statement *Statement) Cols(columns ...string) *Statement { newColumns := col2NewCols(columns...) @@ -868,6 +865,7 @@ func (statement *Statement) Cols(columns ...string) *Statement { if strings.Contains(statement.ColumnStr, ".") { statement.ColumnStr = strings.Replace(statement.ColumnStr, ".", statement.Engine.Quote("."), -1) } + statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.Quote("*"), "*", -1) return statement } @@ -886,15 +884,6 @@ func (statement *Statement) MustCols(columns ...string) *Statement { return statement } -// Update use only: not update columns -/*func (statement *Statement) NotCols(columns ...string) *Statement { - newColumns := col2NewCols(columns...) - for _, nc := range newColumns { - statement.mustColumnMap[strings.ToLower(nc)] = false - } - return statement -}*/ - // indicates that use bool fields as update contents and query contiditions func (statement *Statement) UseBool(columns ...string) *Statement { if len(columns) > 0 { @@ -914,6 +903,14 @@ func (statement *Statement) Omit(columns ...string) { statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", "))) } +// Update use only: update columns to null when value is nullable and zero-value +func (statement *Statement) Nullable(columns ...string) { + newColumns := col2NewCols(columns...) + for _, nc := range newColumns { + statement.nullableMap[strings.ToLower(nc)] = true + } +} + // Generate LIMIT limit statement func (statement *Statement) Top(limit int) *Statement { statement.Limit(limit) @@ -931,7 +928,7 @@ func (statement *Statement) Limit(limit int, start ...int) *Statement { // Generate "Order By order" statement func (statement *Statement) OrderBy(order string) *Statement { - if statement.OrderStr != "" { + if len(statement.OrderStr) > 0 { statement.OrderStr += ", " } statement.OrderStr += order @@ -939,44 +936,51 @@ func (statement *Statement) OrderBy(order string) *Statement { } func (statement *Statement) Desc(colNames ...string) *Statement { - if statement.OrderStr != "" { - statement.OrderStr += ", " + var buf bytes.Buffer + fmt.Fprintf(&buf, statement.OrderStr) + if len(statement.OrderStr) > 0 { + fmt.Fprint(&buf, ", ") } newColNames := statement.col2NewColsWithQuote(colNames...) - sqlStr := strings.Join(newColNames, " DESC, ") - statement.OrderStr += sqlStr + " DESC" + fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, ")) + statement.OrderStr = buf.String() return statement } // Method Asc provide asc order by query condition, the input parameters are columns. func (statement *Statement) Asc(colNames ...string) *Statement { - if statement.OrderStr != "" { - statement.OrderStr += ", " + var buf bytes.Buffer + fmt.Fprintf(&buf, statement.OrderStr) + if len(statement.OrderStr) > 0 { + fmt.Fprint(&buf, ", ") } newColNames := statement.col2NewColsWithQuote(colNames...) - sqlStr := strings.Join(newColNames, " ASC, ") - statement.OrderStr += sqlStr + " ASC" + fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, ")) + statement.OrderStr = buf.String() return statement } //The join_operator should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (statement *Statement) Join(join_operator string, tablename interface{}, condition string) *Statement { - var joinTable string + var buf bytes.Buffer + if len(statement.JoinStr) > 0 { + fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, join_operator) + } else { + fmt.Fprintf(&buf, "%v JOIN ", join_operator) + } + switch tablename.(type) { case []string: t := tablename.([]string) - l := len(t) - if l > 1 { - table := t[0] - joinTable = statement.Engine.Quote(table) + " AS " + statement.Engine.Quote(t[1]) - } else if l == 1 { - table := t[0] - joinTable = statement.Engine.Quote(table) + if len(t) > 1 { + fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(t[0]), statement.Engine.Quote(t[1])) + } else if len(t) == 1 { + fmt.Fprintf(&buf, statement.Engine.Quote(t[0])) } case []interface{}: t := tablename.([]interface{}) l := len(t) - table := "" + var table string if l > 0 { f := t[0] v := rValue(f) @@ -989,21 +993,17 @@ func (statement *Statement) Join(join_operator string, tablename interface{}, co } } if l > 1 { - joinTable = statement.Engine.Quote(table) + " AS " + statement.Engine.Quote(fmt.Sprintf("%v", t[1])) + fmt.Fprintf(&buf, "%v AS %v", statement.Engine.Quote(table), + statement.Engine.Quote(fmt.Sprintf("%v", t[1]))) } else if l == 1 { - joinTable = statement.Engine.Quote(table) + fmt.Fprintf(&buf, statement.Engine.Quote(table)) } default: - t := fmt.Sprintf("%v", tablename) - joinTable = statement.Engine.Quote(t) - } - if statement.JoinStr != "" { - statement.JoinStr = statement.JoinStr + fmt.Sprintf(" %v JOIN %v ON %v", join_operator, - joinTable, condition) - } else { - statement.JoinStr = fmt.Sprintf("%v JOIN %v ON %v", join_operator, - joinTable, condition) + fmt.Fprintf(&buf, statement.Engine.Quote(fmt.Sprintf("%v", tablename))) } + + fmt.Fprintf(&buf, " ON %v", condition) + statement.JoinStr = buf.String() return statement } @@ -1120,11 +1120,6 @@ func (s *Statement) genDelIndexSQL() []string { return sqls } -/* -func (s *Statement) genDropSQL() string { - return s.Engine.dialect.MustDropTa(s.TableName()) + ";" -}*/ - func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) { var table *core.Table if statement.RefTable == nil { @@ -1134,28 +1129,34 @@ func (statement *Statement) genGetSql(bean interface{}) (string, []interface{}) table = statement.RefTable } - colNames, args := buildConditions(statement.Engine, table, bean, true, true, - false, true, statement.allUseBool, statement.useAllCols, - statement.unscoped, statement.mustColumnMap) + var addedTableName = (len(statement.JoinStr) > 0) - statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.dialect.AndStr()+" ") - statement.BeanArgs = args + if !statement.noAutoCondition { + colNames, args := statement.buildConditions(table, bean, true, true, false, true, addedTableName) + + statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.dialect.AndStr()+" ") + statement.BeanArgs = args + } var columnStr string = statement.ColumnStr - if len(statement.JoinStr) == 0 { - if len(columnStr) == 0 { - if statement.GroupByStr != "" { - columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) - } else { - columnStr = statement.genColumnStr() - } - } + if len(statement.selectStr) > 0 { + columnStr = statement.selectStr } else { - if len(columnStr) == 0 { - if statement.GroupByStr != "" { - columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) - } else { - columnStr = "*" + if len(statement.JoinStr) == 0 { + if len(columnStr) == 0 { + if statement.GroupByStr != "" { + columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) + } else { + columnStr = statement.genColumnStr() + } + } + } else { + if len(columnStr) == 0 { + if statement.GroupByStr != "" { + columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) + } else { + columnStr = "*" + } } } } @@ -1185,16 +1186,23 @@ func (s *Statement) genAddUniqueStr(uqeName string, cols []string) (string, []in return sql, []interface{}{} }*/ +func (statement *Statement) buildConditions(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) ([]string, []interface{}) { + return buildConditions(statement.Engine, table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols, + statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) +} + func (statement *Statement) genCountSql(bean interface{}) (string, []interface{}) { table := statement.Engine.TableInfo(bean) statement.RefTable = table - colNames, args := buildConditions(statement.Engine, table, bean, true, true, false, - true, statement.allUseBool, statement.useAllCols, - statement.unscoped, statement.mustColumnMap) + var addedTableName = (len(statement.JoinStr) > 0) - statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.Dialect().AndStr()+" ") - statement.BeanArgs = args + if !statement.noAutoCondition { + colNames, args := statement.buildConditions(table, bean, true, true, false, true, addedTableName) + + statement.ConditionStr = strings.Join(colNames, " "+statement.Engine.Dialect().AndStr()+" ") + statement.BeanArgs = args + } // count(index fieldname) > count(0) > count(*) var id string = "*" @@ -1206,47 +1214,46 @@ func (statement *Statement) genCountSql(bean interface{}) (string, []interface{} } func (statement *Statement) genSelectSql(columnStr string) (a string) { - /*if statement.GroupByStr != "" { - if columnStr == "" { - columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) - } - //statement.GroupByStr = columnStr - }*/ var distinct string if statement.IsDistinct { distinct = "DISTINCT " } + var dialect = statement.Engine.Dialect() + var quote = statement.Engine.Quote var top string var mssqlCondi string - /*var orderBy string - if statement.OrderStr != "" { - orderBy = fmt.Sprintf(" ORDER BY %v", statement.OrderStr) - }*/ + statement.processIdParam() - var whereStr string - if statement.WhereStr != "" { - whereStr = fmt.Sprintf(" WHERE %v", statement.WhereStr) - if statement.ConditionStr != "" { - whereStr = fmt.Sprintf("%v %s %v", whereStr, statement.Engine.Dialect().AndStr(), - statement.ConditionStr) - } - } else if statement.ConditionStr != "" { - whereStr = fmt.Sprintf(" WHERE %v", statement.ConditionStr) - } - var fromStr string = " FROM " + statement.Engine.Quote(statement.TableName()) - if statement.TableAlias != "" { - if statement.Engine.dialect.DBType() == core.ORACLE { - fromStr += " " + statement.Engine.Quote(statement.TableAlias) + + var buf bytes.Buffer + if len(statement.WhereStr) > 0 { + if len(statement.ConditionStr) > 0 { + fmt.Fprintf(&buf, " WHERE (%v)", statement.WhereStr) } else { - fromStr += " AS " + statement.Engine.Quote(statement.TableAlias) + fmt.Fprintf(&buf, " WHERE %v", statement.WhereStr) + } + if statement.ConditionStr != "" { + fmt.Fprintf(&buf, " %s (%v)", dialect.AndStr(), statement.ConditionStr) + } + } else if len(statement.ConditionStr) > 0 { + fmt.Fprintf(&buf, " WHERE %v", statement.ConditionStr) + } + var whereStr = buf.String() + + var fromStr string = " FROM " + quote(statement.TableName()) + if statement.TableAlias != "" { + if dialect.DBType() == core.ORACLE { + fromStr += " " + quote(statement.TableAlias) + } else { + fromStr += " AS " + quote(statement.TableAlias) } } if statement.JoinStr != "" { fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr) } - if statement.Engine.dialect.DBType() == core.MSSQL { + if dialect.DBType() == core.MSSQL { if statement.LimitN > 0 { top = fmt.Sprintf(" TOP %d ", statement.LimitN) } @@ -1277,10 +1284,9 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { } // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern - a = fmt.Sprintf("SELECT %v%v%v%v%v", top, distinct, columnStr, - fromStr, whereStr) - if mssqlCondi != "" { - if whereStr != "" { + a = fmt.Sprintf("SELECT %v%v%v%v%v", top, distinct, columnStr, fromStr, whereStr) + if len(mssqlCondi) > 0 { + if len(whereStr) > 0 { a += " AND " + mssqlCondi } else { a += " WHERE " + mssqlCondi @@ -1296,17 +1302,20 @@ func (statement *Statement) genSelectSql(columnStr string) (a string) { if statement.OrderStr != "" { a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) } - if statement.Engine.dialect.DBType() != core.MSSQL && statement.Engine.dialect.DBType() != core.ORACLE { + if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { if statement.Start > 0 { a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) } else if statement.LimitN > 0 { a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) } - } else if statement.Engine.dialect.DBType() == core.ORACLE { + } else if dialect.DBType() == core.ORACLE { if statement.Start != 0 || statement.LimitN != 0 { a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) } } + if statement.IsForUpdate { + a = dialect.ForUpdateSql(a) + } return } diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/syslogger.go b/Godeps/_workspace/src/github.com/go-xorm/xorm/syslogger.go index 5c78fecb0d2..eff69d11194 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/syslogger.go +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/syslogger.go @@ -1,3 +1,7 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + // +build !windows,!nacl,!plan9 package xorm diff --git a/Godeps/_workspace/src/github.com/go-xorm/xorm/xorm.go b/Godeps/_workspace/src/github.com/go-xorm/xorm/xorm.go index 71644e6c039..2da9949ebc2 100644 --- a/Godeps/_workspace/src/github.com/go-xorm/xorm/xorm.go +++ b/Godeps/_workspace/src/github.com/go-xorm/xorm/xorm.go @@ -1,3 +1,7 @@ +// Copyright 2015 The Xorm Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + package xorm import ( @@ -13,7 +17,7 @@ import ( ) const ( - Version string = "0.4.2.0225" + Version string = "0.4.5.0204" ) func regDrvsNDialects() bool { @@ -35,7 +39,7 @@ func regDrvsNDialects() bool { for driverName, v := range providedDrvsNDialects { if driver := core.QueryDriver(driverName); driver == nil { core.RegisterDriver(driverName, v.getDriver()) - core.RegisterDialect(v.dbType, v.getDialect()) + core.RegisterDialect(v.dbType, v.getDialect) } } return true