(******************************************************************************
 * Copyright (c) 2012-2014, Toyama&Aoto Laboratory, Tohoku University
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without 
 * modification, are permitted provided that the following conditions are met:
 * 
 *  1. Redistributions of source code must retain the above copyright notice, 
 *     this list of conditions and the following disclaimer.
 *  2. Redistributions in binary form must reproduce the above copyright 
 *     notice, this list of conditions and the following disclaimer in the 
 *     documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 
 * POSSIBILITY OF SUCH DAMAGE.
 ******************************************************************************)
(******************************************************************************
 * file: rwtools/rwchecker/cr_direct.sml
 * description: confluence check by direct sum decomposition
 * author: AOTO Takahito
 * 
 ******************************************************************************)

signature CR_DIRECT  = 
   sig
   val disjointComponents: (Term.term * Term.term) list 
			   -> ((Term.term * Term.term) list * FunSet.set) list		     

   val proveConfluenceByDirectSumDecomposition:
	   ((Term.term * Term.term) list -> bool) 
	   -> (Term.term * Term.term) list 
	   ->  Cr.ConfluenceResult

   val tryDirectSumDecomposition:
       (Term.term * Term.term) list
       -> ((Term.term * Term.term) list -> Cr.ConfluenceResult) 
	  * ((Term.term * Term.term) list -> Cr.ConfluenceResult) 
       -> Cr.ConfluenceResult

   val makeInitialEnv: (Term.term * Term.term) list 
          -> (VarSet.item * int) list * (Fun.ord_key * int list * int) list 
	     * (Term.term * int) list * (Term.term * Term.term) list

   val inferTypesFromTermConstraints:
          (VarSet.item * int) list * (Fun.ord_key * int list * int) list 
	  -> (Term.term * int) list
          -> (VarSet.item * int) list * (Fun.ord_key * int list * int) list 

   val lookupVarEnv:  (VarSet.item * int) list -> VarSet.item -> int

   val typesOfTerm: (VarSet.item * int) list * (Fun.ord_key * int list * int) list 
		    -> Term.term -> int

   val persistentComponents: (Term.term * Term.term) list 
			     -> (IntSet.set * (Term.term * Term.term) list) list

   val tryPersistentDecomposition:
       (Term.term * Term.term) list
       -> ((Term.term * Term.term) list -> Cr.ConfluenceResult) 
	  * ((Term.term * Term.term) list -> Cr.ConfluenceResult) 
       -> Cr.ConfluenceResult
       
   val isConfluentWeakLeftLinearSystem:
	   (bool -> (Term.term * Term.term) list -> bool)
       -> (Term.term * Term.term) list * (Term.term * Term.term) list 
       -> (Term.term * Term.term) list 
       -> bool

end;


structure CrDirect : CR_DIRECT = 
   struct

   local 
       open Term
       open Trs
       open Rewrite
       open Subst
       open Cr
       structure VS = VarSet
       structure VM = VarMap
       structure FS = FunSet
       structure FM = FunMap
       structure SS = SortSet
       structure FIS = FunIntSet
       structure IS = IntSet
       structure L = List
       structure LP = ListPair
       structure LU = ListUtil
       structure TS = TermSet
       fun mapAppend f xs = List.foldr (fn (x,ys) => List.@(f x, ys)) [] xs
       fun exceptNth n [] = []
	 | exceptNth n (x::xs) = 
	   if n = 0 
	   then xs
	   else (x::(exceptNth (n-1) xs))
       fun member x ys = List.exists (fn y => y = x) ys
       fun union [] ys = ys
	 | union (x::xs) ys =
	   if member x ys then union xs ys else (union xs (x::ys))
       fun notDisjoint xs ys = List.exists (fn y => member y xs) ys
       fun prFunSet set = "{" ^ (PrintUtil.prSeq 
				     Fun.toString 
				     (FS.listItems set))
			  ^ "}\n"
       fun delete [] ys = []
	 | delete (x::xs) ys = if member x ys 
			       then delete xs ys
			       else x::(delete xs ys)
       fun prIntSet set = "{" ^ (PrintUtil.prSeq Int.toString (IS.listItems set)) ^ "}"
       open PrintUtil

   in

   val runDebug = ref false : bool ref
   fun debug f = if !runDebug then f () else ()
   exception CrDirectError;


  (* 񤭴§ν rs = [r1,...rn]  direct sum ˤʤ褦 
     [([r1,r2],funSet1),([r4],funSet2),([r3,r5],funSet3)] 
     Τ褦ʬ롥funSetI \cap funSetJ = \emptyset for all I \neq J *)
  
   fun disjointComponents rs = 
       let fun djComponents [] accum = accum
	     | djComponents ((rs0,fs0)::rss) accum =
	       let val (rss1,rss2) = 
		       L.partition
			   (fn (_,fs) => FunSet.isEmpty (FunSet.intersection (fs0,fs)))
			   rss
	       in if null rss2
		  then djComponents rss ((rs0,fs0)::accum)
		  else
		      let 
			  val rs1Xfs0 = 
			      L.foldr 
				  (fn ((rs,fs),(rs',fs')) => (rs @ rs', FunSet.union (fs,fs')))
				  (rs0,fs0) rss2
		      in
			  djComponents (rs1Xfs0::rss1) accum
		      end
	       end
       in
	   djComponents (map (fn rule => ([rule], funSetInRule rule)) rs) []
       end


  (* ήľʬѤȽ *)
   fun proveConfluenceByDirectSumDecomposition isTerminating rs = 
       let val _ = print "\nTry Direct Sum Decomposition...\n"
	   val components = disjointComponents rs
       in
	   if length components = 1
	   then (print "Direct Sum Decomposition failed"; report Unknown)
	   else
	       let val subresults = 
		       L.map (fn (rs,funSet) =>
				 let val _ = print (prFunSet funSet)
				     val _ = print (Trs.prRules rs)
				 in
				     checkConfluenceConditions isTerminating rs
				 end)
			     components
		   val _ = print "Proof by Direct Sum Decomposition"
	       in 
		   if L.all (fn c => c = CR) subresults
		   then report CR
		   else if L.exists (fn c => c = NotCR) subresults
		   then report NotCR
		   else report Unknown
	       end
       end

   fun tryDirectSumDecomposition rs (direct,decomp) = 
       let val _ = print "\nTry Direct Sum Decomposition for...\n"
	   val _ = print (Trs.prRules rs)
	   val components = disjointComponents rs
	   fun checkNCR [] = Unknown
	     | checkNCR ((rs,funSet)::rss) = 
	       let val _ = print (prFunSet funSet)
		   val _ = print "(ds)"
	       in
		  case direct rs of
		      NotCR => NotCR
		    | _ => checkNCR rss
	       end
	   fun checkAll [] = CR
	     | checkAll ((rs,funSet)::rss) = 
	       let val _ = print (prFunSet funSet)
		   val _ = print "(ds)"
	       in
		  case direct rs of
		      CR => checkAll rss
		    | NotCR => NotCR
		    | Unknown => checkNCR rss
	       end
       in
	   if length components = 1
	   then (print "Direct Sum Decomposition failed"; 
		 report Unknown;
		 decomp rs)
	   else
	       case checkAll components of
		   CR => (print "Result by Direct Sum Decomposition";
			  report CR)
		 | NotCR => (print "Result by Direct Sum Decomposition";
			     report NotCR)
		 | Unknown => (print "Result by Direct Sum Decomposition";
			       report Unknown)
       end


   fun makeInitialEnv rs =
       let val rs' = Trs.renameRules rs
	   val vSet =  Trs.varSetInRules rs'
	   val faSet =  Trs.funAritySetInRules rs'
	   val counter = ref 0
	   fun inc () = counter := (!counter) + 1
	   fun typingFun (f,n) = 
	       let val args = L.tabulate (n, fn i => (inc (); !counter))
		   val ret = (inc (); !counter)
	       in
		   (f, args,ret)
	       end
	   fun typingRules (l,r) = 
	       let val new = (inc (); !counter)
	       in
		   [(l,new), (r,new)]
	       end
	       
       in
	   (L.map (fn v => (inc (); (v, !counter))) (VarSet.listItems vSet),
	    L.map (fn (f,n) => typingFun (f,n)) (FunIntSet.listItems faSet),
	    mapAppend typingRules rs',
	    rs')
       end

   fun lookupVarEnv varEnv x =
       case List.find (fn (y,n) => Var.equal (x,y)) varEnv of
	   SOME (y,n) => n
	 | NONE => (print "lookupVarEnv fails\n";
		    raise CrDirectError)

   fun lookupFunEnv funEnv f =
       case List.find (fn (g,ms,m) => Fun.equal (f,g)) funEnv of
	   SOME (g,ms,m) => (ms,m)
	 | NONE => (print "lookupFunEnv fails\n";
		    raise CrDirectError)

   fun typesOfTerm (varEnv,funEnv) (Var (v,_)) = lookupVarEnv varEnv v
     | typesOfTerm (varEnv,funEnv) (Fun (f,_,_)) = #2 (lookupFunEnv funEnv f)

   fun replace (m,n) k = if k = m then n else k

   fun updateEnv (m,n) (varEnv,funEnv) =
       (L.map (fn (v,k) => (v, replace (m,n) k)) varEnv,
	L.map (fn (f,ks,k) => (f,L.map (replace (m,n)) ks, replace (m,n) k)) funEnv)

   fun updateTermConstraints (m,n) ts =  L.map (fn (t,k) => (t,replace (m,n) k)) ts

   fun inferTypesFromTermConstraints (varEnv,funEnv) [] = (varEnv,funEnv)
     | inferTypesFromTermConstraints (varEnv,funEnv) ((t,n)::cs) = 
       case t of
	   Var (v,_) => let val m = lookupVarEnv varEnv v
			in inferTypesFromTermConstraints
			    (updateEnv (m,n) (varEnv,funEnv))
			    (updateTermConstraints(m,n) cs)
			end

	 | Fun (f,ts,_) => let val (ms,m) = lookupFunEnv funEnv f
			       val newcs = LP.zip (ts,ms)
			   in inferTypesFromTermConstraints
				  (updateEnv (m,n) (varEnv,funEnv))
				  (updateTermConstraints (m,n) (newcs @ cs))
			   end


   fun getMaximalTypes funEnv = 
       let val initSet = L.foldl (fn ((_,ms,m),set) => IS.addList (set, m::ms))
				 IS.empty
				 funEnv
(*	   val _ = print (prIntSet initSet) *)
	   val inits = IS.listItems initSet
	   val idxes = L.tabulate (length inits,fn i=>i)
	   fun indexOf n = let fun indexOfSub [] _ = (print ("getMaximalType failed\n" ^ 
							       (Int.toString n));
						      raise CrDirectError)
				 | indexOfSub (m::ms) i = if n = m 
							    then i
							    else indexOfSub ms (i+1)
			   in indexOfSub inits 0 
			   end
	   val gInit = Array.array (length inits, [])
	   val _  = L.app (fn (_,ms,m) => 
			      let val ys = Array.sub (gInit,indexOf m)
			      in Array.update (gInit,
					       indexOf m,
					       union (L.map indexOf ms) ys)
			      end)
			  funEnv
(*	   val _ = print (Graph.toString gInit) *)
	   val gTrans = Graph.reflexiveTransitiveClosure gInit
(*	   val _ = print (Graph.toString gTrans) *)
	   val all = Array.foldl (fn (ms,nss) => (IS.addList(IS.empty,ms))::nss)
				 [] gTrans
	   fun leaveMaximals xss = 
	       let fun leaveMaximalsSub [] zss = zss
		     | leaveMaximalsSub (xset::yss) zss = 
		       if L.exists (fn yset => IS.isSubset (xset,yset)) yss
			  orelse L.exists (fn zset => IS.isSubset (xset,zset)) zss
		       then leaveMaximalsSub yss zss
		       else leaveMaximalsSub 
			    (L.filter (fn yset => not (IS.isSubset (yset,xset))) yss)
			    (xset::zss)
	       in leaveMaximalsSub xss [] 
	       end
	   val maximals = L.map (fn set => IS.map (fn i => L.nth (inits,i)) set)
				(leaveMaximals all)

	   val _ = print "maximal types: "
	   val _ = L.app (fn mset => print (prIntSet mset)) maximals 
	   val _ = print "\n"

       in
	   maximals
       end



   fun persistentComponents rs = 
       let val (varEnv,funEnv,cs,_) = makeInitialEnv rs
	   val (varEnv',funEnv') = inferTypesFromTermConstraints (varEnv,funEnv) cs
	   val _ = print "Sort Assignment:\n"
	   val _ = L.app (fn (f,ms,m) => print (" " ^ (Fun.toString f)
						^ " : "
						^ (PrintUtil.prProd Int.toString ms)
						^ "=>"
						^ (Int.toString m) ^ "\n"))
			 funEnv'
	   val maximals = getMaximalTypes funEnv'
       in
	   L.map
	       (fn nset => (nset, L.filter (fn (l,_) => 
					       let val m = typesOfTerm (varEnv',funEnv') l
					       in IS.member (nset,m)
					       end)
					   rs))
	       maximals
       end

   fun tryPersistentDecomposition rs (direct,decomp) = 
       let val _ = print "\nTry Persistent Decomposition for...\n"
	   val _ = print (Trs.prRules rs)
	   val components = persistentComponents rs
	   fun checkNCR [] = Unknown
	     | checkNCR ((tvSet,rs)::rss) = 
	       let val _ = print ((prIntSet tvSet) ^ "\n")
		   val _ = print "(ps)"
(*		   val _ = print (Trs.prRules rs) *)
	       in
		  case direct rs of
		      NotCR => NotCR
		    | _ => checkNCR rss
	       end
	   fun checkAll [] = CR
	     | checkAll ((tvSet,rs)::rss) = 
	       let val _ = print ((prIntSet tvSet) ^ "\n")
		   val _ = print "(ps)"
(*		   val _ = print (Trs.prRules rs) *)
	       in
		   case direct rs of
		       CR => checkAll rss
		     | NotCR => NotCR
		     | Unknown => checkNCR rss
	       end
       in
	   if length components = 1
	   then (print "Persistent Decomposition failed"; 
		 report Unknown;
		 decomp rs)
	   else
	       case checkAll components of
		   CR => (print "Result by Persistent Decomposition";
			  report CR)
		 | NotCR => (print "Result by Persistent Decomposition";
			     report NotCR)
		 | Unknown => (print "Result by Persistent Decomposition";
			       report Unknown)
       end

   val maxRewriteLenForParallelClosedCheck = ref 2

   (* assume rs is not left-linear, otherwise not usulful *)
   fun isConfluentWeakLeftLinearSystem isTerminating (inCps,outCps) rs = 
       let val (varEnv,funEnv,cs,rs') = makeInitialEnv rs
	   val _ = print (Trs.prRules rs')
	   val (varEnv',funEnv') = inferTypesFromTermConstraints (varEnv,funEnv) cs
	   val _ = print "Sort Assignment:\n"
	   val _ = L.app (fn (f,ms,m) => print (" " ^ (Fun.toString f)
						^ " : "
						^ (PrintUtil.prProd Int.toString ms)
						^ "=>"
						^ (Int.toString m) ^ "\n"))
			 funEnv'

	   val nonLinearVars = LU.mapAppend (fn (l,_) => Term.nonLinearVarListInTerm l) rs'
	   val _ = print "non-linear variables: "
	   val _ = print ((LU.toStringCommaCurly Var.toString nonLinearVars) ^ "\n")

	   val nonLinearTypes = LU.eliminateDuplication (L.map (lookupVarEnv varEnv') nonLinearVars)
	   val _ = print "non-linear types: "
	   val _ = print ((LU.toStringCommaCurly Int.toString nonLinearTypes) ^ "\n")

	   fun getLeqTypes funEnv' types = 
	       let fun addTypes tmp = LU.mapAppend (fn (f,args,ret) => if LU.member ret tmp then args else []) 
						   funEnv'
		   fun getClosure tmp = let val added = addTypes tmp
					in if L.all (fn ty => LU.member ty tmp) added
					   then tmp
					   else getClosure (LU.union (added, tmp))
					end
	       in getClosure types 
	       end

	   val typesLeqNonLinearTypes = getLeqTypes funEnv' nonLinearTypes
	   val _ = print "types leq non-linear types: "
	   val _ = print ((LU.toStringCommaCurly Int.toString nonLinearTypes) ^ "\n")

	   val rulesApplicableToNonLinearTypes = 
	       L.filter (fn (l,r) => LU.member (typesOfTerm (varEnv',funEnv') l) typesLeqNonLinearTypes) rs'

	   val _ = print "rules applicable to terms of non-linear types:\n"
	   val _ = print (Trs.prRules rulesApplicableToNonLinearTypes)

	   fun makeSort (args,ty) = if (null args) 
				    then (Sort.fromString (Int.toString ty))
				    else Sort.Proc (L.map (fn sy => (Sort.fromString (Int.toString sy))) args,
						    Sort.fromString (Int.toString ty))
	   fun makeDecl (f,args,ty) = {sym=f, sort=makeSort (args,ty)}
	   val decls = L.map makeDecl funEnv' 
	   val sortedRules = case Trs.attachSortToRules decls rs of
				 SOME ss => ss
			       | NONE => (print "isConfluentWeakLeftLinearSystem\n";
					  raise CrDirectError)

	   val nonLinearSorts = L.map (fn ty => Sort.fromString (Int.toString ty)) nonLinearTypes

	   val inCps = Cr.insideCriticalPairs sortedRules
	   val outCps = Cr.outsideCriticalPairsInOneside sortedRules

	   fun groundInstPreserving sublist =
	       let val lhs' = L.map (fn (l,r) => l) rs'
		   fun checkimage (x,t) = 
		       LU.member' Var.equal x nonLinearVars
		       andalso
		       (TS.foldr (fn (y,ans) => ans orelse not (LU.member' Sort.equal 
									  (Term.sortOfTerm y) nonLinearSorts))
				false (Term.varTermSetInTerm t) 
			orelse
			ListXProd.foldX (fn (l,u,ans) => ans orelse (isSome (Subst.unify l u)))
					(lhs', Term.nonVarSubterms t) false)

		   fun checksub sub =
		       (print ("is " ^ (Subst.toStringWithVarSort sub) ^ " ground inst. preserving? ");
			if VM.isEmpty (VM.filteri checkimage sub)
			then (print "(yes)\n";true)
			else (print "(no)\n"; false))
	       in L.all checksub sublist
	       end

	   fun checkInCps (p,q) = 
	       let val redsublist = parallelOneStepReductsWithSubst rs' p
		   val _ = print ("inner CP <p,q> = <" ^ (Term.toStringWithVarSort p) 
				  ^ ", " ^ (Term.toStringWithVarSort q) ^ ">\n")
		   val _ = print ("parallel reducts of p: " 
				  ^ (LU.toStringCommaCurly Term.toStringWithVarSort 
							   (L.map (fn (t,_) => t) redsublist)))
		   val _ = print "\n"
	       in L.exists (fn (t,sublist) => Term.equal (t,q) andalso groundInstPreserving sublist)
			   redsublist
	       end

	   fun checkOutCps (p,q) = 
	       let val predsublist = parallelOneStepReductsWithSubst rs' p
		   val qredsublist = parallelOneStepReductsWithSubst rs' q
		   val _ = print ("outer CP <p,q> = <" ^ (Term.toStringWithVarSort p) 
				  ^ ", " ^ (Term.toStringWithVarSort q) ^ ">\n")
		   val _ = print ("parallel reducts of p: " 
				  ^ (LU.toStringCommaCurly Term.toStringWithVarSort
							   (L.map (fn (t,_) => t) predsublist)))
		   val _ = print "\n"
		   val _ = print ("parallel reducts of q: " 
				  ^ (LU.toStringCommaCurly Term.toStringWithVarSort
							   (L.map (fn (t,_) => t) qredsublist)))
		   val _ = print "\n"
	       in ListXProd.foldX 
		  (fn ((x,xs),(y,ys),ans) => ans 
					     orelse (Term.equal (x,y) 
						     andalso (print ("join at " ^ (Term.toStringWithVarSort x) 
								     ^ "\n");true)
						     andalso (print ("check subst from p\n");groundInstPreserving xs)
						     andalso (print ("check subst from q\n");groundInstPreserving ys)))
		  (predsublist, qredsublist) false
	       end

       in if isTerminating true rulesApplicableToNonLinearTypes
	  then let val _ = print "terms of non-linear types are innermost terminating\n"
	       in if Cr.isNonOverlapping rs
		  then (print "Non-Overlapping\n"; true)
		  else if L.all checkInCps inCps
			  andalso
			  L.all checkOutCps outCps
		  then (print "Parallel-Closed\n"; true)
		  else false
	       end
	  else (print "unknown innermost-termination for terms of non-linear types\n"; false)
       end

   (* val _ = isConfluentWeakLeftLinearSystem (fn _ => (fn _ => true)) *)
   (* 	   (IOFotrs.rdRules [ "f(?x,?x) -> f(g(?x),?x)", *)
   (* 			     "f(g(?x),?x) -> f(h(?x),h(?x))", *)
   (* 			     "h(g(?x)) -> g(g(h(?x)))" ]) *)



  end
end
