(******************************************************************************
 * Copyright (c) 2012-2013, Toyama&Aoto Laboratory, Tohoku University
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without 
 * modification, are permitted provided that the following conditions are met:
 * 
 *  1. Redistributions of source code must retain the above copyright notice, 
 *     this list of conditions and the following disclaimer.
 *  2. Redistributions in binary form must reproduce the above copyright 
 *     notice, this list of conditions and the following disclaimer in the 
 *     documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 
 * POSSIBILITY OF SUCH DAMAGE.
 ******************************************************************************)
(******************************************************************************
 * file: rwtools/util/ring.sml
 * description: utilities for ring, vector, matrix, polynomials
 * author: AOTO Takahito
 * 
 ******************************************************************************)

(* 環 ring *)
signature RING = 
sig
    type elem
    val toString: elem -> string
    val equal: elem * elem -> bool
    val zero: elem
    val plus: elem * elem -> elem
    val minus: elem -> elem
    val one: elem
    val times: elem * elem -> elem
end

structure IntRing : RING = (* RING of integers *)
struct
    type elem = int
    val toString = Int.toString
    fun equal (x,y) = x = y
    val zero = 0
    fun plus (x,y) = x + y
    fun minus x = ~x
    val one = 1
    fun times (x,y) = x * y
end

signature VEC = (* vector *)
sig
   type elem
   type vec
   val toString: vec -> string
   val equal: vec * vec -> bool
   val fromList: elem list -> vec
   val toList: vec -> elem list
   val dimension: int
   val zero: vec
   val units: vec list
   val minus: vec -> vec
   val plus: vec * vec -> vec
   val times: vec * vec -> elem
   val multiply: elem -> vec -> vec
end

(* ベクトルや行列で次元を指定するのに使うシグニチャ *)
signature DIMENSION = sig val dimension: int end

functor Ring2VecFn (structure Ring: RING
                    structure Dimension: DIMENSION) : VEC = 
struct
   local 
       structure L = List
       structure LP = ListPair
       structure LU = ListUtil
   in
   type elem = Ring.elem
   type vec = Ring.elem list 
   fun toString xs = LU.toStringSpaceSquare Ring.toString xs
   fun equal (xs,ys) = LP.allEq Ring.equal (xs,ys)
   val dimension = Dimension.dimension
   fun fromList xs = L.take (xs,dimension)    (* dimension 長のリスト *)
   fun toList xs = xs
   val zero = L.tabulate (dimension,fn x=>Ring.zero)
   val units = L.tabulate (dimension,fn x=>LU.replaceNth (zero,x,Ring.one))
   fun plus (xs,ys) = LP.mapEq Ring.plus (xs,ys)
   fun minus xs = L.map Ring.minus xs
   fun times (xs,ys) = LP.foldl (fn (x,y,z) => Ring.plus (Ring.times(x,y),z))
				Ring.zero (xs,ys)
   fun multiply n xs = L.map (fn x => Ring.times (n,x)) xs
   end
end

(* 長さ2の整数ベクトル *)
structure IntVec = Ring2VecFn (structure Ring = IntRing
            structure Dimension = struct val dimension = 2 end) : VEC

(* 長さ3の整数ベクトル *)
structure IntVec3 = Ring2VecFn (structure Ring = IntRing
            structure Dimension = struct val dimension = 3 end) : VEC

(***
val h0 = IntVec.zero
val _ = print ((IntVec.toString h0) ^ "\n")

val hs = IntVec.units
val h1 = List.nth (hs,0)
val h2 = List.nth (hs,1)
val _ = print ((IntVec.toString h1) ^ "\n")
val _ = print ((IntVec.toString h2) ^ "\n")

val h3 = IntVec.plus (h1,h2)
val _ = print ((IntVec.toString h3) ^ "\n")

val h4 = IntVec.multiply 2 h3
val _ = print ((IntVec.toString h4) ^ "\n")

val v1 = IntVec.fromList [2,3,4]  (* drop the last *)
val _ = print ((IntVec.toString v1) ^ "\n")

val _ = print ((IntRing.toString (IntVec.times (h1,v1))) ^ "\n")
***)

(* 次元を固定した正方行列 *)
signature MATRIX = 
sig
   type elem
   type vec
   type matrix
   val toString: matrix -> string
   val equal: matrix * matrix -> bool
   val fromList: elem list list -> matrix
   val toList: matrix -> elem list list
   val transpose: matrix -> matrix
   val plus: matrix * matrix -> matrix
   val minus: matrix -> matrix
   val vmtimes: vec * matrix -> vec
   val mvtimes: matrix * vec -> vec
   val times: matrix * matrix -> matrix
   val zero: matrix
   val unit: matrix
   val element: matrix -> int * int -> elem option
   (* 行列成分の位置 (1,1)  (1,2)  (1,3)
                    (2,1)  (2,2)  (2,3)
                    (3,1)  (3,2)  (3,3)  *)
   val row: matrix -> int -> vec option  (* "行" i.e. 横方向の列 *)
   val column: matrix -> int -> vec option (* "列"  i.e. 縦方向の列 *)
end

(* 行列解釈のための特殊な順序が入った、*)
(* 環の元を成分とする、次元固定の正方行列 *)
(* 行列は [1 2]
   　　　 [3 4] を [[1,2],[3,4]] と表す
 *)
functor Ring2MatrixFn (structure Ring: RING 
                       structure Dimension: DIMENSION) : MATRIX =
struct
   local 
       structure L = List
       structure LP = ListPair
       structure LU = ListUtil
       structure V = Ring2VecFn (structure Ring = Ring
                                 structure Dimension = Dimension) : VEC
   in
   type elem = Ring.elem
   type vec = V.vec
   type matrix = V.vec list
   val dimension = Dimension.dimension

   fun toString xss = LU.toStringSpaceLnSquare V.toString xss
   fun equal (xss,yss) = LP.allEq  V.equal (xss,yss)
   fun plus (xss,yss) = LP.mapEq V.plus (xss,yss)
   fun minus xss = L.map V.minus xss

   fun fromList xss = xss
   fun toList xss = xss

   (* transpose [[1,2,3],[4,5,6],[7,8,9]] = [[1,4,7],[2,5,8],[3,6,9]] *)
   fun transpose mx = 
       let fun transposeList [] = []
	     | transposeList (xs::[]) = L.map (fn x=>[x]) xs
	     | transposeList (xs::xss) = 
	       LP.mapEq (fn (x,xs)=> x::xs) (xs,transposeList xss)
       in fromList (transposeList (toList mx))
       end

   (* vmtimes vec matrix: vec is e.g. [1,2,3] *)
   fun vmtimes (xs,yss) = L.map (fn ys=> V.times (xs,ys)) (transpose yss)

   (* mvtimes matrix vec: vec is interpreted as [1,2,3]^T *)
   fun mvtimes (xss,ys) = L.map (fn xs => V.times (xs,ys)) xss

   fun times (xss,yss) = L.map (fn xs => vmtimes (xs,yss)) xss

   val zero = L.map (fn _=> V.zero) (L.tabulate (dimension,fn x=>x))

   val unit = L.map (fn i=> LU.replaceNth (V.zero,i,Ring.one)) 
		    (L.tabulate (dimension,fn x=>x))

   fun element xss (i,j) = (SOME (L.nth (L.nth (xss,i-1), j-1)))
       handle Subscript => NONE

   fun row xss i = (SOME (L.nth (xss,i-1))) 
       handle Subscript => NONE

   fun column xss j = (SOME (L.map (fn xs => L.nth (xs,j-1)) xss))
       handle Subscript => NONE

   end
end

(* 2x2の整数行列 *)
structure IntMatrix = Ring2MatrixFn (structure Ring = IntRing
                      structure Dimension = struct val dimension = 2 end) : MATRIX

(***
val m1 = IntMatrix.fromList [[1,2],[3,4]]
val _ = print (IntMatrix.toString m1)
val _ = print (IntVec.toString (valOf (IntMatrix.row m1 1))^ "\n")
val _ = print (IntVec.toString (valOf (IntMatrix.column m1 1))^ "\n")

val h1 = IntVec.fromList [1,3]
val _ = print ((IntVec.toString h1) ^ "\n")
val h2 = IntMatrix.vmtimes (h1,m1)
val _ = print ((IntVec.toString h2) ^ "\n")
val h3 = IntMatrix.mvtimes (m1,h1)
val _ = print ((IntVec.toString h3) ^ "\n")

val m1' = IntMatrix.fromList [[1,5],[7,2]]
val _ = print (IntMatrix.toString m1')

val m2 = IntMatrix.plus (m1,m1')
val _ = print (IntMatrix.toString m2)

val m3 = IntMatrix.times (m1,m1')
val _ = print (IntMatrix.toString m3)

val m3' = IntMatrix.times (m1',m1)
val _ = print (IntMatrix.toString m3')

val m4 = IntMatrix.zero
val _ = print (IntMatrix.toString m4)

val m5 = IntMatrix.unit
val _ = print (IntMatrix.toString m5)
***)

(* 3x3の整数行列 *)
structure IntMatrix3 = Ring2MatrixFn (structure Ring = IntRing
                       structure Dimension = struct val dimension = 3 end) : MATRIX
(***
val m1 = IntMatrix3.fromList [[1,2,3],[4,5,6],[7,8,9]]
val _ = print (IntMatrix3.toString m1)

val m2 = IntMatrix3.transpose m1
val _ = print (IntMatrix3.toString m2)
***)

(* 正方行列全体は環をなす *)
functor Matrix2RingFn (Matrix: MATRIX) : RING =
struct
    type elem = Matrix.matrix
    val toString = Matrix.toString
    fun equal (x,y) = Matrix.equal (x,y)
    val zero = Matrix.zero
    fun plus (x,y) = Matrix.plus (x,y)
    fun minus x = Matrix.minus x
    val one = Matrix.unit
    fun times (x,y) = Matrix.times (x,y)
end

(* 2x2の整数行列を要素とする環 *)
structure IntMatrixRing = Matrix2RingFn (IntMatrix)

(* POLY は Polynomial Order で使っているので．．*)
signature POLY2 =
sig
    type elem
    type poly
    val toString: poly -> string
    val equal: poly * poly -> bool
    val zero: poly
    val plus: poly * poly -> poly
    val minus: poly -> poly
    val one: poly
    val times: poly * poly -> poly
    val apply: poly -> poly list -> poly
    val eval: (int -> elem) -> poly -> elem
    val fromList: (int list * elem) list -> poly
    val toList: poly -> (int list * elem) list
    val coefficients: poly -> elem list
    val constant: poly -> elem
end

signature VAR_NAME = sig val varName: string end

(* 多項式：2x_1+3x_1x_2+5 を map { [1]-> 2, [1,2]->3, []-> 5 } で表わす *)
functor Ring2PolyFn (structure Ring: RING 
                     structure VarName: VAR_NAME) : POLY2 =
struct
    local 
       structure ILM = IntListMap2	
       structure L = List
       structure LU = ListUtil
    in
    type elem = Ring.elem
    type poly = Ring.elem ILM.map

    val varString = VarName.varName

    fun prVarList xs = 
	LU.toStringAst (fn n => varString ^ (Int.toString n)) xs

    fun toStringMono (xs,n) = 
	if null xs then "(" ^ (Ring.toString n) ^ ")"
	else "(" ^ (Ring.toString n) ^ ")" ^ "*" ^ (prVarList xs) 

    fun toString p = 
	let val str =  LU.toStringPlus toStringMono (ILM.listItemsi p)
	in if str = "" then (Ring.toString Ring.zero) else str
	end

    val zero = ILM.empty
    fun plus (p,q) = ILM.filter (fn x => not (Ring.equal (x, Ring.zero)))
				(ILM.unionWith (fn (x,y) => Ring.plus (x,y)) (p,q))
    fun minus p = ILM.map (fn x => Ring.minus x) p

    fun equal (p,q) = let val xs = ILM.listItemsi (plus (p, minus q))
		      in case xs of
			     [] => true
			   | _ => false
		      end

    val one = ILM.singleton ([],Ring.one)
    local
	fun merge([],ys) = ys
	  | merge(xs,[]) = xs
	  | merge(x::xs,y::ys) =
	    if x > y then y::merge(x::xs,ys) else x::merge(xs,y::ys)
						  
	fun times1 ((xs,n), p) = 
	    ILM.foldli 
		(fn (ys,m,q) => ILM.insert (q, merge (xs,ys), Ring.times (m, n)))
		ILM.empty
		p
    in
    fun times (p,q) = ILM.foldli
			  (fn (xs,n,r) => plus(times1((xs,n),q), r))
			  ILM.empty
			  p
    end

   (* 多項式の合成 *)
    fun apply1 (xs,n) qs = L.foldl (fn (x,p) => times (L.nth (qs,x-1),p))
				   (ILM.singleton ([],n)) xs
    fun apply p qs = ILM.foldli
			 (fn (xs,n,r) => plus (apply1 (xs,n) qs, r))
			 ILM.empty p

   (* 多項式の評価 *)			 
    fun eval1 env (xs,n) =  L.foldl (fn (x,p) => Ring.times (env x,p))
				    n xs
    fun eval env p =  ILM.foldli
			  (fn (xs,n,r) => Ring.plus (eval1 env (xs,n), r))
			  Ring.zero p

    fun initMono (xs,n) = ILM.singleton (xs,n)
    fun fromList xss = L.foldl (fn (m,p) => plus (initMono m, p)) zero xss
    fun toList p = ILM.listItemsi p

   (* 係数のリスト *)			 
    fun coefficients poly = ILM.listItems poly

   (* 定数 (0次の項) *)			 
    fun constant poly = case ILM.find (poly,[]) of
			    SOME n => n | NONE => Ring.zero
    end
end

(* 整数係数の多項式 *)
(* 2x+3xy+5 のようなのが要素 *)
structure IntPoly = Ring2PolyFn (structure Ring = IntRing
                    structure VarName = struct val varName = "x" end) : POLY2
                     
functor Poly2RingFn (Poly: POLY2) : RING =
struct
    type elem = Poly.poly
    val toString = Poly.toString
    fun equal (x,y) = Poly.equal (x,y)
    val zero = Poly.zero
    fun plus (x,y) = Poly.plus (x,y)
    fun minus x = Poly.minus x
    val one = Poly.one
    fun times (x,y) = Poly.times (x,y)
end

(* 整数係数の多項式を要素とする環 *)
(* 2x+3xy+5 のようなのが要素 *)
(* 加算は例えば、(2x+3xy+5) + (5x+7) = 7x+3xy+12 のように定義されている *)
structure IntPolyRing = Poly2RingFn (IntPoly) : RING

(* "整数係数の多項式"を係数とする多項式 *)

structure IntPolyPoly = Ring2PolyFn (structure Ring = IntPolyRing
                    structure VarName = struct val varName = "xx" end) : POLY2

(* "整数係数の多項式"を係数とする多項式を要素とする環 *)
(* (2x+3)a + (3y+1)ab + (2xy+x) のようなのが要素 *)
(* 加算は例えば、[(2x+3)a + (2xy+x)] + [(x+2+1)a + (x+y+1)b]
                 = (3x+2+4)a + (x+y+1)b + (2xy+x)
   のように定義されている *)
structure IntPolyPolyRing = Poly2RingFn (IntPolyPoly) : RING

(* 2x2の整数行列を係数とする多項式 *)
(* (2 0)*A  + (1 0)*A*B + (0 5) 
   (1 1)      (0 1)       (3 1) のようなのが要素
  ここで、*や+は行列の乗算と加算 *)
structure IntMatrixPoly = Ring2PolyFn (structure Ring = IntMatrixRing
                    structure VarName = struct val varName = "A" end) : POLY2
                     
(***
val p = IntPoly.zero
val _ = print ((IntPoly.toString p) ^ "\n")

val p1 = IntPoly.fromList [([1], 1)]
val _ = print ("p1 = " ^ (IntPoly.toString p1) ^ "\n")

val p2 = IntPoly.fromList [([1], 2),([1,2], 5),([2],10),([],3)]
val _ = print ("p2 = " ^ (IntPoly.toString p2) ^ "\n")

val _ = print ("-p2 = " ^(IntPoly.toString (IntPoly.minus p2)) ^ "\n")

val _ = print ("p1+p2 = " ^(IntPoly.toString (IntPoly.plus (p1,p2)))^ "\n")

val _ = print ("p1*p2 = " ^(IntPoly.toString (IntPoly.times (p1,p2)))^ "\n")

val env = (fn x => case x of 1 => 0 | _ => 1)

val e = IntPoly.eval env p2
val  _ = print ("p2[0,1]= " ^ (Int.toString e) ^ "\n")

val q = IntPoly.fromList [([1,2],3),([],1)]
val _ = print ("q = " ^(IntPoly.toString q) ^ "\n")

val _ = print ("q[p1,p2] = " ^ (IntPoly.toString (IntPoly.apply q [p1,p2])) ^ "\n")

val r1 = IntPoly.fromList [([1],2),([],1)]
val _ = print ("r1 = " ^(IntPoly.toString r1) ^ "\n")

val r2 = IntPoly.fromList [([1],3),([],3)]
val _ = print ("r2 = " ^(IntPoly.toString r2) ^ "\n")

val q1 = IntPoly.fromList [([3],1)]
val _ = print ("q1 = " ^(IntPoly.toString q1) ^ "\n")

val pp0 = IntPolyPoly.one
val _ = print ("pp0 = " ^(IntPolyPoly.toString pp0) ^ "\n")

val pp1 = IntPolyPoly.fromList [([1],r1),([2],r2)]
val _ = print ("pp1 = r1*a1+r2*a2 = " ^(IntPolyPoly.toString pp1) ^ "\n")

val pp2 = IntPolyPoly.fromList [([1],r1),([1],r2)]
val _ = print ("pp2 = r1*x1+r2*x1 = " ^(IntPolyPoly.toString pp2) ^ "\n")

val _ = print ("r1*r2 = " ^(IntPoly.toString (IntPoly.times (r1,r2)))^ "\n")

val pp3 = IntPolyPoly.times (IntPolyPoly.fromList [([1],r1)], 
			     IntPolyPoly.fromList [([1],r2)])

val _ = print ("pp3 = (r1*x1)**(r2*x1) = " 
	       ^ (IntPolyPoly.toString pp3) ^ "\n")

val pp4 = IntPolyPoly.minus pp3

val _ = print ("pp4 = - pp3 = " ^ (IntPolyPoly.toString pp4) ^ "\n")

val _ = print ("pp4 - pp4 = " ^ (IntPolyPoly.toString 
				    (IntPolyPoly.plus (pp4, IntPolyPoly.minus pp4)))
	       ^ "\n")


val qq = IntPolyPoly.fromList [([2],r1),([1],r2)]

val _ = print ("qq = " ^ (IntPolyPoly.toString qq) ^ "\n")

val _ = print  ("qq[pp2,pp2]=" ^ (IntPolyPoly.toString
				      (IntPolyPoly.apply qq [pp2,pp2])) ^ "\n")
***)

(* 多項式を成分とするベクトル *)
structure IntPolyVec = Ring2VecFn (structure Ring = IntPolyRing
                  structure Dimension = struct val dimension = 2 end) : VEC

structure IntPolyVec3 = Ring2VecFn (structure Ring = IntPolyRing
                  structure Dimension = struct val dimension = 3 end) : VEC

(* "整数係数の多項式"を係数とする多項式を成分とする、長さ2のベクトル *)
structure IntPolyPolyVec = Ring2VecFn (structure Ring = IntPolyPolyRing
            structure Dimension = struct val dimension = 2 end) : VEC

structure IntPolyPolyVec3 = Ring2VecFn (structure Ring = IntPolyPolyRing
            structure Dimension = struct val dimension = 3 end) : VEC

(* 多項式を成分とする2x2行列 *)
structure IntPolyMatrix = Ring2MatrixFn (structure Ring = IntPolyRing
                  structure Dimension = struct val dimension = 2 end) : MATRIX

structure IntPolyMatrix3 = Ring2MatrixFn (structure Ring = IntPolyRing
                  structure Dimension = struct val dimension = 3 end) : MATRIX

(* "整数係数の多項式"を係数とする多項式を成分とする、長さ2のベクトル *)
structure IntPolyPolyMatrix = Ring2MatrixFn (structure Ring = IntPolyPolyRing
            structure Dimension = struct val dimension = 2  end) : MATRIX

structure IntPolyPolyMatrix3 = Ring2MatrixFn (structure Ring = IntPolyPolyRing
            structure Dimension = struct val dimension = 3 end) : MATRIX

