canon.sml 3.88 KB
structure Canon : CANON = struct

  open Tree

  datatype canon_stm =
      Empty
    | Snoc of canon_stm * stm

  fun append xs Empty = xs
    | append xs (Snoc(ys, y)) = Snoc(append xs ys, y)

  fun toListRev Empty = []
    | toListRev (Snoc(xs, x)) = x :: toListRev xs

  fun toList xs = rev (toListRev xs)

  type canon_exp = canon_stm * exp

  fun isConstant (Const _) = true
    | isConstant (Name _) = true
    | isConstant _ = false

  fun extend ce Empty = ce
    | extend (cs, ce) ss =
      if isConstant ce then
        (append cs ss, ce)
      else
        let val t = Temp.fresh ()
        in
          (append (Snoc(cs, Move(Temp t, ce))) ss,
           Temp(t))
        end

  fun join (ce1 : canon_exp) ((s2, e2) : canon_stm * 'a)
    : canon_stm * (exp * 'a)=
    let val (s1', e1') = extend ce1 s2
    in
      (s1', (e1', e2))
    end

  fun joinList [] = (Empty, [])
    | joinList (e :: es) =
      let val jes = joinList es
          val (s, (e, je)) = join e jes
      in
        (s, e :: je)
      end

  fun lift f (s, e) = (s, f e)

  fun lift2 f ce1 ce2 =
    lift f (join ce1 ce2)

  fun canon_exp (e : exp) =
    case e of
      Const _ => (Empty, e)
    | Name _ => (Empty, e)
    | Temp _ => (Empty, e)
    | Param _ => (Empty, e)
    | BinOp(bo, e1, e2) =>
        let val ce1 = canon_exp_no_top_call e1
            val ce2 = canon_exp_no_top_call e2
        in
           lift2 (fn (x1, x2) => BinOp(bo, x1, x2)) ce1 ce2
        end
    | Mem(e) =>
        let val ce = canon_exp_no_top_call e
        in
           lift (fn x => Mem(x)) ce
        end
    | Call(f, es) =>
        let val ce = canon_exp_no_top_call f
            val ses = joinList (map canon_exp_no_top_call es)
        in
           lift2 (fn (x1, x2) => Call(x1, x2)) ce ses
        end
    | ESeq(s, e) =>
        let val cs = canon_stm (Seq s)
            val (cse, e') = canon_exp e
        in
           (append cs cse, e')
        end
  and canon_exp_no_top_call (e: exp) =
    let val (ss, e') = canon_exp e
    in
      case e' of
        Call _ =>
          let val t = Temp.fresh ()
          in
            (Snoc(ss, Move(Temp(t), e')), Temp t)
          end
      | _ => (ss, e')
    end
  and canon_stm (s: stm) =
    case s of
      Move(Mem(a), e2) =>
        let val ce1 = canon_exp_no_top_call a
            val ce2 = canon_exp(e2)
            val (ss, s) = lift2 (fn (x, y) => Move(Mem(x), y)) ce1 ce2
        in
          Snoc(ss, s)
        end
    | Move(Temp(t), e2) =>
        let val ce2 = canon_exp(e2)
            val (ss, s) = lift (fn y => Move(Temp(t), y)) ce2
        in
          Snoc(ss, s)
        end
    | Move(Param(i), e2) =>
        let val ce2 = canon_exp(e2)
            val (ss, s) = lift (fn y => Move(Param(i), y)) ce2
        in
          Snoc(ss, s)
        end
    | Move(ESeq(s1, e1), e2) =>
        let val cs1 = canon_stm(Seq(s1))
            val cs2 = canon_stm(Move(e1, e2))
        in
          append cs1 cs2
        end
    | Move(_, e) =>
        raise (Fail "MOVE must have MEM, TEMP, PARAM or ESEQ as left operand.")
    | Jump(e, ls) =>
        let val ce = canon_exp_no_top_call e
            val (ss, s) = lift (fn x => Jump(x, ls)) ce
        in
          Snoc(ss, s)
        end
    | CJump(ro, e1, e2, l1, l2) =>
        let val ce1 = canon_exp_no_top_call e1
            val ce2 = canon_exp_no_top_call e2
            val (ss, s) = lift2 (fn (x, y) => CJump(ro, x, y, l1, l2)) ce1 ce2
        in
          Snoc(ss, s)
        end
    | Seq([]) =>
        Empty
    | Seq(s1 :: ss) =>
        let val cs1 = canon_stm s1
            val css = canon_stm (Seq ss)
        in
          append cs1 css
        end
    | Label(l) =>
        Snoc(Empty, Label(l))

  fun canon_func (f: func) : func =
    { name = #name f,
      nparams = #nparams f,
      body = List.concat (map (toList o canon_stm) (#body f)),
      ret = #ret f }

  fun canonize (p: prg) : prg =
    { funcs = map canon_func (#funcs p) }
end