(******************************************************************************
 * Copyright (c) 2017-2023, Aoto Laboratory, Niigata University
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without 
 * modification, are permitted provided that the following conditions are met:
 * 
 *  1. Redistributions of source code must retain the above copyright notice, 
 *     this list of conditions and the following disclaimer.
 *  2. Redistributions in binary form must reproduce the above copyright 
 *     notice, this list of conditions and the following disclaimer in the 
 *     documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 
 * POSSIBILITY OF SUCH DAMAGE.
 ******************************************************************************)
(******************************************************************************
 * file: rwtools/term_rewriting/ctrs.sml
 * description: conditional term rewriting systems
 * author: AOTO Takahito
 * 
 ******************************************************************************)

signature CTRS = 
   sig
   val runDebug: bool ref
   type crule = Term.term * Term.term * (Term.term * Term.term) list
   type crules = crule list

   type ceq = Term.term * Term.term * (Term.term * Term.term) list
   type ceqs = ceq list

   val prRule: crule -> string
   val prRules: crules -> string

   val prRuleWithVarSort: crule -> string
   val prRulesWithVarSort: crules -> string

   val prEq: ceq -> string
   val prEqs: ceqs -> string

   val condTOLinearizationRule: Trs.rule -> crule
   val condTOLinearizationRules: Trs.rules -> crules
   val renameRules: crules -> crules

   val unraveling: crules -> Trs.rules

   val ruleOfTypeIII: crule -> bool
   val rulesOfTypeIII: crules -> bool

   val isDeterministicRule: crule -> bool
   val isDeterministic: crules -> bool

   val isAbsolutelyIrreducible: crules -> bool
   val isWeaklyLeftLinearRules: crules -> bool
   val isRightLinearRules: crules -> bool

   val isRightStable: crules -> bool
   val isProperlyOrientedRules: crules -> bool

   val skolemize: Term.term -> (Term.term * Term.term) list * Term.term
   val skolemizeTerms: Term.term list -> (Term.term * Term.term) list * Term.term list
   val unSkolemize: (Term.term * Term.term) list -> Term.term -> Term.term

   val rootStep: crules -> Term.term -> Term.term option
   val rewriteOneStep: crules -> Term.term -> Term.term option
   val linf: crules -> Term.term -> Term.term

   val constraintOneStepReducts: crules -> (Term.term * Term.term) list * Term.term -> Term.term list
   val constraintZeroOrOneStepReducts: crules -> (Term.term * Term.term) list * Term.term -> Term.term list
   val constraintParallelOneStepReducts: crules -> (Term.term * Term.term) list * Term.term -> Term.term list
   val constraintParallelTwoStepsReducts: crules -> (Term.term * Term.term) list * Term.term -> Term.term list

   val condKdVLinearizationRule: Trs.rule -> crule
   val condKdVLinearizationRules: Trs.rules -> crules

   val condKdVLinearizationRuleMult: Trs.rule -> crules
   val condKdVLinearizationRulesMult: Trs.rules -> crules list

   val attachSortToCondRule: Term.decl list ->  crule -> crule option
   val attachSortToCondRules: Term.decl list ->  crules -> crules option
end;

structure Ctrs : CTRS = 
struct

type crule = Term.term * Term.term * (Term.term * Term.term) list
type crules = crule list

type ceq = Term.term * Term.term * (Term.term * Term.term) list
type ceqs = ceq list

val runDebug = ref true : bool ref
fun debug f = if !runDebug then f () else ()
exception CtrsError


local 
    structure FS = FunSet
    structure L = List
    structure LU = ListUtil
    structure LP = ListPair
    structure PU = PrintUtil
    structure S = Subst
    structure T = Term
    structure TS = TermSet
    structure VM = VarMap
    structure VS = VarSet
    open PrintUtil
in

fun prCond (l,r) = (T.toString l) ^ " == " ^ (T.toString r)
fun prCondWithVarSort (l,r) = (T.toStringWithVarSort l) ^ " == " ^ (T.toStringWithVarSort r)
fun prConds cs = LU.toStringComma prCond cs
fun prCondsWithVarSort cs = LU.toStringComma prCondWithVarSort cs

fun prRule (l,r,[]) = (T.toString l) ^ " -> " ^ (T.toString r)
  | prRule (l,r,cs) = (T.toString l) ^ " -> " ^ (T.toString r) ^ " | " ^ (prConds cs)
									     
fun prRuleWithVarSort (l,r,[]) = (Term.toStringWithVarSort l) ^ " -> " ^ (Term.toStringWithVarSort r)
  | prRuleWithVarSort (l,r,cs) = (Term.toStringWithVarSort l) ^ " -> " ^ (Term.toStringWithVarSort r)
				 ^ " | " ^ (prCondsWithVarSort cs)

fun prRules rs = LU.toStringCommaLnSquare prRule rs
fun prRulesWithVarSort rs = LU.toStringCommaLnSquare prRuleWithVarSort rs

fun prEq (l,r,[]) = (T.toString l) ^ " = " ^ (T.toString r)
  | prEq (l,r,cs) = (T.toString l) ^ " = " ^ (T.toString r) ^ " | " ^ (prConds cs)
									     
fun prEqs rs = LU.toStringCommaLnSquare prEq rs

fun condTOLinearizationRule (l,r) =
    let val lvars = T.varListInTerm l
	val linLhs = T.linearize l
	val freshLinLhs = Subst.renameTermDisjointFrom lvars linLhs
	val renaming = valOf (Subst.match freshLinLhs l)
	(* val cancel = [] *)
        (*** refine the definition of linerization  ***)
	val cancel = LU.differenceByAll' Var.equal (T.linearVarListInTerm l, T.varListInTerm r)
	val renamingCancel = VM.filteri (fn (v,t) => LU.member' Var.equal (valOf (T.varRootOfTerm t)) cancel) renaming
	val renamingRest = VM.filteri (fn (v,t) => not (LU.member' Var.equal (valOf (T.varRootOfTerm t)) cancel)) renaming
	val ans = Subst.applySubst renamingCancel freshLinLhs
	val cond = L.map (fn (x,t) => (T.mkVarTerm (x, Sort.null), t)) (VM.listItemsi renamingRest)
    in (ans, r, cond) 
    end

fun condTOLinearizationRules rs = L.map condTOLinearizationRule rs

fun renameRules crules = 
    let val f = Fun.fromString "Dummy"
	fun trans (l,r,cs) = let val (csL,csR) = LP.unzip cs
				 val t1 = T.mkFunTerm (f, [l,r],Sort.null)
				 val t2 = T.mkFunTerm (f, csL,Sort.null)
				 val t3 = T.mkFunTerm (f, csR,Sort.null)
			     in T.mkFunTerm (f, [t1,t2,t3], Sort.null)
			     end
	val ts = Subst.renameTerms (L.map trans crules)
	fun back t = let val [t1,t2,t3] = T.argsOfTerm t
			 val [l,r] = T.argsOfTerm t1
			 val csL = T.argsOfTerm t2
			 val csR = T.argsOfTerm t3
		     in (l, r, LP.zip (csL,csR)) 
		     end
    in L.map back ts
    end


fun mkUsym n = Fun.fromString ("U" ^ Int.toString n)

fun isUsym f = String.isPrefix "U" (Fun.toString f)
fun isRootUterm t = case T.funRootOfTerm t of NONE => false | SOME f => isUsym f


fun unravelingRule n (l,r,[]) = [(l, r)]
  | unravelingRule n (l,r,(u,v)::cs) = 
    let val ty = T.sortOfTerm l
	val xs = TS.listItems (T.varTermSetInTerm l)
	fun unravelingRuleSub n r xs v [] = (T.Fun (mkUsym n, v::xs, ty), r)::[]
	  | unravelingRuleSub n r xs v ((s,t)::cs) = 
	    let val ys = TS.listItems (T.varTermSetInTerm v)
		val xs2 = LU.union' Term.equal (xs, ys)
	    in (T.Fun (mkUsym n, v::xs, ty),T.Fun (mkUsym (n+1), s::xs2,ty)) :: unravelingRuleSub (n+1) r xs2 t cs
	    end
    in (l, T.Fun (mkUsym n, u::xs, ty)) :: (unravelingRuleSub n r xs v cs)
    end


fun unraveling rules = 
    let fun main n [] = []
	  | main n ((lrc as (_,_,cs))::rs) = (unravelingRule n lrc) @ (main (n + (L.length cs)) rs)
	fun normalizeVarName (l,r) = let val (l',r',sigma) = Trs.normalizeVarNameRule (l,r)
					 (* val _ = println (Subst.toString sigma) *)
				     in (l',r') end
	val rules1 = main 0 rules
	(* val _ = print (Trs.prRules rules1) *)

	val rules2 = L.map normalizeVarName rules1
	(* val _ = print (Trs.prRules rules2) *)
	fun changeFun (f,g) rs = L.map (fn (l,r) => (T.changeFunNameInTerm (f,g) l, T.changeFunNameInTerm (f,g) r)) rs
	exception Found of (Fun.ord_key * Fun.ord_key)
	fun collapseStep rs = let val _ = L.app (fn (l1,r1) => 
						       L.app (fn (l2,r2) => 
								 if Term.equal (l1,l2)
								    andalso not (Term.equal (r1,r2))
								    andalso isRootUterm r1 
								    andalso isRootUterm r2
								 then case (Term.funRootOfTerm r1, Term.funRootOfTerm r2)
								       of (SOME f, SOME g) => if Fun.equal (f,g)
											      then ()
											      else raise Found (g,f)
									| _ => ()
								 else ()) rs) rs
			      in NONE end
			      handle Found (f,g) => SOME (f,g)
	fun collapseIte rs = case collapseStep rs of
				  NONE => rs
				| SOME (f,g) => collapseIte (L.map normalizeVarName (changeFun (g,f) rs))
    in LU.eliminateDuplication' TermPair.equal (collapseIte rules2)
    end

fun ruleOfTypeIII (l,r,c) = 
    let val vsetL = T.varSetInTerm l
	val vsetR = T.varSetInTerm r
	val vsetC = Trs.varSetInRules c
    in VS.isSubset (vsetR, VS.union (vsetL,vsetC))
    end

fun rulesOfTypeIII rs = L.all ruleOfTypeIII rs

fun isDeterministicRule (l,r,cs) =
    let fun checkDet vset [] = true
	  | checkDet vset ((s,t)::cs) =
	    VS.isSubset (T.varSetInTerm s, vset)
	    andalso checkDet (VS.union (vset,T.varSetInTerm t)) cs
    in checkDet (T.varSetInTerm l) cs
    end

fun isDeterministic rs = L.all isDeterministicRule rs

fun isAbsolutelyIrreducible rs =
    let val LHSs = L.map (fn (l,r,c) => l) rs
	val maxIdx = 1 + T.maxVarIndexInTerms LHSs
	fun isAbsolutelyIrreducibleTerm t = 
	    let val ts = T.nonVarSubterms (T.increaseVarIndexBy maxIdx t)
	    in L.all (fn u => L.all (fn l => not (isSome (Subst.unify l u))) LHSs) ts
	    end
	val condRHSs = LU.mapAppend (fn (l,r,c) => L.map (fn (s,t)=>t) c) rs
    in L.all isAbsolutelyIrreducibleTerm condRHSs
    end

fun isWeaklyLeftLinearRule (l,r,c) =
    let val nlvars = T.nonLinearVarListInTerm (T.Fun (Fun.fromString "Dummy", l::L.map (fn (u,v)=>v) c, Sort.null))
	val vset = T.varSetInTerms (r::L.map (fn (u,v)=>u) c)
    in L.all (fn x => not (VS.member (vset,x))) nlvars
    end

fun isWeaklyLeftLinearRules rs = L.all isWeaklyLeftLinearRule rs

fun isRightLinearRule (l,r,c) = T.isLinearTerm r
fun isRightLinearRules rs = L.all isRightLinearRule rs

fun isRightStable rs = 
    let fun isConstructor f = L.all (fn (l,r,c) => not (Fun.equal (valOf (Term.funRootOfTerm l), f))) rs
	val funs = let val (us,vs) = LP.unzip (LU.mapAppend (fn (l,r,c) => (l,r)::c) rs)
		   in FS.listItems (Term.funSetInTerms (us @ vs))
		   end
	val cfuns = L.filter isConstructor funs
	val cfunSet = FS.addList (FS.empty, cfuns)
	fun isConstructorTerm t = FS.isSubset (T.funSetInTerm t, cfunSet)
	val underlyingRules = L.map (fn (l,r,c) => (l,r)) rs
	fun isRuNormalForm t = Rewrite.isNormalForm underlyingRules t
	fun isGround t = VS.isEmpty (T.varSetInTerm t)
	fun checkCondRhs (l,r,c) = L.all (fn (_,t) => (T.isLinearTerm t andalso isConstructorTerm t)
							   orelse (T.isGroundTerm t andalso isRuNormalForm t)) c
	fun checkVarConditionSub vset t [] = VS.isEmpty (VS.intersection (vset, T.varSetInTerm t))
	  | checkVarConditionSub vset t ((u,v)::rest) = 
	    VS.isEmpty (VS.intersection (vset, T.varSetInTerm t))
	    andalso checkVarConditionSub (VS.union (vset, T.varSetInTerms [t,u])) v rest
	fun checkVarCondition (l,r,[]) = true
	  | checkVarCondition (l,r,(u,v)::rest) = checkVarConditionSub (T.varSetInTerms [l,u]) v rest

    in L.all checkCondRhs rs
       andalso L.all checkVarCondition rs
    end

fun isProperlyOrientedRule (l,r,cond) =
    let val vSetR = T.varSetInTerm r
	val vSetL = T.varSetInTerm l
	fun checkVarCondition vset [] = true
	  | checkVarCondition vset ((s,t)::rest) =
	    (VS.isSubset (T.varSetInTerm s, vset) 
	     andalso checkVarCondition (VS.union (vset, T.varSetInTerm t)) rest)
	    orelse L.all (fn (u,v) => VS.isSubset (VS.intersection (vSetR, T.varSetInTerms [u,v]), vset))
			 ((s,t)::rest)
    in VS.isSubset (vSetR, vSetL)
       orelse checkVarCondition vSetL cond
    end

fun isProperlyOrientedRules rs = L.all isProperlyOrientedRule rs

fun skolemize term = 
    let val vtermset = T.varTermSetInTerm term
	val vterms = TS.listItems vtermset
	val consts = L.map (fn t => T.Fun (Fun.fromString ("<" ^ Term.toString t ^ ">"), [], T.sortOfTerm t)) vterms
	val vcPairs = LP.zip (vterms,consts)
	val sigma = L.foldl (fn ((v,c),map)=> VM.insert (map,valOf (T.varRootOfTerm v),c)) VM.empty vcPairs
	val cvPairs = LP.zip (consts,vterms)
    in (cvPairs, Subst.applySubst sigma term)
    end

fun skolemizeTerms terms = 
    let val vtermset = T.varTermSetInTerms terms
	val vterms = TS.listItems vtermset
	val consts = L.map (fn t => T.Fun (Fun.fromString ("<" ^ Term.toString t ^ ">"), [], T.sortOfTerm t)) vterms
	val vcPairs = LP.zip (vterms,consts)
	val sigma = L.foldl (fn ((v,c),map)=> VM.insert (map,valOf (T.varRootOfTerm v),c)) VM.empty vcPairs
	val cvPairs = LP.zip (consts,vterms)
    in (cvPairs, L.map (Subst.applySubst sigma) terms)
    end

fun isSkolemConst f = let val name = Fun.toString f
		      in String.isPrefix "<?" name andalso String.isSuffix ">" name
		      end

fun dropAngles s = String.substring (s, 1, String.size s - 2)

fun unSkolemize cvPairs (T.Var x) = T.Var x
  | unSkolemize cvPairs (t as T.Fun (f,ts,ty)) = 
    if null ts andalso isSkolemConst f
    then case L.find (fn (c,v) => Term.equal (t,c)) cvPairs of SOME (_,v) => v | NONE => t
    else T.Fun (f, L.map (unSkolemize cvPairs) ts, ty)


local 
(* we assume variables in "term" is renamed to constants *)
fun rootStepSub crules ([], term) = NONE
  | rootStepSub crules ((l,r,cond)::rest, term) =
    case Subst.match l term of
	NONE => rootStepSub crules (rest, term)
      | SOME sigma => case isReachableSeq crules (sigma, cond) of 
			  NONE => rootStepSub crules (rest, term)
			| SOME rho => SOME (Subst.applySubst rho r)
and isReachableSeq crules (sigma, []) = SOME sigma
 |  isReachableSeq crules (sigma, (s,t)::cond) = isReachableSeqSub crules sigma (Subst.applySubst sigma s, Subst.applySubst sigma t) cond
and isReachableSeqSub crules sigma (s',t') cond = 
    case Subst.match t' s' of
	SOME rho => (case Subst.merge (sigma, rho) of
			 SOME sigma' => isReachableSeq crules (sigma', cond)
		       | NONE => (case rewriteOneStepSub crules s' of
				      NONE => NONE
				    | SOME s2' => isReachableSeqSub crules sigma (s2',t') cond))
      | NONE => (case rewriteOneStepSub crules s' of
		     NONE => NONE
		   | SOME s2' => isReachableSeqSub crules sigma (s2',t') cond)
and rewriteOneStepSub crules (T.Var x) = NONE
 |  rewriteOneStepSub crules (term as T.Fun (f,ts,ty)) = 
    case rootStepSub crules (crules, term) of
	SOME ans => SOME ans
      | NONE => case rewriteOneStepSubList crules ts of
		    SOME us => SOME (T.Fun (f,us,ty))
		  | NONE => NONE
and rewriteOneStepSubList crules [] = NONE
 |  rewriteOneStepSubList crules (t::ts) =
    case rewriteOneStepSub crules t of
	SOME ans => SOME (ans::ts)
      | NONE => case rewriteOneStepSubList crules ts of 
		    SOME us => SOME (t::us)
		  | NONE => NONE
and linfSub crules term = case rewriteOneStepSub crules term of
			      SOME term' => ((* println (" -> " ^ Term.toString term'); *)
			       linfSub crules term')
			    | NONE => ( (* println (" nf: " ^ Term.toString term); *)
				term)
in
fun rootStep crules term = let val (cvPairs, term') = skolemize term
			   in case rootStepSub crules (crules, term') of
				  SOME ans' => SOME (unSkolemize cvPairs ans')
				| NONE => NONE
			   end
fun rewriteOneStep crules term = let val (cvPairs, term') = skolemize term
				 in case rewriteOneStepSub crules term' of
					SOME ans' => SOME (unSkolemize cvPairs ans')
				      | NONE => NONE
				 end
fun linf crules term = let val (cvPairs, term') = skolemize term
			   val nf' = linfSub crules term'
		       in unSkolemize cvPairs nf'
		       end

end



(* val R1 = [(IOFotrs.rdTerm "lt(?x,0)", IOFotrs.rdTerm "false", []) *)
(* 	 ,(IOFotrs.rdTerm "lt(0,s(?y))", IOFotrs.rdTerm "true", []) *)
(* 	 ,(IOFotrs.rdTerm "lt(s(?x),s(?y))",IOFotrs.rdTerm "lt(?x,?y)",[]) *)
(* 	 ,(IOFotrs.rdTerm "min(cons(?x,nil))",IOFotrs.rdTerm "?x",[]) *)
(* 	 ,(IOFotrs.rdTerm "min(cons(?x,?xs))",IOFotrs.rdTerm "?x", *)
(* 	   [(IOFotrs.rdTerm "min(?xs)",IOFotrs.rdTerm "?y"), (IOFotrs.rdTerm "lt(?x,?y)",IOFotrs.rdTerm "true")]) *)
(* 	 ,(IOFotrs.rdTerm "min(cons(?x,?xs))",IOFotrs.rdTerm "?y", *)
(* 	   [(IOFotrs.rdTerm "min(?xs)",IOFotrs.rdTerm "?y"), (IOFotrs.rdTerm "lt(?x,?y)",IOFotrs.rdTerm "false")]) *)
(* 	 ] *)

(* val R2 = [(IOFotrs.rdTerm "plus(0, ?y)", IOFotrs.rdTerm "?y", []) *)
(* 	 ,(IOFotrs.rdTerm "plus(s(?x), ?y)", IOFotrs.rdTerm "s(plus(?x, ?y))", []) *)
(* 	 ,(IOFotrs.rdTerm "fib(0)", IOFotrs.rdTerm "pair(s(0), 0)", []) *)
(* 	 ,(IOFotrs.rdTerm "fib(s(?x))", IOFotrs.rdTerm "pair(?z3, ?z1)",  *)
(* 	   [(IOFotrs.rdTerm "fib(?x)", IOFotrs.rdTerm "pair(?z1, ?z2)") ,(IOFotrs.rdTerm "plus(?z1, ?z2)", IOFotrs.rdTerm "?z3")]) *)
(* 	 ] *)


(* val _ = print (prRules R1) *)
(* val _ = print (Trs.prRules (unraveling R1)) *)



(* val _ = print (prRules R2) *)
(* val _ = print (Trs.prRules (unraveling R2)) *)


(* constraint rewrite by semi-equational CTRS of type I *)
fun constraintRootStepRule (l,r,c) (cnstr, T.Var _ ) = []
  | constraintRootStepRule (l,r,c) (cnstr, term) = 
    case Subst.match l term of
	SOME sigma => if L.all (fn (s,t) => Term.checkCC cnstr (Subst.applySubst sigma s, Subst.applySubst sigma t)) c
		      then
			  let val vpairs = L.filter (fn (u,v) => T.isVar u andalso T.isVar v) c
			      val classes = LU.classify (fn ((u1,v1),(u2,v2)) => T.equal (u1,u2)) vpairs
			      fun mkmap [] = (PU.error "ctrs.sml: mkmap"; raise CtrsError)
				| mkmap ((x,y)::xs) = (x, y:: L.map (fn (u,v) => v) xs)
			      val maps = L.map mkmap classes
			      val rhos = L.foldr (fn ((u,vs),mapList) =>
						       ListXProd.mapX (fn (v,map) => VM.insert (map,
												valOf (T.varRootOfTerm u),
												v)) (vs, mapList))
						   [VM.empty] maps
			      val rlist = L.map (fn rho => Subst.applySubst rho r) rhos
			  in  LU.eliminateDuplication' T.equal (L.map  (Subst.applySubst sigma) rlist)
			  end
		      else []
      | NONE => []

fun constraintRootStep crules (cnstr, term) =
    LU.mapAppend (fn crule => constraintRootStepRule crule (cnstr, term)) crules

fun constraintRewriteStep crules (cnstr, T.Var _ ) =  []
  | constraintRewriteStep crules (cnstr, term as T.Fun (f,args,ty)) =
    case constraintRootStep crules (cnstr, term) of
	[] => (case constraintRewriteStepList crules (cnstr, args) of
		   [] => []
		|  tsList => L.map (fn ts => T.Fun (f,ts,ty)) tsList)
      | ts => ts
and constraintRewriteStepList crules (cnstr, []) = []
  | constraintRewriteStepList crules (cnstr, t::ts) =
    case constraintRewriteStep crules (cnstr, t) of
	[] => L.map (fn ts2 => t::ts2)
		    (constraintRewriteStepList crules (cnstr, ts))
      | us => L.map (fn u => u::ts) us

fun constraintOneStepReducts crules (cnstr, term as T.Var _ ) =  []
  | constraintOneStepReducts crules (cnstr, term as T.Fun (f,args,ty)) =
    (constraintRootStep crules (cnstr, term))
    @ L.map (fn ts => T.Fun (f,ts,ty)) (constraintOneStepReductsList crules (cnstr, args))
and constraintOneStepReductsList crules (cnstr,[]) = [[]]
  | constraintOneStepReductsList crules (cnstr,t::ts) = 
    let val t2s = constraintOneStepReducts crules (cnstr, t)
	val ts2s = constraintOneStepReductsList crules (cnstr, ts)
    in (L.map (fn x => x::ts) t2s) @  (L.map (fn ys => t::ys) ts2s)
    end

fun constraintZeroOrOneStepReducts crules (cnstr, term) =
    term :: constraintOneStepReducts crules (cnstr, term) 

fun constraintParallelOneStepReducts crules (cnstr, term as T.Var _ ) =  [term]
  | constraintParallelOneStepReducts crules (cnstr, term as T.Fun (f,args,ty)) =
    constraintRootStep crules (cnstr, term)
    @ L.map (fn ts => T.Fun (f,ts,ty)) (constraintParallelOneStepReductsList crules (cnstr, args))
and constraintParallelOneStepReductsList crules (cnstr,[]) = [[]]
  | constraintParallelOneStepReductsList crules (cnstr,t::ts) = 
    let val t2s = constraintParallelOneStepReducts crules (cnstr, t)
	val ts2s = constraintParallelOneStepReductsList crules (cnstr, ts)
    in ListXProd.mapX (fn (x,xs) => x::xs) (t2s, ts2s)
    end

fun constraintParallelTwoStepsReducts crules (cnstr, term) =
    let val reducts = constraintParallelOneStepReducts crules (cnstr, term)
	val reducts2 = LU.mapAppend (fn t => constraintParallelOneStepReducts crules (cnstr,t)) reducts
    in LU.eliminateDuplication' Term.equal reducts2
    end



(*
val R3 = [(IOFotrs.rdTerm "P(Q(?x))", IOFotrs.rdTerm "P(R(?x))",
	   [(IOFotrs.rdTerm "?x", IOFotrs.rdTerm "A")])
	 ,(IOFotrs.rdTerm "Q(H(?x))", IOFotrs.rdTerm "R(?x)",
	   [(IOFotrs.rdTerm "S(?x)", IOFotrs.rdTerm "H(?x)")])
	 ,(IOFotrs.rdTerm "R(?x)", IOFotrs.rdTerm "R(H(?x))",
	   [(IOFotrs.rdTerm "S(?x)", IOFotrs.rdTerm "A")])
	 ]

val cnstr = [(IOFotrs.rdTerm "S(?x)", IOFotrs.rdTerm "H(?x)")
	    ,(IOFotrs.rdTerm "H(?x)", IOFotrs.rdTerm "A")
	    ]

val term = IOFotrs.rdTerm "P(R(?x))"

val ts =  constraintParallelOneStepReducts R3 (cnstr, term)

val _ = print (LU.toStringCommaCurly Term.toString ts)
*)

(* a conditional linearization by Klop & de Vrijer  *)
(* f(x,x,x,y) -> g(x,y)  |-> f(x1,x2,x3,y1) -> g(x1,y1) <= x1 = x2, x1 = x3 *)
fun condKdVLinearizationRule (l,r) =
    let val lvars = T.varListInTerm l
	val linLhs = T.linearize l
	val freshLinLhs = Subst.renameTermDisjointFrom lvars linLhs
	val renaming = valOf (Subst.match freshLinLhs l)    (* { x1 := x, x2 := x,  x3 := x,  y1 := y } *)

	val list = L.map (fn (x,v) => (x,  valOf (Term.varRootOfTerm v), Term.sortOfTerm v))
			 (VM.listItemsi renaming)
	fun filterList [] = []
	  | filterList ((x,y,ty)::xs) = (x,y,ty):: L.filter (fn (x2,y2,ty2) => not (Var.equal (y,y2))) xs

	val filteredList = filterList list

	val reverse = L.foldl (fn ((x,y,ty),map) => VM.insert (map,y,T.Var (x,ty)))  VM.empty filteredList
			     (* { x := x1, y := y1 } *)

	val nonLinVarSet = T.nonLinearVarSetInTerm l

	val nlvList = L.rev (L.filter (fn (x,y,ty) => VS.member (nonLinVarSet, y)) list)

	val cond = L.filter (fn (s,t) => not (Term.equal (s,t)))
			    (L.map (fn (x,y,ty) => (Subst.applySubst reverse (T.Var (y,ty)), T.Var (x,ty))) nlvList)

    in (freshLinLhs, Subst.applySubst reverse r, cond) 
    end
	
fun condKdVLinearizationRules rs = L.map condKdVLinearizationRule rs


(* a conditional linearization by Klop & de Vrijer with variations in RHS *)
(* f(x,x,x) -> g(x,x) |->
   { f(x1,x2,x3) -> g(xI,xJ) <= x1 = x2, x1 = x3  |   I = 1,2,3, J = 1,2,3  } *)
fun condKdVLinearizationRuleMult (l,r) =
    let val lvars = T.varListInTerm l
	val linLhs = T.linearize l
	val freshLinLhs = Subst.renameTermDisjointFrom lvars linLhs
	val renaming = valOf (Subst.match freshLinLhs l)    (* { x1 := x, x2 := x,  x3 := x,  y1 := y } *)

	val list2 = L.map (fn (x,v) => (v, T.Var (x, T.sortOfTerm v)))
			 (VM.listItemsi renaming)
			 (* [ (x,x1), (x,x2) (x,x3) (y,y1) ] *)

	val classes = LU.classify (fn ((x1,y1),(x2,y2)) => T.equal (x1,x2)) list2
	fun mkmap [] = (PU.error "ctrs.sml: mkmap"; raise CtrsError)
	  | mkmap ((x,y)::xs) = (x, y:: L.map (fn (u,v) => v) xs)
	val maps = L.map mkmap classes   (* [ (x,[x1,x2,x3], (y, [y1]) ] *)
			 
	fun allInst (t as (T.Var x)) = (case L.find (fn (y,ys) => T.equal (t,y)) maps of
					    SOME (_,ys) => ys
					  | NONE => (PU.error "ctrs.sml: condKdVLinear"; raise CtrsError))
	  | allInst (T.Fun (f,ts,ty)) = L.map (fn us => T.Fun (f,us,ty)) (allInstList ts)
	and allInstList [] = [[]]
	  | allInstList (t::ts) =
	    let val tList = allInst t
		val tsList = allInstList ts
	    in ListXProd.mapX (fn (x,xs) => (x::xs)) (tList,tsList)
	    end

	val list = L.map (fn (x,v) => (x,  valOf (Term.varRootOfTerm v), Term.sortOfTerm v))
			 (VM.listItemsi renaming)
	fun filterList [] = []
	  | filterList ((x,y,ty)::xs) = (x,y,ty):: L.filter (fn (x2,y2,ty2) => not (Var.equal (y,y2))) xs
	val filteredList = filterList list
	val reverse = L.foldl (fn ((x,y,ty),map) => VM.insert (map,y,T.Var (x,ty)))  VM.empty filteredList
	val nonLinVarSet = T.nonLinearVarSetInTerm l
	val nlvList = L.rev (L.filter (fn (x,y,ty) => VS.member (nonLinVarSet, y)) list)
	val cond = L.filter (fn (s,t) => not (Term.equal (s,t)))
			    (L.map (fn (x,y,ty) => (Subst.applySubst reverse (T.Var (y,ty)), T.Var (x,ty))) nlvList)

    in L.map (fn r' => (freshLinLhs, r', cond))  (allInst r)
    end

fun condKdVLinearizationRulesMult [] = [[]]
  | condKdVLinearizationRulesMult (rule::rules) = 
    ListXProd.mapX (fn (x,xs) => (x::xs))
		   (condKdVLinearizationRuleMult rule, condKdVLinearizationRulesMult rules)



(*
val R4 = [(IOFotrs.rdTerm "f(?x,?x)", IOFotrs.rdTerm "h(?x,f(?x,b))")
	 ,(IOFotrs.rdTerm "f(g(?y),?y)", IOFotrs.rdTerm "h(?y,f(g(?y),a))")
	 ,(IOFotrs.rdTerm "k(?x,?x,?x,?y)", IOFotrs.rdTerm "f(?x,?y)")
	 ,(IOFotrs.rdTerm "a", IOFotrs.rdTerm "b")
	 ]

val _ = print (Trs.prRules R4)

val _ = print (prRules (condKdVLinearizationRules R4))
*)


fun attachSortToRuleWithSharedEnv (decls:Term.decl list) vartype (l,r) =
    case Term.attachSortToTermWithEnv decls (l,vartype,NONE) of 
	SOME (l',env,opty) => (case Term.attachSortToTermWithEnv decls (r, env, opty) of
				   SOME (r',env2, _ ) => SOME (l',r',env2)
				 | NONE => NONE)
      | _ => NONE

fun attachSortToRulesWithSharedEnv (decls:Term.decl list) vartype [] = SOME ([],vartype)
  | attachSortToRulesWithSharedEnv (decls:Term.decl list) vartype (lr::rules) =
    case attachSortToRuleWithSharedEnv (decls:Term.decl list) vartype lr of
	SOME (l,r,env) => (case attachSortToRulesWithSharedEnv decls env rules of
			       SOME (sortedRules,env) => SOME ((l,r)::sortedRules, env)
			    |  NONE => NONE)
     |  NONE => NONE
			       
fun attachSortToCondRule (decls:Term.decl list) (l,r,c) =
    case attachSortToRulesWithSharedEnv decls VM.empty ((l,r)::c) of
	NONE => NONE
      | SOME (sortedRules,env) => (case Trs.attachSortToRulesWithEnv decls env sortedRules of
				      SOME ((l,r)::c) => SOME (l,r,c)
				    | _ => NONE)

fun attachSortToCondRules (decls:Term.decl list) crules =
    let val ans = L.map (attachSortToCondRule decls) crules
    in if L.all isSome ans
       then SOME (L.mapPartial (fn x=>x) ans)
       else NONE
    end


end (* of local *)

end; (* of structure Trs *)

