#
#  Union-Find data structure for set operations
#
#       Adrian Altenhoff, March 19, 2008
#

UnionFind := proc(Elements, Col, Sizes, ElmInd) option NoIndex:
	if nargs=5 then noeval(procname(args));
	# a list of sets: [{44,66},{12,23,33},{11,55,64}]
	elif nargs=1 and type(args[1],list(set(anything))) then
		m := length(args[1]):
		for i to m do for j from i+1 to m do 
			if intersect(args[1,i],args[1,j])<>{} then
				error('initial sets must be disjoint')
			fi
		od od:
		elm := [op( union( op(args[1]) ) )]:
		n := length(elm):
		cls := CreateArray(1..n):
		siz := copy(cls):
		for i to m do 
			L := length(args[1,i]):
			k0 := traperror(searchfun(args[1,i,1], elm)):
			k0 := SearchOrderedArray(args[1,i,1], elm):
			cls[k0] := k0: siz[k0] := L:
			for j from 2 to L do
				k := SearchOrderedArray(args[1,i,j], elm):
				cls[k] := k0:
			od:
		od:
		noeval(procname( elm, cls, siz, 0, 0 )):
	# an initial list of elements, which all are sets of size 1.
    #    [44,66,12,23,33,11,55,64]
	elif nargs=1 and type(args[1],list(anything)) then
		elm := [op({op(args[1])})]:
		n := length(elm):
		clst := [seq(i,i=1..n)]:
		noeval(procname( elm, clst, [seq(1,n)], 0, 0 )):
	elif nargs=0 then noeval(procname([],[],[],0,0))
	else error('invalid arguments') fi:
end:

UnionFind_type := noeval(UnionFind(list,list,list,{0,list},{0,list})):

# Find uses halving for path compression 
# (see Tarjan etal, Worst case analysis of set undion algorthims, 
#      J.Assoc.Comput.Mach., 31:245-281, 1984
UFfind := proc(x:posint, arr:list(posint)) option internal;
	y := x:
	while arr[arr[y]]<>arr[y] do 
		y := arr[y] := arr[arr[y]];
	od:
	return( arr[y] ):
end:
	
UnionFind_union := proc(o:UnionFind, elm:[anything,anything] )
	n := length(o['Elements']):
	i1 := SearchOrderedArray(elm[1], o['Elements']):
	if i1<=0 or i1>n or o['Elements',i1]<>elm[1] then
		error('Element '.string(elm[1]).' not present in UnionFind.') fi:
	i2 := SearchOrderedArray(elm[2], o['Elements']):
	if i2<=0 or i2>n or o['Elements',i2]<>elm[2] then
		error('Element '.string(elm[2]).' not present in UnionFind.') fi:
	# use element index
	if o['ElmInd']<>0 then i1 := o['ElmInd',i1]: i2:=o['ElmInd',i2] fi:
	cls1 := UFfind(i1,o['Col']): cls2 := UFfind(i2,o['Col']):
	if cls1<>cls2 then 
		# merge collections: small into big
		if o['Sizes',cls2] >= o['Sizes',cls1] then
			t := cls1: cls1 := cls2: cls2:=t:
		fi:
		o['Col',cls2] := o['Col',cls1]:
		o['Sizes',cls1] := o['Sizes',cls1] + o['Sizes',cls2]:
		# reset cluster selector
		o[5] := 0:
	fi:
	return(o)
end:
	
UnionFind_select := proc(o:UnionFind, sel, val)
	if nargs=3 then error('assignments not allowed');
	elif sel='Clusters' then
		if o[5]=0 then
			n := length(o['Elements']):
			tmp := CreateArray(1..n,[]):
			for i to n do 
				k := If( o['ElmInd']=0, i, o['ElmInd',i] ):
				s := UFfind( k, o['Col'] ):
				tmp[s] := append(tmp[s], o['Elements',i]):
			od:
			o[5] := [seq( If(tmp[i]=[],NULL, {op(tmp[i])}), i=1..n)]:
		fi:
		o[5]:
	else error('invalid selector') fi;
end:


UnionFind_print := proc(o:UnionFind)
	printf('\nOverview on collection of sets under union:\n');
	printf(' # of distinct elements in all sets: %d\n', length(o['Elements']));
	printf(' # of disjoint sets: %d\n', length(o['Clusters']) ):
	lens := seq(length(t), t=o['Clusters']):
	printf(' Biggest cluster has %d elements\n', max( lens )):
	printf(' Smallest cluster has %d elements\n', min( lens )):
end:


UnionFind_plus := proc(o:UnionFind, new:{set, anything})
	# add a new, disjoint set to the structure
	n := length(o['Elements']):
	if type(new,set(anything)) then
		elm := new:
		if intersect(elm, {op(o['Elements'])}) <> {} then
			error('new set must be disjoint from all the other sets');
		elif o['ElmInd']=0 then o['ElmInd'] := [seq(i,i=1..n)]:
		fi:
		m := length(elm):
		o['Col'] := append( o['Col'], seq(n+1, m) );
		o['Sizes'] := append( o['Sizes'], m, seq(0, m-1) ):
	else 
		elm := If( type(new,list(anything)), {op(new)}, {new} ):
		int := intersect(elm, {op(o['Elements'])}):
		if int<>{} then
			error( sprintf('Elements %a already exist in '.
			    'UnionFind data structure',[op(int)])); 
		elif o['ElmInd']=0 then o['ElmInd'] := [seq(i,i=1..n)]:
		fi:
		m := length(elm):
		o['Col'] := append( o['Col'], seq(n+i, i=1..m) );
		o['Sizes'] := append( o['Sizes'], seq(1, m) ):
	fi:
	
	tmp := {seq( [o['Elements',i],o['ElmInd',i]], i=1..n ),
	        seq( [elm[i],i+n], i=1..m )}:
	tmp := transpose([op(tmp)]):
	o['Elements'] := tmp[1]:
	o['ElmInd'] := tmp[2]:
	o[5] := 0:
	return(o):
end:

CompleteClass(UnionFind):
