(******************************************************************************
 * Copyright (c) 2012-2013, Toyama&Aoto Laboratory, Tohoku University
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without 
 * modification, are permitted provided that the following conditions are met:
 * 
 *  1. Redistributions of source code must retain the above copyright notice, 
 *     this list of conditions and the following disclaimer.
 *  2. Redistributions in binary form must reproduce the above copyright 
 *     notice, this list of conditions and the following disclaimer in the 
 *     documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 
 * POSSIBILITY OF SUCH DAMAGE.
 ******************************************************************************)
(******************************************************************************
 * file: rwtools/rwchecker/dp.sml
 * description: ingredients of the dependency pair technique
 * author: AOTO Takahito
 * 
 ******************************************************************************)

signature DP = 
sig
    val runDebug: bool ref
   val dependencyPairs: FunSet.set -> (Term.term * Term.term) list -> (Term.term * Term.term) list
   val dependencyGraph: FunSet.set -> (Term.term * Term.term) list -> int list array
   val innermostDependencyGraph: 
       FunSet.set -> (Term.term * Term.term) list 
       -> (Term.term * Term.term) list -> int list array
   val applySubtermCriteria: string 
			     -> string 
			     -> (Term.term * Term.term) list 
			     -> bool * (int list)
   val getUsableFunMap: Trs.trs -> FunSet.set FunMap.map
   val usableFunSet: FunSet.set FunMap.map -> FunSet.set -> FunSet.set
   val getUsableRules :FunSet.set FunMap.map 
		       -> (Term.term * Term.term) list 
		       -> (Term.term * Term.term) list 
		       -> int list 
   val applyAfToTerm: int list FunMap.map -> FunSet.set -> Term.term -> Term.term

end;

structure Dp: DP =
struct
   local 
       open Term
       open Subst
       structure VS = VarSet
       structure VM = VarMap
       structure FS = FunSet
       structure FM = FunMap
       structure SS = SortSet
       structure IS = IntSet
       structure IM = IntMap
       structure FIS = FunIntSet
       structure FIT = FunIntTable
       structure L = List
       structure LP = ListPair
       structure TS = TermSet
       structure TPS = TermPairSet
       structure P = Prop
       fun mapAppend f xs = List.foldr (fn (x,ys) => List.@(f x, ys)) [] xs
   in

   val runDebug = ref false : bool ref
   fun debug f = if !runDebug then f () else ()

   val markingLetter = ref #"#"

   exception DpError;

   fun dependencyPairs dset rs =
       let
	   fun rootDefinedSubtermSet t = 
	       let
		   fun dSub (Var _) accum = accum
		     | dSub (t as Fun (f,ts,_)) accum =
		       if FS.member (dset,f) 
		       then dSubList ts (TS.add (accum, t))
		       else dSubList ts accum
		   and dSubList [] accum = accum
		     | dSubList (t::ts) accum = 
		       let val accum2 = dSub t accum
		       in dSubList ts accum2
		       end
	       in
		   dSub t TS.empty
	       end

	   fun markingRoot (t as Var _) = t
	     | markingRoot (Fun (f,ts,a)) = 
	       let val new = Atom.atom (concat [(Atom.toString f),(str (!markingLetter))])
	       in (Fun (new,ts,a)) 
	       end
	       
	   fun mkDps (l,r) = 
	       let val ts = TS.listItems (rootDefinedSubtermSet r)
(* 		   val _ = print (PrintUtil.prList toString ts) *)
		   val ts' = List.filter (fn t => not (Term.isAProperSubterm t l)) ts
		   val l' = markingRoot l
	       in L.map (fn s => (l', markingRoot s)) ts'
	       end
       in
	   mapAppend mkDps rs
       end

  (* $BItJ,9`5,=`(B *) 
   fun applySubtermCriteria path dir dps =
       let val markedSymbolxAritySet =  
	       let fun root (Fun (f,args,_)) = (f,length args)
		   fun join (l,r) set = FIS.add (FIS.add (set, root l), root r)
	       in
		   L.foldr (fn (rule,set) => join rule set) FIS.empty dps
	       end

	   val len =  FIS.foldr (fn ((_,k),n) => n+k) 0 markedSymbolxAritySet

	   exception FunArityTableError
	   val funArityTable = FIT.mkTable (len, FunArityTableError) 

	   val symCount = ref 1
	   val _ = FIS.app (fn (f,ar) =>  
			       (L.app
				    (fn i =>
					(FIT.insert funArityTable ((f,i), !symCount);
					 symCount := !symCount + 1))
 				    (L.tabulate (ar,fn i=>i))))
			   markedSymbolxAritySet

 	  fun pv (f,i) =  (* i = 0,..., arity(f) - 1  *)
	      case FIT.find funArityTable (f,i)
	       of SOME n => P.Atom n
		| NONE => (print ("subtermCriteria: " ^ (Atom.toString f) ^ " / " 
				  ^ (Int.toString i) ^ "\n");
			   raise FunArityTableError)

	  fun makePropFromDP (l as Fun (f,ss,_), r as Fun (g,ts,_)) = 
	      let val ilist = L.tabulate (length ss, fn i=>i)
		  val jlist = L.tabulate (length ts, fn j=>j)
	      in
		   P.Conj (L.map (fn i => 
				     let val si  = L.nth (ss,i)
				     in
					 P.Imp (pv (f,i),
						P.Disj (L.mapPartial 
							    (fn j => 
								if isASubterm (L.nth (ts,j)) si
								then SOME (pv (g,j))
								else NONE)
							    jlist))
				     end)
				 ilist)
	      end

	   val prop1 = P.Conj (L.map makePropFromDP dps)

	   (* filter $B$9$k0z?t$O!$CzEY(B 1 $B2U=j$@$1(B *)
	   val prop2 = P.Conj (L.map (fn (f,ar) => P.one (L.tabulate (ar, fn i=> pv (f,i))))
				   (FIS.listItems markedSymbolxAritySet))

	   fun collectProperCaseFromDP (l as Fun (f,ss,_), r as Fun (g,ts,_)) = 
	       let val ilist = L.tabulate (length ss, fn i=>i)
		   val jlist = L.tabulate (length ts, fn j=>j)
	       in
		   List.mapPartial
		       (fn (i,j) => let val si = L.nth (ss,i)
					val tj = L.nth (ts,j)
				    in if isAProperSubterm tj si
				       then SOME (P.Conj [pv (f,i), pv (g,j)])
				       else NONE
				    end)
		       (ListXProd.mapX  (fn (i,j) => (i,j)) (ilist,jlist))
	       end

           (* L.nth(dp,i)  $B$,(B proper <=> Atom (dpsCount+i) $B$,(B true *)
	   val dpsCount = !symCount 
	   val _ = symCount := !symCount + (L.length dps)

	   val prop3 = P.Conj (L.tabulate (L.length dps,
					   fn i =>
					      P.Iff (P.Atom (dpsCount + i),
						     P.Disj (collectProperCaseFromDP (L.nth (dps,i))))))

	   (* proper subterm $B$K$J$C$F$$$k$b$N$,(B1$B$D$OB8:_$9$k(B *)
	   val prop4 = P.Disj (L.tabulate (L.length dps,
					   fn i => P.Atom (dpsCount + i)))

	   val (prop,count) = (P.Conj [prop1,prop2,prop3,prop4], !symCount - 1)

	   val (result,resultAr) = Solver.propSolver path dir (prop,count)

	   val _ = debug (fn _ =>
			     if result
			     then
				 (print "argument selection [";
				  FIS.app (fn (f,ar) =>  
					      (L.app
						   (fn i=>
						       case FIT.find funArityTable (f,i)
							of SOME n => if PoSolver.isAssignedByTrue resultAr n
								     then print (" " ^ (Fun.toString f)
										 ^ ":=" ^ (Int.toString (i+1)) )
								     else ()
							 | NONE => (print ("subtermCriteria: " 
									   ^ (Atom.toString f) ^ " / " 
									   ^ (Int.toString i) ^ "\n");
								    raise FunArityTableError))
						   (L.tabulate (ar,fn i=>i))))
					  markedSymbolxAritySet;
				  print " ]\n")
			     else ())
		   
	   val remainedIdxes = if result
			     then
				 L.filter
				     (fn i => not (PoSolver.isAssignedByTrue resultAr (dpsCount + i)))
				     (L.tabulate (L.length dps, fn i =>i))
			     else (L.tabulate (L.length dps, fn i =>i))
       in
	   (* 1$B$D$G$b(Bproper$B$J$N$,$"$C$?$+$I$&$+!$(Bproper $B$G$J$+$C$?(Bdps$B$N%$%s%G%C%/%9%j%9%H(B *)
	   (result,remainedIdxes)
       end



(*   (\* REN o CAP(s)  $B$H(B REN o CAP(t) $B$,(B unifiable $B$+(B *\)  *)
(*   (\* --- marking $B$,$5$l$F$$$k$H2>Dj(B *\) *)
(*    fun isConnectable _ (Var _, _) = true *)
(*      | isConnectable _ (_,Var _) = true *)
(*      | isConnectable DSymSet (Fun (f,ts,a),Fun (g,ss,b)) = *)
(*        if FS.member (DSymSet, f)  *)
(*        then true *)
(*        else if FS.member (DSymSet, g)  *)
(*        then true *)
(*        else Fun.equal (f,g)  *)
(* 	    andalso LP.all (isConnectable DSymSet) (ts,ss) *)

  (* REN o CAP(s)  $B$H(B t $B$,(B unifiable $B$+(B *) 
  (* --- marking $B$,$5$l$F$$$k$H2>Dj(B *)
   fun isConnectable _ (Var _, _) = true
     | isConnectable _ (_,Var _) = true
     | isConnectable DSymSet (Fun (f,ts,a),Fun (g,ss,b)) =
       if FS.member (DSymSet, f) 
       then true
       else Fun.equal (f,g) 
	    andalso LP.all (isConnectable DSymSet) (ts,ss)

   fun dependencyGraph dSymSet dps = 
       let val dg = Array.array (length dps, [])
	   val count = ref 0
	   fun updateDg r = 
	       let fun addEdge _ [] accum = accum
		     | addEdge n ((l,_)::ts) accum= 
		       if isConnectable dSymSet (r,l)
		       then addEdge (n+1) ts (n::accum)
		       else addEdge (n+1) ts accum
		   val edges = rev (addEdge 0 dps [])
	       in Array.update (dg, !count, edges);
		  count := !count + 1
	       end
	   val _ = L.app (fn (_,r) => updateDg r) dps
       in
	   dg
       end


  (* CAP_l(r)  $B$H(B t $B$,(B unifiable $B$+(B *) 
  (* l >-> r $B$,(B s >-> ... $B$H(B inner connectable $B$+(B *)
  (* l $B$NItJ,9`$K$J$C$F$$$l$P!$(Br $B$NItJ,9`(B t $B$OJQ?t$KCV$-49$($J$/$F$h$$(B *) 
  (* CAP_l(r) $B$H(B t $B$N(B mgu instantiation $B$,@55,7A$G$J$$>l9g$O(B not connectable *) 
  (* --- marking $B$,$5$l$F$$$k$H2>Dj(B *)
   fun isInnermostConnectable DSymSet rules (l, r, s) =
       let val mi = Term.maxVarIndexInTerms [r,s]
	   val count = ref mi
	   fun mkCap (Var (x,ty) ) = Var (x,ty)
	     | mkCap (t as Fun (f,ts,ty)) = 
	       if FS.member (DSymSet, f) 
		  andalso not (Term.isASubterm t l)
	       then (count := 1 + !count; Var ((Atom.atom "x", !count),ty))
	       else Fun (f, L.map mkCap ts, ty)
	   val [r',s'] = Subst.renameTerms  [(mkCap r), s] 
       in
	   case Subst.unify r' s' of 
	       NONE => false
	     | SOME sigma => if Rewrite.isNormalForm rules (Subst.applySubst sigma s')
			     then true
			     else false
       end

   fun innermostDependencyGraph dSymSet rules dps = 
       let val dg = Array.array (length dps, [])
	   val count = ref 0
	   fun updateDg (l,r) = 
	       let fun addEdge _ [] accum = accum
		     | addEdge n ((s,_)::ts) accum= 
		       if isInnermostConnectable dSymSet rules (l,r,s)
		       then addEdge (n+1) ts (n::accum)
		       else addEdge (n+1) ts accum
		   val edges = rev (addEdge 0 dps [])
	       in Array.update (dg, !count, edges);
		  count := !count + 1
	       end
	   val _ = L.app updateDg dps
       in
	   dg
       end


(* (\*** $BMxMQ2DG=5,B'(B [Arts & Giesl, TCS, 2000] ***\) *)
(*    local *)
(*        (\*** $B%k!<%H5-9f$N<h$j=P$7(B ***\) *)
(*        fun rootSymbol (Term.Var ((x,_),_)) = x *)
(* 	 | rootSymbol (Term.Fun (f,_,_)) = f *)

(*        (\*** $B:8JU$N%k!<%H$N<h$j=P$7(B ***\) *)
(*        fun leftHandSide (x,_) = x *)
(*        fun rightHandSide (_,y) = y *)

(*        (\*** $B:8JU$N%k!<%H$r;XDj$7$FBP$r<h$j=P$7$?(B $B%j%9%H$H=89g(B ***\) *)
(*        fun isDefinedBy f tp = Fun.equal (rootSymbol (leftHandSide tp), f) *)
(*        fun termPairListDefinedBy f tps = List.filter (isDefinedBy f) tps *)
(*        fun termPairSetDefinedBy f tps = TPS.filter (isDefinedBy f) tps *)

(*        fun listToTermPairSet xs = List.foldl TPS.add' TPS.empty xs *)

(*    in *)
(*    fun usableSubset (Term.Var _) ruleset = TPS.empty *)
(*      | usableSubset (Term.Fun (f,ts,_)) ruleset = *)
(*        let val uruleset = termPairSetDefinedBy f ruleset *)
(* 	   val druleset = TPS.difference (ruleset, uruleset) *)
(*        in  TPS.union (uruleset, *)
(* 		      TPS.union (TPS.foldl (fn (x,xs) => TPS.union (usableSubset (rightHandSide x) druleset, xs)) TPS.empty uruleset, *)
(* 				 foldl (fn (t,xs) => TPS.union (usableSubset t druleset, xs)) TPS.empty ts *)
(* 				) *)
(* 		     ) *)
(*        end *)
(*    fun usableSubsetFunSet term rules = *)
(*       Trs.funSetInRules (TPS.listItems (usableSubset term (listToTermPairSet rules))) *)

(*    end *)


   fun getUsableFunMap (trs:Trs.trs) =
       let
	   val rs = #Rules trs
	   val dSymSet = #DSymSet trs

 	   (* $B:8JU$N%k!<%H5-9f(Bf$B!$1&JU$NDj5A4X?t5-9f=89g(B *)
	   fun rootOfLhsAndDFunSetOfRhs dSymSet (l,r) =
	       let val root = case l of
			      Term.Fun (f,_,_) => f
			    | _ => (print "getUsableFunMap: variable root lhs?\n";
				    raise DpError)
		   val funSet = FS.intersection (Term.funSetInTerm r, dSymSet)
	       in (root, funSet)
	       end

	   fun registrate xs =
	       let fun makeTableSub [] fm = fm
		     | makeTableSub ((f,fset)::xs) fm =
		       case (FM.find (fm,f)) of
			   SOME gset => makeTableSub xs (FM.insert (fm, f, FS.union (fset,gset)))
			 | NONE => makeTableSub xs (FM.insert (fm, f, fset))
	       in
		   makeTableSub xs FM.empty
	       end

	   val dSymList = FS.listItems dSymSet
	   val dSymListLength = List.length dSymList
	   val dIdxList = L.tabulate (dSymListLength, fn x=>x)

           val (fmap,imap) = ListPair.foldr  
			     (fn (a,i,(f,g)) => (FM.insert (f,a,i), IM.insert (g,i,a)))
			     (FM.empty, IM.empty)  
			     (dSymList, dIdxList)

	   fun fromList [] = FS.empty
	     | fromList (f::fs) = FS.add (fromList fs, f)

	   fun getClosure fm = 
	       let val ar = Array.array (dSymListLength,[])
		   val _ = FS.app (fn f => 
					 Array.update (ar,
						       valOf (FM.find (fmap,f)),
						       L.map (fn g => valOf (FM.find (fmap,g)))
							     (FS.listItems (valOf (FM.find (fm,f))))))
				      dSymSet
		   val graph2 = Graph.reflexiveTransitiveClosure ar
	       in
		  Array.foldri
		      (fn (i,ns,fm) => 
			  (FM.insert (fm,
				      valOf (IM.find (imap,i)),
				      fromList (L.map (fn j => valOf (IM.find (imap,j)))
						      ns))))
		      FM.empty
 		      graph2
	       end

       in
	 getClosure (registrate (L.map (rootOfLhsAndDFunSetOfRhs dSymSet) rs))
       end
	   

   fun usableFunSet fm fset =
       FS.foldr (fn (f,ac) => case (FM.find (fm,f)) of
				  SOME gset => FS.union (gset,ac)
				| NONE => ac)
		FS.empty
		fset

   fun getUsableRules usableFM rules dps = 
       let val usableFS =  usableFunSet usableFM 
					(Term.funSetInTerms (L.map (fn (l,r) => r) dps))
       in
	   L.filter (fn i =>
			case L.nth (rules,i) of
			    (Term.Fun (f,_,_),_) => FS.member (usableFS,f)
			  | _ => (print "getUsableRules: variable lhs root?\n";
				  raise DpError))
		    (L.tabulate (L.length rules, fn i=>i))
       end


   fun applyAfToTerm piMap colSet (t as (Var _)) = t
     | applyAfToTerm piMap colSet (t as (Fun (f,ts,ty))) =
	   case FM.find (piMap, f) of 
	       SOME ns => if FS.member (colSet, f)
			  then applyAfToTerm piMap colSet (L.nth (ts, hd ns))
			  else Fun (f,
				    L.map (fn i => applyAfToTerm piMap colSet (L.nth (ts,i))) ns,
				    ty)
	     | _=> (print ("applyAfToTerm: (" ^ (Term.toString t) ^ "\n");
		    raise DpError)


   end (* of local *)

end (* of structure *)


