(* file: cp.sml *)
(* description: CP algorism *)
(* author: Masaomi Yamaguchi *)

signature NUE_CP_NF = 
sig
    type cpnf = (NueTerm.term * NueTerm.term) list
    exception Witness of (NueTerm.term * NueTerm.term * cpnf)
    val replace: cpnf -> NueTerm.term -> NueTerm.term
    val CPa: NueTrs.trs -> NueTrs.eqs -> (NueTerm.term -> NueTerm.term) -> NueTerm.term list -> (cpnf * (NueTerm.term * NueTerm.term) option)
    val toString: cpnf -> string
    val toStringOP: (cpnf * (NueTerm.term * NueTerm.term) option) -> string
end

structure NueCP: NUE_CP_NF =
struct 

local 
    structure L = List
    structure LU = NueListUtil
    structure T = NueTerm
    structure S = NueSubst
    structure AL = NueAssocList
    structure H = NueHat  
    structure Rewrite = NueRewrite
in
type cpnf = (T.term * T.term) list
exception Witness of (T.term * T.term * cpnf)
			 
(* 定数の代表元の集合を作成 *)
fun repConst hat C = foldl (fn (c,C') => LU.add (hat c) C') [] C

(* Xnfの初期値を計算する *)
fun CP1 hat (T.Fun (c,[])) eqs =
    let (* c <--> f(…)というルールなら，^f(…)をtsに追加する *)
	fun addRuleEqC (T.Fun (c',[]),T.Fun (c'',[])) ts = ts
	  | addRuleEqC (T.Fun (c',[]),T.Fun (f,ss)) ts = if c = c' then LU.add (hat (T.Fun(f,ss))) ts
							 else ts
	  | addRuleEqC (T.Fun (f,ss),T.Fun (c',[])) ts = if c = c' then LU.add (hat (T.Fun(f,ss))) ts
							 else ts						  
	  | addRuleEqC _ ts = ts 
    in
	foldl (fn (rule,ts) => addRuleEqC rule ts) [] eqs
    end
	 
(* CPNFを使って項中の定数を置き換える *)
fun replace Xnf (T.Var x) = T.Var x
  | replace Xnf (T.Fun (c,[])) = (case AL.find (T.Fun (c,[])) Xnf of NONE => T.Fun (c,[])
								   | SOME c' => c')
  | replace Xnf (T.Fun (f,ts)) = T.Fun (f,map (fn t => replace Xnf t) ts)
				       
				       
(* step2を各Yiについて実行 *)
fun sub_Yi Xnf trs (c,Ti) =
    let val (X_tmp_i,elim) = foldl (fn (r,(X_tmp_i,elim)) =>
				       let val r' = replace Xnf r
				       in
					   if Rewrite.isNF trs r' then (LU.add (c,r') X_tmp_i,LU.add r elim)
					   else (X_tmp_i,elim)
				       end
				   ) ([],[]) Ti
	val Ti' = LU.difference (Ti, elim)
    in
	 (X_tmp_i,(c,Ti'))
    end

(* 項のdepth-1 subtermに変数が含まれているかチェック *)
fun hasVar (T.Var x) = false
  | hasVar (T.Fun (f,ts)) = L.exists (fn (T.Var x) => true
				   |   _ => false) ts

(* 変数の名前変えを行う．具体的には，各変数のiに1を足す *)
fun alpha (T.Var (x,i)) = (T.Var (x,i+1))
  | alpha (T.Fun (f,ts)) = T.Fun (f, map alpha ts)
				     
(* ALのaddと基本は同じだが，witnessを見つけたら例外でwitnessを知らせる *)
fun witAdd (c,t) xs = case AL.find c xs of
			  SOME t' => if t = t' then xs
				     else raise Witness (t,t',LU.add (c,t) xs)
			| NONE => if (hasVar t) then raise Witness (t,alpha t,LU.add (c,alpha t) (LU.add (c,t) xs))
				  else (c,t)::xs

(* ALのUnionと基本は同じだが，witnessを見つけたら例外でwitnessを知らせる *) 
fun witUnion ([],ys) = ys
  | witUnion ((x::xs),ys) = witUnion (xs,(witAdd x ys))

(* Constant Propagation algorithm *)
fun CPa trs cpEqs hat C =
    (let
	(* 定数の代表元の集合 *)
	val repC = repConst hat C
	(* Y_iを要素にもつ集合を生成 *)
	val Y = map (fn c => (c,CP1 hat c cpEqs)) repC
	val Xnf = foldl (fn (c,X) =>
			    if Rewrite.isNF trs c then
				witAdd (hat c,c) X
			    else X
			) [] C
	fun main Xnf Y = let val (X_tmp,Y') = foldl (fn (Yi,(X_tmp,Y')) =>
							let
							    val (X_tmp_i,Yi') = sub_Yi Xnf trs Yi
							in
							    (X_tmp_i@X_tmp,Yi'::Y')
							end
						    ) ([],[]) Y
						    
			 in
			     if X_tmp = [] then Xnf
			     else main (witUnion (X_tmp,Xnf)) Y'
			 end
    in
	(main Xnf Y,NONE)
    end) handle Witness (t1,t2,CPnf) => (CPnf,SOME (t1,t2))

fun prPair (c,t) = T.toString c ^" … "^ T.toString t
fun toString cpnf = LU.toStringCommaLnSquare prPair cpnf
fun toStringOP (cpnf,_) = LU.toStringCommaLnSquare prPair cpnf

end (* of local *)
end (* of struct *)
    
