diff --git a/compiler/forget/src/HIR/ValidateConsistentIdentifiers.ts b/compiler/forget/src/HIR/ValidateConsistentIdentifiers.ts index 6444ea7271..a8fde74b2b 100644 --- a/compiler/forget/src/HIR/ValidateConsistentIdentifiers.ts +++ b/compiler/forget/src/HIR/ValidateConsistentIdentifiers.ts @@ -14,7 +14,11 @@ import { SourceLocation, } from "./HIR"; import { printPlace } from "./PrintHIR"; -import { eachInstructionValueOperand, eachTerminalOperand } from "./visitors"; +import { + eachInstructionLValue, + eachInstructionValueOperand, + eachTerminalOperand, +} from "./visitors"; /** * Validation pass to check that there is a 1:1 mapping between Identifier objects and IdentifierIds, @@ -46,7 +50,9 @@ export function validateConsistentIdentifiers(fn: HIRFunction): void { ); } assignments.add(instr.lvalue.identifier.id); - validate(identifiers, instr.lvalue.identifier, instr.lvalue.loc); + for (const operand of eachInstructionLValue(instr)) { + validate(identifiers, operand.identifier, operand.loc); + } for (const operand of eachInstructionValueOperand(instr.value)) { validate(identifiers, operand.identifier, operand.loc); } diff --git a/compiler/forget/src/HIR/visitors.ts b/compiler/forget/src/HIR/visitors.ts index 8aaeb8c996..6155b0a880 100644 --- a/compiler/forget/src/HIR/visitors.ts +++ b/compiler/forget/src/HIR/visitors.ts @@ -7,19 +7,34 @@ import { assertExhaustive } from "../Utils/utils"; import { - BasicBlock, BlockId, - HIR, Instruction, InstructionValue, makeInstructionId, Pattern, Place, - ReactiveScope, - ScopeId, + ReactiveInstruction, Terminal, } from "./HIR"; +export function* eachInstructionLValue( + instr: ReactiveInstruction +): Iterable { + if (instr.lvalue !== null) { + yield instr.lvalue; + } + switch (instr.value.kind) { + case "StoreLocal": { + yield instr.value.lvalue.place; + break; + } + case "Destructure": { + yield* eachPatternOperand(instr.value.lvalue.pattern); + break; + } + } +} + export function* eachInstructionOperand(instr: Instruction): Iterable { yield* eachInstructionValueOperand(instr.value); } @@ -54,12 +69,10 @@ export function* eachInstructionValueOperand( break; } case "StoreLocal": { - yield instrValue.lvalue.place; yield instrValue.value; break; } case "Destructure": { - yield* eachPatternOperand(instrValue.lvalue.pattern); yield instrValue.value; break; } @@ -200,6 +213,26 @@ export function* eachPatternOperand(pattern: Pattern): Iterable { } } +export function mapInstructionLValues( + instr: Instruction, + fn: (place: Place) => Place +): void { + switch (instr.value.kind) { + case "StoreLocal": { + const lvalue = instr.value.lvalue; + lvalue.place = fn(lvalue.place); + break; + } + case "Destructure": { + mapPatternOperands(instr.value.lvalue.pattern, fn); + break; + } + } + if (instr.lvalue !== null) { + instr.lvalue = fn(instr.lvalue); + } +} + export function mapInstructionOperands( instr: Instruction, fn: (place: Place) => Place @@ -671,32 +704,3 @@ export function* eachTerminalOperand(terminal: Terminal): Iterable { } } } - -/** - * Iterates over all {@link Place}s within a {@link BasicBlock}. - */ -export function* eachBlockOperand(block: BasicBlock): Iterable { - for (const instr of block.instructions) { - yield* eachInstructionOperand(instr); - if (instr.lvalue != null) { - yield instr.lvalue; - } - } - yield* eachTerminalOperand(block.terminal); -} - -export function* eachReactiveScope(ir: HIR): Iterable { - const seenScopes: Set = new Set(); - for (const [, block] of ir.blocks) { - for (const operand of eachBlockOperand(block)) { - const scope = operand.identifier.scope; - if (scope != null) { - if (seenScopes.has(scope.id)) { - continue; - } - seenScopes.add(scope.id); - yield scope; - } - } - } -} diff --git a/compiler/forget/src/Inference/InferAliasForStores.ts b/compiler/forget/src/Inference/InferAliasForStores.ts index cb59236c37..1a3191ac15 100644 --- a/compiler/forget/src/Inference/InferAliasForStores.ts +++ b/compiler/forget/src/Inference/InferAliasForStores.ts @@ -12,8 +12,8 @@ import { Place, } from "../HIR/HIR"; import { + eachInstructionLValue, eachInstructionValueOperand, - eachPatternOperand, } from "../HIR/visitors"; import DisjointSet from "../Utils/DisjointSet"; @@ -27,12 +27,8 @@ export function inferAliasForStores( if (lvalue.effect !== Effect.Store) { continue; } - if (value.kind === "StoreLocal") { - maybeAlias(aliases, value.lvalue.place, value.value, instr.id); - } else if (value.kind === "Destructure") { - for (const place of eachPatternOperand(value.lvalue.pattern)) { - maybeAlias(aliases, place, value.value, instr.id); - } + for (const operand of eachInstructionLValue(instr)) { + maybeAlias(aliases, lvalue, operand, instr.id); } for (const operand of eachInstructionValueOperand(value)) { if ( diff --git a/compiler/forget/src/Inference/InferMutableLifetimes.ts b/compiler/forget/src/Inference/InferMutableLifetimes.ts index 00e2d53e44..8f222815ea 100644 --- a/compiler/forget/src/Inference/InferMutableLifetimes.ts +++ b/compiler/forget/src/Inference/InferMutableLifetimes.ts @@ -14,7 +14,7 @@ import { Place, } from "../HIR/HIR"; import { printInstruction, printPlace } from "../HIR/PrintHIR"; -import { eachInstructionOperand, eachPatternOperand } from "../HIR/visitors"; +import { eachInstructionLValue, eachInstructionOperand } from "../HIR/visitors"; import { assertExhaustive } from "../Utils/utils"; /** @@ -118,32 +118,20 @@ export function inferMutableLifetimes( } for (const instr of block.instructions) { - if (instr.value.kind === "StoreLocal") { - inferPlace(instr.value.value, instr, inferMutableRangeForStores); - instr.value.lvalue.place.identifier.mutableRange.start = instr.id; - instr.value.lvalue.place.identifier.mutableRange.end = - makeInstructionId(instr.id + 1); - } else if (instr.value.kind === "Destructure") { - inferPlace(instr.value.value, instr, inferMutableRangeForStores); - for (const place of eachPatternOperand(instr.value.lvalue.pattern)) { - place.identifier.mutableRange.start = instr.id; - place.identifier.mutableRange.end = makeInstructionId(instr.id + 1); - } - } else { - for (const input of eachInstructionOperand(instr)) { - inferPlace(input, instr, inferMutableRangeForStores); - } + for (const operand of eachInstructionLValue(instr)) { + const lvalueId = operand.identifier; + + // lvalue start being mutable when they're initially assigned a + // value. + lvalueId.mutableRange.start = instr.id; + + // Let's be optimistic and assume this lvalue is not mutable by + // default. + lvalueId.mutableRange.end = makeInstructionId(instr.id + 1); + } + for (const operand of eachInstructionOperand(instr)) { + inferPlace(operand, instr, inferMutableRangeForStores); } - - const lvalueId = instr.lvalue.identifier; - - // lvalue start being mutable when they're initially assigned a - // value. - lvalueId.mutableRange.start = instr.id; - - // Let's be optimistic and assume this lvalue is not mutable by - // default. - lvalueId.mutableRange.end = makeInstructionId(instr.id + 1); } } } diff --git a/compiler/forget/src/ReactiveScopes/BuildReactiveBlocks.ts b/compiler/forget/src/ReactiveScopes/BuildReactiveBlocks.ts index 07e3c6d41a..91dca360eb 100644 --- a/compiler/forget/src/ReactiveScopes/BuildReactiveBlocks.ts +++ b/compiler/forget/src/ReactiveScopes/BuildReactiveBlocks.ts @@ -18,6 +18,7 @@ import { ReactiveStatement, ScopeId, } from "../HIR"; +import { eachInstructionLValue } from "../HIR/visitors"; import { assertExhaustive } from "../Utils/utils"; import { eachReactiveValueOperand, mapTerminalBlocks } from "./visitors"; @@ -178,22 +179,22 @@ function visitBlock(context: Context, block: ReactiveBlock): void { } } -export function getInstructionScope({ - id, - lvalue, - value, -}: ReactiveInstruction): ReactiveScope | null { +export function getInstructionScope( + instr: ReactiveInstruction +): ReactiveScope | null { invariant( - lvalue !== null, + instr.lvalue !== null, "Expected lvalues to not be null when assigning scopes. " + "Pruning lvalues too early can result in missing scope information." ); - const lvalueScope = getPlaceScope(id, lvalue); - if (lvalueScope !== null) { - return lvalueScope; + for (const operand of eachInstructionLValue(instr)) { + const operandScope = getPlaceScope(instr.id, operand); + if (operandScope !== null) { + return operandScope; + } } - for (const operand of eachReactiveValueOperand(value)) { - const operandScope = getPlaceScope(id, operand); + for (const operand of eachReactiveValueOperand(instr.value)) { + const operandScope = getPlaceScope(instr.id, operand); if (operandScope !== null) { return operandScope; } diff --git a/compiler/forget/src/ReactiveScopes/InferReactiveIdentifiers.ts b/compiler/forget/src/ReactiveScopes/InferReactiveIdentifiers.ts index d2f9a49401..3f1e020e9e 100644 --- a/compiler/forget/src/ReactiveScopes/InferReactiveIdentifiers.ts +++ b/compiler/forget/src/ReactiveScopes/InferReactiveIdentifiers.ts @@ -13,6 +13,7 @@ import { ReactiveFunction, ReactiveInstruction, } from "../HIR/HIR"; +import { eachInstructionLValue } from "../HIR/visitors"; import { assertExhaustive } from "../Utils/utils"; import { eachReactiveValueOperand, @@ -66,6 +67,9 @@ class Visitor extends ReactiveFunctionVisitor { state.reactivityMap.set(lval.identifier.id, hasReactiveInput); if (hasReactiveInput) { + for (const lvalue of eachInstructionLValue(instr)) { + state.reactivityMap.set(lvalue.identifier.id, true); + } // all mutating effects must also be marked as reactive for (const operand of eachReactiveValueOperand(value)) { switch (operand.effect) { diff --git a/compiler/forget/src/ReactiveScopes/visitors.ts b/compiler/forget/src/ReactiveScopes/visitors.ts index 37933a6372..73bc814667 100644 --- a/compiler/forget/src/ReactiveScopes/visitors.ts +++ b/compiler/forget/src/ReactiveScopes/visitors.ts @@ -17,7 +17,10 @@ import { ReactiveTerminalStatement, ReactiveValue, } from "../HIR/HIR"; -import { eachInstructionValueOperand } from "../HIR/visitors"; +import { + eachInstructionLValue, + eachInstructionValueOperand, +} from "../HIR/visitors"; import { assertExhaustive } from "../Utils/utils"; export function visitReactiveFunction( @@ -69,8 +72,8 @@ export class ReactiveFunctionVisitor { } traverseInstruction(instruction: ReactiveInstruction, state: TState): void { this.visitID(instruction.id, state); - if (instruction.lvalue !== null) { - this.visitLValue(instruction.id, instruction.lvalue, state); + for (const operand of eachInstructionLValue(instruction)) { + this.visitLValue(instruction.id, operand, state); } this.visitValue(instruction.id, instruction.value, state); } diff --git a/compiler/forget/src/SSA/EliminateRedundantPhi.ts b/compiler/forget/src/SSA/EliminateRedundantPhi.ts index 773ca6aae7..cff7d3e7a6 100644 --- a/compiler/forget/src/SSA/EliminateRedundantPhi.ts +++ b/compiler/forget/src/SSA/EliminateRedundantPhi.ts @@ -7,7 +7,11 @@ import invariant from "invariant"; import { BlockId, HIRFunction, Identifier, Place } from "../HIR/HIR"; -import { eachInstructionOperand, eachTerminalOperand } from "../HIR/visitors"; +import { + eachInstructionLValue, + eachInstructionOperand, + eachTerminalOperand, +} from "../HIR/visitors"; /** * Pass to eliminate redundant phi nodes: @@ -79,6 +83,9 @@ export function eliminateRedundantPhi(fn: HIRFunction) { // Rewrite all instruction lvalues and operands for (const instr of block.instructions) { + for (const place of eachInstructionLValue(instr)) { + rewritePlace(place, rewrites); + } for (const place of eachInstructionOperand(instr)) { rewritePlace(place, rewrites); } diff --git a/compiler/forget/src/SSA/EnterSSA.ts b/compiler/forget/src/SSA/EnterSSA.ts index f2c25eef87..81ac35baef 100644 --- a/compiler/forget/src/SSA/EnterSSA.ts +++ b/compiler/forget/src/SSA/EnterSSA.ts @@ -15,8 +15,8 @@ import { import { printIdentifier } from "../HIR/PrintHIR"; import { eachTerminalSuccessor, + mapInstructionLValues, mapInstructionOperands, - mapPatternOperands, mapTerminalOperands, } from "../HIR/visitors"; @@ -224,24 +224,8 @@ export default function enterSSA(func: HIRFunction): void { } for (const instr of block.instructions) { - if (instr.value.kind === "StoreLocal") { - const oldPlace = instr.value.lvalue.place; - const newPlace = builder.definePlace(oldPlace); - instr.value.lvalue.place = newPlace; - - instr.value.value = builder.getPlace(instr.value.value); - } else if (instr.value.kind === "Destructure") { - mapPatternOperands(instr.value.lvalue.pattern, (place) => - builder.definePlace(place) - ); - instr.value.value = builder.getPlace(instr.value.value); - } else { - mapInstructionOperands(instr, (place) => builder.getPlace(place)); - } - - const oldPlace = instr.lvalue; - const newPlace = builder.definePlace(oldPlace); - instr.lvalue = newPlace; + mapInstructionLValues(instr, (lvalue) => builder.definePlace(lvalue)); + mapInstructionOperands(instr, (place) => builder.getPlace(place)); } mapTerminalOperands(block.terminal, (place) => builder.getPlace(place)); diff --git a/compiler/forget/src/SSA/LeaveSSA.ts b/compiler/forget/src/SSA/LeaveSSA.ts index d75a8d6ac7..1acd79c4bf 100644 --- a/compiler/forget/src/SSA/LeaveSSA.ts +++ b/compiler/forget/src/SSA/LeaveSSA.ts @@ -24,6 +24,7 @@ import { } from "../HIR/HIR"; import { printPlace } from "../HIR/PrintHIR"; import { + eachInstructionLValue, eachInstructionValueOperand, eachPatternOperand, eachTerminalOperand, @@ -188,6 +189,9 @@ export function leaveSSA(fn: HIRFunction): void { value.lvalue.kind = kind; } rewritePlace(lvalue, rewrites, declarations); + for (const operand of eachInstructionLValue(instr)) { + rewritePlace(operand, rewrites, declarations); + } for (const operand of eachInstructionValueOperand(instr.value)) { rewritePlace(operand, rewrites, declarations); } diff --git a/compiler/forget/src/TypeInference/InferTypes.ts b/compiler/forget/src/TypeInference/InferTypes.ts index 6e19bcc5a3..c33e9ec1ed 100644 --- a/compiler/forget/src/TypeInference/InferTypes.ts +++ b/compiler/forget/src/TypeInference/InferTypes.ts @@ -9,7 +9,7 @@ import { TypeId, TypeVar, } from "../HIR/HIR"; -import { eachInstructionOperand } from "../HIR/visitors"; +import { eachInstructionLValue, eachInstructionOperand } from "../HIR/visitors"; function isPrimitiveBinaryOp(op: t.BinaryExpression["operator"]) { switch (op) { @@ -50,6 +50,9 @@ function apply(func: HIRFunction, unifier: Unifier) { phi.type = unifier.get(phi.type); } for (const instr of block.instructions) { + for (const operand of eachInstructionLValue(instr)) { + operand.identifier.type = unifier.get(operand.identifier.type); + } for (const place of eachInstructionOperand(instr)) { place.identifier.type = unifier.get(place.identifier.type); }