Source file src/cmd/compile/internal/ssa/poset.go

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa
     6  
     7  import (
     8  	"fmt"
     9  	"os"
    10  	"slices"
    11  )
    12  
    13  // If true, check poset integrity after every mutation
    14  var debugPoset = false
    15  
    16  const uintSize = 32 << (^uint(0) >> 63) // 32 or 64
    17  
    18  // bitset is a bit array for dense indexes.
    19  type bitset []uint
    20  
    21  func newBitset(n int) bitset {
    22  	return make(bitset, (n+uintSize-1)/uintSize)
    23  }
    24  
    25  func (bs bitset) Reset() {
    26  	for i := range bs {
    27  		bs[i] = 0
    28  	}
    29  }
    30  
    31  func (bs bitset) Set(idx uint32) {
    32  	bs[idx/uintSize] |= 1 << (idx % uintSize)
    33  }
    34  
    35  func (bs bitset) Clear(idx uint32) {
    36  	bs[idx/uintSize] &^= 1 << (idx % uintSize)
    37  }
    38  
    39  func (bs bitset) Test(idx uint32) bool {
    40  	return bs[idx/uintSize]&(1<<(idx%uintSize)) != 0
    41  }
    42  
    43  type undoType uint8
    44  
    45  const (
    46  	undoInvalid    undoType = iota
    47  	undoCheckpoint          // a checkpoint to group undo passes
    48  	undoSetChl              // change back left child of undo.idx to undo.edge
    49  	undoSetChr              // change back right child of undo.idx to undo.edge
    50  	undoNonEqual            // forget that SSA value undo.ID is non-equal to undo.idx (another ID)
    51  	undoNewNode             // remove new node created for SSA value undo.ID
    52  	undoAliasNode           // unalias SSA value undo.ID so that it points back to node index undo.idx
    53  	undoNewRoot             // remove node undo.idx from root list
    54  	undoChangeRoot          // remove node undo.idx from root list, and put back undo.edge.Target instead
    55  	undoMergeRoot           // remove node undo.idx from root list, and put back its children instead
    56  )
    57  
    58  // posetUndo represents an undo pass to be performed.
    59  // It's a union of fields that can be used to store information,
    60  // and typ is the discriminant, that specifies which kind
    61  // of operation must be performed. Not all fields are always used.
    62  type posetUndo struct {
    63  	typ  undoType
    64  	idx  uint32
    65  	ID   ID
    66  	edge posetEdge
    67  }
    68  
    69  const (
    70  	// Make poset handle values as unsigned numbers.
    71  	// (TODO: remove?)
    72  	posetFlagUnsigned = 1 << iota
    73  )
    74  
    75  // A poset edge. The zero value is the null/empty edge.
    76  // Packs target node index (31 bits) and strict flag (1 bit).
    77  type posetEdge uint32
    78  
    79  func newedge(t uint32, strict bool) posetEdge {
    80  	s := uint32(0)
    81  	if strict {
    82  		s = 1
    83  	}
    84  	return posetEdge(t<<1 | s)
    85  }
    86  func (e posetEdge) Target() uint32 { return uint32(e) >> 1 }
    87  func (e posetEdge) Strict() bool   { return uint32(e)&1 != 0 }
    88  func (e posetEdge) String() string {
    89  	s := fmt.Sprint(e.Target())
    90  	if e.Strict() {
    91  		s += "*"
    92  	}
    93  	return s
    94  }
    95  
    96  // posetNode is a node of a DAG within the poset.
    97  type posetNode struct {
    98  	l, r posetEdge
    99  }
   100  
   101  // poset is a union-find data structure that can represent a partially ordered set
   102  // of SSA values. Given a binary relation that creates a partial order (eg: '<'),
   103  // clients can record relations between SSA values using SetOrder, and later
   104  // check relations (in the transitive closure) with Ordered. For instance,
   105  // if SetOrder is called to record that A<B and B<C, Ordered will later confirm
   106  // that A<C.
   107  //
   108  // It is possible to record equality relations between SSA values with SetEqual and check
   109  // equality with Equal. Equality propagates into the transitive closure for the partial
   110  // order so that if we know that A<B<C and later learn that A==D, Ordered will return
   111  // true for D<C.
   112  //
   113  // It is also possible to record inequality relations between nodes with SetNonEqual;
   114  // non-equality relations are not transitive, but they can still be useful: for instance
   115  // if we know that A<=B and later we learn that A!=B, we can deduce that A<B.
   116  // NonEqual can be used to check whether it is known that the nodes are different, either
   117  // because SetNonEqual was called before, or because we know that they are strictly ordered.
   118  //
   119  // poset will refuse to record new relations that contradict existing relations:
   120  // for instance if A<B<C, calling SetOrder for C<A will fail returning false; also
   121  // calling SetEqual for C==A will fail.
   122  //
   123  // poset is implemented as a forest of DAGs; in each DAG, if there is a path (directed)
   124  // from node A to B, it means that A<B (or A<=B). Equality is represented by mapping
   125  // two SSA values to the same DAG node; when a new equality relation is recorded
   126  // between two existing nodes, the nodes are merged, adjusting incoming and outgoing edges.
   127  //
   128  // poset is designed to be memory efficient and do little allocations during normal usage.
   129  // Most internal data structures are pre-allocated and flat, so for instance adding a
   130  // new relation does not cause any allocation. For performance reasons,
   131  // each node has only up to two outgoing edges (like a binary tree), so intermediate
   132  // "extra" nodes are required to represent more than two relations. For instance,
   133  // to record that A<I, A<J, A<K (with no known relation between I,J,K), we create the
   134  // following DAG:
   135  //
   136  //	  A
   137  //	 / \
   138  //	I  extra
   139  //	    /  \
   140  //	   J    K
   141  type poset struct {
   142  	lastidx uint32            // last generated dense index
   143  	flags   uint8             // internal flags
   144  	values  map[ID]uint32     // map SSA values to dense indexes
   145  	nodes   []posetNode       // nodes (in all DAGs)
   146  	roots   []uint32          // list of root nodes (forest)
   147  	noneq   map[uint32]bitset // non-equal relations
   148  	undo    []posetUndo       // undo chain
   149  }
   150  
   151  func newPoset() *poset {
   152  	return &poset{
   153  		values: make(map[ID]uint32),
   154  		nodes:  make([]posetNode, 1, 16),
   155  		roots:  make([]uint32, 0, 4),
   156  		noneq:  make(map[uint32]bitset),
   157  		undo:   make([]posetUndo, 0, 4),
   158  	}
   159  }
   160  
   161  func (po *poset) SetUnsigned(uns bool) {
   162  	if uns {
   163  		po.flags |= posetFlagUnsigned
   164  	} else {
   165  		po.flags &^= posetFlagUnsigned
   166  	}
   167  }
   168  
   169  // Handle children
   170  func (po *poset) setchl(i uint32, l posetEdge) { po.nodes[i].l = l }
   171  func (po *poset) setchr(i uint32, r posetEdge) { po.nodes[i].r = r }
   172  func (po *poset) chl(i uint32) uint32          { return po.nodes[i].l.Target() }
   173  func (po *poset) chr(i uint32) uint32          { return po.nodes[i].r.Target() }
   174  func (po *poset) children(i uint32) (posetEdge, posetEdge) {
   175  	return po.nodes[i].l, po.nodes[i].r
   176  }
   177  
   178  // upush records a new undo step. It can be used for simple
   179  // undo passes that record up to one index and one edge.
   180  func (po *poset) upush(typ undoType, p uint32, e posetEdge) {
   181  	po.undo = append(po.undo, posetUndo{typ: typ, idx: p, edge: e})
   182  }
   183  
   184  // upushnew pushes an undo pass for a new node
   185  func (po *poset) upushnew(id ID, idx uint32) {
   186  	po.undo = append(po.undo, posetUndo{typ: undoNewNode, ID: id, idx: idx})
   187  }
   188  
   189  // upushneq pushes a new undo pass for a nonequal relation
   190  func (po *poset) upushneq(idx1 uint32, idx2 uint32) {
   191  	po.undo = append(po.undo, posetUndo{typ: undoNonEqual, ID: ID(idx1), idx: idx2})
   192  }
   193  
   194  // upushalias pushes a new undo pass for aliasing two nodes
   195  func (po *poset) upushalias(id ID, i2 uint32) {
   196  	po.undo = append(po.undo, posetUndo{typ: undoAliasNode, ID: id, idx: i2})
   197  }
   198  
   199  // addchild adds i2 as direct child of i1.
   200  func (po *poset) addchild(i1, i2 uint32, strict bool) {
   201  	i1l, i1r := po.children(i1)
   202  	e2 := newedge(i2, strict)
   203  
   204  	if i1l == 0 {
   205  		po.setchl(i1, e2)
   206  		po.upush(undoSetChl, i1, 0)
   207  	} else if i1r == 0 {
   208  		po.setchr(i1, e2)
   209  		po.upush(undoSetChr, i1, 0)
   210  	} else {
   211  		// If n1 already has two children, add an intermediate extra
   212  		// node to record the relation correctly (without relating
   213  		// n2 to other existing nodes). Use a non-deterministic value
   214  		// to decide whether to append on the left or the right, to avoid
   215  		// creating degenerated chains.
   216  		//
   217  		//      n1
   218  		//     /  \
   219  		//   i1l  extra
   220  		//        /   \
   221  		//      i1r   n2
   222  		//
   223  		extra := po.newnode(nil)
   224  		if (i1^i2)&1 != 0 { // non-deterministic
   225  			po.setchl(extra, i1r)
   226  			po.setchr(extra, e2)
   227  			po.setchr(i1, newedge(extra, false))
   228  			po.upush(undoSetChr, i1, i1r)
   229  		} else {
   230  			po.setchl(extra, i1l)
   231  			po.setchr(extra, e2)
   232  			po.setchl(i1, newedge(extra, false))
   233  			po.upush(undoSetChl, i1, i1l)
   234  		}
   235  	}
   236  }
   237  
   238  // newnode allocates a new node bound to SSA value n.
   239  // If n is nil, this is an extra node (= only used internally).
   240  func (po *poset) newnode(n *Value) uint32 {
   241  	i := po.lastidx + 1
   242  	po.lastidx++
   243  	po.nodes = append(po.nodes, posetNode{})
   244  	if n != nil {
   245  		if po.values[n.ID] != 0 {
   246  			panic("newnode for Value already inserted")
   247  		}
   248  		po.values[n.ID] = i
   249  		po.upushnew(n.ID, i)
   250  	} else {
   251  		po.upushnew(0, i)
   252  	}
   253  	return i
   254  }
   255  
   256  // lookup searches for a SSA value into the forest of DAGS, and return its node.
   257  func (po *poset) lookup(n *Value) (uint32, bool) {
   258  	i, f := po.values[n.ID]
   259  	return i, f
   260  }
   261  
   262  // aliasnewnode records that a single node n2 (not in the poset yet) is an alias
   263  // of the master node n1.
   264  func (po *poset) aliasnewnode(n1, n2 *Value) {
   265  	i1, i2 := po.values[n1.ID], po.values[n2.ID]
   266  	if i1 == 0 || i2 != 0 {
   267  		panic("aliasnewnode invalid arguments")
   268  	}
   269  
   270  	po.values[n2.ID] = i1
   271  	po.upushalias(n2.ID, 0)
   272  }
   273  
   274  // aliasnodes records that all the nodes i2s are aliases of a single master node n1.
   275  // aliasnodes takes care of rearranging the DAG, changing references of parent/children
   276  // of nodes in i2s, so that they point to n1 instead.
   277  // Complexity is O(n) (with n being the total number of nodes in the poset, not just
   278  // the number of nodes being aliased).
   279  func (po *poset) aliasnodes(n1 *Value, i2s bitset) {
   280  	i1 := po.values[n1.ID]
   281  	if i1 == 0 {
   282  		panic("aliasnode for non-existing node")
   283  	}
   284  	if i2s.Test(i1) {
   285  		panic("aliasnode i2s contains n1 node")
   286  	}
   287  
   288  	// Go through all the nodes to adjust parent/chidlren of nodes in i2s
   289  	for idx, n := range po.nodes {
   290  		// Do not touch i1 itself, otherwise we can create useless self-loops
   291  		if uint32(idx) == i1 {
   292  			continue
   293  		}
   294  		l, r := n.l, n.r
   295  
   296  		// Rename all references to i2s into i1
   297  		if i2s.Test(l.Target()) {
   298  			po.setchl(uint32(idx), newedge(i1, l.Strict()))
   299  			po.upush(undoSetChl, uint32(idx), l)
   300  		}
   301  		if i2s.Test(r.Target()) {
   302  			po.setchr(uint32(idx), newedge(i1, r.Strict()))
   303  			po.upush(undoSetChr, uint32(idx), r)
   304  		}
   305  
   306  		// Connect all children of i2s to i1 (unless those children
   307  		// are in i2s as well, in which case it would be useless)
   308  		if i2s.Test(uint32(idx)) {
   309  			if l != 0 && !i2s.Test(l.Target()) {
   310  				po.addchild(i1, l.Target(), l.Strict())
   311  			}
   312  			if r != 0 && !i2s.Test(r.Target()) {
   313  				po.addchild(i1, r.Target(), r.Strict())
   314  			}
   315  			po.setchl(uint32(idx), 0)
   316  			po.setchr(uint32(idx), 0)
   317  			po.upush(undoSetChl, uint32(idx), l)
   318  			po.upush(undoSetChr, uint32(idx), r)
   319  		}
   320  	}
   321  
   322  	// Reassign all existing IDs that point to i2 to i1.
   323  	// This includes n2.ID.
   324  	for k, v := range po.values {
   325  		if i2s.Test(v) {
   326  			po.values[k] = i1
   327  			po.upushalias(k, v)
   328  		}
   329  	}
   330  }
   331  
   332  func (po *poset) isroot(r uint32) bool {
   333  	for i := range po.roots {
   334  		if po.roots[i] == r {
   335  			return true
   336  		}
   337  	}
   338  	return false
   339  }
   340  
   341  func (po *poset) changeroot(oldr, newr uint32) {
   342  	for i := range po.roots {
   343  		if po.roots[i] == oldr {
   344  			po.roots[i] = newr
   345  			return
   346  		}
   347  	}
   348  	panic("changeroot on non-root")
   349  }
   350  
   351  func (po *poset) removeroot(r uint32) {
   352  	for i := range po.roots {
   353  		if po.roots[i] == r {
   354  			po.roots = slices.Delete(po.roots, i, i+1)
   355  			return
   356  		}
   357  	}
   358  	panic("removeroot on non-root")
   359  }
   360  
   361  // dfs performs a depth-first search within the DAG whose root is r.
   362  // f is the visit function called for each node; if it returns true,
   363  // the search is aborted and true is returned. The root node is
   364  // visited too.
   365  // If strict, ignore edges across a path until at least one
   366  // strict edge is found. For instance, for a chain A<=B<=C<D<=E<F,
   367  // a strict walk visits D,E,F.
   368  // If the visit ends, false is returned.
   369  func (po *poset) dfs(r uint32, strict bool, f func(i uint32) bool) bool {
   370  	closed := newBitset(int(po.lastidx + 1))
   371  	open := make([]uint32, 1, 64)
   372  	open[0] = r
   373  
   374  	if strict {
   375  		// Do a first DFS; walk all paths and stop when we find a strict
   376  		// edge, building a "next" list of nodes reachable through strict
   377  		// edges. This will be the bootstrap open list for the real DFS.
   378  		next := make([]uint32, 0, 64)
   379  
   380  		for len(open) > 0 {
   381  			i := open[len(open)-1]
   382  			open = open[:len(open)-1]
   383  
   384  			// Don't visit the same node twice. Notice that all nodes
   385  			// across non-strict paths are still visited at least once, so
   386  			// a non-strict path can never obscure a strict path to the
   387  			// same node.
   388  			if !closed.Test(i) {
   389  				closed.Set(i)
   390  
   391  				l, r := po.children(i)
   392  				if l != 0 {
   393  					if l.Strict() {
   394  						next = append(next, l.Target())
   395  					} else {
   396  						open = append(open, l.Target())
   397  					}
   398  				}
   399  				if r != 0 {
   400  					if r.Strict() {
   401  						next = append(next, r.Target())
   402  					} else {
   403  						open = append(open, r.Target())
   404  					}
   405  				}
   406  			}
   407  		}
   408  		open = next
   409  		closed.Reset()
   410  	}
   411  
   412  	for len(open) > 0 {
   413  		i := open[len(open)-1]
   414  		open = open[:len(open)-1]
   415  
   416  		if !closed.Test(i) {
   417  			if f(i) {
   418  				return true
   419  			}
   420  			closed.Set(i)
   421  			l, r := po.children(i)
   422  			if l != 0 {
   423  				open = append(open, l.Target())
   424  			}
   425  			if r != 0 {
   426  				open = append(open, r.Target())
   427  			}
   428  		}
   429  	}
   430  	return false
   431  }
   432  
   433  // Returns true if there is a path from i1 to i2.
   434  // If strict ==  true: if the function returns true, then i1 <  i2.
   435  // If strict == false: if the function returns true, then i1 <= i2.
   436  // If the function returns false, no relation is known.
   437  func (po *poset) reaches(i1, i2 uint32, strict bool) bool {
   438  	return po.dfs(i1, strict, func(n uint32) bool {
   439  		return n == i2
   440  	})
   441  }
   442  
   443  // findroot finds i's root, that is which DAG contains i.
   444  // Returns the root; if i is itself a root, it is returned.
   445  // Panic if i is not in any DAG.
   446  func (po *poset) findroot(i uint32) uint32 {
   447  	// TODO(rasky): if needed, a way to speed up this search is
   448  	// storing a bitset for each root using it as a mini bloom filter
   449  	// of nodes present under that root.
   450  	for _, r := range po.roots {
   451  		if po.reaches(r, i, false) {
   452  			return r
   453  		}
   454  	}
   455  	panic("findroot didn't find any root")
   456  }
   457  
   458  // mergeroot merges two DAGs into one DAG by creating a new extra root
   459  func (po *poset) mergeroot(r1, r2 uint32) uint32 {
   460  	r := po.newnode(nil)
   461  	po.setchl(r, newedge(r1, false))
   462  	po.setchr(r, newedge(r2, false))
   463  	po.changeroot(r1, r)
   464  	po.removeroot(r2)
   465  	po.upush(undoMergeRoot, r, 0)
   466  	return r
   467  }
   468  
   469  // collapsepath marks n1 and n2 as equal and collapses as equal all
   470  // nodes across all paths between n1 and n2. If a strict edge is
   471  // found, the function does not modify the DAG and returns false.
   472  // Complexity is O(n).
   473  func (po *poset) collapsepath(n1, n2 *Value) bool {
   474  	i1, i2 := po.values[n1.ID], po.values[n2.ID]
   475  	if po.reaches(i1, i2, true) {
   476  		return false
   477  	}
   478  
   479  	// Find all the paths from i1 to i2
   480  	paths := po.findpaths(i1, i2)
   481  	// Mark all nodes in all the paths as aliases of n1
   482  	// (excluding n1 itself)
   483  	paths.Clear(i1)
   484  	po.aliasnodes(n1, paths)
   485  	return true
   486  }
   487  
   488  // findpaths is a recursive function that calculates all paths from cur to dst
   489  // and return them as a bitset (the index of a node is set in the bitset if
   490  // that node is on at least one path from cur to dst).
   491  // We do a DFS from cur (stopping going deep any time we reach dst, if ever),
   492  // and mark as part of the paths any node that has a children which is already
   493  // part of the path (or is dst itself).
   494  func (po *poset) findpaths(cur, dst uint32) bitset {
   495  	seen := newBitset(int(po.lastidx + 1))
   496  	path := newBitset(int(po.lastidx + 1))
   497  	path.Set(dst)
   498  	po.findpaths1(cur, dst, seen, path)
   499  	return path
   500  }
   501  
   502  func (po *poset) findpaths1(cur, dst uint32, seen bitset, path bitset) {
   503  	if cur == dst {
   504  		return
   505  	}
   506  	seen.Set(cur)
   507  	l, r := po.chl(cur), po.chr(cur)
   508  	if !seen.Test(l) {
   509  		po.findpaths1(l, dst, seen, path)
   510  	}
   511  	if !seen.Test(r) {
   512  		po.findpaths1(r, dst, seen, path)
   513  	}
   514  	if path.Test(l) || path.Test(r) {
   515  		path.Set(cur)
   516  	}
   517  }
   518  
   519  // Check whether it is recorded that i1!=i2
   520  func (po *poset) isnoneq(i1, i2 uint32) bool {
   521  	if i1 == i2 {
   522  		return false
   523  	}
   524  	if i1 < i2 {
   525  		i1, i2 = i2, i1
   526  	}
   527  
   528  	// Check if we recorded a non-equal relation before
   529  	if bs, ok := po.noneq[i1]; ok && bs.Test(i2) {
   530  		return true
   531  	}
   532  	return false
   533  }
   534  
   535  // Record that i1!=i2
   536  func (po *poset) setnoneq(n1, n2 *Value) {
   537  	i1, f1 := po.lookup(n1)
   538  	i2, f2 := po.lookup(n2)
   539  
   540  	// If any of the nodes do not exist in the poset, allocate them. Since
   541  	// we don't know any relation (in the partial order) about them, they must
   542  	// become independent roots.
   543  	if !f1 {
   544  		i1 = po.newnode(n1)
   545  		po.roots = append(po.roots, i1)
   546  		po.upush(undoNewRoot, i1, 0)
   547  	}
   548  	if !f2 {
   549  		i2 = po.newnode(n2)
   550  		po.roots = append(po.roots, i2)
   551  		po.upush(undoNewRoot, i2, 0)
   552  	}
   553  
   554  	if i1 == i2 {
   555  		panic("setnoneq on same node")
   556  	}
   557  	if i1 < i2 {
   558  		i1, i2 = i2, i1
   559  	}
   560  	bs := po.noneq[i1]
   561  	if bs == nil {
   562  		// Given that we record non-equality relations using the
   563  		// higher index as a key, the bitsize will never change size.
   564  		// TODO(rasky): if memory is a problem, consider allocating
   565  		// a small bitset and lazily grow it when higher indices arrive.
   566  		bs = newBitset(int(i1))
   567  		po.noneq[i1] = bs
   568  	} else if bs.Test(i2) {
   569  		// Already recorded
   570  		return
   571  	}
   572  	bs.Set(i2)
   573  	po.upushneq(i1, i2)
   574  }
   575  
   576  // CheckIntegrity verifies internal integrity of a poset. It is intended
   577  // for debugging purposes.
   578  func (po *poset) CheckIntegrity() {
   579  	// Verify that each node appears in a single DAG
   580  	seen := newBitset(int(po.lastidx + 1))
   581  	for _, r := range po.roots {
   582  		if r == 0 {
   583  			panic("empty root")
   584  		}
   585  
   586  		po.dfs(r, false, func(i uint32) bool {
   587  			if seen.Test(i) {
   588  				panic("duplicate node")
   589  			}
   590  			seen.Set(i)
   591  			return false
   592  		})
   593  	}
   594  
   595  	// Verify that values contain the minimum set
   596  	for id, idx := range po.values {
   597  		if !seen.Test(idx) {
   598  			panic(fmt.Errorf("spurious value [%d]=%d", id, idx))
   599  		}
   600  	}
   601  
   602  	// Verify that only existing nodes have non-zero children
   603  	for i, n := range po.nodes {
   604  		if n.l|n.r != 0 {
   605  			if !seen.Test(uint32(i)) {
   606  				panic(fmt.Errorf("children of unknown node %d->%v", i, n))
   607  			}
   608  			if n.l.Target() == uint32(i) || n.r.Target() == uint32(i) {
   609  				panic(fmt.Errorf("self-loop on node %d", i))
   610  			}
   611  		}
   612  	}
   613  }
   614  
   615  // CheckEmpty checks that a poset is completely empty.
   616  // It can be used for debugging purposes, as a poset is supposed to
   617  // be empty after it's fully rolled back through Undo.
   618  func (po *poset) CheckEmpty() error {
   619  	if len(po.nodes) != 1 {
   620  		return fmt.Errorf("non-empty nodes list: %v", po.nodes)
   621  	}
   622  	if len(po.values) != 0 {
   623  		return fmt.Errorf("non-empty value map: %v", po.values)
   624  	}
   625  	if len(po.roots) != 0 {
   626  		return fmt.Errorf("non-empty root list: %v", po.roots)
   627  	}
   628  	if len(po.undo) != 0 {
   629  		return fmt.Errorf("non-empty undo list: %v", po.undo)
   630  	}
   631  	if po.lastidx != 0 {
   632  		return fmt.Errorf("lastidx index is not zero: %v", po.lastidx)
   633  	}
   634  	for _, bs := range po.noneq {
   635  		for _, x := range bs {
   636  			if x != 0 {
   637  				return fmt.Errorf("non-empty noneq map")
   638  			}
   639  		}
   640  	}
   641  	return nil
   642  }
   643  
   644  // DotDump dumps the poset in graphviz format to file fn, with the specified title.
   645  func (po *poset) DotDump(fn string, title string) error {
   646  	f, err := os.Create(fn)
   647  	if err != nil {
   648  		return err
   649  	}
   650  	defer f.Close()
   651  
   652  	// Create reverse index mapping (taking aliases into account)
   653  	names := make(map[uint32]string)
   654  	for id, i := range po.values {
   655  		s := names[i]
   656  		if s == "" {
   657  			s = fmt.Sprintf("v%d", id)
   658  		} else {
   659  			s += fmt.Sprintf(", v%d", id)
   660  		}
   661  		names[i] = s
   662  	}
   663  
   664  	fmt.Fprintf(f, "digraph poset {\n")
   665  	fmt.Fprintf(f, "\tedge [ fontsize=10 ]\n")
   666  	for ridx, r := range po.roots {
   667  		fmt.Fprintf(f, "\tsubgraph root%d {\n", ridx)
   668  		po.dfs(r, false, func(i uint32) bool {
   669  			fmt.Fprintf(f, "\t\tnode%d [label=<%s <font point-size=\"6\">[%d]</font>>]\n", i, names[i], i)
   670  			chl, chr := po.children(i)
   671  			for _, ch := range []posetEdge{chl, chr} {
   672  				if ch != 0 {
   673  					if ch.Strict() {
   674  						fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <\" color=\"red\"]\n", i, ch.Target())
   675  					} else {
   676  						fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <=\" color=\"green\"]\n", i, ch.Target())
   677  					}
   678  				}
   679  			}
   680  			return false
   681  		})
   682  		fmt.Fprintf(f, "\t}\n")
   683  	}
   684  	fmt.Fprintf(f, "\tlabelloc=\"t\"\n")
   685  	fmt.Fprintf(f, "\tlabeldistance=\"3.0\"\n")
   686  	fmt.Fprintf(f, "\tlabel=%q\n", title)
   687  	fmt.Fprintf(f, "}\n")
   688  	return nil
   689  }
   690  
   691  // Ordered reports whether n1<n2. It returns false either when it is
   692  // certain that n1<n2 is false, or if there is not enough information
   693  // to tell.
   694  // Complexity is O(n).
   695  func (po *poset) Ordered(n1, n2 *Value) bool {
   696  	if debugPoset {
   697  		defer po.CheckIntegrity()
   698  	}
   699  	if n1.ID == n2.ID {
   700  		panic("should not call Ordered with n1==n2")
   701  	}
   702  
   703  	i1, f1 := po.lookup(n1)
   704  	i2, f2 := po.lookup(n2)
   705  	if !f1 || !f2 {
   706  		return false
   707  	}
   708  
   709  	return i1 != i2 && po.reaches(i1, i2, true)
   710  }
   711  
   712  // OrderedOrEqual reports whether n1<=n2. It returns false either when it is
   713  // certain that n1<=n2 is false, or if there is not enough information
   714  // to tell.
   715  // Complexity is O(n).
   716  func (po *poset) OrderedOrEqual(n1, n2 *Value) bool {
   717  	if debugPoset {
   718  		defer po.CheckIntegrity()
   719  	}
   720  	if n1.ID == n2.ID {
   721  		panic("should not call Ordered with n1==n2")
   722  	}
   723  
   724  	i1, f1 := po.lookup(n1)
   725  	i2, f2 := po.lookup(n2)
   726  	if !f1 || !f2 {
   727  		return false
   728  	}
   729  
   730  	return i1 == i2 || po.reaches(i1, i2, false)
   731  }
   732  
   733  // Equal reports whether n1==n2. It returns false either when it is
   734  // certain that n1==n2 is false, or if there is not enough information
   735  // to tell.
   736  // Complexity is O(1).
   737  func (po *poset) Equal(n1, n2 *Value) bool {
   738  	if debugPoset {
   739  		defer po.CheckIntegrity()
   740  	}
   741  	if n1.ID == n2.ID {
   742  		panic("should not call Equal with n1==n2")
   743  	}
   744  
   745  	i1, f1 := po.lookup(n1)
   746  	i2, f2 := po.lookup(n2)
   747  	return f1 && f2 && i1 == i2
   748  }
   749  
   750  // NonEqual reports whether n1!=n2. It returns false either when it is
   751  // certain that n1!=n2 is false, or if there is not enough information
   752  // to tell.
   753  // Complexity is O(n) (because it internally calls Ordered to see if we
   754  // can infer n1!=n2 from n1<n2 or n2<n1).
   755  func (po *poset) NonEqual(n1, n2 *Value) bool {
   756  	if debugPoset {
   757  		defer po.CheckIntegrity()
   758  	}
   759  	if n1.ID == n2.ID {
   760  		panic("should not call NonEqual with n1==n2")
   761  	}
   762  
   763  	// If we never saw the nodes before, we don't
   764  	// have a recorded non-equality.
   765  	i1, f1 := po.lookup(n1)
   766  	i2, f2 := po.lookup(n2)
   767  	if !f1 || !f2 {
   768  		return false
   769  	}
   770  
   771  	// Check if we recorded inequality
   772  	if po.isnoneq(i1, i2) {
   773  		return true
   774  	}
   775  
   776  	// Check if n1<n2 or n2<n1, in which case we can infer that n1!=n2
   777  	if po.Ordered(n1, n2) || po.Ordered(n2, n1) {
   778  		return true
   779  	}
   780  
   781  	return false
   782  }
   783  
   784  // setOrder records that n1<n2 or n1<=n2 (depending on strict). Returns false
   785  // if this is a contradiction.
   786  // Implements SetOrder() and SetOrderOrEqual()
   787  func (po *poset) setOrder(n1, n2 *Value, strict bool) bool {
   788  	i1, f1 := po.lookup(n1)
   789  	i2, f2 := po.lookup(n2)
   790  
   791  	switch {
   792  	case !f1 && !f2:
   793  		// Neither n1 nor n2 are in the poset, so they are not related
   794  		// in any way to existing nodes.
   795  		// Create a new DAG to record the relation.
   796  		i1, i2 = po.newnode(n1), po.newnode(n2)
   797  		po.roots = append(po.roots, i1)
   798  		po.upush(undoNewRoot, i1, 0)
   799  		po.addchild(i1, i2, strict)
   800  
   801  	case f1 && !f2:
   802  		// n1 is in one of the DAGs, while n2 is not. Add n2 as children
   803  		// of n1.
   804  		i2 = po.newnode(n2)
   805  		po.addchild(i1, i2, strict)
   806  
   807  	case !f1 && f2:
   808  		// n1 is not in any DAG but n2 is. If n2 is a root, we can put
   809  		// n1 in its place as a root; otherwise, we need to create a new
   810  		// extra root to record the relation.
   811  		i1 = po.newnode(n1)
   812  
   813  		if po.isroot(i2) {
   814  			po.changeroot(i2, i1)
   815  			po.upush(undoChangeRoot, i1, newedge(i2, strict))
   816  			po.addchild(i1, i2, strict)
   817  			return true
   818  		}
   819  
   820  		// Search for i2's root; this requires a O(n) search on all
   821  		// DAGs
   822  		r := po.findroot(i2)
   823  
   824  		// Re-parent as follows:
   825  		//
   826  		//                  extra
   827  		//     r            /   \
   828  		//      \   ===>   r    i1
   829  		//      i2          \   /
   830  		//                    i2
   831  		//
   832  		extra := po.newnode(nil)
   833  		po.changeroot(r, extra)
   834  		po.upush(undoChangeRoot, extra, newedge(r, false))
   835  		po.addchild(extra, r, false)
   836  		po.addchild(extra, i1, false)
   837  		po.addchild(i1, i2, strict)
   838  
   839  	case f1 && f2:
   840  		// If the nodes are aliased, fail only if we're setting a strict order
   841  		// (that is, we cannot set n1<n2 if n1==n2).
   842  		if i1 == i2 {
   843  			return !strict
   844  		}
   845  
   846  		// If we are trying to record n1<=n2 but we learned that n1!=n2,
   847  		// record n1<n2, as it provides more information.
   848  		if !strict && po.isnoneq(i1, i2) {
   849  			strict = true
   850  		}
   851  
   852  		// Both n1 and n2 are in the poset. This is the complex part of the algorithm
   853  		// as we need to find many different cases and DAG shapes.
   854  
   855  		// Check if n1 somehow reaches n2
   856  		if po.reaches(i1, i2, false) {
   857  			// This is the table of all cases we need to handle:
   858  			//
   859  			//      DAG          New      Action
   860  			//      ---------------------------------------------------
   861  			// #1:  N1<=X<=N2 |  N1<=N2 | do nothing
   862  			// #2:  N1<=X<=N2 |  N1<N2  | add strict edge (N1<N2)
   863  			// #3:  N1<X<N2   |  N1<=N2 | do nothing (we already know more)
   864  			// #4:  N1<X<N2   |  N1<N2  | do nothing
   865  
   866  			// Check if we're in case #2
   867  			if strict && !po.reaches(i1, i2, true) {
   868  				po.addchild(i1, i2, true)
   869  				return true
   870  			}
   871  
   872  			// Case #1, #3, or #4: nothing to do
   873  			return true
   874  		}
   875  
   876  		// Check if n2 somehow reaches n1
   877  		if po.reaches(i2, i1, false) {
   878  			// This is the table of all cases we need to handle:
   879  			//
   880  			//      DAG           New      Action
   881  			//      ---------------------------------------------------
   882  			// #5:  N2<=X<=N1  |  N1<=N2 | collapse path (learn that N1=X=N2)
   883  			// #6:  N2<=X<=N1  |  N1<N2  | contradiction
   884  			// #7:  N2<X<N1    |  N1<=N2 | contradiction in the path
   885  			// #8:  N2<X<N1    |  N1<N2  | contradiction
   886  
   887  			if strict {
   888  				// Cases #6 and #8: contradiction
   889  				return false
   890  			}
   891  
   892  			// We're in case #5 or #7. Try to collapse path, and that will
   893  			// fail if it realizes that we are in case #7.
   894  			return po.collapsepath(n2, n1)
   895  		}
   896  
   897  		// We don't know of any existing relation between n1 and n2. They could
   898  		// be part of the same DAG or not.
   899  		// Find their roots to check whether they are in the same DAG.
   900  		r1, r2 := po.findroot(i1), po.findroot(i2)
   901  		if r1 != r2 {
   902  			// We need to merge the two DAGs to record a relation between the nodes
   903  			po.mergeroot(r1, r2)
   904  		}
   905  
   906  		// Connect n1 and n2
   907  		po.addchild(i1, i2, strict)
   908  	}
   909  
   910  	return true
   911  }
   912  
   913  // SetOrder records that n1<n2. Returns false if this is a contradiction
   914  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
   915  func (po *poset) SetOrder(n1, n2 *Value) bool {
   916  	if debugPoset {
   917  		defer po.CheckIntegrity()
   918  	}
   919  	if n1.ID == n2.ID {
   920  		panic("should not call SetOrder with n1==n2")
   921  	}
   922  	return po.setOrder(n1, n2, true)
   923  }
   924  
   925  // SetOrderOrEqual records that n1<=n2. Returns false if this is a contradiction
   926  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
   927  func (po *poset) SetOrderOrEqual(n1, n2 *Value) bool {
   928  	if debugPoset {
   929  		defer po.CheckIntegrity()
   930  	}
   931  	if n1.ID == n2.ID {
   932  		panic("should not call SetOrder with n1==n2")
   933  	}
   934  	return po.setOrder(n1, n2, false)
   935  }
   936  
   937  // SetEqual records that n1==n2. Returns false if this is a contradiction
   938  // (that is, if it is already recorded that n1<n2 or n2<n1).
   939  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
   940  func (po *poset) SetEqual(n1, n2 *Value) bool {
   941  	if debugPoset {
   942  		defer po.CheckIntegrity()
   943  	}
   944  	if n1.ID == n2.ID {
   945  		panic("should not call Add with n1==n2")
   946  	}
   947  
   948  	i1, f1 := po.lookup(n1)
   949  	i2, f2 := po.lookup(n2)
   950  
   951  	switch {
   952  	case !f1 && !f2:
   953  		i1 = po.newnode(n1)
   954  		po.roots = append(po.roots, i1)
   955  		po.upush(undoNewRoot, i1, 0)
   956  		po.aliasnewnode(n1, n2)
   957  	case f1 && !f2:
   958  		po.aliasnewnode(n1, n2)
   959  	case !f1 && f2:
   960  		po.aliasnewnode(n2, n1)
   961  	case f1 && f2:
   962  		if i1 == i2 {
   963  			// Already aliased, ignore
   964  			return true
   965  		}
   966  
   967  		// If we recorded that n1!=n2, this is a contradiction.
   968  		if po.isnoneq(i1, i2) {
   969  			return false
   970  		}
   971  
   972  		// If we already knew that n1<=n2, we can collapse the path to
   973  		// record n1==n2 (and vice versa).
   974  		if po.reaches(i1, i2, false) {
   975  			return po.collapsepath(n1, n2)
   976  		}
   977  		if po.reaches(i2, i1, false) {
   978  			return po.collapsepath(n2, n1)
   979  		}
   980  
   981  		r1 := po.findroot(i1)
   982  		r2 := po.findroot(i2)
   983  		if r1 != r2 {
   984  			// Merge the two DAGs so we can record relations between the nodes
   985  			po.mergeroot(r1, r2)
   986  		}
   987  
   988  		// Set n2 as alias of n1. This will also update all the references
   989  		// to n2 to become references to n1
   990  		i2s := newBitset(int(po.lastidx) + 1)
   991  		i2s.Set(i2)
   992  		po.aliasnodes(n1, i2s)
   993  	}
   994  	return true
   995  }
   996  
   997  // SetNonEqual records that n1!=n2. Returns false if this is a contradiction
   998  // (that is, if it is already recorded that n1==n2).
   999  // Complexity is O(n).
  1000  func (po *poset) SetNonEqual(n1, n2 *Value) bool {
  1001  	if debugPoset {
  1002  		defer po.CheckIntegrity()
  1003  	}
  1004  	if n1.ID == n2.ID {
  1005  		panic("should not call SetNonEqual with n1==n2")
  1006  	}
  1007  
  1008  	// Check whether the nodes are already in the poset
  1009  	i1, f1 := po.lookup(n1)
  1010  	i2, f2 := po.lookup(n2)
  1011  
  1012  	// If either node wasn't present, we just record the new relation
  1013  	// and exit.
  1014  	if !f1 || !f2 {
  1015  		po.setnoneq(n1, n2)
  1016  		return true
  1017  	}
  1018  
  1019  	// See if we already know this, in which case there's nothing to do.
  1020  	if po.isnoneq(i1, i2) {
  1021  		return true
  1022  	}
  1023  
  1024  	// Check if we're contradicting an existing equality relation
  1025  	if po.Equal(n1, n2) {
  1026  		return false
  1027  	}
  1028  
  1029  	// Record non-equality
  1030  	po.setnoneq(n1, n2)
  1031  
  1032  	// If we know that i1<=i2 but not i1<i2, learn that as we
  1033  	// now know that they are not equal. Do the same for i2<=i1.
  1034  	// Do this check only if both nodes were already in the DAG,
  1035  	// otherwise there cannot be an existing relation.
  1036  	if po.reaches(i1, i2, false) && !po.reaches(i1, i2, true) {
  1037  		po.addchild(i1, i2, true)
  1038  	}
  1039  	if po.reaches(i2, i1, false) && !po.reaches(i2, i1, true) {
  1040  		po.addchild(i2, i1, true)
  1041  	}
  1042  
  1043  	return true
  1044  }
  1045  
  1046  // Checkpoint saves the current state of the DAG so that it's possible
  1047  // to later undo this state.
  1048  // Complexity is O(1).
  1049  func (po *poset) Checkpoint() {
  1050  	po.undo = append(po.undo, posetUndo{typ: undoCheckpoint})
  1051  }
  1052  
  1053  // Undo restores the state of the poset to the previous checkpoint.
  1054  // Complexity depends on the type of operations that were performed
  1055  // since the last checkpoint; each Set* operation creates an undo
  1056  // pass which Undo has to revert with a worst-case complexity of O(n).
  1057  func (po *poset) Undo() {
  1058  	if len(po.undo) == 0 {
  1059  		panic("empty undo stack")
  1060  	}
  1061  	if debugPoset {
  1062  		defer po.CheckIntegrity()
  1063  	}
  1064  
  1065  	for len(po.undo) > 0 {
  1066  		pass := po.undo[len(po.undo)-1]
  1067  		po.undo = po.undo[:len(po.undo)-1]
  1068  
  1069  		switch pass.typ {
  1070  		case undoCheckpoint:
  1071  			return
  1072  
  1073  		case undoSetChl:
  1074  			po.setchl(pass.idx, pass.edge)
  1075  
  1076  		case undoSetChr:
  1077  			po.setchr(pass.idx, pass.edge)
  1078  
  1079  		case undoNonEqual:
  1080  			po.noneq[uint32(pass.ID)].Clear(pass.idx)
  1081  
  1082  		case undoNewNode:
  1083  			if pass.idx != po.lastidx {
  1084  				panic("invalid newnode index")
  1085  			}
  1086  			if pass.ID != 0 {
  1087  				if po.values[pass.ID] != pass.idx {
  1088  					panic("invalid newnode undo pass")
  1089  				}
  1090  				delete(po.values, pass.ID)
  1091  			}
  1092  			po.setchl(pass.idx, 0)
  1093  			po.setchr(pass.idx, 0)
  1094  			po.nodes = po.nodes[:pass.idx]
  1095  			po.lastidx--
  1096  
  1097  		case undoAliasNode:
  1098  			ID, prev := pass.ID, pass.idx
  1099  			cur := po.values[ID]
  1100  			if prev == 0 {
  1101  				// Born as an alias, die as an alias
  1102  				delete(po.values, ID)
  1103  			} else {
  1104  				if cur == prev {
  1105  					panic("invalid aliasnode undo pass")
  1106  				}
  1107  				// Give it back previous value
  1108  				po.values[ID] = prev
  1109  			}
  1110  
  1111  		case undoNewRoot:
  1112  			i := pass.idx
  1113  			l, r := po.children(i)
  1114  			if l|r != 0 {
  1115  				panic("non-empty root in undo newroot")
  1116  			}
  1117  			po.removeroot(i)
  1118  
  1119  		case undoChangeRoot:
  1120  			i := pass.idx
  1121  			l, r := po.children(i)
  1122  			if l|r != 0 {
  1123  				panic("non-empty root in undo changeroot")
  1124  			}
  1125  			po.changeroot(i, pass.edge.Target())
  1126  
  1127  		case undoMergeRoot:
  1128  			i := pass.idx
  1129  			l, r := po.children(i)
  1130  			po.changeroot(i, l.Target())
  1131  			po.roots = append(po.roots, r.Target())
  1132  
  1133  		default:
  1134  			panic(pass.typ)
  1135  		}
  1136  	}
  1137  
  1138  	if debugPoset && po.CheckEmpty() != nil {
  1139  		panic("poset not empty at the end of undo")
  1140  	}
  1141  }
  1142  

View as plain text