diff --git a/prometeu-compiler/prometeu-build-pipeline/src/main/java/p/studio/compiler/backend/irvm/OptimizeIRVMService.java b/prometeu-compiler/prometeu-build-pipeline/src/main/java/p/studio/compiler/backend/irvm/OptimizeIRVMService.java index 5dca8737..6399569d 100644 --- a/prometeu-compiler/prometeu-build-pipeline/src/main/java/p/studio/compiler/backend/irvm/OptimizeIRVMService.java +++ b/prometeu-compiler/prometeu-build-pipeline/src/main/java/p/studio/compiler/backend/irvm/OptimizeIRVMService.java @@ -1,5 +1,12 @@ package p.studio.compiler.backend.irvm; +import p.studio.compiler.backend.bytecode.BytecodeEmitter; +import p.studio.utilities.structures.ReadOnlyList; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Objects; @@ -7,13 +14,14 @@ public class OptimizeIRVMService { private final IRVMValidator validator; private final IRVMProfileFeatureGate profileFeatureGate; private final List passes; + private volatile List lastExecutedPassNames = List.of(); public OptimizeIRVMService() { - this(new IRVMValidator(), new IRVMProfileFeatureGate(), List.of(new NoOpPass())); + this(new IRVMValidator(), new IRVMProfileFeatureGate(), defaultPasses()); } OptimizeIRVMService(final IRVMValidator validator) { - this(validator, new IRVMProfileFeatureGate(), List.of(new NoOpPass())); + this(validator, new IRVMProfileFeatureGate(), defaultPasses()); } OptimizeIRVMService( @@ -28,7 +36,7 @@ public class OptimizeIRVMService { final List passes) { this.validator = validator; this.profileFeatureGate = profileFeatureGate; - this.passes = passes == null ? List.of(new NoOpPass()) : List.copyOf(passes); + this.passes = passes == null ? defaultPasses() : List.copyOf(passes); } public IRVMProgram optimize(final IRVMProgram input) { @@ -37,11 +45,13 @@ public class OptimizeIRVMService { throw new IllegalArgumentException("unsupported vm profile: " + program.module().vmProfile()); } var current = program; + final var executedPassNames = new ArrayList(); validator.validate(current, false); for (final var pass : passes) { if (pass == null || !pass.enabled()) { continue; } + executedPassNames.add(pass.name()); final var beforeProfile = current.module().vmProfile(); current = Objects.requireNonNull(pass.apply(current), "pass output"); if (!beforeProfile.equals(current.module().vmProfile())) { @@ -49,9 +59,33 @@ public class OptimizeIRVMService { } validator.validate(current, false); } + lastExecutedPassNames = List.copyOf(executedPassNames); return current; } + public List lastExecutedPassNames() { + return lastExecutedPassNames; + } + + static IRVMPass unreachableInstructionEliminationPass() { + return new UnreachableInstructionEliminationPass(); + } + + static IRVMPass normalizeRedundantJumpTargetsPass() { + return new NormalizeRedundantJumpTargetsPass(); + } + + static IRVMPass simplifyJumpToNextPcPass() { + return new SimplifyJumpToNextPcPass(); + } + + private static List defaultPasses() { + return List.of( + unreachableInstructionEliminationPass(), + normalizeRedundantJumpTargetsPass(), + simplifyJumpToNextPcPass()); + } + public interface IRVMPass { String name(); @@ -62,15 +96,374 @@ public class OptimizeIRVMService { IRVMProgram apply(IRVMProgram input); } - private static final class NoOpPass implements IRVMPass { + private interface FunctionTransformer { + FunctionRewrite apply(IRVMFunction function, BytecodeEmitter.FunctionPlan functionPlan); + } + + private static IRVMProgram transformProgram( + final IRVMProgram input, + final FunctionTransformer functionTransformer) { + if (input == null || input.emissionPlan() == null || input.emissionPlan().functions().isEmpty()) { + return input; + } + if (input.emissionPlan().functions().size() != input.module().functions().size()) { + return input; + } + + var changed = false; + final var rewrittenFunctions = new ArrayList(input.module().functions().size()); + final var rewrittenPlans = new ArrayList(input.emissionPlan().functions().size()); + for (var i = 0; i < input.module().functions().size(); i++) { + final var rewrite = functionTransformer.apply( + input.module().functions().get(i), + input.emissionPlan().functions().get(i)); + rewrittenFunctions.add(rewrite.function()); + rewrittenPlans.add(rewrite.functionPlan()); + if (rewrite.changed()) { + changed = true; + } + } + + if (!changed) { + return input; + } + + return new IRVMProgram( + new IRVMModule(input.module().vmProfile(), ReadOnlyList.wrap(rewrittenFunctions)), + new BytecodeEmitter.EmissionPlan( + input.emissionPlan().version(), + input.emissionPlan().constPool(), + input.emissionPlan().exports(), + ReadOnlyList.wrap(rewrittenPlans))); + } + + private static boolean isJump(final IRVMOp op) { + return op == IRVMOp.JMP || op == IRVMOp.JMP_IF_TRUE || op == IRVMOp.JMP_IF_FALSE; + } + + private static int[] pcByIndex(final List instructions) { + final var out = new int[instructions.size()]; + var pc = 0; + for (var i = 0; i < instructions.size(); i++) { + out[i] = pc; + pc += instructions.get(i).encodedSize(); + } + return out; + } + + private static HashMap indexByPc(final int[] pcByIndex) { + final var out = new HashMap(); + for (var i = 0; i < pcByIndex.length; i++) { + out.put(pcByIndex[i], i); + } + return out; + } + + private static FunctionRewrite removeInstructionIndices( + final IRVMFunction function, + final BytecodeEmitter.FunctionPlan functionPlan, + final boolean[] removeByIndex) { + var removedAny = false; + for (final var remove : removeByIndex) { + if (remove) { + removedAny = true; + break; + } + } + if (!removedAny) { + return new FunctionRewrite(function, functionPlan, false); + } + + final var oldInstructions = function.instructions().asList(); + final var oldOperations = functionPlan.operations().asList(); + final var oldPcByIndex = pcByIndex(oldInstructions); + final var oldIndexByPc = indexByPc(oldPcByIndex); + + final var newInstructions = new ArrayList(oldInstructions.size()); + final var newOperations = new ArrayList(oldOperations.size()); + final var oldPcToNewPc = new HashMap(); + + var nextPc = 0; + for (var i = 0; i < oldInstructions.size(); i++) { + if (removeByIndex[i]) { + continue; + } + oldPcToNewPc.put(oldPcByIndex[i], nextPc); + newInstructions.add(oldInstructions.get(i)); + newOperations.add(oldOperations.get(i)); + nextPc += oldInstructions.get(i).encodedSize(); + } + + var changed = true; + for (var i = 0; i < newInstructions.size(); i++) { + final var instruction = newInstructions.get(i); + if (!isJump(instruction.op())) { + continue; + } + final var resolvedOldTarget = resolveRemovedJumpTarget( + instruction.immediate(), + removeByIndex, + oldIndexByPc, + oldInstructions); + final var remappedTarget = oldPcToNewPc.get(resolvedOldTarget); + if (remappedTarget == null) { + throw new IllegalArgumentException( + "jump target cannot be remapped after optimization: target_pc=" + instruction.immediate()); + } + if (remappedTarget == instruction.immediate()) { + continue; + } + newInstructions.set(i, new IRVMInstruction(instruction.op(), remappedTarget)); + newOperations.set(i, remapJumpOperation(newOperations.get(i), instruction.op(), remappedTarget)); + } + + final var rewrittenFunction = new IRVMFunction( + function.name(), + function.paramSlots(), + function.localSlots(), + function.returnSlots(), + function.maxStackSlots(), + ReadOnlyList.wrap(newInstructions)); + final var rewrittenPlan = new BytecodeEmitter.FunctionPlan( + functionPlan.name(), + functionPlan.paramSlots(), + functionPlan.localSlots(), + functionPlan.returnSlots(), + functionPlan.maxStackSlots(), + ReadOnlyList.wrap(newOperations)); + return new FunctionRewrite(rewrittenFunction, rewrittenPlan, changed); + } + + private static int resolveRemovedJumpTarget( + final int oldTargetPc, + final boolean[] removeByIndex, + final HashMap oldIndexByPc, + final List oldInstructions) { + var currentTargetPc = oldTargetPc; + final var visited = new HashSet(); + while (visited.add(currentTargetPc)) { + final var targetIndex = oldIndexByPc.get(currentTargetPc); + if (targetIndex == null || !removeByIndex[targetIndex]) { + return currentTargetPc; + } + final var targetInstruction = oldInstructions.get(targetIndex); + if (targetInstruction.op() != IRVMOp.JMP) { + return currentTargetPc; + } + currentTargetPc = targetInstruction.immediate(); + } + return currentTargetPc; + } + + private static BytecodeEmitter.Operation remapJumpOperation( + final BytecodeEmitter.Operation operation, + final IRVMOp jumpOp, + final int targetPc) { + final var span = operation == null ? null : operation.span(); + if (jumpOp == IRVMOp.JMP) { + return BytecodeEmitter.Operation.jmp(targetPc, span); + } + if (jumpOp == IRVMOp.JMP_IF_TRUE) { + return BytecodeEmitter.Operation.jmpIfTrue(targetPc, span); + } + if (jumpOp == IRVMOp.JMP_IF_FALSE) { + return BytecodeEmitter.Operation.jmpIfFalse(targetPc, span); + } + throw new IllegalArgumentException("unexpected jump op for remap: " + jumpOp.name()); + } + + private record FunctionRewrite( + IRVMFunction function, + BytecodeEmitter.FunctionPlan functionPlan, + boolean changed) { + } + + private static final class UnreachableInstructionEliminationPass implements IRVMPass { @Override public String name() { - return "NoOpPass"; + return "UnreachableInstructionEliminationPass"; } @Override public IRVMProgram apply(final IRVMProgram input) { - return input; + return transformProgram(input, this::rewriteFunction); + } + + private FunctionRewrite rewriteFunction( + final IRVMFunction function, + final BytecodeEmitter.FunctionPlan functionPlan) { + final var instructions = function.instructions().asList(); + if (instructions.isEmpty()) { + return new FunctionRewrite(function, functionPlan, false); + } + + final var pcs = pcByIndex(instructions); + final var indexByPc = indexByPc(pcs); + final var reachable = new boolean[instructions.size()]; + final var worklist = new ArrayDeque(); + worklist.add(0); + while (!worklist.isEmpty()) { + final var index = worklist.removeFirst(); + if (index < 0 || index >= instructions.size() || reachable[index]) { + continue; + } + reachable[index] = true; + final var instruction = instructions.get(index); + if (instruction.op() == IRVMOp.HALT || instruction.op() == IRVMOp.RET) { + continue; + } + if (isJump(instruction.op())) { + final var targetIndex = indexByPc.get(instruction.immediate()); + if (targetIndex != null) { + worklist.add(targetIndex); + } + if (instruction.op() == IRVMOp.JMP) { + continue; + } + } + final var fallthroughIndex = index + 1; + if (fallthroughIndex < instructions.size()) { + worklist.add(fallthroughIndex); + } + } + + final var removeByIndex = new boolean[instructions.size()]; + var removedAny = false; + for (var i = 0; i < instructions.size(); i++) { + removeByIndex[i] = !reachable[i]; + if (removeByIndex[i]) { + removedAny = true; + } + } + if (!removedAny) { + return new FunctionRewrite(function, functionPlan, false); + } + return removeInstructionIndices(function, functionPlan, removeByIndex); + } + } + + private static final class NormalizeRedundantJumpTargetsPass implements IRVMPass { + @Override + public String name() { + return "NormalizeRedundantJumpTargetsPass"; + } + + @Override + public IRVMProgram apply(final IRVMProgram input) { + return transformProgram(input, this::rewriteFunction); + } + + private FunctionRewrite rewriteFunction( + final IRVMFunction function, + final BytecodeEmitter.FunctionPlan functionPlan) { + final var instructions = new ArrayList<>(function.instructions().asList()); + final var operations = new ArrayList<>(functionPlan.operations().asList()); + final var indexByPc = indexByPc(pcByIndex(instructions)); + var changed = false; + + for (var i = 0; i < instructions.size(); i++) { + final var instruction = instructions.get(i); + if (!isJump(instruction.op())) { + continue; + } + final var normalizedTarget = normalizeJumpTarget( + instruction.immediate(), + instructions, + indexByPc); + if (normalizedTarget == instruction.immediate()) { + continue; + } + changed = true; + instructions.set(i, new IRVMInstruction(instruction.op(), normalizedTarget)); + operations.set(i, remapJumpOperation(operations.get(i), instruction.op(), normalizedTarget)); + } + + if (!changed) { + return new FunctionRewrite(function, functionPlan, false); + } + return new FunctionRewrite( + new IRVMFunction( + function.name(), + function.paramSlots(), + function.localSlots(), + function.returnSlots(), + function.maxStackSlots(), + ReadOnlyList.wrap(instructions)), + new BytecodeEmitter.FunctionPlan( + functionPlan.name(), + functionPlan.paramSlots(), + functionPlan.localSlots(), + functionPlan.returnSlots(), + functionPlan.maxStackSlots(), + ReadOnlyList.wrap(operations)), + true); + } + + private int normalizeJumpTarget( + final int targetPc, + final List instructions, + final HashMap indexByPc) { + var currentPc = targetPc; + final var visited = new HashSet(); + while (visited.add(currentPc)) { + final var targetIndex = indexByPc.get(currentPc); + if (targetIndex == null) { + return currentPc; + } + final var targetInstruction = instructions.get(targetIndex); + if (targetInstruction.op() != IRVMOp.JMP) { + return currentPc; + } + currentPc = targetInstruction.immediate(); + } + return currentPc; + } + } + + private static final class SimplifyJumpToNextPcPass implements IRVMPass { + @Override + public String name() { + return "SimplifyJumpToNextPcPass"; + } + + @Override + public IRVMProgram apply(final IRVMProgram input) { + return transformProgram(input, this::rewriteFunction); + } + + private FunctionRewrite rewriteFunction( + final IRVMFunction function, + final BytecodeEmitter.FunctionPlan functionPlan) { + var currentFunction = function; + var currentFunctionPlan = functionPlan; + var changed = false; + + while (true) { + final var instructions = currentFunction.instructions().asList(); + if (instructions.size() <= 1) { + break; + } + final var pcs = pcByIndex(instructions); + final var removeByIndex = new boolean[instructions.size()]; + var removedAny = false; + for (var i = 0; i < instructions.size() - 1; i++) { + final var instruction = instructions.get(i); + if (instruction.op() == IRVMOp.JMP && instruction.immediate() == pcs[i + 1]) { + removeByIndex[i] = true; + removedAny = true; + } + } + if (!removedAny) { + break; + } + + final var rewritten = removeInstructionIndices(currentFunction, currentFunctionPlan, removeByIndex); + currentFunction = rewritten.function(); + currentFunctionPlan = rewritten.functionPlan(); + changed = true; + } + + return new FunctionRewrite(currentFunction, currentFunctionPlan, changed); } } } diff --git a/prometeu-compiler/prometeu-build-pipeline/src/test/java/p/studio/compiler/backend/irvm/OptimizeIRVMServiceTest.java b/prometeu-compiler/prometeu-build-pipeline/src/test/java/p/studio/compiler/backend/irvm/OptimizeIRVMServiceTest.java index 5721168c..6a2614b0 100644 --- a/prometeu-compiler/prometeu-build-pipeline/src/test/java/p/studio/compiler/backend/irvm/OptimizeIRVMServiceTest.java +++ b/prometeu-compiler/prometeu-build-pipeline/src/test/java/p/studio/compiler/backend/irvm/OptimizeIRVMServiceTest.java @@ -1,12 +1,14 @@ package p.studio.compiler.backend.irvm; import org.junit.jupiter.api.Test; +import p.studio.compiler.backend.bytecode.BytecodeEmitter; import p.studio.utilities.structures.ReadOnlyList; import java.util.ArrayList; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -23,6 +25,7 @@ class OptimizeIRVMServiceTest { assertEquals(validProgram(), output); assertEquals(List.of("A", "B"), order); + assertEquals(List.of("A", "B"), service.lastExecutedPassNames()); } @Test @@ -35,6 +38,7 @@ class OptimizeIRVMServiceTest { service.optimize(validProgram()); assertEquals(List.of("B"), order); + assertEquals(List.of("B"), service.lastExecutedPassNames()); } @Test @@ -57,6 +61,72 @@ class OptimizeIRVMServiceTest { assertTrue(thrown.getMessage().contains("vm profile")); } + @Test + void optimizeDefaultPassesMustRemoveUnreachableInstructions() { + final var service = new OptimizeIRVMService(); + final var input = singleFunctionProgram( + ReadOnlyList.from( + new IRVMInstruction(IRVMOp.RET, null), + new IRVMInstruction(IRVMOp.HALT, null)), + ReadOnlyList.from( + BytecodeEmitter.Operation.ret(), + BytecodeEmitter.Operation.halt())); + + final var optimized = service.optimize(input); + + assertNotEquals(input, optimized); + assertEquals(1, optimized.module().functions().getFirst().instructions().size()); + assertEquals(IRVMOp.RET, optimized.module().functions().getFirst().instructions().getFirst().op()); + assertEquals(List.of( + "UnreachableInstructionEliminationPass", + "NormalizeRedundantJumpTargetsPass", + "SimplifyJumpToNextPcPass"), service.lastExecutedPassNames()); + } + + @Test + void normalizeRedundantJumpTargetsPassMustCollapseJumpChain() { + final var service = new OptimizeIRVMService(new IRVMValidator(), List.of( + OptimizeIRVMService.normalizeRedundantJumpTargetsPass())); + final var input = singleFunctionProgram( + ReadOnlyList.from( + new IRVMInstruction(IRVMOp.JMP, 6), + new IRVMInstruction(IRVMOp.JMP, 12), + new IRVMInstruction(IRVMOp.RET, null)), + ReadOnlyList.from( + BytecodeEmitter.Operation.jmp(6, null), + BytecodeEmitter.Operation.jmp(12, null), + BytecodeEmitter.Operation.ret())); + + final var optimized = service.optimize(input); + final var firstInstruction = optimized.module().functions().getFirst().instructions().getFirst(); + final var firstOperation = optimized.emissionPlan().functions().getFirst().operations().getFirst(); + + assertEquals(IRVMOp.JMP, firstInstruction.op()); + assertEquals(12, firstInstruction.immediate()); + assertEquals(BytecodeEmitter.OperationKind.JMP, firstOperation.kind()); + assertEquals(12, firstOperation.immediate()); + } + + @Test + void simplifyJumpToNextPcPassMustRemoveDirectFallthroughJump() { + final var service = new OptimizeIRVMService(new IRVMValidator(), List.of( + OptimizeIRVMService.simplifyJumpToNextPcPass())); + final var input = singleFunctionProgram( + ReadOnlyList.from( + new IRVMInstruction(IRVMOp.JMP, 6), + new IRVMInstruction(IRVMOp.RET, null)), + ReadOnlyList.from( + BytecodeEmitter.Operation.jmp(6, null), + BytecodeEmitter.Operation.ret())); + + final var optimized = service.optimize(input); + + assertEquals(1, optimized.module().functions().getFirst().instructions().size()); + assertEquals(IRVMOp.RET, optimized.module().functions().getFirst().instructions().getFirst().op()); + assertEquals(1, optimized.emissionPlan().functions().getFirst().operations().size()); + assertEquals(BytecodeEmitter.OperationKind.RET, optimized.emissionPlan().functions().getFirst().operations().getFirst().kind()); + } + private OptimizeIRVMService.IRVMPass namedPass( final String name, final List order, @@ -81,14 +151,34 @@ class OptimizeIRVMServiceTest { } private IRVMProgram validProgram() { - return new IRVMProgram(new IRVMModule( - "core-v1", - ReadOnlyList.from(new IRVMFunction( - "main", + return singleFunctionProgram( + ReadOnlyList.from(new IRVMInstruction(IRVMOp.HALT, null)), + ReadOnlyList.from(BytecodeEmitter.Operation.halt())); + } + + private IRVMProgram singleFunctionProgram( + final ReadOnlyList instructions, + final ReadOnlyList operations) { + return new IRVMProgram( + new IRVMModule( + "core-v1", + ReadOnlyList.from(new IRVMFunction( + "main", + 0, + 0, + 0, + 1, + instructions))), + new BytecodeEmitter.EmissionPlan( 0, - 0, - 0, - 1, - ReadOnlyList.from(new IRVMInstruction(IRVMOp.HALT, null)))))); + ReadOnlyList.empty(), + ReadOnlyList.empty(), + ReadOnlyList.from(new BytecodeEmitter.FunctionPlan( + "main", + 0, + 0, + 0, + 1, + operations)))); } }