Files
kubesphere/vendor/github.com/open-policy-agent/opa/internal/compiler/wasm/wasm.go
hongming 9769357005 update
Signed-off-by: hongming <talonwan@yunify.com>
2020-03-20 02:16:11 +08:00

1014 lines
34 KiB
Go

// Copyright 2018 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
// Package wasm contains an IR->WASM compiler backend.
package wasm
import (
"bytes"
"fmt"
"github.com/pkg/errors"
"github.com/open-policy-agent/opa/internal/compiler/wasm/opa"
"github.com/open-policy-agent/opa/internal/ir"
"github.com/open-policy-agent/opa/internal/wasm/encoding"
"github.com/open-policy-agent/opa/internal/wasm/instruction"
"github.com/open-policy-agent/opa/internal/wasm/module"
"github.com/open-policy-agent/opa/internal/wasm/types"
)
const (
opaTypeNull int32 = iota + 1
opaTypeBoolean
opaTypeNumber
opaTypeString
opaTypeArray
opaTypeObject
)
const (
opaFuncPrefix = "opa_"
opaAbort = "opa_abort"
opaJSONParse = "opa_json_parse"
opaNull = "opa_null"
opaBoolean = "opa_boolean"
opaNumberInt = "opa_number_int"
opaNumberFloat = "opa_number_float"
opaNumberRef = "opa_number_ref"
opaNumberSize = "opa_number_size"
opaArrayWithCap = "opa_array_with_cap"
opaArrayAppend = "opa_array_append"
opaObject = "opa_object"
opaObjectInsert = "opa_object_insert"
opaSet = "opa_set"
opaSetAdd = "opa_set_add"
opaStringTerminated = "opa_string_terminated"
opaValueBooleanSet = "opa_value_boolean_set"
opaValueNumberSetInt = "opa_value_number_set_int"
opaValueCompare = "opa_value_compare"
opaValueGet = "opa_value_get"
opaValueIter = "opa_value_iter"
opaValueLength = "opa_value_length"
opaValueMerge = "opa_value_merge"
opaValueShallowCopy = "opa_value_shallow_copy"
opaValueType = "opa_value_type"
)
var builtins = [...]string{
"opa_builtin0",
"opa_builtin1",
"opa_builtin2",
"opa_builtin3",
"opa_builtin4",
}
// Compiler implements an IR->WASM compiler backend.
type Compiler struct {
stages []func() error // compiler stages to execute
errors []error // compilation errors encountered
policy *ir.Policy // input policy to compile
module *module.Module // output WASM module
code *module.CodeEntry // output WASM code
builtinStringAddrs map[int]uint32 // addresses of built-in string constants
builtinFuncNameAddrs map[string]int32 // addresses of built-in function names for listing
builtinFuncs map[string]int32 // built-in function ids
stringOffset int32 // null-terminated string data base offset
stringAddrs []uint32 // null-terminated string constant addresses
funcs map[string]uint32 // maps imported and exported function names to function indices
nextLocal uint32
locals map[ir.Local]uint32
lctx uint32 // local pointing to eval context
lrs uint32 // local pointing to result set
}
const (
errVarAssignConflict int = iota
errObjectInsertConflict
errObjectMergeConflict
errWithConflict
)
var errorMessages = [...]struct {
id int
message string
}{
{errVarAssignConflict, "var assignment conflict"},
{errObjectInsertConflict, "object insert conflict"},
{errObjectMergeConflict, "object merge conflict"},
{errWithConflict, "with target conflict"},
}
// New returns a new compiler object.
func New() *Compiler {
c := &Compiler{}
c.stages = []func() error{
c.initModule,
c.compileStrings,
c.compileBuiltinDecls,
c.compileFuncs,
c.compilePlan,
}
return c
}
// WithPolicy sets the policy to compile.
func (c *Compiler) WithPolicy(p *ir.Policy) *Compiler {
c.policy = p
return c
}
// Compile returns a compiled WASM module.
func (c *Compiler) Compile() (*module.Module, error) {
for _, stage := range c.stages {
if err := stage(); err != nil {
return nil, err
} else if len(c.errors) > 0 {
return nil, c.errors[0] // TODO(tsandall) return all errors.
}
}
return c.module, nil
}
// initModule instantiates the module from the pre-compiled OPA binary. The
// module is then updated to include declarations for all of the functions that
// are about to be compiled.
func (c *Compiler) initModule() error {
bs, err := opa.Bytes()
if err != nil {
return err
}
c.module, err = encoding.ReadModule(bytes.NewReader(bs))
if err != nil {
return err
}
c.funcs = make(map[string]uint32)
var funcidx uint32
for _, imp := range c.module.Import.Imports {
if imp.Descriptor.Kind() == module.FunctionImportType {
c.funcs[imp.Name] = funcidx
funcidx++
}
}
for i := range c.module.Export.Exports {
exp := &c.module.Export.Exports[i]
if exp.Descriptor.Type == module.FunctionExportType {
c.funcs[exp.Name] = exp.Descriptor.Index
}
}
for _, fn := range c.policy.Funcs.Funcs {
params := make([]types.ValueType, len(fn.Params))
for i := 0; i < len(params); i++ {
params[i] = types.I32
}
tpe := module.FunctionType{
Params: params,
Results: []types.ValueType{types.I32},
}
c.emitFunctionDecl(fn.Name, tpe, false)
}
c.emitFunctionDecl("eval", module.FunctionType{
Params: []types.ValueType{types.I32},
Results: []types.ValueType{types.I32},
}, true)
c.emitFunctionDecl("builtins", module.FunctionType{
Params: nil,
Results: []types.ValueType{types.I32},
}, true)
return nil
}
// compileStrings compiles string constants into the data section of the module.
// The strings are indexed for lookups in later stages.
func (c *Compiler) compileStrings() error {
c.stringOffset = 2048
c.stringAddrs = make([]uint32, len(c.policy.Static.Strings))
var buf bytes.Buffer
for i, s := range c.policy.Static.Strings {
addr := uint32(buf.Len()) + uint32(c.stringOffset)
buf.WriteString(s.Value)
buf.WriteByte(0)
c.stringAddrs[i] = addr
}
c.builtinFuncNameAddrs = make(map[string]int32, len(c.policy.Static.BuiltinFuncs))
for _, decl := range c.policy.Static.BuiltinFuncs {
addr := int32(buf.Len()) + int32(c.stringOffset)
buf.WriteString(decl.Name)
buf.WriteByte(0)
c.builtinFuncNameAddrs[decl.Name] = addr
}
c.builtinStringAddrs = make(map[int]uint32, len(errorMessages))
for i := range errorMessages {
addr := uint32(buf.Len()) + uint32(c.stringOffset)
buf.WriteString(errorMessages[i].message)
buf.WriteByte(0)
c.builtinStringAddrs[errorMessages[i].id] = addr
}
c.module.Data.Segments = append(c.module.Data.Segments, module.DataSegment{
Index: 0,
Offset: module.Expr{
Instrs: []instruction.Instruction{
instruction.I32Const{
Value: c.stringOffset,
},
},
},
Init: buf.Bytes(),
})
return nil
}
// compileBuiltinDecls generates a function that lists the built-ins required by
// the policy. The host environment should invoke this function obtain the list
// of built-in function identifiers (represented as integers) that will be used
// when calling out.
func (c *Compiler) compileBuiltinDecls() error {
c.code = &module.CodeEntry{}
c.nextLocal = 0
c.locals = map[ir.Local]uint32{}
lobj := c.genLocal()
c.appendInstr(instruction.Call{Index: c.function(opaObject)})
c.appendInstr(instruction.SetLocal{Index: lobj})
c.builtinFuncs = make(map[string]int32, len(c.policy.Static.BuiltinFuncs))
for index, decl := range c.policy.Static.BuiltinFuncs {
c.appendInstr(instruction.GetLocal{Index: lobj})
c.appendInstr(instruction.I32Const{Value: c.builtinFuncNameAddrs[decl.Name]})
c.appendInstr(instruction.Call{Index: c.function(opaStringTerminated)})
c.appendInstr(instruction.I64Const{Value: int64(index)})
c.appendInstr(instruction.Call{Index: c.function(opaNumberInt)})
c.appendInstr(instruction.Call{Index: c.function(opaObjectInsert)})
c.builtinFuncs[decl.Name] = int32(index)
}
c.appendInstr(instruction.GetLocal{Index: lobj})
c.code.Func.Locals = []module.LocalDeclaration{
{
Count: c.nextLocal,
Type: types.I32,
},
}
return c.emitFunction("builtins", c.code)
}
// compileFuncs compiles the policy functions and emits them into the module.
func (c *Compiler) compileFuncs() error {
for _, fn := range c.policy.Funcs.Funcs {
if err := c.compileFunc(fn); err != nil {
return errors.Wrapf(err, "func %v", fn.Name)
}
}
return nil
}
// compilePlan compiles the policy plan and emits the resulting function into
// the module.
func (c *Compiler) compilePlan() error {
c.code = &module.CodeEntry{}
c.nextLocal = 0
c.locals = map[ir.Local]uint32{}
c.lctx = c.genLocal()
c.lrs = c.genLocal()
// Initialize the input and data locals.
c.appendInstr(instruction.GetLocal{Index: c.lctx})
c.appendInstr(instruction.I32Load{Offset: 0, Align: 2})
c.appendInstr(instruction.SetLocal{Index: c.local(ir.Input)})
c.appendInstr(instruction.GetLocal{Index: c.lctx})
c.appendInstr(instruction.I32Load{Offset: 4, Align: 2})
c.appendInstr(instruction.SetLocal{Index: c.local(ir.Data)})
// Initialize the result set.
c.appendInstr(instruction.Call{Index: c.function(opaSet)})
c.appendInstr(instruction.SetLocal{Index: c.lrs})
c.appendInstr(instruction.GetLocal{Index: c.lctx})
c.appendInstr(instruction.GetLocal{Index: c.lrs})
c.appendInstr(instruction.I32Store{Offset: 8, Align: 2})
for i := range c.policy.Plan.Blocks {
instrs, err := c.compileBlock(c.policy.Plan.Blocks[i])
if err != nil {
return errors.Wrapf(err, "block %d", i)
}
c.appendInstr(instruction.Block{Instrs: instrs})
}
c.appendInstr(instruction.I32Const{Value: int32(0)})
c.code.Func.Locals = []module.LocalDeclaration{
{
Count: c.nextLocal,
Type: types.I32,
},
}
return c.emitFunction("eval", c.code)
}
func (c *Compiler) compileFunc(fn *ir.Func) error {
if len(fn.Params) == 0 {
return fmt.Errorf("illegal function: zero args")
}
c.nextLocal = 0
c.locals = map[ir.Local]uint32{}
for _, a := range fn.Params {
_ = c.local(a)
}
_ = c.local(fn.Return)
c.code = &module.CodeEntry{}
for i := range fn.Blocks {
instrs, err := c.compileBlock(fn.Blocks[i])
if err != nil {
return errors.Wrapf(err, "block %d", i)
}
if i < len(fn.Blocks)-1 {
c.appendInstr(instruction.Block{Instrs: instrs})
} else {
c.appendInstrs(instrs)
}
}
c.code.Func.Locals = []module.LocalDeclaration{
{
Count: c.nextLocal,
Type: types.I32,
},
}
var params []types.ValueType
for i := 0; i < len(fn.Params); i++ {
params = append(params, types.I32)
}
return c.emitFunction(fn.Name, c.code)
}
func (c *Compiler) compileBlock(block *ir.Block) ([]instruction.Instruction, error) {
var instrs []instruction.Instruction
for _, stmt := range block.Stmts {
switch stmt := stmt.(type) {
case *ir.ResultSetAdd:
instrs = append(instrs, instruction.GetLocal{Index: c.lrs})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Value)})
instrs = append(instrs, instruction.Call{Index: c.function(opaSetAdd)})
case *ir.ReturnLocalStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Source)})
instrs = append(instrs, instruction.Return{})
case *ir.BlockStmt:
for i := range stmt.Blocks {
block, err := c.compileBlock(stmt.Blocks[i])
if err != nil {
return nil, err
}
instrs = append(instrs, instruction.Block{Instrs: block})
}
case *ir.BreakStmt:
instrs = append(instrs, instruction.Br{Index: stmt.Index})
case *ir.CallStmt:
if err := c.compileCallStmt(stmt, &instrs); err != nil {
return nil, err
}
case *ir.WithStmt:
if err := c.compileWithStmt(stmt, &instrs); err != nil {
return instrs, err
}
case *ir.AssignVarStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Source)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
case *ir.AssignVarOnceStmt:
instrs = append(instrs, instruction.Block{
Instrs: []instruction.Instruction{
instruction.Block{
Instrs: []instruction.Instruction{
instruction.GetLocal{Index: c.local(stmt.Target)},
instruction.I32Eqz{},
instruction.BrIf{Index: 0},
instruction.GetLocal{Index: c.local(stmt.Target)},
instruction.GetLocal{Index: c.local(stmt.Source)},
instruction.Call{Index: c.function(opaValueCompare)},
instruction.I32Eqz{},
instruction.BrIf{Index: 1},
instruction.I32Const{Value: c.builtinStringAddr(errVarAssignConflict)},
instruction.Call{Index: c.function(opaAbort)},
instruction.Unreachable{},
},
},
instruction.GetLocal{Index: c.local(stmt.Source)},
instruction.SetLocal{Index: c.local(stmt.Target)},
},
})
case *ir.AssignBooleanStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Target)})
if stmt.Value {
instrs = append(instrs, instruction.I32Const{Value: 1})
} else {
instrs = append(instrs, instruction.I32Const{Value: 0})
}
instrs = append(instrs, instruction.Call{Index: c.function(opaValueBooleanSet)})
case *ir.AssignIntStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Target)})
instrs = append(instrs, instruction.I64Const{Value: stmt.Value})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueNumberSetInt)})
case *ir.ScanStmt:
if err := c.compileScan(stmt, &instrs); err != nil {
return nil, err
}
case *ir.NotStmt:
if err := c.compileNot(stmt, &instrs); err != nil {
return nil, err
}
case *ir.DotStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Source)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Key)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueGet)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Target)})
instrs = append(instrs, instruction.I32Eqz{})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.LenStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Source)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueLength)})
instrs = append(instrs, instruction.Call{Index: c.function(opaNumberSize)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
case *ir.EqualStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.A)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.B)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueCompare)})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.LessThanStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.A)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.B)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueCompare)})
instrs = append(instrs, instruction.I32Const{Value: 0})
instrs = append(instrs, instruction.I32GeS{})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.LessThanEqualStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.A)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.B)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueCompare)})
instrs = append(instrs, instruction.I32Const{Value: 0})
instrs = append(instrs, instruction.I32GtS{})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.GreaterThanStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.A)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.B)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueCompare)})
instrs = append(instrs, instruction.I32Const{Value: 0})
instrs = append(instrs, instruction.I32LeS{})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.GreaterThanEqualStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.A)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.B)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueCompare)})
instrs = append(instrs, instruction.I32Const{Value: 0})
instrs = append(instrs, instruction.I32LtS{})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.NotEqualStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.A)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.B)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueCompare)})
instrs = append(instrs, instruction.I32Eqz{})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.MakeNullStmt:
instrs = append(instrs, instruction.Call{Index: c.function(opaNull)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
case *ir.MakeBooleanStmt:
instr := instruction.I32Const{}
if stmt.Value {
instr.Value = 1
} else {
instr.Value = 0
}
instrs = append(instrs, instr)
instrs = append(instrs, instruction.Call{Index: c.function(opaBoolean)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
case *ir.MakeNumberFloatStmt:
instrs = append(instrs, instruction.F64Const{Value: stmt.Value})
instrs = append(instrs, instruction.Call{Index: c.function(opaNumberFloat)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
case *ir.MakeNumberIntStmt:
instrs = append(instrs, instruction.I64Const{Value: stmt.Value})
instrs = append(instrs, instruction.Call{Index: c.function(opaNumberInt)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
case *ir.MakeNumberRefStmt:
instrs = append(instrs, instruction.I32Const{Value: c.stringAddr(stmt.Index)})
instrs = append(instrs, instruction.I32Const{Value: int32(len(c.policy.Static.Strings[stmt.Index].Value))})
instrs = append(instrs, instruction.Call{Index: c.function(opaNumberRef)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
case *ir.MakeStringStmt:
instrs = append(instrs, instruction.I32Const{Value: c.stringAddr(stmt.Index)})
instrs = append(instrs, instruction.Call{Index: c.function(opaStringTerminated)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
case *ir.MakeArrayStmt:
instrs = append(instrs, instruction.I32Const{Value: stmt.Capacity})
instrs = append(instrs, instruction.Call{Index: c.function(opaArrayWithCap)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
case *ir.MakeObjectStmt:
instrs = append(instrs, instruction.Call{Index: c.function(opaObject)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
case *ir.MakeSetStmt:
instrs = append(instrs, instruction.Call{Index: c.function(opaSet)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
case *ir.IsArrayStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Source)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueType)})
instrs = append(instrs, instruction.I32Const{Value: opaTypeArray})
instrs = append(instrs, instruction.I32Ne{})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.IsObjectStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Source)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueType)})
instrs = append(instrs, instruction.I32Const{Value: opaTypeObject})
instrs = append(instrs, instruction.I32Ne{})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.IsUndefinedStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Source)})
instrs = append(instrs, instruction.I32Const{Value: 0})
instrs = append(instrs, instruction.I32Ne{})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.IsDefinedStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Source)})
instrs = append(instrs, instruction.I32Eqz{})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.ArrayAppendStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Array)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Value)})
instrs = append(instrs, instruction.Call{Index: c.function(opaArrayAppend)})
case *ir.ObjectInsertStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Object)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Key)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Value)})
instrs = append(instrs, instruction.Call{Index: c.function(opaObjectInsert)})
case *ir.ObjectInsertOnceStmt:
tmp := c.genLocal()
instrs = append(instrs, instruction.Block{
Instrs: []instruction.Instruction{
instruction.Block{
Instrs: []instruction.Instruction{
instruction.GetLocal{Index: c.local(stmt.Object)},
instruction.GetLocal{Index: c.local(stmt.Key)},
instruction.Call{Index: c.function(opaValueGet)},
instruction.SetLocal{Index: tmp},
instruction.GetLocal{Index: tmp},
instruction.I32Eqz{},
instruction.BrIf{Index: 0},
instruction.GetLocal{Index: tmp},
instruction.GetLocal{Index: c.local(stmt.Value)},
instruction.Call{Index: c.function(opaValueCompare)},
instruction.I32Eqz{},
instruction.BrIf{Index: 1},
instruction.I32Const{Value: c.builtinStringAddr(errObjectInsertConflict)},
instruction.Call{Index: c.function(opaAbort)},
instruction.Unreachable{},
},
},
instruction.GetLocal{Index: c.local(stmt.Object)},
instruction.GetLocal{Index: c.local(stmt.Key)},
instruction.GetLocal{Index: c.local(stmt.Value)},
instruction.Call{Index: c.function(opaObjectInsert)},
},
})
case *ir.ObjectMergeStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.A)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.B)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueMerge)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Target)})
instrs = append(instrs, instruction.Block{
Instrs: []instruction.Instruction{
instruction.GetLocal{Index: c.local(stmt.Target)},
instruction.BrIf{Index: 0},
instruction.I32Const{Value: c.builtinStringAddr(errObjectMergeConflict)},
instruction.Call{Index: c.function(opaAbort)},
instruction.Unreachable{},
},
})
case *ir.SetAddStmt:
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Set)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Value)})
instrs = append(instrs, instruction.Call{Index: c.function(opaSetAdd)})
default:
var buf bytes.Buffer
ir.Pretty(&buf, stmt)
return instrs, fmt.Errorf("illegal statement: %v", buf.String())
}
}
return instrs, nil
}
func (c *Compiler) compileScan(scan *ir.ScanStmt, result *[]instruction.Instruction) error {
var instrs = *result
instrs = append(instrs, instruction.I32Const{Value: 0})
instrs = append(instrs, instruction.SetLocal{Index: c.local(scan.Key)})
body, err := c.compileScanBlock(scan)
if err != nil {
return err
}
instrs = append(instrs, instruction.Block{
Instrs: []instruction.Instruction{
instruction.Loop{Instrs: body},
},
})
*result = instrs
return nil
}
func (c *Compiler) compileScanBlock(scan *ir.ScanStmt) ([]instruction.Instruction, error) {
var instrs []instruction.Instruction
// Execute iterator.
instrs = append(instrs, instruction.GetLocal{Index: c.local(scan.Source)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(scan.Key)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueIter)})
// Check for emptiness.
instrs = append(instrs, instruction.SetLocal{Index: c.local(scan.Key)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(scan.Key)})
instrs = append(instrs, instruction.I32Eqz{})
instrs = append(instrs, instruction.BrIf{Index: 1})
// Load value.
instrs = append(instrs, instruction.GetLocal{Index: c.local(scan.Source)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(scan.Key)})
instrs = append(instrs, instruction.Call{Index: c.function(opaValueGet)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(scan.Value)})
// Loop body.
nested, err := c.compileBlock(scan.Block)
if err != nil {
return nil, err
}
// Continue.
instrs = append(instrs, nested...)
instrs = append(instrs, instruction.Br{Index: 0})
return instrs, nil
}
func (c *Compiler) compileNot(not *ir.NotStmt, result *[]instruction.Instruction) error {
var instrs = *result
// generate and initialize condition variable
cond := c.genLocal()
instrs = append(instrs, instruction.I32Const{Value: 1})
instrs = append(instrs, instruction.SetLocal{Index: cond})
nested, err := c.compileBlock(not.Block)
if err != nil {
return err
}
// unset condition variable if end of block is reached
nested = append(nested, instruction.I32Const{Value: 0})
nested = append(nested, instruction.SetLocal{Index: cond})
instrs = append(instrs, instruction.Block{Instrs: nested})
// break out of block if condition variable was unset
instrs = append(instrs, instruction.GetLocal{Index: cond})
instrs = append(instrs, instruction.I32Eqz{})
instrs = append(instrs, instruction.BrIf{Index: 0})
*result = instrs
return nil
}
func (c *Compiler) compileWithStmt(with *ir.WithStmt, result *[]instruction.Instruction) error {
var instrs = *result
save := c.genLocal()
instrs = append(instrs, instruction.GetLocal{Index: c.local(with.Local)})
instrs = append(instrs, instruction.SetLocal{Index: save})
if len(with.Path) == 0 {
instrs = append(instrs, instruction.GetLocal{Index: c.local(with.Value)})
instrs = append(instrs, instruction.SetLocal{Index: c.local(with.Local)})
} else {
instrs = c.compileUpsert(with.Local, with.Path, with.Value, instrs)
}
undefined := c.genLocal()
instrs = append(instrs, instruction.I32Const{Value: 1})
instrs = append(instrs, instruction.SetLocal{Index: undefined})
nested, err := c.compileBlock(with.Block)
if err != nil {
return err
}
nested = append(nested, instruction.I32Const{Value: 0})
nested = append(nested, instruction.SetLocal{Index: undefined})
instrs = append(instrs, instruction.Block{Instrs: nested})
instrs = append(instrs, instruction.GetLocal{Index: save})
instrs = append(instrs, instruction.SetLocal{Index: c.local(with.Local)})
instrs = append(instrs, instruction.GetLocal{Index: undefined})
instrs = append(instrs, instruction.BrIf{Index: 0})
*result = instrs
return nil
}
func (c *Compiler) compileUpsert(local ir.Local, path []int, value ir.Local, instrs []instruction.Instruction) []instruction.Instruction {
lcopy := c.genLocal() // holds copy of local
instrs = append(instrs, instruction.GetLocal{Index: c.local(local)})
instrs = append(instrs, instruction.SetLocal{Index: lcopy})
// Shallow copy the local if defined otherwise initialize to an empty object.
instrs = append(instrs, instruction.Block{
Instrs: []instruction.Instruction{
instruction.Block{Instrs: []instruction.Instruction{
instruction.GetLocal{Index: lcopy},
instruction.I32Eqz{},
instruction.BrIf{Index: 0},
instruction.GetLocal{Index: lcopy},
instruction.Call{Index: c.function(opaValueShallowCopy)},
instruction.SetLocal{Index: lcopy},
instruction.GetLocal{Index: lcopy},
instruction.SetLocal{Index: c.local(local)},
instruction.Br{Index: 1},
}},
instruction.Call{Index: c.function(opaObject)},
instruction.SetLocal{Index: lcopy},
instruction.GetLocal{Index: lcopy},
instruction.SetLocal{Index: c.local(local)},
},
})
// Initialize the locals that specify the path of the upsert operation.
lpath := make(map[int]uint32, len(path))
for i := 0; i < len(path); i++ {
lpath[i] = c.genLocal()
instrs = append(instrs, instruction.I32Const{Value: c.stringAddr(path[i])})
instrs = append(instrs, instruction.Call{Index: c.function(opaStringTerminated)})
instrs = append(instrs, instruction.SetLocal{Index: lpath[i]})
}
// Generate a block that traverses the path of the upsert operation,
// shallowing copying values at each step as needed. Stop before the final
// segment that will only be inserted.
var inner []instruction.Instruction
ltemp := c.genLocal()
for i := 0; i < len(path)-1; i++ {
// Lookup the next part of the path.
inner = append(inner, instruction.GetLocal{Index: lcopy})
inner = append(inner, instruction.GetLocal{Index: lpath[i]})
inner = append(inner, instruction.Call{Index: c.function(opaValueGet)})
inner = append(inner, instruction.SetLocal{Index: ltemp})
// If the next node is missing, break.
inner = append(inner, instruction.GetLocal{Index: ltemp})
inner = append(inner, instruction.I32Eqz{})
inner = append(inner, instruction.BrIf{Index: uint32(i)})
// If the next node is not an object, generate a conflict error.
inner = append(inner, instruction.Block{
Instrs: []instruction.Instruction{
instruction.GetLocal{Index: ltemp},
instruction.Call{Index: c.function(opaValueType)},
instruction.I32Const{Value: opaTypeObject},
instruction.I32Eq{},
instruction.BrIf{Index: 0},
instruction.I32Const{Value: c.builtinStringAddr(errWithConflict)},
instruction.Call{Index: c.function(opaAbort)},
},
})
// Otherwise, shallow copy the next node node and insert into the copy
// before continuing.
inner = append(inner, instruction.GetLocal{Index: ltemp})
inner = append(inner, instruction.Call{Index: c.function(opaValueShallowCopy)})
inner = append(inner, instruction.SetLocal{Index: ltemp})
inner = append(inner, instruction.GetLocal{Index: lcopy})
inner = append(inner, instruction.GetLocal{Index: lpath[i]})
inner = append(inner, instruction.GetLocal{Index: ltemp})
inner = append(inner, instruction.Call{Index: c.function(opaObjectInsert)})
inner = append(inner, instruction.GetLocal{Index: ltemp})
inner = append(inner, instruction.SetLocal{Index: lcopy})
}
inner = append(inner, instruction.Br{Index: uint32(len(path) - 1)})
// Generate blocks that handle missing nodes during traversal.
var block []instruction.Instruction
lval := c.genLocal()
for i := 0; i < len(path)-1; i++ {
block = append(block, instruction.Block{Instrs: inner})
block = append(block, instruction.Call{Index: c.function(opaObject)})
block = append(block, instruction.SetLocal{Index: lval})
block = append(block, instruction.GetLocal{Index: lcopy})
block = append(block, instruction.GetLocal{Index: lpath[i]})
block = append(block, instruction.GetLocal{Index: lval})
block = append(block, instruction.Call{Index: c.function(opaObjectInsert)})
block = append(block, instruction.GetLocal{Index: lval})
block = append(block, instruction.SetLocal{Index: lcopy})
inner = block
block = nil
}
// Finish by inserting the statement's value into the shallow copied node.
instrs = append(instrs, instruction.Block{Instrs: inner})
instrs = append(instrs, instruction.GetLocal{Index: lcopy})
instrs = append(instrs, instruction.GetLocal{Index: lpath[len(path)-1]})
instrs = append(instrs, instruction.GetLocal{Index: c.local(value)})
instrs = append(instrs, instruction.Call{Index: c.function(opaObjectInsert)})
return instrs
}
func (c *Compiler) compileCallStmt(stmt *ir.CallStmt, result *[]instruction.Instruction) error {
if index, ok := c.funcs[stmt.Func]; ok {
return c.compileInternalCall(stmt, index, result)
}
if id, ok := c.builtinFuncs[stmt.Func]; ok {
return c.compileBuiltinCall(stmt, id, result)
}
c.errors = append(c.errors, fmt.Errorf("undefined function: %q", stmt.Func))
return nil
}
func (c *Compiler) compileInternalCall(stmt *ir.CallStmt, index uint32, result *[]instruction.Instruction) error {
instrs := *result
for _, arg := range stmt.Args {
instrs = append(instrs, instruction.GetLocal{Index: c.local(arg)})
}
instrs = append(instrs, instruction.Call{Index: index})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Result)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Result)})
instrs = append(instrs, instruction.I32Eqz{})
instrs = append(instrs, instruction.BrIf{Index: 0})
*result = instrs
return nil
}
func (c *Compiler) compileBuiltinCall(stmt *ir.CallStmt, id int32, result *[]instruction.Instruction) error {
if len(stmt.Args) >= len(builtins) {
c.errors = append(c.errors, fmt.Errorf("too many built-in call arguments: %q", stmt.Func))
return nil
}
instrs := *result
instrs = append(instrs, instruction.I32Const{Value: id})
instrs = append(instrs, instruction.I32Const{Value: 0}) // unused context parameter
for _, arg := range stmt.Args {
instrs = append(instrs, instruction.GetLocal{Index: c.local(arg)})
}
instrs = append(instrs, instruction.Call{Index: c.funcs[builtins[len(stmt.Args)]]})
instrs = append(instrs, instruction.SetLocal{Index: c.local(stmt.Result)})
instrs = append(instrs, instruction.GetLocal{Index: c.local(stmt.Result)})
instrs = append(instrs, instruction.I32Eqz{})
instrs = append(instrs, instruction.BrIf{Index: 0})
*result = instrs
return nil
}
func (c *Compiler) emitFunctionDecl(name string, tpe module.FunctionType, export bool) {
typeIndex := c.emitFunctionType(tpe)
c.module.Function.TypeIndices = append(c.module.Function.TypeIndices, typeIndex)
c.module.Code.Segments = append(c.module.Code.Segments, module.RawCodeSegment{})
c.funcs[name] = uint32((len(c.module.Function.TypeIndices) - 1) + c.functionImportCount())
if export {
c.module.Export.Exports = append(c.module.Export.Exports, module.Export{
Name: name,
Descriptor: module.ExportDescriptor{
Type: module.FunctionExportType,
Index: c.funcs[name],
},
})
}
}
func (c *Compiler) emitFunctionType(tpe module.FunctionType) uint32 {
for i, other := range c.module.Type.Functions {
if tpe.Equal(other) {
return uint32(i)
}
}
c.module.Type.Functions = append(c.module.Type.Functions, tpe)
return uint32(len(c.module.Type.Functions) - 1)
}
func (c *Compiler) emitFunction(name string, entry *module.CodeEntry) error {
var buf bytes.Buffer
if err := encoding.WriteCodeEntry(&buf, entry); err != nil {
return err
}
index := c.function(name) - uint32(c.functionImportCount())
c.module.Code.Segments[index].Code = buf.Bytes()
return nil
}
func (c *Compiler) functionImportCount() int {
var count int
for _, imp := range c.module.Import.Imports {
if imp.Descriptor.Kind() == module.FunctionImportType {
count++
}
}
return count
}
func (c *Compiler) stringAddr(index int) int32 {
return int32(c.stringAddrs[index])
}
func (c *Compiler) builtinStringAddr(code int) int32 {
return int32(c.builtinStringAddrs[code])
}
func (c *Compiler) local(l ir.Local) uint32 {
var u32 uint32
var exist bool
if u32, exist = c.locals[l]; !exist {
u32 = c.nextLocal
c.locals[l] = u32
c.nextLocal++
}
return u32
}
func (c *Compiler) genLocal() uint32 {
l := c.nextLocal
c.nextLocal++
return l
}
func (c *Compiler) function(name string) uint32 {
return c.funcs[name]
}
func (c *Compiler) appendInstr(instr instruction.Instruction) {
c.code.Func.Expr.Instrs = append(c.code.Func.Expr.Instrs, instr)
}
func (c *Compiler) appendInstrs(instrs []instruction.Instruction) {
for _, instr := range instrs {
c.appendInstr(instr)
}
}