185 lines
3.3 KiB
Go
185 lines
3.3 KiB
Go
// Copyright 2016 The OPA Authors. All rights reserved.
|
|
// Use of this source code is governed by an Apache2
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package ast
|
|
|
|
// Unify returns a set of variables that will be unified when the equality expression defined by
|
|
// terms a and b is evaluated. The unifier assumes that variables in the VarSet safe are already
|
|
// unified.
|
|
func Unify(safe VarSet, a *Term, b *Term) VarSet {
|
|
u := &unifier{
|
|
safe: safe,
|
|
unified: VarSet{},
|
|
unknown: map[Var]VarSet{},
|
|
}
|
|
u.unify(a, b)
|
|
return u.unified
|
|
}
|
|
|
|
type unifier struct {
|
|
safe VarSet
|
|
unified VarSet
|
|
unknown map[Var]VarSet
|
|
}
|
|
|
|
func (u *unifier) isSafe(x Var) bool {
|
|
return u.safe.Contains(x) || u.unified.Contains(x)
|
|
}
|
|
|
|
func (u *unifier) unify(a *Term, b *Term) {
|
|
|
|
switch a := a.Value.(type) {
|
|
|
|
case Var:
|
|
switch b := b.Value.(type) {
|
|
case Var:
|
|
if u.isSafe(b) {
|
|
u.markSafe(a)
|
|
} else if u.isSafe(a) {
|
|
u.markSafe(b)
|
|
} else {
|
|
u.markUnknown(a, b)
|
|
u.markUnknown(b, a)
|
|
}
|
|
case Array, Object:
|
|
u.unifyAll(a, b)
|
|
case Ref:
|
|
if u.isSafe(b[0].Value.(Var)) {
|
|
u.markSafe(a)
|
|
}
|
|
default:
|
|
u.markSafe(a)
|
|
}
|
|
|
|
case Ref:
|
|
if u.isSafe(a[0].Value.(Var)) {
|
|
switch b := b.Value.(type) {
|
|
case Var:
|
|
u.markSafe(b)
|
|
case Array, Object:
|
|
u.markAllSafe(b)
|
|
}
|
|
}
|
|
|
|
case *ArrayComprehension:
|
|
switch b := b.Value.(type) {
|
|
case Var:
|
|
u.markSafe(b)
|
|
case Array:
|
|
u.markAllSafe(b)
|
|
}
|
|
case *ObjectComprehension:
|
|
switch b := b.Value.(type) {
|
|
case Var:
|
|
u.markSafe(b)
|
|
case Object:
|
|
u.markAllSafe(b)
|
|
}
|
|
case *SetComprehension:
|
|
switch b := b.Value.(type) {
|
|
case Var:
|
|
u.markSafe(b)
|
|
}
|
|
|
|
case Array:
|
|
switch b := b.Value.(type) {
|
|
case Var:
|
|
u.unifyAll(b, a)
|
|
case Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension:
|
|
u.markAllSafe(a)
|
|
case Array:
|
|
if len(a) == len(b) {
|
|
for i := range a {
|
|
u.unify(a[i], b[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
case Object:
|
|
switch b := b.Value.(type) {
|
|
case Var:
|
|
u.unifyAll(b, a)
|
|
case Ref:
|
|
u.markAllSafe(a)
|
|
case Object:
|
|
if a.Len() == b.Len() {
|
|
a.Iter(func(k, v *Term) error {
|
|
if v2 := b.Get(k); v2 != nil {
|
|
u.unify(v, v2)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
}
|
|
|
|
default:
|
|
switch b := b.Value.(type) {
|
|
case Var:
|
|
u.markSafe(b)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (u *unifier) markAllSafe(x Value) {
|
|
vis := u.varVisitor()
|
|
vis.Walk(x)
|
|
for v := range vis.Vars() {
|
|
u.markSafe(v)
|
|
}
|
|
}
|
|
|
|
func (u *unifier) markSafe(x Var) {
|
|
u.unified.Add(x)
|
|
|
|
// Add dependencies of 'x' to safe set
|
|
vs := u.unknown[x]
|
|
delete(u.unknown, x)
|
|
for v := range vs {
|
|
u.markSafe(v)
|
|
}
|
|
|
|
// Add dependants of 'x' to safe set if they have no more
|
|
// dependencies.
|
|
for v, deps := range u.unknown {
|
|
if deps.Contains(x) {
|
|
delete(deps, x)
|
|
if len(deps) == 0 {
|
|
u.markSafe(v)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (u *unifier) markUnknown(a, b Var) {
|
|
if _, ok := u.unknown[a]; !ok {
|
|
u.unknown[a] = NewVarSet()
|
|
}
|
|
u.unknown[a].Add(b)
|
|
}
|
|
|
|
func (u *unifier) unifyAll(a Var, b Value) {
|
|
if u.isSafe(a) {
|
|
u.markAllSafe(b)
|
|
} else {
|
|
vis := u.varVisitor()
|
|
vis.Walk(b)
|
|
unsafe := vis.Vars().Diff(u.safe).Diff(u.unified)
|
|
if len(unsafe) == 0 {
|
|
u.markSafe(a)
|
|
} else {
|
|
for v := range unsafe {
|
|
u.markUnknown(a, v)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (u *unifier) varVisitor() *VarVisitor {
|
|
return NewVarVisitor().WithParams(VarVisitorParams{
|
|
SkipRefHead: true,
|
|
SkipObjectKeys: true,
|
|
SkipClosures: true,
|
|
})
|
|
}
|