Update dependencies (#5518)

This commit is contained in:
hongming
2023-02-12 23:09:20 +08:00
committed by GitHub
parent d3b35fb2da
commit a979342f56
1486 changed files with 126660 additions and 71128 deletions

View File

@@ -10,6 +10,7 @@ import (
"syscall"
"github.com/Azure/go-ansiterm"
windows "golang.org/x/sys/windows"
)
// Windows keyboard constants
@@ -162,15 +163,28 @@ func ensureInRange(n int16, min int16, max int16) int16 {
func GetStdFile(nFile int) (*os.File, uintptr) {
var file *os.File
switch nFile {
case syscall.STD_INPUT_HANDLE:
// syscall uses negative numbers
// windows package uses very big uint32
// Keep these switches split so we don't have to convert ints too much.
switch uint32(nFile) {
case windows.STD_INPUT_HANDLE:
file = os.Stdin
case syscall.STD_OUTPUT_HANDLE:
case windows.STD_OUTPUT_HANDLE:
file = os.Stdout
case syscall.STD_ERROR_HANDLE:
case windows.STD_ERROR_HANDLE:
file = os.Stderr
default:
panic(fmt.Errorf("Invalid standard handle identifier: %v", nFile))
switch nFile {
case syscall.STD_INPUT_HANDLE:
file = os.Stdin
case syscall.STD_OUTPUT_HANDLE:
file = os.Stdout
case syscall.STD_ERROR_HANDLE:
file = os.Stderr
default:
panic(fmt.Errorf("Invalid standard handle identifier: %v", nFile))
}
}
fd, err := syscall.GetStdHandle(nFile)

View File

@@ -1,5 +1,2 @@
TAGS
tags
.*.swp
tomlcheck/tomlcheck
toml.test
/toml-test

View File

@@ -1,15 +0,0 @@
language: go
go:
- 1.1
- 1.2
- 1.3
- 1.4
- 1.5
- 1.6
- tip
install:
- go install ./...
- go get github.com/BurntSushi/toml-test
script:
- export PATH="$PATH:$HOME/gopath/bin"
- make test

View File

@@ -1,3 +1 @@
Compatible with TOML version
[v0.4.0](https://github.com/toml-lang/toml/blob/v0.4.0/versions/en/toml-v0.4.0.md)
Compatible with TOML version [v1.0.0](https://toml.io/en/v1.0.0).

View File

@@ -1,19 +0,0 @@
install:
go install ./...
test: install
go test -v
toml-test toml-test-decoder
toml-test -encoder toml-test-encoder
fmt:
gofmt -w *.go */*.go
colcheck *.go */*.go
tags:
find ./ -name '*.go' -print0 | xargs -0 gotags > TAGS
push:
git push origin master
git push github master

View File

@@ -1,46 +1,36 @@
## TOML parser and encoder for Go with reflection
TOML stands for Tom's Obvious, Minimal Language. This Go package provides a
reflection interface similar to Go's standard library `json` and `xml`
packages. This package also supports the `encoding.TextUnmarshaler` and
`encoding.TextMarshaler` interfaces so that you can define custom data
representations. (There is an example of this below.)
packages.
Spec: https://github.com/toml-lang/toml
Compatible with TOML version [v1.0.0](https://toml.io/en/v1.0.0).
Compatible with TOML version
[v0.4.0](https://github.com/toml-lang/toml/blob/master/versions/en/toml-v0.4.0.md)
Documentation: https://godocs.io/github.com/BurntSushi/toml
Documentation: https://godoc.org/github.com/BurntSushi/toml
See the [releases page](https://github.com/BurntSushi/toml/releases) for a
changelog; this information is also in the git tag annotations (e.g. `git show
v0.4.0`).
Installation:
This library requires Go 1.13 or newer; install it with:
```bash
go get github.com/BurntSushi/toml
```
% go get github.com/BurntSushi/toml@latest
Try the toml validator:
It also comes with a TOML validator CLI tool:
```bash
go get github.com/BurntSushi/toml/cmd/tomlv
tomlv some-toml-file.toml
```
[![Build Status](https://travis-ci.org/BurntSushi/toml.svg?branch=master)](https://travis-ci.org/BurntSushi/toml) [![GoDoc](https://godoc.org/github.com/BurntSushi/toml?status.svg)](https://godoc.org/github.com/BurntSushi/toml)
% go install github.com/BurntSushi/toml/cmd/tomlv@latest
% tomlv some-toml-file.toml
### Testing
This package passes all tests in [toml-test] for both the decoder and the
encoder.
This package passes all tests in
[toml-test](https://github.com/BurntSushi/toml-test) for both the decoder
and the encoder.
[toml-test]: https://github.com/BurntSushi/toml-test
### Examples
This package works similar to how the Go standard library handles XML and JSON.
Namely, data is loaded into Go values via reflection.
This package works similarly to how the Go standard library handles `XML`
and `JSON`. Namely, data is loaded into Go values via reflection.
For the simplest example, consider some TOML file as just a list of keys
and values:
For the simplest example, consider some TOML file as just a list of keys and
values:
```toml
Age = 25
@@ -54,11 +44,11 @@ Which could be defined in Go as:
```go
type Config struct {
Age int
Cats []string
Pi float64
Perfection []int
DOB time.Time // requires `import time`
Age int
Cats []string
Pi float64
Perfection []int
DOB time.Time // requires `import time`
}
```
@@ -66,9 +56,8 @@ And then decoded with:
```go
var conf Config
if _, err := toml.Decode(tomlData, &conf); err != nil {
// handle error
}
_, err := toml.Decode(tomlData, &conf)
// handle error
```
You can also use struct tags if your struct field name doesn't map to a TOML
@@ -80,12 +69,14 @@ some_key_NAME = "wat"
```go
type TOML struct {
ObscureKey string `toml:"some_key_NAME"`
ObscureKey string `toml:"some_key_NAME"`
}
```
### Using the `encoding.TextUnmarshaler` interface
Beware that like other most other decoders **only exported fields** are
considered when encoding and decoding; private fields are silently ignored.
### Using the `Marshaler` and `encoding.TextUnmarshaler` interfaces
Here's an example that automatically parses duration strings into
`time.Duration` values:
@@ -103,19 +94,19 @@ Which can be decoded with:
```go
type song struct {
Name string
Duration duration
Name string
Duration duration
}
type songs struct {
Song []song
Song []song
}
var favorites songs
if _, err := toml.Decode(blob, &favorites); err != nil {
log.Fatal(err)
log.Fatal(err)
}
for _, s := range favorites.Song {
fmt.Printf("%s (%s)\n", s.Name, s.Duration)
fmt.Printf("%s (%s)\n", s.Name, s.Duration)
}
```
@@ -134,8 +125,10 @@ func (d *duration) UnmarshalText(text []byte) error {
}
```
### More complex usage
To target TOML specifically you can implement `UnmarshalTOML` TOML interface in
a similar way.
### More complex usage
Here's an example of how to load the example from the official spec page:
```toml
@@ -180,23 +173,23 @@ And the corresponding Go types are:
```go
type tomlConfig struct {
Title string
Owner ownerInfo
DB database `toml:"database"`
Title string
Owner ownerInfo
DB database `toml:"database"`
Servers map[string]server
Clients clients
}
type ownerInfo struct {
Name string
Org string `toml:"organization"`
Bio string
DOB time.Time
Org string `toml:"organization"`
Bio string
DOB time.Time
}
type database struct {
Server string
Ports []int
Server string
Ports []int
ConnMax int `toml:"connection_max"`
Enabled bool
}
@@ -207,7 +200,7 @@ type server struct {
}
type clients struct {
Data [][]interface{}
Data [][]interface{}
Hosts []string
}
```
@@ -215,4 +208,4 @@ type clients struct {
Note that a case insensitive match will be tried if an exact match can't be
found.
A working example of the above can be found in `_examples/example.{go,toml}`.
A working example of the above can be found in `_example/example.{go,toml}`.

View File

@@ -1,19 +1,17 @@
package toml
import (
"bytes"
"encoding"
"fmt"
"io"
"io/ioutil"
"math"
"os"
"reflect"
"strings"
"time"
)
func e(format string, args ...interface{}) error {
return fmt.Errorf("toml: "+format, args...)
}
// Unmarshaler is the interface implemented by objects that can unmarshal a
// TOML description of themselves.
type Unmarshaler interface {
@@ -21,34 +19,145 @@ type Unmarshaler interface {
}
// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`.
func Unmarshal(p []byte, v interface{}) error {
_, err := Decode(string(p), v)
func Unmarshal(data []byte, v interface{}) error {
_, err := NewDecoder(bytes.NewReader(data)).Decode(v)
return err
}
// Decode the TOML data in to the pointer v.
//
// See the documentation on Decoder for a description of the decoding process.
func Decode(data string, v interface{}) (MetaData, error) {
return NewDecoder(strings.NewReader(data)).Decode(v)
}
// DecodeFile is just like Decode, except it will automatically read the
// contents of the file at path and decode it for you.
func DecodeFile(path string, v interface{}) (MetaData, error) {
fp, err := os.Open(path)
if err != nil {
return MetaData{}, err
}
defer fp.Close()
return NewDecoder(fp).Decode(v)
}
// Primitive is a TOML value that hasn't been decoded into a Go value.
// When using the various `Decode*` functions, the type `Primitive` may
// be given to any value, and its decoding will be delayed.
//
// A `Primitive` value can be decoded using the `PrimitiveDecode` function.
// This type can be used for any value, which will cause decoding to be delayed.
// You can use the PrimitiveDecode() function to "manually" decode these values.
//
// The underlying representation of a `Primitive` value is subject to change.
// Do not rely on it.
// NOTE: The underlying representation of a `Primitive` value is subject to
// change. Do not rely on it.
//
// N.B. Primitive values are still parsed, so using them will only avoid
// the overhead of reflection. They can be useful when you don't know the
// exact type of TOML data until run time.
// NOTE: Primitive values are still parsed, so using them will only avoid the
// overhead of reflection. They can be useful when you don't know the exact type
// of TOML data until runtime.
type Primitive struct {
undecoded interface{}
context Key
}
// DEPRECATED!
// The significand precision for float32 and float64 is 24 and 53 bits; this is
// the range a natural number can be stored in a float without loss of data.
const (
maxSafeFloat32Int = 16777215 // 2^24-1
maxSafeFloat64Int = int64(9007199254740991) // 2^53-1
)
// Decoder decodes TOML data.
//
// Use MetaData.PrimitiveDecode instead.
func PrimitiveDecode(primValue Primitive, v interface{}) error {
md := MetaData{decoded: make(map[string]bool)}
return md.unify(primValue.undecoded, rvalue(v))
// TOML tables correspond to Go structs or maps (dealer's choice they can be
// used interchangeably).
//
// TOML table arrays correspond to either a slice of structs or a slice of maps.
//
// TOML datetimes correspond to Go time.Time values. Local datetimes are parsed
// in the local timezone.
//
// All other TOML types (float, string, int, bool and array) correspond to the
// obvious Go types.
//
// An exception to the above rules is if a type implements the TextUnmarshaler
// interface, in which case any primitive TOML value (floats, strings, integers,
// booleans, datetimes) will be converted to a []byte and given to the value's
// UnmarshalText method. See the Unmarshaler example for a demonstration with
// time duration strings.
//
// Key mapping
//
// TOML keys can map to either keys in a Go map or field names in a Go struct.
// The special `toml` struct tag can be used to map TOML keys to struct fields
// that don't match the key name exactly (see the example). A case insensitive
// match to struct names will be tried if an exact match can't be found.
//
// The mapping between TOML values and Go values is loose. That is, there may
// exist TOML values that cannot be placed into your representation, and there
// may be parts of your representation that do not correspond to TOML values.
// This loose mapping can be made stricter by using the IsDefined and/or
// Undecoded methods on the MetaData returned.
//
// This decoder does not handle cyclic types. Decode will not terminate if a
// cyclic type is passed.
type Decoder struct {
r io.Reader
}
// NewDecoder creates a new Decoder.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
}
var (
unmarshalToml = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
unmarshalText = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
)
// Decode TOML data in to the pointer `v`.
func (dec *Decoder) Decode(v interface{}) (MetaData, error) {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
s := "%q"
if reflect.TypeOf(v) == nil {
s = "%v"
}
return MetaData{}, e("cannot decode to non-pointer "+s, reflect.TypeOf(v))
}
if rv.IsNil() {
return MetaData{}, e("cannot decode to nil value of %q", reflect.TypeOf(v))
}
// Check if this is a supported type: struct, map, interface{}, or something
// that implements UnmarshalTOML or UnmarshalText.
rv = indirect(rv)
rt := rv.Type()
if rv.Kind() != reflect.Struct && rv.Kind() != reflect.Map &&
!(rv.Kind() == reflect.Interface && rv.NumMethod() == 0) &&
!rt.Implements(unmarshalToml) && !rt.Implements(unmarshalText) {
return MetaData{}, e("cannot decode to type %s", rt)
}
// TODO: parser should read from io.Reader? Or at the very least, make it
// read from []byte rather than string
data, err := ioutil.ReadAll(dec.r)
if err != nil {
return MetaData{}, err
}
p, err := parse(string(data))
if err != nil {
return MetaData{}, err
}
md := MetaData{
mapping: p.mapping,
types: p.types,
keys: p.ordered,
decoded: make(map[string]struct{}, len(p.ordered)),
context: nil,
}
return md, md.unify(p.mapping, rv)
}
// PrimitiveDecode is just like the other `Decode*` functions, except it
@@ -68,89 +177,14 @@ func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error {
return md.unify(primValue.undecoded, rvalue(v))
}
// Decode will decode the contents of `data` in TOML format into a pointer
// `v`.
//
// TOML hashes correspond to Go structs or maps. (Dealer's choice. They can be
// used interchangeably.)
//
// TOML arrays of tables correspond to either a slice of structs or a slice
// of maps.
//
// TOML datetimes correspond to Go `time.Time` values.
//
// All other TOML types (float, string, int, bool and array) correspond
// to the obvious Go types.
//
// An exception to the above rules is if a type implements the
// encoding.TextUnmarshaler interface. In this case, any primitive TOML value
// (floats, strings, integers, booleans and datetimes) will be converted to
// a byte string and given to the value's UnmarshalText method. See the
// Unmarshaler example for a demonstration with time duration strings.
//
// Key mapping
//
// TOML keys can map to either keys in a Go map or field names in a Go
// struct. The special `toml` struct tag may be used to map TOML keys to
// struct fields that don't match the key name exactly. (See the example.)
// A case insensitive match to struct names will be tried if an exact match
// can't be found.
//
// The mapping between TOML values and Go values is loose. That is, there
// may exist TOML values that cannot be placed into your representation, and
// there may be parts of your representation that do not correspond to
// TOML values. This loose mapping can be made stricter by using the IsDefined
// and/or Undecoded methods on the MetaData returned.
//
// This decoder will not handle cyclic types. If a cyclic type is passed,
// `Decode` will not terminate.
func Decode(data string, v interface{}) (MetaData, error) {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return MetaData{}, e("Decode of non-pointer %s", reflect.TypeOf(v))
}
if rv.IsNil() {
return MetaData{}, e("Decode of nil %s", reflect.TypeOf(v))
}
p, err := parse(data)
if err != nil {
return MetaData{}, err
}
md := MetaData{
p.mapping, p.types, p.ordered,
make(map[string]bool, len(p.ordered)), nil,
}
return md, md.unify(p.mapping, indirect(rv))
}
// DecodeFile is just like Decode, except it will automatically read the
// contents of the file at `fpath` and decode it for you.
func DecodeFile(fpath string, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadFile(fpath)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// DecodeReader is just like Decode, except it will consume all bytes
// from the reader and decode it for you.
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) {
bs, err := ioutil.ReadAll(r)
if err != nil {
return MetaData{}, err
}
return Decode(string(bs), v)
}
// unify performs a sort of type unification based on the structure of `rv`,
// which is the client representation.
//
// Any type mismatch produces an error. Finding a type that we don't know
// how to handle produces an unsupported type error.
func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
// Special case. Look for a `Primitive` value.
// TODO: #76 would make this superfluous after implemented.
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() {
// Save the undecoded data and the key context into the primitive
// value.
@@ -170,25 +204,17 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
}
}
// Special case. Handle time.Time values specifically.
// TODO: Remove this code when we decide to drop support for Go 1.1.
// This isn't necessary in Go 1.2 because time.Time satisfies the encoding
// interfaces.
if rv.Type().AssignableTo(rvalue(time.Time{}).Type()) {
return md.unifyDatetime(data, rv)
}
// Special case. Look for a value satisfying the TextUnmarshaler interface.
if v, ok := rv.Interface().(TextUnmarshaler); ok {
if v, ok := rv.Interface().(encoding.TextUnmarshaler); ok {
return md.unifyText(data, v)
}
// BUG(burntsushi)
// TODO:
// The behavior here is incorrect whenever a Go type satisfies the
// encoding.TextUnmarshaler interface but also corresponds to a TOML
// hash or array. In particular, the unmarshaler should only be applied
// to primitive TOML values. But at this point, it will be applied to
// all kinds of values and produce an incorrect error whenever those values
// are hashes or arrays (including arrays of tables).
// encoding.TextUnmarshaler interface but also corresponds to a TOML hash or
// array. In particular, the unmarshaler should only be applied to primitive
// TOML values. But at this point, it will be applied to all kinds of values
// and produce an incorrect error whenever those values are hashes or arrays
// (including arrays of tables).
k := rv.Kind()
@@ -223,9 +249,7 @@ func (md *MetaData) unify(data interface{}, rv reflect.Value) error {
return e("unsupported type %s", rv.Type())
}
return md.unifyAnything(data, rv)
case reflect.Float32:
fallthrough
case reflect.Float64:
case reflect.Float32, reflect.Float64:
return md.unifyFloat64(data, rv)
}
return e("unsupported type %s", rv.Kind())
@@ -259,17 +283,17 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
for _, i := range f.index {
subv = indirect(subv.Field(i))
}
if isUnifiable(subv) {
md.decoded[md.context.add(key).String()] = true
md.decoded[md.context.add(key).String()] = struct{}{}
md.context = append(md.context, key)
if err := md.unify(datum, subv); err != nil {
err := md.unify(datum, subv)
if err != nil {
return err
}
md.context = md.context[0 : len(md.context)-1]
} else if f.name != "" {
// Bad user! No soup for you!
return e("cannot write unexported field %s.%s",
rv.Type().String(), f.name)
return e("cannot write unexported field %s.%s", rv.Type().String(), f.name)
}
}
}
@@ -277,27 +301,33 @@ func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error {
}
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error {
if k := rv.Type().Key().Kind(); k != reflect.String {
return fmt.Errorf(
"toml: cannot decode to a map with non-string key type (%s in %q)",
k, rv.Type())
}
tmap, ok := mapping.(map[string]interface{})
if !ok {
if tmap == nil {
return nil
}
return badtype("map", mapping)
return md.badtype("map", mapping)
}
if rv.IsNil() {
rv.Set(reflect.MakeMap(rv.Type()))
}
for k, v := range tmap {
md.decoded[md.context.add(k).String()] = true
md.decoded[md.context.add(k).String()] = struct{}{}
md.context = append(md.context, k)
rvkey := indirect(reflect.New(rv.Type().Key()))
rvval := reflect.Indirect(reflect.New(rv.Type().Elem()))
if err := md.unify(v, rvval); err != nil {
return err
}
md.context = md.context[0 : len(md.context)-1]
rvkey := indirect(reflect.New(rv.Type().Key()))
rvkey.SetString(k)
rv.SetMapIndex(rvkey, rvval)
}
@@ -310,12 +340,10 @@ func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error {
if !datav.IsValid() {
return nil
}
return badtype("slice", data)
return md.badtype("slice", data)
}
sliceLen := datav.Len()
if sliceLen != rv.Len() {
return e("expected array length %d; got TOML array of length %d",
rv.Len(), sliceLen)
if l := datav.Len(); l != rv.Len() {
return e("expected array length %d; got TOML array of length %d", rv.Len(), l)
}
return md.unifySliceArray(datav, rv)
}
@@ -326,7 +354,7 @@ func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
if !datav.IsValid() {
return nil
}
return badtype("slice", data)
return md.badtype("slice", data)
}
n := datav.Len()
if rv.IsNil() || rv.Cap() < n {
@@ -337,37 +365,31 @@ func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error {
}
func (md *MetaData) unifySliceArray(data, rv reflect.Value) error {
sliceLen := data.Len()
for i := 0; i < sliceLen; i++ {
v := data.Index(i).Interface()
sliceval := indirect(rv.Index(i))
if err := md.unify(v, sliceval); err != nil {
l := data.Len()
for i := 0; i < l; i++ {
err := md.unify(data.Index(i).Interface(), indirect(rv.Index(i)))
if err != nil {
return err
}
}
return nil
}
func (md *MetaData) unifyDatetime(data interface{}, rv reflect.Value) error {
if _, ok := data.(time.Time); ok {
rv.Set(reflect.ValueOf(data))
return nil
}
return badtype("time.Time", data)
}
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error {
if s, ok := data.(string); ok {
rv.SetString(s)
return nil
}
return badtype("string", data)
return md.badtype("string", data)
}
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
if num, ok := data.(float64); ok {
switch rv.Kind() {
case reflect.Float32:
if num < -math.MaxFloat32 || num > math.MaxFloat32 {
return e("value %f is out of range for float32", num)
}
fallthrough
case reflect.Float64:
rv.SetFloat(num)
@@ -376,7 +398,26 @@ func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error {
}
return nil
}
return badtype("float", data)
if num, ok := data.(int64); ok {
switch rv.Kind() {
case reflect.Float32:
if num < -maxSafeFloat32Int || num > maxSafeFloat32Int {
return e("value %d is out of range for float32", num)
}
fallthrough
case reflect.Float64:
if num < -maxSafeFloat64Int || num > maxSafeFloat64Int {
return e("value %d is out of range for float64", num)
}
rv.SetFloat(float64(num))
default:
panic("bug")
}
return nil
}
return md.badtype("float", data)
}
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
@@ -423,7 +464,7 @@ func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error {
}
return nil
}
return badtype("integer", data)
return md.badtype("integer", data)
}
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
@@ -431,7 +472,7 @@ func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error {
rv.SetBool(b)
return nil
}
return badtype("boolean", data)
return md.badtype("boolean", data)
}
func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error {
@@ -439,9 +480,15 @@ func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error {
return nil
}
func (md *MetaData) unifyText(data interface{}, v TextUnmarshaler) error {
func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) error {
var s string
switch sdata := data.(type) {
case Marshaler:
text, err := sdata.MarshalTOML()
if err != nil {
return err
}
s = string(text)
case TextMarshaler:
text, err := sdata.MarshalText()
if err != nil {
@@ -459,7 +506,7 @@ func (md *MetaData) unifyText(data interface{}, v TextUnmarshaler) error {
case float64:
s = fmt.Sprintf("%f", sdata)
default:
return badtype("primitive (string-like)", data)
return md.badtype("primitive (string-like)", data)
}
if err := v.UnmarshalText([]byte(s)); err != nil {
return err
@@ -467,22 +514,27 @@ func (md *MetaData) unifyText(data interface{}, v TextUnmarshaler) error {
return nil
}
func (md *MetaData) badtype(dst string, data interface{}) error {
return e("incompatible types: TOML key %q has type %T; destination has type %s", md.context, data, dst)
}
// rvalue returns a reflect.Value of `v`. All pointers are resolved.
func rvalue(v interface{}) reflect.Value {
return indirect(reflect.ValueOf(v))
}
// indirect returns the value pointed to by a pointer.
// Pointers are followed until the value is not a pointer.
// New values are allocated for each nil pointer.
//
// An exception to this rule is if the value satisfies an interface of
// interest to us (like encoding.TextUnmarshaler).
// Pointers are followed until the value is not a pointer. New values are
// allocated for each nil pointer.
//
// An exception to this rule is if the value satisfies an interface of interest
// to us (like encoding.TextUnmarshaler).
func indirect(v reflect.Value) reflect.Value {
if v.Kind() != reflect.Ptr {
if v.CanSet() {
pv := v.Addr()
if _, ok := pv.Interface().(TextUnmarshaler); ok {
if _, ok := pv.Interface().(encoding.TextUnmarshaler); ok {
return pv
}
}
@@ -498,12 +550,12 @@ func isUnifiable(rv reflect.Value) bool {
if rv.CanSet() {
return true
}
if _, ok := rv.Interface().(TextUnmarshaler); ok {
if _, ok := rv.Interface().(encoding.TextUnmarshaler); ok {
return true
}
return false
}
func badtype(expected string, data interface{}) error {
return e("cannot load TOML value of type %T into a Go %s", data, expected)
func e(format string, args ...interface{}) error {
return fmt.Errorf("toml: "+format, args...)
}

19
vendor/github.com/BurntSushi/toml/decode_go116.go generated vendored Normal file
View File

@@ -0,0 +1,19 @@
//go:build go1.16
// +build go1.16
package toml
import (
"io/fs"
)
// DecodeFS is just like Decode, except it will automatically read the contents
// of the file at `path` from a fs.FS instance.
func DecodeFS(fsys fs.FS, path string, v interface{}) (MetaData, error) {
fp, err := fsys.Open(path)
if err != nil {
return MetaData{}, err
}
defer fp.Close()
return NewDecoder(fp).Decode(v)
}

21
vendor/github.com/BurntSushi/toml/deprecated.go generated vendored Normal file
View File

@@ -0,0 +1,21 @@
package toml
import (
"encoding"
"io"
)
// Deprecated: use encoding.TextMarshaler
type TextMarshaler encoding.TextMarshaler
// Deprecated: use encoding.TextUnmarshaler
type TextUnmarshaler encoding.TextUnmarshaler
// Deprecated: use MetaData.PrimitiveDecode.
func PrimitiveDecode(primValue Primitive, v interface{}) error {
md := MetaData{decoded: make(map[string]struct{})}
return md.unify(primValue.undecoded, rvalue(v))
}
// Deprecated: use NewDecoder(reader).Decode(&value).
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) { return NewDecoder(r).Decode(v) }

View File

@@ -1,27 +1,13 @@
/*
Package toml provides facilities for decoding and encoding TOML configuration
files via reflection. There is also support for delaying decoding with
the Primitive type, and querying the set of keys in a TOML document with the
MetaData type.
Package toml implements decoding and encoding of TOML files.
The specification implemented: https://github.com/toml-lang/toml
This package supports TOML v1.0.0, as listed on https://toml.io
The sub-command github.com/BurntSushi/toml/cmd/tomlv can be used to verify
whether a file is a valid TOML document. It can also be used to print the
type of each key in a TOML document.
There is also support for delaying decoding with the Primitive type, and
querying the set of keys in a TOML document with the MetaData type.
Testing
There are two important types of tests used for this package. The first is
contained inside '*_test.go' files and uses the standard Go unit testing
framework. These tests are primarily devoted to holistically testing the
decoder and encoder.
The second type of testing is used to verify the implementation's adherence
to the TOML specification. These tests have been factored into their own
project: https://github.com/BurntSushi/toml-test
The reason the tests are in a separate project is so that they can be used by
any implementation of TOML. Namely, it is language agnostic.
The github.com/BurntSushi/toml/cmd/tomlv package implements a TOML validator,
and can be used to verify if TOML document is valid. It can also be used to
print the type of each key.
*/
package toml

View File

@@ -2,57 +2,106 @@ package toml
import (
"bufio"
"encoding"
"errors"
"fmt"
"io"
"math"
"reflect"
"sort"
"strconv"
"strings"
"time"
"github.com/BurntSushi/toml/internal"
)
type tomlEncodeError struct{ error }
var (
errArrayMixedElementTypes = errors.New(
"toml: cannot encode array with mixed element types")
errArrayNilElement = errors.New(
"toml: cannot encode array with nil element")
errNonString = errors.New(
"toml: cannot encode a map with non-string key type")
errAnonNonStruct = errors.New(
"toml: cannot encode an anonymous field that is not a struct")
errArrayNoTable = errors.New(
"toml: TOML array element cannot contain a table")
errNoKey = errors.New(
"toml: top-level values must be Go maps or structs")
errAnything = errors.New("") // used in testing
errArrayNilElement = errors.New("toml: cannot encode array with nil element")
errNonString = errors.New("toml: cannot encode a map with non-string key type")
errNoKey = errors.New("toml: top-level values must be Go maps or structs")
errAnything = errors.New("") // used in testing
)
var quotedReplacer = strings.NewReplacer(
"\t", "\\t",
"\n", "\\n",
"\r", "\\r",
var dblQuotedReplacer = strings.NewReplacer(
"\"", "\\\"",
"\\", "\\\\",
"\x00", `\u0000`,
"\x01", `\u0001`,
"\x02", `\u0002`,
"\x03", `\u0003`,
"\x04", `\u0004`,
"\x05", `\u0005`,
"\x06", `\u0006`,
"\x07", `\u0007`,
"\b", `\b`,
"\t", `\t`,
"\n", `\n`,
"\x0b", `\u000b`,
"\f", `\f`,
"\r", `\r`,
"\x0e", `\u000e`,
"\x0f", `\u000f`,
"\x10", `\u0010`,
"\x11", `\u0011`,
"\x12", `\u0012`,
"\x13", `\u0013`,
"\x14", `\u0014`,
"\x15", `\u0015`,
"\x16", `\u0016`,
"\x17", `\u0017`,
"\x18", `\u0018`,
"\x19", `\u0019`,
"\x1a", `\u001a`,
"\x1b", `\u001b`,
"\x1c", `\u001c`,
"\x1d", `\u001d`,
"\x1e", `\u001e`,
"\x1f", `\u001f`,
"\x7f", `\u007f`,
)
// Encoder controls the encoding of Go values to a TOML document to some
// io.Writer.
//
// The indentation level can be controlled with the Indent field.
type Encoder struct {
// A single indentation level. By default it is two spaces.
Indent string
// hasWritten is whether we have written any output to w yet.
hasWritten bool
w *bufio.Writer
// Marshaler is the interface implemented by types that can marshal themselves
// into valid TOML.
type Marshaler interface {
MarshalTOML() ([]byte, error)
}
// NewEncoder returns a TOML encoder that encodes Go values to the io.Writer
// given. By default, a single indentation level is 2 spaces.
// Encoder encodes a Go to a TOML document.
//
// The mapping between Go values and TOML values should be precisely the same as
// for the Decode* functions.
//
// The toml.Marshaler and encoder.TextMarshaler interfaces are supported to
// encoding the value as custom TOML.
//
// If you want to write arbitrary binary data then you will need to use
// something like base64 since TOML does not have any binary types.
//
// When encoding TOML hashes (Go maps or structs), keys without any sub-hashes
// are encoded first.
//
// Go maps will be sorted alphabetically by key for deterministic output.
//
// Encoding Go values without a corresponding TOML representation will return an
// error. Examples of this includes maps with non-string keys, slices with nil
// elements, embedded non-struct types, and nested slices containing maps or
// structs. (e.g. [][]map[string]string is not allowed but []map[string]string
// is okay, as is []map[string][]string).
//
// NOTE: only exported keys are encoded due to the use of reflection. Unexported
// keys are silently discarded.
type Encoder struct {
// String to use for a single indentation level; default is two spaces.
Indent string
w *bufio.Writer
hasWritten bool // written any output to w yet?
}
// NewEncoder create a new Encoder.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{
w: bufio.NewWriter(w),
@@ -60,29 +109,10 @@ func NewEncoder(w io.Writer) *Encoder {
}
}
// Encode writes a TOML representation of the Go value to the underlying
// io.Writer. If the value given cannot be encoded to a valid TOML document,
// then an error is returned.
// Encode writes a TOML representation of the Go value to the Encoder's writer.
//
// The mapping between Go values and TOML values should be precisely the same
// as for the Decode* functions. Similarly, the TextMarshaler interface is
// supported by encoding the resulting bytes as strings. (If you want to write
// arbitrary binary data then you will need to use something like base64 since
// TOML does not have any binary types.)
//
// When encoding TOML hashes (i.e., Go maps or structs), keys without any
// sub-hashes are encoded first.
//
// If a Go map is encoded, then its keys are sorted alphabetically for
// deterministic output. More control over this behavior may be provided if
// there is demand for it.
//
// Encoding Go values without a corresponding TOML representation---like map
// types with non-string keys---will cause an error to be returned. Similarly
// for mixed arrays/slices, arrays/slices with nil elements, embedded
// non-struct types and nested slices containing maps or structs.
// (e.g., [][]map[string]string is not allowed but []map[string]string is OK
// and so is []map[string][]string.)
// An error is returned if the value given cannot be encoded to a valid TOML
// document.
func (enc *Encoder) Encode(v interface{}) error {
rv := eindirect(reflect.ValueOf(v))
if err := enc.safeEncode(Key([]string{}), rv); err != nil {
@@ -106,13 +136,18 @@ func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) {
}
func (enc *Encoder) encode(key Key, rv reflect.Value) {
// Special case. Time needs to be in ISO8601 format.
// Special case. If we can marshal the type to text, then we used that.
// Basically, this prevents the encoder for handling these types as
// generic structs (or whatever the underlying type of a TextMarshaler is).
switch rv.Interface().(type) {
case time.Time, TextMarshaler:
enc.keyEqElement(key, rv)
// Special case: time needs to be in ISO8601 format.
//
// Special case: if we can marshal the type to text, then we used that. This
// prevents the encoder for handling these types as generic structs (or
// whatever the underlying type of a TextMarshaler is).
switch t := rv.Interface().(type) {
case time.Time, encoding.TextMarshaler, Marshaler:
enc.writeKeyValue(key, rv, false)
return
// TODO: #76 would make this superfluous after implemented.
case Primitive:
enc.encode(key, reflect.ValueOf(t.undecoded))
return
}
@@ -123,12 +158,12 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.String, reflect.Bool:
enc.keyEqElement(key, rv)
enc.writeKeyValue(key, rv, false)
case reflect.Array, reflect.Slice:
if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) {
enc.eArrayOfTables(key, rv)
} else {
enc.keyEqElement(key, rv)
enc.writeKeyValue(key, rv, false)
}
case reflect.Interface:
if rv.IsNil() {
@@ -148,55 +183,88 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
case reflect.Struct:
enc.eTable(key, rv)
default:
panic(e("unsupported type for key '%s': %s", key, k))
encPanic(fmt.Errorf("unsupported type for key '%s': %s", key, k))
}
}
// eElement encodes any value that can be an array element (primitives and
// arrays).
// eElement encodes any value that can be an array element.
func (enc *Encoder) eElement(rv reflect.Value) {
switch v := rv.Interface().(type) {
case time.Time:
// Special case time.Time as a primitive. Has to come before
// TextMarshaler below because time.Time implements
// encoding.TextMarshaler, but we need to always use UTC.
enc.wf(v.UTC().Format("2006-01-02T15:04:05Z"))
return
case TextMarshaler:
// Special case. Use text marshaler if it's available for this value.
if s, err := v.MarshalText(); err != nil {
encPanic(err)
} else {
enc.writeQuoted(string(s))
case time.Time: // Using TextMarshaler adds extra quotes, which we don't want.
format := time.RFC3339Nano
switch v.Location() {
case internal.LocalDatetime:
format = "2006-01-02T15:04:05.999999999"
case internal.LocalDate:
format = "2006-01-02"
case internal.LocalTime:
format = "15:04:05.999999999"
}
switch v.Location() {
default:
enc.wf(v.Format(format))
case internal.LocalDatetime, internal.LocalDate, internal.LocalTime:
enc.wf(v.In(time.UTC).Format(format))
}
return
case Marshaler:
s, err := v.MarshalTOML()
if err != nil {
encPanic(err)
}
enc.w.Write(s)
return
case encoding.TextMarshaler:
s, err := v.MarshalText()
if err != nil {
encPanic(err)
}
enc.writeQuoted(string(s))
return
}
switch rv.Kind() {
case reflect.Bool:
enc.wf(strconv.FormatBool(rv.Bool()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64:
enc.wf(strconv.FormatInt(rv.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16,
reflect.Uint32, reflect.Uint64:
enc.wf(strconv.FormatUint(rv.Uint(), 10))
case reflect.Float32:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 32)))
case reflect.Float64:
enc.wf(floatAddDecimal(strconv.FormatFloat(rv.Float(), 'f', -1, 64)))
case reflect.Array, reflect.Slice:
enc.eArrayOrSliceElement(rv)
case reflect.Interface:
enc.eElement(rv.Elem())
case reflect.String:
enc.writeQuoted(rv.String())
case reflect.Bool:
enc.wf(strconv.FormatBool(rv.Bool()))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
enc.wf(strconv.FormatInt(rv.Int(), 10))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
enc.wf(strconv.FormatUint(rv.Uint(), 10))
case reflect.Float32:
f := rv.Float()
if math.IsNaN(f) {
enc.wf("nan")
} else if math.IsInf(f, 0) {
enc.wf("%cinf", map[bool]byte{true: '-', false: '+'}[math.Signbit(f)])
} else {
enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, 32)))
}
case reflect.Float64:
f := rv.Float()
if math.IsNaN(f) {
enc.wf("nan")
} else if math.IsInf(f, 0) {
enc.wf("%cinf", map[bool]byte{true: '-', false: '+'}[math.Signbit(f)])
} else {
enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, 64)))
}
case reflect.Array, reflect.Slice:
enc.eArrayOrSliceElement(rv)
case reflect.Struct:
enc.eStruct(nil, rv, true)
case reflect.Map:
enc.eMap(nil, rv, true)
case reflect.Interface:
enc.eElement(rv.Elem())
default:
panic(e("unexpected primitive type: %s", rv.Kind()))
encPanic(fmt.Errorf("unexpected primitive type: %T", rv.Interface()))
}
}
// By the TOML spec, all floats must have a decimal with at least one
// number on either side.
// By the TOML spec, all floats must have a decimal with at least one number on
// either side.
func floatAddDecimal(fstr string) string {
if !strings.Contains(fstr, ".") {
return fstr + ".0"
@@ -205,7 +273,7 @@ func floatAddDecimal(fstr string) string {
}
func (enc *Encoder) writeQuoted(s string) {
enc.wf("\"%s\"", quotedReplacer.Replace(s))
enc.wf("\"%s\"", dblQuotedReplacer.Replace(s))
}
func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
@@ -230,40 +298,39 @@ func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) {
if isNil(trv) {
continue
}
panicIfInvalidKey(key)
enc.newline()
enc.wf("%s[[%s]]", enc.indentStr(key), key.maybeQuotedAll())
enc.wf("%s[[%s]]", enc.indentStr(key), key)
enc.newline()
enc.eMapOrStruct(key, trv)
enc.eMapOrStruct(key, trv, false)
}
}
func (enc *Encoder) eTable(key Key, rv reflect.Value) {
panicIfInvalidKey(key)
if len(key) == 1 {
// Output an extra newline between top-level tables.
// (The newline isn't written if nothing else has been written though.)
enc.newline()
}
if len(key) > 0 {
enc.wf("%s[%s]", enc.indentStr(key), key.maybeQuotedAll())
enc.wf("%s[%s]", enc.indentStr(key), key)
enc.newline()
}
enc.eMapOrStruct(key, rv)
enc.eMapOrStruct(key, rv, false)
}
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value) {
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value, inline bool) {
switch rv := eindirect(rv); rv.Kind() {
case reflect.Map:
enc.eMap(key, rv)
enc.eMap(key, rv, inline)
case reflect.Struct:
enc.eStruct(key, rv)
enc.eStruct(key, rv, inline)
default:
// Should never happen?
panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String())
}
}
func (enc *Encoder) eMap(key Key, rv reflect.Value) {
func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) {
rt := rv.Type()
if rt.Key().Kind() != reflect.String {
encPanic(errNonString)
@@ -274,114 +341,163 @@ func (enc *Encoder) eMap(key Key, rv reflect.Value) {
var mapKeysDirect, mapKeysSub []string
for _, mapKey := range rv.MapKeys() {
k := mapKey.String()
if typeIsHash(tomlTypeOfGo(rv.MapIndex(mapKey))) {
if typeIsTable(tomlTypeOfGo(rv.MapIndex(mapKey))) {
mapKeysSub = append(mapKeysSub, k)
} else {
mapKeysDirect = append(mapKeysDirect, k)
}
}
var writeMapKeys = func(mapKeys []string) {
var writeMapKeys = func(mapKeys []string, trailC bool) {
sort.Strings(mapKeys)
for _, mapKey := range mapKeys {
mrv := rv.MapIndex(reflect.ValueOf(mapKey))
if isNil(mrv) {
// Don't write anything for nil fields.
for i, mapKey := range mapKeys {
val := rv.MapIndex(reflect.ValueOf(mapKey))
if isNil(val) {
continue
}
enc.encode(key.add(mapKey), mrv)
if inline {
enc.writeKeyValue(Key{mapKey}, val, true)
if trailC || i != len(mapKeys)-1 {
enc.wf(", ")
}
} else {
enc.encode(key.add(mapKey), val)
}
}
}
writeMapKeys(mapKeysDirect)
writeMapKeys(mapKeysSub)
if inline {
enc.wf("{")
}
writeMapKeys(mapKeysDirect, len(mapKeysSub) > 0)
writeMapKeys(mapKeysSub, false)
if inline {
enc.wf("}")
}
}
func (enc *Encoder) eStruct(key Key, rv reflect.Value) {
const is32Bit = (32 << (^uint(0) >> 63)) == 32
func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) {
// Write keys for fields directly under this key first, because if we write
// a field that creates a new table, then all keys under it will be in that
// a field that creates a new table then all keys under it will be in that
// table (not the one we're writing here).
rt := rv.Type()
var fieldsDirect, fieldsSub [][]int
var addFields func(rt reflect.Type, rv reflect.Value, start []int)
//
// Fields is a [][]int: for fieldsDirect this always has one entry (the
// struct index). For fieldsSub it contains two entries: the parent field
// index from tv, and the field indexes for the fields of the sub.
var (
rt = rv.Type()
fieldsDirect, fieldsSub [][]int
addFields func(rt reflect.Type, rv reflect.Value, start []int)
)
addFields = func(rt reflect.Type, rv reflect.Value, start []int) {
for i := 0; i < rt.NumField(); i++ {
f := rt.Field(i)
// skip unexported fields
if f.PkgPath != "" && !f.Anonymous {
if f.PkgPath != "" && !f.Anonymous { /// Skip unexported fields.
continue
}
opts := getOptions(f.Tag)
if opts.skip {
continue
}
frv := rv.Field(i)
// Treat anonymous struct fields with tag names as though they are
// not anonymous, like encoding/json does.
//
// Non-struct anonymous fields use the normal encoding logic.
if f.Anonymous {
t := f.Type
switch t.Kind() {
case reflect.Struct:
// Treat anonymous struct fields with
// tag names as though they are not
// anonymous, like encoding/json does.
if getOptions(f.Tag).name == "" {
addFields(t, frv, f.Index)
addFields(t, frv, append(start, f.Index...))
continue
}
case reflect.Ptr:
if t.Elem().Kind() == reflect.Struct &&
getOptions(f.Tag).name == "" {
if t.Elem().Kind() == reflect.Struct && getOptions(f.Tag).name == "" {
if !frv.IsNil() {
addFields(t.Elem(), frv.Elem(), f.Index)
addFields(t.Elem(), frv.Elem(), append(start, f.Index...))
}
continue
}
// Fall through to the normal field encoding logic below
// for non-struct anonymous fields.
}
}
if typeIsHash(tomlTypeOfGo(frv)) {
if typeIsTable(tomlTypeOfGo(frv)) {
fieldsSub = append(fieldsSub, append(start, f.Index...))
} else {
fieldsDirect = append(fieldsDirect, append(start, f.Index...))
// Copy so it works correct on 32bit archs; not clear why this
// is needed. See #314, and https://www.reddit.com/r/golang/comments/pnx8v4
// This also works fine on 64bit, but 32bit archs are somewhat
// rare and this is a wee bit faster.
if is32Bit {
copyStart := make([]int, len(start))
copy(copyStart, start)
fieldsDirect = append(fieldsDirect, append(copyStart, f.Index...))
} else {
fieldsDirect = append(fieldsDirect, append(start, f.Index...))
}
}
}
}
addFields(rt, rv, nil)
var writeFields = func(fields [][]int) {
writeFields := func(fields [][]int) {
for _, fieldIndex := range fields {
sft := rt.FieldByIndex(fieldIndex)
sf := rv.FieldByIndex(fieldIndex)
if isNil(sf) {
// Don't write anything for nil fields.
fieldType := rt.FieldByIndex(fieldIndex)
fieldVal := rv.FieldByIndex(fieldIndex)
if isNil(fieldVal) { /// Don't write anything for nil fields.
continue
}
opts := getOptions(sft.Tag)
opts := getOptions(fieldType.Tag)
if opts.skip {
continue
}
keyName := sft.Name
keyName := fieldType.Name
if opts.name != "" {
keyName = opts.name
}
if opts.omitempty && isEmpty(sf) {
if opts.omitempty && isEmpty(fieldVal) {
continue
}
if opts.omitzero && isZero(sf) {
if opts.omitzero && isZero(fieldVal) {
continue
}
enc.encode(key.add(keyName), sf)
if inline {
enc.writeKeyValue(Key{keyName}, fieldVal, true)
if fieldIndex[0] != len(fields)-1 {
enc.wf(", ")
}
} else {
enc.encode(key.add(keyName), fieldVal)
}
}
}
if inline {
enc.wf("{")
}
writeFields(fieldsDirect)
writeFields(fieldsSub)
if inline {
enc.wf("}")
}
}
// tomlTypeName returns the TOML type name of the Go value's type. It is
// used to determine whether the types of array elements are mixed (which is
// forbidden). If the Go value is nil, then it is illegal for it to be an array
// element, and valueIsNil is returned as true.
// Returns the TOML type of a Go value. The type may be `nil`, which means
// no concrete TOML type could be found.
// tomlTypeOfGo returns the TOML type name of the Go value's type.
//
// It is used to determine whether the types of array elements are mixed (which
// is forbidden). If the Go value is nil, then it is illegal for it to be an
// array element, and valueIsNil is returned as true.
//
// The type may be `nil`, which means no concrete TOML type could be found.
func tomlTypeOfGo(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() {
return nil
@@ -408,19 +524,43 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
case reflect.Map:
return tomlHash
case reflect.Struct:
switch rv.Interface().(type) {
case time.Time:
if _, ok := rv.Interface().(time.Time); ok {
return tomlDatetime
case TextMarshaler:
return tomlString
default:
return tomlHash
}
if isMarshaler(rv) {
return tomlString
}
return tomlHash
default:
panic("unexpected reflect.Kind: " + rv.Kind().String())
if isMarshaler(rv) {
return tomlString
}
encPanic(errors.New("unsupported type: " + rv.Kind().String()))
panic("unreachable")
}
}
func isMarshaler(rv reflect.Value) bool {
switch rv.Interface().(type) {
case encoding.TextMarshaler:
return true
case Marshaler:
return true
}
// Someone used a pointer receiver: we can make it work for pointer values.
if rv.CanAddr() {
if _, ok := rv.Addr().Interface().(encoding.TextMarshaler); ok {
return true
}
if _, ok := rv.Addr().Interface().(Marshaler); ok {
return true
}
}
return false
}
// tomlArrayType returns the element type of a TOML array. The type returned
// may be nil if it cannot be determined (e.g., a nil slice or a zero length
// slize). This function may also panic if it finds a type that cannot be
@@ -430,30 +570,19 @@ func tomlArrayType(rv reflect.Value) tomlType {
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 {
return nil
}
/// Don't allow nil.
rvlen := rv.Len()
for i := 1; i < rvlen; i++ {
if tomlTypeOfGo(rv.Index(i)) == nil {
encPanic(errArrayNilElement)
}
}
firstType := tomlTypeOfGo(rv.Index(0))
if firstType == nil {
encPanic(errArrayNilElement)
}
rvlen := rv.Len()
for i := 1; i < rvlen; i++ {
elem := rv.Index(i)
switch elemType := tomlTypeOfGo(elem); {
case elemType == nil:
encPanic(errArrayNilElement)
case !typeEqual(firstType, elemType):
encPanic(errArrayMixedElementTypes)
}
}
// If we have a nested array, then we must make sure that the nested
// array contains ONLY primitives.
// This checks arbitrarily nested arrays.
if typeEqual(firstType, tomlArray) || typeEqual(firstType, tomlArrayHash) {
nest := tomlArrayType(eindirect(rv.Index(0)))
if typeEqual(nest, tomlHash) || typeEqual(nest, tomlArrayHash) {
encPanic(errArrayNoTable)
}
}
return firstType
}
@@ -511,18 +640,32 @@ func (enc *Encoder) newline() {
}
}
func (enc *Encoder) keyEqElement(key Key, val reflect.Value) {
// Write a key/value pair:
//
// key = <any value>
//
// This is also used for "k = v" in inline tables; so something like this will
// be written in three calls:
//
// ┌────────────────────┐
// │ ┌───┐ ┌─────┐│
// v v v v vv
// key = {k = v, k2 = v2}
//
func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) {
if len(key) == 0 {
encPanic(errNoKey)
}
panicIfInvalidKey(key)
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))
enc.eElement(val)
enc.newline()
if !inline {
enc.newline()
}
}
func (enc *Encoder) wf(format string, v ...interface{}) {
if _, err := fmt.Fprintf(enc.w, format, v...); err != nil {
_, err := fmt.Fprintf(enc.w, format, v...)
if err != nil {
encPanic(err)
}
enc.hasWritten = true
@@ -553,16 +696,3 @@ func isNil(rv reflect.Value) bool {
return false
}
}
func panicIfInvalidKey(key Key) {
for _, k := range key {
if len(k) == 0 {
encPanic(e("Key '%s' is not a valid table name. Key names "+
"cannot be empty.", key.maybeQuotedAll()))
}
}
}
func isValidKeyName(s string) bool {
return len(s) != 0
}

View File

@@ -1,19 +0,0 @@
// +build go1.2
package toml
// In order to support Go 1.1, we define our own TextMarshaler and
// TextUnmarshaler types. For Go 1.2+, we just alias them with the
// standard library interfaces.
import (
"encoding"
)
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler encoding.TextMarshaler
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler encoding.TextUnmarshaler

View File

@@ -1,18 +0,0 @@
// +build !go1.2
package toml
// These interfaces were introduced in Go 1.2, so we add them manually when
// compiling for Go 1.1.
// TextMarshaler is a synonym for encoding.TextMarshaler. It is defined here
// so that Go 1.1 can be supported.
type TextMarshaler interface {
MarshalText() (text []byte, err error)
}
// TextUnmarshaler is a synonym for encoding.TextUnmarshaler. It is defined
// here so that Go 1.1 can be supported.
type TextUnmarshaler interface {
UnmarshalText(text []byte) error
}

229
vendor/github.com/BurntSushi/toml/error.go generated vendored Normal file
View File

@@ -0,0 +1,229 @@
package toml
import (
"fmt"
"strings"
)
// ParseError is returned when there is an error parsing the TOML syntax.
//
// For example invalid syntax, duplicate keys, etc.
//
// In addition to the error message itself, you can also print detailed location
// information with context by using ErrorWithPosition():
//
// toml: error: Key 'fruit' was already created and cannot be used as an array.
//
// At line 4, column 2-7:
//
// 2 | fruit = []
// 3 |
// 4 | [[fruit]] # Not allowed
// ^^^^^
//
// Furthermore, the ErrorWithUsage() can be used to print the above with some
// more detailed usage guidance:
//
// toml: error: newlines not allowed within inline tables
//
// At line 1, column 18:
//
// 1 | x = [{ key = 42 #
// ^
//
// Error help:
//
// Inline tables must always be on a single line:
//
// table = {key = 42, second = 43}
//
// It is invalid to split them over multiple lines like so:
//
// # INVALID
// table = {
// key = 42,
// second = 43
// }
//
// Use regular for this:
//
// [table]
// key = 42
// second = 43
type ParseError struct {
Message string // Short technical message.
Usage string // Longer message with usage guidance; may be blank.
Position Position // Position of the error
LastKey string // Last parsed key, may be blank.
Line int // Line the error occurred. Deprecated: use Position.
err error
input string
}
// Position of an error.
type Position struct {
Line int // Line number, starting at 1.
Start int // Start of error, as byte offset starting at 0.
Len int // Lenght in bytes.
}
func (pe ParseError) Error() string {
msg := pe.Message
if msg == "" { // Error from errorf()
msg = pe.err.Error()
}
if pe.LastKey == "" {
return fmt.Sprintf("toml: line %d: %s", pe.Position.Line, msg)
}
return fmt.Sprintf("toml: line %d (last key %q): %s",
pe.Position.Line, pe.LastKey, msg)
}
// ErrorWithUsage() returns the error with detailed location context.
//
// See the documentation on ParseError.
func (pe ParseError) ErrorWithPosition() string {
if pe.input == "" { // Should never happen, but just in case.
return pe.Error()
}
var (
lines = strings.Split(pe.input, "\n")
col = pe.column(lines)
b = new(strings.Builder)
)
msg := pe.Message
if msg == "" {
msg = pe.err.Error()
}
// TODO: don't show control characters as literals? This may not show up
// well everywhere.
if pe.Position.Len == 1 {
fmt.Fprintf(b, "toml: error: %s\n\nAt line %d, column %d:\n\n",
msg, pe.Position.Line, col+1)
} else {
fmt.Fprintf(b, "toml: error: %s\n\nAt line %d, column %d-%d:\n\n",
msg, pe.Position.Line, col, col+pe.Position.Len)
}
if pe.Position.Line > 2 {
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line-2, lines[pe.Position.Line-3])
}
if pe.Position.Line > 1 {
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line-1, lines[pe.Position.Line-2])
}
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line, lines[pe.Position.Line-1])
fmt.Fprintf(b, "% 10s%s%s\n", "", strings.Repeat(" ", col), strings.Repeat("^", pe.Position.Len))
return b.String()
}
// ErrorWithUsage() returns the error with detailed location context and usage
// guidance.
//
// See the documentation on ParseError.
func (pe ParseError) ErrorWithUsage() string {
m := pe.ErrorWithPosition()
if u, ok := pe.err.(interface{ Usage() string }); ok && u.Usage() != "" {
return m + "Error help:\n\n " +
strings.ReplaceAll(strings.TrimSpace(u.Usage()), "\n", "\n ") +
"\n"
}
return m
}
func (pe ParseError) column(lines []string) int {
var pos, col int
for i := range lines {
ll := len(lines[i]) + 1 // +1 for the removed newline
if pos+ll >= pe.Position.Start {
col = pe.Position.Start - pos
if col < 0 { // Should never happen, but just in case.
col = 0
}
break
}
pos += ll
}
return col
}
type (
errLexControl struct{ r rune }
errLexEscape struct{ r rune }
errLexUTF8 struct{ b byte }
errLexInvalidNum struct{ v string }
errLexInvalidDate struct{ v string }
errLexInlineTableNL struct{}
errLexStringNL struct{}
)
func (e errLexControl) Error() string {
return fmt.Sprintf("TOML files cannot contain control characters: '0x%02x'", e.r)
}
func (e errLexControl) Usage() string { return "" }
func (e errLexEscape) Error() string { return fmt.Sprintf(`invalid escape in string '\%c'`, e.r) }
func (e errLexEscape) Usage() string { return usageEscape }
func (e errLexUTF8) Error() string { return fmt.Sprintf("invalid UTF-8 byte: 0x%02x", e.b) }
func (e errLexUTF8) Usage() string { return "" }
func (e errLexInvalidNum) Error() string { return fmt.Sprintf("invalid number: %q", e.v) }
func (e errLexInvalidNum) Usage() string { return "" }
func (e errLexInvalidDate) Error() string { return fmt.Sprintf("invalid date: %q", e.v) }
func (e errLexInvalidDate) Usage() string { return "" }
func (e errLexInlineTableNL) Error() string { return "newlines not allowed within inline tables" }
func (e errLexInlineTableNL) Usage() string { return usageInlineNewline }
func (e errLexStringNL) Error() string { return "strings cannot contain newlines" }
func (e errLexStringNL) Usage() string { return usageStringNewline }
const usageEscape = `
A '\' inside a "-delimited string is interpreted as an escape character.
The following escape sequences are supported:
\b, \t, \n, \f, \r, \", \\, \uXXXX, and \UXXXXXXXX
To prevent a '\' from being recognized as an escape character, use either:
- a ' or '''-delimited string; escape characters aren't processed in them; or
- write two backslashes to get a single backslash: '\\'.
If you're trying to add a Windows path (e.g. "C:\Users\martin") then using '/'
instead of '\' will usually also work: "C:/Users/martin".
`
const usageInlineNewline = `
Inline tables must always be on a single line:
table = {key = 42, second = 43}
It is invalid to split them over multiple lines like so:
# INVALID
table = {
key = 42,
second = 43
}
Use regular for this:
[table]
key = 42
second = 43
`
const usageStringNewline = `
Strings must always be on a single line, and cannot span more than one line:
# INVALID
string = "Hello,
world!"
Instead use """ or ''' to split strings over multiple lines:
string = """Hello,
world!"""
`

36
vendor/github.com/BurntSushi/toml/internal/tz.go generated vendored Normal file
View File

@@ -0,0 +1,36 @@
package internal
import "time"
// Timezones used for local datetime, date, and time TOML types.
//
// The exact way times and dates without a timezone should be interpreted is not
// well-defined in the TOML specification and left to the implementation. These
// defaults to current local timezone offset of the computer, but this can be
// changed by changing these variables before decoding.
//
// TODO:
// Ideally we'd like to offer people the ability to configure the used timezone
// by setting Decoder.Timezone and Encoder.Timezone; however, this is a bit
// tricky: the reason we use three different variables for this is to support
// round-tripping without these specific TZ names we wouldn't know which
// format to use.
//
// There isn't a good way to encode this right now though, and passing this sort
// of information also ties in to various related issues such as string format
// encoding, encoding of comments, etc.
//
// So, for the time being, just put this in internal until we can write a good
// comprehensive API for doing all of this.
//
// The reason they're exported is because they're referred from in e.g.
// internal/tag.
//
// Note that this behaviour is valid according to the TOML spec as the exact
// behaviour is left up to implementations.
var (
localOffset = func() int { _, o := time.Now().Zone(); return o }()
LocalDatetime = time.FixedZone("datetime-local", localOffset)
LocalDate = time.FixedZone("date-local", localOffset)
LocalTime = time.FixedZone("time-local", localOffset)
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,33 +1,39 @@
package toml
import "strings"
import (
"strings"
)
// MetaData allows access to meta information about TOML data that may not
// be inferrable via reflection. In particular, whether a key has been defined
// and the TOML type of a key.
// MetaData allows access to meta information about TOML data that's not
// accessible otherwise.
//
// It allows checking if a key is defined in the TOML data, whether any keys
// were undecoded, and the TOML type of a key.
type MetaData struct {
context Key // Used only during decoding.
mapping map[string]interface{}
types map[string]tomlType
keys []Key
decoded map[string]bool
context Key // Used only during decoding.
decoded map[string]struct{}
}
// IsDefined returns true if the key given exists in the TOML data. The key
// should be specified hierarchially. e.g.,
// IsDefined reports if the key exists in the TOML data.
//
// // access the TOML key 'a.b.c'
// IsDefined("a", "b", "c")
// The key should be specified hierarchically, for example to access the TOML
// key "a.b.c" you would use IsDefined("a", "b", "c"). Keys are case sensitive.
//
// IsDefined will return false if an empty key given. Keys are case sensitive.
// Returns false for an empty key.
func (md *MetaData) IsDefined(key ...string) bool {
if len(key) == 0 {
return false
}
var hash map[string]interface{}
var ok bool
var hashOrVal interface{} = md.mapping
var (
hash map[string]interface{}
ok bool
hashOrVal interface{} = md.mapping
)
for _, k := range key {
if hash, ok = hashOrVal.(map[string]interface{}); !ok {
return false
@@ -41,58 +47,20 @@ func (md *MetaData) IsDefined(key ...string) bool {
// Type returns a string representation of the type of the key specified.
//
// Type will return the empty string if given an empty key or a key that
// does not exist. Keys are case sensitive.
// Type will return the empty string if given an empty key or a key that does
// not exist. Keys are case sensitive.
func (md *MetaData) Type(key ...string) string {
fullkey := strings.Join(key, ".")
if typ, ok := md.types[fullkey]; ok {
if typ, ok := md.types[Key(key).String()]; ok {
return typ.typeString()
}
return ""
}
// Key is the type of any TOML key, including key groups. Use (MetaData).Keys
// to get values of this type.
type Key []string
func (k Key) String() string {
return strings.Join(k, ".")
}
func (k Key) maybeQuotedAll() string {
var ss []string
for i := range k {
ss = append(ss, k.maybeQuoted(i))
}
return strings.Join(ss, ".")
}
func (k Key) maybeQuoted(i int) string {
quote := false
for _, c := range k[i] {
if !isBareKeyChar(c) {
quote = true
break
}
}
if quote {
return "\"" + strings.Replace(k[i], "\"", "\\\"", -1) + "\""
}
return k[i]
}
func (k Key) add(piece string) Key {
newKey := make(Key, len(k)+1)
copy(newKey, k)
newKey[len(k)] = piece
return newKey
}
// Keys returns a slice of every key in the TOML data, including key groups.
// Each key is itself a slice, where the first element is the top of the
// hierarchy and the last is the most specific.
//
// The list will have the same order as the keys appeared in the TOML data.
// Each key is itself a slice, where the first element is the top of the
// hierarchy and the last is the most specific. The list will have the same
// order as the keys appeared in the TOML data.
//
// All keys returned are non-empty.
func (md *MetaData) Keys() []Key {
@@ -113,9 +81,40 @@ func (md *MetaData) Keys() []Key {
func (md *MetaData) Undecoded() []Key {
undecoded := make([]Key, 0, len(md.keys))
for _, key := range md.keys {
if !md.decoded[key.String()] {
if _, ok := md.decoded[key.String()]; !ok {
undecoded = append(undecoded, key)
}
}
return undecoded
}
// Key represents any TOML key, including key groups. Use (MetaData).Keys to get
// values of this type.
type Key []string
func (k Key) String() string {
ss := make([]string, len(k))
for i := range k {
ss[i] = k.maybeQuoted(i)
}
return strings.Join(ss, ".")
}
func (k Key) maybeQuoted(i int) string {
if k[i] == "" {
return `""`
}
for _, c := range k[i] {
if !isBareKeyChar(c) {
return `"` + dblQuotedReplacer.Replace(k[i]) + `"`
}
}
return k[i]
}
func (k Key) add(piece string) Key {
newKey := make(Key, len(k)+1)
copy(newKey, k)
newKey[len(k)] = piece
return newKey
}

View File

@@ -5,54 +5,63 @@ import (
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
"github.com/BurntSushi/toml/internal"
)
type parser struct {
mapping map[string]interface{}
types map[string]tomlType
lx *lexer
lx *lexer
context Key // Full key for the current hash in scope.
currentKey string // Base key name for everything except hashes.
pos Position // Current position in the TOML file.
// A list of keys in the order that they appear in the TOML data.
ordered []Key
// the full key for the current hash in scope
context Key
// the base key name for everything except hashes
currentKey string
// rough approximation of line number
approxLine int
// A map of 'key.group.names' to whether they were created implicitly.
implicits map[string]bool
}
type parseError string
func (pe parseError) Error() string {
return string(pe)
ordered []Key // List of keys in the order that they appear in the TOML data.
mapping map[string]interface{} // Map keyname → key value.
types map[string]tomlType // Map keyname → TOML type.
implicits map[string]struct{} // Record implicit keys (e.g. "key.group.names").
}
func parse(data string) (p *parser, err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
if err, ok = r.(parseError); ok {
if pErr, ok := r.(ParseError); ok {
pErr.input = data
err = pErr
return
}
panic(r)
}
}()
// Read over BOM; do this here as the lexer calls utf8.DecodeRuneInString()
// which mangles stuff.
if strings.HasPrefix(data, "\xff\xfe") || strings.HasPrefix(data, "\xfe\xff") {
data = data[2:]
}
// Examine first few bytes for NULL bytes; this probably means it's a UTF-16
// file (second byte in surrogate pair being NULL). Again, do this here to
// avoid having to deal with UTF-8/16 stuff in the lexer.
ex := 6
if len(data) < 6 {
ex = len(data)
}
if i := strings.IndexRune(data[:ex], 0); i > -1 {
return nil, ParseError{
Message: "files cannot contain NULL bytes; probably using UTF-16; TOML files must be UTF-8",
Position: Position{Line: 1, Start: i, Len: 1},
Line: 1,
input: data,
}
}
p = &parser{
mapping: make(map[string]interface{}),
types: make(map[string]tomlType),
lx: lex(data),
ordered: make([]Key, 0),
implicits: make(map[string]bool),
implicits: make(map[string]struct{}),
}
for {
item := p.next()
@@ -65,20 +74,48 @@ func parse(data string) (p *parser, err error) {
return p, nil
}
func (p *parser) panicItemf(it item, format string, v ...interface{}) {
panic(ParseError{
Message: fmt.Sprintf(format, v...),
Position: it.pos,
Line: it.pos.Len,
LastKey: p.current(),
})
}
func (p *parser) panicf(format string, v ...interface{}) {
msg := fmt.Sprintf("Near line %d (last key parsed '%s'): %s",
p.approxLine, p.current(), fmt.Sprintf(format, v...))
panic(parseError(msg))
panic(ParseError{
Message: fmt.Sprintf(format, v...),
Position: p.pos,
Line: p.pos.Line,
LastKey: p.current(),
})
}
func (p *parser) next() item {
it := p.lx.nextItem()
//fmt.Printf("ITEM %-18s line %-3d │ %q\n", it.typ, it.line, it.val)
if it.typ == itemError {
p.panicf("%s", it.val)
if it.err != nil {
panic(ParseError{
Position: it.pos,
Line: it.pos.Line,
LastKey: p.current(),
err: it.err,
})
}
p.panicItemf(it, "%s", it.val)
}
return it
}
func (p *parser) nextPos() item {
it := p.next()
p.pos = it.pos
return it
}
func (p *parser) bug(format string, v ...interface{}) {
panic(fmt.Sprintf("BUG: "+format+"\n\n", v...))
}
@@ -97,44 +134,59 @@ func (p *parser) assertEqual(expected, got itemType) {
func (p *parser) topLevel(item item) {
switch item.typ {
case itemCommentStart:
p.approxLine = item.line
case itemCommentStart: // # ..
p.expect(itemText)
case itemTableStart:
kg := p.next()
p.approxLine = kg.line
case itemTableStart: // [ .. ]
name := p.nextPos()
var key Key
for ; kg.typ != itemTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
for ; name.typ != itemTableEnd && name.typ != itemEOF; name = p.next() {
key = append(key, p.keyString(name))
}
p.assertEqual(itemTableEnd, kg.typ)
p.assertEqual(itemTableEnd, name.typ)
p.establishContext(key, false)
p.addContext(key, false)
p.setType("", tomlHash)
p.ordered = append(p.ordered, key)
case itemArrayTableStart:
kg := p.next()
p.approxLine = kg.line
case itemArrayTableStart: // [[ .. ]]
name := p.nextPos()
var key Key
for ; kg.typ != itemArrayTableEnd && kg.typ != itemEOF; kg = p.next() {
key = append(key, p.keyString(kg))
for ; name.typ != itemArrayTableEnd && name.typ != itemEOF; name = p.next() {
key = append(key, p.keyString(name))
}
p.assertEqual(itemArrayTableEnd, kg.typ)
p.assertEqual(itemArrayTableEnd, name.typ)
p.establishContext(key, true)
p.addContext(key, true)
p.setType("", tomlArrayHash)
p.ordered = append(p.ordered, key)
case itemKeyStart:
kname := p.next()
p.approxLine = kname.line
p.currentKey = p.keyString(kname)
case itemKeyStart: // key = ..
outerContext := p.context
/// Read all the key parts (e.g. 'a' and 'b' in 'a.b')
k := p.nextPos()
var key Key
for ; k.typ != itemKeyEnd && k.typ != itemEOF; k = p.next() {
key = append(key, p.keyString(k))
}
p.assertEqual(itemKeyEnd, k.typ)
val, typ := p.value(p.next())
p.setValue(p.currentKey, val)
p.setType(p.currentKey, typ)
/// The current key is the last part.
p.currentKey = key[len(key)-1]
/// All the other parts (if any) are the context; need to set each part
/// as implicit.
context := key[:len(key)-1]
for i := range context {
p.addImplicitContext(append(p.context, context[i:i+1]...))
}
/// Set value.
val, typ := p.value(p.next(), false)
p.set(p.currentKey, val, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
/// Remove the context we added (preserving any context from [tbl] lines).
p.context = outerContext
p.currentKey = ""
default:
p.bug("Unexpected type at top level: %s", item.typ)
@@ -148,180 +200,262 @@ func (p *parser) keyString(it item) string {
return it.val
case itemString, itemMultilineString,
itemRawString, itemRawMultilineString:
s, _ := p.value(it)
s, _ := p.value(it, false)
return s.(string)
default:
p.bug("Unexpected key type: %s", it.typ)
panic("unreachable")
}
panic("unreachable")
}
var datetimeRepl = strings.NewReplacer(
"z", "Z",
"t", "T",
" ", "T")
// value translates an expected value from the lexer into a Go value wrapped
// as an empty interface.
func (p *parser) value(it item) (interface{}, tomlType) {
func (p *parser) value(it item, parentIsArray bool) (interface{}, tomlType) {
switch it.typ {
case itemString:
return p.replaceEscapes(it.val), p.typeOfPrimitive(it)
return p.replaceEscapes(it, it.val), p.typeOfPrimitive(it)
case itemMultilineString:
trimmed := stripFirstNewline(stripEscapedWhitespace(it.val))
return p.replaceEscapes(trimmed), p.typeOfPrimitive(it)
return p.replaceEscapes(it, stripFirstNewline(p.stripEscapedNewlines(it.val))), p.typeOfPrimitive(it)
case itemRawString:
return it.val, p.typeOfPrimitive(it)
case itemRawMultilineString:
return stripFirstNewline(it.val), p.typeOfPrimitive(it)
case itemInteger:
return p.valueInteger(it)
case itemFloat:
return p.valueFloat(it)
case itemBool:
switch it.val {
case "true":
return true, p.typeOfPrimitive(it)
case "false":
return false, p.typeOfPrimitive(it)
default:
p.bug("Expected boolean value, but got '%s'.", it.val)
}
p.bug("Expected boolean value, but got '%s'.", it.val)
case itemInteger:
if !numUnderscoresOK(it.val) {
p.panicf("Invalid integer %q: underscores must be surrounded by digits",
it.val)
}
val := strings.Replace(it.val, "_", "", -1)
num, err := strconv.ParseInt(val, 10, 64)
if err != nil {
// Distinguish integer values. Normally, it'd be a bug if the lexer
// provides an invalid integer, but it's possible that the number is
// out of range of valid values (which the lexer cannot determine).
// So mark the former as a bug but the latter as a legitimate user
// error.
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Integer '%s' is out of the range of 64-bit "+
"signed integers.", it.val)
} else {
p.bug("Expected integer value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemFloat:
parts := strings.FieldsFunc(it.val, func(r rune) bool {
switch r {
case '.', 'e', 'E':
return true
}
return false
})
for _, part := range parts {
if !numUnderscoresOK(part) {
p.panicf("Invalid float %q: underscores must be "+
"surrounded by digits", it.val)
}
}
if !numPeriodsOK(it.val) {
// As a special case, numbers like '123.' or '1.e2',
// which are valid as far as Go/strconv are concerned,
// must be rejected because TOML says that a fractional
// part consists of '.' followed by 1+ digits.
p.panicf("Invalid float %q: '.' must be followed "+
"by one or more digits", it.val)
}
val := strings.Replace(it.val, "_", "", -1)
num, err := strconv.ParseFloat(val, 64)
if err != nil {
if e, ok := err.(*strconv.NumError); ok &&
e.Err == strconv.ErrRange {
p.panicf("Float '%s' is out of the range of 64-bit "+
"IEEE-754 floating-point numbers.", it.val)
} else {
p.panicf("Invalid float value: %q", it.val)
}
}
return num, p.typeOfPrimitive(it)
case itemDatetime:
var t time.Time
var ok bool
var err error
for _, format := range []string{
"2006-01-02T15:04:05Z07:00",
"2006-01-02T15:04:05",
"2006-01-02",
} {
t, err = time.ParseInLocation(format, it.val, time.Local)
if err == nil {
ok = true
break
}
}
if !ok {
p.panicf("Invalid TOML Datetime: %q.", it.val)
}
return t, p.typeOfPrimitive(it)
return p.valueDatetime(it)
case itemArray:
array := make([]interface{}, 0)
types := make([]tomlType, 0)
for it = p.next(); it.typ != itemArrayEnd; it = p.next() {
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
val, typ := p.value(it)
array = append(array, val)
types = append(types, typ)
}
return array, p.typeOfArray(types)
return p.valueArray(it)
case itemInlineTableStart:
var (
hash = make(map[string]interface{})
outerContext = p.context
outerKey = p.currentKey
)
p.context = append(p.context, p.currentKey)
p.currentKey = ""
for it := p.next(); it.typ != itemInlineTableEnd; it = p.next() {
if it.typ != itemKeyStart {
p.bug("Expected key start but instead found %q, around line %d",
it.val, p.approxLine)
}
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
// retrieve key
k := p.next()
p.approxLine = k.line
kname := p.keyString(k)
// retrieve value
p.currentKey = kname
val, typ := p.value(p.next())
// make sure we keep metadata up to date
p.setType(kname, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
hash[kname] = val
}
p.context = outerContext
p.currentKey = outerKey
return hash, tomlHash
return p.valueInlineTable(it, parentIsArray)
default:
p.bug("Unexpected value type: %s", it.typ)
}
p.bug("Unexpected value type: %s", it.typ)
panic("unreachable")
}
func (p *parser) valueInteger(it item) (interface{}, tomlType) {
if !numUnderscoresOK(it.val) {
p.panicItemf(it, "Invalid integer %q: underscores must be surrounded by digits", it.val)
}
if numHasLeadingZero(it.val) {
p.panicItemf(it, "Invalid integer %q: cannot have leading zeroes", it.val)
}
num, err := strconv.ParseInt(it.val, 0, 64)
if err != nil {
// Distinguish integer values. Normally, it'd be a bug if the lexer
// provides an invalid integer, but it's possible that the number is
// out of range of valid values (which the lexer cannot determine).
// So mark the former as a bug but the latter as a legitimate user
// error.
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange {
p.panicItemf(it, "Integer '%s' is out of the range of 64-bit signed integers.", it.val)
} else {
p.bug("Expected integer value, but got '%s'.", it.val)
}
}
return num, p.typeOfPrimitive(it)
}
func (p *parser) valueFloat(it item) (interface{}, tomlType) {
parts := strings.FieldsFunc(it.val, func(r rune) bool {
switch r {
case '.', 'e', 'E':
return true
}
return false
})
for _, part := range parts {
if !numUnderscoresOK(part) {
p.panicItemf(it, "Invalid float %q: underscores must be surrounded by digits", it.val)
}
}
if len(parts) > 0 && numHasLeadingZero(parts[0]) {
p.panicItemf(it, "Invalid float %q: cannot have leading zeroes", it.val)
}
if !numPeriodsOK(it.val) {
// As a special case, numbers like '123.' or '1.e2',
// which are valid as far as Go/strconv are concerned,
// must be rejected because TOML says that a fractional
// part consists of '.' followed by 1+ digits.
p.panicItemf(it, "Invalid float %q: '.' must be followed by one or more digits", it.val)
}
val := strings.Replace(it.val, "_", "", -1)
if val == "+nan" || val == "-nan" { // Go doesn't support this, but TOML spec does.
val = "nan"
}
num, err := strconv.ParseFloat(val, 64)
if err != nil {
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange {
p.panicItemf(it, "Float '%s' is out of the range of 64-bit IEEE-754 floating-point numbers.", it.val)
} else {
p.panicItemf(it, "Invalid float value: %q", it.val)
}
}
return num, p.typeOfPrimitive(it)
}
var dtTypes = []struct {
fmt string
zone *time.Location
}{
{time.RFC3339Nano, time.Local},
{"2006-01-02T15:04:05.999999999", internal.LocalDatetime},
{"2006-01-02", internal.LocalDate},
{"15:04:05.999999999", internal.LocalTime},
}
func (p *parser) valueDatetime(it item) (interface{}, tomlType) {
it.val = datetimeRepl.Replace(it.val)
var (
t time.Time
ok bool
err error
)
for _, dt := range dtTypes {
t, err = time.ParseInLocation(dt.fmt, it.val, dt.zone)
if err == nil {
ok = true
break
}
}
if !ok {
p.panicItemf(it, "Invalid TOML Datetime: %q.", it.val)
}
return t, p.typeOfPrimitive(it)
}
func (p *parser) valueArray(it item) (interface{}, tomlType) {
p.setType(p.currentKey, tomlArray)
// p.setType(p.currentKey, typ)
var (
types []tomlType
// Initialize to a non-nil empty slice. This makes it consistent with
// how S = [] decodes into a non-nil slice inside something like struct
// { S []string }. See #338
array = []interface{}{}
)
for it = p.next(); it.typ != itemArrayEnd; it = p.next() {
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
val, typ := p.value(it, true)
array = append(array, val)
types = append(types, typ)
// XXX: types isn't used here, we need it to record the accurate type
// information.
//
// Not entirely sure how to best store this; could use "key[0]",
// "key[1]" notation, or maybe store it on the Array type?
}
return array, tomlArray
}
func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tomlType) {
var (
hash = make(map[string]interface{})
outerContext = p.context
outerKey = p.currentKey
)
p.context = append(p.context, p.currentKey)
prevContext := p.context
p.currentKey = ""
p.addImplicit(p.context)
p.addContext(p.context, parentIsArray)
/// Loop over all table key/value pairs.
for it := p.next(); it.typ != itemInlineTableEnd; it = p.next() {
if it.typ == itemCommentStart {
p.expect(itemText)
continue
}
/// Read all key parts.
k := p.nextPos()
var key Key
for ; k.typ != itemKeyEnd && k.typ != itemEOF; k = p.next() {
key = append(key, p.keyString(k))
}
p.assertEqual(itemKeyEnd, k.typ)
/// The current key is the last part.
p.currentKey = key[len(key)-1]
/// All the other parts (if any) are the context; need to set each part
/// as implicit.
context := key[:len(key)-1]
for i := range context {
p.addImplicitContext(append(p.context, context[i:i+1]...))
}
/// Set the value.
val, typ := p.value(p.next(), false)
p.set(p.currentKey, val, typ)
p.ordered = append(p.ordered, p.context.add(p.currentKey))
hash[p.currentKey] = val
/// Restore context.
p.context = prevContext
}
p.context = outerContext
p.currentKey = outerKey
return hash, tomlHash
}
// numHasLeadingZero checks if this number has leading zeroes, allowing for '0',
// +/- signs, and base prefixes.
func numHasLeadingZero(s string) bool {
if len(s) > 1 && s[0] == '0' && !(s[1] == 'b' || s[1] == 'o' || s[1] == 'x') { // Allow 0b, 0o, 0x
return true
}
if len(s) > 2 && (s[0] == '-' || s[0] == '+') && s[1] == '0' {
return true
}
return false
}
// numUnderscoresOK checks whether each underscore in s is surrounded by
// characters that are not underscores.
func numUnderscoresOK(s string) bool {
switch s {
case "nan", "+nan", "-nan", "inf", "-inf", "+inf":
return true
}
accept := false
for _, r := range s {
if r == '_' {
if !accept {
return false
}
accept = false
continue
}
accept = true
// isHexadecimal is a superset of all the permissable characters
// surrounding an underscore.
accept = isHexadecimal(r)
}
return accept
}
@@ -338,13 +472,12 @@ func numPeriodsOK(s string) bool {
return !period
}
// establishContext sets the current context of the parser,
// where the context is either a hash or an array of hashes. Which one is
// set depends on the value of the `array` parameter.
// Set the current context of the parser, where the context is either a hash or
// an array of hashes, depending on the value of the `array` parameter.
//
// Establishing the context also makes sure that the key isn't a duplicate, and
// will create implicit hashes automatically.
func (p *parser) establishContext(key Key, array bool) {
func (p *parser) addContext(key Key, array bool) {
var ok bool
// Always start at the top level and drill down for our context.
@@ -383,7 +516,7 @@ func (p *parser) establishContext(key Key, array bool) {
// list of tables for it.
k := key[len(key)-1]
if _, ok := hashContext[k]; !ok {
hashContext[k] = make([]map[string]interface{}, 0, 5)
hashContext[k] = make([]map[string]interface{}, 0, 4)
}
// Add a new table. But make sure the key hasn't already been used
@@ -391,8 +524,7 @@ func (p *parser) establishContext(key Key, array bool) {
if hash, ok := hashContext[k].([]map[string]interface{}); ok {
hashContext[k] = append(hash, make(map[string]interface{}))
} else {
p.panicf("Key '%s' was already created and cannot be used as "+
"an array.", keyContext)
p.panicf("Key '%s' was already created and cannot be used as an array.", key)
}
} else {
p.setValue(key[len(key)-1], make(map[string]interface{}))
@@ -400,15 +532,22 @@ func (p *parser) establishContext(key Key, array bool) {
p.context = append(p.context, key[len(key)-1])
}
// set calls setValue and setType.
func (p *parser) set(key string, val interface{}, typ tomlType) {
p.setValue(key, val)
p.setType(key, typ)
}
// setValue sets the given key to the given value in the current context.
// It will make sure that the key hasn't already been defined, account for
// implicit key groups.
func (p *parser) setValue(key string, value interface{}) {
var tmpHash interface{}
var ok bool
hash := p.mapping
keyContext := make(Key, 0)
var (
tmpHash interface{}
ok bool
hash = p.mapping
keyContext Key
)
for _, k := range p.context {
keyContext = append(keyContext, k)
if tmpHash, ok = hash[k]; !ok {
@@ -422,24 +561,26 @@ func (p *parser) setValue(key string, value interface{}) {
case map[string]interface{}:
hash = t
default:
p.bug("Expected hash to have type 'map[string]interface{}', but "+
"it has '%T' instead.", tmpHash)
p.panicf("Key '%s' has already been defined.", keyContext)
}
}
keyContext = append(keyContext, key)
if _, ok := hash[key]; ok {
// Typically, if the given key has already been set, then we have
// to raise an error since duplicate keys are disallowed. However,
// it's possible that a key was previously defined implicitly. In this
// case, it is allowed to be redefined concretely. (See the
// `tests/valid/implicit-and-explicit-after.toml` test in `toml-test`.)
// Normally redefining keys isn't allowed, but the key could have been
// defined implicitly and it's allowed to be redefined concretely. (See
// the `valid/implicit-and-explicit-after.toml` in toml-test)
//
// But we have to make sure to stop marking it as an implicit. (So that
// another redefinition provokes an error.)
//
// Note that since it has already been defined (as a hash), we don't
// want to overwrite it. So our business is done.
if p.isArray(keyContext) {
p.removeImplicit(keyContext)
hash[key] = value
return
}
if p.isImplicit(keyContext) {
p.removeImplicit(keyContext)
return
@@ -449,40 +590,39 @@ func (p *parser) setValue(key string, value interface{}) {
// key, which is *always* wrong.
p.panicf("Key '%s' has already been defined.", keyContext)
}
hash[key] = value
}
// setType sets the type of a particular value at a given key.
// It should be called immediately AFTER setValue.
// setType sets the type of a particular value at a given key. It should be
// called immediately AFTER setValue.
//
// Note that if `key` is empty, then the type given will be applied to the
// current context (which is either a table or an array of tables).
func (p *parser) setType(key string, typ tomlType) {
keyContext := make(Key, 0, len(p.context)+1)
for _, k := range p.context {
keyContext = append(keyContext, k)
}
keyContext = append(keyContext, p.context...)
if len(key) > 0 { // allow type setting for hashes
keyContext = append(keyContext, key)
}
// Special case to make empty keys ("" = 1) work.
// Without it it will set "" rather than `""`.
// TODO: why is this needed? And why is this only needed here?
if len(keyContext) == 0 {
keyContext = Key{""}
}
p.types[keyContext.String()] = typ
}
// addImplicit sets the given Key as having been created implicitly.
func (p *parser) addImplicit(key Key) {
p.implicits[key.String()] = true
}
// removeImplicit stops tagging the given key as having been implicitly
// created.
func (p *parser) removeImplicit(key Key) {
p.implicits[key.String()] = false
}
// isImplicit returns true if the key group pointed to by the key was created
// implicitly.
func (p *parser) isImplicit(key Key) bool {
return p.implicits[key.String()]
// Implicit keys need to be created when tables are implied in "a.b.c.d = 1" and
// "[a.b.c]" (the "a", "b", and "c" hashes are never created explicitly).
func (p *parser) addImplicit(key Key) { p.implicits[key.String()] = struct{}{} }
func (p *parser) removeImplicit(key Key) { delete(p.implicits, key.String()) }
func (p *parser) isImplicit(key Key) bool { _, ok := p.implicits[key.String()]; return ok }
func (p *parser) isArray(key Key) bool { return p.types[key.String()] == tomlArray }
func (p *parser) addImplicitContext(key Key) {
p.addImplicit(key)
p.addContext(key, false)
}
// current returns the full key name of the current context.
@@ -497,24 +637,62 @@ func (p *parser) current() string {
}
func stripFirstNewline(s string) string {
if len(s) == 0 || s[0] != '\n' {
if len(s) > 0 && s[0] == '\n' {
return s[1:]
}
if len(s) > 1 && s[0] == '\r' && s[1] == '\n' {
return s[2:]
}
return s
}
// Remove newlines inside triple-quoted strings if a line ends with "\".
func (p *parser) stripEscapedNewlines(s string) string {
split := strings.Split(s, "\n")
if len(split) < 1 {
return s
}
return s[1:]
}
func stripEscapedWhitespace(s string) string {
esc := strings.Split(s, "\\\n")
if len(esc) > 1 {
for i := 1; i < len(esc); i++ {
esc[i] = strings.TrimLeftFunc(esc[i], unicode.IsSpace)
escNL := false // Keep track of the last non-blank line was escaped.
for i, line := range split {
line = strings.TrimRight(line, " \t\r")
if len(line) == 0 || line[len(line)-1] != '\\' {
split[i] = strings.TrimRight(split[i], "\r")
if !escNL && i != len(split)-1 {
split[i] += "\n"
}
continue
}
escBS := true
for j := len(line) - 1; j >= 0 && line[j] == '\\'; j-- {
escBS = !escBS
}
if escNL {
line = strings.TrimLeft(line, " \t\r")
}
escNL = !escBS
if escBS {
split[i] += "\n"
continue
}
if i == len(split)-1 {
p.panicf("invalid escape: '\\ '")
}
split[i] = line[:len(line)-1] // Remove \
if len(split)-1 > i {
split[i+1] = strings.TrimLeft(split[i+1], " \t\r")
}
}
return strings.Join(esc, "")
return strings.Join(split, "")
}
func (p *parser) replaceEscapes(str string) string {
var replaced []rune
func (p *parser) replaceEscapes(it item, str string) string {
replaced := make([]rune, 0, len(str))
s := []byte(str)
r := 0
for r < len(s) {
@@ -533,6 +711,9 @@ func (p *parser) replaceEscapes(str string) string {
default:
p.bug("Expected valid escape code after \\, but got %q.", s[r])
return ""
case ' ', '\t':
p.panicItemf(it, "invalid escape: '\\%c'", s[r])
return ""
case 'b':
replaced = append(replaced, rune(0x0008))
r += 1
@@ -558,14 +739,14 @@ func (p *parser) replaceEscapes(str string) string {
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+5). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+5])
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+5])
replaced = append(replaced, escaped)
r += 5
case 'U':
// At this point, we know we have a Unicode escape of the form
// `uXXXX` at [r, r+9). (Because the lexer guarantees this
// for us.)
escaped := p.asciiEscapeToUnicode(s[r+1 : r+9])
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+9])
replaced = append(replaced, escaped)
r += 9
}
@@ -573,20 +754,14 @@ func (p *parser) replaceEscapes(str string) string {
return string(replaced)
}
func (p *parser) asciiEscapeToUnicode(bs []byte) rune {
func (p *parser) asciiEscapeToUnicode(it item, bs []byte) rune {
s := string(bs)
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32)
if err != nil {
p.bug("Could not parse '%s' as a hexadecimal number, but the "+
"lexer claims it's OK: %s", s, err)
p.bug("Could not parse '%s' as a hexadecimal number, but the lexer claims it's OK: %s", s, err)
}
if !utf8.ValidRune(rune(hex)) {
p.panicf("Escaped character '\\u%s' is not valid UTF-8.", s)
p.panicItemf(it, "Escaped character '\\u%s' is not valid UTF-8.", s)
}
return rune(hex)
}
func isStringType(ty itemType) bool {
return ty == itemString || ty == itemMultilineString ||
ty == itemRawString || ty == itemRawMultilineString
}

View File

@@ -1 +0,0 @@
au BufWritePost *.go silent!make tags > /dev/null 2>&1

View File

@@ -70,8 +70,8 @@ func typeFields(t reflect.Type) []field {
next := []field{{typ: t}}
// Count of queued names for current level and the next.
count := map[reflect.Type]int{}
nextCount := map[reflect.Type]int{}
var count map[reflect.Type]int
var nextCount map[reflect.Type]int
// Types already visited at an earlier level.
visited := map[reflect.Type]bool{}

View File

@@ -16,7 +16,7 @@ func typeEqual(t1, t2 tomlType) bool {
return t1.typeString() == t2.typeString()
}
func typeIsHash(t tomlType) bool {
func typeIsTable(t tomlType) bool {
return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash)
}
@@ -68,24 +68,3 @@ func (p *parser) typeOfPrimitive(lexItem item) tomlType {
p.bug("Cannot infer primitive type of lex item '%s'.", lexItem)
panic("unreachable")
}
// typeOfArray returns a tomlType for an array given a list of types of its
// values.
//
// In the current spec, if an array is homogeneous, then its type is always
// "Array". If the array is not homogeneous, an error is generated.
func (p *parser) typeOfArray(types []tomlType) tomlType {
// Empty arrays are cool.
if len(types) == 0 {
return tomlArray
}
theType := types[0]
for _, t := range types[1:] {
if !typeEqual(theType, t) {
p.panicf("Array contains values of type '%s' and '%s', but "+
"arrays must be homogeneous.", theType, t)
}
}
return tomlArray
}

View File

@@ -1,6 +1,6 @@
The MIT License (MIT)
Copyright (c) 2014-2017 TSUYUSATO Kitsune
Copyright (c) 2014-2019 TSUYUSATO Kitsune
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@@ -1,4 +1,6 @@
# heredoc [![CircleCI](https://circleci.com/gh/MakeNowJust/heredoc.svg?style=svg)](https://circleci.com/gh/MakeNowJust/heredoc) [![Go Walker](http://gowalker.org/api/v1/badge)](https://gowalker.org/github.com/MakeNowJust/heredoc)
# heredoc
[![Build Status](https://circleci.com/gh/MakeNowJust/heredoc.svg?style=svg)](https://circleci.com/gh/MakeNowJust/heredoc) [![GoDoc](https://godoc.org/github.com/MakeNowJusti/heredoc?status.svg)](https://godoc.org/github.com/MakeNowJust/heredoc)
## About
@@ -15,8 +17,6 @@ $ go get github.com/MakeNowJust/heredoc
```go
// usual
import "github.com/MakeNowJust/heredoc"
// shortcuts
import . "github.com/MakeNowJust/heredoc/dot"
```
## Example
@@ -26,11 +26,11 @@ package main
import (
"fmt"
. "github.com/MakeNowJust/heredoc/dot"
"github.com/MakeNowJust/heredoc"
)
func main() {
fmt.Println(D(`
fmt.Println(heredoc.Doc(`
Lorem ipsum dolor sit amet, consectetur adipisicing elit,
sed do eiusmod tempor incididunt ut labore et dolore magna
aliqua. Ut enim ad minim veniam, ...
@@ -45,8 +45,7 @@ func main() {
## API Document
- [Go Walker - github.com/MakeNowJust/heredoc](https://gowalker.org/github.com/MakeNowJust/heredoc)
- [Go Walker - github.com/MakeNowJust/heredoc/dot](https://gowalker.org/github.com/MakeNowJust/heredoc/dot)
- [heredoc - GoDoc](https://godoc.org/github.com/MakeNowJust/heredoc)
## License

View File

@@ -1,24 +1,31 @@
// Copyright (c) 2014-2017 TSUYUSATO Kitsune
// Copyright (c) 2014-2019 TSUYUSATO Kitsune
// This software is released under the MIT License.
// http://opensource.org/licenses/mit-license.php
// Package heredoc provides creation of here-documents from raw strings.
//
// Golang supports raw-string syntax.
//
// doc := `
// Foo
// Bar
// `
//
// But raw-string cannot recognize indentation. Thus such content is an indented string, equivalent to
//
// "\n\tFoo\n\tBar\n"
//
// I dont't want this!
//
// However this problem is solved by package heredoc.
//
// doc := heredoc.Doc(`
// Foo
// Bar
// `)
//
// Is equivalent to
//
// "Foo\nBar\n"
package heredoc

View File

@@ -1,27 +0,0 @@
language: go
# Testing and linting occuring via go modules does not really work well prior
# to Go 1.12. This is what can happen with experiments.
go:
- 1.11.x
- 1.12.x
- 1.13.x
- tip
# Setting sudo access to false will let Travis CI use containers rather than
# VMs to run the tests. For more details see:
# - http://docs.travis-ci.com/user/workers/container-based-infrastructure/
# - http://docs.travis-ci.com/user/workers/standard-infrastructure/
sudo: false
script:
- make lint
- make test-cover
notifications:
webhooks:
urls:
- https://webhooks.gitter.im/e/06e3328629952dabe3e0
on_success: change # options: [always|never|change] default: always
on_failure: always # options: [always|never|change] default: always
on_start: never # options: [always|never|change] default: always

View File

@@ -1,5 +1,33 @@
# Changelog
## 3.1.1 (2020-11-23)
### Fixed
- #158: Fixed issue with generated regex operation order that could cause problem
## 3.1.0 (2020-04-15)
### Added
- #131: Add support for serializing/deserializing SQL (thanks @ryancurrah)
### Changed
- #148: More accurate validation messages on constraints
## 3.0.3 (2019-12-13)
### Fixed
- #141: Fixed issue with <= comparison
## 3.0.2 (2019-11-14)
### Fixed
- #134: Fixed broken constraint checking with ^0.0 (thanks @krmichelos)
## 3.0.1 (2019-09-13)
### Fixed

View File

@@ -9,7 +9,9 @@ The `semver` package provides the ability to work with [Semantic Versions](http:
[![Stability:
Active](https://masterminds.github.io/stability/active.svg)](https://masterminds.github.io/stability/active.html)
[![Build Status](https://travis-ci.org/Masterminds/semver.svg)](https://travis-ci.org/Masterminds/semver) [![Build status](https://ci.appveyor.com/api/projects/status/jfk66lib7hb985k8/branch/master?svg=true&passingText=windows%20build%20passing&failingText=windows%20build%20failing)](https://ci.appveyor.com/project/mattfarina/semver/branch/master) [![GoDoc](https://godoc.org/github.com/Masterminds/semver?status.svg)](https://godoc.org/github.com/Masterminds/semver) [![Go Report Card](https://goreportcard.com/badge/github.com/Masterminds/semver)](https://goreportcard.com/report/github.com/Masterminds/semver)
[![](https://github.com/Masterminds/semver/workflows/Tests/badge.svg)](https://github.com/Masterminds/semver/actions)
[![GoDoc](https://img.shields.io/static/v1?label=godoc&message=reference&color=blue)](https://pkg.go.dev/github.com/Masterminds/semver/v3)
[![Go Report Card](https://goreportcard.com/badge/github.com/Masterminds/semver)](https://goreportcard.com/report/github.com/Masterminds/semver)
If you are looking for a command line tool for version comparisons please see
[vert](https://github.com/Masterminds/vert) which uses this library.
@@ -219,7 +221,7 @@ if err != nil {
// Handle constraint not being parseable.
}
v, _ := semver.NewVersion("1.3")
v, err := semver.NewVersion("1.3")
if err != nil {
// Handle version not being parseable.
}

View File

@@ -1,17 +0,0 @@
version: build-{build}.{branch}
shallow_clone: true
environment:
GOPATH: C:\gopath
platform:
- x64
build_script:
- go install -v ./...
test_script:
- go test -v
deploy: off

View File

@@ -54,11 +54,13 @@ func NewConstraint(c string) (*Constraints, error) {
// Check tests if a version satisfies the constraints.
func (cs Constraints) Check(v *Version) bool {
// TODO(mattfarina): For v4 of this library consolidate the Check and Validate
// functions as the underlying functions make that possible now.
// loop over the ORs and check the inner ANDs
for _, o := range cs.constraints {
joy := true
for _, c := range o {
if !c.check(v) {
if check, _ := c.check(v); !check {
joy = false
break
}
@@ -96,9 +98,8 @@ func (cs Constraints) Validate(v *Version) (bool, []error) {
} else {
if !c.check(v) {
em := fmt.Errorf(constraintMsg[c.origfunc], v, c.orig)
e = append(e, em)
if _, err := c.check(v); err != nil {
e = append(e, err)
joy = false
}
}
@@ -134,7 +135,6 @@ func (cs Constraints) String() string {
}
var constraintOps map[string]cfunc
var constraintMsg map[string]string
var constraintRegex *regexp.Regexp
var constraintRangeRegex *regexp.Regexp
@@ -164,29 +164,11 @@ func init() {
"^": constraintCaret,
}
constraintMsg = map[string]string{
"": "%s is not equal to %s",
"=": "%s is not equal to %s",
"!=": "%s is equal to %s",
">": "%s is less than or equal to %s",
"<": "%s is greater than or equal to %s",
">=": "%s is less than %s",
"=>": "%s is less than %s",
"<=": "%s is greater than %s",
"=<": "%s is greater than %s",
"~": "%s does not have same major and minor version as %s",
"~>": "%s does not have same major and minor version as %s",
"^": "%s does not have same major version as %s",
}
ops := make([]string, 0, len(constraintOps))
for k := range constraintOps {
ops = append(ops, regexp.QuoteMeta(k))
}
ops := `=||!=|>|<|>=|=>|<=|=<|~|~>|\^`
constraintRegex = regexp.MustCompile(fmt.Sprintf(
`^\s*(%s)\s*(%s)\s*$`,
strings.Join(ops, "|"),
ops,
cvRegex))
constraintRangeRegex = regexp.MustCompile(fmt.Sprintf(
@@ -195,12 +177,12 @@ func init() {
findConstraintRegex = regexp.MustCompile(fmt.Sprintf(
`(%s)\s*(%s)`,
strings.Join(ops, "|"),
ops,
cvRegex))
validConstraintRegex = regexp.MustCompile(fmt.Sprintf(
`^(\s*(%s)\s*(%s)\s*\,?)+$`,
strings.Join(ops, "|"),
ops,
cvRegex))
}
@@ -223,7 +205,7 @@ type constraint struct {
}
// Check if a version meets the constraint
func (c *constraint) check(v *Version) bool {
func (c *constraint) check(v *Version) (bool, error) {
return constraintOps[c.origfunc](v, c)
}
@@ -232,7 +214,7 @@ func (c *constraint) string() string {
return c.origfunc + c.orig
}
type cfunc func(v *Version, c *constraint) bool
type cfunc func(v *Version, c *constraint) (bool, error)
func parseConstraint(c string) (*constraint, error) {
if len(c) > 0 {
@@ -301,111 +283,148 @@ func parseConstraint(c string) (*constraint, error) {
}
// Constraint functions
func constraintNotEqual(v *Version, c *constraint) bool {
func constraintNotEqual(v *Version, c *constraint) (bool, error) {
if c.dirty {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
return false, fmt.Errorf("%s is a prerelease version and the constraint is only looking for release versions", v)
}
if c.con.Major() != v.Major() {
return true
return true, nil
}
if c.con.Minor() != v.Minor() && !c.minorDirty {
return true
return true, nil
} else if c.minorDirty {
return false
return false, fmt.Errorf("%s is equal to %s", v, c.orig)
} else if c.con.Patch() != v.Patch() && !c.patchDirty {
return true
return true, nil
} else if c.patchDirty {
// Need to handle prereleases if present
if v.Prerelease() != "" || c.con.Prerelease() != "" {
return comparePrerelease(v.Prerelease(), c.con.Prerelease()) != 0
eq := comparePrerelease(v.Prerelease(), c.con.Prerelease()) != 0
if eq {
return true, nil
}
return false, fmt.Errorf("%s is equal to %s", v, c.orig)
}
return false
return false, fmt.Errorf("%s is equal to %s", v, c.orig)
}
}
return !v.Equal(c.con)
eq := v.Equal(c.con)
if eq {
return false, fmt.Errorf("%s is equal to %s", v, c.orig)
}
return true, nil
}
func constraintGreaterThan(v *Version, c *constraint) bool {
func constraintGreaterThan(v *Version, c *constraint) (bool, error) {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
return false, fmt.Errorf("%s is a prerelease version and the constraint is only looking for release versions", v)
}
var eq bool
if !c.dirty {
return v.Compare(c.con) == 1
eq = v.Compare(c.con) == 1
if eq {
return true, nil
}
return false, fmt.Errorf("%s is less than or equal to %s", v, c.orig)
}
if v.Major() > c.con.Major() {
return true
return true, nil
} else if v.Major() < c.con.Major() {
return false
return false, fmt.Errorf("%s is less than or equal to %s", v, c.orig)
} else if c.minorDirty {
// This is a range case such as >11. When the version is something like
// 11.1.0 is it not > 11. For that we would need 12 or higher
return false
return false, fmt.Errorf("%s is less than or equal to %s", v, c.orig)
} else if c.patchDirty {
// This is for ranges such as >11.1. A version of 11.1.1 is not greater
// which one of 11.2.1 is greater
return v.Minor() > c.con.Minor()
eq = v.Minor() > c.con.Minor()
if eq {
return true, nil
}
return false, fmt.Errorf("%s is less than or equal to %s", v, c.orig)
}
// If we have gotten here we are not comparing pre-preleases and can use the
// Compare function to accomplish that.
return v.Compare(c.con) == 1
eq = v.Compare(c.con) == 1
if eq {
return true, nil
}
return false, fmt.Errorf("%s is less than or equal to %s", v, c.orig)
}
func constraintLessThan(v *Version, c *constraint) bool {
func constraintLessThan(v *Version, c *constraint) (bool, error) {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
return false, fmt.Errorf("%s is a prerelease version and the constraint is only looking for release versions", v)
}
return v.Compare(c.con) < 0
eq := v.Compare(c.con) < 0
if eq {
return true, nil
}
return false, fmt.Errorf("%s is greater than or equal to %s", v, c.orig)
}
func constraintGreaterThanEqual(v *Version, c *constraint) bool {
func constraintGreaterThanEqual(v *Version, c *constraint) (bool, error) {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
return false, fmt.Errorf("%s is a prerelease version and the constraint is only looking for release versions", v)
}
return v.Compare(c.con) >= 0
eq := v.Compare(c.con) >= 0
if eq {
return true, nil
}
return false, fmt.Errorf("%s is less than %s", v, c.orig)
}
func constraintLessThanEqual(v *Version, c *constraint) bool {
func constraintLessThanEqual(v *Version, c *constraint) (bool, error) {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
return false, fmt.Errorf("%s is a prerelease version and the constraint is only looking for release versions", v)
}
var eq bool
if !c.dirty {
return v.Compare(c.con) <= 0
eq = v.Compare(c.con) <= 0
if eq {
return true, nil
}
return false, fmt.Errorf("%s is greater than %s", v, c.orig)
}
if v.Major() > c.con.Major() {
return false
} else if v.Minor() > c.con.Minor() && !c.minorDirty {
return false
return false, fmt.Errorf("%s is greater than %s", v, c.orig)
} else if v.Major() == c.con.Major() && v.Minor() > c.con.Minor() && !c.minorDirty {
return false, fmt.Errorf("%s is greater than %s", v, c.orig)
}
return true
return true, nil
}
// ~*, ~>* --> >= 0.0.0 (any)
@@ -414,51 +433,56 @@ func constraintLessThanEqual(v *Version, c *constraint) bool {
// ~1.2, ~1.2.x, ~>1.2, ~>1.2.x --> >=1.2.0, <1.3.0
// ~1.2.3, ~>1.2.3 --> >=1.2.3, <1.3.0
// ~1.2.0, ~>1.2.0 --> >=1.2.0, <1.3.0
func constraintTilde(v *Version, c *constraint) bool {
func constraintTilde(v *Version, c *constraint) (bool, error) {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
return false, fmt.Errorf("%s is a prerelease version and the constraint is only looking for release versions", v)
}
if v.LessThan(c.con) {
return false
return false, fmt.Errorf("%s is less than %s", v, c.orig)
}
// ~0.0.0 is a special case where all constraints are accepted. It's
// equivalent to >= 0.0.0.
if c.con.Major() == 0 && c.con.Minor() == 0 && c.con.Patch() == 0 &&
!c.minorDirty && !c.patchDirty {
return true
return true, nil
}
if v.Major() != c.con.Major() {
return false
return false, fmt.Errorf("%s does not have same major version as %s", v, c.orig)
}
if v.Minor() != c.con.Minor() && !c.minorDirty {
return false
return false, fmt.Errorf("%s does not have same major and minor version as %s", v, c.orig)
}
return true
return true, nil
}
// When there is a .x (dirty) status it automatically opts in to ~. Otherwise
// it's a straight =
func constraintTildeOrEqual(v *Version, c *constraint) bool {
func constraintTildeOrEqual(v *Version, c *constraint) (bool, error) {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
return false, fmt.Errorf("%s is a prerelease version and the constraint is only looking for release versions", v)
}
if c.dirty {
return constraintTilde(v, c)
}
return v.Equal(c.con)
eq := v.Equal(c.con)
if eq {
return true, nil
}
return false, fmt.Errorf("%s is not equal to %s", v, c.orig)
}
// ^* --> (any)
@@ -470,37 +494,54 @@ func constraintTildeOrEqual(v *Version, c *constraint) bool {
// ^0.0.3 --> >=0.0.3 <0.0.4
// ^0.0 --> >=0.0.0 <0.1.0
// ^0 --> >=0.0.0 <1.0.0
func constraintCaret(v *Version, c *constraint) bool {
func constraintCaret(v *Version, c *constraint) (bool, error) {
// If there is a pre-release on the version but the constraint isn't looking
// for them assume that pre-releases are not compatible. See issue 21 for
// more details.
if v.Prerelease() != "" && c.con.Prerelease() == "" {
return false
return false, fmt.Errorf("%s is a prerelease version and the constraint is only looking for release versions", v)
}
// This less than handles prereleases
if v.LessThan(c.con) {
return false
return false, fmt.Errorf("%s is less than %s", v, c.orig)
}
var eq bool
// ^ when the major > 0 is >=x.y.z < x+1
if c.con.Major() > 0 || c.minorDirty {
// ^ has to be within a major range for > 0. Everything less than was
// filtered out with the LessThan call above. This filters out those
// that greater but not within the same major range.
return v.Major() == c.con.Major()
eq = v.Major() == c.con.Major()
if eq {
return true, nil
}
return false, fmt.Errorf("%s does not have same major version as %s", v, c.orig)
}
// ^ when the major is 0 and minor > 0 is >=0.y.z < 0.y+1
if c.con.Major() == 0 && v.Major() > 0 {
return false, fmt.Errorf("%s does not have same major version as %s", v, c.orig)
}
// If the con Minor is > 0 it is not dirty
if c.con.Minor() > 0 || c.patchDirty {
return v.Minor() == c.con.Minor()
eq = v.Minor() == c.con.Minor()
if eq {
return true, nil
}
return false, fmt.Errorf("%s does not have same minor version as %s. Expected minor versions to match when constraint major version is 0", v, c.orig)
}
// At this point the major is 0 and the minor is 0 and not dirty. The patch
// is not dirty so we need to check if they are equal. If they are not equal
return c.con.Patch() == v.Patch()
eq = c.con.Patch() == v.Patch()
if eq {
return true, nil
}
return false, fmt.Errorf("%s does not equal %s. Expect version and constraint to equal when major and minor versions are 0", v, c.orig)
}
func isX(x string) bool {

View File

@@ -21,7 +21,7 @@ that can be sorted, compared, and used in constraints.
When parsing a version an optional error can be returned if there is an issue
parsing the version. For example,
v, err := semver.NewVersion("1.2.3-beta.1+build345")
v, err := semver.NewVersion("1.2.3-beta.1+b345")
The version object has methods to get the parts of the version, compare it to
other versions, convert the version back into a string, and get the original

View File

@@ -2,6 +2,7 @@ package semver
import (
"bytes"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
@@ -435,6 +436,28 @@ func (v Version) MarshalJSON() ([]byte, error) {
return json.Marshal(v.String())
}
// Scan implements the SQL.Scanner interface.
func (v *Version) Scan(value interface{}) error {
var s string
s, _ = value.(string)
temp, err := NewVersion(s)
if err != nil {
return err
}
v.major = temp.major
v.minor = temp.minor
v.patch = temp.patch
v.pre = temp.pre
v.metadata = temp.metadata
v.original = temp.original
return nil
}
// Value implements the Driver.Valuer interface.
func (v Version) Value() (driver.Value, error) {
return v.String(), nil
}
func compareSegment(v, o uint64) int {
if v < o {
return -1

View File

@@ -1,9 +0,0 @@
language: go
go:
- 1.12.x
- 1.13.x
- tip
script:
- make test-cover

View File

@@ -1,5 +1,77 @@
# Changelog
## Release 3.2.1 (2021-02-04)
### Changed
- Upgraded `Masterminds/goutils` to `v1.1.1`. see the [Security Advisory](https://github.com/Masterminds/goutils/security/advisories/GHSA-xg2h-wx96-xgxr)
## Release 3.2.0 (2020-12-14)
### Added
- #211: Added randInt function (thanks @kochurovro)
- #223: Added fromJson and mustFromJson functions (thanks @mholt)
- #242: Added a bcrypt function (thanks @robbiet480)
- #253: Added randBytes function (thanks @MikaelSmith)
- #254: Added dig function for dicts (thanks @nyarly)
- #257: Added regexQuoteMeta for quoting regex metadata (thanks @rheaton)
- #261: Added filepath functions osBase, osDir, osExt, osClean, osIsAbs (thanks @zugl)
- #268: Added and and all functions for testing conditions (thanks @phuslu)
- #181: Added float64 arithmetic addf, add1f, subf, divf, mulf, maxf, and minf
(thanks @andrewmostello)
- #265: Added chunk function to split array into smaller arrays (thanks @karelbilek)
- #270: Extend certificate functions to handle non-RSA keys + add support for
ed25519 keys (thanks @misberner)
### Changed
- Removed testing and support for Go 1.12. ed25519 support requires Go 1.13 or newer
- Using semver 3.1.1 and mergo 0.3.11
### Fixed
- #249: Fix htmlDateInZone example (thanks @spawnia)
NOTE: The dependency github.com/imdario/mergo reverted the breaking change in
0.3.9 via 0.3.10 release.
## Release 3.1.0 (2020-04-16)
NOTE: The dependency github.com/imdario/mergo made a behavior change in 0.3.9
that impacts sprig functionality. Do not use sprig with a version newer than 0.3.8.
### Added
- #225: Added support for generating htpasswd hash (thanks @rustycl0ck)
- #224: Added duration filter (thanks @frebib)
- #205: Added `seq` function (thanks @thadc23)
### Changed
- #203: Unlambda functions with correct signature (thanks @muesli)
- #236: Updated the license formatting for GitHub display purposes
- #238: Updated package dependency versions. Note, mergo not updated to 0.3.9
as it causes a breaking change for sprig. That issue is tracked at
https://github.com/imdario/mergo/issues/139
### Fixed
- #229: Fix `seq` example in docs (thanks @kalmant)
## Release 3.0.2 (2019-12-13)
### Fixed
- #220: Updating to semver v3.0.3 to fix issue with <= ranges
- #218: fix typo elyptical->elliptic in ecdsa key description (thanks @laverya)
## Release 3.0.1 (2019-12-08)
### Fixed
- #212: Updated semver fixing broken constraint checking with ^0.0
## Release 3.0.0 (2019-10-02)
### Added

View File

@@ -1,5 +1,4 @@
Sprig
Copyright (C) 2013 Masterminds
Copyright (C) 2013-2020 Masterminds
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@@ -1,7 +1,9 @@
# Sprig: Template functions for Go templates [![GoDoc](https://godoc.org/github.com/Masterminds/sprig?status.svg)](https://godoc.org/github.com/Masterminds/sprig) [![Go Report Card](https://goreportcard.com/badge/github.com/Masterminds/sprig)](https://goreportcard.com/report/github.com/Masterminds/sprig)
# Sprig: Template functions for Go templates
[![GoDoc](https://img.shields.io/static/v1?label=godoc&message=reference&color=blue)](https://pkg.go.dev/github.com/Masterminds/sprig/v3)
[![Go Report Card](https://goreportcard.com/badge/github.com/Masterminds/sprig)](https://goreportcard.com/report/github.com/Masterminds/sprig)
[![Stability: Sustained](https://masterminds.github.io/stability/sustained.svg)](https://masterminds.github.io/stability/sustained.html)
[![Build Status](https://travis-ci.org/Masterminds/sprig.svg?branch=master)](https://travis-ci.org/Masterminds/sprig)
[![](https://github.com/Masterminds/sprig/workflows/Tests/badge.svg)](https://github.com/Masterminds/sprig/actions)
The Go language comes with a [built-in template
language](http://golang.org/pkg/text/template/), but not
@@ -12,6 +14,14 @@ It is inspired by the template functions found in
[Twig](http://twig.sensiolabs.org/documentation) and in various
JavaScript libraries, such as [underscore.js](http://underscorejs.org/).
## IMPORTANT NOTES
Sprig leverages [mergo](https://github.com/imdario/mergo) to handle merges. In
its v0.3.9 release there was a behavior change that impacts merging template
functions in sprig. It is currently recommended to use v0.3.8 of that package.
Using v0.3.9 will cause sprig tests to fail. The issue in mergo is tracked at
https://github.com/imdario/mergo/issues/139.
## Package Versions
There are two active major versions of the `sprig` package.

View File

@@ -1,22 +0,0 @@
version: build-{build}.{branch}
shallow_clone: true
environment:
GOPATH: C:\gopath
platform:
- x64
install:
- go version
- go env
build_script:
- go install ./...
test_script:
- go test -v
deploy: off

View File

@@ -2,10 +2,12 @@ package sprig
import (
"bytes"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/dsa"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/hmac"
"crypto/rand"
@@ -27,7 +29,10 @@ import (
"net"
"time"
"strings"
"github.com/google/uuid"
bcrypt_lib "golang.org/x/crypto/bcrypt"
"golang.org/x/crypto/scrypt"
)
@@ -46,6 +51,30 @@ func adler32sum(input string) string {
return fmt.Sprintf("%d", hash)
}
func bcrypt(input string) string {
hash, err := bcrypt_lib.GenerateFromPassword([]byte(input), bcrypt_lib.DefaultCost)
if err != nil {
return fmt.Sprintf("failed to encrypt string with bcrypt: %s", err)
}
return string(hash)
}
func htpasswd(username string, password string) string {
if strings.Contains(username, ":") {
return fmt.Sprintf("invalid username: %s", username)
}
return fmt.Sprintf("%s:%s", username, bcrypt(password))
}
func randBytes(count int) (string, error) {
buf := make([]byte, count)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(buf), nil
}
// uuidv4 provides a safe and secure UUID v4 implementation
func uuidv4() string {
return uuid.New().String()
@@ -133,6 +162,8 @@ func generatePrivateKey(typ string) string {
case "ecdsa":
// again, good enough for government work
priv, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case "ed25519":
_, priv, err = ed25519.GenerateKey(rand.Reader)
default:
return "Unknown type " + typ
}
@@ -165,7 +196,73 @@ func pemBlockForKey(priv interface{}) *pem.Block {
b, _ := x509.MarshalECPrivateKey(k)
return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b}
default:
return nil
// attempt PKCS#8 format for all other keys
b, err := x509.MarshalPKCS8PrivateKey(k)
if err != nil {
return nil
}
return &pem.Block{Type: "PRIVATE KEY", Bytes: b}
}
}
func parsePrivateKeyPEM(pemBlock string) (crypto.PrivateKey, error) {
block, _ := pem.Decode([]byte(pemBlock))
if block == nil {
return nil, errors.New("no PEM data in input")
}
if block.Type == "PRIVATE KEY" {
priv, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("decoding PEM as PKCS#8: %s", err)
}
return priv, nil
} else if !strings.HasSuffix(block.Type, " PRIVATE KEY") {
return nil, fmt.Errorf("no private key data in PEM block of type %s", block.Type)
}
switch block.Type[:len(block.Type)-12] { // strip " PRIVATE KEY"
case "RSA":
priv, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parsing RSA private key from PEM: %s", err)
}
return priv, nil
case "EC":
priv, err := x509.ParseECPrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parsing EC private key from PEM: %s", err)
}
return priv, nil
case "DSA":
var k DSAKeyFormat
_, err := asn1.Unmarshal(block.Bytes, &k)
if err != nil {
return nil, fmt.Errorf("parsing DSA private key from PEM: %s", err)
}
priv := &dsa.PrivateKey{
PublicKey: dsa.PublicKey{
Parameters: dsa.Parameters{
P: k.P, Q: k.Q, G: k.G,
},
Y: k.Y,
},
X: k.X,
}
return priv, nil
default:
return nil, fmt.Errorf("invalid private key type %s", block.Type)
}
}
func getPublicKey(priv crypto.PrivateKey) (crypto.PublicKey, error) {
switch k := priv.(type) {
case interface{ Public() crypto.PublicKey }:
return k.Public(), nil
case *dsa.PrivateKey:
return &k.PublicKey, nil
default:
return nil, fmt.Errorf("unable to get public key for type %T", priv)
}
}
@@ -199,14 +296,10 @@ func buildCustomCertificate(b64cert string, b64key string) (certificate, error)
)
}
decodedKey, _ := pem.Decode(key)
if decodedKey == nil {
return crt, errors.New("unable to decode key")
}
_, err = x509.ParsePKCS1PrivateKey(decodedKey.Bytes)
_, err = parsePrivateKeyPEM(string(key))
if err != nil {
return crt, fmt.Errorf(
"error parsing prive key: decodedKey.Bytes: %s",
"error parsing private key: %s",
err,
)
}
@@ -220,6 +313,31 @@ func buildCustomCertificate(b64cert string, b64key string) (certificate, error)
func generateCertificateAuthority(
cn string,
daysValid int,
) (certificate, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return certificate{}, fmt.Errorf("error generating rsa key: %s", err)
}
return generateCertificateAuthorityWithKeyInternal(cn, daysValid, priv)
}
func generateCertificateAuthorityWithPEMKey(
cn string,
daysValid int,
privPEM string,
) (certificate, error) {
priv, err := parsePrivateKeyPEM(privPEM)
if err != nil {
return certificate{}, fmt.Errorf("parsing private key: %s", err)
}
return generateCertificateAuthorityWithKeyInternal(cn, daysValid, priv)
}
func generateCertificateAuthorityWithKeyInternal(
cn string,
daysValid int,
priv crypto.PrivateKey,
) (certificate, error) {
ca := certificate{}
@@ -233,11 +351,6 @@ func generateCertificateAuthority(
x509.KeyUsageCertSign
template.IsCA = true
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return ca, fmt.Errorf("error generating rsa key: %s", err)
}
ca.Cert, ca.Key, err = getCertAndKey(template, priv, template, priv)
return ca, err
@@ -248,6 +361,34 @@ func generateSelfSignedCertificate(
ips []interface{},
alternateDNS []interface{},
daysValid int,
) (certificate, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return certificate{}, fmt.Errorf("error generating rsa key: %s", err)
}
return generateSelfSignedCertificateWithKeyInternal(cn, ips, alternateDNS, daysValid, priv)
}
func generateSelfSignedCertificateWithPEMKey(
cn string,
ips []interface{},
alternateDNS []interface{},
daysValid int,
privPEM string,
) (certificate, error) {
priv, err := parsePrivateKeyPEM(privPEM)
if err != nil {
return certificate{}, fmt.Errorf("parsing private key: %s", err)
}
return generateSelfSignedCertificateWithKeyInternal(cn, ips, alternateDNS, daysValid, priv)
}
func generateSelfSignedCertificateWithKeyInternal(
cn string,
ips []interface{},
alternateDNS []interface{},
daysValid int,
priv crypto.PrivateKey,
) (certificate, error) {
cert := certificate{}
@@ -256,11 +397,6 @@ func generateSelfSignedCertificate(
return cert, err
}
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return cert, fmt.Errorf("error generating rsa key: %s", err)
}
cert.Cert, cert.Key, err = getCertAndKey(template, priv, template, priv)
return cert, err
@@ -272,6 +408,36 @@ func generateSignedCertificate(
alternateDNS []interface{},
daysValid int,
ca certificate,
) (certificate, error) {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return certificate{}, fmt.Errorf("error generating rsa key: %s", err)
}
return generateSignedCertificateWithKeyInternal(cn, ips, alternateDNS, daysValid, ca, priv)
}
func generateSignedCertificateWithPEMKey(
cn string,
ips []interface{},
alternateDNS []interface{},
daysValid int,
ca certificate,
privPEM string,
) (certificate, error) {
priv, err := parsePrivateKeyPEM(privPEM)
if err != nil {
return certificate{}, fmt.Errorf("parsing private key: %s", err)
}
return generateSignedCertificateWithKeyInternal(cn, ips, alternateDNS, daysValid, ca, priv)
}
func generateSignedCertificateWithKeyInternal(
cn string,
ips []interface{},
alternateDNS []interface{},
daysValid int,
ca certificate,
priv crypto.PrivateKey,
) (certificate, error) {
cert := certificate{}
@@ -286,14 +452,10 @@ func generateSignedCertificate(
err,
)
}
decodedSignerKey, _ := pem.Decode([]byte(ca.Key))
if decodedSignerKey == nil {
return cert, errors.New("unable to decode key")
}
signerKey, err := x509.ParsePKCS1PrivateKey(decodedSignerKey.Bytes)
signerKey, err := parsePrivateKeyPEM(ca.Key)
if err != nil {
return cert, fmt.Errorf(
"error parsing prive key: decodedSignerKey.Bytes: %s",
"error parsing private key: %s",
err,
)
}
@@ -303,11 +465,6 @@ func generateSignedCertificate(
return cert, err
}
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return cert, fmt.Errorf("error generating rsa key: %s", err)
}
cert.Cert, cert.Key, err = getCertAndKey(
template,
priv,
@@ -320,15 +477,19 @@ func generateSignedCertificate(
func getCertAndKey(
template *x509.Certificate,
signeeKey *rsa.PrivateKey,
signeeKey crypto.PrivateKey,
parent *x509.Certificate,
signingKey *rsa.PrivateKey,
signingKey crypto.PrivateKey,
) (string, string, error) {
signeePubKey, err := getPublicKey(signeeKey)
if err != nil {
return "", "", fmt.Errorf("error retrieving public key from signee key: %s", err)
}
derBytes, err := x509.CreateCertificate(
rand.Reader,
template,
parent,
&signeeKey.PublicKey,
signeePubKey,
signingKey,
)
if err != nil {
@@ -346,10 +507,7 @@ func getCertAndKey(
keyBuffer := bytes.Buffer{}
if err := pem.Encode(
&keyBuffer,
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(signeeKey),
},
pemBlockForKey(signeeKey),
); err != nil {
return "", "", fmt.Errorf("error pem-encoding key: %s", err)
}

View File

@@ -81,6 +81,19 @@ func dateAgo(date interface{}) string {
return duration.String()
}
func duration(sec interface{}) string {
var n int64
switch value := sec.(type) {
default:
n = 0
case string:
n, _ = strconv.ParseInt(value, 10, 64)
case int64:
n = value
}
return (time.Duration(n) * time.Second).String()
}
func durationRound(duration interface{}) string {
var d time.Duration
switch duration := duration.(type) {

View File

@@ -3,10 +3,16 @@ package sprig
import (
"bytes"
"encoding/json"
"math/rand"
"reflect"
"strings"
"time"
)
func init() {
rand.Seed(time.Now().UnixNano())
}
// dfault checks whether `given` is set, and returns default if not set.
//
// This returns `d` if `given` appears not to be set, and `given` otherwise.
@@ -63,6 +69,41 @@ func coalesce(v ...interface{}) interface{} {
return nil
}
// all returns true if empty(x) is false for all values x in the list.
// If the list is empty, return true.
func all(v ...interface{}) bool {
for _, val := range v {
if empty(val) {
return false
}
}
return true
}
// any returns true if empty(x) is false for any x in the list.
// If the list is empty, return false.
func any(v ...interface{}) bool {
for _, val := range v {
if !empty(val) {
return true
}
}
return false
}
// fromJson decodes JSON into a structured value, ignoring errors.
func fromJson(v string) interface{} {
output, _ := mustFromJson(v)
return output
}
// mustFromJson decodes JSON into a structured value, returning errors.
func mustFromJson(v string) (interface{}, error) {
var output interface{}
err := json.Unmarshal([]byte(v), &output)
return output, err
}
// toJson encodes an item into a JSON string
func toJson(v interface{}) string {
output, _ := json.Marshal(v)

View File

@@ -146,3 +146,29 @@ func deepCopy(i interface{}) interface{} {
func mustDeepCopy(i interface{}) (interface{}, error) {
return copystructure.Copy(i)
}
func dig(ps ...interface{}) (interface{}, error) {
if len(ps) < 3 {
panic("dig needs at least three arguments")
}
dict := ps[len(ps)-1].(map[string]interface{})
def := ps[len(ps)-2]
ks := make([]string, len(ps)-2)
for i := 0; i < len(ks); i++ {
ks[i] = ps[i].(string)
}
return digFromDict(dict, def, ks)
}
func digFromDict(dict map[string]interface{}, d interface{}, ks []string) (interface{}, error) {
k, ns := ks[0], ks[1:len(ks)]
step, has := dict[k]
if !has {
return d, nil
}
if len(ns) == 0 {
return step, nil
}
return digFromDict(step.(map[string]interface{}), d, ns)
}

View File

@@ -3,8 +3,10 @@ package sprig
import (
"errors"
"html/template"
"math/rand"
"os"
"path"
"path/filepath"
"reflect"
"strconv"
"strings"
@@ -13,6 +15,7 @@ import (
util "github.com/Masterminds/goutils"
"github.com/huandu/xstrings"
"github.com/shopspring/decimal"
)
// FuncMap produces the function map.
@@ -80,6 +83,7 @@ var nonhermeticFunctions = []string{
"randAlpha",
"randAscii",
"randNumeric",
"randBytes",
"uuidv4",
// OS
@@ -100,13 +104,14 @@ var genericMap = map[string]interface{}{
"date_modify": dateModify,
"dateInZone": dateInZone,
"dateModify": dateModify,
"duration": duration,
"durationRound": durationRound,
"htmlDate": htmlDate,
"htmlDateInZone": htmlDateInZone,
"must_date_modify": mustDateModify,
"mustDateModify": mustDateModify,
"mustToDate": mustToDate,
"now": func() time.Time { return time.Now() },
"now": time.Now,
"toDate": toDate,
"unixEpoch": unixEpoch,
@@ -162,6 +167,7 @@ var genericMap = map[string]interface{}{
"int64": toInt64,
"int": toInt,
"float64": toFloat64,
"seq": seq,
"toDecimal": toDecimal,
//"gt": func(a, b int) bool {return a > b},
@@ -198,9 +204,28 @@ var genericMap = map[string]interface{}{
}
return val
},
"randInt": func(min, max int) int { return rand.Intn(max-min) + min },
"add1f": func(i interface{}) float64 {
return execDecimalOp(i, []interface{}{1}, func(d1, d2 decimal.Decimal) decimal.Decimal { return d1.Add(d2) })
},
"addf": func(i ...interface{}) float64 {
a := interface{}(float64(0))
return execDecimalOp(a, i, func(d1, d2 decimal.Decimal) decimal.Decimal { return d1.Add(d2) })
},
"subf": func(a interface{}, v ...interface{}) float64 {
return execDecimalOp(a, v, func(d1, d2 decimal.Decimal) decimal.Decimal { return d1.Sub(d2) })
},
"divf": func(a interface{}, v ...interface{}) float64 {
return execDecimalOp(a, v, func(d1, d2 decimal.Decimal) decimal.Decimal { return d1.Div(d2) })
},
"mulf": func(a interface{}, v ...interface{}) float64 {
return execDecimalOp(a, v, func(d1, d2 decimal.Decimal) decimal.Decimal { return d1.Mul(d2) })
},
"biggest": max,
"max": max,
"min": min,
"maxf": maxf,
"minf": minf,
"ceil": ceil,
"floor": floor,
"round": round,
@@ -214,11 +239,15 @@ var genericMap = map[string]interface{}{
"default": dfault,
"empty": empty,
"coalesce": coalesce,
"all": all,
"any": any,
"compact": compact,
"mustCompact": mustCompact,
"fromJson": fromJson,
"toJson": toJson,
"toPrettyJson": toPrettyJson,
"toRawJson": toRawJson,
"mustFromJson": mustFromJson,
"mustToJson": mustToJson,
"mustToPrettyJson": mustToPrettyJson,
"mustToRawJson": mustToRawJson,
@@ -235,19 +264,26 @@ var genericMap = map[string]interface{}{
"deepEqual": reflect.DeepEqual,
// OS:
"env": func(s string) string { return os.Getenv(s) },
"expandenv": func(s string) string { return os.ExpandEnv(s) },
"env": os.Getenv,
"expandenv": os.ExpandEnv,
// Network:
"getHostByName": getHostByName,
// File Paths:
// Paths:
"base": path.Base,
"dir": path.Dir,
"clean": path.Clean,
"ext": path.Ext,
"isAbs": path.IsAbs,
// Filepaths:
"osBase": filepath.Base,
"osClean": filepath.Clean,
"osDir": filepath.Dir,
"osExt": filepath.Ext,
"osIsAbs": filepath.IsAbs,
// Encoding:
"b64enc": base64encode,
"b64dec": base64decode,
@@ -295,16 +331,25 @@ var genericMap = map[string]interface{}{
"slice": slice,
"mustSlice": mustSlice,
"concat": concat,
"dig": dig,
"chunk": chunk,
"mustChunk": mustChunk,
// Crypto:
"bcrypt": bcrypt,
"htpasswd": htpasswd,
"genPrivateKey": generatePrivateKey,
"derivePassword": derivePassword,
"buildCustomCert": buildCustomCertificate,
"genCA": generateCertificateAuthority,
"genCAWithKey": generateCertificateAuthorityWithPEMKey,
"genSelfSignedCert": generateSelfSignedCertificate,
"genSelfSignedCertWithKey": generateSelfSignedCertificateWithPEMKey,
"genSignedCert": generateSignedCertificate,
"genSignedCertWithKey": generateSignedCertificateWithPEMKey,
"encryptAES": encryptAES,
"decryptAES": decryptAES,
"randBytes": randBytes,
// UUIDs:
"uuidv4": uuidv4,
@@ -329,6 +374,7 @@ var genericMap = map[string]interface{}{
"mustRegexReplaceAllLiteral": mustRegexReplaceAllLiteral,
"regexSplit": regexSplit,
"mustRegexSplit": mustRegexSplit,
"regexQuoteMeta": regexQuoteMeta,
// URLs:
"urlParse": urlParse,

View File

@@ -2,6 +2,7 @@ package sprig
import (
"fmt"
"math"
"reflect"
"sort"
)
@@ -72,6 +73,50 @@ func mustPrepend(list interface{}, v interface{}) ([]interface{}, error) {
}
}
func chunk(size int, list interface{}) [][]interface{} {
l, err := mustChunk(size, list)
if err != nil {
panic(err)
}
return l
}
func mustChunk(size int, list interface{}) ([][]interface{}, error) {
tp := reflect.TypeOf(list).Kind()
switch tp {
case reflect.Slice, reflect.Array:
l2 := reflect.ValueOf(list)
l := l2.Len()
cs := int(math.Floor(float64(l-1)/float64(size)) + 1)
nl := make([][]interface{}, cs)
for i := 0; i < cs; i++ {
clen := size
if i == cs-1 {
clen = int(math.Floor(math.Mod(float64(l), float64(size))))
if clen == 0 {
clen = size
}
}
nl[i] = make([]interface{}, clen)
for j := 0; j < clen; j++ {
ix := i*size + j
nl[i][j] = l2.Index(ix).Interface()
}
}
return nl, nil
default:
return nil, fmt.Errorf("Cannot chunk type %s", tp)
}
}
func last(list interface{}) interface{} {
l, err := mustLast(list)
if err != nil {

View File

@@ -7,6 +7,6 @@ import (
func getHostByName(name string) string {
addrs, _ := net.LookupHost(name)
//TODO: add error handing when release v3 cames out
//TODO: add error handing when release v3 comes out
return addrs[rand.Intn(len(addrs))]
}

View File

@@ -4,8 +4,10 @@ import (
"fmt"
"math"
"strconv"
"strings"
"github.com/spf13/cast"
"github.com/shopspring/decimal"
)
// toFloat64 converts 64-bit floats
@@ -33,6 +35,15 @@ func max(a interface{}, i ...interface{}) int64 {
return aa
}
func maxf(a interface{}, i ...interface{}) float64 {
aa := toFloat64(a)
for _, b := range i {
bb := toFloat64(b)
aa = math.Max(aa, bb)
}
return aa
}
func min(a interface{}, i ...interface{}) int64 {
aa := toInt64(a)
for _, b := range i {
@@ -44,6 +55,15 @@ func min(a interface{}, i ...interface{}) int64 {
return aa
}
func minf(a interface{}, i ...interface{}) float64 {
aa := toFloat64(a)
for _, b := range i {
bb := toFloat64(b)
aa = math.Min(aa, bb)
}
return aa
}
func until(count int) []int {
step := 1
if count < 0 {
@@ -112,3 +132,55 @@ func toDecimal(v interface{}) int64 {
}
return result
}
func seq(params ...int) string {
increment := 1
switch len(params) {
case 0:
return ""
case 1:
start := 1
end := params[0]
if end < start {
increment = -1
}
return intArrayToString(untilStep(start, end+increment, increment), " ")
case 3:
start := params[0]
end := params[2]
step := params[1]
if end < start {
increment = -1
if step > 0 {
return ""
}
}
return intArrayToString(untilStep(start, end+increment, step), " ")
case 2:
start := params[0]
end := params[1]
step := 1
if end < start {
step = -1
}
return intArrayToString(untilStep(start, end+step, step), " ")
default:
return ""
}
}
func intArrayToString(slice []int, delimeter string) string {
return strings.Trim(strings.Join(strings.Fields(fmt.Sprint(slice)), delimeter), "[]")
}
// performs a float and subsequent decimal.Decimal conversion on inputs,
// and iterates through a and b executing the mathmetical operation f
func execDecimalOp(a interface{}, b []interface{}, f func(d1, d2 decimal.Decimal) decimal.Decimal) float64 {
prt := decimal.NewFromFloat(toFloat64(a))
for _, x := range b {
dx := decimal.NewFromFloat(toFloat64(x))
prt = f(prt, dx)
}
rslt, _ := prt.Float64()
return rslt
}

View File

@@ -77,3 +77,7 @@ func mustRegexSplit(regex string, s string, n int) ([]string, error) {
}
return r.Split(s, n), nil
}
func regexQuoteMeta(s string) string {
return regexp.QuoteMeta(s)
}

View File

@@ -1,12 +1,13 @@
language: go
go:
- 1.1
- 1.2
- 1.3
- 1.4
- 1.5
- tip
- 1.11.x
- 1.12.x
- 1.13.x
services:
- mysql
- postgresql
# Setting sudo access to false will let Travis CI use containers rather than
# VMs to run the tests. For more details see:
@@ -14,9 +15,16 @@ go:
# - http://docs.travis-ci.com/user/workers/standard-infrastructure/
sudo: false
install:
- go get
- go get github.com/stretchr/testify/assert
before_script:
- mysql -e 'CREATE DATABASE squirrel;'
- psql -c 'CREATE DATABASE squirrel;' -U postgres
script:
- go test
- cd integration
- go test -args -driver sqlite3
- go test -args -driver mysql -dataSource travis@/squirrel
- go test -args -driver postgres -dataSource 'postgres://postgres@localhost/squirrel?sslmode=disable'
notifications:
irc: "irc.freenode.net#masterminds"

View File

@@ -1,24 +1,20 @@
[![Stability: Maintenance](https://masterminds.github.io/stability/maintenance.svg)](https://masterminds.github.io/stability/maintenance.html)
### Squirrel is "complete".
Bug fixes will still be merged (slowly). Bug reports are welcome, but I will not necessarily respond to them. If another fork (or substantially similar project) actively improves on what Squirrel does, let me know and I may link to it here.
# Squirrel - fluent SQL generator for Go
```go
import "gopkg.in/Masterminds/squirrel.v1"
```
or if you prefer using `master` (which may be arbitrarily ahead of or behind `v1`):
**NOTE:** as of Go 1.6, `go get` correctly clones the Github default branch (which is `v1` in this repo).
```go
import "github.com/Masterminds/squirrel"
```
[![GoDoc](https://godoc.org/github.com/Masterminds/squirrel?status.png)](https://godoc.org/github.com/Masterminds/squirrel)
[![Build Status](https://travis-ci.org/Masterminds/squirrel.svg?branch=v1)](https://travis-ci.org/Masterminds/squirrel)
_**Note:** This project has moved from `github.com/lann/squirrel` to
`github.com/Masterminds/squirrel`. Lann remains the architect of the
project, but we're helping him curate.
[![GoDoc](https://godoc.org/github.com/Masterminds/squirrel?status.png)](https://godoc.org/github.com/Masterminds/squirrel)
[![Build Status](https://api.travis-ci.org/Masterminds/squirrel.svg?branch=master)](https://travis-ci.org/Masterminds/squirrel)
**Squirrel is not an ORM.** For an application of Squirrel, check out
[structable, a table-struct mapper](https://github.com/technosophos/structable)
[structable, a table-struct mapper](https://github.com/Masterminds/structable)
Squirrel helps you build SQL queries from composable parts:
@@ -68,7 +64,7 @@ Squirrel wants to make your life easier:
```go
// StmtCache caches Prepared Stmts for you
dbCache := sq.NewStmtCacher(db)
dbCache := sq.NewStmtCache(db)
// StatementBuilder keeps your syntax neat
mydb := sq.StatementBuilder.RunWith(dbCache)
@@ -81,7 +77,7 @@ Squirrel loves PostgreSQL:
psql := sq.StatementBuilder.PlaceholderFormat(sq.Dollar)
// You use question marks for placeholders...
sql, _, _ := psql.Select("*").From("elephants").Where("name IN (?,?)", "Dumbo", "Verna")
sql, _, _ := psql.Select("*").From("elephants").Where("name IN (?,?)", "Dumbo", "Verna").ToSql()
/// ...squirrel replaces them using PlaceholderFormat.
sql == "SELECT * FROM elephants WHERE name IN ($1,$2)"
@@ -98,7 +94,7 @@ query := sq.Insert("nodes").
query.QueryRow().Scan(&node.id)
```
You can escape question mask by inserting two question marks:
You can escape question marks by inserting two question marks:
```sql
SELECT * FROM nodes WHERE meta->'format' ??| array[?,?]
@@ -110,7 +106,35 @@ will generate with the Dollar Placeholder:
SELECT * FROM nodes WHERE meta->'format' ?| array[$1,$2]
```
## FAQ
* **How can I build an IN query on composite keys / tuples, e.g. `WHERE (col1, col2) IN ((1,2),(3,4))`? ([#104](https://github.com/Masterminds/squirrel/issues/104))**
Squirrel does not explicitly support tuples, but you can get the same effect with e.g.:
```go
sq.Or{
sq.Eq{"col1": 1, "col2": 2},
sq.Eq{"col1": 3, "col2": 4}}
```
```sql
WHERE (col1 = 1 AND col2 = 2) OR (col1 = 3 AND col2 = 4)
```
(which should produce the same query plan as the tuple version)
* **Why doesn't `Eq{"mynumber": []uint8{1,2,3}}` turn into an `IN` query? ([#114](https://github.com/Masterminds/squirrel/issues/114))**
Values of type `[]byte` are handled specially by `database/sql`. In Go, [`byte` is just an alias of `uint8`](https://golang.org/pkg/builtin/#byte), so there is no way to distinguish `[]uint8` from `[]byte`.
* **Some features are poorly documented!**
This isn't a frequent complaints section!
* **Some features are poorly documented?**
Yes. The tests should be considered a part of the documentation; take a look at those for ideas on how to express more complex queries.
## License

View File

@@ -27,7 +27,7 @@ func (b *sqlizerBuffer) WriteSql(item Sqlizer) {
var str string
var args []interface{}
str, args, b.err = item.ToSql()
str, args, b.err = nestedToSql(item)
if b.err != nil {
return
@@ -100,6 +100,16 @@ func (b CaseBuilder) ToSql() (string, []interface{}, error) {
return data.ToSql()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b CaseBuilder) MustSql() (string, []interface{}) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// what sets optional value for CASE construct "CASE [value] ..."
func (b CaseBuilder) what(expr interface{}) CaseBuilder {
return builder.Set(b, "What", newPart(expr)).(CaseBuilder)

View File

@@ -4,20 +4,21 @@ import (
"bytes"
"database/sql"
"fmt"
"github.com/lann/builder"
"strings"
"github.com/lann/builder"
)
type deleteData struct {
PlaceholderFormat PlaceholderFormat
RunWith BaseRunner
Prefixes exprs
Prefixes []Sqlizer
From string
WhereParts []Sqlizer
OrderBys []string
Limit string
Offset string
Suffixes exprs
Suffixes []Sqlizer
}
func (d *deleteData) Exec() (sql.Result, error) {
@@ -36,7 +37,11 @@ func (d *deleteData) ToSql() (sqlStr string, args []interface{}, err error) {
sql := &bytes.Buffer{}
if len(d.Prefixes) > 0 {
args, _ = d.Prefixes.AppendToSql(sql, " ", args)
args, err = appendToSql(d.Prefixes, sql, " ", args)
if err != nil {
return
}
sql.WriteString(" ")
}
@@ -68,14 +73,16 @@ func (d *deleteData) ToSql() (sqlStr string, args []interface{}, err error) {
if len(d.Suffixes) > 0 {
sql.WriteString(" ")
args, _ = d.Suffixes.AppendToSql(sql, " ", args)
args, err = appendToSql(d.Suffixes, sql, " ", args)
if err != nil {
return
}
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String())
return
}
// Builder
// DeleteBuilder builds SQL DELETE statements.
@@ -114,9 +121,24 @@ func (b DeleteBuilder) ToSql() (string, []interface{}, error) {
return data.ToSql()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b DeleteBuilder) MustSql() (string, []interface{}) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// Prefix adds an expression to the beginning of the query
func (b DeleteBuilder) Prefix(sql string, args ...interface{}) DeleteBuilder {
return builder.Append(b, "Prefixes", Expr(sql, args...)).(DeleteBuilder)
return b.PrefixExpr(Expr(sql, args...))
}
// PrefixExpr adds an expression to the very beginning of the query
func (b DeleteBuilder) PrefixExpr(expr Sqlizer) DeleteBuilder {
return builder.Append(b, "Prefixes", expr).(DeleteBuilder)
}
// From sets the table to be deleted from.
@@ -148,5 +170,22 @@ func (b DeleteBuilder) Offset(offset uint64) DeleteBuilder {
// Suffix adds an expression to the end of the query
func (b DeleteBuilder) Suffix(sql string, args ...interface{}) DeleteBuilder {
return builder.Append(b, "Suffixes", Expr(sql, args...)).(DeleteBuilder)
return b.SuffixExpr(Expr(sql, args...))
}
// SuffixExpr adds an expression to the end of the query
func (b DeleteBuilder) SuffixExpr(expr Sqlizer) DeleteBuilder {
return builder.Append(b, "Suffixes", expr).(DeleteBuilder)
}
func (b DeleteBuilder) Query() (*sql.Rows, error) {
data := builder.GetStruct(b).(deleteData)
return data.Query()
}
func (d *deleteData) Query() (*sql.Rows, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
return QueryWith(d.RunWith, d)
}

69
vendor/github.com/Masterminds/squirrel/delete_ctx.go generated vendored Normal file
View File

@@ -0,0 +1,69 @@
// +build go1.8
package squirrel
import (
"context"
"database/sql"
"github.com/lann/builder"
)
func (d *deleteData) ExecContext(ctx context.Context) (sql.Result, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(ExecerContext)
if !ok {
return nil, NoContextSupport
}
return ExecContextWith(ctx, ctxRunner, d)
}
func (d *deleteData) QueryContext(ctx context.Context) (*sql.Rows, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(QueryerContext)
if !ok {
return nil, NoContextSupport
}
return QueryContextWith(ctx, ctxRunner, d)
}
func (d *deleteData) QueryRowContext(ctx context.Context) RowScanner {
if d.RunWith == nil {
return &Row{err: RunnerNotSet}
}
queryRower, ok := d.RunWith.(QueryRowerContext)
if !ok {
if _, ok := d.RunWith.(QueryerContext); !ok {
return &Row{err: RunnerNotQueryRunner}
}
return &Row{err: NoContextSupport}
}
return QueryRowContextWith(ctx, queryRower, d)
}
// ExecContext builds and ExecContexts the query with the Runner set by RunWith.
func (b DeleteBuilder) ExecContext(ctx context.Context) (sql.Result, error) {
data := builder.GetStruct(b).(deleteData)
return data.ExecContext(ctx)
}
// QueryContext builds and QueryContexts the query with the Runner set by RunWith.
func (b DeleteBuilder) QueryContext(ctx context.Context) (*sql.Rows, error) {
data := builder.GetStruct(b).(deleteData)
return data.QueryContext(ctx)
}
// QueryRowContext builds and QueryRowContexts the query with the Runner set by RunWith.
func (b DeleteBuilder) QueryRowContext(ctx context.Context) RowScanner {
data := builder.GetStruct(b).(deleteData)
return data.QueryRowContext(ctx)
}
// ScanContext is a shortcut for QueryRowContext().Scan.
func (b DeleteBuilder) ScanContext(ctx context.Context, dest ...interface{}) error {
return b.QueryRowContext(ctx).Scan(dest...)
}

View File

@@ -1,47 +1,114 @@
package squirrel
import (
"bytes"
"database/sql/driver"
"fmt"
"io"
"reflect"
"sort"
"strings"
)
const (
// Portable true/false literals.
sqlTrue = "(1=1)"
sqlFalse = "(1=0)"
)
type expr struct {
sql string
args []interface{}
}
// Expr builds value expressions for InsertBuilder and UpdateBuilder.
// Expr builds an expression from a SQL fragment and arguments.
//
// Ex:
// .Values(Expr("FROM_UNIXTIME(?)", t))
func Expr(sql string, args ...interface{}) expr {
// Expr("FROM_UNIXTIME(?)", t)
func Expr(sql string, args ...interface{}) Sqlizer {
return expr{sql: sql, args: args}
}
func (e expr) ToSql() (sql string, args []interface{}, err error) {
return e.sql, e.args, nil
simple := true
for _, arg := range e.args {
if _, ok := arg.(Sqlizer); ok {
simple = false
}
}
if simple {
return e.sql, e.args, nil
}
buf := &bytes.Buffer{}
ap := e.args
sp := e.sql
var isql string
var iargs []interface{}
for err == nil && len(ap) > 0 && len(sp) > 0 {
i := strings.Index(sp, "?")
if i < 0 {
// no more placeholders
break
}
if len(sp) > i+1 && sp[i+1:i+2] == "?" {
// escaped "??"; append it and step past
buf.WriteString(sp[:i+2])
sp = sp[i+2:]
continue
}
if as, ok := ap[0].(Sqlizer); ok {
// sqlizer argument; expand it and append the result
isql, iargs, err = as.ToSql()
buf.WriteString(sp[:i])
buf.WriteString(isql)
args = append(args, iargs...)
} else {
// normal argument; append it and the placeholder
buf.WriteString(sp[:i+1])
args = append(args, ap[0])
}
// step past the argument and placeholder
ap = ap[1:]
sp = sp[i+1:]
}
// append the remaining sql and arguments
buf.WriteString(sp)
return buf.String(), append(args, ap...), err
}
type exprs []expr
type concatExpr []interface{}
func (es exprs) AppendToSql(w io.Writer, sep string, args []interface{}) ([]interface{}, error) {
for i, e := range es {
if i > 0 {
_, err := io.WriteString(w, sep)
func (ce concatExpr) ToSql() (sql string, args []interface{}, err error) {
for _, part := range ce {
switch p := part.(type) {
case string:
sql += p
case Sqlizer:
pSql, pArgs, err := p.ToSql()
if err != nil {
return nil, err
return "", nil, err
}
sql += pSql
args = append(args, pArgs...)
default:
return "", nil, fmt.Errorf("%#v is not a string or Sqlizer", part)
}
_, err := io.WriteString(w, e.sql)
if err != nil {
return nil, err
}
args = append(args, e.args...)
}
return args, nil
return
}
// ConcatExpr builds an expression by concatenating strings and other expressions.
//
// Ex:
// name_expr := Expr("CONCAT(?, ' ', ?)", firstName, lastName)
// ConcatExpr("COALESCE(full_name,", name_expr, ")")
func ConcatExpr(parts ...interface{}) concatExpr {
return concatExpr(parts)
}
// aliasExpr helps to alias part of SQL query generated with underlying "expr"
@@ -67,26 +134,34 @@ func (e aliasExpr) ToSql() (sql string, args []interface{}, err error) {
}
// Eq is syntactic sugar for use with Where/Having/Set methods.
// Ex:
// .Where(Eq{"id": 1})
type Eq map[string]interface{}
func (eq Eq) toSql(useNotOpr bool) (sql string, args []interface{}, err error) {
func (eq Eq) toSQL(useNotOpr bool) (sql string, args []interface{}, err error) {
if len(eq) == 0 {
// Empty Sql{} evaluates to true.
sql = sqlTrue
return
}
var (
exprs []string
equalOpr string = "="
inOpr string = "IN"
nullOpr string = "IS"
exprs []string
equalOpr = "="
inOpr = "IN"
nullOpr = "IS"
inEmptyExpr = sqlFalse
)
if useNotOpr {
equalOpr = "<>"
inOpr = "NOT IN"
nullOpr = "IS NOT"
inEmptyExpr = sqlTrue
}
for key, val := range eq {
expr := ""
sortedKeys := getSortedKeys(eq)
for _, key := range sortedKeys {
var expr string
val := eq[key]
switch v := val.(type) {
case driver.Valuer:
@@ -95,13 +170,22 @@ func (eq Eq) toSql(useNotOpr bool) (sql string, args []interface{}, err error) {
}
}
r := reflect.ValueOf(val)
if r.Kind() == reflect.Ptr {
if r.IsNil() {
val = nil
} else {
val = r.Elem().Interface()
}
}
if val == nil {
expr = fmt.Sprintf("%s %s NULL", key, nullOpr)
} else {
valVal := reflect.ValueOf(val)
if valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice {
if isListType(val) {
valVal := reflect.ValueOf(val)
if valVal.Len() == 0 {
expr = fmt.Sprintf("%s %s (NULL)", key, inOpr)
expr = inEmptyExpr
if args == nil {
args = []interface{}{}
}
@@ -123,7 +207,7 @@ func (eq Eq) toSql(useNotOpr bool) (sql string, args []interface{}, err error) {
}
func (eq Eq) ToSql() (sql string, args []interface{}, err error) {
return eq.toSql(false)
return eq.toSQL(false)
}
// NotEq is syntactic sugar for use with Where/Having/Set methods.
@@ -132,7 +216,73 @@ func (eq Eq) ToSql() (sql string, args []interface{}, err error) {
type NotEq Eq
func (neq NotEq) ToSql() (sql string, args []interface{}, err error) {
return Eq(neq).toSql(true)
return Eq(neq).toSQL(true)
}
// Like is syntactic sugar for use with LIKE conditions.
// Ex:
// .Where(Like{"name": "%irrel"})
type Like map[string]interface{}
func (lk Like) toSql(opr string) (sql string, args []interface{}, err error) {
var exprs []string
for key, val := range lk {
expr := ""
switch v := val.(type) {
case driver.Valuer:
if val, err = v.Value(); err != nil {
return
}
}
if val == nil {
err = fmt.Errorf("cannot use null with like operators")
return
} else {
if isListType(val) {
err = fmt.Errorf("cannot use array or slice with like operators")
return
} else {
expr = fmt.Sprintf("%s %s ?", key, opr)
args = append(args, val)
}
}
exprs = append(exprs, expr)
}
sql = strings.Join(exprs, " AND ")
return
}
func (lk Like) ToSql() (sql string, args []interface{}, err error) {
return lk.toSql("LIKE")
}
// NotLike is syntactic sugar for use with LIKE conditions.
// Ex:
// .Where(NotLike{"name": "%irrel"})
type NotLike Like
func (nlk NotLike) ToSql() (sql string, args []interface{}, err error) {
return Like(nlk).toSql("NOT LIKE")
}
// ILike is syntactic sugar for use with ILIKE conditions.
// Ex:
// .Where(ILike{"name": "sq%"})
type ILike Like
func (ilk ILike) ToSql() (sql string, args []interface{}, err error) {
return Like(ilk).toSql("ILIKE")
}
// NotILike is syntactic sugar for use with ILIKE conditions.
// Ex:
// .Where(NotILike{"name": "sq%"})
type NotILike Like
func (nilk NotILike) ToSql() (sql string, args []interface{}, err error) {
return Like(nilk).toSql("NOT ILIKE")
}
// Lt is syntactic sugar for use with Where/Having/Set methods.
@@ -143,7 +293,7 @@ type Lt map[string]interface{}
func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err error) {
var (
exprs []string
opr string = "<"
opr = "<"
)
if opposite {
@@ -154,8 +304,10 @@ func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err err
opr = fmt.Sprintf("%s%s", opr, "=")
}
for key, val := range lt {
expr := ""
sortedKeys := getSortedKeys(lt)
for _, key := range sortedKeys {
var expr string
val := lt[key]
switch v := val.(type) {
case driver.Valuer:
@@ -167,16 +319,14 @@ func (lt Lt) toSql(opposite, orEq bool) (sql string, args []interface{}, err err
if val == nil {
err = fmt.Errorf("cannot use null with less than or greater than operators")
return
} else {
valVal := reflect.ValueOf(val)
if valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice {
err = fmt.Errorf("cannot use array or slice with less than or greater than operators")
return
} else {
expr = fmt.Sprintf("%s %s ?", key, opr)
args = append(args, val)
}
}
if isListType(val) {
err = fmt.Errorf("cannot use array or slice with less than or greater than operators")
return
}
expr = fmt.Sprintf("%s %s ?", key, opr)
args = append(args, val)
exprs = append(exprs, expr)
}
sql = strings.Join(exprs, " AND ")
@@ -216,15 +366,18 @@ func (gtOrEq GtOrEq) ToSql() (sql string, args []interface{}, err error) {
type conj []Sqlizer
func (c conj) join(sep string) (sql string, args []interface{}, err error) {
func (c conj) join(sep, defaultExpr string) (sql string, args []interface{}, err error) {
if len(c) == 0 {
return defaultExpr, []interface{}{}, nil
}
var sqlParts []string
for _, sqlizer := range c {
partSql, partArgs, err := sqlizer.ToSql()
partSQL, partArgs, err := nestedToSql(sqlizer)
if err != nil {
return "", nil, err
}
if partSql != "" {
sqlParts = append(sqlParts, partSql)
if partSQL != "" {
sqlParts = append(sqlParts, partSQL)
args = append(args, partArgs...)
}
}
@@ -234,14 +387,33 @@ func (c conj) join(sep string) (sql string, args []interface{}, err error) {
return
}
// And conjunction Sqlizers
type And conj
func (a And) ToSql() (string, []interface{}, error) {
return conj(a).join(" AND ")
return conj(a).join(" AND ", sqlTrue)
}
// Or conjunction Sqlizers
type Or conj
func (o Or) ToSql() (string, []interface{}, error) {
return conj(o).join(" OR ")
return conj(o).join(" OR ", sqlFalse)
}
func getSortedKeys(exp map[string]interface{}) []string {
sortedKeys := make([]string, 0, len(exp))
for k := range exp {
sortedKeys = append(sortedKeys, k)
}
sort.Strings(sortedKeys)
return sortedKeys
}
func isListType(val interface{}) bool {
if driver.IsValue(val) {
return false
}
valVal := reflect.ValueOf(val)
return valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice
}

View File

@@ -3,20 +3,26 @@ package squirrel
import (
"bytes"
"database/sql"
"errors"
"fmt"
"github.com/lann/builder"
"io"
"sort"
"strings"
"github.com/lann/builder"
)
type insertData struct {
PlaceholderFormat PlaceholderFormat
RunWith BaseRunner
Prefixes exprs
Prefixes []Sqlizer
StatementKeyword string
Options []string
Into string
Columns []string
Values [][]interface{}
Suffixes exprs
Suffixes []Sqlizer
Select *SelectBuilder
}
func (d *insertData) Exec() (sql.Result, error) {
@@ -46,22 +52,31 @@ func (d *insertData) QueryRow() RowScanner {
func (d *insertData) ToSql() (sqlStr string, args []interface{}, err error) {
if len(d.Into) == 0 {
err = fmt.Errorf("insert statements must specify a table")
err = errors.New("insert statements must specify a table")
return
}
if len(d.Values) == 0 {
err = fmt.Errorf("insert statements must have at least one set of values")
if len(d.Values) == 0 && d.Select == nil {
err = errors.New("insert statements must have at least one set of values or select clause")
return
}
sql := &bytes.Buffer{}
if len(d.Prefixes) > 0 {
args, _ = d.Prefixes.AppendToSql(sql, " ", args)
args, err = appendToSql(d.Prefixes, sql, " ", args)
if err != nil {
return
}
sql.WriteString(" ")
}
sql.WriteString("INSERT ")
if d.StatementKeyword == "" {
sql.WriteString("INSERT ")
} else {
sql.WriteString(d.StatementKeyword)
sql.WriteString(" ")
}
if len(d.Options) > 0 {
sql.WriteString(strings.Join(d.Options, " "))
@@ -78,16 +93,45 @@ func (d *insertData) ToSql() (sqlStr string, args []interface{}, err error) {
sql.WriteString(") ")
}
sql.WriteString("VALUES ")
if d.Select != nil {
args, err = d.appendSelectToSQL(sql, args)
} else {
args, err = d.appendValuesToSQL(sql, args)
}
if err != nil {
return
}
if len(d.Suffixes) > 0 {
sql.WriteString(" ")
args, err = appendToSql(d.Suffixes, sql, " ", args)
if err != nil {
return
}
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String())
return
}
func (d *insertData) appendValuesToSQL(w io.Writer, args []interface{}) ([]interface{}, error) {
if len(d.Values) == 0 {
return args, errors.New("values for insert statements are not set")
}
io.WriteString(w, "VALUES ")
valuesStrings := make([]string, len(d.Values))
for r, row := range d.Values {
valueStrings := make([]string, len(row))
for v, val := range row {
e, isExpr := val.(expr)
if isExpr {
valueStrings[v] = e.sql
args = append(args, e.args...)
if vs, ok := val.(Sqlizer); ok {
vsql, vargs, err := vs.ToSql()
if err != nil {
return nil, err
}
valueStrings[v] = vsql
args = append(args, vargs...)
} else {
valueStrings[v] = "?"
args = append(args, val)
@@ -95,15 +139,26 @@ func (d *insertData) ToSql() (sqlStr string, args []interface{}, err error) {
}
valuesStrings[r] = fmt.Sprintf("(%s)", strings.Join(valueStrings, ","))
}
sql.WriteString(strings.Join(valuesStrings, ","))
if len(d.Suffixes) > 0 {
sql.WriteString(" ")
args, _ = d.Suffixes.AppendToSql(sql, " ", args)
io.WriteString(w, strings.Join(valuesStrings, ","))
return args, nil
}
func (d *insertData) appendSelectToSQL(w io.Writer, args []interface{}) ([]interface{}, error) {
if d.Select == nil {
return args, errors.New("select clause for insert statements are not set")
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String())
return
selectClause, sArgs, err := d.Select.ToSql()
if err != nil {
return args, err
}
io.WriteString(w, selectClause)
args = append(args, sArgs...)
return args, nil
}
// Builder
@@ -161,9 +216,24 @@ func (b InsertBuilder) ToSql() (string, []interface{}, error) {
return data.ToSql()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b InsertBuilder) MustSql() (string, []interface{}) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// Prefix adds an expression to the beginning of the query
func (b InsertBuilder) Prefix(sql string, args ...interface{}) InsertBuilder {
return builder.Append(b, "Prefixes", Expr(sql, args...)).(InsertBuilder)
return b.PrefixExpr(Expr(sql, args...))
}
// PrefixExpr adds an expression to the very beginning of the query
func (b InsertBuilder) PrefixExpr(expr Sqlizer) InsertBuilder {
return builder.Append(b, "Prefixes", expr).(InsertBuilder)
}
// Options adds keyword options before the INTO clause of the query.
@@ -188,20 +258,41 @@ func (b InsertBuilder) Values(values ...interface{}) InsertBuilder {
// Suffix adds an expression to the end of the query
func (b InsertBuilder) Suffix(sql string, args ...interface{}) InsertBuilder {
return builder.Append(b, "Suffixes", Expr(sql, args...)).(InsertBuilder)
return b.SuffixExpr(Expr(sql, args...))
}
// SuffixExpr adds an expression to the end of the query
func (b InsertBuilder) SuffixExpr(expr Sqlizer) InsertBuilder {
return builder.Append(b, "Suffixes", expr).(InsertBuilder)
}
// SetMap set columns and values for insert builder from a map of column name and value
// note that it will reset all previous columns and values was set if any
func (b InsertBuilder) SetMap(clauses map[string]interface{}) InsertBuilder {
// Keep the columns in a consistent order by sorting the column key string.
cols := make([]string, 0, len(clauses))
vals := make([]interface{}, 0, len(clauses))
for col, val := range clauses {
for col := range clauses {
cols = append(cols, col)
vals = append(vals, val)
}
sort.Strings(cols)
vals := make([]interface{}, 0, len(clauses))
for _, col := range cols {
vals = append(vals, clauses[col])
}
b = builder.Set(b, "Columns", cols).(InsertBuilder)
b = builder.Set(b, "Values", [][]interface{}{vals}).(InsertBuilder)
return b
}
// Select set Select clause for insert query
// If Values and Select are used, then Select has higher priority
func (b InsertBuilder) Select(sb SelectBuilder) InsertBuilder {
return builder.Set(b, "Select", &sb).(InsertBuilder)
}
func (b InsertBuilder) statementKeyword(keyword string) InsertBuilder {
return builder.Set(b, "StatementKeyword", keyword).(InsertBuilder)
}

69
vendor/github.com/Masterminds/squirrel/insert_ctx.go generated vendored Normal file
View File

@@ -0,0 +1,69 @@
// +build go1.8
package squirrel
import (
"context"
"database/sql"
"github.com/lann/builder"
)
func (d *insertData) ExecContext(ctx context.Context) (sql.Result, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(ExecerContext)
if !ok {
return nil, NoContextSupport
}
return ExecContextWith(ctx, ctxRunner, d)
}
func (d *insertData) QueryContext(ctx context.Context) (*sql.Rows, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(QueryerContext)
if !ok {
return nil, NoContextSupport
}
return QueryContextWith(ctx, ctxRunner, d)
}
func (d *insertData) QueryRowContext(ctx context.Context) RowScanner {
if d.RunWith == nil {
return &Row{err: RunnerNotSet}
}
queryRower, ok := d.RunWith.(QueryRowerContext)
if !ok {
if _, ok := d.RunWith.(QueryerContext); !ok {
return &Row{err: RunnerNotQueryRunner}
}
return &Row{err: NoContextSupport}
}
return QueryRowContextWith(ctx, queryRower, d)
}
// ExecContext builds and ExecContexts the query with the Runner set by RunWith.
func (b InsertBuilder) ExecContext(ctx context.Context) (sql.Result, error) {
data := builder.GetStruct(b).(insertData)
return data.ExecContext(ctx)
}
// QueryContext builds and QueryContexts the query with the Runner set by RunWith.
func (b InsertBuilder) QueryContext(ctx context.Context) (*sql.Rows, error) {
data := builder.GetStruct(b).(insertData)
return data.QueryContext(ctx)
}
// QueryRowContext builds and QueryRowContexts the query with the Runner set by RunWith.
func (b InsertBuilder) QueryRowContext(ctx context.Context) RowScanner {
data := builder.GetStruct(b).(insertData)
return data.QueryRowContext(ctx)
}
// ScanContext is a shortcut for QueryRowContext().Scan.
func (b InsertBuilder) ScanContext(ctx context.Context, dest ...interface{}) error {
return b.QueryRowContext(ctx).Scan(dest...)
}

View File

@@ -19,7 +19,7 @@ func (p part) ToSql() (sql string, args []interface{}, err error) {
case nil:
// no-op
case Sqlizer:
sql, args, err = pred.ToSql()
sql, args, err = nestedToSql(pred)
case string:
sql = pred
args = p.args
@@ -29,9 +29,17 @@ func (p part) ToSql() (sql string, args []interface{}, err error) {
return
}
func nestedToSql(s Sqlizer) (string, []interface{}, error) {
if raw, ok := s.(rawSqlizer); ok {
return raw.toSqlRaw()
} else {
return s.ToSql()
}
}
func appendToSql(parts []Sqlizer, w io.Writer, sep string, args []interface{}) ([]interface{}, error) {
for i, p := range parts {
partSql, partArgs, err := p.ToSql()
partSql, partArgs, err := nestedToSql(p)
if err != nil {
return nil, err
} else if len(partSql) == 0 {

View File

@@ -14,6 +14,10 @@ type PlaceholderFormat interface {
ReplacePlaceholders(sql string) (string, error)
}
type placeholderDebugger interface {
debugPlaceholder() string
}
var (
// Question is a PlaceholderFormat instance that leaves placeholders as
// question marks.
@@ -22,17 +26,66 @@ var (
// Dollar is a PlaceholderFormat instance that replaces placeholders with
// dollar-prefixed positional placeholders (e.g. $1, $2, $3).
Dollar = dollarFormat{}
// Colon is a PlaceholderFormat instance that replaces placeholders with
// colon-prefixed positional placeholders (e.g. :1, :2, :3).
Colon = colonFormat{}
// AtP is a PlaceholderFormat instance that replaces placeholders with
// "@p"-prefixed positional placeholders (e.g. @p1, @p2, @p3).
AtP = atpFormat{}
)
type questionFormat struct{}
func (_ questionFormat) ReplacePlaceholders(sql string) (string, error) {
func (questionFormat) ReplacePlaceholders(sql string) (string, error) {
return sql, nil
}
func (questionFormat) debugPlaceholder() string {
return "?"
}
type dollarFormat struct{}
func (_ dollarFormat) ReplacePlaceholders(sql string) (string, error) {
func (dollarFormat) ReplacePlaceholders(sql string) (string, error) {
return replacePositionalPlaceholders(sql, "$")
}
func (dollarFormat) debugPlaceholder() string {
return "$"
}
type colonFormat struct{}
func (colonFormat) ReplacePlaceholders(sql string) (string, error) {
return replacePositionalPlaceholders(sql, ":")
}
func (colonFormat) debugPlaceholder() string {
return ":"
}
type atpFormat struct{}
func (atpFormat) ReplacePlaceholders(sql string) (string, error) {
return replacePositionalPlaceholders(sql, "@p")
}
func (atpFormat) debugPlaceholder() string {
return "@p"
}
// Placeholders returns a string with count ? placeholders joined with commas.
func Placeholders(count int) string {
if count < 1 {
return ""
}
return strings.Repeat(",?", count)[1:]
}
func replacePositionalPlaceholders(sql, prefix string) (string, error) {
buf := &bytes.Buffer{}
i := 0
for {
@@ -51,7 +104,7 @@ func (_ dollarFormat) ReplacePlaceholders(sql string) (string, error) {
} else {
i++
buf.WriteString(sql[:p])
fmt.Fprintf(buf, "$%d", i)
fmt.Fprintf(buf, "%s%d", prefix, i)
sql = sql[p+1:]
}
}
@@ -59,12 +112,3 @@ func (_ dollarFormat) ReplacePlaceholders(sql string) (string, error) {
buf.WriteString(sql)
return buf.String(), nil
}
// Placeholders returns a string with count ? placeholders joined with commas.
func Placeholders(count int) string {
if count < 1 {
return ""
}
return strings.Repeat(",?", count)[1:]
}

View File

@@ -12,7 +12,7 @@ import (
type selectData struct {
PlaceholderFormat PlaceholderFormat
RunWith BaseRunner
Prefixes exprs
Prefixes []Sqlizer
Options []string
Columns []Sqlizer
From Sqlizer
@@ -20,10 +20,10 @@ type selectData struct {
WhereParts []Sqlizer
GroupBys []string
HavingParts []Sqlizer
OrderBys []string
OrderByParts []Sqlizer
Limit string
Offset string
Suffixes exprs
Suffixes []Sqlizer
}
func (d *selectData) Exec() (sql.Result, error) {
@@ -52,6 +52,16 @@ func (d *selectData) QueryRow() RowScanner {
}
func (d *selectData) ToSql() (sqlStr string, args []interface{}, err error) {
sqlStr, args, err = d.toSqlRaw()
if err != nil {
return
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sqlStr)
return
}
func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) {
if len(d.Columns) == 0 {
err = fmt.Errorf("select statements must have at least one result column")
return
@@ -60,7 +70,11 @@ func (d *selectData) ToSql() (sqlStr string, args []interface{}, err error) {
sql := &bytes.Buffer{}
if len(d.Prefixes) > 0 {
args, _ = d.Prefixes.AppendToSql(sql, " ", args)
args, err = appendToSql(d.Prefixes, sql, " ", args)
if err != nil {
return
}
sql.WriteString(" ")
}
@@ -115,9 +129,12 @@ func (d *selectData) ToSql() (sqlStr string, args []interface{}, err error) {
}
}
if len(d.OrderBys) > 0 {
if len(d.OrderByParts) > 0 {
sql.WriteString(" ORDER BY ")
sql.WriteString(strings.Join(d.OrderBys, ", "))
args, err = appendToSql(d.OrderByParts, sql, ", ", args)
if err != nil {
return
}
}
if len(d.Limit) > 0 {
@@ -132,10 +149,14 @@ func (d *selectData) ToSql() (sqlStr string, args []interface{}, err error) {
if len(d.Suffixes) > 0 {
sql.WriteString(" ")
args, _ = d.Suffixes.AppendToSql(sql, " ", args)
args, err = appendToSql(d.Suffixes, sql, " ", args)
if err != nil {
return
}
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String())
sqlStr = sql.String()
return
}
@@ -159,6 +180,9 @@ func (b SelectBuilder) PlaceholderFormat(f PlaceholderFormat) SelectBuilder {
// Runner methods
// RunWith sets a Runner (like database/sql.DB) to be used with e.g. Exec.
// For most cases runner will be a database connection.
//
// Internally we use this to mock out the database connection for testing.
func (b SelectBuilder) RunWith(runner BaseRunner) SelectBuilder {
return setRunWith(b, runner).(SelectBuilder)
}
@@ -194,9 +218,29 @@ func (b SelectBuilder) ToSql() (string, []interface{}, error) {
return data.ToSql()
}
func (b SelectBuilder) toSqlRaw() (string, []interface{}, error) {
data := builder.GetStruct(b).(selectData)
return data.toSqlRaw()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b SelectBuilder) MustSql() (string, []interface{}) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// Prefix adds an expression to the beginning of the query
func (b SelectBuilder) Prefix(sql string, args ...interface{}) SelectBuilder {
return builder.Append(b, "Prefixes", Expr(sql, args...)).(SelectBuilder)
return b.PrefixExpr(Expr(sql, args...))
}
// PrefixExpr adds an expression to the very beginning of the query
func (b SelectBuilder) PrefixExpr(expr Sqlizer) SelectBuilder {
return builder.Append(b, "Prefixes", expr).(SelectBuilder)
}
// Distinct adds a DISTINCT clause to the query.
@@ -211,7 +255,7 @@ func (b SelectBuilder) Options(options ...string) SelectBuilder {
// Columns adds result columns to the query.
func (b SelectBuilder) Columns(columns ...string) SelectBuilder {
var parts []interface{}
parts := make([]interface{}, 0, len(columns))
for _, str := range columns {
parts = append(parts, newPart(str))
}
@@ -233,6 +277,8 @@ func (b SelectBuilder) From(from string) SelectBuilder {
// FromSelect sets a subquery into the FROM clause of the query.
func (b SelectBuilder) FromSelect(from SelectBuilder, alias string) SelectBuilder {
// Prevent misnumbered parameters in nested selects (#183).
from = from.PlaceholderFormat(Question)
return builder.Set(b, "From", Alias(from, alias)).(SelectBuilder)
}
@@ -256,6 +302,16 @@ func (b SelectBuilder) RightJoin(join string, rest ...interface{}) SelectBuilder
return b.JoinClause("RIGHT JOIN "+join, rest...)
}
// InnerJoin adds a INNER JOIN clause to the query.
func (b SelectBuilder) InnerJoin(join string, rest ...interface{}) SelectBuilder {
return b.JoinClause("INNER JOIN "+join, rest...)
}
// CrossJoin adds a CROSS JOIN clause to the query.
func (b SelectBuilder) CrossJoin(join string, rest ...interface{}) SelectBuilder {
return b.JoinClause("CROSS JOIN "+join, rest...)
}
// Where adds an expression to the WHERE clause of the query.
//
// Expressions are ANDed together in the generated SQL.
@@ -277,6 +333,9 @@ func (b SelectBuilder) RightJoin(join string, rest ...interface{}) SelectBuilder
//
// Where will panic if pred isn't any of the above types.
func (b SelectBuilder) Where(pred interface{}, args ...interface{}) SelectBuilder {
if pred == nil || pred == "" {
return b
}
return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(SelectBuilder)
}
@@ -292,9 +351,18 @@ func (b SelectBuilder) Having(pred interface{}, rest ...interface{}) SelectBuild
return builder.Append(b, "HavingParts", newWherePart(pred, rest...)).(SelectBuilder)
}
// OrderByClause adds ORDER BY clause to the query.
func (b SelectBuilder) OrderByClause(pred interface{}, args ...interface{}) SelectBuilder {
return builder.Append(b, "OrderByParts", newPart(pred, args...)).(SelectBuilder)
}
// OrderBy adds ORDER BY expressions to the query.
func (b SelectBuilder) OrderBy(orderBys ...string) SelectBuilder {
return builder.Extend(b, "OrderBys", orderBys).(SelectBuilder)
for _, orderBy := range orderBys {
b = b.OrderByClause(orderBy)
}
return b
}
// Limit sets a LIMIT clause on the query.
@@ -302,12 +370,27 @@ func (b SelectBuilder) Limit(limit uint64) SelectBuilder {
return builder.Set(b, "Limit", fmt.Sprintf("%d", limit)).(SelectBuilder)
}
// Limit ALL allows to access all records with limit
func (b SelectBuilder) RemoveLimit() SelectBuilder {
return builder.Delete(b, "Limit").(SelectBuilder)
}
// Offset sets a OFFSET clause on the query.
func (b SelectBuilder) Offset(offset uint64) SelectBuilder {
return builder.Set(b, "Offset", fmt.Sprintf("%d", offset)).(SelectBuilder)
}
// RemoveOffset removes OFFSET clause.
func (b SelectBuilder) RemoveOffset() SelectBuilder {
return builder.Delete(b, "Offset").(SelectBuilder)
}
// Suffix adds an expression to the end of the query
func (b SelectBuilder) Suffix(sql string, args ...interface{}) SelectBuilder {
return builder.Append(b, "Suffixes", Expr(sql, args...)).(SelectBuilder)
return b.SuffixExpr(Expr(sql, args...))
}
// SuffixExpr adds an expression to the end of the query
func (b SelectBuilder) SuffixExpr(expr Sqlizer) SelectBuilder {
return builder.Append(b, "Suffixes", expr).(SelectBuilder)
}

69
vendor/github.com/Masterminds/squirrel/select_ctx.go generated vendored Normal file
View File

@@ -0,0 +1,69 @@
// +build go1.8
package squirrel
import (
"context"
"database/sql"
"github.com/lann/builder"
)
func (d *selectData) ExecContext(ctx context.Context) (sql.Result, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(ExecerContext)
if !ok {
return nil, NoContextSupport
}
return ExecContextWith(ctx, ctxRunner, d)
}
func (d *selectData) QueryContext(ctx context.Context) (*sql.Rows, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(QueryerContext)
if !ok {
return nil, NoContextSupport
}
return QueryContextWith(ctx, ctxRunner, d)
}
func (d *selectData) QueryRowContext(ctx context.Context) RowScanner {
if d.RunWith == nil {
return &Row{err: RunnerNotSet}
}
queryRower, ok := d.RunWith.(QueryRowerContext)
if !ok {
if _, ok := d.RunWith.(QueryerContext); !ok {
return &Row{err: RunnerNotQueryRunner}
}
return &Row{err: NoContextSupport}
}
return QueryRowContextWith(ctx, queryRower, d)
}
// ExecContext builds and ExecContexts the query with the Runner set by RunWith.
func (b SelectBuilder) ExecContext(ctx context.Context) (sql.Result, error) {
data := builder.GetStruct(b).(selectData)
return data.ExecContext(ctx)
}
// QueryContext builds and QueryContexts the query with the Runner set by RunWith.
func (b SelectBuilder) QueryContext(ctx context.Context) (*sql.Rows, error) {
data := builder.GetStruct(b).(selectData)
return data.QueryContext(ctx)
}
// QueryRowContext builds and QueryRowContexts the query with the Runner set by RunWith.
func (b SelectBuilder) QueryRowContext(ctx context.Context) RowScanner {
data := builder.GetStruct(b).(selectData)
return data.QueryRowContext(ctx)
}
// ScanContext is a shortcut for QueryRowContext().Scan.
func (b SelectBuilder) ScanContext(ctx context.Context, dest ...interface{}) error {
return b.QueryRowContext(ctx).Scan(dest...)
}

View File

@@ -1,6 +1,6 @@
// Package squirrel provides a fluent SQL generator.
//
// See https://github.com/lann/squirrel for examples.
// See https://github.com/Masterminds/squirrel for examples.
package squirrel
import (
@@ -20,6 +20,12 @@ type Sqlizer interface {
ToSql() (string, []interface{}, error)
}
// rawSqlizer is expected to do what Sqlizer does, but without finalizing placeholders.
// This is useful for nested queries.
type rawSqlizer interface {
toSqlRaw() (string, []interface{}, error)
}
// Execer is the interface that wraps the Exec method.
//
// Exec executes the given query as implemented by database/sql.Exec.
@@ -54,32 +60,34 @@ type Runner interface {
QueryRower
}
// DBRunner wraps sql.DB to implement Runner.
type dbRunner struct {
*sql.DB
// WrapStdSql wraps a type implementing the standard SQL interface with methods that
// squirrel expects.
func WrapStdSql(stdSql StdSql) Runner {
return &stdsqlRunner{stdSql}
}
func (r *dbRunner) QueryRow(query string, args ...interface{}) RowScanner {
return r.DB.QueryRow(query, args...)
// StdSql encompasses the standard methods of the *sql.DB type, and other types that
// wrap these methods.
type StdSql interface {
Query(string, ...interface{}) (*sql.Rows, error)
QueryRow(string, ...interface{}) *sql.Row
Exec(string, ...interface{}) (sql.Result, error)
}
type txRunner struct {
*sql.Tx
type stdsqlRunner struct {
StdSql
}
func (r *txRunner) QueryRow(query string, args ...interface{}) RowScanner {
return r.Tx.QueryRow(query, args...)
func (r *stdsqlRunner) QueryRow(query string, args ...interface{}) RowScanner {
return r.StdSql.QueryRow(query, args...)
}
func setRunWith(b interface{}, baseRunner BaseRunner) interface{} {
var runner Runner
switch r := baseRunner.(type) {
case Runner:
runner = r
case *sql.DB:
runner = &dbRunner{r}
case *sql.Tx:
runner = &txRunner{r}
func setRunWith(b interface{}, runner BaseRunner) interface{} {
switch r := runner.(type) {
case StdSqlCtx:
runner = WrapStdSqlCtx(r)
case StdSql:
runner = WrapStdSql(r)
}
return builder.Set(b, "RunWith", runner)
}
@@ -129,11 +137,18 @@ func DebugSqlizer(s Sqlizer) string {
return fmt.Sprintf("[ToSql error: %s]", err)
}
var placeholder string
downCast, ok := s.(placeholderDebugger)
if !ok {
placeholder = "?"
} else {
placeholder = downCast.debugPlaceholder()
}
// TODO: dedupe this with placeholder.go
buf := &bytes.Buffer{}
i := 0
for {
p := strings.Index(sql, "?")
p := strings.Index(sql, placeholder)
if p == -1 {
break
}
@@ -152,6 +167,7 @@ func DebugSqlizer(s Sqlizer) string {
}
buf.WriteString(sql[:p])
fmt.Fprintf(buf, "'%v'", args[i])
// advance our sql string "cursor" beyond the arg we placed
sql = sql[p+1:]
i++
}
@@ -161,6 +177,7 @@ func DebugSqlizer(s Sqlizer) string {
"[DebugSqlizer error: not enough placeholders in %#v for %d args]",
sql, len(args))
}
// "append" any remaning sql that won't need interpolating
buf.WriteString(sql)
return buf.String()
}

93
vendor/github.com/Masterminds/squirrel/squirrel_ctx.go generated vendored Normal file
View File

@@ -0,0 +1,93 @@
// +build go1.8
package squirrel
import (
"context"
"database/sql"
"errors"
)
// NoContextSupport is returned if a db doesn't support Context.
var NoContextSupport = errors.New("DB does not support Context")
// ExecerContext is the interface that wraps the ExecContext method.
//
// Exec executes the given query as implemented by database/sql.ExecContext.
type ExecerContext interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
// QueryerContext is the interface that wraps the QueryContext method.
//
// QueryContext executes the given query as implemented by database/sql.QueryContext.
type QueryerContext interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
// QueryRowerContext is the interface that wraps the QueryRowContext method.
//
// QueryRowContext executes the given query as implemented by database/sql.QueryRowContext.
type QueryRowerContext interface {
QueryRowContext(ctx context.Context, query string, args ...interface{}) RowScanner
}
// RunnerContext groups the Runner interface, along with the Context versions of each of
// its methods
type RunnerContext interface {
Runner
QueryerContext
QueryRowerContext
ExecerContext
}
// WrapStdSqlCtx wraps a type implementing the standard SQL interface plus the context
// versions of the methods with methods that squirrel expects.
func WrapStdSqlCtx(stdSqlCtx StdSqlCtx) RunnerContext {
return &stdsqlCtxRunner{stdSqlCtx}
}
// StdSqlCtx encompasses the standard methods of the *sql.DB type, along with the Context
// versions of those methods, and other types that wrap these methods.
type StdSqlCtx interface {
StdSql
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
}
type stdsqlCtxRunner struct {
StdSqlCtx
}
func (r *stdsqlCtxRunner) QueryRow(query string, args ...interface{}) RowScanner {
return r.StdSqlCtx.QueryRow(query, args...)
}
func (r *stdsqlCtxRunner) QueryRowContext(ctx context.Context, query string, args ...interface{}) RowScanner {
return r.StdSqlCtx.QueryRowContext(ctx, query, args...)
}
// ExecContextWith ExecContexts the SQL returned by s with db.
func ExecContextWith(ctx context.Context, db ExecerContext, s Sqlizer) (res sql.Result, err error) {
query, args, err := s.ToSql()
if err != nil {
return
}
return db.ExecContext(ctx, query, args...)
}
// QueryContextWith QueryContexts the SQL returned by s with db.
func QueryContextWith(ctx context.Context, db QueryerContext, s Sqlizer) (rows *sql.Rows, err error) {
query, args, err := s.ToSql()
if err != nil {
return
}
return db.QueryContext(ctx, query, args...)
}
// QueryRowContextWith QueryRowContexts the SQL returned by s with db.
func QueryRowContextWith(ctx context.Context, db QueryRowerContext, s Sqlizer) RowScanner {
query, args, err := s.ToSql()
return &Row{RowScanner: db.QueryRowContext(ctx, query, args...), err: err}
}

View File

@@ -15,6 +15,12 @@ func (b StatementBuilderType) Insert(into string) InsertBuilder {
return InsertBuilder(b).Into(into)
}
// Replace returns a InsertBuilder for this StatementBuilderType with the
// statement keyword set to "REPLACE".
func (b StatementBuilderType) Replace(into string) InsertBuilder {
return InsertBuilder(b).statementKeyword("REPLACE").Into(into)
}
// Update returns a UpdateBuilder for this StatementBuilderType.
func (b StatementBuilderType) Update(table string) UpdateBuilder {
return UpdateBuilder(b).Table(table)
@@ -35,6 +41,13 @@ func (b StatementBuilderType) RunWith(runner BaseRunner) StatementBuilderType {
return setRunWith(b, runner).(StatementBuilderType)
}
// Where adds WHERE expressions to the query.
//
// See SelectBuilder.Where for more information.
func (b StatementBuilderType) Where(pred interface{}, args ...interface{}) StatementBuilderType {
return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(StatementBuilderType)
}
// StatementBuilder is a parent builder for other builders, e.g. SelectBuilder.
var StatementBuilder = StatementBuilderType(builder.EmptyBuilder).PlaceholderFormat(Question)
@@ -52,6 +65,14 @@ func Insert(into string) InsertBuilder {
return StatementBuilder.Insert(into)
}
// Replace returns a new InsertBuilder with the statement keyword set to
// "REPLACE" and with the given table name.
//
// See InsertBuilder.Into.
func Replace(into string) InsertBuilder {
return StatementBuilder.Replace(into)
}
// Update returns a new UpdateBuilder with the given table name.
//
// See UpdateBuilder.Table.

View File

@@ -2,6 +2,7 @@ package squirrel
import (
"database/sql"
"fmt"
"sync"
)
@@ -20,22 +21,25 @@ type DBProxy interface {
Preparer
}
type stmtCacher struct {
// NOTE: NewStmtCache is defined in stmtcacher_ctx.go (Go >= 1.8) or stmtcacher_noctx.go (Go < 1.8).
// StmtCache wraps and delegates down to a Preparer type
//
// It also automatically prepares all statements sent to the underlying Preparer calls
// for Exec, Query and QueryRow and caches the returns *sql.Stmt using the provided
// query as the key. So that it can be automatically re-used.
type StmtCache struct {
prep Preparer
cache map[string]*sql.Stmt
mu sync.Mutex
}
// NewStmtCacher returns a DBProxy wrapping prep that caches Prepared Stmts.
//
// Stmts are cached based on the string value of their queries.
func NewStmtCacher(prep Preparer) DBProxy {
return &stmtCacher{prep: prep, cache: make(map[string]*sql.Stmt)}
}
func (sc *stmtCacher) Prepare(query string) (*sql.Stmt, error) {
// Prepare delegates down to the underlying Preparer and caches the result
// using the provided query as a key
func (sc *StmtCache) Prepare(query string) (*sql.Stmt, error) {
sc.mu.Lock()
defer sc.mu.Unlock()
stmt, ok := sc.cache[query]
if ok {
return stmt, nil
@@ -47,7 +51,8 @@ func (sc *stmtCacher) Prepare(query string) (*sql.Stmt, error) {
return stmt, err
}
func (sc *stmtCacher) Exec(query string, args ...interface{}) (res sql.Result, err error) {
// Exec delegates down to the underlying Preparer using a prepared statement
func (sc *StmtCache) Exec(query string, args ...interface{}) (res sql.Result, err error) {
stmt, err := sc.Prepare(query)
if err != nil {
return
@@ -55,7 +60,8 @@ func (sc *stmtCacher) Exec(query string, args ...interface{}) (res sql.Result, e
return stmt.Exec(args...)
}
func (sc *stmtCacher) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
// Query delegates down to the underlying Preparer using a prepared statement
func (sc *StmtCache) Query(query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := sc.Prepare(query)
if err != nil {
return
@@ -63,7 +69,8 @@ func (sc *stmtCacher) Query(query string, args ...interface{}) (rows *sql.Rows,
return stmt.Query(args...)
}
func (sc *stmtCacher) QueryRow(query string, args ...interface{}) RowScanner {
// QueryRow delegates down to the underlying Preparer using a prepared statement
func (sc *StmtCache) QueryRow(query string, args ...interface{}) RowScanner {
stmt, err := sc.Prepare(query)
if err != nil {
return &Row{err: err}
@@ -71,6 +78,30 @@ func (sc *stmtCacher) QueryRow(query string, args ...interface{}) RowScanner {
return stmt.QueryRow(args...)
}
// Clear removes and closes all the currently cached prepared statements
func (sc *StmtCache) Clear() (err error) {
sc.mu.Lock()
defer sc.mu.Unlock()
for key, stmt := range sc.cache {
delete(sc.cache, key)
if stmt == nil {
continue
}
if cerr := stmt.Close(); cerr != nil {
err = cerr
}
}
if err != nil {
return fmt.Errorf("one or more Stmt.Close failed; last error: %v", err)
}
return
}
type DBProxyBeginner interface {
DBProxy
Begin() (*sql.Tx, error)
@@ -82,7 +113,7 @@ type stmtCacheProxy struct {
}
func NewStmtCacheProxy(db *sql.DB) DBProxyBeginner {
return &stmtCacheProxy{DBProxy: NewStmtCacher(db), db: db}
return &stmtCacheProxy{DBProxy: NewStmtCache(db), db: db}
}
func (sp *stmtCacheProxy) Begin() (*sql.Tx, error) {

View File

@@ -0,0 +1,86 @@
// +build go1.8
package squirrel
import (
"context"
"database/sql"
)
// PrepareerContext is the interface that wraps the Prepare and PrepareContext methods.
//
// Prepare executes the given query as implemented by database/sql.Prepare.
// PrepareContext executes the given query as implemented by database/sql.PrepareContext.
type PreparerContext interface {
Preparer
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
}
// DBProxyContext groups the Execer, Queryer, QueryRower and PreparerContext interfaces.
type DBProxyContext interface {
Execer
Queryer
QueryRower
PreparerContext
}
// NewStmtCache returns a *StmtCache wrapping a PreparerContext that caches Prepared Stmts.
//
// Stmts are cached based on the string value of their queries.
func NewStmtCache(prep PreparerContext) *StmtCache {
return &StmtCache{prep: prep, cache: make(map[string]*sql.Stmt)}
}
// NewStmtCacher is deprecated
//
// Use NewStmtCache instead
func NewStmtCacher(prep PreparerContext) DBProxyContext {
return NewStmtCache(prep)
}
// PrepareContext delegates down to the underlying PreparerContext and caches the result
// using the provided query as a key
func (sc *StmtCache) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
ctxPrep, ok := sc.prep.(PreparerContext)
if !ok {
return nil, NoContextSupport
}
sc.mu.Lock()
defer sc.mu.Unlock()
stmt, ok := sc.cache[query]
if ok {
return stmt, nil
}
stmt, err := ctxPrep.PrepareContext(ctx, query)
if err == nil {
sc.cache[query] = stmt
}
return stmt, err
}
// ExecContext delegates down to the underlying PreparerContext using a prepared statement
func (sc *StmtCache) ExecContext(ctx context.Context, query string, args ...interface{}) (res sql.Result, err error) {
stmt, err := sc.PrepareContext(ctx, query)
if err != nil {
return
}
return stmt.ExecContext(ctx, args...)
}
// QueryContext delegates down to the underlying PreparerContext using a prepared statement
func (sc *StmtCache) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) {
stmt, err := sc.PrepareContext(ctx, query)
if err != nil {
return
}
return stmt.QueryContext(ctx, args...)
}
// QueryRowContext delegates down to the underlying PreparerContext using a prepared statement
func (sc *StmtCache) QueryRowContext(ctx context.Context, query string, args ...interface{}) RowScanner {
stmt, err := sc.PrepareContext(ctx, query)
if err != nil {
return &Row{err: err}
}
return stmt.QueryRowContext(ctx, args...)
}

View File

@@ -0,0 +1,21 @@
// +build !go1.8
package squirrel
import (
"database/sql"
)
// NewStmtCacher returns a DBProxy wrapping prep that caches Prepared Stmts.
//
// Stmts are cached based on the string value of their queries.
func NewStmtCache(prep Preparer) *StmtCache {
return &StmtCacher{prep: prep, cache: make(map[string]*sql.Stmt)}
}
// NewStmtCacher is deprecated
//
// Use NewStmtCache instead
func NewStmtCacher(prep Preparer) DBProxy {
return NewStmtCache(prep)
}

View File

@@ -13,14 +13,14 @@ import (
type updateData struct {
PlaceholderFormat PlaceholderFormat
RunWith BaseRunner
Prefixes exprs
Prefixes []Sqlizer
Table string
SetClauses []setClause
WhereParts []Sqlizer
OrderBys []string
Limit string
Offset string
Suffixes exprs
Suffixes []Sqlizer
}
type setClause struct {
@@ -66,7 +66,11 @@ func (d *updateData) ToSql() (sqlStr string, args []interface{}, err error) {
sql := &bytes.Buffer{}
if len(d.Prefixes) > 0 {
args, _ = d.Prefixes.AppendToSql(sql, " ", args)
args, err = appendToSql(d.Prefixes, sql, " ", args)
if err != nil {
return
}
sql.WriteString(" ")
}
@@ -77,10 +81,17 @@ func (d *updateData) ToSql() (sqlStr string, args []interface{}, err error) {
setSqls := make([]string, len(d.SetClauses))
for i, setClause := range d.SetClauses {
var valSql string
e, isExpr := setClause.value.(expr)
if isExpr {
valSql = e.sql
args = append(args, e.args...)
if vs, ok := setClause.value.(Sqlizer); ok {
vsql, vargs, err := vs.ToSql()
if err != nil {
return "", nil, err
}
if _, ok := vs.(SelectBuilder); ok {
valSql = fmt.Sprintf("(%s)", vsql)
} else {
valSql = vsql
}
args = append(args, vargs...)
} else {
valSql = "?"
args = append(args, setClause.value)
@@ -114,7 +125,10 @@ func (d *updateData) ToSql() (sqlStr string, args []interface{}, err error) {
if len(d.Suffixes) > 0 {
sql.WriteString(" ")
args, _ = d.Suffixes.AppendToSql(sql, " ", args)
args, err = appendToSql(d.Suffixes, sql, " ", args)
if err != nil {
return
}
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String())
@@ -173,9 +187,24 @@ func (b UpdateBuilder) ToSql() (string, []interface{}, error) {
return data.ToSql()
}
// MustSql builds the query into a SQL string and bound args.
// It panics if there are any errors.
func (b UpdateBuilder) MustSql() (string, []interface{}) {
sql, args, err := b.ToSql()
if err != nil {
panic(err)
}
return sql, args
}
// Prefix adds an expression to the beginning of the query
func (b UpdateBuilder) Prefix(sql string, args ...interface{}) UpdateBuilder {
return builder.Append(b, "Prefixes", Expr(sql, args...)).(UpdateBuilder)
return b.PrefixExpr(Expr(sql, args...))
}
// PrefixExpr adds an expression to the very beginning of the query
func (b UpdateBuilder) PrefixExpr(expr Sqlizer) UpdateBuilder {
return builder.Append(b, "Prefixes", expr).(UpdateBuilder)
}
// Table sets the table to be updated.
@@ -228,5 +257,10 @@ func (b UpdateBuilder) Offset(offset uint64) UpdateBuilder {
// Suffix adds an expression to the end of the query
func (b UpdateBuilder) Suffix(sql string, args ...interface{}) UpdateBuilder {
return builder.Append(b, "Suffixes", Expr(sql, args...)).(UpdateBuilder)
return b.SuffixExpr(Expr(sql, args...))
}
// SuffixExpr adds an expression to the end of the query
func (b UpdateBuilder) SuffixExpr(expr Sqlizer) UpdateBuilder {
return builder.Append(b, "Suffixes", expr).(UpdateBuilder)
}

69
vendor/github.com/Masterminds/squirrel/update_ctx.go generated vendored Normal file
View File

@@ -0,0 +1,69 @@
// +build go1.8
package squirrel
import (
"context"
"database/sql"
"github.com/lann/builder"
)
func (d *updateData) ExecContext(ctx context.Context) (sql.Result, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(ExecerContext)
if !ok {
return nil, NoContextSupport
}
return ExecContextWith(ctx, ctxRunner, d)
}
func (d *updateData) QueryContext(ctx context.Context) (*sql.Rows, error) {
if d.RunWith == nil {
return nil, RunnerNotSet
}
ctxRunner, ok := d.RunWith.(QueryerContext)
if !ok {
return nil, NoContextSupport
}
return QueryContextWith(ctx, ctxRunner, d)
}
func (d *updateData) QueryRowContext(ctx context.Context) RowScanner {
if d.RunWith == nil {
return &Row{err: RunnerNotSet}
}
queryRower, ok := d.RunWith.(QueryRowerContext)
if !ok {
if _, ok := d.RunWith.(QueryerContext); !ok {
return &Row{err: RunnerNotQueryRunner}
}
return &Row{err: NoContextSupport}
}
return QueryRowContextWith(ctx, queryRower, d)
}
// ExecContext builds and ExecContexts the query with the Runner set by RunWith.
func (b UpdateBuilder) ExecContext(ctx context.Context) (sql.Result, error) {
data := builder.GetStruct(b).(updateData)
return data.ExecContext(ctx)
}
// QueryContext builds and QueryContexts the query with the Runner set by RunWith.
func (b UpdateBuilder) QueryContext(ctx context.Context) (*sql.Rows, error) {
data := builder.GetStruct(b).(updateData)
return data.QueryContext(ctx)
}
// QueryRowContext builds and QueryRowContexts the query with the Runner set by RunWith.
func (b UpdateBuilder) QueryRowContext(ctx context.Context) RowScanner {
data := builder.GetStruct(b).(updateData)
return data.QueryRowContext(ctx)
}
// ScanContext is a shortcut for QueryRowContext().Scan.
func (b UpdateBuilder) ScanContext(ctx context.Context, dest ...interface{}) error {
return b.QueryRowContext(ctx).Scan(dest...)
}

View File

@@ -14,6 +14,8 @@ func (p wherePart) ToSql() (sql string, args []interface{}, err error) {
switch pred := p.pred.(type) {
case nil:
// no-op
case rawSqlizer:
return pred.toSqlRaw()
case Sqlizer:
return pred.ToSql()
case map[string]interface{}:

1
vendor/github.com/Microsoft/go-winio/CODEOWNERS generated vendored Normal file
View File

@@ -0,0 +1 @@
* @microsoft/containerplat

View File

@@ -1,4 +1,4 @@
# go-winio
# go-winio [![Build Status](https://github.com/microsoft/go-winio/actions/workflows/ci.yml/badge.svg)](https://github.com/microsoft/go-winio/actions/workflows/ci.yml)
This repository contains utilities for efficiently performing Win32 IO operations in
Go. Currently, this is focused on accessing named pipes and other file handles, and
@@ -11,12 +11,27 @@ package.
Please see the LICENSE file for licensing information.
This project has adopted the [Microsoft Open Source Code of
Conduct](https://opensource.microsoft.com/codeofconduct/). For more information
see the [Code of Conduct
FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact
[opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional
questions or comments.
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA)
declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.microsoft.com.
When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA and decorate the PR
appropriately (e.g., label, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.
We also require that contributors sign their commits using git commit -s or git commit --signoff to certify they either authored the work themselves
or otherwise have permission to use it in this project. Please see https://developercertificate.org/ for more info, as well as to make sure that you can
attest to the rules listed. Our CI uses the DCO Github app to ensure that all commits in a given PR are signed-off.
## Code of Conduct
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
## Special Thanks
Thanks to natefinch for the inspiration for this library. See https://github.com/natefinch/npipe
for another named pipe implementation.

View File

@@ -1,3 +1,4 @@
//go:build windows
// +build windows
package winio
@@ -16,6 +17,7 @@ import (
//sys createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) = CreateIoCompletionPort
//sys getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
//sys setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
//sys wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
type atomicBool int32
@@ -79,6 +81,7 @@ type win32File struct {
wg sync.WaitGroup
wgLock sync.RWMutex
closing atomicBool
socket bool
readDeadline deadlineHandler
writeDeadline deadlineHandler
}
@@ -109,7 +112,13 @@ func makeWin32File(h syscall.Handle) (*win32File, error) {
}
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
return makeWin32File(h)
// If we return the result of makeWin32File directly, it can result in an
// interface-wrapped nil, rather than a nil interface value.
f, err := makeWin32File(h)
if err != nil {
return nil, err
}
return f, nil
}
// closeHandle closes the resources associated with a Win32 handle
@@ -135,6 +144,11 @@ func (f *win32File) Close() error {
return nil
}
// IsClosed checks if the file has been closed
func (f *win32File) IsClosed() bool {
return f.closing.isSet()
}
// prepareIo prepares for a new IO operation.
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
func (f *win32File) prepareIo() (*ioOperation, error) {
@@ -190,6 +204,10 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
if f.closing.isSet() {
err = ErrFileClosed
}
} else if err != nil && f.socket {
// err is from Win32. Query the overlapped structure to get the winsock error.
var bytes, flags uint32
err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
}
case <-timeout:
cancelIoEx(f.handle, &c.o)
@@ -265,6 +283,10 @@ func (f *win32File) Flush() error {
return syscall.FlushFileBuffers(f.handle)
}
func (f *win32File) Fd() uintptr {
return uintptr(f.handle)
}
func (d *deadlineHandler) set(deadline time.Time) error {
d.setLock.Lock()
defer d.setLock.Unlock()

View File

@@ -5,21 +5,14 @@ package winio
import (
"os"
"runtime"
"syscall"
"unsafe"
)
//sys getFileInformationByHandleEx(h syscall.Handle, class uint32, buffer *byte, size uint32) (err error) = GetFileInformationByHandleEx
//sys setFileInformationByHandle(h syscall.Handle, class uint32, buffer *byte, size uint32) (err error) = SetFileInformationByHandle
const (
fileBasicInfo = 0
fileIDInfo = 0x12
"golang.org/x/sys/windows"
)
// FileBasicInfo contains file access time and file attributes information.
type FileBasicInfo struct {
CreationTime, LastAccessTime, LastWriteTime, ChangeTime syscall.Filetime
CreationTime, LastAccessTime, LastWriteTime, ChangeTime windows.Filetime
FileAttributes uint32
pad uint32 // padding
}
@@ -27,7 +20,7 @@ type FileBasicInfo struct {
// GetFileBasicInfo retrieves times and attributes for a file.
func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) {
bi := &FileBasicInfo{}
if err := getFileInformationByHandleEx(syscall.Handle(f.Fd()), fileBasicInfo, (*byte)(unsafe.Pointer(bi)), uint32(unsafe.Sizeof(*bi))); err != nil {
if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()), windows.FileBasicInfo, (*byte)(unsafe.Pointer(bi)), uint32(unsafe.Sizeof(*bi))); err != nil {
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
}
runtime.KeepAlive(f)
@@ -36,13 +29,32 @@ func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) {
// SetFileBasicInfo sets times and attributes for a file.
func SetFileBasicInfo(f *os.File, bi *FileBasicInfo) error {
if err := setFileInformationByHandle(syscall.Handle(f.Fd()), fileBasicInfo, (*byte)(unsafe.Pointer(bi)), uint32(unsafe.Sizeof(*bi))); err != nil {
if err := windows.SetFileInformationByHandle(windows.Handle(f.Fd()), windows.FileBasicInfo, (*byte)(unsafe.Pointer(bi)), uint32(unsafe.Sizeof(*bi))); err != nil {
return &os.PathError{Op: "SetFileInformationByHandle", Path: f.Name(), Err: err}
}
runtime.KeepAlive(f)
return nil
}
// FileStandardInfo contains extended information for the file.
// FILE_STANDARD_INFO in WinBase.h
// https://docs.microsoft.com/en-us/windows/win32/api/winbase/ns-winbase-file_standard_info
type FileStandardInfo struct {
AllocationSize, EndOfFile int64
NumberOfLinks uint32
DeletePending, Directory bool
}
// GetFileStandardInfo retrieves ended information for the file.
func GetFileStandardInfo(f *os.File) (*FileStandardInfo, error) {
si := &FileStandardInfo{}
if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()), windows.FileStandardInfo, (*byte)(unsafe.Pointer(si)), uint32(unsafe.Sizeof(*si))); err != nil {
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
}
runtime.KeepAlive(f)
return si, nil
}
// FileIDInfo contains the volume serial number and file ID for a file. This pair should be
// unique on a system.
type FileIDInfo struct {
@@ -53,7 +65,7 @@ type FileIDInfo struct {
// GetFileID retrieves the unique (volume, file ID) pair for a file.
func GetFileID(f *os.File) (*FileIDInfo, error) {
fileID := &FileIDInfo{}
if err := getFileInformationByHandleEx(syscall.Handle(f.Fd()), fileIDInfo, (*byte)(unsafe.Pointer(fileID)), uint32(unsafe.Sizeof(*fileID))); err != nil {
if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()), windows.FileIdInfo, (*byte)(unsafe.Pointer(fileID)), uint32(unsafe.Sizeof(*fileID))); err != nil {
return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err}
}
runtime.KeepAlive(f)

316
vendor/github.com/Microsoft/go-winio/hvsock.go generated vendored Normal file
View File

@@ -0,0 +1,316 @@
//go:build windows
// +build windows
package winio
import (
"fmt"
"io"
"net"
"os"
"syscall"
"time"
"unsafe"
"github.com/Microsoft/go-winio/pkg/guid"
)
//sys bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
const (
afHvSock = 34 // AF_HYPERV
socketError = ^uintptr(0)
)
// An HvsockAddr is an address for a AF_HYPERV socket.
type HvsockAddr struct {
VMID guid.GUID
ServiceID guid.GUID
}
type rawHvsockAddr struct {
Family uint16
_ uint16
VMID guid.GUID
ServiceID guid.GUID
}
// Network returns the address's network name, "hvsock".
func (addr *HvsockAddr) Network() string {
return "hvsock"
}
func (addr *HvsockAddr) String() string {
return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID)
}
// VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
func VsockServiceID(port uint32) guid.GUID {
g, _ := guid.FromString("00000000-facb-11e6-bd58-64006a7986d3")
g.Data1 = port
return g
}
func (addr *HvsockAddr) raw() rawHvsockAddr {
return rawHvsockAddr{
Family: afHvSock,
VMID: addr.VMID,
ServiceID: addr.ServiceID,
}
}
func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
addr.VMID = raw.VMID
addr.ServiceID = raw.ServiceID
}
// HvsockListener is a socket listener for the AF_HYPERV address family.
type HvsockListener struct {
sock *win32File
addr HvsockAddr
}
// HvsockConn is a connected socket of the AF_HYPERV address family.
type HvsockConn struct {
sock *win32File
local, remote HvsockAddr
}
func newHvSocket() (*win32File, error) {
fd, err := syscall.Socket(afHvSock, syscall.SOCK_STREAM, 1)
if err != nil {
return nil, os.NewSyscallError("socket", err)
}
f, err := makeWin32File(fd)
if err != nil {
syscall.Close(fd)
return nil, err
}
f.socket = true
return f, nil
}
// ListenHvsock listens for connections on the specified hvsock address.
func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
l := &HvsockListener{addr: *addr}
sock, err := newHvSocket()
if err != nil {
return nil, l.opErr("listen", err)
}
sa := addr.raw()
err = bind(sock.handle, unsafe.Pointer(&sa), int32(unsafe.Sizeof(sa)))
if err != nil {
return nil, l.opErr("listen", os.NewSyscallError("socket", err))
}
err = syscall.Listen(sock.handle, 16)
if err != nil {
return nil, l.opErr("listen", os.NewSyscallError("listen", err))
}
return &HvsockListener{sock: sock, addr: *addr}, nil
}
func (l *HvsockListener) opErr(op string, err error) error {
return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err}
}
// Addr returns the listener's network address.
func (l *HvsockListener) Addr() net.Addr {
return &l.addr
}
// Accept waits for the next connection and returns it.
func (l *HvsockListener) Accept() (_ net.Conn, err error) {
sock, err := newHvSocket()
if err != nil {
return nil, l.opErr("accept", err)
}
defer func() {
if sock != nil {
sock.Close()
}
}()
c, err := l.sock.prepareIo()
if err != nil {
return nil, l.opErr("accept", err)
}
defer l.sock.wg.Done()
// AcceptEx, per documentation, requires an extra 16 bytes per address.
const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
var addrbuf [addrlen * 2]byte
var bytes uint32
err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0, addrlen, addrlen, &bytes, &c.o)
_, err = l.sock.asyncIo(c, nil, bytes, err)
if err != nil {
return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
}
conn := &HvsockConn{
sock: sock,
}
conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
sock = nil
return conn, nil
}
// Close closes the listener, causing any pending Accept calls to fail.
func (l *HvsockListener) Close() error {
return l.sock.Close()
}
/* Need to finish ConnectEx handling
func DialHvsock(ctx context.Context, addr *HvsockAddr) (*HvsockConn, error) {
sock, err := newHvSocket()
if err != nil {
return nil, err
}
defer func() {
if sock != nil {
sock.Close()
}
}()
c, err := sock.prepareIo()
if err != nil {
return nil, err
}
defer sock.wg.Done()
var bytes uint32
err = windows.ConnectEx(windows.Handle(sock.handle), sa, nil, 0, &bytes, &c.o)
_, err = sock.asyncIo(ctx, c, nil, bytes, err)
if err != nil {
return nil, err
}
conn := &HvsockConn{
sock: sock,
remote: *addr,
}
sock = nil
return conn, nil
}
*/
func (conn *HvsockConn) opErr(op string, err error) error {
return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
}
func (conn *HvsockConn) Read(b []byte) (int, error) {
c, err := conn.sock.prepareIo()
if err != nil {
return 0, conn.opErr("read", err)
}
defer conn.sock.wg.Done()
buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
var flags, bytes uint32
err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err)
if err != nil {
if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError("wsarecv", err)
}
return 0, conn.opErr("read", err)
} else if n == 0 {
err = io.EOF
}
return n, err
}
func (conn *HvsockConn) Write(b []byte) (int, error) {
t := 0
for len(b) != 0 {
n, err := conn.write(b)
if err != nil {
return t + n, err
}
t += n
b = b[n:]
}
return t, nil
}
func (conn *HvsockConn) write(b []byte) (int, error) {
c, err := conn.sock.prepareIo()
if err != nil {
return 0, conn.opErr("write", err)
}
defer conn.sock.wg.Done()
buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
var bytes uint32
err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err)
if err != nil {
if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError("wsasend", err)
}
return 0, conn.opErr("write", err)
}
return n, err
}
// Close closes the socket connection, failing any pending read or write calls.
func (conn *HvsockConn) Close() error {
return conn.sock.Close()
}
func (conn *HvsockConn) IsClosed() bool {
return conn.sock.IsClosed()
}
func (conn *HvsockConn) shutdown(how int) error {
if conn.IsClosed() {
return ErrFileClosed
}
err := syscall.Shutdown(conn.sock.handle, how)
if err != nil {
return os.NewSyscallError("shutdown", err)
}
return nil
}
// CloseRead shuts down the read end of the socket, preventing future read operations.
func (conn *HvsockConn) CloseRead() error {
err := conn.shutdown(syscall.SHUT_RD)
if err != nil {
return conn.opErr("close", err)
}
return nil
}
// CloseWrite shuts down the write end of the socket, preventing future write operations and
// notifying the other endpoint that no more data will be written.
func (conn *HvsockConn) CloseWrite() error {
err := conn.shutdown(syscall.SHUT_WR)
if err != nil {
return conn.opErr("close", err)
}
return nil
}
// LocalAddr returns the local address of the connection.
func (conn *HvsockConn) LocalAddr() net.Addr {
return &conn.local
}
// RemoteAddr returns the remote address of the connection.
func (conn *HvsockConn) RemoteAddr() net.Addr {
return &conn.remote
}
// SetDeadline implements the net.Conn SetDeadline method.
func (conn *HvsockConn) SetDeadline(t time.Time) error {
conn.SetReadDeadline(t)
conn.SetWriteDeadline(t)
return nil
}
// SetReadDeadline implements the net.Conn SetReadDeadline method.
func (conn *HvsockConn) SetReadDeadline(t time.Time) error {
return conn.sock.SetReadDeadline(t)
}
// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
func (conn *HvsockConn) SetWriteDeadline(t time.Time) error {
return conn.sock.SetWriteDeadline(t)
}

View File

@@ -3,10 +3,13 @@
package winio
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"runtime"
"syscall"
"time"
"unsafe"
@@ -18,6 +21,48 @@ import (
//sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
//sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
type ioStatusBlock struct {
Status, Information uintptr
}
type objectAttributes struct {
Length uintptr
RootDirectory uintptr
ObjectName *unicodeString
Attributes uintptr
SecurityDescriptor *securityDescriptor
SecurityQoS uintptr
}
type unicodeString struct {
Length uint16
MaximumLength uint16
Buffer uintptr
}
type securityDescriptor struct {
Revision byte
Sbz1 byte
Control uint16
Owner uintptr
Group uintptr
Sacl uintptr
Dacl uintptr
}
type ntstatus int32
func (status ntstatus) Err() error {
if status >= 0 {
return nil
}
return rtlNtStatusToDosError(status)
}
const (
cERROR_PIPE_BUSY = syscall.Errno(231)
@@ -25,21 +70,20 @@ const (
cERROR_PIPE_CONNECTED = syscall.Errno(535)
cERROR_SEM_TIMEOUT = syscall.Errno(121)
cPIPE_ACCESS_DUPLEX = 0x3
cFILE_FLAG_FIRST_PIPE_INSTANCE = 0x80000
cSECURITY_SQOS_PRESENT = 0x100000
cSECURITY_ANONYMOUS = 0
cPIPE_REJECT_REMOTE_CLIENTS = 0x8
cPIPE_UNLIMITED_INSTANCES = 255
cNMPWAIT_USE_DEFAULT_WAIT = 0
cNMPWAIT_NOWAIT = 1
cSECURITY_SQOS_PRESENT = 0x100000
cSECURITY_ANONYMOUS = 0
cPIPE_TYPE_MESSAGE = 4
cPIPE_READMODE_MESSAGE = 2
cFILE_OPEN = 1
cFILE_CREATE = 2
cFILE_PIPE_MESSAGE_TYPE = 1
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
cSE_DACL_PRESENT = 4
)
var (
@@ -137,33 +181,60 @@ func (s pipeAddress) String() string {
return string(s)
}
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
func tryDialPipe(ctx context.Context, path *string, access uint32) (syscall.Handle, error) {
for {
select {
case <-ctx.Done():
return syscall.Handle(0), ctx.Err()
default:
h, err := createFile(*path, access, 0, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
if err == nil {
return h, nil
}
if err != cERROR_PIPE_BUSY {
return h, &os.PathError{Err: err, Op: "open", Path: *path}
}
// Wait 10 msec and try again. This is a rather simplistic
// view, as we always try each 10 milliseconds.
time.Sleep(10 * time.Millisecond)
}
}
}
// DialPipe connects to a named pipe by path, timing out if the connection
// takes longer than the specified duration. If timeout is nil, then we use
// a default timeout of 5 seconds. (We do not use WaitNamedPipe.)
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
var absTimeout time.Time
if timeout != nil {
absTimeout = time.Now().Add(*timeout)
} else {
absTimeout = time.Now().Add(time.Second * 2)
absTimeout = time.Now().Add(2 * time.Second)
}
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
conn, err := DialPipeContext(ctx, path)
if err == context.DeadlineExceeded {
return nil, ErrTimeout
}
return conn, err
}
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
// cancellation or timeout.
func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
return DialPipeAccess(ctx, path, syscall.GENERIC_READ|syscall.GENERIC_WRITE)
}
// DialPipeAccess attempts to connect to a named pipe by `path` with `access` until `ctx`
// cancellation or timeout.
func DialPipeAccess(ctx context.Context, path string, access uint32) (net.Conn, error) {
var err error
var h syscall.Handle
for {
h, err = createFile(path, syscall.GENERIC_READ|syscall.GENERIC_WRITE, 0, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
if err != cERROR_PIPE_BUSY {
break
}
if time.Now().After(absTimeout) {
return nil, ErrTimeout
}
// Wait 10 msec and try again. This is a rather simplistic
// view, as we always try each 10 milliseconds.
time.Sleep(time.Millisecond * 10)
}
h, err = tryDialPipe(ctx, &path, access)
if err != nil {
return nil, &os.PathError{Op: "open", Path: path, Err: err}
return nil, err
}
var flags uint32
@@ -194,43 +265,87 @@ type acceptResponse struct {
}
type win32PipeListener struct {
firstHandle syscall.Handle
path string
securityDescriptor []byte
config PipeConfig
acceptCh chan (chan acceptResponse)
closeCh chan int
doneCh chan int
firstHandle syscall.Handle
path string
config PipeConfig
acceptCh chan (chan acceptResponse)
closeCh chan int
doneCh chan int
}
func makeServerPipeHandle(path string, securityDescriptor []byte, c *PipeConfig, first bool) (syscall.Handle, error) {
var flags uint32 = cPIPE_ACCESS_DUPLEX | syscall.FILE_FLAG_OVERLAPPED
if first {
flags |= cFILE_FLAG_FIRST_PIPE_INSTANCE
}
var mode uint32 = cPIPE_REJECT_REMOTE_CLIENTS
if c.MessageMode {
mode |= cPIPE_TYPE_MESSAGE
}
sa := &syscall.SecurityAttributes{}
sa.Length = uint32(unsafe.Sizeof(*sa))
if securityDescriptor != nil {
len := uint32(len(securityDescriptor))
sa.SecurityDescriptor = localAlloc(0, len)
defer localFree(sa.SecurityDescriptor)
copy((*[0xffff]byte)(unsafe.Pointer(sa.SecurityDescriptor))[:], securityDescriptor)
}
h, err := createNamedPipe(path, flags, mode, cPIPE_UNLIMITED_INSTANCES, uint32(c.OutputBufferSize), uint32(c.InputBufferSize), 0, sa)
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (syscall.Handle, error) {
path16, err := syscall.UTF16FromString(path)
if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
var oa objectAttributes
oa.Length = unsafe.Sizeof(oa)
var ntPath unicodeString
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
defer localFree(ntPath.Buffer)
oa.ObjectName = &ntPath
// The security descriptor is only needed for the first pipe.
if first {
if sd != nil {
len := uint32(len(sd))
sdb := localAlloc(0, len)
defer localFree(sdb)
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
} else {
// Construct the default named pipe security descriptor.
var dacl uintptr
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
}
defer localFree(dacl)
sdb := &securityDescriptor{
Revision: 1,
Control: cSE_DACL_PRESENT,
Dacl: dacl,
}
oa.SecurityDescriptor = sdb
}
}
typ := uint32(cFILE_PIPE_REJECT_REMOTE_CLIENTS)
if c.MessageMode {
typ |= cFILE_PIPE_MESSAGE_TYPE
}
disposition := uint32(cFILE_OPEN)
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE)
if first {
disposition = cFILE_CREATE
// By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking
// client connections until the next call with first == false.
access = syscall.SYNCHRONIZE
}
timeout := int64(-50 * 10000) // 50ms
var (
h syscall.Handle
iosb ioStatusBlock
)
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
runtime.KeepAlive(ntPath)
return h, nil
}
func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
h, err := makeServerPipeHandle(l.path, l.securityDescriptor, &l.config, false)
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
if err != nil {
return nil, err
}
@@ -314,10 +429,10 @@ type PipeConfig struct {
// when the pipe is in message mode.
MessageMode bool
// InputBufferSize specifies the size the input buffer, in bytes.
// InputBufferSize specifies the size of the input buffer, in bytes.
InputBufferSize int32
// OutputBufferSize specifies the size the input buffer, in bytes.
// OutputBufferSize specifies the size of the output buffer, in bytes.
OutputBufferSize int32
}
@@ -341,32 +456,13 @@ func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
if err != nil {
return nil, err
}
// Create a client handle and connect it. This results in the pipe
// instance always existing, so that clients see ERROR_PIPE_BUSY
// rather than ERROR_FILE_NOT_FOUND. This ties the first instance
// up so that no other instances can be used. This would have been
// cleaner if the Win32 API matched CreateFile with ConnectNamedPipe
// instead of CreateNamedPipe. (Apparently created named pipes are
// considered to be in listening state regardless of whether any
// active calls to ConnectNamedPipe are outstanding.)
h2, err := createFile(path, 0, 0, nil, syscall.OPEN_EXISTING, cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
if err != nil {
syscall.Close(h)
return nil, err
}
// Close the client handle. The server side of the instance will
// still be busy, leading to ERROR_PIPE_BUSY instead of
// ERROR_NOT_FOUND, as long as we don't close the server handle,
// or disconnect the client with DisconnectNamedPipe.
syscall.Close(h2)
l := &win32PipeListener{
firstHandle: h,
path: path,
securityDescriptor: sd,
config: *c,
acceptCh: make(chan (chan acceptResponse)),
closeCh: make(chan int),
doneCh: make(chan int),
firstHandle: h,
path: path,
config: *c,
acceptCh: make(chan (chan acceptResponse)),
closeCh: make(chan int),
doneCh: make(chan int),
}
go l.listenerRoutine()
return l, nil

228
vendor/github.com/Microsoft/go-winio/pkg/guid/guid.go generated vendored Normal file
View File

@@ -0,0 +1,228 @@
// +build windows
// Package guid provides a GUID type. The backing structure for a GUID is
// identical to that used by the golang.org/x/sys/windows GUID type.
// There are two main binary encodings used for a GUID, the big-endian encoding,
// and the Windows (mixed-endian) encoding. See here for details:
// https://en.wikipedia.org/wiki/Universally_unique_identifier#Encoding
package guid
import (
"crypto/rand"
"crypto/sha1"
"encoding"
"encoding/binary"
"fmt"
"strconv"
)
// Variant specifies which GUID variant (or "type") of the GUID. It determines
// how the entirety of the rest of the GUID is interpreted.
type Variant uint8
// The variants specified by RFC 4122.
const (
// VariantUnknown specifies a GUID variant which does not conform to one of
// the variant encodings specified in RFC 4122.
VariantUnknown Variant = iota
VariantNCS
VariantRFC4122
VariantMicrosoft
VariantFuture
)
// Version specifies how the bits in the GUID were generated. For instance, a
// version 4 GUID is randomly generated, and a version 5 is generated from the
// hash of an input string.
type Version uint8
var _ = (encoding.TextMarshaler)(GUID{})
var _ = (encoding.TextUnmarshaler)(&GUID{})
// NewV4 returns a new version 4 (pseudorandom) GUID, as defined by RFC 4122.
func NewV4() (GUID, error) {
var b [16]byte
if _, err := rand.Read(b[:]); err != nil {
return GUID{}, err
}
g := FromArray(b)
g.setVersion(4) // Version 4 means randomly generated.
g.setVariant(VariantRFC4122)
return g, nil
}
// NewV5 returns a new version 5 (generated from a string via SHA-1 hashing)
// GUID, as defined by RFC 4122. The RFC is unclear on the encoding of the name,
// and the sample code treats it as a series of bytes, so we do the same here.
//
// Some implementations, such as those found on Windows, treat the name as a
// big-endian UTF16 stream of bytes. If that is desired, the string can be
// encoded as such before being passed to this function.
func NewV5(namespace GUID, name []byte) (GUID, error) {
b := sha1.New()
namespaceBytes := namespace.ToArray()
b.Write(namespaceBytes[:])
b.Write(name)
a := [16]byte{}
copy(a[:], b.Sum(nil))
g := FromArray(a)
g.setVersion(5) // Version 5 means generated from a string.
g.setVariant(VariantRFC4122)
return g, nil
}
func fromArray(b [16]byte, order binary.ByteOrder) GUID {
var g GUID
g.Data1 = order.Uint32(b[0:4])
g.Data2 = order.Uint16(b[4:6])
g.Data3 = order.Uint16(b[6:8])
copy(g.Data4[:], b[8:16])
return g
}
func (g GUID) toArray(order binary.ByteOrder) [16]byte {
b := [16]byte{}
order.PutUint32(b[0:4], g.Data1)
order.PutUint16(b[4:6], g.Data2)
order.PutUint16(b[6:8], g.Data3)
copy(b[8:16], g.Data4[:])
return b
}
// FromArray constructs a GUID from a big-endian encoding array of 16 bytes.
func FromArray(b [16]byte) GUID {
return fromArray(b, binary.BigEndian)
}
// ToArray returns an array of 16 bytes representing the GUID in big-endian
// encoding.
func (g GUID) ToArray() [16]byte {
return g.toArray(binary.BigEndian)
}
// FromWindowsArray constructs a GUID from a Windows encoding array of bytes.
func FromWindowsArray(b [16]byte) GUID {
return fromArray(b, binary.LittleEndian)
}
// ToWindowsArray returns an array of 16 bytes representing the GUID in Windows
// encoding.
func (g GUID) ToWindowsArray() [16]byte {
return g.toArray(binary.LittleEndian)
}
func (g GUID) String() string {
return fmt.Sprintf(
"%08x-%04x-%04x-%04x-%012x",
g.Data1,
g.Data2,
g.Data3,
g.Data4[:2],
g.Data4[2:])
}
// FromString parses a string containing a GUID and returns the GUID. The only
// format currently supported is the `xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx`
// format.
func FromString(s string) (GUID, error) {
if len(s) != 36 {
return GUID{}, fmt.Errorf("invalid GUID %q", s)
}
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
return GUID{}, fmt.Errorf("invalid GUID %q", s)
}
var g GUID
data1, err := strconv.ParseUint(s[0:8], 16, 32)
if err != nil {
return GUID{}, fmt.Errorf("invalid GUID %q", s)
}
g.Data1 = uint32(data1)
data2, err := strconv.ParseUint(s[9:13], 16, 16)
if err != nil {
return GUID{}, fmt.Errorf("invalid GUID %q", s)
}
g.Data2 = uint16(data2)
data3, err := strconv.ParseUint(s[14:18], 16, 16)
if err != nil {
return GUID{}, fmt.Errorf("invalid GUID %q", s)
}
g.Data3 = uint16(data3)
for i, x := range []int{19, 21, 24, 26, 28, 30, 32, 34} {
v, err := strconv.ParseUint(s[x:x+2], 16, 8)
if err != nil {
return GUID{}, fmt.Errorf("invalid GUID %q", s)
}
g.Data4[i] = uint8(v)
}
return g, nil
}
func (g *GUID) setVariant(v Variant) {
d := g.Data4[0]
switch v {
case VariantNCS:
d = (d & 0x7f)
case VariantRFC4122:
d = (d & 0x3f) | 0x80
case VariantMicrosoft:
d = (d & 0x1f) | 0xc0
case VariantFuture:
d = (d & 0x0f) | 0xe0
case VariantUnknown:
fallthrough
default:
panic(fmt.Sprintf("invalid variant: %d", v))
}
g.Data4[0] = d
}
// Variant returns the GUID variant, as defined in RFC 4122.
func (g GUID) Variant() Variant {
b := g.Data4[0]
if b&0x80 == 0 {
return VariantNCS
} else if b&0xc0 == 0x80 {
return VariantRFC4122
} else if b&0xe0 == 0xc0 {
return VariantMicrosoft
} else if b&0xe0 == 0xe0 {
return VariantFuture
}
return VariantUnknown
}
func (g *GUID) setVersion(v Version) {
g.Data3 = (g.Data3 & 0x0fff) | (uint16(v) << 12)
}
// Version returns the GUID version, as defined in RFC 4122.
func (g GUID) Version() Version {
return Version((g.Data3 & 0xF000) >> 12)
}
// MarshalText returns the textual representation of the GUID.
func (g GUID) MarshalText() ([]byte, error) {
return []byte(g.String()), nil
}
// UnmarshalText takes the textual representation of a GUID, and unmarhals it
// into this GUID.
func (g *GUID) UnmarshalText(text []byte) error {
g2, err := FromString(string(text))
if err != nil {
return err
}
*g = g2
return nil
}

View File

@@ -0,0 +1,15 @@
// +build !windows
package guid
// GUID represents a GUID/UUID. It has the same structure as
// golang.org/x/sys/windows.GUID so that it can be used with functions expecting
// that type. It is defined as its own type as that is only available to builds
// targeted at `windows`. The representation matches that used by native Windows
// code.
type GUID struct {
Data1 uint32
Data2 uint16
Data3 uint16
Data4 [8]byte
}

View File

@@ -0,0 +1,10 @@
package guid
import "golang.org/x/sys/windows"
// GUID represents a GUID/UUID. It has the same structure as
// golang.org/x/sys/windows.GUID so that it can be used with functions expecting
// that type. It is defined as its own type so that stringification and
// marshaling can be supported. The representation matches that used by native
// Windows code.
type GUID windows.GUID

View File

@@ -28,8 +28,9 @@ const (
ERROR_NOT_ALL_ASSIGNED syscall.Errno = 1300
SeBackupPrivilege = "SeBackupPrivilege"
SeRestorePrivilege = "SeRestorePrivilege"
SeBackupPrivilege = "SeBackupPrivilege"
SeRestorePrivilege = "SeRestorePrivilege"
SeSecurityPrivilege = "SeSecurityPrivilege"
)
const (

View File

@@ -1,3 +1,3 @@
package winio
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go file.go pipe.go sd.go fileinfo.go privilege.go backup.go
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go file.go pipe.go sd.go fileinfo.go privilege.go backup.go hvsock.go

View File

@@ -1,4 +1,4 @@
// MACHINE GENERATED BY 'go generate' COMMAND; DO NOT EDIT
// Code generated by 'go generate'; DO NOT EDIT.
package winio
@@ -19,6 +19,7 @@ const (
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
errERROR_EINVAL error = syscall.EINVAL
)
// errnoErr returns common boxed Errno values, to prevent
@@ -26,7 +27,7 @@ var (
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return nil
return errERROR_EINVAL
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
@@ -37,213 +38,62 @@ func errnoErr(e syscall.Errno) error {
}
var (
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
modntdll = windows.NewLazySystemDLL("ntdll.dll")
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
procCreateFileW = modkernel32.NewProc("CreateFileW")
procWaitNamedPipeW = modkernel32.NewProc("WaitNamedPipeW")
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
procLocalAlloc = modkernel32.NewProc("LocalAlloc")
procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW")
procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges")
procConvertSecurityDescriptorToStringSecurityDescriptorW = modadvapi32.NewProc("ConvertSecurityDescriptorToStringSecurityDescriptorW")
procConvertSidToStringSidW = modadvapi32.NewProc("ConvertSidToStringSidW")
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
procConvertSecurityDescriptorToStringSecurityDescriptorW = modadvapi32.NewProc("ConvertSecurityDescriptorToStringSecurityDescriptorW")
procLocalFree = modkernel32.NewProc("LocalFree")
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength")
procGetFileInformationByHandleEx = modkernel32.NewProc("GetFileInformationByHandleEx")
procSetFileInformationByHandle = modkernel32.NewProc("SetFileInformationByHandle")
procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges")
procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf")
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken")
procGetCurrentThread = modkernel32.NewProc("GetCurrentThread")
procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW")
procLookupPrivilegeNameW = modadvapi32.NewProc("LookupPrivilegeNameW")
procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW")
procLookupPrivilegeDisplayNameW = modadvapi32.NewProc("LookupPrivilegeDisplayNameW")
procLookupPrivilegeNameW = modadvapi32.NewProc("LookupPrivilegeNameW")
procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW")
procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken")
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
procBackupRead = modkernel32.NewProc("BackupRead")
procBackupWrite = modkernel32.NewProc("BackupWrite")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
procCreateFileW = modkernel32.NewProc("CreateFileW")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
procGetCurrentThread = modkernel32.NewProc("GetCurrentThread")
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
procLocalAlloc = modkernel32.NewProc("LocalAlloc")
procLocalFree = modkernel32.NewProc("LocalFree")
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
procbind = modws2_32.NewProc("bind")
)
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) {
var _p0 uint32
if releaseAll {
_p0 = 1
}
r0, _, e1 := syscall.Syscall6(procAdjustTokenPrivileges.Addr(), 6, uintptr(token), uintptr(_p0), uintptr(unsafe.Pointer(input)), uintptr(outputSize), uintptr(unsafe.Pointer(output)), uintptr(unsafe.Pointer(requiredSize)))
success = r0 != 0
if true {
err = errnoErr(e1)
}
return
}
func convertSecurityDescriptorToStringSecurityDescriptor(sd *byte, revision uint32, secInfo uint32, sddl **uint16, sddlSize *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procConvertSecurityDescriptorToStringSecurityDescriptorW.Addr(), 5, uintptr(unsafe.Pointer(sd)), uintptr(revision), uintptr(secInfo), uintptr(unsafe.Pointer(sddl)), uintptr(unsafe.Pointer(sddlSize)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
newport = syscall.Handle(r0)
if newport == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name)
if err != nil {
return
}
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
}
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
handle = syscall.Handle(r0)
if handle == syscall.InvalidHandle {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name)
if err != nil {
return
}
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
}
func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
handle = syscall.Handle(r0)
if handle == syscall.InvalidHandle {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func waitNamedPipe(name string, timeout uint32) (err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name)
if err != nil {
return
}
return _waitNamedPipe(_p0, timeout)
}
func _waitNamedPipe(name *uint16, timeout uint32) (err error) {
r1, _, e1 := syscall.Syscall(procWaitNamedPipeW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(timeout), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
r0, _, _ := syscall.Syscall(procLocalAlloc.Addr(), 2, uintptr(uFlags), uintptr(length), 0)
ptr = uintptr(r0)
return
}
func lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(accountName)
if err != nil {
return
}
return _lookupAccountName(systemName, _p0, sid, sidSize, refDomain, refDomainSize, sidNameUse)
}
func _lookupAccountName(systemName *uint16, accountName *uint16, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
r1, _, e1 := syscall.Syscall9(procLookupAccountNameW.Addr(), 7, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(accountName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(sidSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
err = errnoErr(e1)
}
return
}
@@ -251,11 +101,7 @@ func _lookupAccountName(systemName *uint16, accountName *uint16, sid *byte, sidS
func convertSidToStringSid(sid *byte, str **uint16) (err error) {
r1, _, e1 := syscall.Syscall(procConvertSidToStringSidW.Addr(), 2, uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(str)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
err = errnoErr(e1)
}
return
}
@@ -272,126 +118,73 @@ func convertStringSecurityDescriptorToSecurityDescriptor(str string, revision ui
func _convertStringSecurityDescriptorToSecurityDescriptor(str *uint16, revision uint32, sd *uintptr, size *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procConvertStringSecurityDescriptorToSecurityDescriptorW.Addr(), 4, uintptr(unsafe.Pointer(str)), uintptr(revision), uintptr(unsafe.Pointer(sd)), uintptr(unsafe.Pointer(size)), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
err = errnoErr(e1)
}
return
}
func convertSecurityDescriptorToStringSecurityDescriptor(sd *byte, revision uint32, secInfo uint32, sddl **uint16, sddlSize *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procConvertSecurityDescriptorToStringSecurityDescriptorW.Addr(), 5, uintptr(unsafe.Pointer(sd)), uintptr(revision), uintptr(secInfo), uintptr(unsafe.Pointer(sddl)), uintptr(unsafe.Pointer(sddlSize)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func localFree(mem uintptr) {
syscall.Syscall(procLocalFree.Addr(), 1, uintptr(mem), 0, 0)
return
}
func getSecurityDescriptorLength(sd uintptr) (len uint32) {
r0, _, _ := syscall.Syscall(procGetSecurityDescriptorLength.Addr(), 1, uintptr(sd), 0, 0)
len = uint32(r0)
return
}
func getFileInformationByHandleEx(h syscall.Handle, class uint32, buffer *byte, size uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetFileInformationByHandleEx.Addr(), 4, uintptr(h), uintptr(class), uintptr(unsafe.Pointer(buffer)), uintptr(size), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func setFileInformationByHandle(h syscall.Handle, class uint32, buffer *byte, size uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procSetFileInformationByHandle.Addr(), 4, uintptr(h), uintptr(class), uintptr(unsafe.Pointer(buffer)), uintptr(size), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func adjustTokenPrivileges(token windows.Token, releaseAll bool, input *byte, outputSize uint32, output *byte, requiredSize *uint32) (success bool, err error) {
var _p0 uint32
if releaseAll {
_p0 = 1
} else {
_p0 = 0
}
r0, _, e1 := syscall.Syscall6(procAdjustTokenPrivileges.Addr(), 6, uintptr(token), uintptr(_p0), uintptr(unsafe.Pointer(input)), uintptr(outputSize), uintptr(unsafe.Pointer(output)), uintptr(unsafe.Pointer(requiredSize)))
success = r0 != 0
if true {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func impersonateSelf(level uint32) (err error) {
r1, _, e1 := syscall.Syscall(procImpersonateSelf.Addr(), 1, uintptr(level), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
err = errnoErr(e1)
}
return
}
func revertToSelf() (err error) {
r1, _, e1 := syscall.Syscall(procRevertToSelf.Addr(), 0, 0, 0, 0)
func lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(accountName)
if err != nil {
return
}
return _lookupAccountName(systemName, _p0, sid, sidSize, refDomain, refDomainSize, sidNameUse)
}
func _lookupAccountName(systemName *uint16, accountName *uint16, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) {
r1, _, e1 := syscall.Syscall9(procLookupAccountNameW.Addr(), 7, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(accountName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(sidSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
err = errnoErr(e1)
}
return
}
func openThreadToken(thread syscall.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) {
var _p0 uint32
if openAsSelf {
_p0 = 1
} else {
_p0 = 0
func lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(systemName)
if err != nil {
return
}
r1, _, e1 := syscall.Syscall6(procOpenThreadToken.Addr(), 4, uintptr(thread), uintptr(accessMask), uintptr(_p0), uintptr(unsafe.Pointer(token)), 0, 0)
return _lookupPrivilegeDisplayName(_p0, name, buffer, size, languageId)
}
func _lookupPrivilegeDisplayName(systemName *uint16, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procLookupPrivilegeDisplayNameW.Addr(), 5, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), uintptr(unsafe.Pointer(languageId)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
err = errnoErr(e1)
}
return
}
func getCurrentThread() (h syscall.Handle) {
r0, _, _ := syscall.Syscall(procGetCurrentThread.Addr(), 0, 0, 0, 0)
h = syscall.Handle(r0)
func lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(systemName)
if err != nil {
return
}
return _lookupPrivilegeName(_p0, luid, buffer, size)
}
func _lookupPrivilegeName(systemName *uint16, luid *uint64, buffer *uint16, size *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procLookupPrivilegeNameW.Addr(), 4, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(luid)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), 0, 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
@@ -412,53 +205,27 @@ func lookupPrivilegeValue(systemName string, name string, luid *uint64) (err err
func _lookupPrivilegeValue(systemName *uint16, name *uint16, luid *uint64) (err error) {
r1, _, e1 := syscall.Syscall(procLookupPrivilegeValueW.Addr(), 3, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(luid)))
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
err = errnoErr(e1)
}
return
}
func lookupPrivilegeName(systemName string, luid *uint64, buffer *uint16, size *uint32) (err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(systemName)
if err != nil {
return
func openThreadToken(thread syscall.Handle, accessMask uint32, openAsSelf bool, token *windows.Token) (err error) {
var _p0 uint32
if openAsSelf {
_p0 = 1
}
return _lookupPrivilegeName(_p0, luid, buffer, size)
}
func _lookupPrivilegeName(systemName *uint16, luid *uint64, buffer *uint16, size *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procLookupPrivilegeNameW.Addr(), 4, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(luid)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), 0, 0)
r1, _, e1 := syscall.Syscall6(procOpenThreadToken.Addr(), 4, uintptr(thread), uintptr(accessMask), uintptr(_p0), uintptr(unsafe.Pointer(token)), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
err = errnoErr(e1)
}
return
}
func lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(systemName)
if err != nil {
return
}
return _lookupPrivilegeDisplayName(_p0, name, buffer, size, languageId)
}
func _lookupPrivilegeDisplayName(systemName *uint16, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procLookupPrivilegeDisplayNameW.Addr(), 5, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buffer)), uintptr(unsafe.Pointer(size)), uintptr(unsafe.Pointer(languageId)), 0)
func revertToSelf() (err error) {
r1, _, e1 := syscall.Syscall(procRevertToSelf.Addr(), 0, 0, 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
err = errnoErr(e1)
}
return
}
@@ -471,22 +238,14 @@ func backupRead(h syscall.Handle, b []byte, bytesRead *uint32, abort bool, proce
var _p1 uint32
if abort {
_p1 = 1
} else {
_p1 = 0
}
var _p2 uint32
if processSecurity {
_p2 = 1
} else {
_p2 = 0
}
r1, _, e1 := syscall.Syscall9(procBackupRead.Addr(), 7, uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesRead)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context)), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
err = errnoErr(e1)
}
return
}
@@ -499,22 +258,170 @@ func backupWrite(h syscall.Handle, b []byte, bytesWritten *uint32, abort bool, p
var _p1 uint32
if abort {
_p1 = 1
} else {
_p1 = 0
}
var _p2 uint32
if processSecurity {
_p2 = 1
} else {
_p2 = 0
}
r1, _, e1 := syscall.Syscall9(procBackupWrite.Addr(), 7, uintptr(h), uintptr(unsafe.Pointer(_p0)), uintptr(len(b)), uintptr(unsafe.Pointer(bytesWritten)), uintptr(_p1), uintptr(_p2), uintptr(unsafe.Pointer(context)), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
err = errnoErr(e1)
}
return
}
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name)
if err != nil {
return
}
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
}
func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
handle = syscall.Handle(r0)
if handle == syscall.InvalidHandle {
err = errnoErr(e1)
}
return
}
func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
newport = syscall.Handle(r0)
if newport == 0 {
err = errnoErr(e1)
}
return
}
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name)
if err != nil {
return
}
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
}
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
handle = syscall.Handle(r0)
if handle == syscall.InvalidHandle {
err = errnoErr(e1)
}
return
}
func getCurrentThread() (h syscall.Handle) {
r0, _, _ := syscall.Syscall(procGetCurrentThread.Addr(), 0, 0, 0, 0)
h = syscall.Handle(r0)
return
}
func getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
r0, _, _ := syscall.Syscall(procLocalAlloc.Addr(), 2, uintptr(uFlags), uintptr(length), 0)
ptr = uintptr(r0)
return
}
func localFree(mem uintptr) {
syscall.Syscall(procLocalFree.Addr(), 1, uintptr(mem), 0, 0)
return
}
func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
status = ntstatus(r0)
return
}
func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) {
r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0)
status = ntstatus(r0)
return
}
func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) {
r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0)
status = ntstatus(r0)
return
}
func rtlNtStatusToDosError(status ntstatus) (winerr error) {
r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0)
if r0 != 0 {
winerr = syscall.Errno(r0)
}
return
}
func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
var _p0 uint32
if wait {
_p0 = 1
}
r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) {
r1, _, e1 := syscall.Syscall(procbind.Addr(), 3, uintptr(s), uintptr(name), uintptr(namelen))
if r1 == socketError {
err = errnoErr(e1)
}
return
}

View File

@@ -48,7 +48,6 @@ CRC64ISOShort-8 22.2ns ± 3%
Fnv64-8 2.34µs ± 1%
Fnv64Short-8 74.7ns ± 8%
#
```
## Usage
@@ -63,7 +62,7 @@ Fnv64Short-8 74.7ns ± 8%
fmt.Println("File checksum:", h.Sum64())
```
[<kbd>playground</kbd>](http://play.golang.org/p/rhRN3RdQyd)
[<kbd>playground</kbd>](https://play.golang.org/p/wHKBwfu6CPV)
## TODO
@@ -72,4 +71,4 @@ Fnv64Short-8 74.7ns ± 8%
## License
This project is released under the Apache v2. licence. See [LICENCE](LICENCE) for more details.
This project is released under the Apache v2. license. See [LICENSE](LICENSE) for more details.

View File

@@ -1,6 +1,10 @@
package xxhash
import "hash"
import (
"encoding/binary"
"errors"
"hash"
)
const (
prime32x1 uint32 = 2654435761
@@ -24,6 +28,13 @@ const (
zero64x4 = 0x61c8864e7a143579
)
const (
magic32 = "xxh\x07"
magic64 = "xxh\x08"
marshaled32Size = len(magic32) + 4*7 + 16
marshaled64Size = len(magic64) + 8*6 + 32 + 1
)
func NewHash32() hash.Hash { return New32() }
func NewHash64() hash.Hash { return New64() }
@@ -86,6 +97,41 @@ func (xx *XXHash32) Sum(in []byte) []byte {
return append(in, byte(s>>24), byte(s>>16), byte(s>>8), byte(s))
}
// MarshalBinary implements the encoding.BinaryMarshaler interface.
func (xx *XXHash32) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaled32Size)
b = append(b, magic32...)
b = appendUint32(b, xx.v1)
b = appendUint32(b, xx.v2)
b = appendUint32(b, xx.v3)
b = appendUint32(b, xx.v4)
b = appendUint32(b, xx.seed)
b = appendInt32(b, xx.ln)
b = appendInt32(b, xx.memIdx)
b = append(b, xx.mem[:]...)
return b, nil
}
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
func (xx *XXHash32) UnmarshalBinary(b []byte) error {
if len(b) < len(magic32) || string(b[:len(magic32)]) != magic32 {
return errors.New("xxhash: invalid hash state identifier")
}
if len(b) != marshaled32Size {
return errors.New("xxhash: invalid hash state size")
}
b = b[len(magic32):]
b, xx.v1 = consumeUint32(b)
b, xx.v2 = consumeUint32(b)
b, xx.v3 = consumeUint32(b)
b, xx.v4 = consumeUint32(b)
b, xx.seed = consumeUint32(b)
b, xx.ln = consumeInt32(b)
b, xx.memIdx = consumeInt32(b)
copy(xx.mem[:], b)
return nil
}
// Checksum64 an alias for Checksum64S(in, 0)
func Checksum64(in []byte) uint64 {
return Checksum64S(in, 0)
@@ -143,6 +189,60 @@ func (xx *XXHash64) Sum(in []byte) []byte {
return append(in, byte(s>>56), byte(s>>48), byte(s>>40), byte(s>>32), byte(s>>24), byte(s>>16), byte(s>>8), byte(s))
}
// MarshalBinary implements the encoding.BinaryMarshaler interface.
func (xx *XXHash64) MarshalBinary() ([]byte, error) {
b := make([]byte, 0, marshaled64Size)
b = append(b, magic64...)
b = appendUint64(b, xx.v1)
b = appendUint64(b, xx.v2)
b = appendUint64(b, xx.v3)
b = appendUint64(b, xx.v4)
b = appendUint64(b, xx.seed)
b = appendUint64(b, xx.ln)
b = append(b, byte(xx.memIdx))
b = append(b, xx.mem[:]...)
return b, nil
}
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface.
func (xx *XXHash64) UnmarshalBinary(b []byte) error {
if len(b) < len(magic64) || string(b[:len(magic64)]) != magic64 {
return errors.New("xxhash: invalid hash state identifier")
}
if len(b) != marshaled64Size {
return errors.New("xxhash: invalid hash state size")
}
b = b[len(magic64):]
b, xx.v1 = consumeUint64(b)
b, xx.v2 = consumeUint64(b)
b, xx.v3 = consumeUint64(b)
b, xx.v4 = consumeUint64(b)
b, xx.seed = consumeUint64(b)
b, xx.ln = consumeUint64(b)
xx.memIdx = int8(b[0])
b = b[1:]
copy(xx.mem[:], b)
return nil
}
func appendInt32(b []byte, x int32) []byte { return appendUint32(b, uint32(x)) }
func appendUint32(b []byte, x uint32) []byte {
var a [4]byte
binary.LittleEndian.PutUint32(a[:], x)
return append(b, a[:]...)
}
func appendUint64(b []byte, x uint64) []byte {
var a [8]byte
binary.LittleEndian.PutUint64(a[:], x)
return append(b, a[:]...)
}
func consumeInt32(b []byte) ([]byte, int32) { bn, x := consumeUint32(b); return bn, int32(x) }
func consumeUint32(b []byte) ([]byte, uint32) { x := u32(b); return b[4:], x }
func consumeUint64(b []byte) ([]byte, uint64) { x := u64(b); return b[8:], x }
// force the compiler to use ROTL instructions
func rotl32_1(x uint32) uint32 { return (x << 1) | (x >> (32 - 1)) }

View File

@@ -58,10 +58,9 @@ func checksum64(in []byte, seed uint64) uint64 {
wordsLen = len(in) >> 3
words = ((*[maxInt32 / 8]uint64)(unsafe.Pointer(&in[0])))[:wordsLen:wordsLen]
h uint64 = prime64x5
v1, v2, v3, v4 = resetVs64(seed)
h uint64
i int
)

View File

@@ -1,5 +0,0 @@
*.sublime-*
.DS_Store
*.swp
*.swo
tags

View File

@@ -1,12 +0,0 @@
language: go
go:
- 1.4.x
- 1.5.x
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
- "1.10.x"
- "1.11.x"
- tip

View File

@@ -1,12 +0,0 @@
Copyright (c) 2012, Martin Angers
All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of the author nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -1,188 +0,0 @@
# Purell
Purell is a tiny Go library to normalize URLs. It returns a pure URL. Pure-ell. Sanitizer and all. Yeah, I know...
Based on the [wikipedia paper][wiki] and the [RFC 3986 document][rfc].
[![build status](https://travis-ci.org/PuerkitoBio/purell.svg?branch=master)](http://travis-ci.org/PuerkitoBio/purell)
## Install
`go get github.com/PuerkitoBio/purell`
## Changelog
* **v1.1.1** : Fix failing test due to Go1.12 changes (thanks to @ianlancetaylor).
* **2016-11-14 (v1.1.0)** : IDN: Conform to RFC 5895: Fold character width (thanks to @beeker1121).
* **2016-07-27 (v1.0.0)** : Normalize IDN to ASCII (thanks to @zenovich).
* **2015-02-08** : Add fix for relative paths issue ([PR #5][pr5]) and add fix for unnecessary encoding of reserved characters ([see issue #7][iss7]).
* **v0.2.0** : Add benchmarks, Attempt IDN support.
* **v0.1.0** : Initial release.
## Examples
From `example_test.go` (note that in your code, you would import "github.com/PuerkitoBio/purell", and would prefix references to its methods and constants with "purell."):
```go
package purell
import (
"fmt"
"net/url"
)
func ExampleNormalizeURLString() {
if normalized, err := NormalizeURLString("hTTp://someWEBsite.com:80/Amazing%3f/url/",
FlagLowercaseScheme|FlagLowercaseHost|FlagUppercaseEscapes); err != nil {
panic(err)
} else {
fmt.Print(normalized)
}
// Output: http://somewebsite.com:80/Amazing%3F/url/
}
func ExampleMustNormalizeURLString() {
normalized := MustNormalizeURLString("hTTpS://someWEBsite.com:443/Amazing%fa/url/",
FlagsUnsafeGreedy)
fmt.Print(normalized)
// Output: http://somewebsite.com/Amazing%FA/url
}
func ExampleNormalizeURL() {
if u, err := url.Parse("Http://SomeUrl.com:8080/a/b/.././c///g?c=3&a=1&b=9&c=0#target"); err != nil {
panic(err)
} else {
normalized := NormalizeURL(u, FlagsUsuallySafeGreedy|FlagRemoveDuplicateSlashes|FlagRemoveFragment)
fmt.Print(normalized)
}
// Output: http://someurl.com:8080/a/c/g?c=3&a=1&b=9&c=0
}
```
## API
As seen in the examples above, purell offers three methods, `NormalizeURLString(string, NormalizationFlags) (string, error)`, `MustNormalizeURLString(string, NormalizationFlags) (string)` and `NormalizeURL(*url.URL, NormalizationFlags) (string)`. They all normalize the provided URL based on the specified flags. Here are the available flags:
```go
const (
// Safe normalizations
FlagLowercaseScheme NormalizationFlags = 1 << iota // HTTP://host -> http://host, applied by default in Go1.1
FlagLowercaseHost // http://HOST -> http://host
FlagUppercaseEscapes // http://host/t%ef -> http://host/t%EF
FlagDecodeUnnecessaryEscapes // http://host/t%41 -> http://host/tA
FlagEncodeNecessaryEscapes // http://host/!"#$ -> http://host/%21%22#$
FlagRemoveDefaultPort // http://host:80 -> http://host
FlagRemoveEmptyQuerySeparator // http://host/path? -> http://host/path
// Usually safe normalizations
FlagRemoveTrailingSlash // http://host/path/ -> http://host/path
FlagAddTrailingSlash // http://host/path -> http://host/path/ (should choose only one of these add/remove trailing slash flags)
FlagRemoveDotSegments // http://host/path/./a/b/../c -> http://host/path/a/c
// Unsafe normalizations
FlagRemoveDirectoryIndex // http://host/path/index.html -> http://host/path/
FlagRemoveFragment // http://host/path#fragment -> http://host/path
FlagForceHTTP // https://host -> http://host
FlagRemoveDuplicateSlashes // http://host/path//a///b -> http://host/path/a/b
FlagRemoveWWW // http://www.host/ -> http://host/
FlagAddWWW // http://host/ -> http://www.host/ (should choose only one of these add/remove WWW flags)
FlagSortQuery // http://host/path?c=3&b=2&a=1&b=1 -> http://host/path?a=1&b=1&b=2&c=3
// Normalizations not in the wikipedia article, required to cover tests cases
// submitted by jehiah
FlagDecodeDWORDHost // http://1113982867 -> http://66.102.7.147
FlagDecodeOctalHost // http://0102.0146.07.0223 -> http://66.102.7.147
FlagDecodeHexHost // http://0x42660793 -> http://66.102.7.147
FlagRemoveUnnecessaryHostDots // http://.host../path -> http://host/path
FlagRemoveEmptyPortSeparator // http://host:/path -> http://host/path
// Convenience set of safe normalizations
FlagsSafe NormalizationFlags = FlagLowercaseHost | FlagLowercaseScheme | FlagUppercaseEscapes | FlagDecodeUnnecessaryEscapes | FlagEncodeNecessaryEscapes | FlagRemoveDefaultPort | FlagRemoveEmptyQuerySeparator
// For convenience sets, "greedy" uses the "remove trailing slash" and "remove www. prefix" flags,
// while "non-greedy" uses the "add (or keep) the trailing slash" and "add www. prefix".
// Convenience set of usually safe normalizations (includes FlagsSafe)
FlagsUsuallySafeGreedy NormalizationFlags = FlagsSafe | FlagRemoveTrailingSlash | FlagRemoveDotSegments
FlagsUsuallySafeNonGreedy NormalizationFlags = FlagsSafe | FlagAddTrailingSlash | FlagRemoveDotSegments
// Convenience set of unsafe normalizations (includes FlagsUsuallySafe)
FlagsUnsafeGreedy NormalizationFlags = FlagsUsuallySafeGreedy | FlagRemoveDirectoryIndex | FlagRemoveFragment | FlagForceHTTP | FlagRemoveDuplicateSlashes | FlagRemoveWWW | FlagSortQuery
FlagsUnsafeNonGreedy NormalizationFlags = FlagsUsuallySafeNonGreedy | FlagRemoveDirectoryIndex | FlagRemoveFragment | FlagForceHTTP | FlagRemoveDuplicateSlashes | FlagAddWWW | FlagSortQuery
// Convenience set of all available flags
FlagsAllGreedy = FlagsUnsafeGreedy | FlagDecodeDWORDHost | FlagDecodeOctalHost | FlagDecodeHexHost | FlagRemoveUnnecessaryHostDots | FlagRemoveEmptyPortSeparator
FlagsAllNonGreedy = FlagsUnsafeNonGreedy | FlagDecodeDWORDHost | FlagDecodeOctalHost | FlagDecodeHexHost | FlagRemoveUnnecessaryHostDots | FlagRemoveEmptyPortSeparator
)
```
For convenience, the set of flags `FlagsSafe`, `FlagsUsuallySafe[Greedy|NonGreedy]`, `FlagsUnsafe[Greedy|NonGreedy]` and `FlagsAll[Greedy|NonGreedy]` are provided for the similarly grouped normalizations on [wikipedia's URL normalization page][wiki]. You can add (using the bitwise OR `|` operator) or remove (using the bitwise AND NOT `&^` operator) individual flags from the sets if required, to build your own custom set.
The [full godoc reference is available on gopkgdoc][godoc].
Some things to note:
* `FlagDecodeUnnecessaryEscapes`, `FlagEncodeNecessaryEscapes`, `FlagUppercaseEscapes` and `FlagRemoveEmptyQuerySeparator` are always implicitly set, because internally, the URL string is parsed as an URL object, which automatically decodes unnecessary escapes, uppercases and encodes necessary ones, and removes empty query separators (an unnecessary `?` at the end of the url). So this operation cannot **not** be done. For this reason, `FlagRemoveEmptyQuerySeparator` (as well as the other three) has been included in the `FlagsSafe` convenience set, instead of `FlagsUnsafe`, where Wikipedia puts it.
* The `FlagDecodeUnnecessaryEscapes` decodes the following escapes (*from -> to*):
- %24 -> $
- %26 -> &
- %2B-%3B -> +,-./0123456789:;
- %3D -> =
- %40-%5A -> @ABCDEFGHIJKLMNOPQRSTUVWXYZ
- %5F -> _
- %61-%7A -> abcdefghijklmnopqrstuvwxyz
- %7E -> ~
* When the `NormalizeURL` function is used (passing an URL object), this source URL object is modified (that is, after the call, the URL object will be modified to reflect the normalization).
* The *replace IP with domain name* normalization (`http://208.77.188.166/ → http://www.example.com/`) is obviously not possible for a library without making some network requests. This is not implemented in purell.
* The *remove unused query string parameters* and *remove default query parameters* are also not implemented, since this is a very case-specific normalization, and it is quite trivial to do with an URL object.
### Safe vs Usually Safe vs Unsafe
Purell allows you to control the level of risk you take while normalizing an URL. You can aggressively normalize, play it totally safe, or anything in between.
Consider the following URL:
`HTTPS://www.RooT.com/toto/t%45%1f///a/./b/../c/?z=3&w=2&a=4&w=1#invalid`
Normalizing with the `FlagsSafe` gives:
`https://www.root.com/toto/tE%1F///a/./b/../c/?z=3&w=2&a=4&w=1#invalid`
With the `FlagsUsuallySafeGreedy`:
`https://www.root.com/toto/tE%1F///a/c?z=3&w=2&a=4&w=1#invalid`
And with `FlagsUnsafeGreedy`:
`http://root.com/toto/tE%1F/a/c?a=4&w=1&w=2&z=3`
## TODOs
* Add a class/default instance to allow specifying custom directory index names? At the moment, removing directory index removes `(^|/)((?:default|index)\.\w{1,4})$`.
## Thanks / Contributions
@rogpeppe
@jehiah
@opennota
@pchristopher1275
@zenovich
@beeker1121
## License
The [BSD 3-Clause license][bsd].
[bsd]: http://opensource.org/licenses/BSD-3-Clause
[wiki]: http://en.wikipedia.org/wiki/URL_normalization
[rfc]: http://tools.ietf.org/html/rfc3986#section-6
[godoc]: http://go.pkgdoc.org/github.com/PuerkitoBio/purell
[pr5]: https://github.com/PuerkitoBio/purell/pull/5
[iss7]: https://github.com/PuerkitoBio/purell/issues/7

View File

@@ -1,379 +0,0 @@
/*
Package purell offers URL normalization as described on the wikipedia page:
http://en.wikipedia.org/wiki/URL_normalization
*/
package purell
import (
"bytes"
"fmt"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
"github.com/PuerkitoBio/urlesc"
"golang.org/x/net/idna"
"golang.org/x/text/unicode/norm"
"golang.org/x/text/width"
)
// A set of normalization flags determines how a URL will
// be normalized.
type NormalizationFlags uint
const (
// Safe normalizations
FlagLowercaseScheme NormalizationFlags = 1 << iota // HTTP://host -> http://host, applied by default in Go1.1
FlagLowercaseHost // http://HOST -> http://host
FlagUppercaseEscapes // http://host/t%ef -> http://host/t%EF
FlagDecodeUnnecessaryEscapes // http://host/t%41 -> http://host/tA
FlagEncodeNecessaryEscapes // http://host/!"#$ -> http://host/%21%22#$
FlagRemoveDefaultPort // http://host:80 -> http://host
FlagRemoveEmptyQuerySeparator // http://host/path? -> http://host/path
// Usually safe normalizations
FlagRemoveTrailingSlash // http://host/path/ -> http://host/path
FlagAddTrailingSlash // http://host/path -> http://host/path/ (should choose only one of these add/remove trailing slash flags)
FlagRemoveDotSegments // http://host/path/./a/b/../c -> http://host/path/a/c
// Unsafe normalizations
FlagRemoveDirectoryIndex // http://host/path/index.html -> http://host/path/
FlagRemoveFragment // http://host/path#fragment -> http://host/path
FlagForceHTTP // https://host -> http://host
FlagRemoveDuplicateSlashes // http://host/path//a///b -> http://host/path/a/b
FlagRemoveWWW // http://www.host/ -> http://host/
FlagAddWWW // http://host/ -> http://www.host/ (should choose only one of these add/remove WWW flags)
FlagSortQuery // http://host/path?c=3&b=2&a=1&b=1 -> http://host/path?a=1&b=1&b=2&c=3
// Normalizations not in the wikipedia article, required to cover tests cases
// submitted by jehiah
FlagDecodeDWORDHost // http://1113982867 -> http://66.102.7.147
FlagDecodeOctalHost // http://0102.0146.07.0223 -> http://66.102.7.147
FlagDecodeHexHost // http://0x42660793 -> http://66.102.7.147
FlagRemoveUnnecessaryHostDots // http://.host../path -> http://host/path
FlagRemoveEmptyPortSeparator // http://host:/path -> http://host/path
// Convenience set of safe normalizations
FlagsSafe NormalizationFlags = FlagLowercaseHost | FlagLowercaseScheme | FlagUppercaseEscapes | FlagDecodeUnnecessaryEscapes | FlagEncodeNecessaryEscapes | FlagRemoveDefaultPort | FlagRemoveEmptyQuerySeparator
// For convenience sets, "greedy" uses the "remove trailing slash" and "remove www. prefix" flags,
// while "non-greedy" uses the "add (or keep) the trailing slash" and "add www. prefix".
// Convenience set of usually safe normalizations (includes FlagsSafe)
FlagsUsuallySafeGreedy NormalizationFlags = FlagsSafe | FlagRemoveTrailingSlash | FlagRemoveDotSegments
FlagsUsuallySafeNonGreedy NormalizationFlags = FlagsSafe | FlagAddTrailingSlash | FlagRemoveDotSegments
// Convenience set of unsafe normalizations (includes FlagsUsuallySafe)
FlagsUnsafeGreedy NormalizationFlags = FlagsUsuallySafeGreedy | FlagRemoveDirectoryIndex | FlagRemoveFragment | FlagForceHTTP | FlagRemoveDuplicateSlashes | FlagRemoveWWW | FlagSortQuery
FlagsUnsafeNonGreedy NormalizationFlags = FlagsUsuallySafeNonGreedy | FlagRemoveDirectoryIndex | FlagRemoveFragment | FlagForceHTTP | FlagRemoveDuplicateSlashes | FlagAddWWW | FlagSortQuery
// Convenience set of all available flags
FlagsAllGreedy = FlagsUnsafeGreedy | FlagDecodeDWORDHost | FlagDecodeOctalHost | FlagDecodeHexHost | FlagRemoveUnnecessaryHostDots | FlagRemoveEmptyPortSeparator
FlagsAllNonGreedy = FlagsUnsafeNonGreedy | FlagDecodeDWORDHost | FlagDecodeOctalHost | FlagDecodeHexHost | FlagRemoveUnnecessaryHostDots | FlagRemoveEmptyPortSeparator
)
const (
defaultHttpPort = ":80"
defaultHttpsPort = ":443"
)
// Regular expressions used by the normalizations
var rxPort = regexp.MustCompile(`(:\d+)/?$`)
var rxDirIndex = regexp.MustCompile(`(^|/)((?:default|index)\.\w{1,4})$`)
var rxDupSlashes = regexp.MustCompile(`/{2,}`)
var rxDWORDHost = regexp.MustCompile(`^(\d+)((?:\.+)?(?:\:\d*)?)$`)
var rxOctalHost = regexp.MustCompile(`^(0\d*)\.(0\d*)\.(0\d*)\.(0\d*)((?:\.+)?(?:\:\d*)?)$`)
var rxHexHost = regexp.MustCompile(`^0x([0-9A-Fa-f]+)((?:\.+)?(?:\:\d*)?)$`)
var rxHostDots = regexp.MustCompile(`^(.+?)(:\d+)?$`)
var rxEmptyPort = regexp.MustCompile(`:+$`)
// Map of flags to implementation function.
// FlagDecodeUnnecessaryEscapes has no action, since it is done automatically
// by parsing the string as an URL. Same for FlagUppercaseEscapes and FlagRemoveEmptyQuerySeparator.
// Since maps have undefined traversing order, make a slice of ordered keys
var flagsOrder = []NormalizationFlags{
FlagLowercaseScheme,
FlagLowercaseHost,
FlagRemoveDefaultPort,
FlagRemoveDirectoryIndex,
FlagRemoveDotSegments,
FlagRemoveFragment,
FlagForceHTTP, // Must be after remove default port (because https=443/http=80)
FlagRemoveDuplicateSlashes,
FlagRemoveWWW,
FlagAddWWW,
FlagSortQuery,
FlagDecodeDWORDHost,
FlagDecodeOctalHost,
FlagDecodeHexHost,
FlagRemoveUnnecessaryHostDots,
FlagRemoveEmptyPortSeparator,
FlagRemoveTrailingSlash, // These two (add/remove trailing slash) must be last
FlagAddTrailingSlash,
}
// ... and then the map, where order is unimportant
var flags = map[NormalizationFlags]func(*url.URL){
FlagLowercaseScheme: lowercaseScheme,
FlagLowercaseHost: lowercaseHost,
FlagRemoveDefaultPort: removeDefaultPort,
FlagRemoveDirectoryIndex: removeDirectoryIndex,
FlagRemoveDotSegments: removeDotSegments,
FlagRemoveFragment: removeFragment,
FlagForceHTTP: forceHTTP,
FlagRemoveDuplicateSlashes: removeDuplicateSlashes,
FlagRemoveWWW: removeWWW,
FlagAddWWW: addWWW,
FlagSortQuery: sortQuery,
FlagDecodeDWORDHost: decodeDWORDHost,
FlagDecodeOctalHost: decodeOctalHost,
FlagDecodeHexHost: decodeHexHost,
FlagRemoveUnnecessaryHostDots: removeUnncessaryHostDots,
FlagRemoveEmptyPortSeparator: removeEmptyPortSeparator,
FlagRemoveTrailingSlash: removeTrailingSlash,
FlagAddTrailingSlash: addTrailingSlash,
}
// MustNormalizeURLString returns the normalized string, and panics if an error occurs.
// It takes an URL string as input, as well as the normalization flags.
func MustNormalizeURLString(u string, f NormalizationFlags) string {
result, e := NormalizeURLString(u, f)
if e != nil {
panic(e)
}
return result
}
// NormalizeURLString returns the normalized string, or an error if it can't be parsed into an URL object.
// It takes an URL string as input, as well as the normalization flags.
func NormalizeURLString(u string, f NormalizationFlags) (string, error) {
parsed, err := url.Parse(u)
if err != nil {
return "", err
}
if f&FlagLowercaseHost == FlagLowercaseHost {
parsed.Host = strings.ToLower(parsed.Host)
}
// The idna package doesn't fully conform to RFC 5895
// (https://tools.ietf.org/html/rfc5895), so we do it here.
// Taken from Go 1.8 cycle source, courtesy of bradfitz.
// TODO: Remove when (if?) idna package conforms to RFC 5895.
parsed.Host = width.Fold.String(parsed.Host)
parsed.Host = norm.NFC.String(parsed.Host)
if parsed.Host, err = idna.ToASCII(parsed.Host); err != nil {
return "", err
}
return NormalizeURL(parsed, f), nil
}
// NormalizeURL returns the normalized string.
// It takes a parsed URL object as input, as well as the normalization flags.
func NormalizeURL(u *url.URL, f NormalizationFlags) string {
for _, k := range flagsOrder {
if f&k == k {
flags[k](u)
}
}
return urlesc.Escape(u)
}
func lowercaseScheme(u *url.URL) {
if len(u.Scheme) > 0 {
u.Scheme = strings.ToLower(u.Scheme)
}
}
func lowercaseHost(u *url.URL) {
if len(u.Host) > 0 {
u.Host = strings.ToLower(u.Host)
}
}
func removeDefaultPort(u *url.URL) {
if len(u.Host) > 0 {
scheme := strings.ToLower(u.Scheme)
u.Host = rxPort.ReplaceAllStringFunc(u.Host, func(val string) string {
if (scheme == "http" && val == defaultHttpPort) || (scheme == "https" && val == defaultHttpsPort) {
return ""
}
return val
})
}
}
func removeTrailingSlash(u *url.URL) {
if l := len(u.Path); l > 0 {
if strings.HasSuffix(u.Path, "/") {
u.Path = u.Path[:l-1]
}
} else if l = len(u.Host); l > 0 {
if strings.HasSuffix(u.Host, "/") {
u.Host = u.Host[:l-1]
}
}
}
func addTrailingSlash(u *url.URL) {
if l := len(u.Path); l > 0 {
if !strings.HasSuffix(u.Path, "/") {
u.Path += "/"
}
} else if l = len(u.Host); l > 0 {
if !strings.HasSuffix(u.Host, "/") {
u.Host += "/"
}
}
}
func removeDotSegments(u *url.URL) {
if len(u.Path) > 0 {
var dotFree []string
var lastIsDot bool
sections := strings.Split(u.Path, "/")
for _, s := range sections {
if s == ".." {
if len(dotFree) > 0 {
dotFree = dotFree[:len(dotFree)-1]
}
} else if s != "." {
dotFree = append(dotFree, s)
}
lastIsDot = (s == "." || s == "..")
}
// Special case if host does not end with / and new path does not begin with /
u.Path = strings.Join(dotFree, "/")
if u.Host != "" && !strings.HasSuffix(u.Host, "/") && !strings.HasPrefix(u.Path, "/") {
u.Path = "/" + u.Path
}
// Special case if the last segment was a dot, make sure the path ends with a slash
if lastIsDot && !strings.HasSuffix(u.Path, "/") {
u.Path += "/"
}
}
}
func removeDirectoryIndex(u *url.URL) {
if len(u.Path) > 0 {
u.Path = rxDirIndex.ReplaceAllString(u.Path, "$1")
}
}
func removeFragment(u *url.URL) {
u.Fragment = ""
}
func forceHTTP(u *url.URL) {
if strings.ToLower(u.Scheme) == "https" {
u.Scheme = "http"
}
}
func removeDuplicateSlashes(u *url.URL) {
if len(u.Path) > 0 {
u.Path = rxDupSlashes.ReplaceAllString(u.Path, "/")
}
}
func removeWWW(u *url.URL) {
if len(u.Host) > 0 && strings.HasPrefix(strings.ToLower(u.Host), "www.") {
u.Host = u.Host[4:]
}
}
func addWWW(u *url.URL) {
if len(u.Host) > 0 && !strings.HasPrefix(strings.ToLower(u.Host), "www.") {
u.Host = "www." + u.Host
}
}
func sortQuery(u *url.URL) {
q := u.Query()
if len(q) > 0 {
arKeys := make([]string, len(q))
i := 0
for k := range q {
arKeys[i] = k
i++
}
sort.Strings(arKeys)
buf := new(bytes.Buffer)
for _, k := range arKeys {
sort.Strings(q[k])
for _, v := range q[k] {
if buf.Len() > 0 {
buf.WriteRune('&')
}
buf.WriteString(fmt.Sprintf("%s=%s", k, urlesc.QueryEscape(v)))
}
}
// Rebuild the raw query string
u.RawQuery = buf.String()
}
}
func decodeDWORDHost(u *url.URL) {
if len(u.Host) > 0 {
if matches := rxDWORDHost.FindStringSubmatch(u.Host); len(matches) > 2 {
var parts [4]int64
dword, _ := strconv.ParseInt(matches[1], 10, 0)
for i, shift := range []uint{24, 16, 8, 0} {
parts[i] = dword >> shift & 0xFF
}
u.Host = fmt.Sprintf("%d.%d.%d.%d%s", parts[0], parts[1], parts[2], parts[3], matches[2])
}
}
}
func decodeOctalHost(u *url.URL) {
if len(u.Host) > 0 {
if matches := rxOctalHost.FindStringSubmatch(u.Host); len(matches) > 5 {
var parts [4]int64
for i := 1; i <= 4; i++ {
parts[i-1], _ = strconv.ParseInt(matches[i], 8, 0)
}
u.Host = fmt.Sprintf("%d.%d.%d.%d%s", parts[0], parts[1], parts[2], parts[3], matches[5])
}
}
}
func decodeHexHost(u *url.URL) {
if len(u.Host) > 0 {
if matches := rxHexHost.FindStringSubmatch(u.Host); len(matches) > 2 {
// Conversion is safe because of regex validation
parsed, _ := strconv.ParseInt(matches[1], 16, 0)
// Set host as DWORD (base 10) encoded host
u.Host = fmt.Sprintf("%d%s", parsed, matches[2])
// The rest is the same as decoding a DWORD host
decodeDWORDHost(u)
}
}
}
func removeUnncessaryHostDots(u *url.URL) {
if len(u.Host) > 0 {
if matches := rxHostDots.FindStringSubmatch(u.Host); len(matches) > 1 {
// Trim the leading and trailing dots
u.Host = strings.Trim(matches[1], ".")
if len(matches) > 2 {
u.Host += matches[2]
}
}
}
}
func removeEmptyPortSeparator(u *url.URL) {
if len(u.Host) > 0 {
u.Host = rxEmptyPort.ReplaceAllString(u.Host, "")
}
}

View File

@@ -1,27 +0,0 @@
Copyright (c) 2012 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -1,16 +0,0 @@
urlesc [![Build Status](https://travis-ci.org/PuerkitoBio/urlesc.svg?branch=master)](https://travis-ci.org/PuerkitoBio/urlesc) [![GoDoc](http://godoc.org/github.com/PuerkitoBio/urlesc?status.svg)](http://godoc.org/github.com/PuerkitoBio/urlesc)
======
Package urlesc implements query escaping as per RFC 3986.
It contains some parts of the net/url package, modified so as to allow
some reserved characters incorrectly escaped by net/url (see [issue 5684](https://github.com/golang/go/issues/5684)).
## Install
go get github.com/PuerkitoBio/urlesc
## License
Go license (BSD-3-Clause)

View File

@@ -1,180 +0,0 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package urlesc implements query escaping as per RFC 3986.
// It contains some parts of the net/url package, modified so as to allow
// some reserved characters incorrectly escaped by net/url.
// See https://github.com/golang/go/issues/5684
package urlesc
import (
"bytes"
"net/url"
"strings"
)
type encoding int
const (
encodePath encoding = 1 + iota
encodeUserPassword
encodeQueryComponent
encodeFragment
)
// Return true if the specified character should be escaped when
// appearing in a URL string, according to RFC 3986.
func shouldEscape(c byte, mode encoding) bool {
// §2.3 Unreserved characters (alphanum)
if 'A' <= c && c <= 'Z' || 'a' <= c && c <= 'z' || '0' <= c && c <= '9' {
return false
}
switch c {
case '-', '.', '_', '~': // §2.3 Unreserved characters (mark)
return false
// §2.2 Reserved characters (reserved)
case ':', '/', '?', '#', '[', ']', '@', // gen-delims
'!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=': // sub-delims
// Different sections of the URL allow a few of
// the reserved characters to appear unescaped.
switch mode {
case encodePath: // §3.3
// The RFC allows sub-delims and : @.
// '/', '[' and ']' can be used to assign meaning to individual path
// segments. This package only manipulates the path as a whole,
// so we allow those as well. That leaves only ? and # to escape.
return c == '?' || c == '#'
case encodeUserPassword: // §3.2.1
// The RFC allows : and sub-delims in
// userinfo. The parsing of userinfo treats ':' as special so we must escape
// all the gen-delims.
return c == ':' || c == '/' || c == '?' || c == '#' || c == '[' || c == ']' || c == '@'
case encodeQueryComponent: // §3.4
// The RFC allows / and ?.
return c != '/' && c != '?'
case encodeFragment: // §4.1
// The RFC text is silent but the grammar allows
// everything, so escape nothing but #
return c == '#'
}
}
// Everything else must be escaped.
return true
}
// QueryEscape escapes the string so it can be safely placed
// inside a URL query.
func QueryEscape(s string) string {
return escape(s, encodeQueryComponent)
}
func escape(s string, mode encoding) string {
spaceCount, hexCount := 0, 0
for i := 0; i < len(s); i++ {
c := s[i]
if shouldEscape(c, mode) {
if c == ' ' && mode == encodeQueryComponent {
spaceCount++
} else {
hexCount++
}
}
}
if spaceCount == 0 && hexCount == 0 {
return s
}
t := make([]byte, len(s)+2*hexCount)
j := 0
for i := 0; i < len(s); i++ {
switch c := s[i]; {
case c == ' ' && mode == encodeQueryComponent:
t[j] = '+'
j++
case shouldEscape(c, mode):
t[j] = '%'
t[j+1] = "0123456789ABCDEF"[c>>4]
t[j+2] = "0123456789ABCDEF"[c&15]
j += 3
default:
t[j] = s[i]
j++
}
}
return string(t)
}
var uiReplacer = strings.NewReplacer(
"%21", "!",
"%27", "'",
"%28", "(",
"%29", ")",
"%2A", "*",
)
// unescapeUserinfo unescapes some characters that need not to be escaped as per RFC3986.
func unescapeUserinfo(s string) string {
return uiReplacer.Replace(s)
}
// Escape reassembles the URL into a valid URL string.
// The general form of the result is one of:
//
// scheme:opaque
// scheme://userinfo@host/path?query#fragment
//
// If u.Opaque is non-empty, String uses the first form;
// otherwise it uses the second form.
//
// In the second form, the following rules apply:
// - if u.Scheme is empty, scheme: is omitted.
// - if u.User is nil, userinfo@ is omitted.
// - if u.Host is empty, host/ is omitted.
// - if u.Scheme and u.Host are empty and u.User is nil,
// the entire scheme://userinfo@host/ is omitted.
// - if u.Host is non-empty and u.Path begins with a /,
// the form host/path does not add its own /.
// - if u.RawQuery is empty, ?query is omitted.
// - if u.Fragment is empty, #fragment is omitted.
func Escape(u *url.URL) string {
var buf bytes.Buffer
if u.Scheme != "" {
buf.WriteString(u.Scheme)
buf.WriteByte(':')
}
if u.Opaque != "" {
buf.WriteString(u.Opaque)
} else {
if u.Scheme != "" || u.Host != "" || u.User != nil {
buf.WriteString("//")
if ui := u.User; ui != nil {
buf.WriteString(unescapeUserinfo(ui.String()))
buf.WriteByte('@')
}
if h := u.Host; h != "" {
buf.WriteString(h)
}
}
if u.Path != "" && u.Path[0] != '/' && u.Host != "" {
buf.WriteByte('/')
}
buf.WriteString(escape(u.Path, encodePath))
}
if u.RawQuery != "" {
buf.WriteByte('?')
buf.WriteString(u.RawQuery)
}
if u.Fragment != "" {
buf.WriteByte('#')
buf.WriteString(escape(u.Fragment, encodeFragment))
}
return buf.String()
}

View File

@@ -1,7 +1,23 @@
language: go
# See https://travis-ci.community/t/goos-js-goarch-wasm-go-run-fails-panic-newosproc-not-implemented/1651
#addons:
# chrome: stable
before_install:
- export GO111MODULE=on
#install:
#- go get github.com/agnivade/wasmbrowsertest
#- mv $GOPATH/bin/wasmbrowsertest $GOPATH/bin/go_js_wasm_exec
#- export PATH=$GOPATH/bin:$PATH
go:
- 1.9.x
- 1.10.x
- 1.11.x
- 1.13.x
- 1.14.x
- 1.15.x
- tip
script:
#- GOOS=js GOARCH=wasm go test -v
- go test -v

View File

@@ -4,10 +4,12 @@ install:
go install
lint:
gofmt -l -s -w . && go tool vet -all . && golint
gofmt -l -s -w . && go vet . && golint -set_exit_status=1 .
test:
go test -race -v -coverprofile=coverage.txt -covermode=atomic
test: # The first 2 go gets are to support older Go versions
go get github.com/arbovm/levenshtein
go get github.com/dgryski/trifles/leven
GO111MODULE=on go test -race -v -coverprofile=coverage.txt -covermode=atomic
bench:
go test -run=XXX -bench=. -benchmem
go test -run=XXX -bench=. -benchmem -count=5

View File

@@ -1,4 +1,4 @@
levenshtein [![Build Status](https://travis-ci.org/agnivade/levenshtein.svg?branch=master)](https://travis-ci.org/agnivade/levenshtein) [![Go Report Card](https://goreportcard.com/badge/github.com/agnivade/levenshtein)](https://goreportcard.com/report/github.com/agnivade/levenshtein) [![GoDoc](https://godoc.org/github.com/agnivade/levenshtein?status.svg)](https://godoc.org/github.com/agnivade/levenshtein)
levenshtein [![Build Status](https://travis-ci.org/agnivade/levenshtein.svg?branch=master)](https://travis-ci.org/agnivade/levenshtein) [![Go Report Card](https://goreportcard.com/badge/github.com/agnivade/levenshtein)](https://goreportcard.com/report/github.com/agnivade/levenshtein) [![PkgGoDev](https://pkg.go.dev/badge/github.com/agnivade/levenshtein)](https://pkg.go.dev/github.com/agnivade/levenshtein)
===========
[Go](http://golang.org) package to calculate the [Levenshtein Distance](http://en.wikipedia.org/wiki/Levenshtein_distance)
@@ -6,6 +6,10 @@ levenshtein [![Build Status](https://travis-ci.org/agnivade/levenshtein.svg?bran
The library is fully capable of working with non-ascii strings. But the strings are not normalized. That is left as a user-dependant use case. Please normalize the strings before passing it to the library if you have such a requirement.
- https://blog.golang.org/normalization
#### Limitation
As a performance optimization, the library can handle strings only up to 65536 characters (runes). If you need to handle strings larger than that, please pin to version 1.0.3.
Install
-------
@@ -38,10 +42,10 @@ Benchmarks
```
name time/op
Simple/ASCII-4 537ns ± 2%
Simple/French-4 956ns ± 0%
Simple/Nordic-4 1.95µs ± 1%
Simple/Tibetan-4 1.53µs ± 2%
Simple/ASCII-4 330ns ± 2%
Simple/French-4 617ns ± 2%
Simple/Nordic-4 1.16µs ± 4%
Simple/Tibetan-4 1.05µs ± 1%
name alloc/op
Simple/ASCII-4 96.0B ± 0%
@@ -55,3 +59,22 @@ Simple/French-4 1.00 ± 0%
Simple/Nordic-4 1.00 ± 0%
Simple/Tibetan-4 1.00 ± 0%
```
Comparisons with other libraries
--------------------------------
```
name time/op
Leven/ASCII/agniva-4 353ns ± 1%
Leven/ASCII/arbovm-4 485ns ± 1%
Leven/ASCII/dgryski-4 395ns ± 0%
Leven/French/agniva-4 648ns ± 1%
Leven/French/arbovm-4 791ns ± 0%
Leven/French/dgryski-4 682ns ± 0%
Leven/Nordic/agniva-4 1.28µs ± 1%
Leven/Nordic/arbovm-4 1.52µs ± 1%
Leven/Nordic/dgryski-4 1.32µs ± 1%
Leven/Tibetan/agniva-4 1.12µs ± 1%
Leven/Tibetan/arbovm-4 1.31µs ± 0%
Leven/Tibetan/dgryski-4 1.16µs ± 0%
```

View File

@@ -6,12 +6,17 @@ package levenshtein
import "unicode/utf8"
// minLengthThreshold is the length of the string beyond which
// an allocation will be made. Strings smaller than this will be
// zero alloc.
const minLengthThreshold = 32
// ComputeDistance computes the levenshtein distance between the two
// strings passed as an argument. The return value is the levenshtein distance
//
// Works on runes (Unicode code points) but does not normalize
// the input strings. See https://blog.golang.org/normalization
// and the golang.org/x/text/unicode/norm pacage.
// and the golang.org/x/text/unicode/norm package.
func ComputeDistance(a, b string) int {
if len(a) == 0 {
return utf8.RuneCountInString(b)
@@ -25,12 +30,10 @@ func ComputeDistance(a, b string) int {
return 0
}
// We need to convert to []rune if the strings are non-ascii.
// We need to convert to []rune if the strings are non-ASCII.
// This could be avoided by using utf8.RuneCountInString
// and then doing some juggling with rune indices.
// The primary challenge is keeping track of the previous rune.
// With a range loop, its not that easy. And with a for-loop
// we need to keep track of the inter-rune width using utf8.DecodeRuneInString
// and then doing some juggling with rune indices,
// but leads to far more bounds checks. It is a reasonable trade-off.
s1 := []rune(a)
s2 := []rune(b)
@@ -41,22 +44,33 @@ func ComputeDistance(a, b string) int {
lenS1 := len(s1)
lenS2 := len(s2)
// init the row
x := make([]int, lenS1+1)
for i := 0; i <= lenS1; i++ {
x[i] = i
// Init the row.
var x []uint16
if lenS1+1 > minLengthThreshold {
x = make([]uint16, lenS1+1)
} else {
// We make a small optimization here for small strings.
// Because a slice of constant length is effectively an array,
// it does not allocate. So we can re-slice it to the right length
// as long as it is below a desired threshold.
x = make([]uint16, minLengthThreshold)
x = x[:lenS1+1]
}
// we start from 1 because index 0 is already 0.
for i := 1; i < len(x); i++ {
x[i] = uint16(i)
}
// make a dummy bounds check to prevent the 2 bounds check down below.
// The one inside the loop is particularly costly.
_ = x[lenS1]
// fill in the rest
for i := 1; i <= lenS2; i++ {
prev := i
var current int
prev := uint16(i)
for j := 1; j <= lenS1; j++ {
if s2[i-1] == s1[j-1] {
current = x[j-1] // match
} else {
current := x[j-1] // match
if s2[i-1] != s1[j-1] {
current = min(min(x[j-1]+1, prev+1), x[j]+1)
}
x[j-1] = prev
@@ -64,10 +78,10 @@ func ComputeDistance(a, b string) int {
}
x[lenS1] = prev
}
return x[lenS1]
return int(x[lenS1])
}
func min(a, b int) int {
func min(a, b uint16) uint16 {
if a < b {
return a
}

View File

@@ -83,14 +83,14 @@ This was changed to prevent data races when accessing custom validators.
import "github.com/asaskevich/govalidator"
// before
govalidator.CustomTypeTagMap["customByteArrayValidator"] = CustomTypeValidator(func(i interface{}, o interface{}) bool {
govalidator.CustomTypeTagMap["customByteArrayValidator"] = func(i interface{}, o interface{}) bool {
// ...
})
}
// after
govalidator.CustomTypeTagMap.Set("customByteArrayValidator", CustomTypeValidator(func(i interface{}, o interface{}) bool {
govalidator.CustomTypeTagMap.Set("customByteArrayValidator", func(i interface{}, o interface{}) bool {
// ...
}))
})
```
#### List of functions:
@@ -238,7 +238,7 @@ func Trim(str, chars string) string
func Truncate(str string, length int, ending string) string
func TruncatingErrorf(str string, args ...interface{}) error
func UnderscoreToCamelCase(s string) string
func ValidateMap(s map[string]interface{}, m map[string]interface{}) (bool, error)
func ValidateMap(inputMap map[string]interface{}, validationMap map[string]interface{}) (bool, error)
func ValidateStruct(s interface{}) (bool, error)
func WhiteList(str, chars string) string
type ConditionIterator
@@ -461,7 +461,7 @@ var inputMap = map[string]interface{}{
},
}
result, err := govalidator.ValidateMap(mapTemplate, inputMap)
result, err := govalidator.ValidateMap(inputMap, mapTemplate)
if err != nil {
println("error: " + err.Error())
}
@@ -487,7 +487,7 @@ type StructWithCustomByteArray struct {
CustomMinLength int `valid:"-"`
}
govalidator.CustomTypeTagMap.Set("customByteArrayValidator", CustomTypeValidator(func(i interface{}, context interface{}) bool {
govalidator.CustomTypeTagMap.Set("customByteArrayValidator", func(i interface{}, context interface{}) bool {
switch v := context.(type) { // you can type switch on the context interface being validated
case StructWithCustomByteArray:
// you can check and validate against some other field in the context,
@@ -507,14 +507,25 @@ govalidator.CustomTypeTagMap.Set("customByteArrayValidator", CustomTypeValidator
}
}
return false
}))
govalidator.CustomTypeTagMap.Set("customMinLengthValidator", CustomTypeValidator(func(i interface{}, context interface{}) bool {
})
govalidator.CustomTypeTagMap.Set("customMinLengthValidator", func(i interface{}, context interface{}) bool {
switch v := context.(type) { // this validates a field against the value in another field, i.e. dependent validation
case StructWithCustomByteArray:
return len(v.ID) >= v.CustomMinLength
}
return false
}))
})
```
###### Loop over Error()
By default .Error() returns all errors in a single String. To access each error you can do this:
```go
if err != nil {
errs := err.(govalidator.Errors).Errors()
for _, e := range errs {
fmt.Println(e.Error())
}
}
```
###### Custom error messages
@@ -602,4 +613,4 @@ Support this project by becoming a sponsor. Your logo will show up here with a l
## License
[![FOSSA Status](https://app.fossa.io/api/projects/git%2Bgithub.com%2Fasaskevich%2Fgovalidator.svg?type=large)](https://app.fossa.io/projects/git%2Bgithub.com%2Fasaskevich%2Fgovalidator?ref=badge_large)
[![FOSSA Status](https://app.fossa.io/api/projects/git%2Bgithub.com%2Fasaskevich%2Fgovalidator.svg?type=large)](https://app.fossa.io/projects/git%2Bgithub.com%2Fasaskevich%2Fgovalidator?ref=badge_large)

View File

@@ -1,6 +1,9 @@
package govalidator
import "strings"
import (
"sort"
"strings"
)
// Errors is an array of multiple errors and conforms to the error interface.
type Errors []error
@@ -15,6 +18,7 @@ func (es Errors) Error() string {
for _, e := range es {
errs = append(errs, e.Error())
}
sort.Strings(errs)
return strings.Join(errs, ";")
}

View File

@@ -48,6 +48,7 @@ const (
hasUpperCase string = ".*[[:upper:]]"
hasWhitespace string = ".*[[:space:]]"
hasWhitespaceOnly string = "^[[:space:]]+$"
IMEI string = "^[0-9a-f]{14}$|^\\d{15}$|^\\d{18}$"
)
// Used by IsFilePath func
@@ -100,4 +101,5 @@ var (
rxHasUpperCase = regexp.MustCompile(hasUpperCase)
rxHasWhitespace = regexp.MustCompile(hasWhitespace)
rxHasWhitespaceOnly = regexp.MustCompile(hasWhitespaceOnly)
rxIMEI = regexp.MustCompile(IMEI)
)

View File

@@ -162,6 +162,7 @@ var TagMap = map[string]Validator{
"ISO3166Alpha2": IsISO3166Alpha2,
"ISO3166Alpha3": IsISO3166Alpha3,
"ISO4217": IsISO4217,
"IMEI": IsIMEI,
}
// ISO3166Entry stores country codes

View File

@@ -282,7 +282,7 @@ func HasLowerCase(str string) bool {
return rxHasLowerCase.MatchString(str)
}
// HasUpperCase check if the string contians as least 1 uppercase. Empty string is valid.
// HasUpperCase check if the string contains as least 1 uppercase. Empty string is valid.
func HasUpperCase(str string) bool {
if IsNull(str) {
return true
@@ -575,7 +575,7 @@ func IsDNSName(str string) bool {
// IsHash checks if a string is a hash of type algorithm.
// Algorithm is one of ['md4', 'md5', 'sha1', 'sha256', 'sha384', 'sha512', 'ripemd128', 'ripemd160', 'tiger128', 'tiger160', 'tiger192', 'crc32', 'crc32b']
func IsHash(str string, algorithm string) bool {
len := "0"
var len string
algo := strings.ToLower(algorithm)
if algo == "crc32" || algo == "crc32b" {
@@ -737,6 +737,11 @@ func IsLongitude(str string) bool {
return rxLongitude.MatchString(str)
}
// IsIMEI check if a string is valid IMEI
func IsIMEI(str string) bool {
return rxIMEI.MatchString(str)
}
// IsRsaPublicKey check if a string is valid public key with provided length
func IsRsaPublicKey(str string, keylen int) bool {
bb := bytes.NewBufferString(str)
@@ -808,8 +813,9 @@ func PrependPathToErrors(err error, path string) error {
// ValidateMap use validation map for fields.
// result will be equal to `false` if there are any errors.
// m is the validation map in the form
// map[string]interface{}{"name":"required,alpha","address":map[string]interface{}{"line1":"required,alphanum"}}
// s is the map containing the data to be validated.
// m is the validation map in the form:
// map[string]interface{}{"name":"required,alpha","address":map[string]interface{}{"line1":"required,alphanum"}}
func ValidateMap(s map[string]interface{}, m map[string]interface{}) (bool, error) {
if s == nil {
return true, nil
@@ -1498,11 +1504,11 @@ func ErrorsByField(e error) map[string]string {
}
// prototype for ValidateStruct
switch e.(type) {
switch e := e.(type) {
case Error:
m[e.(Error).Name] = e.(Error).Err.Error()
m[e.Name] = e.Err.Error()
case Errors:
for _, item := range e.(Errors).Errors() {
for _, item := range e.Errors() {
n := ErrorsByField(item)
for k, v := range n {
m[k] = v

View File

@@ -1,8 +0,0 @@
language: go
go:
- "1.x"
- master
env:
- TAGS=""
- TAGS="-tags purego"
script: go test $TAGS -v ./...

View File

@@ -1,7 +1,7 @@
# xxhash
[![GoDoc](https://godoc.org/github.com/cespare/xxhash?status.svg)](https://godoc.org/github.com/cespare/xxhash)
[![Build Status](https://travis-ci.org/cespare/xxhash.svg?branch=master)](https://travis-ci.org/cespare/xxhash)
[![Go Reference](https://pkg.go.dev/badge/github.com/cespare/xxhash/v2.svg)](https://pkg.go.dev/github.com/cespare/xxhash/v2)
[![Test](https://github.com/cespare/xxhash/actions/workflows/test.yml/badge.svg)](https://github.com/cespare/xxhash/actions/workflows/test.yml)
xxhash is a Go implementation of the 64-bit
[xxHash](http://cyan4973.github.io/xxHash/) algorithm, XXH64. This is a
@@ -64,4 +64,6 @@ $ go test -benchtime 10s -bench '/xxhash,direct,bytes'
- [InfluxDB](https://github.com/influxdata/influxdb)
- [Prometheus](https://github.com/prometheus/prometheus)
- [VictoriaMetrics](https://github.com/VictoriaMetrics/VictoriaMetrics)
- [FreeCache](https://github.com/coocood/freecache)
- [FastCache](https://github.com/VictoriaMetrics/fastcache)

Some files were not shown because too many files have changed in this diff Show More