implements PR-05.6

This commit is contained in:
bQUARKz 2026-03-09 07:13:46 +00:00
parent 48ce448203
commit 0cc836246f
Signed by: bquarkz
SSH Key Fingerprint: SHA256:Z7dgqoglWwoK6j6u4QC87OveEq74WOhFN+gitsxtkf8
2 changed files with 497 additions and 14 deletions

View File

@ -1,5 +1,12 @@
package p.studio.compiler.backend.irvm;
import p.studio.compiler.backend.bytecode.BytecodeEmitter;
import p.studio.utilities.structures.ReadOnlyList;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
@ -7,13 +14,14 @@ public class OptimizeIRVMService {
private final IRVMValidator validator;
private final IRVMProfileFeatureGate profileFeatureGate;
private final List<IRVMPass> passes;
private volatile List<String> lastExecutedPassNames = List.of();
public OptimizeIRVMService() {
this(new IRVMValidator(), new IRVMProfileFeatureGate(), List.of(new NoOpPass()));
this(new IRVMValidator(), new IRVMProfileFeatureGate(), defaultPasses());
}
OptimizeIRVMService(final IRVMValidator validator) {
this(validator, new IRVMProfileFeatureGate(), List.of(new NoOpPass()));
this(validator, new IRVMProfileFeatureGate(), defaultPasses());
}
OptimizeIRVMService(
@ -28,7 +36,7 @@ public class OptimizeIRVMService {
final List<IRVMPass> passes) {
this.validator = validator;
this.profileFeatureGate = profileFeatureGate;
this.passes = passes == null ? List.of(new NoOpPass()) : List.copyOf(passes);
this.passes = passes == null ? defaultPasses() : List.copyOf(passes);
}
public IRVMProgram optimize(final IRVMProgram input) {
@ -37,11 +45,13 @@ public class OptimizeIRVMService {
throw new IllegalArgumentException("unsupported vm profile: " + program.module().vmProfile());
}
var current = program;
final var executedPassNames = new ArrayList<String>();
validator.validate(current, false);
for (final var pass : passes) {
if (pass == null || !pass.enabled()) {
continue;
}
executedPassNames.add(pass.name());
final var beforeProfile = current.module().vmProfile();
current = Objects.requireNonNull(pass.apply(current), "pass output");
if (!beforeProfile.equals(current.module().vmProfile())) {
@ -49,9 +59,33 @@ public class OptimizeIRVMService {
}
validator.validate(current, false);
}
lastExecutedPassNames = List.copyOf(executedPassNames);
return current;
}
public List<String> lastExecutedPassNames() {
return lastExecutedPassNames;
}
static IRVMPass unreachableInstructionEliminationPass() {
return new UnreachableInstructionEliminationPass();
}
static IRVMPass normalizeRedundantJumpTargetsPass() {
return new NormalizeRedundantJumpTargetsPass();
}
static IRVMPass simplifyJumpToNextPcPass() {
return new SimplifyJumpToNextPcPass();
}
private static List<IRVMPass> defaultPasses() {
return List.of(
unreachableInstructionEliminationPass(),
normalizeRedundantJumpTargetsPass(),
simplifyJumpToNextPcPass());
}
public interface IRVMPass {
String name();
@ -62,15 +96,374 @@ public class OptimizeIRVMService {
IRVMProgram apply(IRVMProgram input);
}
private static final class NoOpPass implements IRVMPass {
private interface FunctionTransformer {
FunctionRewrite apply(IRVMFunction function, BytecodeEmitter.FunctionPlan functionPlan);
}
private static IRVMProgram transformProgram(
final IRVMProgram input,
final FunctionTransformer functionTransformer) {
if (input == null || input.emissionPlan() == null || input.emissionPlan().functions().isEmpty()) {
return input;
}
if (input.emissionPlan().functions().size() != input.module().functions().size()) {
return input;
}
var changed = false;
final var rewrittenFunctions = new ArrayList<IRVMFunction>(input.module().functions().size());
final var rewrittenPlans = new ArrayList<BytecodeEmitter.FunctionPlan>(input.emissionPlan().functions().size());
for (var i = 0; i < input.module().functions().size(); i++) {
final var rewrite = functionTransformer.apply(
input.module().functions().get(i),
input.emissionPlan().functions().get(i));
rewrittenFunctions.add(rewrite.function());
rewrittenPlans.add(rewrite.functionPlan());
if (rewrite.changed()) {
changed = true;
}
}
if (!changed) {
return input;
}
return new IRVMProgram(
new IRVMModule(input.module().vmProfile(), ReadOnlyList.wrap(rewrittenFunctions)),
new BytecodeEmitter.EmissionPlan(
input.emissionPlan().version(),
input.emissionPlan().constPool(),
input.emissionPlan().exports(),
ReadOnlyList.wrap(rewrittenPlans)));
}
private static boolean isJump(final IRVMOp op) {
return op == IRVMOp.JMP || op == IRVMOp.JMP_IF_TRUE || op == IRVMOp.JMP_IF_FALSE;
}
private static int[] pcByIndex(final List<IRVMInstruction> instructions) {
final var out = new int[instructions.size()];
var pc = 0;
for (var i = 0; i < instructions.size(); i++) {
out[i] = pc;
pc += instructions.get(i).encodedSize();
}
return out;
}
private static HashMap<Integer, Integer> indexByPc(final int[] pcByIndex) {
final var out = new HashMap<Integer, Integer>();
for (var i = 0; i < pcByIndex.length; i++) {
out.put(pcByIndex[i], i);
}
return out;
}
private static FunctionRewrite removeInstructionIndices(
final IRVMFunction function,
final BytecodeEmitter.FunctionPlan functionPlan,
final boolean[] removeByIndex) {
var removedAny = false;
for (final var remove : removeByIndex) {
if (remove) {
removedAny = true;
break;
}
}
if (!removedAny) {
return new FunctionRewrite(function, functionPlan, false);
}
final var oldInstructions = function.instructions().asList();
final var oldOperations = functionPlan.operations().asList();
final var oldPcByIndex = pcByIndex(oldInstructions);
final var oldIndexByPc = indexByPc(oldPcByIndex);
final var newInstructions = new ArrayList<IRVMInstruction>(oldInstructions.size());
final var newOperations = new ArrayList<BytecodeEmitter.Operation>(oldOperations.size());
final var oldPcToNewPc = new HashMap<Integer, Integer>();
var nextPc = 0;
for (var i = 0; i < oldInstructions.size(); i++) {
if (removeByIndex[i]) {
continue;
}
oldPcToNewPc.put(oldPcByIndex[i], nextPc);
newInstructions.add(oldInstructions.get(i));
newOperations.add(oldOperations.get(i));
nextPc += oldInstructions.get(i).encodedSize();
}
var changed = true;
for (var i = 0; i < newInstructions.size(); i++) {
final var instruction = newInstructions.get(i);
if (!isJump(instruction.op())) {
continue;
}
final var resolvedOldTarget = resolveRemovedJumpTarget(
instruction.immediate(),
removeByIndex,
oldIndexByPc,
oldInstructions);
final var remappedTarget = oldPcToNewPc.get(resolvedOldTarget);
if (remappedTarget == null) {
throw new IllegalArgumentException(
"jump target cannot be remapped after optimization: target_pc=" + instruction.immediate());
}
if (remappedTarget == instruction.immediate()) {
continue;
}
newInstructions.set(i, new IRVMInstruction(instruction.op(), remappedTarget));
newOperations.set(i, remapJumpOperation(newOperations.get(i), instruction.op(), remappedTarget));
}
final var rewrittenFunction = new IRVMFunction(
function.name(),
function.paramSlots(),
function.localSlots(),
function.returnSlots(),
function.maxStackSlots(),
ReadOnlyList.wrap(newInstructions));
final var rewrittenPlan = new BytecodeEmitter.FunctionPlan(
functionPlan.name(),
functionPlan.paramSlots(),
functionPlan.localSlots(),
functionPlan.returnSlots(),
functionPlan.maxStackSlots(),
ReadOnlyList.wrap(newOperations));
return new FunctionRewrite(rewrittenFunction, rewrittenPlan, changed);
}
private static int resolveRemovedJumpTarget(
final int oldTargetPc,
final boolean[] removeByIndex,
final HashMap<Integer, Integer> oldIndexByPc,
final List<IRVMInstruction> oldInstructions) {
var currentTargetPc = oldTargetPc;
final var visited = new HashSet<Integer>();
while (visited.add(currentTargetPc)) {
final var targetIndex = oldIndexByPc.get(currentTargetPc);
if (targetIndex == null || !removeByIndex[targetIndex]) {
return currentTargetPc;
}
final var targetInstruction = oldInstructions.get(targetIndex);
if (targetInstruction.op() != IRVMOp.JMP) {
return currentTargetPc;
}
currentTargetPc = targetInstruction.immediate();
}
return currentTargetPc;
}
private static BytecodeEmitter.Operation remapJumpOperation(
final BytecodeEmitter.Operation operation,
final IRVMOp jumpOp,
final int targetPc) {
final var span = operation == null ? null : operation.span();
if (jumpOp == IRVMOp.JMP) {
return BytecodeEmitter.Operation.jmp(targetPc, span);
}
if (jumpOp == IRVMOp.JMP_IF_TRUE) {
return BytecodeEmitter.Operation.jmpIfTrue(targetPc, span);
}
if (jumpOp == IRVMOp.JMP_IF_FALSE) {
return BytecodeEmitter.Operation.jmpIfFalse(targetPc, span);
}
throw new IllegalArgumentException("unexpected jump op for remap: " + jumpOp.name());
}
private record FunctionRewrite(
IRVMFunction function,
BytecodeEmitter.FunctionPlan functionPlan,
boolean changed) {
}
private static final class UnreachableInstructionEliminationPass implements IRVMPass {
@Override
public String name() {
return "NoOpPass";
return "UnreachableInstructionEliminationPass";
}
@Override
public IRVMProgram apply(final IRVMProgram input) {
return input;
return transformProgram(input, this::rewriteFunction);
}
private FunctionRewrite rewriteFunction(
final IRVMFunction function,
final BytecodeEmitter.FunctionPlan functionPlan) {
final var instructions = function.instructions().asList();
if (instructions.isEmpty()) {
return new FunctionRewrite(function, functionPlan, false);
}
final var pcs = pcByIndex(instructions);
final var indexByPc = indexByPc(pcs);
final var reachable = new boolean[instructions.size()];
final var worklist = new ArrayDeque<Integer>();
worklist.add(0);
while (!worklist.isEmpty()) {
final var index = worklist.removeFirst();
if (index < 0 || index >= instructions.size() || reachable[index]) {
continue;
}
reachable[index] = true;
final var instruction = instructions.get(index);
if (instruction.op() == IRVMOp.HALT || instruction.op() == IRVMOp.RET) {
continue;
}
if (isJump(instruction.op())) {
final var targetIndex = indexByPc.get(instruction.immediate());
if (targetIndex != null) {
worklist.add(targetIndex);
}
if (instruction.op() == IRVMOp.JMP) {
continue;
}
}
final var fallthroughIndex = index + 1;
if (fallthroughIndex < instructions.size()) {
worklist.add(fallthroughIndex);
}
}
final var removeByIndex = new boolean[instructions.size()];
var removedAny = false;
for (var i = 0; i < instructions.size(); i++) {
removeByIndex[i] = !reachable[i];
if (removeByIndex[i]) {
removedAny = true;
}
}
if (!removedAny) {
return new FunctionRewrite(function, functionPlan, false);
}
return removeInstructionIndices(function, functionPlan, removeByIndex);
}
}
private static final class NormalizeRedundantJumpTargetsPass implements IRVMPass {
@Override
public String name() {
return "NormalizeRedundantJumpTargetsPass";
}
@Override
public IRVMProgram apply(final IRVMProgram input) {
return transformProgram(input, this::rewriteFunction);
}
private FunctionRewrite rewriteFunction(
final IRVMFunction function,
final BytecodeEmitter.FunctionPlan functionPlan) {
final var instructions = new ArrayList<>(function.instructions().asList());
final var operations = new ArrayList<>(functionPlan.operations().asList());
final var indexByPc = indexByPc(pcByIndex(instructions));
var changed = false;
for (var i = 0; i < instructions.size(); i++) {
final var instruction = instructions.get(i);
if (!isJump(instruction.op())) {
continue;
}
final var normalizedTarget = normalizeJumpTarget(
instruction.immediate(),
instructions,
indexByPc);
if (normalizedTarget == instruction.immediate()) {
continue;
}
changed = true;
instructions.set(i, new IRVMInstruction(instruction.op(), normalizedTarget));
operations.set(i, remapJumpOperation(operations.get(i), instruction.op(), normalizedTarget));
}
if (!changed) {
return new FunctionRewrite(function, functionPlan, false);
}
return new FunctionRewrite(
new IRVMFunction(
function.name(),
function.paramSlots(),
function.localSlots(),
function.returnSlots(),
function.maxStackSlots(),
ReadOnlyList.wrap(instructions)),
new BytecodeEmitter.FunctionPlan(
functionPlan.name(),
functionPlan.paramSlots(),
functionPlan.localSlots(),
functionPlan.returnSlots(),
functionPlan.maxStackSlots(),
ReadOnlyList.wrap(operations)),
true);
}
private int normalizeJumpTarget(
final int targetPc,
final List<IRVMInstruction> instructions,
final HashMap<Integer, Integer> indexByPc) {
var currentPc = targetPc;
final var visited = new HashSet<Integer>();
while (visited.add(currentPc)) {
final var targetIndex = indexByPc.get(currentPc);
if (targetIndex == null) {
return currentPc;
}
final var targetInstruction = instructions.get(targetIndex);
if (targetInstruction.op() != IRVMOp.JMP) {
return currentPc;
}
currentPc = targetInstruction.immediate();
}
return currentPc;
}
}
private static final class SimplifyJumpToNextPcPass implements IRVMPass {
@Override
public String name() {
return "SimplifyJumpToNextPcPass";
}
@Override
public IRVMProgram apply(final IRVMProgram input) {
return transformProgram(input, this::rewriteFunction);
}
private FunctionRewrite rewriteFunction(
final IRVMFunction function,
final BytecodeEmitter.FunctionPlan functionPlan) {
var currentFunction = function;
var currentFunctionPlan = functionPlan;
var changed = false;
while (true) {
final var instructions = currentFunction.instructions().asList();
if (instructions.size() <= 1) {
break;
}
final var pcs = pcByIndex(instructions);
final var removeByIndex = new boolean[instructions.size()];
var removedAny = false;
for (var i = 0; i < instructions.size() - 1; i++) {
final var instruction = instructions.get(i);
if (instruction.op() == IRVMOp.JMP && instruction.immediate() == pcs[i + 1]) {
removeByIndex[i] = true;
removedAny = true;
}
}
if (!removedAny) {
break;
}
final var rewritten = removeInstructionIndices(currentFunction, currentFunctionPlan, removeByIndex);
currentFunction = rewritten.function();
currentFunctionPlan = rewritten.functionPlan();
changed = true;
}
return new FunctionRewrite(currentFunction, currentFunctionPlan, changed);
}
}
}

View File

@ -1,12 +1,14 @@
package p.studio.compiler.backend.irvm;
import org.junit.jupiter.api.Test;
import p.studio.compiler.backend.bytecode.BytecodeEmitter;
import p.studio.utilities.structures.ReadOnlyList;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
@ -23,6 +25,7 @@ class OptimizeIRVMServiceTest {
assertEquals(validProgram(), output);
assertEquals(List.of("A", "B"), order);
assertEquals(List.of("A", "B"), service.lastExecutedPassNames());
}
@Test
@ -35,6 +38,7 @@ class OptimizeIRVMServiceTest {
service.optimize(validProgram());
assertEquals(List.of("B"), order);
assertEquals(List.of("B"), service.lastExecutedPassNames());
}
@Test
@ -57,6 +61,72 @@ class OptimizeIRVMServiceTest {
assertTrue(thrown.getMessage().contains("vm profile"));
}
@Test
void optimizeDefaultPassesMustRemoveUnreachableInstructions() {
final var service = new OptimizeIRVMService();
final var input = singleFunctionProgram(
ReadOnlyList.from(
new IRVMInstruction(IRVMOp.RET, null),
new IRVMInstruction(IRVMOp.HALT, null)),
ReadOnlyList.from(
BytecodeEmitter.Operation.ret(),
BytecodeEmitter.Operation.halt()));
final var optimized = service.optimize(input);
assertNotEquals(input, optimized);
assertEquals(1, optimized.module().functions().getFirst().instructions().size());
assertEquals(IRVMOp.RET, optimized.module().functions().getFirst().instructions().getFirst().op());
assertEquals(List.of(
"UnreachableInstructionEliminationPass",
"NormalizeRedundantJumpTargetsPass",
"SimplifyJumpToNextPcPass"), service.lastExecutedPassNames());
}
@Test
void normalizeRedundantJumpTargetsPassMustCollapseJumpChain() {
final var service = new OptimizeIRVMService(new IRVMValidator(), List.of(
OptimizeIRVMService.normalizeRedundantJumpTargetsPass()));
final var input = singleFunctionProgram(
ReadOnlyList.from(
new IRVMInstruction(IRVMOp.JMP, 6),
new IRVMInstruction(IRVMOp.JMP, 12),
new IRVMInstruction(IRVMOp.RET, null)),
ReadOnlyList.from(
BytecodeEmitter.Operation.jmp(6, null),
BytecodeEmitter.Operation.jmp(12, null),
BytecodeEmitter.Operation.ret()));
final var optimized = service.optimize(input);
final var firstInstruction = optimized.module().functions().getFirst().instructions().getFirst();
final var firstOperation = optimized.emissionPlan().functions().getFirst().operations().getFirst();
assertEquals(IRVMOp.JMP, firstInstruction.op());
assertEquals(12, firstInstruction.immediate());
assertEquals(BytecodeEmitter.OperationKind.JMP, firstOperation.kind());
assertEquals(12, firstOperation.immediate());
}
@Test
void simplifyJumpToNextPcPassMustRemoveDirectFallthroughJump() {
final var service = new OptimizeIRVMService(new IRVMValidator(), List.of(
OptimizeIRVMService.simplifyJumpToNextPcPass()));
final var input = singleFunctionProgram(
ReadOnlyList.from(
new IRVMInstruction(IRVMOp.JMP, 6),
new IRVMInstruction(IRVMOp.RET, null)),
ReadOnlyList.from(
BytecodeEmitter.Operation.jmp(6, null),
BytecodeEmitter.Operation.ret()));
final var optimized = service.optimize(input);
assertEquals(1, optimized.module().functions().getFirst().instructions().size());
assertEquals(IRVMOp.RET, optimized.module().functions().getFirst().instructions().getFirst().op());
assertEquals(1, optimized.emissionPlan().functions().getFirst().operations().size());
assertEquals(BytecodeEmitter.OperationKind.RET, optimized.emissionPlan().functions().getFirst().operations().getFirst().kind());
}
private OptimizeIRVMService.IRVMPass namedPass(
final String name,
final List<String> order,
@ -81,14 +151,34 @@ class OptimizeIRVMServiceTest {
}
private IRVMProgram validProgram() {
return new IRVMProgram(new IRVMModule(
"core-v1",
ReadOnlyList.from(new IRVMFunction(
"main",
return singleFunctionProgram(
ReadOnlyList.from(new IRVMInstruction(IRVMOp.HALT, null)),
ReadOnlyList.from(BytecodeEmitter.Operation.halt()));
}
private IRVMProgram singleFunctionProgram(
final ReadOnlyList<IRVMInstruction> instructions,
final ReadOnlyList<BytecodeEmitter.Operation> operations) {
return new IRVMProgram(
new IRVMModule(
"core-v1",
ReadOnlyList.from(new IRVMFunction(
"main",
0,
0,
0,
1,
instructions))),
new BytecodeEmitter.EmissionPlan(
0,
0,
0,
1,
ReadOnlyList.from(new IRVMInstruction(IRVMOp.HALT, null))))));
ReadOnlyList.empty(),
ReadOnlyList.empty(),
ReadOnlyList.from(new BytecodeEmitter.FunctionPlan(
"main",
0,
0,
0,
1,
operations))));
}
}