diff --git a/dev.cm b/dev.cm new file mode 100644 index 0000000..f395964 --- /dev/null +++ b/dev.cm @@ -0,0 +1,16 @@ +library (0.1.0) + library (lib.cm) + source (-) +is + $BUCHAREST-ML/sml-test/lib.cm + + lib.cm + + test ( + all-suites.sml + frame-suite.sml + ) + + test/test-cases ( + factorial/factorial.sml + ) diff --git a/lib.cm b/lib.cm index 657e0d4..5ee4fa4 100644 --- a/lib.cm +++ b/lib.cm @@ -1,25 +1,41 @@ library (0.1.0) + library ($BUCHAREST-ML/sml-foundation/lib.cm) source (-) is - $/basis.cm $/smlnj-lib.cm - src/attr.sml - src/class.sml - src/class-name.sml - src/const.sml - src/const-pool.sig - src/const-pool.sml - src/descriptor.sml - src/field.sml - src/instr.sml - src/labeled-instr.sml - src/main.sml - src/member.fun - src/method.sml - src/method-handle.sml - src/prim.sml - src/text.sig - src/text.sml - src/util.sig - src/util.sml + $BUCHAREST-ML/sml-foundation/lib.cm + + src ( + attr.sml + class.sml + class-name.sml + const.sml + const-pool.sig + const-pool.sml + descriptor.sml + field.sml + instr.sml + java.sml + labeled-instr.sml + main.sml + member.fun + method.sml + method-handle.sml + prim.sml + text.sig + text.sml + util.sig + util.sml + + extensions/list.sig + extensions/list.sml + extensions/word8-vector.sml + + stack-map/frame.sml + stack-map/stack-lang.sml + stack-map/stack-map.sml + stack-map/verification-type.sml + stack-map/verifier.sig + stack-map/verifier.sml + ) diff --git a/lib/core/lib.cm b/lib/core/lib.cm new file mode 100644 index 0000000..f779cce --- /dev/null +++ b/lib/core/lib.cm @@ -0,0 +1,8 @@ +library (0.1.0) + signature LIST + structure List +is + $/basis.cm + + src/list.sig + src/list.sml diff --git a/lib/core/src/list.sig b/lib/core/src/list.sig new file mode 100644 index 0000000..f5de07a --- /dev/null +++ b/lib/core/src/list.sig @@ -0,0 +1,8 @@ +signature LIST = + sig + include LIST + + val countWhere : ('a -> bool) -> 'a list -> int + + val foldl : 'a list -> { seed : 'b, step : 'a * 'b -> 'b } -> 'b + end diff --git a/lib/core/src/list.sml b/lib/core/src/list.sml new file mode 100644 index 0000000..e29f872 --- /dev/null +++ b/lib/core/src/list.sml @@ -0,0 +1,31 @@ +structure List : LIST = + struct + open LIST + + fun foldl list { seed, step } = List.foldl step seed list + fun foldr list { seed, step } = List.foldr step seed list + + fun takeWhile _ [] = [] + | takeWhile p (x :: xs) = if p x then x :: takeWhile p xs else [] + + fun takeUntil p = takeWhile (not o p) + + fun countWhere predicate list = + raise Fail "not implemented" + + fun bound xs ys a b = + case (xs, ys) of + (x, []) => a + | ([], y) => b + | (_ :: xs, _ :: ys) => recur xs ys a b + + (** + * Returns the list with more elements. + *) + fun max a b = bound a b a b + + (** + * Returns the list with fewer elements. + *) + fun min a b = bound a b b a + end diff --git a/lib/core/src/other.sml b/lib/core/src/other.sml new file mode 100644 index 0000000..58f79b0 --- /dev/null +++ b/lib/core/src/other.sml @@ -0,0 +1,11 @@ +signature FOLDABLE = + sig + type 'a t + type ('a, 'b) arrow = 'a -> 'b + type 'a monoid = { + zero : 'a, + plus : 'a * 'a -> 'a + } + + val foldMap : 'm monoid -> ('a -> 'm) -> 'a t -> 'm + end diff --git a/log.md b/log.md new file mode 100644 index 0000000..8479d33 --- /dev/null +++ b/log.md @@ -0,0 +1,175 @@ +# Work Log + +## 2024-05-20 11:19:12 + +https://github.com/GaloisInc/jvm-verifier + +Continue here: + +``` +- CM.make "dev.cm"; Factorial.main (); +[scanning dev.cm] +[scanning $BUCHAREST-ML/sml-test/lib.cm] +[scanning $BUCHAREST-ML/sml-foundation/lib.cm] +[scanning (dev.cm):lib.cm] +[New bindings added.] +val it = true : bool + +uncaught exception Fail [Fail: not implemented: leastUpperBound: Reference =/= Integer] + raised at: src/stack-map/verification-type.sml:47.19-47.98 +``` + +## 2023-04-24 09:56:20 + +> Each stack map frame described in the entries table relies on the previous +> frame for some of its semantics. The first stack map frame of a method is +> implicit, and computed from the method descriptor by the type checker +> (§4.10.1.6). The stack_map_frame structure at entries[0] therefore describes +> the second stack map frame of the method. + +— JVMS20, §4.7.4 + +Okay... so I'll have to generate all the stack frames and just after that drop +the ones that are not required by the specification: + +> It is illegal to have code after an **unconditional branch** without a stack +> map frame being provided for it. + +— JVMS20, §4.10.1.6 + +> • Conditional branch: ifeq, ifne, iflt, ifle, ifgt, ifge, ifnull, ifnonnull, if_icmpeq, if_icmpne, if_icmplt, if_icmple, if_icmpgt if_icmpge, if_acmpeq, if_acmpne. +> • Compound conditional branch: tableswitch, lookupswitch. +> • **Unconditional branch**: goto, goto_w, jsr, jsr_w, ret. + +— JVMS20, §2.11.7 + +Steps: + + 1. Build a control-flow graph of the instructions + 2. Basic blocks + +Could we generate the StackMap frames without building the CFG? And only use +a CFG when we want to keep just the required frames? + +--- + +From: From Stack Maps to Software Certificates [slides], slide 6 + +> [Java bytecode verification] is formalized as a data flow problem. + +--- + +Slide 27: + +> BCV typing with interfaces +> Two kind of reference types : classes and interfaces. +> Interfaces introduces a form of intersection types. that must be represented in the type hierarchy. +> The byte code verifier opts for an alternative solution : +> Treat interfaces as java.lang.Object and defer type checking to run-time. +> Type checking rule : +> isJavaAssignable(class(_,_), class(To, L)) :- loadedClass(To, L, ToClass), classIsInterface(ToClass). + +## Efficient Bytecode Verification and Compilation in a Virtual Machine + +Figure 2.1, page 5 + +``` +todo ← true +while todo = true do + todo ← false + for all i in all instructions of a method do + if i was changed then + todo ← true + check whether stack and local variable types match definition of i + calculate new state after i + for all s in all successor instructions of i do + if current state for s ̸= new state derived from i then + assume state after i as new entry state for s + mark s as changed + end if + end for + end if + end for +end while +``` + +## 2023-04-22 22:06:59 + +Taken from a draft in igstan.ro: + +--- +title: JVM Bytecode Verification +author: Ionuț G. Stan +--- + +I'll briefly describe the process of bytecode verification that the JVM +performs during the loading of .class files. + +## Motivation + +Why was verification needed? + +## Complications + +There's a hierarchy of verification types that induces a notion of subtyping. +Store and load instructions need to check subtyping conformance and this +requires loading external classes into the system. I'm hoping to avoid this +somehow, as it requires the side-effect of reading files form disk. In +addition, it means that I might need to write a parser for .class files and a +JAR (which are ZIP archives) reader. + +## 2022-04-16 11:37:19 + +> An additional problem with compile-time checking is **version skew**. A user +> may have successfully compiled a class, say PurchaseStockOptions, to be a +> subclass of TradingClass. But the definition of TradingClass might have +> changed since the time the class was compiled in a way that is not compatible +> with pre-existing binaries. + +— JVMS8, §4.10 + +> The intent is that a **stack map frame** must appear at the beginning of each +> **basic block** in a method. The stack map frame specifies the verification +> type of each operand stack entry and of each local variable at the start of +> each basic block. + +— JVMS8, §4.10.1 + +### Read + + - Lightweight Bytecode Verification + +## 2022-04-17 14:09:19 + +### Java Bytecode Verification — An Overview + +For every instruction `i`: + +``` +i : in(i) → out(i) +in(i) = lub { out(j) | j predecessor i } +``` + +i₀ = first instruction + +``` +in(i₀) = (ε, (P₀ ... Pₙ₋₁, ⊤ ... ⊤)) +``` + +Pₖ are the types of the parameter methods + +> The dataflow framework presented above requires that the type algebra, +> ordered by the subtyping relation, constitutes a semi-lattice. +— §3.3 + +### Dataflow Analysis [slides] + +> Available expressions is a **forward must** analysis +> +> • **Forward** = Data flow from in to out +> • **Must** = At join points, only keep facts that hold on all paths that are joined + +> Liveness is a **backwards may** analysis +> • To know if a variable is live, we need to look at the future uses of it. We +> propagate facts backwards, from Out to In. +> • Variable is live if it is used on some path diff --git a/readme.md b/readme.md index 8d68c9d..73488aa 100644 --- a/readme.md +++ b/readme.md @@ -1,5 +1,11 @@ # JVM bytecode assembler in Standard ML +## Similar Projects + + - BiteScript: https://github.com/headius/bitescript + +## Running Main + ``` $ sml Standard ML of New Jersey v110.80 [built: Sun Aug 28 21:15:09 2016] @@ -36,3 +42,19 @@ Hello, World! val it = () : unit - ``` + +## Running Tests + +``` +- CM.make "dev.cm"; AllSuites.run (); +[scanning dev.cm] +[scanning $BUCHAREST-ML/sml-test/lib.cm] +[scanning $BUCHAREST-ML/sml-foundation/lib.cm] +[scanning (dev.cm):lib.cm] +[New bindings added.] +val it = true : bool +[+] 1 < 2 +[-] 1 = 1 — NotEqual +val it = () : unit +- +``` diff --git a/src/attr.sml b/src/attr.sml index 467a4b2..cfeade0 100644 --- a/src/attr.sml +++ b/src/attr.sml @@ -6,7 +6,7 @@ structure ExceptionInfo = structure ConstantValue = struct datatype t = - Integer of Integer.t + | Integer of Integer.t | Long of Long.t | Float of Float.t | Double of Double.t @@ -19,14 +19,14 @@ structure Attr = open Util datatype t = - Custom + | Custom | ConstantValue of ConstantValue.t | Code of { code : LabeledInstr.t list, exceptionTable : ExceptionInfo.t list, attributes : t list } - | StackMapTable + | StackMapTable of StackMap.frame list | Exceptions of ClassName.t list | BootstrapMethods of ConstPool.bootstrap_method list | InnerClasses @@ -48,9 +48,12 @@ structure Attr = | LocalVariableTypeTable | Deprecated - fun compile constPool attr = + fun isStackMapTable attr = + case attr of StackMapTable _ => true | _ => false + + fun compile constPool attr nameAndType = case attr of - Code code => compileCode constPool code + | Code code => compileCode constPool code nameAndType | ConstantValue value => compileConstantValue constPool value | Exceptions exceptions => compileExceptions constPool exceptions | Synthetic => compileSynthetic constPool @@ -58,21 +61,59 @@ structure Attr = | Signature typeSignature => compileSignature constPool typeSignature | SourceFile value => compileSourceFile constPool value | BootstrapMethods methods => compileBootstrapMethods constPool methods + | StackMapTable frames => compileStackMapTable constPool frames | attribute => raise Fail "not implemented" (* https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-4.html#jvms-4.7.3 *) - and compileCode constPool { code, exceptionTable, attributes } = + and compileCode constPool { code, exceptionTable, attributes } nameAndType = let fun compileExceptions constPool exceptionTable = (u2 0, constPool) (* TODO: add exceptions *) - fun compileAttributes constPool attributes = - (u2 0, constPool) (* TODO: add attributes *) + fun compileAttributes constPool stackMapTable attributes = + let + open List.Op infixr <&> + + val detectStackMapTable = { + seed = false, + step = fn (attr, seen) => seen orelse isStackMapTable attr + } + + val compileAttr = { + seed = { bytes = vec [], length = 0, constPool = constPool }, + step = fn (attr, { bytes, length, constPool }) => + let + val (attrBytes, constPool) = compile constPool attr nameAndType + in + { + bytes = Word8Vector.concat [bytes, attrBytes], + length = length + 1, + constPool = constPool + } + end + } + + val (seenStackMapTable, { bytes, length, constPool }) = + List.stepl (detectStackMapTable <&> compileAttr) attributes + in + if seenStackMapTable + then (Word8Vector.concat [u2 length, bytes], constPool) + else + let + (*) Add the StackMapTable attribute if the caller hasn't + (*) provided one already. + val (attrBytes, constPool) = compile constPool stackMapTable nameAndType + val bytes = Word8Vector.concat [bytes, attrBytes] + in + (Word8Vector.concat [u2 (length + 1), bytes], constPool) + end + end val (attrNameIndex, constPool) = ConstPool.withUtf8 constPool "Code" - val (instrBytes, constPool) = compileInstructions constPool code + (* TODO: generate and add StackMapTable only if version >= 50 *) + val (instrBytes, constPool, stackMapAttr) = compileInstructions constPool code nameAndType val (exceptionBytes, constPool) = compileExceptions constPool exceptionTable - val (attributeBytes, constPool) = compileAttributes constPool attributes + val (attributeBytes, constPool) = compileAttributes constPool stackMapAttr attributes val attributeLength = Word8Vector.length instrBytes + Word8Vector.length exceptionBytes + @@ -88,18 +129,32 @@ structure Attr = (bytes, constPool) end - and compileInstructions constPool code = + and displayVerifierResult result = let - val result = LabeledInstr.compileList constPool code + fun displayStackLangList list = + String.concatWith "; " (List.map StackLang.toString list) + in + List.app (fn { instrs, offset } => Console.println (Int.toString offset ^ ": "^ displayStackLangList instrs)) result + end + and compileInstructions constPool code nameAndType = + let + val result = LabeledInstr.compileList constPool code + val maxLocals = #maxLocals result + val stackLang = Verifier.verify (#offsetedInstrs result) + val () = Console.println "---------------------------------------------" + val () = displayVerifierResult stackLang + val stackMapFrames = StackLang.compileCompact (StackLang.interpret stackLang (Option.valOf nameAndType) maxLocals) + (* val () = List.app (Console.println o StackMap.toString) stackMapFrames *) + val stackMapAttr = StackMapTable stackMapFrames val bytes = Word8Vector.concat [ u2 (#maxStack result), - u2 (#maxLocals result), + u2 maxLocals, u4 (Word8Vector.length (#bytes result)), (#bytes result) ] in - (bytes, (#constPool result)) + (bytes, #constPool result, stackMapAttr) end and compileConstantValue constPool value = @@ -107,7 +162,7 @@ structure Attr = val (attrNameIndex, constPool) = ConstPool.withUtf8 constPool "ConstantValue" val (constValueIndex, constPool) = case value of - ConstantValue.Integer value => raise Fail "not implemented" + | ConstantValue.Integer value => raise Fail "not implemented" | ConstantValue.Long value => raise Fail "not implemented" | ConstantValue.Float value => raise Fail "not implemented" | ConstantValue.Double value => raise Fail "not implemented" @@ -216,9 +271,24 @@ structure Attr = (bytes, constPool) end + and compileStackMapTable constPool frames = + let + val (stackMapBytes, constPool) = StackMap.compileFrames constPool frames + val (attrIndex, constPool) = ConstPool.withUtf8 constPool "StackMapTable" + val attributeLength = 2 + Word8Vector.length stackMapBytes + val bytes = Word8Vector.concat [ + u2 attrIndex, + u4 attributeLength, + u2 (List.length frames), + stackMapBytes + ] + in + (bytes, constPool) + end + fun minimumVersion attr = case attr of - Custom => { major = 45, minor = 3 } + | Custom => { major = 45, minor = 3 } | SourceFile _ => { major = 45, minor = 3 } | InnerClasses => { major = 45, minor = 3 } | ConstantValue _ => { major = 45, minor = 3 } @@ -237,7 +307,7 @@ structure Attr = | RuntimeVisibleAnnotations => { major = 49, minor = 0 } | RuntimeInvisibleAnnotations => { major = 49, minor = 0 } | LocalVariableTypeTable => { major = 49, minor = 0 } - | StackMapTable => { major = 50, minor = 0 } + | StackMapTable _ => { major = 50, minor = 0 } | BootstrapMethods _ => { major = 51, minor = 0 } | MethodParameters => { major = 52, minor = 0 } | RuntimeVisibleTypeAnnotations => { major = 52, minor = 0 } diff --git a/src/class-name.sml b/src/class-name.sml index ed62c40..f8eef3e 100644 --- a/src/class-name.sml +++ b/src/class-name.sml @@ -4,6 +4,6 @@ structure ClassName = struct type t = Text.t - fun fromParts parts = Text.concatWith "/" parts + fun fromParts parts = String.concatWith "/" parts fun fromString s = s end diff --git a/src/class.sml b/src/class.sml index 22fe803..a6aded8 100644 --- a/src/class.sml +++ b/src/class.sml @@ -3,7 +3,7 @@ structure Class = structure Flag = struct datatype t = - PUBLIC + | PUBLIC | FINAL | SUPER | INTERFACE @@ -14,7 +14,7 @@ structure Class = fun compile flag : Word.word = case flag of - PUBLIC => 0wx0001 + | PUBLIC => 0wx0001 | FINAL => 0wx0010 | SUPER => 0wx0020 | INTERFACE => 0wx0200 @@ -50,7 +50,7 @@ structure Class = val constPool = ConstPool.empty val magic = vec [0wxCA, 0wxFE, 0wxBA, 0wxBE] val minorVersion = u2 0 - val majorVersion = u2 49 + val majorVersion = u2 52 val (thisClassIndex, constPool) = ConstPool.withClass constPool thisClass val (superClassIndex, constPool) = ConstPool.withClass constPool superClass @@ -82,7 +82,7 @@ structure Class = if List.null bootstrapMethods then attributes else Attr.BootstrapMethods bootstrapMethods :: attributes - val (attrsBytes, constPool) = compileMany Attr.compile constPool attributes + val (attrsBytes, constPool) = compileMany (fn cp => fn attr => Attr.compile cp attr NONE) constPool attributes val constPoolBytes = ConstPool.compile constPool in Word8Vector.concat [ diff --git a/src/compilable.sig b/src/compilable.sig new file mode 100644 index 0000000..03cd252 --- /dev/null +++ b/src/compilable.sig @@ -0,0 +1,33 @@ +(* Reader Monad *) +signature CONFIGURABLE = + sig + type 'a t + type config + + val from : (config -> 'a) -> 'a t + + val get : key -> + + val run : config -> 'a t -> 'a + end + +structure Configurable = + struct + type 'computation t = int + + fun from f = + raise Fail "not implemented" + + fun run config = + end + +Configurable.from (fn config => + +) + +signature COMPILABLE = + sig + type t + + val compile : ConstPool.t -> t -> (Word8Vector.vector, ConstPool.t) Configurable.t + end diff --git a/src/const-pool.sml b/src/const-pool.sml index 9d1e053..249d3aa 100644 --- a/src/const-pool.sml +++ b/src/const-pool.sml @@ -28,7 +28,7 @@ structure ConstPool :> CONST_POOL = end) datatype entry = - Class of entry_index + | Class of entry_index | String of entry_index | Utf8 of Text.t | Long of Long.t @@ -48,7 +48,7 @@ structure ConstPool :> CONST_POOL = fun ordinal entry = case entry of - Class _ => 0 + | Class _ => 0 | String _ => 1 | Utf8 _ => 2 | Long _ => 3 @@ -65,12 +65,12 @@ structure ConstPool :> CONST_POOL = fun tupleCompare ((a, x), (b, y)) = case Int.compare (a, b) of - EQUAL => Int.compare (x, y) + | EQUAL => Int.compare (x, y) | other => other fun compare operands = case operands of - (Class a, Class b) => Int.compare (a, b) + | (Class a, Class b) => Int.compare (a, b) | (String a, String b) => Int.compare (a, b) | (Utf8 a, Utf8 b) => Text.compare (a, b) | (Long a, Long b) => Long.compare (a, b) @@ -109,7 +109,7 @@ structure ConstPool :> CONST_POOL = fun withEntry (constPool as { counter, entries, bootstrap }) entry = case Map.find (entries, entry) of - SOME entryIndex => (entryIndex, constPool) + | SOME entryIndex => (entryIndex, constPool) | NONE => let val counter = counter + 1 @@ -174,7 +174,7 @@ structure ConstPool :> CONST_POOL = open MethodHandle val makeSymbolRef = case kind of - GetField => withSymbolRef Fieldref + | GetField => withSymbolRef Fieldref | GetStatic => withSymbolRef Fieldref | PutField => withSymbolRef Fieldref | PutStatic => withSymbolRef Fieldref @@ -204,7 +204,7 @@ structure ConstPool :> CONST_POOL = let val (index, constPool) = case methodParam of - Const.Integer a => withInteger constPool a + | Const.Integer a => withInteger constPool a | Const.Float a => withFloat constPool a | Const.String a => withString constPool a | Const.Class a => withClass constPool a @@ -226,7 +226,7 @@ structure ConstPool :> CONST_POOL = val entry = { methodRef = methodHandleIndex, arguments = arguments } in case BootstrapMethodsMap.find (bootEntries, entry) of - SOME entryIndex => (entryIndex, constPool) + | SOME entryIndex => (entryIndex, constPool) | NONE => let val newEntries = BootstrapMethodsMap.insert (bootEntries, entry, bootCounter) @@ -268,7 +268,7 @@ structure ConstPool :> CONST_POOL = let val entryBytes = case entry of - Class entryIndex => Word8Vector.concat [vec [0wx7], u2 entryIndex] + | Class entryIndex => Word8Vector.concat [vec [0wx7], u2 entryIndex] | String entryIndex => Word8Vector.concat [vec [0wx8], u2 entryIndex] | Utf8 value => let diff --git a/src/const.sml b/src/const.sml index 1b4c784..9188fe1 100644 --- a/src/const.sml +++ b/src/const.sml @@ -1,7 +1,7 @@ structure Const = struct datatype t = - Integer of Integer.t + | Integer of Integer.t | Float of Float.t | Long of Long.t | Double of Double.t diff --git a/src/descriptor.sml b/src/descriptor.sml index f39b99c..ad12a83 100644 --- a/src/descriptor.sml +++ b/src/descriptor.sml @@ -4,12 +4,12 @@ structure Descriptor = struct datatype t = - Raw of Text.t + | Raw of Text.t | Field of simple | Method of { params : simple list, return : return } and simple = - Bool + | Bool | Byte | Char | Double @@ -21,14 +21,14 @@ structure Descriptor = | Array of simple and return = - Void + | Void | Type of simple - fun fromString s = Raw s + fun fromString s = Raw s (* ← parse descriptor *) fun paramsCount descriptor = case descriptor of - Raw d => 1 (* TODO *) + | Raw d => 1 (* TODO *) | Field _ => raise Fail "paramsCount called on field descriptor" | Method { params, ... } => let @@ -43,7 +43,7 @@ structure Descriptor = fun returnCount descriptor = case descriptor of - Method { params, return = Void } => 0 + | Method { params, return = Void } => 0 | Method _ => 1 | Field _ => raise Fail "returnCount called on field descriptor" | Raw d => 0 (* TODO *) @@ -52,7 +52,7 @@ structure Descriptor = let fun simple f = case f of - Byte => "B" + | Byte => "B" | Char => "C" | Double => "D" | Float => "F" @@ -64,11 +64,11 @@ structure Descriptor = | Array elemType => "[" ^ simple elemType fun return r = case r of - Void => "V" + | Void => "V" | Type t => simple t in case descriptor of - Raw d => d + | Raw d => d | Field f => simple f | Method { params, return = r } => "(" ^ String.concat (List.map simple params) ^ ")" ^ return r diff --git a/src/extensions/list.sig b/src/extensions/list.sig new file mode 100644 index 0000000..69462f4 --- /dev/null +++ b/src/extensions/list.sig @@ -0,0 +1,37 @@ +signature LIST = + sig + include LIST + + type ('a, 's) stepper = { + seed : 's, + step : 'a * 's -> 's + } + + val foldMapState : + 'a list + -> { + monoid : 'b -> 'b -> 'b, + step : 'a -> 's -> ('b * 's), + seed : 'b * 's + } + -> ('b * 's) + + val stepl : ('a, 's) stepper -> 'a list -> 's + val stepr : ('a, 's) stepper -> 'a list -> 's + + structure Op : + sig + (** + * Operator for composing fold functions. + *) + val & : + ('a * 'state_1 -> 'state_1) * ('a * 'state_2 -> 'state_2) + -> 'a * ('state_1 * 'state_2) + -> 'state_1 * 'state_2 + + (** + * Operator for composing `stepper`s. + *) + val <&> : ('a, 'b) stepper * ('a, 'c) stepper -> ('a, 'b * 'c) stepper + end + end diff --git a/src/extensions/list.sml b/src/extensions/list.sml new file mode 100644 index 0000000..0678db8 --- /dev/null +++ b/src/extensions/list.sml @@ -0,0 +1,38 @@ +structure List : LIST = + struct + open List + + type ('a, 's) stepper = { + seed : 's, + step : 'a * 's -> 's + } + + fun foldMapState list { monoid, step, seed } = + let + fun fold (elem, (r1, state)) = + let + val (r2, state) = step elem state + in + (monoid r1 r2, state) + end + in + List.foldl fold seed list + end + + fun stepl { step, seed } = List.foldl step seed + fun stepr { step, seed } = List.foldr step seed + + structure Op = + struct + infixr & + + (* https://smlnj-gforge.cs.uchicago.edu/tracker/?group_id=33&atid=215&func=detail&aid=129 *) + fun f & g = + fn (elem, (a, b)) => (f (elem, a), g (elem, b)) + + fun <&> (f : ('a, 's) stepper, g : ('a, 't) stepper) = { + seed = (#seed f, #seed g), + step = #step f & #step g + } + end + end diff --git a/src/extensions/word8-vector.sml b/src/extensions/word8-vector.sml new file mode 100644 index 0000000..8daadd2 --- /dev/null +++ b/src/extensions/word8-vector.sml @@ -0,0 +1,6 @@ +structure Word8Vector = + struct + open Word8Vector + + fun join a b = concat [a, b] + end diff --git a/src/field.sml b/src/field.sml index a4fea04..a31de57 100644 --- a/src/field.sml +++ b/src/field.sml @@ -3,7 +3,7 @@ structure Field = structure Flag = struct datatype t = - PUBLIC + | PUBLIC | PRIVATE | PROTECTED | STATIC @@ -15,7 +15,7 @@ structure Field = fun compile flag : Word.word = case flag of - PUBLIC => 0wx0001 + | PUBLIC => 0wx0001 | PRIVATE => 0wx0002 | PROTECTED => 0wx0004 | STATIC => 0wx0008 diff --git a/src/index.sml b/src/index.sml new file mode 100644 index 0000000..7f723a8 --- /dev/null +++ b/src/index.sml @@ -0,0 +1,4 @@ +structure Index = + struct + type t = int + end diff --git a/src/indexed-instr.sml b/src/indexed-instr.sml new file mode 100644 index 0000000..34416ec --- /dev/null +++ b/src/indexed-instr.sml @@ -0,0 +1,18 @@ +signature INDEXED_INSTR = + sig + type t + + val index : t -> Index.t + val instr : t -> Instr.t + end + + +structure IndexedInstr : INDEXED_INSTR = + struct + type t = (Index.t, Instr.t) + + val fromPair = Fn.id + + fun index (i, _) => i + fun instr (_, i) => i + end diff --git a/src/instr.sml b/src/instr.sml index 4aa6352..5943f38 100644 --- a/src/instr.sml +++ b/src/instr.sml @@ -1,7 +1,7 @@ structure ArrayType = struct datatype t = - BOOLEAN + | BOOLEAN | CHAR | FLOAT | DOUBLE @@ -12,7 +12,7 @@ structure ArrayType = fun compile t : Word8.word = case t of - BOOLEAN => 0w4 + | BOOLEAN => 0w4 | CHAR => 0w5 | FLOAT => 0w6 | DOUBLE => 0w7 @@ -30,7 +30,7 @@ structure Instr = type index = Word8.word datatype t = - nop (* Constants *) + | nop (* Constants *) | aconst_null | iconst_m1 | iconst_0 @@ -277,9 +277,38 @@ structure Instr = infix +: val op +: = Word8Vector.prepend infix :+ val op :+ = Word8Vector.append + fun storeIndex instr = + case instr of + | istore i => SOME (Word8.toInt i) + | lstore i => SOME (Word8.toInt i) + | fstore i => SOME (Word8.toInt i) + | dstore i => SOME (Word8.toInt i) + | astore i => SOME (Word8.toInt i) + | istore_0 => SOME 0 + | lstore_0 => SOME 0 + | fstore_0 => SOME 0 + | dstore_0 => SOME 0 + | astore_0 => SOME 0 + | istore_1 => SOME 1 + | lstore_1 => SOME 1 + | fstore_1 => SOME 1 + | dstore_1 => SOME 1 + | astore_1 => SOME 1 + | istore_2 => SOME 2 + | lstore_2 => SOME 2 + | fstore_2 => SOME 2 + | dstore_2 => SOME 2 + | astore_2 => SOME 2 + | istore_3 => SOME 3 + | lstore_3 => SOME 3 + | fstore_3 => SOME 3 + | dstore_3 => SOME 3 + | astore_3 => SOME 3 + | _ => NONE + fun compile constPool instr = case instr of - nop => (vec [0wx0], 0, constPool) + | nop => (vec [0wx0], 0, constPool) | aconst_null => (vec [0wx1], 1, constPool) | iconst_m1 => (vec [0wx2], 1, constPool) | iconst_0 => (vec [0wx3], 1, constPool) @@ -316,7 +345,7 @@ structure Instr = end in case const of - Const.Integer a => ldc ConstPool.withInteger a + | Const.Integer a => ldc ConstPool.withInteger a | Const.Float a => ldc ConstPool.withFloat a | Const.String a => ldc ConstPool.withString a | Const.Class a => ldc ConstPool.withClass a @@ -641,4 +670,210 @@ structure Instr = | breakpoint => (vec [0wxCA], 0, constPool) | impdep1 => (vec [0wxFE], 0, constPool) | impdep2 => (vec [0wxFF], 0, constPool) + + fun toString instr = + case instr of + | nop => "nop" + | aconst_null => "aconst_null" + | iconst_m1 => "iconst_m1" + | iconst_0 => "iconst_0" + | iconst_1 => "iconst_1" + | iconst_2 => "iconst_2" + | iconst_3 => "iconst_3" + | iconst_4 => "iconst_4" + | iconst_5 => "iconst_5" + | lconst_0 => "lconst_0" + | lconst_1 => "lconst_1" + | fconst_0 => "fconst_0" + | fconst_1 => "fconst_1" + | fconst_2 => "fconst_2" + | dconst_0 => "dconst_0" + | dconst_1 => "dconst_1" + | bipush word => "bipush" + | sipush short => "sipush" + | ldc const => "ldc" + | iload index => "iload" + | lload index => "lload" + | fload index => "fload" + | dload index => "dload" + | aload index => "aload" + | iload_0 => "iload_0" + | iload_1 => "iload_1" + | iload_2 => "iload_2" + | iload_3 => "iload_3" + | lload_0 => "lload_0" + | lload_1 => "lload_1" + | lload_2 => "lload_2" + | lload_3 => "lload_3" + | fload_0 => "fload_0" + | fload_1 => "fload_1" + | fload_2 => "fload_2" + | fload_3 => "fload_3" + | dload_0 => "dload_0" + | dload_1 => "dload_1" + | dload_2 => "dload_2" + | dload_3 => "dload_3" + | aload_0 => "aload_0" + | aload_1 => "aload_1" + | aload_2 => "aload_2" + | aload_3 => "aload_3" + | iaload => "iaload" + | laload => "laload" + | faload => "faload" + | daload => "daload" + | aaload => "aaload" + | baload => "baload" + | caload => "caload" + | saload => "saload" + | istore index => "istore" + | lstore index => "lstore" + | fstore index => "fstore" + | dstore index => "dstore" + | astore index => "astore" + | istore_0 => "istore_0" + | istore_1 => "istore_1" + | istore_2 => "istore_2" + | istore_3 => "istore_3" + | lstore_0 => "lstore_0" + | lstore_1 => "lstore_1" + | lstore_2 => "lstore_2" + | lstore_3 => "lstore_3" + | fstore_0 => "fstore_0" + | fstore_1 => "fstore_1" + | fstore_2 => "fstore_2" + | fstore_3 => "fstore_3" + | dstore_0 => "dstore_0" + | dstore_1 => "dstore_1" + | dstore_2 => "dstore_2" + | dstore_3 => "dstore_3" + | astore_0 => "astore_0" + | astore_1 => "astore_1" + | astore_2 => "astore_2" + | astore_3 => "astore_3" + | iastore => "iastore" + | lastore => "lastore" + | fastore => "fastore" + | dastore => "dastore" + | aastore => "aastore" + | bastore => "bastore" + | castore => "castore" + | sastore => "sastore" + | pop => "pop" + | pop2 => "pop2" + | dup => "dup" + | dup_x1 => "dup_x1" + | dup_x2 => "dup_x2" + | dup2 => "dup2" + | dup2_x1 => "dup2_x1" + | dup2_x2 => "dup2_x2" + | swap => "swap" + | iadd => "iadd" + | ladd => "ladd" + | fadd => "fadd" + | dadd => "dadd" + | isub => "isub" + | lsub => "lsub" + | fsub => "fsub" + | dsub => "dsub" + | imul => "imul" + | lmul => "lmul" + | fmul => "fmul" + | dmul => "dmul" + | idiv => "idiv" + | ldiv => "ldiv" + | fdiv => "fdiv" + | ddiv => "ddiv" + | irem => "irem" + | lrem => "lrem" + | frem => "frem" + | drem => "drem" + | ineg => "ineg" + | lneg => "lneg" + | fneg => "fneg" + | dneg => "dneg" + | ishl => "ishl" + | lshl => "lshl" + | ishr => "ishr" + | lshr => "lshr" + | iushr => "iushr" + | lushr => "lushr" + | iand => "iand" + | land => "land" + | ior => "ior" + | lor => "lor" + | ixor => "ixor" + | lxor => "lxor" + | iinc (index, inc) => "iinc" + | i2l => "i2l" + | i2f => "i2f" + | i2d => "i2d" + | l2i => "l2i" + | l2f => "l2f" + | l2d => "l2d" + | f2i => "f2i" + | f2l => "f2l" + | f2d => "f2d" + | d2i => "d2i" + | d2l => "d2l" + | d2f => "d2f" + | i2b => "i2b" + | i2c => "i2c" + | i2s => "i2s" + | lcmp => "lcmp" + | fcmpl => "fcmpl" + | fcmpg => "fcmpg" + | dcmpl => "dcmpl" + | dcmpg => "dcmpg" + | ifeq offset => "ifeq" + | ifne offset => "ifne" + | iflt offset => "iflt" + | ifge offset => "ifge" + | ifgt offset => "ifgt" + | ifle offset => "ifle" + | if_icmpeq offset => "if_icmpeq" + | if_icmpne offset => "if_icmpne" + | if_icmplt offset => "if_icmplt" + | if_icmpge offset => "if_icmpge" + | if_icmpgt offset => "if_icmpgt" + | if_icmple offset => "if_icmple" + | if_acmpeq offset => "if_acmpeq" + | if_acmpne offset => "if_acmpne" + | getstatic _ => "getstatic" + | putstatic _ => "putstatic" + | getfield _ => "getfield" + | putfield _ => "putfield" + | invokevirtual _ => "invokevirtual" + | invokespecial _ => "invokespecial" + | invokestatic _ => "invokestatic" + | invokeinterface _ => "invokeinterface" + | invokedynamic _ => "invokedynamic" + | new className => "new" + | newarray _ => "newarray" + | anewarray index => "anewarray" + | arraylength => "arraylength" + | athrow => "athrow" + | checkcast index => "checkcast" + | instanceof index => "instanceof" + | monitorenter => "monitorenter" + | monitorexit => "monitorexit" + | goto offset => "goto" + | jsr offset => "jsr" + | ret index => "ret" + | tableswitch => "tableswitch" + | lookupswitch => "lookupswitch" + | ireturn => "ireturn" + | lreturn => "lreturn" + | freturn => "freturn" + | dreturn => "dreturn" + | areturn => "areturn" + | return => "return" + | wide => "wide" + | multianewarray _ => "multianewarray" + | ifnull offset => "ifnull" + | ifnonnull offset => "ifnonnull" + | goto_w offset => "goto_w" + | jsr_w offset => "jsr_w" + | breakpoint => "breakpoint" + | impdep1 => "impdep1" + | impdep2 => "impdep2" end diff --git a/src/java.sml b/src/java.sml new file mode 100644 index 0000000..283a349 --- /dev/null +++ b/src/java.sml @@ -0,0 +1,46 @@ +structure java = + let + structure D = Descriptor + val class = ClassName.fromString + in + struct + structure lang = + struct + structure Integer = + struct + val toString = { + class = class "java/lang/Integer", + name = "toString", + descriptor = D.Method { + params = [D.Int], + return = D.Type (D.Object (class "java/lang/String")) + } + } + end + + structure System = + struct + val out = { + class = class "java/lang/System", + name = "out", + descriptor = D.Field (D.Object (class "java/io/PrintStream")) + } + end + end + + structure io = + struct + structure PrintStream = + struct + val println = { + class = class "java/io/PrintStream", + name = "println", + descriptor = D.Method { + params = [D.Object (class "java/lang/String")], + return = D.Void + } + } + end + end + end + end diff --git a/src/labeled-instr.sml b/src/labeled-instr.sml index a803eb7..82eefa8 100644 --- a/src/labeled-instr.sml +++ b/src/labeled-instr.sml @@ -8,7 +8,7 @@ structure LabeledInstr = type label = string datatype t = - INSTR of Instr.t + | INSTR of Instr.t | LABEL of label | GOTO of { label : label, @@ -221,124 +221,187 @@ structure LabeledInstr = val impdep1 = INSTR Instr.impdep1 val impdep2 = INSTR Instr.impdep2 + fun toString instr = + case instr of + | INSTR instr => "INSTR " ^ Instr.toString instr + | LABEL label => "LABEL " ^ label + | GOTO { label, instr, ... } => + "GOTO (" ^ label ^ ", " ^ Instr.toString (instr 0) ^ ")" + structure LabelMap = BinaryMapFn(struct type ord_key = string val compare = String.compare end) - structure State = - struct - type t = { - constPool : ConstPool.t, - stackSize : int, - maxStack : int, - maxLocals : int, - bytes : Word8Vector.vector, - seenLabels : Instr.offset LabelMap.map - } - end - fun compileList constPool instrs = let - fun traverse [] state = state - | traverse (instr :: rest) (state as { offset, constPool, stackSize, maxStack, maxLocals, bytes, seenLabels }) = - case instr of - GOTO { label, instr, byteCount } => let in - case LabelMap.find (seenLabels, label) of - SOME labelOffset => - let - val instr = instr (labelOffset - offset) - val (opcodes, stackDiff, constPool) = Instr.compile constPool instr - in - traverse rest { - offset = offset + Word8Vector.length opcodes, - constPool = constPool, - stackSize = stackSize + stackDiff, - maxStack = Int.max (maxStack, stackSize + stackDiff), - maxLocals = maxLocals, - bytes = Word8Vector.concat [bytes, opcodes], - seenLabels = seenLabels - } - end - | NONE => - let + fun traverseLabel label state rest = + let + val { index, offset, constPool, stackSize, maxStack, ... } = state + val { maxLocals, bytes, seenLabels, offsetedInstrs, ... } = state + in + traverse rest { + index = index, + offset = offset, + constPool = constPool, + stackSize = stackSize, + maxStack = maxStack, + maxLocals = maxLocals, + bytes = bytes, + seenLabels = LabelMap.insert (seenLabels, label, (offset, index)), + offsetedInstrs = offsetedInstrs + } + end + + and traverseInstr instr state rest = + let + val { index, offset, constPool, stackSize, maxStack, ... } = state + val { maxLocals, bytes, seenLabels, offsetedInstrs, ... } = state + val (opcodes, stackDiff, constPool) = Instr.compile constPool instr + val storeIndex = Option.getOpt (Instr.storeIndex instr, 0) + 1 + in + traverse rest { + index = index + 1, + offset = offset + Word8Vector.length opcodes, + constPool = constPool, + stackSize = stackSize + stackDiff, + maxStack = Int.max (maxStack, stackSize + stackDiff), + maxLocals = Int.max (maxLocals, storeIndex), + bytes = Word8Vector.concat [bytes, opcodes], + seenLabels = seenLabels, + offsetedInstrs = (offset, instr) :: offsetedInstrs + } + end + + and traverseGoto { label, instr, byteCount } state rest = + let + val { index, offset, constPool, stackSize, maxStack, ... } = state + val { maxLocals, bytes, seenLabels, offsetedInstrs, ... } = state + in + case LabelMap.find (seenLabels, label) of + | SOME (labelOffset, labelIndex) => + let + val offsetedInstr = instr labelIndex + val instr = instr (labelOffset - offset) + val (opcodes, stackDiff, constPool) = Instr.compile constPool instr + in + traverse rest { + index = index + 1, + offset = offset + Word8Vector.length opcodes, + constPool = constPool, + stackSize = stackSize + stackDiff, + maxStack = Int.max (maxStack, stackSize + stackDiff), + maxLocals = maxLocals, + bytes = Word8Vector.concat [bytes, opcodes], + seenLabels = seenLabels, + offsetedInstrs = (offset, offsetedInstr) :: offsetedInstrs + } + end + | NONE => + let + (* + * We don't have a label yet; traverse the rest of the + * instruction stream and then try again, maybe a label + * has been found. + *) + val result = traverse rest { + index = index + 1, + offset = offset + byteCount, + constPool = constPool, + (* + * We don't have an instruction stackDiff here, so we just + * reset these counters and compensate later. + *) + stackSize = 0, + maxStack = 0, + maxLocals = maxLocals, + bytes = Util.vec [], + seenLabels = seenLabels, + offsetedInstrs = [] + } + in + case LabelMap.find (#seenLabels result, label) of + | NONE => raise Fail ("undefined label: " ^ label) + | SOME (labelOffset, labelIndex) => + let + (* + * We're doing a kind of a dirty thing here. We're misusing + * the instruction's offset field by putting the *index* of + * the target instruction. The index as it appears in our + * instruction list, not in the final byte stream. + *) + val offsetedInstr = instr labelIndex + val instr = instr (labelOffset - offset) + val (opcodes, stackDiff, constPool) = + Instr.compile (#constPool result) instr + in + { (* - * We don't have a label yet; traverse the rest of the - * instruction stream and then try again, maybe a label - * has been found. + * These values are only read inside recursive calls, not + * when the function returns, so nobody will look at them, + * which means we can use default values and save some + * computations. *) - val result = traverse rest { - offset = offset + byteCount, - constPool = constPool, - stackSize = stackSize, - maxStack = maxStack, - maxLocals = maxLocals, - bytes = Util.vec [], - seenLabels = seenLabels - } - in - case LabelMap.find (#seenLabels result, label) of - NONE => raise Fail ("undefined label: " ^ label) - | SOME labelOffset => - let - val instr = instr (labelOffset - offset) - val (opcodes, stackDiff, constPool) = - Instr.compile (#constPool result) instr - in - { - offset = #offset result, - constPool = constPool, - stackSize = stackSize + stackDiff, - maxStack = Int.max (maxStack, stackSize + stackDiff), - maxLocals = #maxLocals result, - bytes = Word8Vector.concat [bytes, opcodes, #bytes result], - seenLabels = #seenLabels result - } - end - end - end - | LABEL label => - traverse rest { - offset = offset, - constPool = constPool, - stackSize = stackSize, - maxStack = maxStack, - maxLocals = maxLocals, - bytes = bytes, - seenLabels = LabelMap.insert (seenLabels, label, offset) - } - | INSTR instr => - let - val (opcodes, stackDiff, constPool) = Instr.compile constPool instr - in - traverse rest { - offset = offset + Word8Vector.length opcodes, - constPool = constPool, - stackSize = stackSize + stackDiff, - maxStack = Int.max (maxStack, stackSize + stackDiff), - maxLocals = maxLocals, - bytes = Word8Vector.concat [bytes, opcodes], - seenLabels = seenLabels - } - end + index = 0, + offset = 0, + stackSize = 0, - val seed = { - offset = 0, - constPool = constPool, - stackSize = 0, - maxStack = 0, - maxLocals = 10, (* TODO: compute maxLocals *) - bytes = Util.vec [], - seenLabels = LabelMap.empty - } + (* + * The following are values that will be read on return, + * so we have to put the real values. + *) + constPool = constPool, + seenLabels = #seenLabels result, + (* + * Here's where we compensate for the fact that above we + * didn't know the stack diff amount of an instruction. + *) + maxStack = Int.max ( + maxStack, + #maxStack result + stackSize + stackDiff + ), + maxLocals = #maxLocals result, + bytes = Word8Vector.concat [ + bytes, + opcodes, + #bytes result + ], + offsetedInstrs = List.concat [ + #offsetedInstrs result, + [(offset, offsetedInstr)], + offsetedInstrs + ] + } + end + end + end + + and traverse instrs state = + case instrs of + | [] => state + | (LABEL label :: rest) => traverseLabel label state rest + | (INSTR instr :: rest) => traverseInstr instr state rest + | (GOTO goto :: rest) => traverseGoto goto state rest - val result = traverse instrs seed + val { bytes, maxStack, maxLocals, constPool, offsetedInstrs, ... } = + traverse instrs { + index = 0, + offset = 0, + constPool = constPool, + stackSize = 0, + maxStack = 0, + maxLocals = 0, + bytes = Util.vec [], + seenLabels = LabelMap.empty, + offsetedInstrs = [] + } in { - bytes = #bytes result, - maxStack = #maxStack result, - maxLocals = #maxLocals result, - constPool = #constPool result + bytes = bytes, + maxStack = maxStack, + maxLocals = maxLocals, + constPool = constPool, + offsetedInstrs = List.rev offsetedInstrs } end end diff --git a/src/main.sml b/src/main.sml index 7c0fc25..d08c3e3 100644 --- a/src/main.sml +++ b/src/main.sml @@ -1,5 +1,7 @@ structure Main = struct + open Fn.Syntax infix |> + structure Instr = LabeledInstr fun symbol class name descriptor = { @@ -154,16 +156,16 @@ structure Main = code = let open Instr in [ aload_0, arraylength, - iconst_1, + iconst_0, if_icmpne "else", - getstatic (symbol "java/lang/System" "out" "Ljava/io/PrintStream;"), + getstatic java.lang.System.out, ldc (Const.String "T"), - invokevirtual (symbol "java/io/PrintStream" "println" "(Ljava/lang/String;)V"), + invokevirtual java.io.PrintStream.println, goto "return", label "else", - getstatic (symbol "java/lang/System" "out" "Ljava/io/PrintStream;"), + getstatic java.lang.System.out, ldc (Const.String "F"), - invokevirtual (symbol "java/io/PrintStream" "println" "(Ljava/lang/String;)V"), + invokevirtual java.io.PrintStream.println, label "return", return ] end @@ -171,9 +173,110 @@ structure Main = ] } - val class = Class.from { + val factorial = Method.from { + name = "main", + accessFlags = [Method.Flag.PUBLIC, Method.Flag.STATIC], + descriptor = Descriptor.Method { + return = Descriptor.Void, + params = [ + Descriptor.Array (Descriptor.Object (ClassName.fromString "java/lang/String")) + ] + }, + attributes = [ + Attr.Code { + exceptionTable = [], + attributes = [], + code = let open Instr in [ + iconst_5, + istore_1, + iconst_1, + istore_2, + label "enter-while", + iload_1, + ifle "exit-while", + iload_2, + iload_1, + imul, + istore_2, + iinc (0w1, ~ 0w1), + goto "enter-while", + label "exit-while", + getstatic java.lang.System.out, + iload_2, + invokestatic java.lang.Integer.toString, + invokevirtual java.io.PrintStream.println, + return + ] end + } + ] + } + + val nestedLoops = Method.from { + name = "main", + accessFlags = [Method.Flag.PUBLIC, Method.Flag.STATIC], + descriptor = Descriptor.Method { + return = Descriptor.Void, + params = [ + Descriptor.Array (Descriptor.Object (ClassName.fromString "java/lang/String")) + ] + }, + attributes = [ + Attr.Code { + exceptionTable = [], + attributes = [ + (* Attr.StackMapTable [ + StackMap.Append { + offsetDelta = 2, + extraLocals = 1, + locals = [VerificationType.Integer] + }, + StackMap.Append { + offsetDelta = 7, + extraLocals = 1, + locals = [VerificationType.Integer] + }, + StackMap.Same { offsetDelta = 23 }, + StackMap.Chop { minusLocals = 1, offsetDelta = 9 } + ] *) + ], + code = let open Instr in [ + iconst_0, + istore_1, + label "goto-2", + iload_1, + bipush 0w10, + if_icmpge "exit", + iconst_0, + istore_2, + label "goto-1", + iload_2, + bipush 0w10, + if_icmpge "iinc", + getstatic java.lang.System.out, + iload_1, + iload_2, + iadd, + invokestatic java.lang.Integer.toString, + invokevirtual java.io.PrintStream.println, + iinc (0w2, 0w1), + goto "goto-1", + label "iinc", + iinc (0w1, 0w1), + iload_2, + iconst_3, + iadd, + pop, + goto "goto-2", + label "exit", + return + ] end + } + ] + } + + fun class name = Class.from { accessFlags = [Class.Flag.PUBLIC], - thisClass = ClassName.fromString "Main", + thisClass = ClassName.fromString name, superClass = ClassName.fromString "java/lang/Object", interfaces = [], attributes = [Attr.SourceFile "main.sml"], @@ -188,7 +291,7 @@ structure Main = } ], (* methods = [main, printString, bootstrap] *) - methods = [withBranch] + methods = [nestedLoops] } val trim = @@ -196,9 +299,9 @@ structure Main = string o dropl isSpace o dropr isSpace o full end - fun java classPath className = + fun java { classpath } className = let - val proc = Unix.execute ("/usr/bin/java", ["-cp", classPath, className]) + val proc = Unix.execute ("/usr/bin/java", ["-cp", classpath, className]) val output = TextIO.inputAll (Unix.textInstreamOf proc) in Unix.reap proc @@ -207,13 +310,26 @@ structure Main = fun main () = let + val className = "Main" val workDir = OS.FileSys.getDir () - val bytes = Class.compile class - val f = BinIO.openOut (OS.Path.joinDirFile { dir = workDir, file = "Main.class" }) - val _ = BinIO.output (f, bytes) - val _ = BinIO.closeOut f - val output = java workDir "Main" + val binDir = OS.Path.joinDirFile { dir = workDir, file = "bin" } + val fileName = OS.Path.joinDirFile { dir = binDir, file = className ^ ".class" } + val classFile = BinIO.openOut fileName + val bytes = Class.compile (class className) + val _ = BinIO.output (classFile, bytes) + val _ = BinIO.closeOut classFile + val output = java { classpath = binDir } className in print (output ^ "\n") end + + fun stackMap () = + let + val { offsetedInstrs, maxLocals, ... } = Instr.compileList ConstPool.empty (Method.code nestedLoops) + in + offsetedInstrs + |> Verifier.verify + |> (fn instrs => StackLang.interpret instrs (Method.nameAndType nestedLoops) maxLocals) + |> StackLang.compileCompact + end end diff --git a/src/member.fun b/src/member.fun index 5a82e29..4da619c 100644 --- a/src/member.fun +++ b/src/member.fun @@ -25,7 +25,7 @@ functor Member(Flag : sig type t val compile : t -> Word.word end) = fun compileAttrs (attr, (bytes, constPool)) = let - val (attrBytes, constPool) = Attr.compile constPool attr + val (attrBytes, constPool) = Attr.compile constPool attr (SOME { name, descriptor }) in (Word8Vector.concat [bytes, attrBytes], constPool) end @@ -34,7 +34,7 @@ functor Member(Flag : sig type t val compile : t -> Word.word end) = val (attrBytes, constPool) = List.foldl compileAttrs seed attributes val (nameIndex, constPool) = ConstPool.withUtf8 constPool name val (descIndex, constPool) = ConstPool.withUtf8 constPool (Descriptor.compile descriptor) - val methodBytes = Word8Vector.concat [ + val memberBytes = Word8Vector.concat [ u2 (Word.toInt (mask accessFlags)), u2 nameIndex, u2 descIndex, @@ -42,6 +42,6 @@ functor Member(Flag : sig type t val compile : t -> Word.word end) = attrBytes ] in - (methodBytes, constPool) + (memberBytes, constPool) end end diff --git a/src/method-handle.sml b/src/method-handle.sml index 4541f4e..08465d5 100644 --- a/src/method-handle.sml +++ b/src/method-handle.sml @@ -1,7 +1,7 @@ structure MethodHandle = struct datatype t = - GetField + | GetField | GetStatic | PutField | PutStatic @@ -13,7 +13,7 @@ structure MethodHandle = fun value kind = case kind of - GetField => 1 + | GetField => 1 | GetStatic => 2 | PutField => 3 | PutStatic => 4 diff --git a/src/method.sml b/src/method.sml index 5247755..716285c 100644 --- a/src/method.sml +++ b/src/method.sml @@ -3,7 +3,7 @@ structure Method = structure Flag = struct datatype t = - PUBLIC + | PUBLIC | PRIVATE | PROTECTED | STATIC @@ -18,7 +18,7 @@ structure Method = fun compile flag : Word.word = case flag of - PUBLIC => 0wx0001 + | PUBLIC => 0wx0001 | PRIVATE => 0wx0002 | PROTECTED => 0wx0004 | STATIC => 0wx0008 @@ -33,4 +33,7 @@ structure Method = end structure M = Member(Flag) open M + + fun code ({ attributes = [Attr.Code { code, ... }], ... } : t) = code + | code _ = raise Fail "bug: method without code attribute (abstract method?)" end diff --git a/src/stack-map/frame.dot b/src/stack-map/frame.dot new file mode 100644 index 0000000..0ac215a --- /dev/null +++ b/src/stack-map/frame.dot @@ -0,0 +1,41 @@ +digraph { + rankdir = LR + + node [ + shape = circle + fontname = courier + ] + + node [group = chop] + Chop1 + Chop2 + Chop3 + + node [group = same] + Same + Full + + node [group = append] + Append1 + Append2 + Append3 + + Same -> Same [label = "α ≣ β"] + Same -> Full [label = "α ≢ β"] + + Same -> Chop1 [label = "(α, ⊤)"] + Chop1 -> Chop2 [label = "(α, ⊤)"] + Chop2 -> Chop3 [label = "(α, ⊤)"] + + Chop1 -> Full [label = "∀ α, β"] + Chop2 -> Full [label = "∀ α, β"] + Chop3 -> Full [label = "∀ α, β"] + + Same -> Append1 [label = "(⊤, β)"] + Append1 -> Append2 [label = "(⊤, β)"] + Append2 -> Append3 [label = "(⊤, β)"] + + Append1 -> Full [label = "∀ α, β"] + Append2 -> Full [label = "∀ α, β"] + Append3 -> Full [label = "∀ α, β"] +} diff --git a/src/stack-map/frame.pdf b/src/stack-map/frame.pdf new file mode 100644 index 0000000..8248c45 Binary files /dev/null and b/src/stack-map/frame.pdf differ diff --git a/src/stack-map/frame.sml b/src/stack-map/frame.sml new file mode 100644 index 0000000..ae6e6e9 --- /dev/null +++ b/src/stack-map/frame.sml @@ -0,0 +1,60 @@ +structure Frame = + struct + datatype diff = + | Same + | Full + | Chop of int + | Append of int + + fun toString diff = + case diff of + | Same => "Same" + | Full => "Full" + | Chop n => "Chop " ^ Int.toString n + | Append n => "Append " ^ Int.toString n + + fun localsDifference xs ys = + let + fun chop n xs ys = + if n = 3 + then + case (xs, ys) of + | ([], []) => Chop n + | _ => Full + else + case (xs, ys) of + | ([], []) => Chop n + | (_ :: xs, []) => chop (n + 1) xs [] + | (_ :: xs, VerificationType.Top :: ys) => chop (n + 1) xs ys + | _ => Full + + fun append n xs ys = + if n = 3 + then + case (xs, ys) of + | ([], []) => Append n + | _ => Full + else + case (xs, ys) of + | ([], []) => Append n + | ([], _ :: ys) => append (n + 1) [] ys + | (VerificationType.Top :: xs, _ :: ys) => append (n + 1) xs ys + | _ => Full + + fun same xs ys = + case (xs, ys) of + | ([], []) => Same + | ([], _ :: ys) => append 1 [] ys + | (_ :: xs, []) => chop 1 xs [] + | (x :: xs, y :: ys) => + if x = y + then same xs ys + else + case (x, y) of + | (VerificationType.Top, _) => append 1 xs ys + | (_, VerificationType.Top) => chop 1 xs ys + | _ => Full + in + same xs ys + end + end diff --git a/src/stack-map/stack-lang.sml b/src/stack-map/stack-lang.sml new file mode 100644 index 0000000..62a5680 --- /dev/null +++ b/src/stack-map/stack-lang.sml @@ -0,0 +1,356 @@ +structure StackLang = + struct + type local_index = int + + type indexed_type = { + index : local_index, + vtype : VerificationType.t + } + + datatype t = + | Push of VerificationType.t + | Pop of VerificationType.t + | Load of local_index * VerificationType.t (* indexed_type *) + | Store of local_index * VerificationType.t (* indexed_type *) + | Local of local_index * VerificationType.t (* indexed_type *) + | Branch of { targetOffset : int, fallsThrough : bool } + + exception StackUnderflow + exception UnassignedLocal + + fun toString t = + case t of + | Push vtype => "Push " ^ VerificationType.toString vtype + | Pop vtype => "Pop " ^ VerificationType.toString vtype + | Load (index, vtype) => "Load ("^ Int.toString index ^", "^ VerificationType.toString vtype ^")" + | Store (index, vtype) => "Store ("^ Int.toString index ^", "^ VerificationType.toString vtype ^")" + | Local (index, vtype) => "Local ("^ Int.toString index ^", "^ VerificationType.toString vtype ^")" + | Branch { targetOffset, fallsThrough } => + "Branch { targetOffset = "^ Int.toString targetOffset + ^", fallsThrough = "^ Bool.toString fallsThrough ^" }" + + fun interpret instrs { name, descriptor } maxLocals = + let + fun mergeFrames prev curr = + let + val { stack = prevStack, locals = prevLocals } = prev + val { stack = currStack, locals = currLocals } = curr + in + if List.length prevStack <> List.length currStack + then raise Fail "mergeFrames: different stack lengths" + else + let + val mergedLocals = + ListPair.map (Fn.uncurry VerificationType.leastUpperBound) (prevLocals, currLocals) + in + { + stack = currStack, + locals = mergedLocals + } + end + end + + fun generateFrame instrs state = + let + (* TODO: handle longs and doubles which occupy two slots *) + fun fold (instr, { stack, locals, frameMap, fallsThrough }) = + case instr of + | Push vType => { + stack = vType :: stack, + locals = locals, + frameMap = frameMap, + fallsThrough = true + } + | Pop vType => { + stack = List.tl stack handle Empty => raise StackUnderflow, + locals = locals, + frameMap = frameMap, + fallsThrough = true + } + | Load (index, vType) => { + stack = vType :: stack, + locals = locals, + frameMap = frameMap, + fallsThrough = true + } + | Local (index, vType) => { + stack = stack, + locals = List.update (locals, index, vType) handle Subscript => raise UnassignedLocal, + frameMap = frameMap, + fallsThrough = true + } + | Store (index, vType) => { + stack = List.tl stack handle Empty => raise StackUnderflow, + locals = List.update (locals, index, vType) handle Subscript => raise UnassignedLocal, + frameMap = frameMap, + fallsThrough = true + } + | Branch { targetOffset, fallsThrough } => + let + val frame = { stack = stack, locals = locals } + val mergedFrame = + case IntBinaryMap.find (frameMap, targetOffset) of + NONE => { + offset = NONE, (* TODO: use ~1 (negative 1) *) + frame = frame, + isBranchTarget = true + } + | SOME { offset, frame = prevFrame, isBranchTarget } => { + offset = offset, + frame = mergeFrames prevFrame frame, + isBranchTarget = true + } + in + { + stack = stack, + locals = locals, + frameMap = IntBinaryMap.insert (frameMap, targetOffset, mergedFrame), + fallsThrough = fallsThrough + } + end + in + List.foldl fold state instrs + end + + fun eval instrss = + let + fun fold (index, { offset, instrs }, { stack, locals, frameMap, fallsThrough }) = + let + val frame = { stack = stack, locals = locals } + val mergedFrame = + case IntBinaryMap.find (frameMap, index) of + | NONE => { + offset = SOME offset, + frame = frame, + isBranchTarget = false + } + | SOME { offset = NONE, frame = prevFrame, isBranchTarget } => { + offset = SOME offset, + frame = mergeFrames prevFrame frame, + isBranchTarget = isBranchTarget + } + | SOME { offset, frame = prevFrame, isBranchTarget } => { + offset = offset, + frame = mergeFrames prevFrame frame, + isBranchTarget = isBranchTarget + } + val frameMap = IntBinaryMap.insert (frameMap, index, mergedFrame) + in + generateFrame instrs { + stack = if fallsThrough then stack else [], + locals = locals, + frameMap = frameMap, + fallsThrough = true + } + end + + val locals = + let + val args = VerificationType.methodParams descriptor + val nonArgCount = maxLocals - List.length args + val locals = List.tabulate (nonArgCount, fn _ => VerificationType.Top) + in + args @ locals + end + + val seed = { + stack = [], + locals = locals, + frameMap = IntBinaryMap.empty, + fallsThrough = true + } + + val { frameMap, ... } = List.foldli fold seed instrss + in + IntBinaryMap.listItems frameMap + end + + fun unwrapOffset (index, item) = + case item of + | { offset = NONE, ... } => raise Fail ("bug: NONE offset at index: " ^ Int.toString index) + | { offset = SOME offset, frame, isBranchTarget } => { + offset = offset, + frame = frame, + isBranchTarget = isBranchTarget + } + in + Console.println ("StackLang: interpreting method: " ^ name ^ ":" ^ Descriptor.compile descriptor); + List.mapi unwrapOffset (eval instrs) + end + + (* TODO: update function *) + fun compile frameSets = + let + fun compile ({ offset, frames }, { prevLocals, compiled, lastOffset }) = + let + val isBranchTarget = List.length frames > 1 + val offsetDelta = offset - lastOffset + in + case List.hd frames of + | { stack = [], locals } => + let + val localsSize = List.length locals + val lastLocalsSize = List.length prevLocals + val localsDiff = lastLocalsSize - localsSize + val stackMapFrame = + if localsDiff = 0 + then StackMap.Same { offsetDelta = offsetDelta } + else + if localsDiff < 0 + then StackMap.Chop { + offsetDelta = offsetDelta, + minusLocals = localsDiff + } + else StackMap.Append { + offsetDelta = offsetDelta, + extraLocals = localsDiff, + locals = List.drop (locals, lastLocalsSize) + } + in + { + prevLocals = locals, + compiled = (isBranchTarget, stackMapFrame) :: compiled, + lastOffset = offset + } + end + | { stack = [a], locals } => + let + val localsSize = List.length locals + val lastLocalsSize = List.length prevLocals + val localsDiff = lastLocalsSize - localsSize + val stackMapFrame = + if localsDiff = 0 + then StackMap.SameLocals1StackItem { offsetDelta = offsetDelta, stack = a } + else StackMap.Full { + offsetDelta = offsetDelta, + stack = [a], + locals = locals + } + in + { + prevLocals = locals, + compiled = (isBranchTarget, stackMapFrame) :: compiled, + lastOffset = offset + } + end + | { stack, locals } => + let + val stackMapFrame = StackMap.Full { + offsetDelta = offsetDelta, + stack = stack, + locals = locals + } + in + { + prevLocals = locals, + compiled = (isBranchTarget, stackMapFrame) :: compiled, + lastOffset = offset + } + end + end + + val state = { + prevLocals = [], + compiled = [], + lastOffset = 0 + } + in + List.rev (#compiled (List.foldl compile state frameSets)) + end + + fun compileCompact frameSets = + case frameSets of + | [] => [] + | { frame = { locals, stack }, offset, isBranchTarget } :: frameSets => + let + fun compile ({ offset, frame, isBranchTarget }, state as { prevLocals, compiled, prevOffset }) = + let + (* val isBranchTarget = List.length frames > 1 *) + val offsetDelta = offset - (if prevOffset = 0 then prevOffset else prevOffset + 1) + in + if not isBranchTarget + then state + else + (* TODO: intersect frames *) + case frame of + { stack = [], locals } => + let + val stackMapFrame = + case Frame.localsDifference prevLocals locals of + Frame.Same => StackMap.Same { offsetDelta = offsetDelta } + | Frame.Full => StackMap.Full { + offsetDelta = offsetDelta, + stack = [], + locals = locals + } + | Frame.Chop n => StackMap.Chop { + offsetDelta = offsetDelta, + minusLocals = n + } + | Frame.Append n => + let + val locals = + case List.drop (locals, List.length locals - n) of + [a, VerificationType.Top, VerificationType.Top] => [a] + | [a, b, VerificationType.Top] => [a, b] + | [a, VerificationType.Top] => [a] + | other => other + in + StackMap.Append { + offsetDelta = offsetDelta, + extraLocals = List.length locals, + locals = locals + } + end + in + { + prevLocals = locals, + compiled = stackMapFrame :: compiled, + prevOffset = offset + } + end + | { stack = [a], locals } => + let + val localsSize = List.length locals + val lastLocalsSize = List.length prevLocals + val localsDiff = lastLocalsSize - localsSize + val stackMapFrame = + if localsDiff = 0 + then StackMap.SameLocals1StackItem { offsetDelta = offsetDelta, stack = a } + else StackMap.Full { + offsetDelta = offsetDelta, + stack = [a], + locals = locals + } + in + { + prevLocals = locals, + compiled = stackMapFrame :: compiled, + prevOffset = offset + } + end + | { stack, locals } => + let + val stackMapFrame = StackMap.Full { + offsetDelta = offsetDelta, + stack = stack, + locals = locals + } + in + { + prevLocals = locals, + compiled = stackMapFrame :: compiled, + prevOffset = offset + } + end + end + + val state = { + prevOffset = 0, + prevLocals = locals, + compiled = [] + } + in + List.rev (#compiled (List.foldl compile state frameSets)) + end + end diff --git a/src/stack-map/stack-map.sml b/src/stack-map/stack-map.sml new file mode 100644 index 0000000..a35b42f --- /dev/null +++ b/src/stack-map/stack-map.sml @@ -0,0 +1,146 @@ +(* §4.7.4 *) +structure StackMap = + struct + datatype frame = + (* + * This frame type indicates that the frame has exactly the same local + * variables as the previous frame and that the operand stack is empty. + * + * This entry also handles the `same_frame_extended` case. The distiction + * is based on the `offsetDelta` value. If it's small enough to be + * represented using a `same_frame`, it will use that. Conversely, it + * will use a `same_frame_extended` with an explicit `offset_delta`. + *) + | Same of { offsetDelta : int } + + (* + * This frame type indicates that the frame has exactly the same local + * variables as the previous frame and that the operand stack has one + * entry. + * + * This entry also handles the `same_locals_1_stack_item_frame_extended` + * case. Similarly to `same_frame`, the distiction is made based on the + * value of the `offsetDelta` field — if it's too large, then we use the + * extended case. + *) + | SameLocals1StackItem of { + offsetDelta : int, + stack: VerificationType.t + } + + (* + * This frame type indicates that the frame has the same local variables + * as the previous frame except that the last k local variables are + * absent, and that the operand stack is empty. + *) + | Chop of { + offsetDelta : int, + minusLocals : int + } + + (* + * This frame type indicates that the frame has the same locals as the + * previous frame except that k additional locals are defined, and that + * the operand stack is empty. + *) + | Append of { + offsetDelta : int, + extraLocals : int, + locals : VerificationType.t list + } + + (* + * This frame contains all the stack and locals information, explicitly. + *) + | Full of { + offsetDelta : int, + locals : VerificationType.t list, + stack : VerificationType.t list + } + + (* structure Frame = + struct + type t = frame + + fun toString frame = + case frame of + | Same { offsetDelta } => "Same { "^ Int.toString offsetDelta ^" }" + | + end *) + + fun toString frame = + case frame of + | Same { offsetDelta } => "Same { "^ Int.toString offsetDelta ^" }" + | SameLocals1StackItem { offsetDelta, stack } => + "SameLocals1StackItem { "^ Int.toString offsetDelta ^ ", "^ VerificationType.toString stack ^"}" + | _ => raise Fail "not implemented" + + open Util + + fun compile constPool frame = + let + fun same { offsetDelta } = + if offsetDelta <= 63 + then (u1 offsetDelta, constPool) + else (Word8Vector.prepend (0w251, u2 offsetDelta), constPool) + + fun sameLocals1StackItem { offsetDelta, stack } = + let + val (vtype, constPool) = VerificationType.compile constPool stack + in + if offsetDelta <= 63 + then (Word8Vector.concat [u1 (offsetDelta + 64), vtype], constPool) + else (Word8Vector.concat [u1 247, u2 offsetDelta, vtype], constPool) + end + + fun chop { minusLocals, offsetDelta } = + if minusLocals < 1 orelse minusLocals > 3 + then raise Fail ("chop frame with invalid minusLocals value: " ^ Int.toString minusLocals) + else (Word8Vector.concat [u1 (251 - minusLocals), u2 offsetDelta], constPool) + + fun append { extraLocals, offsetDelta, locals } = + if extraLocals < 1 orelse extraLocals > 3 + then raise Fail ("append frame with invalid extraLocals value: " ^ Int.toString extraLocals) + else + let + val (localBytes, constPool) = VerificationType.compileList constPool locals + val bytes = Word8Vector.concat [ + u1 (251 + extraLocals), + u2 offsetDelta, + localBytes + ] + in + (bytes, constPool) + end + + fun full { offsetDelta, locals, stack } = + let + val (localBytes, constPool) = VerificationType.compileList constPool locals + val (stackBytes, constPool) = VerificationType.compileList constPool stack + val bytes = Word8Vector.concat [ + u1 255, + u2 offsetDelta, + u2 (List.length locals), + localBytes, + u2 (List.length stack), + stackBytes + ] + in + (bytes, constPool) + end + in + case frame of + | Same a => same a + | SameLocals1StackItem a => sameLocals1StackItem a + | Chop a => chop a + | Append a => append a + | Full a => full a + end + + fun compileFrames constPool frames = + List.foldMapState frames { + monoid = Word8Vector.join, + step = Fn.swap compile, + seed = (vec [], constPool) + } + end diff --git a/src/stack-map/verification-type.sml b/src/stack-map/verification-type.sml new file mode 100644 index 0000000..8a3f084 --- /dev/null +++ b/src/stack-map/verification-type.sml @@ -0,0 +1,139 @@ +structure VerificationType = + struct + open Util + + (** + * Verification Type Hierarchy + * + * See: JVMS18 / $4.10.1.2 / Verification Type System + * + * ``` + * top + * ____________/\____________ + * / \ + * / \ + * oneWord twoWord + * / | \ / \ + * / | \ / \ + * int float reference long double + * / \ + * / \_____________ + * / \ + * / \ + * uninitialized +------------------+ + * / \ | Java reference | + * / \ | type hierarchy | + * uninitializedThis uninitialized(Offset) +------------------+ + * | + * | + * null + * ``` + *) + datatype t = + | Top + (* | OneWord *) (* See diagram above *) + (* | TwoWord *) + | Integer + | Float + | Long + | Double + | Null + | Array of t + | Object of ClassName.t + | Uninitialized of Instr.offset + | UninitializedThis + | Reference + + fun toString t = + case t of + | Top => "Top" + | Integer => "Integer" + | Float => "Float" + | Long => "Long" + | Double => "Double" + | Null => "Null" + | Array _ => "Array" + | Object _ => "Object" + | Uninitialized _ => "Uninitialized" + | UninitializedThis => "UninitializedThis" + | Reference => "Reference" + + (** + * See: Java Bytecode Verification — An Overview, p7 + *) + fun leastUpperBound a b = + case (a, b) of + | (Top, _) => Top + | (_, Top) => Top + | (a, b) => + if a = b + then a + else + (* TODO *) + raise Fail ("not implemented: leastUpperBound: " ^ toString a ^ " =/= " ^ toString b) + + fun isTop Top = true + | isTop _ = false + + fun fromSimple simple = + case simple of + | Descriptor.Bool => Integer + | Descriptor.Byte => Integer + | Descriptor.Char => Integer + | Descriptor.Double => Double + | Descriptor.Float => Float + | Descriptor.Int => Integer + | Descriptor.Long => Long + | Descriptor.Short => Integer + | Descriptor.Object class => Object class + | Descriptor.Array elem => Array (fromSimple elem) + + fun methodReturn descriptor = + case descriptor of + | Descriptor.Method { return = Descriptor.Void, ... } => Top + | Descriptor.Method { return = Descriptor.Type simple, ... } => fromSimple simple + | _ => raise Fail "illegal: descriptor is not a method" + + fun methodParams descriptor = + case descriptor of + | Descriptor.Method { params, ... } => List.map fromSimple params + | _ => raise Fail "illegal: descriptor is not a method" + + fun fromDescriptor descriptor = + case descriptor of + | Descriptor.Raw text => raise Fail "not implemented" + | Descriptor.Method { params, return } => raise Fail "not implemented" + | Descriptor.Field simple => fromSimple simple + + fun fromConst const = + case const of + | Const.Integer _ => Integer + | Const.Float _ => Float + | Const.Long _ => Long + | Const.Double _ => Double + | Const.String _ => Object (ClassName.fromParts ["java", "lang", "String"]) + | Const.Class _ => Object (ClassName.fromParts ["java", "lang", "Class"]) + | Const.MethodType _ => Object (ClassName.fromParts ["java", "lang", "invoke", "MethodType"]) + | Const.MethodHandle _ => Object (ClassName.fromParts ["java", "lang", "invoke", "MethodHandle"]) + + fun compile constPool vtype = + case vtype of + | Top => (u1 0, constPool) + | Integer => (u1 1, constPool) + | Float => (u1 2, constPool) + | Long => (u1 4, constPool) + | Double => (u1 3, constPool) + | Null => (u1 5, constPool) + | Array t => raise Fail "not implemented" + | Object t => raise Fail "not implemented" + | Uninitialized offset => raise Fail "not implemented" + | UninitializedThis => (u1 6, constPool) + | Reference => raise Fail "not implemented" + + fun compileList constPool vtypes = + List.foldMapState vtypes { + monoid = Word8Vector.join, + step = Fn.swap compile, + seed = (vec [], constPool) + } + end diff --git a/src/stack-map/verifier.sig b/src/stack-map/verifier.sig new file mode 100644 index 0000000..2436039 --- /dev/null +++ b/src/stack-map/verifier.sig @@ -0,0 +1,4 @@ +signature VERIFIER = + sig + val verify : ('offset * Instr.t) list -> { offset : 'offset, instrs : StackLang.t list } list + end diff --git a/src/stack-map/verifier.sml b/src/stack-map/verifier.sml new file mode 100644 index 0000000..86e5c77 --- /dev/null +++ b/src/stack-map/verifier.sml @@ -0,0 +1,280 @@ +structure Verifier : VERIFIER = + let + open Instr StackLang + + (** + * TODO: this function needs to receive the containing methods signature, + * so that we can verify params and return instructions. + * + * For example, `areturn` must verify that the stack contains a reference + * which is a subtype of the declared return type. + * + * > An areturn instruction is type safe iff the enclosing method has a + * > declared return type, ReturnType, that is a reference type, and one + * > can validly pop a type matching ReturnType off the incoming operand + * > stack. + * — JVMS23, p238 + *) + fun verify instrs = + let + fun transition instr = + case instr of + | nop => [] + | aconst_null => raise Fail "not implemented: aconst_null" + | iconst_m1 => [Push VerificationType.Integer] + | iconst_0 => [Push VerificationType.Integer] + | iconst_1 => [Push VerificationType.Integer] + | iconst_2 => [Push VerificationType.Integer] + | iconst_3 => [Push VerificationType.Integer] + | iconst_4 => [Push VerificationType.Integer] + | iconst_5 => [Push VerificationType.Integer] + | lconst_0 => raise Fail "not implemented: lconst_0" + | lconst_1 => raise Fail "not implemented: lconst_1" + | fconst_0 => raise Fail "not implemented: fconst_0" + | fconst_1 => raise Fail "not implemented: fconst_1" + | fconst_2 => raise Fail "not implemented: fconst_2" + | dconst_0 => raise Fail "not implemented: dconst_0" + | dconst_1 => raise Fail "not implemented: dconst_1" + | bipush word => [Push VerificationType.Integer] + | sipush short => raise Fail "not implemented: sipush" + | ldc const => [Push (VerificationType.fromConst const)] + | iload index => raise Fail "not implemented: iload" + | lload index => raise Fail "not implemented: lload" + | fload index => raise Fail "not implemented: fload" + | dload index => raise Fail "not implemented: dload" + | aload index => raise Fail "not implemented: aload" + | iload_0 => [Load (0, VerificationType.Integer)] + | iload_1 => [Load (1, VerificationType.Integer)] + | iload_2 => [Load (2, VerificationType.Integer)] + | iload_3 => [Load (3, VerificationType.Integer)] + | lload_0 => raise Fail "not implemented: lload_0" + | lload_1 => raise Fail "not implemented: lload_1" + | lload_2 => raise Fail "not implemented: lload_2" + | lload_3 => raise Fail "not implemented: lload_3" + | fload_0 => raise Fail "not implemented: fload_0" + | fload_1 => raise Fail "not implemented: fload_1" + | fload_2 => raise Fail "not implemented: fload_2" + | fload_3 => raise Fail "not implemented: fload_3" + | dload_0 => raise Fail "not implemented: dload_0" + | dload_1 => raise Fail "not implemented: dload_1" + | dload_2 => raise Fail "not implemented: dload_2" + | dload_3 => raise Fail "not implemented: dload_3" + | aload_0 => [Load (0, VerificationType.Reference)] + | aload_1 => [Load (1, VerificationType.Reference)] + | aload_2 => [Load (2, VerificationType.Reference)] + | aload_3 => [Load (3, VerificationType.Reference)] + | iaload => raise Fail "not implemented: iaload" + | laload => raise Fail "not implemented: laload" + | faload => raise Fail "not implemented: faload" + | daload => raise Fail "not implemented: daload" + | aaload => raise Fail "not implemented: aaload" + | baload => raise Fail "not implemented: baload" + | caload => raise Fail "not implemented: caload" + | saload => raise Fail "not implemented: saload" + | istore index => raise Fail "not implemented: istore" + | lstore index => raise Fail "not implemented: lstore" + | fstore index => raise Fail "not implemented: fstore" + | dstore index => raise Fail "not implemented: dstore" + | astore index => raise Fail "not implemented: astore" + | istore_0 => [Store (0, VerificationType.Integer)] + | istore_1 => [Store (1, VerificationType.Integer)] + | istore_2 => [Store (2, VerificationType.Integer)] + | istore_3 => [Store (3, VerificationType.Integer)] + | lstore_0 => raise Fail "not implemented: lstore_0" + | lstore_1 => raise Fail "not implemented: lstore_1" + | lstore_2 => raise Fail "not implemented: lstore_2" + | lstore_3 => raise Fail "not implemented: lstore_3" + | fstore_0 => raise Fail "not implemented: fstore_0" + | fstore_1 => raise Fail "not implemented: fstore_1" + | fstore_2 => raise Fail "not implemented: fstore_2" + | fstore_3 => raise Fail "not implemented: fstore_3" + | dstore_0 => raise Fail "not implemented: dstore_0" + | dstore_1 => raise Fail "not implemented: dstore_1" + | dstore_2 => raise Fail "not implemented: dstore_2" + | dstore_3 => raise Fail "not implemented: dstore_3" + | astore_0 => raise Fail "not implemented: astore_0" + | astore_1 => raise Fail "not implemented: astore_1" + | astore_2 => raise Fail "not implemented: astore_2" + | astore_3 => raise Fail "not implemented: astore_3" + | iastore => raise Fail "not implemented: iastore" + | lastore => raise Fail "not implemented: lastore" + | fastore => raise Fail "not implemented: fastore" + | dastore => raise Fail "not implemented: dastore" + | aastore => raise Fail "not implemented: aastore" + | bastore => raise Fail "not implemented: bastore" + | castore => raise Fail "not implemented: castore" + | sastore => raise Fail "not implemented: sastore" + | pop => [Pop VerificationType.Top] + | pop2 => raise Fail "not implemented: pop2" + | dup => raise Fail "not implemented: dup" + | dup_x1 => raise Fail "not implemented: dup_x1" + | dup_x2 => raise Fail "not implemented: dup_x2" + | dup2 => raise Fail "not implemented: dup2" + | dup2_x1 => raise Fail "not implemented: dup2_x1" + | dup2_x2 => raise Fail "not implemented: dup2_x2" + | swap => raise Fail "not implemented: swap" + | iadd => [ + Pop VerificationType.Integer, + Pop VerificationType.Integer, + Push VerificationType.Integer + ] + | ladd => raise Fail "not implemented: ladd" + | fadd => raise Fail "not implemented: fadd" + | dadd => raise Fail "not implemented: dadd" + | isub => raise Fail "not implemented: isub" + | lsub => raise Fail "not implemented: lsub" + | fsub => raise Fail "not implemented: fsub" + | dsub => raise Fail "not implemented: dsub" + | imul => [ + Pop VerificationType.Integer, + Pop VerificationType.Integer, + Push VerificationType.Integer + ] + | lmul => raise Fail "not implemented: lmul" + | fmul => raise Fail "not implemented: fmul" + | dmul => raise Fail "not implemented: dmul" + | idiv => raise Fail "not implemented: idiv" + | ldiv => raise Fail "not implemented: ldiv" + | fdiv => raise Fail "not implemented: fdiv" + | ddiv => raise Fail "not implemented: ddiv" + | irem => raise Fail "not implemented: irem" + | lrem => raise Fail "not implemented: lrem" + | frem => raise Fail "not implemented: frem" + | drem => raise Fail "not implemented: drem" + | ineg => raise Fail "not implemented: ineg" + | lneg => raise Fail "not implemented: lneg" + | fneg => raise Fail "not implemented: fneg" + | dneg => raise Fail "not implemented: dneg" + | ishl => raise Fail "not implemented: ishl" + | lshl => raise Fail "not implemented: lshl" + | ishr => raise Fail "not implemented: ishr" + | lshr => raise Fail "not implemented: lshr" + | iushr => raise Fail "not implemented: iushr" + | lushr => raise Fail "not implemented: lushr" + | iand => raise Fail "not implemented: iand" + | land => raise Fail "not implemented: land" + | ior => raise Fail "not implemented: ior" + | lor => raise Fail "not implemented: lor" + | ixor => raise Fail "not implemented: ixor" + | lxor => raise Fail "not implemented: lxor" + | iinc (index, _) => [Local (Word8.toInt index, VerificationType.Integer)] + | i2l => raise Fail "not implemented: i2l" + | i2f => raise Fail "not implemented: i2f" + | i2d => raise Fail "not implemented: i2d" + | l2i => raise Fail "not implemented: l2i" + | l2f => raise Fail "not implemented: l2f" + | l2d => raise Fail "not implemented: l2d" + | f2i => raise Fail "not implemented: f2i" + | f2l => raise Fail "not implemented: f2l" + | f2d => raise Fail "not implemented: f2d" + | d2i => raise Fail "not implemented: d2i" + | d2l => raise Fail "not implemented: d2l" + | d2f => raise Fail "not implemented: d2f" + | i2b => raise Fail "not implemented: i2b" + | i2c => raise Fail "not implemented: i2c" + | i2s => raise Fail "not implemented: i2s" + | lcmp => raise Fail "not implemented: lcmp" + | fcmpl => raise Fail "not implemented: fcmpl" + | fcmpg => raise Fail "not implemented: fcmpg" + | dcmpl => raise Fail "not implemented: dcmpl" + | dcmpg => raise Fail "not implemented: dcmpg" + | ifeq offset => raise Fail "not implemented: ifeq" + | ifne offset => raise Fail "not implemented: ifne" + | iflt offset => raise Fail "not implemented: iflt" + | ifge offset => raise Fail "not implemented: ifge" + | ifgt offset => raise Fail "not implemented: ifgt" + | ifle offset => [ + Pop VerificationType.Integer, + Branch { targetOffset = offset, fallsThrough = true } + ] + | if_icmpeq offset => raise Fail "not implemented: if_icmpeq" + | if_icmpne offset => [ + Pop VerificationType.Integer, + Pop VerificationType.Integer, + Branch { targetOffset = offset, fallsThrough = true } + ] + | if_icmplt offset => raise Fail "not implemented: if_icmplt" + | if_icmpge offset => [ + Pop VerificationType.Integer, + Pop VerificationType.Integer, + Branch { targetOffset = offset, fallsThrough = true } + ] + | if_icmpgt offset => raise Fail "not implemented: if_icmpgt" + | if_icmple offset => raise Fail "not implemented: if_icmple" + | if_acmpeq offset => raise Fail "not implemented: if_acmpeq" + | if_acmpne offset => raise Fail "not implemented: if_acmpne" + | getstatic { descriptor, ... } => [Push (VerificationType.fromDescriptor descriptor)] + | putstatic _ => raise Fail "not implemented: putstatic" + | getfield _ => raise Fail "not implemented: getfield" + | putfield _ => raise Fail "not implemented: putfield" + | invokevirtual { descriptor, class, ... } => + let + val paramTypes = List.revMap Pop (VerificationType.methodParams descriptor) + val thisType = [Pop (VerificationType.Object class)] + val returnType = + case VerificationType.methodReturn descriptor of + | VerificationType.Top => [] + | verificationType => [Push verificationType] + in + List.concat [paramTypes, thisType, returnType] + end + | invokespecial _ => raise Fail "not implemented: invokespecial" + | invokestatic { descriptor, class, ... } => + let + val paramTypes = List.revMap Pop (VerificationType.methodParams descriptor) + val returnType = + case VerificationType.methodReturn descriptor of + | VerificationType.Top => [] + | verificationType => [Push verificationType] + in + List.concat [paramTypes, returnType] + end + | invokeinterface _ => raise Fail "not implemented: invokeinterface" + | invokedynamic _ => raise Fail "not implemented: invokedynamic" + | new className => raise Fail "not implemented: new" + | newarray _ => raise Fail "not implemented: newarray" + | anewarray index => raise Fail "not implemented: anewarray" + | arraylength => [ + Pop (VerificationType.Array VerificationType.Top), + Push VerificationType.Integer + ] + | athrow => raise Fail "not implemented: athrow" + | checkcast index => raise Fail "not implemented: checkcast" + | instanceof index => raise Fail "not implemented: instanceof" + | monitorenter => raise Fail "not implemented: monitorenter" + | monitorexit => raise Fail "not implemented: monitorexit" + | goto offset => [Branch { targetOffset = offset, fallsThrough = false }] + | tableswitch => raise Fail "not implemented: tableswitch" + | lookupswitch => raise Fail "not implemented: lookupswitch" + | ireturn => [Pop VerificationType.Integer] + | lreturn => [Pop VerificationType.Long] + | freturn => [Pop VerificationType.Float] + | dreturn => [Pop VerificationType.Double] + | areturn => raise Fail "not implemented" + | return => [] + | wide => raise Fail "not implemented: wide" + | multianewarray _ => raise Fail "not implemented: multianewarray" + | ifnull offset => raise Fail "not implemented: ifnull" + | ifnonnull offset => raise Fail "not implemented: ifnonnull" + | goto_w offset => raise Fail "not implemented: goto_w" + | breakpoint => raise Fail "not implemented: breakpoint" + | impdep1 => raise Fail "not implemented: impdep1" + | impdep2 => raise Fail "not implemented: impdep2" + + | ret index => raise Fail "Illegal instruction: ret is disallowed. See §4.9.1." + | jsr offset => raise Fail "Illegal instruction: jsr is disallowed. See §4.9.1." + | jsr_w offset => raise Fail "Illegal instruction: jsr_w is disallowed. See §4.9.1." + in + Console.println ("VERIFY........"); + List.mapPartial + (fn (offset, instr) => + case transition instr of + | [] => NONE + | xs => SOME ({ offset = offset, instrs = xs })) + instrs + end + in + struct + val verify = verify + end + end diff --git a/src/stack-map/verifier2.sml b/src/stack-map/verifier2.sml new file mode 100644 index 0000000..34e17d6 --- /dev/null +++ b/src/stack-map/verifier2.sml @@ -0,0 +1,28 @@ +(* + * Write it using the final tagless approach? Is it worth it? + *) +(* structure StackLang = + struct + type t = int + end *) + + + +structure Verifier2 : + sig + val verify : Instr.t list -> StackMap.frame list + end + = + struct + fun verify instrs = + let + fun fold (instr, state) = + raise Fail "not implemented" + + val seed = [] + + val r = List.foldl fold seed instrs + in + raise Fail "not implemented" + end + end diff --git a/test/all-suites.sml b/test/all-suites.sml new file mode 100644 index 0000000..be956c0 --- /dev/null +++ b/test/all-suites.sml @@ -0,0 +1,5 @@ +structure AllSuites = TestRunner( + val all = List.concat [ + FrameSuite.all + ] +) diff --git a/test/frame-suite.sml b/test/frame-suite.sml new file mode 100644 index 0000000..0d0d37b --- /dev/null +++ b/test/frame-suite.sml @@ -0,0 +1,18 @@ +structure FrameSuite : TEST_SUITE = + struct + infix --> + infix === + + open TestSuite + + val all = [ + "1 < 2" --> (fn _ => + Ints.assert 1 op< 2 + ), + + "1 = 1" --> let in fn _ => + 1 === 2 + (* Ints.assert 1 op= 2 *) + end + ] + end diff --git a/test/test-cases/factorial/Factorial.java b/test/test-cases/factorial/Factorial.java new file mode 100644 index 0000000..cd61031 --- /dev/null +++ b/test/test-cases/factorial/Factorial.java @@ -0,0 +1,33 @@ +/* + +Compile with `-g:none` to forgo all debug symbols: + +``` +$ javac -g:none Factorial.java +``` + +Inspect generated bytecode with: + +``` +$ javap -v Factorial +``` + +*/ + +public class Factorial { + public static void main(String[] args) { + int r = factorial(5); + System.out.println(Integer.toString(r)); + } + + public static int factorial(int n) { + int r = 1; + + while (n > 0) { + r = n * r; + n--; + } + + return r; + } +} diff --git a/test/test-cases/factorial/factorial.sml b/test/test-cases/factorial/factorial.sml new file mode 100644 index 0000000..d7f9733 --- /dev/null +++ b/test/test-cases/factorial/factorial.sml @@ -0,0 +1,161 @@ +structure Factorial = + let + val className = "Factorial" + + structure D = Descriptor + structure Instr = LabeledInstr + + val main = Method.from { + name = "main", + accessFlags = [Method.Flag.PUBLIC, Method.Flag.STATIC], + descriptor = Descriptor.Method { + return = Descriptor.Void, + params = [ + Descriptor.Array (Descriptor.Object (ClassName.fromString "java/lang/String")) + ] + }, + attributes = [ + Attr.Code { + exceptionTable = [], + attributes = [], + code = let open Instr in [ + iconst_5, + invokestatic { + class = ClassName.fromString className, + name = "factorial", + descriptor = D.Method { + params = [D.Int], + return = D.Type D.Int + } + }, + istore_1, + getstatic java.lang.System.out, + iload_1, + invokestatic java.lang.Integer.toString, + invokevirtual java.io.PrintStream.println, + return + ] end + } + ] + } + + val factorial = Method.from { + name = "factorial", + accessFlags = [Method.Flag.PUBLIC, Method.Flag.STATIC], + descriptor = Descriptor.Method { + return = Descriptor.Type Descriptor.Int, + params = [ + Descriptor.Int + ] + }, + attributes = [ + Attr.Code { + exceptionTable = [], + attributes = [], + (* + 0: Push Integer + 1: Store (1, Integer) + 2: Load (0, Integer) + 3: Pop Integer; Branch { targetOffset = 10, fallsThrough = true } + 6: Load (0, Integer) + 7: Load (1, Integer) + 8: Pop Integer; Pop Integer; Push Integer + 9: Store (1, Integer) + 10: Local (0, Integer) + 13: Branch { targetOffset = 2, fallsThrough = false } + 16: Load (1, Integer) + 17: Pop Integer + *) + code = let open Instr in [ + iconst_1, (* int r = 1; *) + istore_1, + label "enter-while", (* while (n > 0) { *) + iload_0, + ifle "exit-while", + iload_0, (* r = n * r; *) + iload_1, + imul, + istore_1, + iinc (0w0, ~ 0w1), (* n-- *) + goto "enter-while", (* } *) + label "exit-while", + iload_1, (* return r; *) + ireturn + ] end + } + ] + } + in + (* DSL sketch attempt *) + (* let open Assembly.DSL in + class [PUBLIC] "Factorial" [] [] [ + field [PRIVATE, STATIC, FINAL] "java/lang/String" "Hello, World!" + + method [PUBLIC, STATIC] INT "factorial" [INT] [ + iconst_1, + istore_1, + label "enter-while", + iload_0, + ifle "exit-while", + iload_0, + iload_1, + imul, + istore_1, + iinc (0w0, ~ 0w1), + goto "enter-while", + label "exit-while", + iload_1, + ireturn + ] end + ] + end *) + + struct + fun class name = Class.from { + accessFlags = [Class.Flag.PUBLIC], + thisClass = ClassName.fromString name, + superClass = ClassName.fromString "java/lang/Object", + interfaces = [], + attributes = [Attr.SourceFile "main.sml"], + fields = [ + Field.from { + name = "message", + accessFlags = [Field.Flag.PRIVATE, Field.Flag.STATIC, Field.Flag.FINAL], + descriptor = Descriptor.Field (Descriptor.Object (ClassName.fromString "java/lang/String")), + attributes = [ + Attr.ConstantValue (ConstantValue.String "Hello, World!") + ] + } + ], + methods = [main, factorial] + } + + val trim = + let open Char Substring in + string o dropl isSpace o dropr isSpace o full + end + + fun java { classpath } className = + let + val proc = Unix.execute ("/usr/bin/java", ["-cp", classpath, className]) + val output = TextIO.inputAll (Unix.textInstreamOf proc) + in + Unix.reap proc + ; trim output + end + + fun main () = + let + val workDir = OS.FileSys.getDir () + val binDir = OS.Path.joinDirFile { dir = workDir, file = "bin" } + val fileName = OS.Path.joinDirFile { dir = binDir, file = className ^ ".class" } + val classFile = BinIO.openOut fileName + val bytes = Class.compile (class className) + val _ = BinIO.output (classFile, bytes) + val _ = BinIO.closeOut classFile + val output = java { classpath = binDir } className + in + print (output ^ "\n") + end + end + end