Files
2025-05-14 14:44:13 +08:00

966 lines
19 KiB
Go

// Copyright 2018 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package encoding
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/open-policy-agent/opa/internal/leb128"
"github.com/open-policy-agent/opa/internal/wasm/constant"
"github.com/open-policy-agent/opa/internal/wasm/instruction"
"github.com/open-policy-agent/opa/internal/wasm/module"
"github.com/open-policy-agent/opa/internal/wasm/opcode"
"github.com/open-policy-agent/opa/internal/wasm/types"
)
// ReadModule reads a binary-encoded WASM module from r.
func ReadModule(r io.Reader) (*module.Module, error) {
wr := &reader{r: r, n: 0}
module, err := readModule(wr)
if err != nil {
return nil, fmt.Errorf("offset 0x%x: %w", wr.n, err)
}
return module, nil
}
// ReadCodeEntry reads a binary-encoded WASM code entry from r.
func ReadCodeEntry(r io.Reader) (*module.CodeEntry, error) {
wr := &reader{r: r, n: 0}
entry, err := readCodeEntry(wr)
if err != nil {
return nil, fmt.Errorf("offset 0x%x: %w", wr.n, err)
}
return entry, nil
}
// CodeEntries returns the WASM code entries contained in r.
func CodeEntries(m *module.Module) ([]*module.CodeEntry, error) {
entries := make([]*module.CodeEntry, len(m.Code.Segments))
for i, s := range m.Code.Segments {
buf := bytes.NewBuffer(s.Code)
entry, err := ReadCodeEntry(buf)
if err != nil {
return nil, err
}
entries[i] = entry
}
return entries, nil
}
type reader struct {
r io.Reader
n int
}
func (r *reader) Read(bs []byte) (int, error) {
n, err := r.r.Read(bs)
r.n += n
return n, err
}
func readModule(r io.Reader) (*module.Module, error) {
if err := readMagic(r); err != nil {
return nil, err
}
if err := readVersion(r); err != nil {
return nil, err
}
var m module.Module
if err := readSections(r, &m); err != nil && err != io.EOF {
return nil, err
}
return &m, nil
}
func readCodeEntry(r io.Reader) (*module.CodeEntry, error) {
var entry module.CodeEntry
if err := readLocals(r, &entry.Func.Locals); err != nil {
return nil, fmt.Errorf("local declarations: %w", err)
}
return &entry, readExpr(r, &entry.Func.Expr)
}
func readMagic(r io.Reader) error {
var v uint32
if err := binary.Read(r, binary.LittleEndian, &v); err != nil {
return err
} else if v != constant.Magic {
return errors.New("illegal magic value")
}
return nil
}
func readVersion(r io.Reader) error {
var v uint32
if err := binary.Read(r, binary.LittleEndian, &v); err != nil {
return err
} else if v != constant.Version {
return errors.New("illegal wasm version")
}
return nil
}
func readSections(r io.Reader, m *module.Module) error {
for {
id, err := readByte(r)
if err != nil {
return err
}
size, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
buf := make([]byte, size)
if _, err := io.ReadFull(r, buf); err != nil {
return err
}
bufr := bytes.NewReader(buf)
switch id {
case constant.StartSectionID:
if err := readStartSection(bufr, &m.Start); err != nil {
return fmt.Errorf("start section: %w", err)
}
case constant.CustomSectionID:
var name string
if err := readByteVectorString(bufr, &name); err != nil {
return fmt.Errorf("read custom section type: %w", err)
}
if name == "name" {
if err := readCustomNameSections(bufr, &m.Names); err != nil {
return fmt.Errorf("custom 'name' section: %w", err)
}
} else {
if err := readCustomSection(bufr, name, &m.Customs); err != nil {
return fmt.Errorf("custom section: %w", err)
}
}
case constant.TypeSectionID:
if err := readTypeSection(bufr, &m.Type); err != nil {
return fmt.Errorf("type section: %w", err)
}
case constant.ImportSectionID:
if err := readImportSection(bufr, &m.Import); err != nil {
return fmt.Errorf("import section: %w", err)
}
case constant.TableSectionID:
if err := readTableSection(bufr, &m.Table); err != nil {
return fmt.Errorf("table section: %w", err)
}
case constant.MemorySectionID:
if err := readMemorySection(bufr, &m.Memory); err != nil {
return fmt.Errorf("memory section: %w", err)
}
case constant.GlobalSectionID:
if err := readGlobalSection(bufr, &m.Global); err != nil {
return fmt.Errorf("global section: %w", err)
}
case constant.FunctionSectionID:
if err := readFunctionSection(bufr, &m.Function); err != nil {
return fmt.Errorf("function section: %w", err)
}
case constant.ExportSectionID:
if err := readExportSection(bufr, &m.Export); err != nil {
return fmt.Errorf("export section: %w", err)
}
case constant.ElementSectionID:
if err := readElementSection(bufr, &m.Element); err != nil {
return fmt.Errorf("element section: %w", err)
}
case constant.DataSectionID:
if err := readDataSection(bufr, &m.Data); err != nil {
return fmt.Errorf("data section: %w", err)
}
case constant.CodeSectionID:
if err := readRawCodeSection(bufr, &m.Code); err != nil {
return fmt.Errorf("code section: %w", err)
}
default:
return errors.New("illegal section id")
}
}
}
func readCustomSection(r io.Reader, name string, s *[]module.CustomSection) error {
buf, err := io.ReadAll(r)
if err != nil {
return err
}
*s = append(*s, module.CustomSection{
Name: name,
Data: buf,
})
return nil
}
func readCustomNameSections(r io.Reader, s *module.NameSection) error {
for {
id, err := readByte(r)
if err != nil {
if err == io.EOF {
break
}
return err
}
n, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
buf := make([]byte, n)
if _, err := io.ReadFull(r, buf); err != nil {
return err
}
bufr := bytes.NewReader(buf)
switch id {
case constant.NameSectionModuleType:
err = readNameSectionModule(bufr, s)
case constant.NameSectionFunctionsType:
err = readNameSectionFunctions(bufr, s)
case constant.NameSectionLocalsType:
err = readNameSectionLocals(bufr, s)
}
if err != nil {
return err
}
}
return nil
}
func readNameSectionModule(r io.Reader, s *module.NameSection) error {
return readByteVectorString(r, &s.Module)
}
func readNameSectionFunctions(r io.Reader, s *module.NameSection) error {
nm, err := readNameMap(r)
if err != nil {
return err
}
s.Functions = nm
return nil
}
func readNameMap(r io.Reader) ([]module.NameMap, error) {
n, err := leb128.ReadVarUint32(r)
if err != nil {
return nil, err
}
nm := make([]module.NameMap, n)
for i := range n {
var name string
id, err := leb128.ReadVarUint32(r)
if err != nil {
return nil, err
}
if err := readByteVectorString(r, &name); err != nil {
return nil, err
}
nm[i] = module.NameMap{Index: id, Name: name}
}
return nm, nil
}
func readNameSectionLocals(r io.Reader, s *module.NameSection) error {
n, err := leb128.ReadVarUint32(r) // length of vec(indirectnameassoc)
if err != nil {
return err
}
for range n {
id, err := leb128.ReadVarUint32(r) // func index
if err != nil {
return err
}
nm, err := readNameMap(r)
if err != nil {
return err
}
for _, m := range nm {
s.Locals = append(s.Locals, module.LocalNameMap{
FuncIndex: id,
NameMap: module.NameMap{
Index: m.Index,
Name: m.Name,
}})
}
}
return nil
}
func readStartSection(r io.Reader, s *module.StartSection) error {
idx, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
s.FuncIndex = &idx
return nil
}
func readTypeSection(r io.Reader, s *module.TypeSection) error {
n, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
for range n {
var ftype module.FunctionType
if err := readFunctionType(r, &ftype); err != nil {
return err
}
s.Functions = append(s.Functions, ftype)
}
return nil
}
func readImportSection(r io.Reader, s *module.ImportSection) error {
n, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
for range n {
var imp module.Import
if err := readImport(r, &imp); err != nil {
return err
}
s.Imports = append(s.Imports, imp)
}
return nil
}
func readTableSection(r io.Reader, s *module.TableSection) error {
n, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
for range n {
var table module.Table
if elem, err := readByte(r); err != nil {
return err
} else if elem != constant.ElementTypeAnyFunc {
return errors.New("illegal element type")
}
table.Type = types.Anyfunc
if err := readLimits(r, &table.Lim); err != nil {
return err
}
s.Tables = append(s.Tables, table)
}
return nil
}
func readMemorySection(r io.Reader, s *module.MemorySection) error {
n, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
for range n {
var mem module.Memory
if err := readLimits(r, &mem.Lim); err != nil {
return err
}
s.Memories = append(s.Memories, mem)
}
return nil
}
func readGlobalSection(r io.Reader, s *module.GlobalSection) error {
n, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
for range n {
var global module.Global
if err := readGlobal(r, &global); err != nil {
return err
}
s.Globals = append(s.Globals, global)
}
return nil
}
func readFunctionSection(r io.Reader, s *module.FunctionSection) error {
return readVarUint32Vector(r, &s.TypeIndices)
}
func readExportSection(r io.Reader, s *module.ExportSection) error {
n, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
for range n {
var exp module.Export
if err := readExport(r, &exp); err != nil {
return err
}
s.Exports = append(s.Exports, exp)
}
return nil
}
func readElementSection(r io.Reader, s *module.ElementSection) error {
n, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
for range n {
var seg module.ElementSegment
if err := readElementSegment(r, &seg); err != nil {
return err
}
s.Segments = append(s.Segments, seg)
}
return nil
}
func readDataSection(r io.Reader, s *module.DataSection) error {
n, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
for range n {
var seg module.DataSegment
if err := readDataSegment(r, &seg); err != nil {
return err
}
s.Segments = append(s.Segments, seg)
}
return nil
}
func readRawCodeSection(r io.Reader, s *module.RawCodeSection) error {
n, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
for range n {
var seg module.RawCodeSegment
if err := readRawCodeSegment(r, &seg); err != nil {
return err
}
s.Segments = append(s.Segments, seg)
}
return nil
}
func readFunctionType(r io.Reader, ftype *module.FunctionType) error {
if b, err := readByte(r); err != nil {
return err
} else if b != constant.FunctionTypeID {
return fmt.Errorf("illegal function type id 0x%x", b)
}
if err := readValueTypeVector(r, &ftype.Params); err != nil {
return err
}
return readValueTypeVector(r, &ftype.Results)
}
func readGlobal(r io.Reader, global *module.Global) error {
if err := readValueType(r, &global.Type); err != nil {
return err
}
b, err := readByte(r)
if err != nil {
return err
}
if b == 1 {
global.Mutable = true
} else if b != 0 {
return errors.New("illegal mutability flag")
}
return readConstantExpr(r, &global.Init)
}
func readImport(r io.Reader, imp *module.Import) error {
if err := readByteVectorString(r, &imp.Module); err != nil {
return err
}
if err := readByteVectorString(r, &imp.Name); err != nil {
return err
}
b, err := readByte(r)
if err != nil {
return err
}
if b == constant.ImportDescType {
index, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
imp.Descriptor = module.FunctionImport{
Func: index,
}
return nil
}
if b == constant.ImportDescTable {
if elem, err := readByte(r); err != nil {
return err
} else if elem != constant.ElementTypeAnyFunc {
return errors.New("illegal element type")
}
desc := module.TableImport{
Type: types.Anyfunc,
}
if err := readLimits(r, &desc.Lim); err != nil {
return err
}
imp.Descriptor = desc
return nil
}
if b == constant.ImportDescMem {
desc := module.MemoryImport{}
if err := readLimits(r, &desc.Mem.Lim); err != nil {
return err
}
imp.Descriptor = desc
return nil
}
if b == constant.ImportDescGlobal {
desc := module.GlobalImport{}
if err := readValueType(r, &desc.Type); err != nil {
return err
}
b, err := readByte(r)
if err != nil {
return err
}
if b == 1 {
desc.Mutable = true
} else if b != 0 {
return errors.New("illegal mutability flag")
}
return nil
}
return errors.New("illegal import descriptor type")
}
func readExport(r io.Reader, exp *module.Export) error {
if err := readByteVectorString(r, &exp.Name); err != nil {
return err
}
b, err := readByte(r)
if err != nil {
return err
}
switch b {
case constant.ExportDescType:
exp.Descriptor.Type = module.FunctionExportType
case constant.ExportDescTable:
exp.Descriptor.Type = module.TableExportType
case constant.ExportDescMem:
exp.Descriptor.Type = module.MemoryExportType
case constant.ExportDescGlobal:
exp.Descriptor.Type = module.GlobalExportType
default:
return errors.New("illegal export descriptor type")
}
exp.Descriptor.Index, err = leb128.ReadVarUint32(r)
if err != nil {
return err
}
return nil
}
func readElementSegment(r io.Reader, seg *module.ElementSegment) error {
if err := readVarUint32(r, &seg.Index); err != nil {
return err
}
if err := readConstantExpr(r, &seg.Offset); err != nil {
return err
}
return readVarUint32Vector(r, &seg.Indices)
}
func readDataSegment(r io.Reader, seg *module.DataSegment) error {
if err := readVarUint32(r, &seg.Index); err != nil {
return err
}
if err := readConstantExpr(r, &seg.Offset); err != nil {
return err
}
return readByteVector(r, &seg.Init)
}
func readRawCodeSegment(r io.Reader, seg *module.RawCodeSegment) error {
return readByteVector(r, &seg.Code)
}
func readConstantExpr(r io.Reader, expr *module.Expr) error {
instrs := make([]instruction.Instruction, 0)
for {
b, err := readByte(r)
if err != nil {
return err
}
switch opcode.Opcode(b) {
case opcode.I32Const:
i32, err := leb128.ReadVarInt32(r)
if err != nil {
return err
}
instrs = append(instrs, instruction.I32Const{Value: i32})
case opcode.I64Const:
i64, err := leb128.ReadVarInt64(r)
if err != nil {
return err
}
instrs = append(instrs, instruction.I64Const{Value: i64})
case opcode.End:
expr.Instrs = instrs
return nil
default:
return fmt.Errorf("illegal constant expr opcode 0x%x", b)
}
}
}
func readExpr(r io.Reader, expr *module.Expr) (err error) {
defer func() {
if r := recover(); r != nil {
switch r := r.(type) {
case error:
err = r
default:
err = errors.New("unknown panic")
}
}
}()
return readInstructions(r, &expr.Instrs)
}
func readInstructions(r io.Reader, instrs *[]instruction.Instruction) error {
ret := make([]instruction.Instruction, 0)
for {
b, err := readByte(r)
if err != nil {
return err
}
switch opcode.Opcode(b) {
case opcode.I32Const:
ret = append(ret, instruction.I32Const{Value: leb128.MustReadVarInt32(r)})
case opcode.I64Const:
ret = append(ret, instruction.I64Const{Value: leb128.MustReadVarInt64(r)})
case opcode.I32Eqz:
ret = append(ret, instruction.I32Eqz{})
case opcode.GetLocal:
ret = append(ret, instruction.GetLocal{Index: leb128.MustReadVarUint32(r)})
case opcode.SetLocal:
ret = append(ret, instruction.SetLocal{Index: leb128.MustReadVarUint32(r)})
case opcode.Call:
ret = append(ret, instruction.Call{Index: leb128.MustReadVarUint32(r)})
case opcode.CallIndirect:
ret = append(ret, instruction.CallIndirect{
Index: leb128.MustReadVarUint32(r),
Reserved: mustReadByte(r),
})
case opcode.BrIf:
ret = append(ret, instruction.BrIf{Index: leb128.MustReadVarUint32(r)})
case opcode.Return:
ret = append(ret, instruction.Return{})
case opcode.Block:
block := instruction.Block{}
if err := readBlockValueType(r, block.Type); err != nil {
return err
}
if err := readInstructions(r, &block.Instrs); err != nil {
return err
}
ret = append(ret, block)
case opcode.Loop:
loop := instruction.Loop{}
if err := readBlockValueType(r, loop.Type); err != nil {
return err
}
if err := readInstructions(r, &loop.Instrs); err != nil {
return err
}
ret = append(ret, loop)
case opcode.End:
*instrs = ret
return nil
default:
return fmt.Errorf("illegal opcode 0x%x", b)
}
}
}
func mustReadByte(r io.Reader) byte {
b, err := readByte(r)
if err != nil {
panic(err)
}
return b
}
func readLimits(r io.Reader, l *module.Limit) error {
b, err := readByte(r)
if err != nil {
return err
}
minLim, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
l.Min = minLim
if b == 1 {
maxLim, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
l.Max = &maxLim
} else if b != 0 {
return errors.New("illegal limit flag")
}
return nil
}
func readLocals(r io.Reader, locals *[]module.LocalDeclaration) error {
n, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
ret := make([]module.LocalDeclaration, n)
for i := range n {
if err := readVarUint32(r, &ret[i].Count); err != nil {
return err
}
if err := readValueType(r, &ret[i].Type); err != nil {
return err
}
}
*locals = ret
return nil
}
func readByteVector(r io.Reader, v *[]byte) error {
n, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
buf := make([]byte, n)
if _, err := io.ReadFull(r, buf); err != nil {
return err
}
*v = buf
return nil
}
func readByteVectorString(r io.Reader, v *string) error {
var buf []byte
if err := readByteVector(r, &buf); err != nil {
return err
}
*v = string(buf)
return nil
}
func readVarUint32Vector(r io.Reader, v *[]uint32) error {
n, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
ret := make([]uint32, n)
for i := range n {
if err := readVarUint32(r, &ret[i]); err != nil {
return err
}
}
*v = ret
return nil
}
func readValueTypeVector(r io.Reader, v *[]types.ValueType) error {
n, err := leb128.ReadVarUint32(r)
if err != nil {
return err
}
ret := make([]types.ValueType, n)
for i := range n {
if err := readValueType(r, &ret[i]); err != nil {
return err
}
}
*v = ret
return nil
}
func readVarUint32(r io.Reader, v *uint32) error {
var err error
*v, err = leb128.ReadVarUint32(r)
return err
}
func readValueType(r io.Reader, v *types.ValueType) error {
if b, err := readByte(r); err != nil {
return err
} else if b == constant.ValueTypeI32 {
*v = types.I32
} else if b == constant.ValueTypeI64 {
*v = types.I64
} else if b == constant.ValueTypeF32 {
*v = types.F32
} else if b == constant.ValueTypeF64 {
*v = types.F64
} else {
return fmt.Errorf("illegal value type: 0x%x", b)
}
return nil
}
func readBlockValueType(r io.Reader, v *types.ValueType) error {
if b, err := readByte(r); err != nil {
return err
} else if b == constant.ValueTypeI32 {
*v = types.I32
} else if b == constant.ValueTypeI64 {
*v = types.I64
} else if b == constant.ValueTypeF32 {
*v = types.F32
} else if b == constant.ValueTypeF64 {
*v = types.F64
} else if b != constant.BlockTypeEmpty {
return fmt.Errorf("illegal value type: 0x%x", b)
}
return nil
}
func readByte(r io.Reader) (byte, error) {
buf := make([]byte, 1)
_, err := io.ReadFull(r, buf)
return buf[0], err
}