(* file: term.sml *)
(* description: first-order terms *)
(* author: Takahito Aoto *)

signature NUE_TERM = 
sig 
    type var_key = NueVar.key
    type fun_key = NueFun.key
    type position = int list
    datatype term = Var of var_key | Fun of fun_key * term list
    type context = term -> term

    exception PositionNotInTerm
		  
    val root: term -> NueSymbol.symbol
    val args: term -> term list 
    val isVar: term -> bool 
    val isFun: term -> bool 
    val vars: term -> var_key list 
    val funs: term -> fun_key list 

    val toString: term -> string
    val printList: term list -> unit
    val fromString: string -> term

    val cropId: string -> (string * string) option
    val cropTerm: string -> (term * string) option
    val cropKeySeparatedTermPair: string -> string -> ((term * term) * string) option
    val readKeySeparatedTermPair: string -> string -> (term * term)
    val readMultipleKeySepartedTermPairs: string * string * string -> string -> string -> (term * term) list

    val prPos: position -> string
    val pos: term -> position list
    val posWithDepth: term -> (position * int) list
    val symbol: term -> position -> NueSymbol.symbol
    val subterm: term -> position -> term

    val makeContext: term -> position -> context
    val fillContext: context -> term -> term
    val replaceSubterm: term -> position -> term -> term

    val size: term -> int
    val pattX: term -> (int * int) list
end

structure NueTerm : NUE_TERM =
struct
local 
    structure LP = ListPair
    structure SU = NueStringUtil
    structure LU = NueListUtil
    structure Fun = NueFun
    structure Var = NueVar
    structure Symbol = NueSymbol
    open PrintUtil

in

type var_key = Var.key
type fun_key = Fun.key
type position = int list	   
datatype term = Var of var_key | Fun of fun_key * term list
type context = term -> term
			   
exception PositionNotInTerm
	      
fun root (Var x) = Symbol.V x
  | root (Fun (f,ts)) = Symbol.F f		      

fun args (Var x) = [] :term list
  | args (Fun (f,ts)) = ts				       
			    
fun isVar (Var x) = true
  | isVar _ = false
		  
fun isFun (Fun (f,ts)) = true
  | isFun _ = false
		  
fun vars (Var x) = [x]
  | vars (Fun (f,ts)) = varsList ts
and varsList [] = []
  | varsList (t::ts) = LU.union ((vars t),varsList ts)
    
fun funs (Var x) = []
  | funs (Fun (f,ts)) = LU.add f (funsList ts)
and funsList [] = []
  | funsList (t::ts) = LU.union ((funs t),funsList ts)

fun toString (Var x) =  "?" ^ (Var.toString x)
  | toString (Fun (f,[])) = (Fun.toString f)
  | toString (Fun (f,ts)) = (Fun.toString f) ^ "(" ^ (toStringList ts)
and toStringList [] = "" (* does not reach here *)
  | toStringList ([t]) = (toString t) ^ ")"
  | toStringList (t::ts) = (toString t) ^ "," ^ (toStringList ts)

fun printList ts = print (foldr (fn (t,s) => (toString t) ^"\n"^s) "" ts)

structure TermSpecialSymbols = struct val special = [#"(", #")", #","]  end
structure TermLex = Lexical (TermSpecialSymbols)
structure TermParsing = Parsing (TermLex)

local
    fun makeFun (id, ts) = (case Symbol.fromString id of
			    Symbol.F f => Fun (f, ts)
			  | Symbol.V _ => raise Fail "Syntax error: function symbol expected")
    fun makeList (t, ts) = t::ts
    fun makeList1 t = t::nil
    fun makeAtom id  = (case Symbol.fromString id of
			    Symbol.F c => Fun (c, []) 
			  | Symbol.V x => Var x)

    open TermParsing
    infix 6 $--
    infix 5 --
    infix 3 >>
    infix 0 ||

    fun term toks =
        ( id --  "(" $-- termlist >> makeFun || atom ) toks
    and termlist toks =
        ( termseq -- $ ")" >> #1 ) toks
    and termseq toks =
        ( term -- "," $-- termseq >> makeList || term >> makeList1 ) toks
    and atom toks  =
        ( id >> makeAtom ) toks
in 
fun fromString str = reader term str
end (* of local *)

fun cropId str = TermLex.cropId str

fun cropTerm str = case cropId str of
		       NONE => NONE
		     | SOME (id,body) => let val (init, rest) = SU.scanBalancedPar body
					     val t = fromString (id ^ init)
					 in SOME (t, rest) end

fun cropKeySeparatedTermPair key str 
    = case cropTerm str of
	  NONE => NONE
	| SOME (lhs, str1) => (case SU.scanKey key str1 of 
				   NONE => raise Fail ("Syntax error: " ^ key ^ " expected after term" )
				 | SOME str2 => (case cropTerm str2 of
						     NONE => raise Fail ("Syntax error: term expected after " ^ key)
						   | SOME (rhs, rest) => SOME ((lhs,rhs), rest)))


(* term key term の形の文字列から，項の対を読み込む *)
fun readKeySeparatedTermPair key str = 
    case cropKeySeparatedTermPair key str of
	SOME (tp,rest) => if SU.ending rest then tp
			  else raise Fail ("Syntax error: trailing " ^ rest)
      | NONE => raise Fail ("Syntax error: not a " ^ key ^ " separated term pair")

(* start term key term sep ... sep term key term stop の形の文字列から *)
(* 項の対のリストを読み込む *)
fun readMultipleKeySepartedTermPairs (start,sep,stop) key str =
    let fun rdFirstItem s = cropKeySeparatedTermPair key s
	fun rdRemainingItems ans s = 
	    case SU.scanKey sep s of
		SOME rest => (case rdFirstItem rest of 
				 SOME (new,s2) => rdRemainingItems (new::ans) s2
			       | NONE => raise Fail ("Syntax error: starting term pair expected " ^ rest))
	      | NONE => if SU.ending s
			then rev ans
			else raise Fail ("Syntax error: trailing " ^ s)
    in case SU.scanBeginEnd (start,stop) str of
	   NONE => raise Fail ("Syntax error: " ^ start ^ "..." ^ stop ^ " expected")
	 | SOME str1 => (case rdFirstItem str1 of
			     NONE => []
			   | SOME (tp,rest) => rdRemainingItems [tp] rest)
    end

fun prPos [] = "e"
  | prPos [p] = Int.toString p
  | prPos (p::ps) = Int.toString p ^ "." ^ prPos ps

fun pos (Var x) = [[]]
  | pos (Fun (f,ts)) = LU.add [] (posList ts 1)
and posList [] i = []
  | posList (t::ts) i = LU.union ((map (fn ps => i :: ps) (pos t)), (posList ts (i+1)))

fun posWithDepth t =
    let 
	fun posD (Var x) d = [([],d)]
	  | posD (Fun (f,ts)) d = ([],d) :: (posListD ts 1 d)
	and posListD [] i d = []
	  | posListD (t::ts) i d = (map (fn (ps,d) => (i::ps,d)) (posD t (d+1))) @ (posListD ts (i+1) d)
    in
	posD t 0
    end
	
fun symbol (Var x) [] = Symbol.V(x)
  | symbol (Fun (f,ts)) [] = Symbol.F(f)
  | symbol (Var x) _ =  raise PositionNotInTerm
  | symbol (Fun (f,ts)) (p::ps) = case (LU.get ts p) of
				      NONE => raise PositionNotInTerm
				    | SOME t => symbol t ps
							  
fun subterm t [] = t
  | subterm (Var x) _ =  raise PositionNotInTerm
  | subterm (Fun (f,ts)) (p::ps) = case (LU.get ts p) of
				       NONE => raise PositionNotInTerm
				     | SOME t => subterm t ps

fun makeContext t ps = (fn u =>
			  let fun replaceU t [] = u
				| replaceU (Var x) _ = raise PositionNotInTerm
				| replaceU (Fun (f,ts)) (p::ps) = case (LU.mapSpecified (fn t => replaceU t ps) ts p) of
								       NONE => raise PositionNotInTerm
								     | SOME ts' => Fun (f,ts')
			  in replaceU t ps end)
							 
fun fillContext c t = c t

fun replaceSubterm t p u = fillContext (makeContext t p) u

fun size (Var x) = 1
  | size (Fun (f,ts)) = 1 + (sizeList ts)
and sizeList [] = 0
  | sizeList (t::ts) = size t + sizeList ts
					 
(* 項の変数についてのパターンを計算 *)
fun pattX (Var x) = []
  | pattX (Fun (f,ts)) = pattXList ts 1
and pattXList [] n = []
  | pattXList ((Var x)::ts) n =
    let
	fun sub [] n' = []
	  | sub ((Var y)::ts') n' = if x = y then (n,n')::(sub ts' (n'+1)) 
				      else (sub ts' (n'+1)) 
	  | sub (t'::ts') n' = sub ts' (n'+1)
    in
	(sub ts (n+1)) @ (pattXList ts (n+1))
    end
  | pattXList (_::ts) n = pattXList ts (n+1)
				    
end (* of local *)    
end


