158 lines
3.2 KiB
Go
158 lines
3.2 KiB
Go
package dbr
|
|
|
|
import (
|
|
"database/sql/driver"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type interpolator struct {
|
|
Buffer
|
|
Dialect
|
|
IgnoreBinary bool
|
|
N int
|
|
}
|
|
|
|
// InterpolateForDialect replaces placeholder in query with corresponding value in dialect
|
|
func InterpolateForDialect(query string, value []interface{}, d Dialect) (string, error) {
|
|
i := interpolator{
|
|
Buffer: NewBuffer(),
|
|
Dialect: d,
|
|
}
|
|
err := i.interpolate(query, value)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return i.String(), nil
|
|
}
|
|
|
|
func (i *interpolator) interpolate(query string, value []interface{}) error {
|
|
if strings.Count(query, placeholder) != len(value) {
|
|
return ErrPlaceholderCount
|
|
}
|
|
|
|
valueIndex := 0
|
|
|
|
for {
|
|
index := strings.Index(query, placeholder)
|
|
if index == -1 {
|
|
break
|
|
}
|
|
|
|
i.WriteString(query[:index])
|
|
if _, ok := value[valueIndex].([]byte); ok && i.IgnoreBinary {
|
|
i.WriteString(i.Placeholder(i.N))
|
|
i.N++
|
|
i.WriteValue(value[valueIndex])
|
|
} else {
|
|
err := i.encodePlaceholder(value[valueIndex])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
query = query[index+len(placeholder):]
|
|
valueIndex++
|
|
}
|
|
|
|
// placeholder not found; write remaining query
|
|
i.WriteString(query)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (i *interpolator) encodePlaceholder(value interface{}) error {
|
|
if builder, ok := value.(Builder); ok {
|
|
pbuf := NewBuffer()
|
|
err := builder.Build(i.Dialect, pbuf)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
paren := true
|
|
switch value.(type) {
|
|
case *SelectStmt:
|
|
case *union:
|
|
default:
|
|
paren = false
|
|
}
|
|
if paren {
|
|
i.WriteString("(")
|
|
}
|
|
err = i.interpolate(pbuf.String(), pbuf.Value())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if paren {
|
|
i.WriteString(")")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if valuer, ok := value.(driver.Valuer); ok {
|
|
// get driver.Valuer's data
|
|
var err error
|
|
value, err = valuer.Value()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if value == nil {
|
|
i.WriteString("NULL")
|
|
return nil
|
|
}
|
|
v := reflect.ValueOf(value)
|
|
switch v.Kind() {
|
|
case reflect.String:
|
|
i.WriteString(i.EncodeString(v.String()))
|
|
return nil
|
|
case reflect.Bool:
|
|
i.WriteString(i.EncodeBool(v.Bool()))
|
|
return nil
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
i.WriteString(strconv.FormatInt(v.Int(), 10))
|
|
return nil
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
i.WriteString(strconv.FormatUint(v.Uint(), 10))
|
|
return nil
|
|
case reflect.Float32, reflect.Float64:
|
|
i.WriteString(strconv.FormatFloat(v.Float(), 'f', -1, 64))
|
|
return nil
|
|
case reflect.Struct:
|
|
if v.Type() == reflect.TypeOf(time.Time{}) {
|
|
i.WriteString(i.EncodeTime(v.Interface().(time.Time)))
|
|
return nil
|
|
}
|
|
case reflect.Slice:
|
|
if v.Type().Elem().Kind() == reflect.Uint8 {
|
|
// []byte
|
|
i.WriteString(i.EncodeBytes(v.Bytes()))
|
|
return nil
|
|
}
|
|
if v.Len() == 0 {
|
|
// FIXME: support zero-length slice
|
|
return ErrInvalidSliceLength
|
|
}
|
|
i.WriteString("(")
|
|
for n := 0; n < v.Len(); n++ {
|
|
if n > 0 {
|
|
i.WriteString(",")
|
|
}
|
|
err := i.encodePlaceholder(v.Index(n).Interface())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
i.WriteString(")")
|
|
return nil
|
|
case reflect.Ptr:
|
|
if v.IsNil() {
|
|
i.WriteString("NULL")
|
|
return nil
|
|
}
|
|
return i.encodePlaceholder(v.Elem().Interface())
|
|
}
|
|
return ErrNotSupported
|
|
}
|