Blame view

src/SML/canon.sml 3.88 KB
Ulrich Schoepp committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
structure Canon : CANON = struct

  open Tree

  datatype canon_stm =
      Empty
    | Snoc of canon_stm * stm

  fun append xs Empty = xs
    | append xs (Snoc(ys, y)) = Snoc(append xs ys, y)

  fun toListRev Empty = []
    | toListRev (Snoc(xs, x)) = x :: toListRev xs

  fun toList xs = rev (toListRev xs)

  type canon_exp = canon_stm * exp

  fun isConstant (Const _) = true
    | isConstant (Name _) = true
    | isConstant _ = false

  fun extend ce Empty = ce
    | extend (cs, ce) ss =
      if isConstant ce then
        (append cs ss, ce)
      else
        let val t = Temp.fresh ()
        in
          (append (Snoc(cs, Move(Temp t, ce))) ss,
           Temp(t))
        end

  fun join (ce1 : canon_exp) ((s2, e2) : canon_stm * 'a)
    : canon_stm * (exp * 'a)=
    let val (s1', e1') = extend ce1 s2
    in
      (s1', (e1', e2))
    end

  fun joinList [] = (Empty, [])
    | joinList (e :: es) =
      let val jes = joinList es
          val (s, (e, je)) = join e jes
      in
        (s, e :: je)
      end

  fun lift f (s, e) = (s, f e)

  fun lift2 f ce1 ce2 =
    lift f (join ce1 ce2)

  fun canon_exp (e : exp) =
    case e of
      Const _ => (Empty, e)
    | Name _ => (Empty, e)
    | Temp _ => (Empty, e)
    | Param _ => (Empty, e)
    | BinOp(bo, e1, e2) =>
        let val ce1 = canon_exp_no_top_call e1
            val ce2 = canon_exp_no_top_call e2
        in
           lift2 (fn (x1, x2) => BinOp(bo, x1, x2)) ce1 ce2
        end
    | Mem(e) =>
        let val ce = canon_exp_no_top_call e
        in
           lift (fn x => Mem(x)) ce
        end
    | Call(f, es) =>
        let val ce = canon_exp_no_top_call f
            val ses = joinList (map canon_exp_no_top_call es)
        in
           lift2 (fn (x1, x2) => Call(x1, x2)) ce ses
        end
    | ESeq(s, e) =>
        let val cs = canon_stm (Seq s)
            val (cse, e') = canon_exp e
        in
           (append cs cse, e')
        end
  and canon_exp_no_top_call (e: exp) =
    let val (ss, e') = canon_exp e
    in
      case e' of
        Call _ =>
          let val t = Temp.fresh ()
          in
            (Snoc(ss, Move(Temp(t), e')), Temp t)
          end
      | _ => (ss, e')
    end
  and canon_stm (s: stm) =
    case s of
      Move(Mem(a), e2) =>
        let val ce1 = canon_exp_no_top_call a
            val ce2 = canon_exp(e2)
            val (ss, s) = lift2 (fn (x, y) => Move(Mem(x), y)) ce1 ce2
        in
          Snoc(ss, s)
        end
    | Move(Temp(t), e2) =>
        let val ce2 = canon_exp(e2)
            val (ss, s) = lift (fn y => Move(Temp(t), y)) ce2
        in
          Snoc(ss, s)
        end
    | Move(Param(i), e2) =>
        let val ce2 = canon_exp(e2)
            val (ss, s) = lift (fn y => Move(Param(i), y)) ce2
        in
          Snoc(ss, s)
        end
    | Move(ESeq(s1, e1), e2) =>
        let val cs1 = canon_stm(Seq(s1))
            val cs2 = canon_stm(Move(e1, e2))
        in
          append cs1 cs2
        end
    | Move(_, e) =>
        raise (Fail "MOVE must have MEM, TEMP, PARAM or ESEQ as left operand.")
    | Jump(e, ls) =>
        let val ce = canon_exp_no_top_call e
            val (ss, s) = lift (fn x => Jump(x, ls)) ce
        in
          Snoc(ss, s)
        end
    | CJump(ro, e1, e2, l1, l2) =>
        let val ce1 = canon_exp_no_top_call e1
            val ce2 = canon_exp_no_top_call e2
            val (ss, s) = lift2 (fn (x, y) => CJump(ro, x, y, l1, l2)) ce1 ce2
        in
          Snoc(ss, s)
        end
    | Seq([]) =>
        Empty
    | Seq(s1 :: ss) =>
        let val cs1 = canon_stm s1
            val css = canon_stm (Seq ss)
        in
          append cs1 css
        end
    | Label(l) =>
        Snoc(Empty, Label(l))

  fun canon_func (f: func) : func =
    { name = #name f,
      nparams = #nparams f,
      body = List.concat (map (toList o canon_stm) (#body f)),
      ret = #ret f }

  fun canonize (p: prg) : prg =
    { funcs = map canon_func (#funcs p) }
end