(******************************************************************************
 * Copyright (c) 2014-2015, Toyama&Aoto Laboratory, Tohoku University
 * Copyright (c) 2016-2023, Aoto Laboratory, Niigata University
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without 
 * modification, are permitted provided that the following conditions are met:
 * 
 *  1. Redistributions of source code must retain the above copyright notice, 
 *     this list of conditions and the following disclaimer.
 *  2. Redistributions in binary form must reproduce the above copyright 
 *     notice, this list of conditions and the following disclaimer in the 
 *     documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 
 * POSSIBILITY OF SUCH DAMAGE.
 ******************************************************************************)
(******************************************************************************
 * file: rwtools/term_rewriting/term.sml
 * description: definition and utility functions for Terms
 * author: AOTO Takahito
 * 
 ******************************************************************************)

signature TERM_KEY =
   sig
   include ORD_KEY2 (* { val ord_key, fun compare, fun equal, fun toString }*)
   type var_key = Var.ord_key
   type fun_key = Fun.ord_key
   type sort_key = Sort.ord_key
   datatype term = Var of var_key * sort_key | Fun of fun_key * term list * sort_key
       (** Ȥ򵭽ҤϤ뤬ڥåϤʤ **)

   end

structure TermKey : TERM_KEY = 
struct
   type var_key = Var.ord_key
   type fun_key = Fun.ord_key
   type sort_key = Sort.ord_key
   datatype term = Var of var_key * sort_key | Fun of fun_key * term list * sort_key
   type ord_key = term

   fun compare (Var (x,_), Var (y,_)) = Var.compare (x,y)
     | compare (Var _, Fun _) = LESS
     | compare (Fun _, Var _) = GREATER
     | compare (Fun (f,ts,_), Fun (g,ss,_)) =
       case Fun.compare (f,g)
	of EQUAL => compareList ts ss
	 | LESS => LESS
	 | GREATER => GREATER
   and compareList [] [] = EQUAL
     | compareList (t::ts) [] = GREATER
     | compareList [] (s::ss) = LESS
     | compareList (t::ts) (s::ss) = 
       case compare (t,s)
	of EQUAL => compareList ts ss
	 | LESS => LESS
	 | GREATER => GREATER

   fun equal (s,t) = compare (s,t) = EQUAL

   fun toString (Var (x,ty)) = (Var.toString x)
	 | toString (Fun (f,[],ty)) = (Fun.toString f)
	 | toString (Fun (f,ts,ty)) =  (Fun.toString f) ^ (ListUtil.toStringCommaRound toString ts)

   end 

structure TermSet = RedBlackSetFn (TermKey) : ORD_SET
structure TermMap = RedBlackMapFn (TermKey) : ORD_MAP
structure TermPair = OrdKey2PairFn (structure A = TermKey structure B = TermKey) : ORD_KEY2
structure TermPairSet = RedBlackSetFn (TermPair) : ORD_SET
structure TermPairMap = RedBlackMapFn (TermPair) : ORD_MAP
structure TermTriple = OrdKey2TripleFn (structure A = TermKey structure B = TermKey structure C = TermKey) : ORD_KEY2
structure TermTripleSet = RedBlackSetFn (TermTriple) : ORD_SET
structure TermTripleMap = RedBlackMapFn (TermTriple) : ORD_MAP
structure TermInt = OrdKey2PairFn (structure A = TermKey structure B = Int2) : ORD_KEY2
structure TermIntSet = RedBlackSetFn (TermInt) : ORD_SET
structure TermIntMap = RedBlackMapFn (TermInt) : ORD_MAP
structure IntTerm = OrdKey2PairFn (structure A = Int2 structure B = TermKey) : ORD_KEY2
structure IntTermSet = RedBlackSetFn (IntTerm) : ORD_SET
structure IntTermMap = RedBlackMapFn (IntTerm) : ORD_MAP
structure IntListTerm = OrdKey2PairFn (structure A = IntList structure B = TermKey) : ORD_KEY2
structure IntListTermSet = RedBlackSetFn (IntListTerm) : ORD_SET
structure IntListTermMap = RedBlackMapFn (IntListTerm) : ORD_MAP
							     
structure SortedVar = OrdKey2PairFn (structure A = Var structure B = Sort) : ORD_KEY2
structure SortedVarSet = RedBlackSetFn (SortedVar) : ORD_SET

signature TERM = 
   sig
   include TERM_KEY

   type decl = {sym:fun_key, sort:sort_key}
   val prDecl: decl -> string 
   val prDecls: decl list -> string 
   val prDecl2: int -> decl -> string 
   val eqDecl: decl * decl -> bool

(*** δŪ ***)
   val toStringWithSort: term -> string
   val toStringWithVarSort: term -> string
   val toStringWithoutQuestion: term -> string  (* without the leading '?' on variables *)
   val toProofTree: term -> unit -> string

   val mkFunTerm: fun_key * term list * sort_key -> term
   val decFunTerm: term -> (fun_key * term list * sort_key) option
   val mkVarTerm: var_key * sort_key -> term
   val decVarTerm: term -> (var_key * sort_key) option

   val termSize: term -> int
   val termDepth: term -> int
   val isVar: term -> bool
   val isFun: term -> bool
   val isConstant: term -> bool
   val funRootOfTerm: term -> Fun.ord_key option
   val varRootOfTerm: term -> Var.ord_key option
   val argsOfTerm: term -> term list
   val haveSameRoots:  term * term -> bool
   val termListMinus: term list * term list -> term list

   val sortSetInTerm: term -> SortSet.set
   val sortSetInTerms: term list -> SortSet.set

   val sortSetInDecl: decl -> SortSet.set
   val sortSetInDecls: decl list -> SortSet.set

   val declToTerm: decl -> term
   val declsToTerms: decl list -> term list

   val varSetInTerm: term -> VarSet.set
   val varSetInTerms: term list -> VarSet.set

   val varTermSetInTerm: term -> TermSet.set
   val varTermSetInTerms: term list -> TermSet.set

   val varListInTerm: term -> var_key list
   val varListInTerms: term list -> var_key list

   val linearVarListInTerm: term -> var_key list
   val linearVarListInTerms: term list -> var_key list
   val linearVarSetInTerm: term -> VarSet.set
   val linearVarSetInTerms: term list -> VarSet.set

   val nonLinearVarListInTerm: term -> var_key list
   val nonLinearVarSetInTerm: term -> VarSet.set
   val linearize: term -> term

   val sortedVarSetInTerm: term -> SortedVarSet.set
   val sortedVarSetInTerms: term list -> SortedVarSet.set
   val sortedVars: term -> (var_key * sort_key) list

   val funSetInTerm: term -> FunSet.set
   val funSetInTerms: term list -> FunSet.set

   val nonConstantFunSetInTerm: term -> FunSet.set
   val nonConstantFunSetInTerms: term list -> FunSet.set

   val funAritySetInTerm: term -> FunIntSet.set
   val funAritySetInTerms: term list -> FunIntSet.set

   val funArityMapInTerm: term -> int FunMap.map
   val funArityMapInTerms: term list -> int FunMap.map

   val signatureInTerm: term -> (sort_key list * sort_key) FunMap.map
   val signatureInTerms: term list -> (sort_key list * sort_key) FunMap.map

   val sortOfTerm: term -> sort_key
   val haveSameSort: term * term -> bool
   val isGroundTerm: term -> bool
   val isLinearTerm: term -> bool
   val isShallowTerm: term -> bool
   val numberOfFSymsInTerm: term -> int
   val changeFunNameInTerm: fun_key * fun_key -> term -> term

   val isConstructorTerm: fun_key list -> term -> bool
   val isBasicTerm: (fun_key list * fun_key list) -> term -> bool

(*** ѿ ***)
   val maxVarIndexInTerm: term -> int
   val maxVarIndexInTerms: term list -> int
   val rangeVarIndexInTerm: term -> int * int 
   val rangeVarIndexInTerms: term list -> int * int
   val increaseVarIndexBy: int -> term -> term 
  val linearVarSet: term -> VarSet.set

(*** ˤĤƤδؿ ***)
  val attachRootSortToTerm: sort_key -> term -> term 
  val attachSortToTerm: decl list -> term -> term option
  val attachSortToTermWithEnv: decl list 
			       -> (term * sort_key VarMap.map * sort_key option)
			       -> (term * sort_key VarMap.map * sort_key option) option

(*** ʬϢ ***)

   type position
   val prPosition: position -> string
   val prPosInProofTree:  int -> unit -> string
   val prPositionInProofTree: position -> unit -> string
   val isPrefixPosition: position -> position -> bool
   val positionsInTerm: term -> position list
   val funPositionsInTerm: term -> position list
   val varPositionsInTerm: term -> position list
					    
   val subterm: position -> term -> term option
   val isASubterm: term ->  term -> bool
   val isAProperSubterm: term ->  term -> bool
   val hasASubtermOC: term ->  term -> position option
   val replaceSubterm: term -> position -> term -> term option
   val collectSubtermOCs: (term -> bool) -> term -> position list
   val posListMinus: position list * position list -> position list
   val subterms: term ->  term list
   val properSubterms: term ->  term list
   val varSubterms: term -> term list
   val nonVarSubterms: term -> term list
   val nonVarProperSubterms: term -> term list

   val setOfSubterms: term -> TermSet.set
   val setOfProperSubterms: term -> TermSet.set
   val sortOfVarInTerm: term -> var_key -> sort_key option

   val hasNoOccurrence: var_key -> term -> bool
   val hasOneOccurrence: var_key -> term -> bool
   val hasAtMostOneOccurrence: var_key -> term -> bool


   val constructorCap: FunSet.set -> term -> term * VarSet.set
       (* constructor cap and introduced fresh variable symbols *) 
   val basicCap: (FunSet.set * FunSet.set) -> term -> (term * VarSet.set) option
       (* basic cap and introduced fresh variable symbols *) 

   val embGt: term * term -> bool
   val embGe: term * term -> bool

  (* congruence closure *)
   val checkCC: (term * term) list -> term * term -> bool

   end;

structure Term : TERM = 
   struct
   open TermKey

   type ord_key = term
   type decl = {sym:fun_key, sort:sort_key}

   local 
       structure CU = CertifyUtil
       structure L = List
       structure LP = ListPair
       structure LU = ListUtil
       structure VS = VarSet
       structure VM = VarMap
       structure FS = FunSet
       structure FM = FunMap
       structure S = Sort
       structure SS = SortSet
       structure SVS = SortedVarSet
       structure FIS = FunIntSet
       structure TS = TermSet
       structure TPM = TermPairMap
       structure TPS = TermPairSet
       fun mapAppend f xs = List.foldr (fn (x,ys) => List.@(f x, ys)) [] xs
   in

   val compare = TermKey.compare
   val equal = TermKey.equal
   val toString = TermKey.toString

   fun toStringWithSort (Var (x,ty)) = (Var.toString x) ^ ":" ^ (Sort.toString ty)
	 | toStringWithSort (Fun (f,[],ty)) = (Fun.toString f) ^ ":" ^ (Sort.toString ty)
	 | toStringWithSort (Fun (f,ts,ty)) = (Fun.toString f) 
										  ^ (ListUtil.toStringCommaRound toStringWithSort ts)
										  ^ ":" ^ (Sort.toString ty)

   fun toStringWithVarSort (Var (x,ty)) = (Var.toString x) ^ ":" ^ (Sort.toString ty)
	 | toStringWithVarSort (Fun (f,[],ty)) = (Fun.toString f)
	 | toStringWithVarSort (Fun (f,ts,ty)) = (Fun.toString f) ^ (ListUtil.toStringCommaRound toStringWithVarSort ts)

   fun toStringWithoutQuestion (Var (x,ty)) = (Var.toStringWithoutQuestion x)
	 | toStringWithoutQuestion (Fun (f,[],ty)) = (Fun.toString f)
	 | toStringWithoutQuestion (Fun (f,ts,ty)) =  (Fun.toString f) ^ (ListUtil.toStringCommaRound toStringWithoutQuestion ts)

  fun toProofTree (Var (x,_)) () = Var.toProofTree x ()
    | toProofTree (Fun (f,[],_)) () = CU.encloseProofTreeBy "funapp" (fn _ => (Fun.toProofTree f ()))
    | toProofTree (Fun (f,ts,_)) () = 
      CU.encloseProofTreesBy "funapp"
			     ((fn _ => Fun.toProofTree f ())
			      ::(L.map (fn t => fn _ => (CU.encloseProofTreeBy "arg" (fn _ => (toProofTree t ())))) ts))

   fun mkFunTerm (f,ts,ty) = Fun (f,ts,ty)
   fun decFunTerm (Var _) = NONE
     | decFunTerm (Fun (f,ts,ty)) = SOME (f,ts,ty)

   fun mkVarTerm (x,ty) = Var (x,ty)
   fun decVarTerm (Var (x,ty)) = SOME (x,ty)
     | decVarTerm (Fun _) = NONE

   fun termSize (Var _) = 1
     | termSize (Fun (_,ts,_)) = L.foldr (fn (t,j) => (termSize t) + j) 1 ts

   fun termDepth (Var _) = 1
     | termDepth (Fun (_,ts,_)) = 
       (L.foldr (fn (t,j) => Int.max (termDepth t,j)) 0 ts) + 1

   fun isVar (Var _) = true
     | isVar (Fun _) = false

   fun isFun (Var _) = false
     | isFun (Fun _) = true

   fun isConstant (Var _) = false
     | isConstant (Fun (_,ts,_)) = null ts

   fun funRootOfTerm (Var _) = NONE
     | funRootOfTerm (Fun (f,_,_)) = SOME f

   fun varRootOfTerm (Var (x,_)) = SOME x
     | varRootOfTerm (Fun (f,_,_)) = NONE

   fun argsOfTerm (Var _) = []
     | argsOfTerm (Fun (_,ts,_)) = ts

   fun haveSameRoots (Var (x,_), Var (y,_)) = Var.equal (x,y)
     | haveSameRoots (Var _, Fun _) = false
     | haveSameRoots (Fun _, Var _) = false
     | haveSameRoots (Fun (f,_,_), Fun (g,_,_)) = Fun.equal (f,g)

   fun sortSetInTerm (Var (_,ty)) = SS.singleton ty
     | sortSetInTerm (Fun (_,ts,ty)) = 
       L.foldr (fn (t,xs) => SS.union (sortSetInTerm t, xs)) (SS.singleton ty) ts

   fun sortSetInTerms ts = 
       L.foldr (fn (t,xs) => SS.union (sortSetInTerm t, xs)) SS.empty ts

   fun varSetInTerm (Var (x,_)) = VS.singleton x
     | varSetInTerm (Fun (f,ts,_)) = 
       L.foldr (fn (t,xs) => VS.union (varSetInTerm t, xs)) VS.empty ts

   fun varSetInTerms ts = 
       L.foldr (fn (t,xs) => VS.union (varSetInTerm t, xs)) VS.empty ts

   fun varTermSetInTerm (t as Var _) = TS.singleton t
     | varTermSetInTerm (Fun (f,ts,_)) = 
       L.foldr (fn (t,xs) => TS.union (varTermSetInTerm t, xs)) TS.empty ts

   fun varTermSetInTerms ts = 
       L.foldr (fn (t,xs) => TS.union (varTermSetInTerm t, xs)) TS.empty ts

   fun varListInTerm (Var (x,_)) = [x]
     | varListInTerm (Fun (f,ts,_)) = 
       L.foldr (fn (t,xs) => List.@ (varListInTerm t, xs)) [] ts

   fun varListInTerms ts = 
       L.foldr (fn (t,xs) => List.@ (varListInTerm t, xs)) [] ts

   fun linearVarListInTerms ts =
       let val vs = varListInTerms ts
	   fun selectLinear [] = []
	     | selectLinear (x::xs) = 
	       if ListUtil.member' Var.equal x xs
	       then selectLinear (ListUtil.deleteAll' Var.equal x xs)
	       else x::(selectLinear xs)
       in selectLinear vs
       end
   fun linearVarListInTerm t = linearVarListInTerms [t]
   fun linearVarSetInTerm t = VS.addList (VS.empty, linearVarListInTerm t)
   fun linearVarSetInTerms ts = VS.addList (VS.empty, linearVarListInTerms ts)

(******* incorrect
   fun linearVarSetInTerm (Var (x,_)) = VS.singleton x
     | linearVarSetInTerm (Fun (f,[],_)) = VS.empty
     | linearVarSetInTerm (Fun (f,t::ts,_)) = 
       let val vs1 = linearVarSetInTerm t
	   val vs2 = linearVarSetInTerms ts
       in
	   VS.union (VS.difference (vs1,vs2), VS.difference (vs2,vs1))
       end
   and linearVarSetInTerms [] =  VS.empty
     | linearVarSetInTerms (t::ts) =  
       let val vs1 = linearVarSetInTerm t
	   val vs2 = linearVarSetInTerms ts
       in
	   VS.union (VS.difference (vs1,vs2), VS.difference (vs2,vs1))
       end
***)

   fun nonLinearVarListInTerm t =
       let val vs = varListInTerm t
	   fun selectNonLinear [] = []
	     | selectNonLinear (x::xs) = 
	       if ListUtil.member' Var.equal x xs
	       then x::selectNonLinear (ListUtil.deleteAll' Var.equal x xs)
	       else selectNonLinear xs
       in selectNonLinear vs
       end
   fun nonLinearVarSetInTerm t = VS.addList (VS.empty, nonLinearVarListInTerm t)

   fun linearize term =
       let fun freshenVar x vs = if LU.member' Var.equal x vs
				 then freshenVar (Var.increaseIndexBy 1 x) vs
				 else (x,x::vs)
	   fun linearizeSub (t as (Var (x,ty))) vs =
	       let val (x',vs') = freshenVar x vs
	       in (Var (x',ty), vs') end
	     | linearizeSub (Fun (f,ts,ty)) vs = 
	       let val (ts',vs') = linearizeSubList ts vs
	       in (Fun (f,ts',ty), vs') end
	   and linearizeSubList [] vs = ([],vs)
	     | linearizeSubList (t::ts) vs = 
	       let val (t',vs') = linearizeSub t vs
		   val (ts',vs'') = linearizeSubList ts vs'
	       in (t'::ts',vs'') end
	   val (ans,_) = linearizeSub term []
       in ans
       end


   fun sortedVarSetInTerm (Var (x,ty)) = SVS.singleton (x,ty)
     | sortedVarSetInTerm (Fun (f,ts,_)) = 
       L.foldr (fn (t,xs) => SVS.union (sortedVarSetInTerm t, xs)) SVS.empty ts

   fun sortedVarSetInTerms ts = 
       L.foldr (fn (t,xs) => SVS.union (sortedVarSetInTerm t, xs)) SVS.empty ts

   (* occurrence order sensitive list *)
   fun sortedVars (Var (x,ty)) = [(x,ty)]
     | sortedVars (Fun (f,ts,_)) = sortedVarsList ts
   and sortedVarsList [] = []
     | sortedVarsList (t::ts) =
       let val xs = sortedVars t
	   val ys = sortedVarsList ts
	   val ys2 = L.filter (fn y=> not (LU.member' (fn ((u,v),(u',v')) => Var.equal (u,u') andalso Sort.equal (v,v'))
						      y xs)) ys
       in xs@ys2
       end
						      

   fun funSetInTerm (Var _) = FS.empty
     | funSetInTerm (Fun (f,ts,_)) = 
       FS.add (L.foldr (fn (t,xs) => FS.union (funSetInTerm t, xs)) FS.empty ts, 
	       f)

   fun funSetInTerms ts = 
       L.foldr (fn (t,xs) => FS.union (funSetInTerm t, xs)) FS.empty ts

   fun nonConstantFunSetInTerm (Var _) = FS.empty
     | nonConstantFunSetInTerm (Fun (f,[],_)) = FS.empty
     | nonConstantFunSetInTerm (Fun (f,ts,_)) = 
       FS.add (L.foldr (fn (t,xs) => FS.union (nonConstantFunSetInTerm t, xs)) FS.empty ts, 
	       f)

   fun nonConstantFunSetInTerms ts = 
       L.foldr (fn (t,xs) => FS.union (nonConstantFunSetInTerm t, xs)) FS.empty ts

   fun funAritySetInTerm (Var _) = FIS.empty
     | funAritySetInTerm (Fun (f,ts,_)) = 
       FIS.add (L.foldr (fn (t,xs) => FIS.union (funAritySetInTerm t, xs)) FIS.empty ts, 
		(f, List.length ts))

   fun funAritySetInTerms ts = 
       L.foldr (fn (t,xs) => FIS.union (funAritySetInTerm t, xs)) FIS.empty ts

   fun funArityMapInTerm (Var _) = FM.empty
     | funArityMapInTerm (Fun (f,ts,_)) = 
       FM.insert (L.foldr (fn (t,xs) => FM.unionWith (fn (x,y) => x) (funArityMapInTerm t, xs)) 
			  FM.empty ts, 
		  f, 
		  List.length ts)
       
   fun funArityMapInTerms ts = 
       L.foldr (fn (t,xs) => FM.unionWith (fn (x,y) => x) (funArityMapInTerm t, xs))
	       FM.empty ts

(*   val isWellSortedTerm: term -> sort_key option *)

   fun sortOfTerm (Var (_,s)) = s
     | sortOfTerm (Fun (_,_,s)) = s

   fun haveSameSort (s,t) = Sort.equal (sortOfTerm s,  sortOfTerm t)

   fun signatureInTerm (Var _) = FM.empty
     | signatureInTerm (Fun (f,ts,ty)) = 
       FM.insert (signatureInTerms ts, f, (map sortOfTerm ts, ty))
   and signatureInTerms ts = 
       L.foldr (fn (t,f) => FM.unionWith (fn (x,y) => x) (signatureInTerm t, f)) FM.empty ts

   fun prDecl (decl:decl) = (Fun.toString (#sym decl)) ^ " : " ^ (Sort.toString (#sort decl))
   val prDecls = PrintUtil.prList prDecl

   fun prDecl2 maxlen (decl:decl) = 
       let val fname = Fun.toString (#sym decl)
	   val len = String.size fname
	   val fnameWithSpace = if len > maxlen
				then fname ^ "   "
				else fname ^ (String.concat (L.tabulate ((maxlen + 3) - len,fn x => " ")))
       in fnameWithSpace ^ (Sort.toString2 (#sort decl))
       end


   fun eqDecl (d1:decl,d2:decl) = Fun.equal (#sym d1, #sym d2)
				 andalso Sort.equal (#sort d1, #sort d2)

   fun sortSetInDecl (decl:decl)= AtomSet.foldl
				      (fn (a,set) => SortSet.add (set,Sort.Base a))
				      SortSet.empty
				      (Sort.basicSortSetInSort (#sort decl))

   fun sortSetInDecls decls = List.foldl 
				  (fn (decl,set) => SortSet.union (sortSetInDecl decl,set))
				  SortSet.empty
				  decls

   fun declToTerm (decl:decl) =
       let val f = #sym decl
	   val ty = #sort decl
	   val count = ref 0
	   fun newVar () = (count := (!count) + 1; Var.fromStringAndInt ("x",!count))
	   fun newSortedVar ty = Var (newVar (), ty)
       in if Sort.isBaseType ty
	  then Fun (f, [], ty)
	  else Fun (f, L.map newSortedVar (Sort.args ty), Sort.return ty)
       end

   fun declsToTerms decls = L.map declToTerm decls


   fun isGroundTerm (Var _) = false
     | isGroundTerm (Fun (_,ts,_)) = L.all isGroundTerm ts

   fun isLinearTerm t =
       let exception NotLinear
	   fun collectVar vset (Var (x,_)) = if VS.member(vset,x)
					     then raise NotLinear
					     else VS.add(vset,x) 
	     | collectVar vset (Fun (_,ts,_)) = collectVarList vset ts
	   and collectVarList vset [] = vset
	     | collectVarList vset (t::ts) = collectVarList (collectVar vset t) ts
       in
	   (collectVar VS.empty t; true)
	   handle NotLinear => false
       end

   fun isShallowTerm (Var _) = true
     | isShallowTerm (Fun (_,ts,_)) = L.all (fn t => isVar t orelse isGroundTerm t) ts

   fun isConstructorTerm csyms (Var _) = true
     | isConstructorTerm csyms (Fun (f,ts,_)) = LU.member' Fun.equal f csyms
						andalso L.all (isConstructorTerm csyms) ts

   fun isBasicTerm (csyms,dsyms) (Var _) = true
     | isBasicTerm (csyms,dsyms) (Fun (f,ts,_)) = LU.member' Fun.equal f dsyms
						  andalso L.all (isConstructorTerm csyms) ts


   fun numberOfFSymsInTerm (Var _) = 0
     | numberOfFSymsInTerm (Fun (_,ts,_)) = 
       L.foldl (fn (ti,n) => (numberOfFSymsInTerm ti) + n) 1 ts

   (* change h to g *)
   fun changeFunNameInTerm (h,g) (Var x) = Var x
     | changeFunNameInTerm (h,g) (Fun (f,ts,ty)) =
       if Fun.equal (f,h)
       then Fun (g,L.map (changeFunNameInTerm (h,g)) ts,ty)
       else Fun (f,L.map (changeFunNameInTerm (h,g)) ts,ty)

   fun attachRootSortToTerm ty (Var (x,_)) = Var (x,ty)
    |  attachRootSortToTerm ty (Fun (f,ts,_)) = Fun (f,ts,ty)

   (* well-sorted ʤ SOME sortedTerm ֤Ǥʤ NONE ֤ *)
   fun attachSortToTerm (decls:decl list) term =
       case term of
	   Var _ => SOME term
	 | Fun (f,ts,_) =>
	   let
	       val spec = L.find (fn x => (Fun.equal (#sym x, f))) decls
	       val ss = map (attachSortToTerm decls) ts
	   in
	       if (isSome spec) andalso (L.all isSome ss) 
	       then 
		   if null ts
		   then 
		       case #sort (valOf spec) of
				   S.Base ty => SOME (Fun (f,[],#sort (valOf spec)))
				 | S.Proc ty => NONE
		   else 
		       let 
			   val args = L.map valOf ss
			   val sorts = L.map sortOfTerm args
		       in case #sort (valOf spec) of
			      S.Base ty => NONE
			    | S.Proc (types,ty) =>
					  if (length sorts) = (length types)
					  andalso (ListPair.all (fn (x,y) => Sort.equal (x, y)) (sorts,types))
						  then SOME (Fun (f,args,ty))
					  else NONE
		       end
	       else
		   NONE
	   end

   (* well-sorted ʤ SOME (sortedTerm,var-type env,optional type) ֤
		  Ǥʤ NONE ֤ *)
   fun attachSortToTermWithEnv (decls:decl list) (Var (x,_),env,opty) = 
       (case VM.find (env,x) of
	   SOME ty => (case opty of
			  SOME ty' => if Sort.equal (ty,ty')
				      then SOME (Var (x,ty),env,opty)
				      else NONE
			| NONE => SOME (Var (x,ty),VM.insert (env,x,ty), SOME ty))
	 | NONE => (case opty of
			SOME ty' => SOME (Var (x,ty'),VM.insert (env,x,ty'), opty)
		      | NONE => NONE))
     | attachSortToTermWithEnv (decls:decl list) (Fun (f,ts,_),env,opty) = 
       (case L.find (fn x => Fun.equal (#sym x, f)) decls of
	   NONE => NONE
	 | SOME spec =>
	   let
	       val (argsorts,retsort) =  case #sort (spec:decl) of
					     Sort.Base ty => ([],Sort.Base ty)
					   | Sort.Proc (tys,ty) => (tys,ty)
	   in case opty of
		  NONE => (case attachSortToTermWithEnvList decls (ts,env,argsorts) of
			       SOME (ss,env',_) => SOME (Fun (f,ss,retsort),env',SOME retsort)
			     | NONE => NONE)
		| SOME ty => if Sort.equal (ty,retsort)
			     then case attachSortToTermWithEnvList decls (ts,env,argsorts) of
				      SOME (ss,env',_) => SOME (Fun (f,ss,retsort),env',SOME retsort)
				    | NONE => NONE
			     else NONE
	   end)
   and attachSortToTermWithEnvList (decls:decl list) ([],env,[]) = SOME ([],env,[])
     | attachSortToTermWithEnvList (decls:decl list) (ti::ts,env,ty::tys) = 
       (case attachSortToTermWithEnv decls (ti,env,SOME ty) of
	   SOME (ti',env2,_) => 
	   (case attachSortToTermWithEnvList (decls:decl list) (ts,env2,tys) of
	       SOME (ts',env3,_) =>  SOME (ti'::ts',env3,ty::tys)
	     | NONE => NONE)
	 | NONE => NONE)
     | attachSortToTermWithEnvList (decls:decl list) _ = NONE

   fun termListMinus ([],ss) = []
     | termListMinus (t::ts,ss) = 
       if L.exists (fn s => equal (s,t)) ss
       then termListMinus (ts,ss)
       else t::termListMinus (ts,ss)
       
   type position = Pos.ord_key

   val prPosition = Pos.toString

   fun prPosInProofTree n () = CU.encloseProofLeafBy "position" (Int.toString n)
   fun prPositionInProofTree p () = CU.encloseProofTreesBy "positionInTerm" (L.map prPosInProofTree p)

   fun isPrefixPosition [] q = true
     | isPrefixPosition (x::xs) (y::ys) = (x = y) andalso (isPrefixPosition xs ys)

   fun positionsInTerm (Var _) = [[]]
     | positionsInTerm (Fun (_,ts,_)) = 
       let val argpos = L.map positionsInTerm ts
       in []:: (LU.mapAppend (fn i=> L.map (fn p=> (i+1)::p) (L.nth (argpos,i))) (L.tabulate (L.length ts,fn x=>x)))
       end

   fun funPositionsInTerm (Var _) = []
     | funPositionsInTerm (Fun (_,ts,_)) = 
       let val argpos = L.map funPositionsInTerm ts
       in []:: (LU.mapAppend (fn i=> L.map (fn p=> (i+1)::p) (L.nth (argpos,i))) (L.tabulate (L.length ts,fn x=>x)))
       end

   fun varPositionsInTerm (Var _) = [[]]
     | varPositionsInTerm (Fun (_,ts,_)) = 
       let val argpos = L.map varPositionsInTerm ts
       in (LU.mapAppend (fn i=> L.map (fn p=> (i+1)::p) (L.nth (argpos,i))) (L.tabulate (L.length ts,fn x=>x)))
       end

   local
       exception NotSubtermOC;
       fun subterm' [] term = term
	 | subterm' (p::ps) (Var _) = raise NotSubtermOC
	 | subterm' (p::ps) (Fun (f,ts,_)) = 
	   if p > (L.length ts) then raise NotSubtermOC
	   else subterm' ps (L.nth (ts,p-1))

       fun replaceSubterm' t [] u = u
	 | replaceSubterm' (Var _) (p::ps) u = raise NotSubtermOC
	 | replaceSubterm' (Fun (f,ts,sort)) (p::ps) u = 
	   if p > (L.length ts) then raise NotSubtermOC
	   else Fun (f, 
		     L.@(L.take (ts, p-1),
			 (replaceSubterm' (L.nth (ts,p-1)) ps u)::(List.drop (ts,p))),
		     sort)
   in
   fun subterm pos t = 
       let val s = (subterm' pos t)
       in  SOME s
       end
       handle NotSubtermOC => NONE
   fun replaceSubterm t pos u = 
       let val s = (replaceSubterm' t pos u)
       in  SOME s
       end
       handle NotSubtermOC => NONE
   end

  (* s  t ʬ *)
   fun isASubterm s (t as Var _) = equal (s,t)
     | isASubterm (s as Var _) (t as Fun (g,ts,_)) = 
       L.exists (isASubterm s) ts
     | isASubterm (s as Fun (f,ss,_)) (t as Fun (g,ts,_)) = 
       if Fun.equal (f,g) 
       then equal (s,t) orelse L.exists (isASubterm s) ts
       else L.exists (isASubterm s) ts

  (* s  t οʬ *)
   fun isAProperSubterm s (t as Var _) = false
     | isAProperSubterm s (t as Fun (g,ts,_)) = L.exists (isASubterm s) ts

   fun hasASubtermOC s (t as Var _) = if equal (s,t) 
				      then SOME []
				      else NONE
     | hasASubtermOC s  (t as Fun (g,ss,_)) = if equal (s,t) 
					      then SOME []
					      else hasASubtermOCList s ss 1
   and hasASubtermOCList s [] _ = NONE
     | hasASubtermOCList s (t::ts) n = 
       case hasASubtermOC s t of
	   SOME ps => SOME (n::ps)
	 | NONE => hasASubtermOCList s ts (n+1)


   fun collectSubtermOCs pred t =
       let
	   fun walkTopDown pred (t as (Var _), revpos) = 
	       if pred t then [revpos] else []
	     | walkTopDown pred (t as (Fun (_,ts,_)), revpos) = 
	       let val pps = walkTopDownList pred 
					     (ListPair.zip (ts,
							    List.tabulate (List.length ts, 
								       fn n => (n+1::revpos))))
	       in
		   if pred t then (revpos::pps) else pps
	       end
	   and walkTopDownList pred [] = []
	     | walkTopDownList pred ((t,p)::tsps) = 
	       let val pps = walkTopDownList pred tsps
		   val qqs = walkTopDown pred (t,p)
	       in qqs @ pps
	       end
       in
	   List.map rev (walkTopDown pred (t,[]))
       end

(*    fun varSubterms s =  *)
(*        let *)
(* 	   fun vsSub (s as Var _) vs = if L.exists (fn t => equal (s,t)) vs *)
(* 				       then vs *)
(* 				       else (s::vs) *)
(* 	     | vsSub (s as Fun (_,ts,_)) vs = vsSubList ts vs *)
(* 	   and vsSubList [] vs = vs *)
(* 	     | vsSubList (s::ss) vs = vsSubList ss (vsSub s vs) *)
(*        in *)
(* 	   vsSub s [] *)
(*        end *)

   fun posListMinus ([],qs) = []
     | posListMinus (p::ps,qs) = 
       let val (qs1,qs2) = List.partition (fn q => p = q) qs
       in  if null qs1
	   then p::(posListMinus (ps,qs))
	   else posListMinus (ps, (tl qs1) @ qs2)
       end
   

   local
       fun allSubtermsSub [] cond acc = acc
	 | allSubtermsSub (s0::ss) cond acc = 
	   if not (cond s0)
	   (* then allSubtermsSub ss cond acc   -- bug corrected at 2017/02/22 *)
	   then allSubtermsSub ((argsOfTerm s0) @ ss) cond acc 
	   else if L.exists (fn t => equal (s0,t)) acc
	   then allSubtermsSub ss cond acc
	   else 
	       case s0 of
		   Var _ => allSubtermsSub ss cond (s0::acc)
		 | Fun (_,ts,_) => allSubtermsSub (ts @ ss) cond (s0::acc)
   in
   fun subterms s = allSubtermsSub [s] (fn _ => true) []
   fun properSubterms (Var _)  = []
     | properSubterms (Fun (_,ts,_)) = allSubtermsSub ts (fn _ => true) []
   fun varSubterms s = allSubtermsSub [s] isVar []
   fun nonVarSubterms s = allSubtermsSub [s] (not o isVar) []
   fun nonVarProperSubterms (Var _)  = []
     | nonVarProperSubterms (Fun (_,ts,_)) = allSubtermsSub ts (not o isVar) []
   end

   fun hasOneOccurrence x (Var (y,_)) = Var.equal (x,y)
     | hasOneOccurrence x (Fun (_,ts,_)) = hasOneOccurrenceList x ts
   and hasOneOccurrenceList x [] = false
     | hasOneOccurrenceList x (t::ts) =  
       ((hasOneOccurrence x t) andalso (hasNoOccurrenceList x ts))
       orelse ((hasNoOccurrence x t) andalso (hasOneOccurrenceList x ts))
   and hasNoOccurrence x (Var (y,_)) = not (Var.equal (x,y))
     | hasNoOccurrence x (Fun (_,ts,_)) = hasNoOccurrenceList x ts
   and hasNoOccurrenceList x [] = true
     | hasNoOccurrenceList x (t::ts) = 
       (hasNoOccurrence x t) andalso (hasNoOccurrenceList x ts)
   and hasAtMostOneOccurrence x (Var (y,_)) = true
     | hasAtMostOneOccurrence x (Fun (_,ts,_)) = hasAtMostOneOccurrenceList x ts
   and hasAtMostOneOccurrenceList x [] = true
     | hasAtMostOneOccurrenceList x (t::ts) = 
       ((hasAtMostOneOccurrence x t) andalso (hasNoOccurrenceList x ts))
       orelse ((hasNoOccurrence x t) andalso (hasAtMostOneOccurrenceList x ts))


   fun setOfSubterms (t as (Var _)) = TS.singleton t 
     | setOfSubterms (t as (Fun (f,ts,ty))) =
		      List.foldl
			  (fn (ti,set) => TS.union (setOfSubterms ti,set))
			  (TS.singleton t)
			  ts 

   fun setOfProperSubterms (t as (Var _)) = TS.empty
     | setOfProperSubterms (t as (Fun (f,ts,ty))) =
			    List.foldl
				(fn (ti,set) => TS.union (setOfSubterms ti,set))
				TS.empty 
				ts 

   fun linearVarSet t = linearVarSetInTerm t 

   fun sortOfVarInTerm t x =
       case List.find (fn u => Var.equal (x, valOf (varRootOfTerm u))) (varSubterms t) of
		       NONE => NONE
		     | SOME u => SOME (sortOfTerm u)


(** use the same function 
   fun linearVarSet t = 
       let fun linVS (lvset,vset) (Var (x,_)) = 
	       if VS.member (vset, x) 
	       then if VS.member (lvset, x)
		    then (VS.delete(lvset, x), vset)
		    else (lvset, vset)
	       else (VS.add(lvset, x), VS.add(vset, x))
	     | linVS (lvset,vset) (Fun (_,ts,_)) = linVSList (lvset,vset) ts
	   and linVSList (lvset,vset) [] = (lvset,vset)
	     | linVSList (lvset,vset) (t::ts) =
	       let val (lvset2, vset2) = linVS (lvset,vset) t
	       in linVSList (lvset2,vset2) ts
	       end
       in
	   (fn (x,y) => x) (linVS (VS.empty, VS.empty) t)
       end
***)    

   fun maxVarIndexInTerm (Var (x,_)) = #2 x
     | maxVarIndexInTerm (Fun (_,ts,_)) = maxVarIndexInTerms ts
   and maxVarIndexInTerms [] = 0
     | maxVarIndexInTerms (t::ts) = Int.max (maxVarIndexInTerm t, maxVarIndexInTerms ts)

   local
   fun rangeVarIndexInTermOp (Var (x,_)) = (SOME (#2 x), #2 x)
     | rangeVarIndexInTermOp (Fun (_,ts,_)) = rangeVarIndexInTermsOp ts
   and rangeVarIndexInTermsOp [] = (NONE,0)
     | rangeVarIndexInTermsOp (t::ts) = 
       let val (l1,g1) = rangeVarIndexInTermOp t
	   val (l2,g2) = rangeVarIndexInTermsOp ts
       in case (l1,l2) of
	      (NONE,NONE) => (NONE, Int.max (g1,g2))
	    | (SOME i,NONE) => (SOME i, Int.max (g1,g2))
	    | (NONE,SOME i) => (SOME i, Int.max (g1,g2))
	    | (SOME i,SOME j) => (SOME (Int.min (i,j)), Int.max (g1,g2))
       end
   in
   fun rangeVarIndexInTerm t = 
       let val (l1,g1) = rangeVarIndexInTermOp t
       in case l1 of NONE => (0,g1) | SOME i => (i,g1)
       end
   fun rangeVarIndexInTerms ts = 
       let val (l1,g1) = rangeVarIndexInTermsOp ts
       in case l1 of NONE => (0,g1) | SOME i => (i,g1)
       end
   end

   fun increaseVarIndexBy n (Var (x,sort)) = Var (Var.increaseIndexBy n x,sort)
     | increaseVarIndexBy n (Fun (f,ts,sort)) = Fun (f,L.map (increaseVarIndexBy n) ts,sort)

   fun constructorCap cSymSet t =
       let val count = ref (maxVarIndexInTerm t) 
	   val vSym = Atom.atom "z"
	   fun cap (s as Var (x,_)) = (s,VS.empty)
	     | cap (Fun (f,ts,ty)) = 
	       if not (FS.member(cSymSet,f))
	       then (count := !count + 1; (Var ((vSym,!count),ty), 
					   VS.singleton (vSym,!count)))
	       else let val (ts', vset) = capList ts
		    in (Fun (f,ts',ty), vset)
		    end
	   and capList [] = ([],VS.empty)
	     | capList (t::ts) = let val (t',vset1) = cap t
				     val (ts',vset2) = capList ts
				 in (t'::ts', VS.union(vset1,vset2))
				 end
       in
	   cap t
       end

   fun basicCap (cSymSet,dSymSet) (Var _) = NONE
     | basicCap (cSymSet,dSymSet) (term as Fun (root,args,ty)) = 
       if FS.member (cSymSet, root) 
       then NONE
       else 
	   let val count = ref (maxVarIndexInTerm term) 
	       val vSym = Atom.atom "z"
	       fun cap (s as Var (x,_)) = (s, VS.empty)
		 | cap (Fun (f,ts,ty)) = 
		   if not (FS.member(cSymSet,f))
		   then (count := !count + 1; (Var ((vSym,!count),ty), 
					       VS.singleton (vSym,!count)))
		   else let val (ts', vset) = capList ts
			in (Fun (f,ts',ty), vset)
			end
	       and capList [] = ([],VS.empty)
		 | capList (t::ts) = let val (t',vset1) = cap t
					 val (ts',vset2) = capList ts
				     in (t'::ts', VS.union(vset1,vset2))
				     end
	       val (args',varSet) = capList args
	   in
	       SOME (Fun (root,args',ty), varSet)
	   end

  (* Embedding of a term s into a term t: s \ge^emb t and s \gt^emb t *) 
   local
       val runDebug = ref false
       fun debug f = if (!runDebug) then f () else ()
       datatype order_type = GE | GT
       fun emb orderType (s,t) =
	   let val geMap = ref TermPairMap.empty
	       val gtMap = ref TermPairMap.empty
	       fun lookupGt (s,t) = 
		   case TPM.find (!gtMap, (s,t)) of 
                       SOME b => (debug (fn _ => print ( "gtMap used (" 
							 ^ (toString s) ^ " :gt: "
							 ^ (toString t) ^ ")\n" ));
				  b)
                     | NONE => let val _ =  debug (fn _ => print ( "try to judge  (" 
								   ^ (toString s) ^ " :gt: "
								   ^ (toString t) ^ ")\n" ))
				   val b = embGtSub (s,t)
                               in gtMap := TPM.insert (!gtMap, (s,t), b);
				  geMap := TPM.insert (!geMap, (s,t), b); b
			       end
	       and lookupGe (s,t) = 
		   case TPM.find (!geMap, (s,t)) of 
                       SOME b => (debug (fn _ => print ( "geMap used (" 
							 ^ (toString s) ^ " :ge: "
							 ^ (toString t) ^ ")\n" ));
				  b)
                     | NONE => let val _ =  debug (fn _ => print ( "try to judge  (" 
								   ^ (toString s) ^ " :ge: "
								   ^ (toString t) ^ ")\n" ))
				   val b = embGeSub (s,t)
                               in geMap := TPM.insert (!geMap, (s,t), b); b
			       end
	       and embGeSub (Var (x,_), Var (y,_)) = (Var.compare (x,y) = EQUAL)
		 | embGeSub (Var _, Fun _) = false
		 | embGeSub (s as Fun (f,ss,_), t as Var _) = L.exists (fn si => lookupGe (si,t)) ss
		 | embGeSub (s as Fun (f,ss,_), t as Fun (g,ts,_)) = 
		   (Fun.compare (f,g) = EQUAL 
		    andalso LP.all (fn (si,ti) => lookupGe (si,ti)) (ss,ts))
		   orelse
		   L.exists (fn si => lookupGe (si,t)) ss
	       and embGtSub (Var _, _) = false
		 | embGtSub (s as Fun (f,ss,_), t as Var _) = L.exists (fn si => lookupGe (si,t)) ss
		 | embGtSub (s as Fun (f,ss,_), t as Fun (g,ts,_)) = 
		   (Fun.compare (f,g) = EQUAL 
		    andalso LP.all (fn (si,ti) => lookupGe (si,ti)) (ss,ts)
		    andalso LP.exists (fn (si,ti) => lookupGt (si,ti)) (ss,ts))
		   orelse
		   L.exists (fn si => lookupGe (si,t)) ss
	   in case orderType of
		  GE => lookupGe (s,t)
		| GT => lookupGt (s,t)
	   end
   in
   fun embGt (s,t) = emb GT (s,t)
   fun embGe (s,t) = emb GE (s,t)
   end


  (* Congruence Closure *)
   local 
       fun reflStep tset = TS.foldl (fn (t,set) => TPS.add (set,(t,t))) TPS.empty tset 
       fun symStep eqset = TPS.map (fn (s,t) => (t,s)) eqset
       fun transStep eqset = 
	   let fun transSub (fst,snd) = TPS.foldl (fn ((snd',trd),set) => if equal (snd,snd')
									  then  (TPS.add (set, (fst,trd)))
									  else set)
						  TPS.empty
						  eqset
	   in TPS.foldl (fn ((fst,snd),set) => TPS.union (set, transSub (fst,snd)))
			TPS.empty
			eqset
	   end
       fun cong0Step falist univSet eqset = 
	   let val eqs = TPS.listItems eqset
	       val unary = L.map (fn (x,y) => ([x],[y])) eqs
	       fun init 0 = SOME [([],[])] 
		 | init 1 = SOME unary 
		 | init _ = NONE
	       val table = ref init
	       fun lookup n = 
		   case (!table) n of
		       SOME xs => xs
		     | NONE => let val entry = ListXProd.mapX (fn ((x,y),(xs,ys)) => (x::xs, y::ys)) (eqs, lookup (n-1))
				   fun newTable x = if x = n then SOME entry else (!table) n
			       in (table := newTable; entry)
			       end
	   in L.foldl (fn ((f,ar),set) => let val args = lookup ar
					      val pairs = L.map (fn (ts,ss) => (Fun (f,ts,Sort.null), Fun (f,ss,Sort.null))) args
					      val result = L.filter (fn (s,t) => TS.member (univSet,s) andalso TS.member (univSet,t)) pairs
					  in TPS.addList (set,result)
					  end)
		      TPS.empty
		      falist
	   end
       fun congStep falist univSet eqset =
	   let val R = reflStep univSet
	       val S = symStep eqset
	       val T = transStep eqset
	       val C = cong0Step falist univSet eqset (* prefiltered by univSet^2 *)
	       val allset = TPS.union (TPS.union (TPS.union (TPS.union (eqset,R), S), T), C)
	   in TPS.filter (fn (s,t) => TS.member (univSet,s) andalso TS.member (univSet,t)) allset
	   end
	      
   in fun checkCC eqs (s,t) =
	  let val falist = FunIntSet.listItems (funAritySetInTerms (LU.mapAppend (fn (l,r)=> [l,r]) ((s,t)::eqs)))
	      val univSet = L.foldl (fn ((u,v),set) => TS.addList(TS.addList (set, subterms u), subterms v)) TermSet.empty ((s,t)::eqs)
	      val eqset = TPS.addList (TermPairSet.empty, eqs)
	      fun check CC = if TPS.member (CC, (s,t))
			     then true
			     else let val nextCC = congStep falist univSet CC
				  in if TPS.equal (CC,nextCC)
				     then false
				     else check nextCC
				  end
	  in check eqset
	  end
   end (* local *)





   end (* of local *)

   end (* of structure Term *)




