diff --git a/crud.go b/crud.go new file mode 100644 index 000000000..9e29585ee --- /dev/null +++ b/crud.go @@ -0,0 +1,328 @@ +package sqlx + +import ( + "database/sql" + "encoding/json" + "errors" + "reflect" + "strings" + "time" +) + +type DBI interface { + NamedExec(query string, arg interface{}) (sql.Result, error) + NamedQueryRow(query string, arg interface{}) *Row + NamedQuery(query string, arg interface{}) (*Rows, error) + QueryRowx(query string, args ...interface{}) *Row + QueryRow(query string, args ...interface{}) *sql.Row + Queryx(query string, args ...interface{}) (*Rows, error) + ExecOne(query string, args ...interface{}) error + Rebind(query string) string + Exec(query string, args ...interface{}) (sql.Result, error) +} + +// Return the name of a Struct to tablename +func DefaultTableName(i interface{}) string { + return strings.ToLower(reflect.TypeOf(reflect.Indirect(reflect.ValueOf(i)).Interface()).Name()) +} + +type Helper struct { + DBI +} + +type StructTable interface { + TableName() string + Validate() error +} + +type SafeSelector map[string]interface{} + +// Expand expands the selector into a clause delimited by some space and a list of +// args to append into prepared statements +func Expand(s map[string]interface{}, spacer string) (string, []interface{}) { + args := []interface{}{} + cnt := 0 + query := "" + for key, value := range s { + query += key + query += "=?" + if cnt != len(s)-1 { + query += spacer + } + args = append(args, value) + cnt += 1 + } + return query, args +} + +// Extract takes in a struct object and extracts out the mapping +func Extract(obj StructTable) (map[string]interface{}, error) { + // Validate the schema. + if err := obj.Validate(); err != nil { + return nil, err + } + baseType := reflect.TypeOf(obj) // eg. Parameter + items := map[string]interface{}{} + for i := 0; i < baseType.NumField(); i++ { + fieldName := baseType.Field(i).Name // eg. "Torsion" + possiblyPtr := reflect.ValueOf(obj).FieldByName(fieldName) + // possiblyPtr could also be a struct or pointer + if possiblyPtr.Kind() == reflect.Struct { + subMap, err := Extract(possiblyPtr.Interface().(StructTable)) + if err != nil { + return nil, err + } + for k, v := range subMap { + items[k] = v + } + continue + } + if possiblyPtr.IsNil() { + // pass + } else { + // we are not a nil pointer, then indirect would always work. + fieldValue := reflect.Indirect(possiblyPtr) + concreteValue := fieldValue.Interface() + dbName, _ := parseTag(baseType.Field(i).Tag.Get("json")) + // if tagOptions.Contains("nonzero") && isZeroValue(fieldValue) { + // return nil, errors.New("Zero value found for tagged nonzero field:" + fieldName) + // } + switch item := concreteValue.(type) { + default: + items[dbName] = concreteValue + // dbVals = append(dbVals, ":"+dbName) + case time.Time: + if item.IsZero() { + items[dbName] = "NOW" + } else { + items[dbName] = concreteValue + } + } + } + } + return items, nil +} + +func LookupTag(obj StructTable, field string) string { + b, ok := reflect.TypeOf(obj).FieldByName(field) + if !ok { + return "" + } + tagName, _ := parseTag(b.Tag.Get("json")) + return tagName +} + +/* START ripped from unexported std lib END */ +type tagOptions string + +func parseTag(tag string) (string, tagOptions) { + if idx := strings.Index(tag, ","); idx != -1 { + return tag[:idx], tagOptions(tag[idx+1:]) + } + return tag, tagOptions("") +} + +func (o tagOptions) Contains(optionName string) bool { + if len(o) == 0 { + return false + } + s := string(o) + for s != "" { + var next string + i := strings.Index(s, ",") + if i >= 0 { + s, next = s[:i], s[i+1:] + } + if s == optionName { + return true + } + s = next + } + return false +} + +/* END ripped from unexported std lib END */ + +// MsiToStruct takes in a JSON serializable map[string]interface{} and converts +// it the actual object +func JsonToStruct(input map[string]interface{}, s StructTable) error { + // YT: LOL + b, err := json.Marshal(input) + if err != nil { + return err + } + return json.Unmarshal(b, s) +} + +func MakeStructTable(input map[string]interface{}, obj StructTable) error { + base := reflect.Indirect(reflect.ValueOf(obj)) + baseType := reflect.TypeOf(base.Interface()) + for k, v := range input { + _, ok := baseType.FieldByName(k) + if !ok { + return errors.New("Bad input name: " + k) + } + fv := base.FieldByName(k) + ptr := reflect.New(reflect.TypeOf(v)) + reflect.Indirect(ptr).Set(reflect.ValueOf(v)) + fv.Set(ptr) + } + return nil +} + +// special insertion rules: +// if type is time.Time, and the value is a Zero Value, then CURRENT_TIMESTAMP will be inserted +// if type is a Pointer, and its indirected value is nil, then it is omitted. +func (h *Helper) CreateObject(obj StructTable) error { + msi, err := Extract(obj) + if err != nil { + return err + } + dbKeys := []string{} + dbVals := []interface{}{} + for k, v := range msi { + dbKeys = append(dbKeys, k) + dbVals = append(dbVals, v) + } + query := "INSERT INTO " + obj.TableName() + query += " (" + for idx, key := range dbKeys { + query += key + if idx != len(dbKeys)-1 { + query += "," + } + } + query += ") VALUES (" + for idx, _ := range dbKeys { + query += "?" + if idx != len(dbKeys)-1 { + query += "," + } + } + query += ")" + query = h.Rebind(query) + _, err = h.Exec(query, dbVals...) + return err +} + +// DeleteAll removes all rows in the table matching condition. +// If no matching row was deleted, then an error is returned. +func (h *Helper) DeleteAll(condition StructTable) error { + tableName := condition.TableName() + msi, err := Extract(condition) + if err != nil { + return err + } + query := "DELETE FROM " + tableName + query += " WHERE " + where, args := Expand(msi, " AND ") + query += where + query = h.Rebind(query) + res, err := h.Exec(query, args...) + if err != nil { + return err + } + cnt, err := res.RowsAffected() + if err != nil { + return err + } + if cnt == 0 { + return sql.ErrNoRows + } + return nil +} + +func (h *Helper) buildQuery(condition StructTable, projection []string) (string, []interface{}, error) { + tableName := condition.TableName() + query := "SELECT " + if len(projection) > 0 { + for idx, p := range projection { + query += p + if idx != len(projection)-1 { + query += "," + } + } + } else { + query += "*" + } + query += " FROM " + query += tableName + msi, err := Extract(condition) + if err != nil { + return "", nil, err + } + args := []interface{}{} + if len(msi) > 0 { + query += " WHERE " + var where string + where, args = Expand(msi, " AND ") + query += where + } + return query, args, nil +} + +// QueryOne returns a scanned object corresponding to the first row matching condition. For +// more complicated tasks such as pagination, etc. It's more sensible to build your own SQL. +// objPtr must be some pointer to a StructTable to receive the deserialized value. Projection +// should be json tags. +func (h *Helper) QueryOne(condition StructTable, objPtr StructTable, projection ...string) error { + query, args, err := h.buildQuery(condition, projection) + if err != nil { + return err + } + query += " LIMIT 1" + query = h.Rebind(query) + return h.QueryRowx(query, args...).StructScan(objPtr) +} + +// QueryRows returns a pointer to a sql.Rows object that can iterated over and scanned. Projection +// should be json tags. +func (h *Helper) QueryRows(condition StructTable, projection ...string) (*Rows, error) { + query, args, err := h.buildQuery(condition, projection) + if err != nil { + return nil, err + } + query = h.Rebind(query) + return h.Queryx(query, args...) +} + +// UpdateAll updates rows matching condition with new values given by updates. +// If no matching row was updated, then an error is returned. +func (h *Helper) UpdateAll(update StructTable, condition StructTable) error { + tableName := update.TableName() + msi1, err := Extract(update) + if err != nil { + return err + } + if len(msi1) == 0 { + // nothing to update, all nil + return nil + } + msi2, err := Extract(condition) + if err != nil { + return err + } + query := "UPDATE " + tableName + " SET " + expansion, args := Expand(msi1, ",") + query += expansion + // all_args := append(args + if len(msi2) > 0 { + query += " WHERE " + expansion2, args2 := Expand(msi2, " AND ") + query += expansion2 + args = append(args, args2...) + } + query = h.Rebind(query) + res, err := h.Exec(query, args...) + if err != nil { + return err + } + cnt, err := res.RowsAffected() + if err != nil { + return err + } + if cnt == 0 { + return errors.New("No row was updated.") + } + return nil +} diff --git a/sqlx.go b/sqlx.go index 05dd8b13f..2aa794ec8 100644 --- a/sqlx.go +++ b/sqlx.go @@ -33,10 +33,10 @@ var mpr *reflectx.Mapper // mapper returns a valid mapper using the configured NameMapper func. func mapper() *reflectx.Mapper { if mpr == nil { - mpr = reflectx.NewMapperFunc("db", NameMapper) + mpr = reflectx.NewMapperFunc("json", NameMapper) } else if origMapper != reflect.ValueOf(NameMapper) { // if NameMapper has changed, create a new mapper - mpr = reflectx.NewMapperFunc("db", NameMapper) + mpr = reflectx.NewMapperFunc("json", NameMapper) origMapper = reflect.ValueOf(NameMapper) } return mpr @@ -258,10 +258,25 @@ func MustOpen(driverName, dataSourceName string) *DB { return db } +func ExecOne(ex Execer, query string, args ...interface{}) error { + rows, err := ex.Exec(query, args...) + if err != nil { + return err + } + count, err := rows.RowsAffected() + if err != nil { + return err + } + if count != 1 { + return sql.ErrNoRows + } + return nil +} + // MapperFunc sets a new mapper for this db using the default sqlx struct tag // and the provided mapper function. func (db *DB) MapperFunc(mf func(string) string) { - db.Mapper = reflectx.NewMapperFunc("db", mf) + db.Mapper = reflectx.NewMapperFunc("json", mf) } // Rebind transforms a query from QUESTION to the DB driver's bindvar type. @@ -287,6 +302,12 @@ func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) { return NamedQuery(db, query, arg) } +// NamedQueryRow using this DB. +func (db *DB) NamedQueryRow(query string, arg interface{}) *Row { + rows, err := NamedQuery(db, query, arg) + return &Row{rows: rows.Rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper} +} + // NamedExec using this DB. func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) { return NamedExec(db, query, arg) @@ -351,6 +372,11 @@ func (db *DB) PrepareNamed(query string) (*NamedStmt, error) { return prepareNamed(db, query) } +// ExecOne expects exactly one row to be affected. +func (db *DB) ExecOne(query string, args ...interface{}) error { + return ExecOne(db, query, args...) +} + // Tx is an sqlx wrapper around sql.Tx with extra functionality type Tx struct { *sql.Tx @@ -385,6 +411,12 @@ func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) { return NamedQuery(tx, query, arg) } +// NamedQueryRow within a transaction. +func (tx *Tx) NamedQueryRow(query string, arg interface{}) *Row { + rows, err := NamedQuery(tx, query, arg) + return &Row{rows: rows.Rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper} +} + // NamedExec a named query within a transaction. func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) { return NamedExec(tx, query, arg) @@ -395,6 +427,11 @@ func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error return Select(tx, dest, query, args...) } +// ExecOne expects exactly one row to be affected. +func (tx *Tx) ExecOne(query string, args ...interface{}) error { + return ExecOne(tx, query, args...) +} + // Queryx within a transaction. func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) { r, err := tx.Tx.Query(query, args...) diff --git a/sqlx_test.go b/sqlx_test.go index c1a125bac..2d3ef0820 100644 --- a/sqlx_test.go +++ b/sqlx_test.go @@ -685,6 +685,32 @@ func TestNamedQuery(t *testing.T) { check(t, rows) + row := db.NamedQueryRow(pdb(` + SELECT * FROM jsperson + WHERE + "FIRST"=:FIRST AND + last_name=:last_name AND + "EMAIL"=:EMAIL + `, db), jp) + + jp2 := JSONPerson{} + err = row.StructScan(&jp2) + if err != nil { + t.Error(err) + } + + if jp2.Email != jp.Email { + t.Errorf("Email Mismatch %s", db.DriverName()) + } + + if jp2.FirstName != jp.FirstName { + t.Errorf("FirstName Mismatch %s", db.DriverName()) + } + + if jp2.LastName != jp.LastName { + t.Errorf("LastName Mismatch %s", db.DriverName()) + } + db.Mapper = &old })