(* file: un.sml *)
(* description: check UN= for shallow TRS *)
(* author: Masaomi Yamaguchi *)

signature NUE_UNIQUENESS_OF_NORMAL_FORMS = 
sig
(*    val checkUN_print: NueTrs.trs -> unit *)
    val checkUN_print: NueTrs.trs -> bool
    val checkUN: NueTrs.trs -> (NueTerm.term * NueTerm.term) option		      
end

structure NueUN: NUE_UNIQUENESS_OF_NORMAL_FORMS =
struct 

local 
    structure L = List
    structure LU = NueListUtil
    structure T = NueTerm
    structure S = NueSubst
    structure AL = NueAssocList	      
    structure Trs = NueTrs
    structure Rewrite = NueRewrite
    structure FlatTrs = NueFlatTrs
    structure Flatting = NueFlatting
    structure CompE = NueCompE
    structure Measure = NueMeasure
    structure OrderRewrite = NueOrderRewrite
    open PrintUtil
in
exception Witness of (T.term * T.term)
(* !cx_0〜!cx_n-1までのn個の定数(に見立てた変数)を作成する *)
fun makeConstants n =
    let fun main 0 = []
	  | main n =  (T.Var ("!cx",n-1)) :: main (n-1)
    in main n
    end

(* べき乗 x^n の計算，nは正のみ *)
fun pow x 0 = 1
  | pow x n = x * pow x (n-1)

fun makeAllNFTerms rs F C k =
    let
	fun make_f_NF (f,arity) h1_ts
	    = let val rs' = Trs.filter f rs 
		  fun make_tss 0 = [[]]
		    | make_tss n = LU.largeUnion (map (fn ts => map (fn t => t::ts) h1_ts) (make_tss (n-1)))
	      in
		  foldl (fn (ts,nfs) =>
			    if Rewrite.isNF_root rs' (T.Fun (f,ts)) then LU.add (T.Fun (f,ts)) nfs
			    else nfs
			) [] (make_tss arity)
	      end
	(* 引数：高さh-1以下のNFの集合，現在の高さ *)
	fun main ts h =
	    let (* 高さhのNFの集合 *)
		val h_ts = LU.flatten (map (fn (f,arity) => make_f_NF (f,arity) ts) F)
		val ts' = LU.union (h_ts,ts)
	    in if h = k then ts'
	       else main ts' (h+1)
	    end 
    in
	main (L.filter (fn c => Rewrite.isNF_root rs c) C) 1
    end


(* ALのaddと基本は同じだが，witnessを見つけたら例外でwitnessを知らせる *)
fun witAdd (nft,t) xs = case AL.find nft xs of
			  SOME t' => if t = t' then xs
				     else raise Witness (t,t')
			 | NONE => (nft,t)::xs(*case nft of
				       (T.Fun (c,[])) => xs
				     | _ => (nft,t)::xs*)
						
(* 交差する2つのdistinctなnormal formsが存在するかを確認する *)
(* 各項の^Rにおけるnormal formを計算し，重複がないか調べる *)
(* 上記normal formが定数となったものは調べる必要がない *)
fun checkConvergent cpTrs ts =
    (foldl (fn (t,ts') => witAdd (OrderRewrite.linf cpTrs t,t) ts') [] ts;
     NONE)
    handle Witness (t1,t2) => SOME (t1,t2)
				   
			  
fun checkUN_print_sub rs =
    (let val rs = (print ("Input:\n"^(Trs.prRules rs)^"\n");
		   Measure.pf "Make it flat:\n" (Trs.prRules) (fn () =>
				  Flatting.flatting rs))
	 val fs = FlatTrs.function_symbols_with_arity rs
	 val max_arity = FlatTrs.max_arity fs
	 val F = FlatTrs.NC_function_symbols_with_arity fs		  
	 val c_num = FlatTrs.enum_c fs
	 val constant = FlatTrs.constant fs
	 val k = Int.max (1,length constant)
	 val cpE = Measure.pf "Make it Complete (R^):\n" (Trs.prEqs) (fn () => CompE.comp rs)
	 val cpTrs = CompE.compTrs cpE
	 fun main 0 = let val c =
			      Measure.pf "The number of normal forms that must be checked: " (fn ts => Int.toString (length ts) ^ "\n")
					 (fn () => (L.filter (fn c => Rewrite.isNF_root rs c) constant))
		      in
			  (print ("Now checking all the pairs...\n\n");
			   Measure.p (fn () => checkConvergent cpTrs c) "Time to check pairs: ")
		      end
	   | main a = let val n = 2 * pow a (k-1)
			  val C = (makeConstants n) @ constant
			  val ts = Measure.pf "The number of normal forms that must be checked: " (fn ts => Int.toString (length ts) ^ "\n") (fn () => makeAllNFTerms rs F C k)
		      in
			  (print ("Now checking all the pairs...\n\n");
			   Measure.p (fn () => checkConvergent cpTrs ts) "Time to check pairs: ")
		      end
     in
	 case main max_arity of
	     SOME (t1,t2) => (print ("The TRS doesn't have Uniqueness of Normal Forms.\n"
				     ^"Counter Example: \n"
				     ^ "     " ^ (T.toString t1) ^ "\n"
				     ^ "<->* " ^ (T.toString t2) ^ "\n\n");
			      print "proof:\n";
			      OrderRewrite.listepsToNF cpTrs t1;
			      print "\n";
			      OrderRewrite.listepsToNF cpTrs t2;
			      false)
	   | NONE => (print "The TRS has Uniqueness of Normal Forms.\n"; true)
	     
     end) handle CompE.Inconsistent =>
		 (print ("The TRS doesn't have Uniqueness of Normal Forms because it is inconsistent.\n"); false)
					     
fun checkUN_print rs =
    Measure.p (fn () => checkUN_print_sub rs) "Total Time: ";
   (*  checkUN_print_sub rs *)

fun checkUN rs =
    (let val rs = Flatting.flatting rs
	 val fs = FlatTrs.function_symbols_with_arity rs
	 val max_arity = FlatTrs.max_arity fs
	 val F = FlatTrs.NC_function_symbols_with_arity fs		  
	 val c_num = FlatTrs.enum_c fs
	 val constant = FlatTrs.constant fs
	 val k = Int.max (1,length constant)
	 val cpE = CompE.comp rs
	 val cpTrs = CompE.compTrs cpE
	 fun main 0 = let val c = L.filter (fn c => Rewrite.isNF_root rs c) constant in
			  checkConvergent cpTrs c
		      end
	   | main a = let val n = 2 * pow a (k-1)
			  val C = (makeConstants n) @ constant
			  val ts = makeAllNFTerms rs F C k
		      in
			  checkConvergent cpTrs ts
		      end
     in
	 main max_arity    
     end) handle CompE.Inconsistent => SOME (T.Var ("x",0),T.Var ("y",0))



end (* of local *)
end (* of struct *)
