Constant propagation/folding

Implements constant propagation/constant folding for a conservative subset of 
the language. The approach is described in detail in the comments in the file 
itself, a key note here is that this pass currently emits what looks like 
garbage: 

``` 

// input 

const x = 1; 

const y = x + 1; 

// output 

const x = 1; 

2; // <---- you'll see a bunch of lines like this 

const y = 2; 

``` 

These useless lines occur where previously there was a temporary getting 
calculated that was used later (so we saved it until it was used), but now it 
isn't used later so we just emit it in-place. Dead code elimination (DCE) can 
eliminate these and other useless statements later. 

Note that a key motivation for implementing this pass is to reduce memoization 
blocks to what is strictly required for dynamic computations. Why memoize at 
runtime when we compute at build time?
This commit is contained in:
Joe Savona
2023-01-09 15:30:20 -08:00
parent 6f9c9cf9ea
commit dce5371763
30 changed files with 570 additions and 243 deletions
+4
View File
@@ -14,6 +14,7 @@ import {
ReactiveFunction,
} from "./HIR";
import { inferMutableRanges, inferReferenceEffects } from "./Inference";
import { constantPropagation } from "./Optimization";
import {
buildReactiveFunction,
codegenReactiveFunction,
@@ -53,6 +54,9 @@ export default function (
eliminateRedundantPhi(ir);
logHIRFunction("eliminateRedundantPhi", ir);
constantPropagation(ir);
logHIRFunction("constantPropagation", ir);
inferTypes(ir);
logHIRFunction("inferTypes", ir);
+5 -5
View File
@@ -360,7 +360,7 @@ export default class HIRBuilder {
/**
* Helper to shrink a CFG eliminate jump-only blocks.
*/
function shrink(func: HIR): void {
export function shrink(func: HIR): void {
const gotos = new Map();
/**
* Given a target block for some terminator, resolves the ideal block that should be
@@ -408,7 +408,7 @@ function shrink(func: HIR): void {
}
}
function removeUnreachableFallthroughs(func: HIR): void {
export function removeUnreachableFallthroughs(func: HIR): void {
const visited: Set<BlockId> = new Set();
for (const [_, block] of func.blocks) {
visited.add(block.id);
@@ -434,7 +434,7 @@ function removeUnreachableFallthroughs(func: HIR): void {
* Converts the graph to reverse-postorder, with predecessor blocks appearing
* before successors except in the case of back links (ie loops).
*/
function reversePostorderBlocks(func: HIR): void {
export function reversePostorderBlocks(func: HIR): void {
const visited: Set<BlockId> = new Set();
const postorder: Array<BlockId> = [];
function visit(blockId: BlockId) {
@@ -521,7 +521,7 @@ function reversePostorderBlocks(func: HIR): void {
func.blocks = blocks;
}
function markInstructionIds(func: HIR) {
export function markInstructionIds(func: HIR) {
let id = 0;
const visited = new Set<Instruction>();
for (const [_, block] of func.blocks) {
@@ -537,7 +537,7 @@ function markInstructionIds(func: HIR) {
}
}
function markPredecessors(func: HIR) {
export function markPredecessors(func: HIR) {
for (const [, block] of func.blocks) {
block.preds.clear();
}
+9 -2
View File
@@ -6,7 +6,14 @@
*/
export { lower } from "./BuildHIR";
export { HIRFunction, ReactiveFunction } from "./HIR";
export { Environment } from "./HIRBuilder";
export * from "./HIR";
export {
Environment,
markInstructionIds,
markPredecessors,
removeUnreachableFallthroughs,
reversePostorderBlocks,
shrink,
} from "./HIRBuilder";
export { mergeConsecutiveBlocks } from "./MergeConsecutiveBlocks";
export { printFunction } from "./PrintHIR";
@@ -0,0 +1,270 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
import {
BlockId,
GotoVariant,
HIRFunction,
IdentifierId,
InstructionValue,
markInstructionIds,
markPredecessors,
mergeConsecutiveBlocks,
Place,
Primitive,
removeUnreachableFallthroughs,
reversePostorderBlocks,
shrink,
} from "../HIR";
import { eliminateRedundantPhi } from "../SSA";
/**
* Applies constant propagation and constant folding to the given function.
* Note that because HIR operands are always a Place, constants cannot be directly
* propagated into the HIR itself (the closest option would be to copy constants to
* new temporaries just before each use, and update usage sites to reference those
* new temporaries).
*
* Instead this pass implements constant folding, in which constant values are
* propagated internally to the pass and subsequent operations are removed/folded where
* possible.
*
* Note that this pass may prune control flow blocks that are unreachable, for example
* a consequent or alternate branch if an `if` test is provably truthy or falsey.
* If (and only if) terminals change, the pass re-runs various stages to ensure the
* CFG is in minimal form. This means instruction ids *may* change as a result of this
* pass.
*/
export function constantPropagation(fn: HIRFunction): void {
const haveTerminalsChanged = applyConstantPropagation(fn);
if (haveTerminalsChanged) {
// If terminals have changed then blocks may have become newly unreachable.
// Re-run minification of the graph (incl reordering instruction ids)
shrink(fn.body);
reversePostorderBlocks(fn.body);
removeUnreachableFallthroughs(fn.body);
markInstructionIds(fn.body);
markPredecessors(fn.body);
// Now that predecessors are updated, prune phi operands that can never be reached
for (const [, block] of fn.body.blocks) {
for (const phi of block.phis) {
for (const [predecessor] of phi.operands) {
if (!block.preds.has(predecessor)) {
phi.operands.delete(predecessor);
}
}
}
}
// By removing some phi operands, there may be phis that were not previously
// redundant but now are
eliminateRedundantPhi(fn);
// Finally, merge together any blocks that are now guaranteed to execute
// consecutively
mergeConsecutiveBlocks(fn);
}
}
function applyConstantPropagation(fn: HIRFunction): boolean {
let hasChanges = false;
// A set of blocks whose terminals can't (yet) be safely rewritten
const valueBlocks = new Set<BlockId>();
const constants: Constants = new Map();
for (const [, block] of fn.body.blocks) {
// Initialize phi values if all operands have the same known constant value.
// Note that this analysis uses a single-pass only, so it will never fill in
// phi values for blocks that have a back-edge.
for (const phi of block.phis) {
let value: Primitive | null = null;
for (const [, operand] of phi.operands) {
const operandValue = constants.get(operand.id) ?? null;
if (operandValue === null) {
value = null;
break;
}
if (value === null) {
value = operandValue;
} else if (operandValue.value !== value.value) {
value = null;
break;
}
}
if (value !== null) {
constants.set(phi.id.id, value);
}
}
for (const instr of block.instructions) {
const value = evaluateInstruction(constants, instr.value);
if (value !== null) {
instr.value = value;
constants.set(instr.lvalue.place.identifier.id, value);
}
}
if (valueBlocks.has(block.id)) {
// can't rewrite terminals in value blocks yet
continue;
}
const terminal = block.terminal;
switch (terminal.kind) {
case "if": {
const testValue = read(constants, terminal.test);
if (testValue !== null && testValue.kind === "Primitive") {
hasChanges = true;
const targetBlockId = Boolean(testValue.value)
? terminal.consequent
: terminal.alternate;
block.terminal = {
kind: "goto",
variant: GotoVariant.Break,
block: targetBlockId,
id: terminal.id,
};
}
break;
}
case "while": {
valueBlocks.add(terminal.test);
break;
}
case "for": {
valueBlocks.add(terminal.init);
valueBlocks.add(terminal.test);
valueBlocks.add(terminal.update);
break;
}
default: {
// no-op
}
}
}
return hasChanges;
}
function evaluateInstruction(
constants: Constants,
instr: InstructionValue
): Constant | null {
switch (instr.kind) {
case "Primitive": {
return instr;
}
case "BinaryExpression": {
const lhsValue = read(constants, instr.left);
const rhsValue = read(constants, instr.right);
if (lhsValue !== null && rhsValue !== null) {
const lhs = lhsValue.value;
const rhs = rhsValue.value;
switch (instr.operator) {
case "+": {
if (typeof lhs === "number" && typeof rhs === "number") {
return { kind: "Primitive", value: lhs + rhs, loc: instr.loc };
}
return null;
}
case "-": {
if (typeof lhs === "number" && typeof rhs === "number") {
return { kind: "Primitive", value: lhs - rhs, loc: instr.loc };
}
return null;
}
case "*": {
if (typeof lhs === "number" && typeof rhs === "number") {
return { kind: "Primitive", value: lhs * rhs, loc: instr.loc };
}
return null;
}
case "/": {
if (typeof lhs === "number" && typeof rhs === "number") {
return { kind: "Primitive", value: lhs / rhs, loc: instr.loc };
}
return null;
}
case "<": {
if (typeof lhs === "number" && typeof rhs === "number") {
return { kind: "Primitive", value: lhs < rhs, loc: instr.loc };
}
return null;
}
case "<=": {
if (typeof lhs === "number" && typeof rhs === "number") {
return { kind: "Primitive", value: lhs <= rhs, loc: instr.loc };
}
return null;
}
case ">": {
if (typeof lhs === "number" && typeof rhs === "number") {
return { kind: "Primitive", value: lhs > rhs, loc: instr.loc };
}
return null;
}
case ">=": {
if (typeof lhs === "number" && typeof rhs === "number") {
return { kind: "Primitive", value: lhs >= rhs, loc: instr.loc };
}
return null;
}
case "==": {
return { kind: "Primitive", value: lhs == rhs, loc: instr.loc };
}
case "===": {
return { kind: "Primitive", value: lhs === rhs, loc: instr.loc };
}
case "!=": {
return { kind: "Primitive", value: lhs != rhs, loc: instr.loc };
}
case "!==": {
return { kind: "Primitive", value: lhs !== rhs, loc: instr.loc };
}
default: {
// TODO: handle more cases
return null;
}
}
}
return null;
}
case "PropertyLoad": {
const objectValue = read(constants, instr.object);
if (objectValue !== null) {
if (
typeof objectValue.value === "string" &&
instr.property === "length"
) {
return {
kind: "Primitive",
value: objectValue.value.length,
loc: instr.loc,
};
}
}
return null;
}
case "Identifier": {
return read(constants, instr);
}
default: {
// TODO: handle more cases
return null;
}
}
}
/**
* Recursively read the value of a place: if it is a constant place, attempt to read
* from that place until reaching a primitive or finding a value that is unset.
*/
function read(constants: Constants, place: Place): Constant | null {
return constants.get(place.identifier.id) ?? null;
}
type Constant = Primitive;
type Constants = Map<IdentifierId, Constant>;
@@ -0,0 +1,8 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
export { constantPropagation } from "./ConstantPropagation";
@@ -17,7 +17,8 @@ function f() {
function f() {
const x = 1;
const x$0 = 2;
return x$0 + x$0 + x$0;
4;
return 6;
}
```
@@ -16,8 +16,10 @@ function f() {
```javascript
function f() {
const x = 1;
const x$0 = x + 1;
const x$1 = x$0 + 1;
1;
const x$0 = 2;
1;
const x$1 = 3;
const x$2 = x$1 >>> 1;
}
@@ -0,0 +1,37 @@
## Input
```javascript
function foo() {
let y = 0;
for (const x = 100; x < 10; x) {
y = y + 1;
}
return y;
}
```
## Code
```javascript
function foo() {
const $ = React.useMemoCache();
let y;
if ($[0] === Symbol.for("react.memo_cache_sentinel")) {
y = 0;
for (const x = 100; 10, false; 100) {
y = y + 1;
}
$[0] = y;
} else {
y = $[0];
}
return y;
}
```
@@ -0,0 +1,7 @@
function foo() {
let y = 0;
for (const x = 100; x < 10; x) {
y = y + 1;
}
return y;
}
@@ -0,0 +1,45 @@
## Input
```javascript
function foo(a, b, c) {
let x;
if (a) {
x = 2 - 1;
} else {
x = 0 + 1;
}
if (x === 1) {
return b;
} else {
return c;
}
}
```
## Code
```javascript
function foo(a, b, c) {
const x = undefined;
let x$0 = undefined;
if (a) {
2;
1;
const x$1 = 1;
x$0 = x$1;
} else {
0;
1;
const x$2 = 1;
x$0 = x$2;
}
1;
true;
return b;
}
```
@@ -0,0 +1,13 @@
function foo(a, b, c) {
let x;
if (a) {
x = 2 - 1;
} else {
x = 0 + 1;
}
if (x === 1) {
return b;
} else {
return c;
}
}
@@ -0,0 +1,39 @@
## Input
```javascript
function foo() {
let x = 100;
let y = 0;
while (x < 10) {
y += 1;
}
return y;
}
```
## Code
```javascript
function foo() {
const $ = React.useMemoCache();
const x = 100;
let y;
if ($[0] === Symbol.for("react.memo_cache_sentinel")) {
y = 0;
while ((10, false)) {
y = y + 1;
}
$[0] = y;
} else {
y = $[0];
}
return y;
}
```
@@ -0,0 +1,8 @@
function foo() {
let x = 100;
let y = 0;
while (x < 10) {
y += 1;
}
return y;
}
@@ -0,0 +1,45 @@
## Input
```javascript
function foo() {
const a = 1;
const b = 2;
const c = 3;
const d = a + b;
const e = d * c;
const f = e / d;
const g = f - e;
if (g) {
console.log("foo");
}
const h = g;
const i = h;
const j = i;
return j;
}
```
## Code
```javascript
function foo() {
const a = 1;
const b = 2;
const c = 3;
const d = 3;
const e = 9;
const f = 3;
const g = -6;
console.log("foo");
const h = -6;
const i = -6;
const j = -6;
return j;
}
```
@@ -0,0 +1,18 @@
function foo() {
const a = 1;
const b = 2;
const c = 3;
const d = a + b;
const e = d * c;
const f = e / d;
const g = f - e;
if (g) {
console.log("foo");
}
const h = g;
const i = h;
const j = i;
return j;
}
@@ -14,7 +14,9 @@ function foo(a, b, c) {
```javascript
function foo(a, b, c) {
a[b] = c[b];
a[1 + 2] = c[b * 4];
1;
2;
a[3] = c[b * 4];
}
```
@@ -21,39 +21,14 @@ function foo() {
```javascript
function foo() {
const $ = React.useMemoCache();
const x = 1;
const y = 2;
let x$0;
if ($[0] === Symbol.for("react.memo_cache_sentinel")) {
x$0 = x;
if (y === 2) {
const x$1 = 3;
x$0 = x$1;
}
$[0] = x$0;
} else {
x$0 = $[0];
}
let x$2;
if ($[1] === Symbol.for("react.memo_cache_sentinel")) {
x$2 = x$0;
if (y === 3) {
const x$3 = 5;
x$2 = x$3;
}
$[1] = x$2;
} else {
x$2 = $[1];
}
const y$4 = x$2;
2;
true;
const x$0 = 3;
3;
false;
const y$1 = x$0;
}
```
@@ -18,24 +18,12 @@ function foo() {
```javascript
function foo() {
const $ = React.useMemoCache();
const x = 1;
const y = 2;
let x$0;
if ($[0] === Symbol.for("react.memo_cache_sentinel")) {
x$0 = x;
if (y === 2) {
const x$1 = 3;
x$0 = x$1;
}
$[0] = x$0;
} else {
x$0 = $[0];
}
const y$2 = x$0;
2;
true;
const x$0 = 3;
const y$1 = x$0;
}
```
@@ -21,7 +21,7 @@ function foo() {
if ($[0] === Symbol.for("react.memo_cache_sentinel")) {
x = 1;
for (const i = 0; i < 10; i) {
for (const i = 0; 10, true; 0) {
x = x + 1;
}
@@ -21,11 +21,7 @@ function foo() {
function foo() {
const x = 1;
const y = 2;
if (y) {
const z = x + y;
} else {
const z = x;
}
const z = 3;
}
```
@@ -28,49 +28,11 @@ function foo(a, b, c, d) {
```javascript
function foo(a, b, c, d) {
const $ = React.useMemoCache();
const x = 0;
const c_0 = $[0] !== a;
const c_1 = $[1] !== b;
const c_2 = $[2] !== c;
const c_3 = $[3] !== d;
let x$0;
if (c_0 || c_1 || c_2 || c_3) {
x$0 = undefined;
if (true) {
if (true) {
const x$1 = a;
x$0 = x$1;
} else {
const x$2 = b;
x$0 = x$2;
}
x$3;
x$0 = x$3;
} else {
if (true) {
const x$4 = c;
x$0 = x$4;
} else {
const x$5 = d;
x$0 = x$5;
}
x$6;
x$0 = x$6;
}
$[0] = a;
$[1] = b;
$[2] = c;
$[3] = d;
$[4] = x$0;
} else {
x$0 = $[4];
}
true;
true;
const x$0 = a;
x$0;
return x$0;
}
@@ -25,7 +25,8 @@ function foo(a, b, c) {
while (a) {
while (b) {
while (c) {
x + 1;
1;
1;
}
}
}
@@ -25,38 +25,18 @@ function foo() {
const $ = React.useMemoCache();
const x = 1;
const y = 2;
let x$0;
if ($[0] === Symbol.for("react.memo_cache_sentinel")) {
x$0 = x;
let y$1 = y;
if (x > 1) {
const x$2 = 2;
x$0 = x$2;
} else {
const y$3 = 3;
y$1 = y$3;
}
$[0] = x$0;
} else {
x$0 = $[0];
}
const c_1 = $[1] !== x$0;
const c_2 = $[2] !== y$1;
1;
false;
const y$0 = 3;
let t;
if (c_1 || c_2) {
if ($[0] === Symbol.for("react.memo_cache_sentinel")) {
t = {
x: x$0,
y: y$1,
x: x,
y: y$0,
};
$[1] = x$0;
$[2] = y$1;
$[3] = t;
$[0] = t;
} else {
t = $[3];
t = $[0];
}
return t;
@@ -17,22 +17,10 @@ function foo() {
```javascript
function foo() {
const $ = React.useMemoCache();
const x = 1;
let x$0;
if ($[0] === Symbol.for("react.memo_cache_sentinel")) {
x$0 = x;
if (x === 1) {
const x$1 = 2;
x$0 = x$1;
}
$[0] = x$0;
} else {
x$0 = $[0];
}
1;
true;
const x$0 = 2;
return x$0;
}
@@ -28,57 +28,11 @@ function foo(a, b, c, d) {
```javascript
function foo(a, b, c, d) {
const $ = React.useMemoCache();
const x = 0;
if (true) {
const c_0 = $[0] !== a;
const c_1 = $[1] !== b;
let x$0;
if (c_0 || c_1) {
x$0 = undefined;
if (true) {
const x$1 = a;
x$0 = x$1;
} else {
const x$2 = b;
x$0 = x$2;
}
$[0] = a;
$[1] = b;
$[2] = x$0;
} else {
x$0 = $[2];
}
x$0;
} else {
const c_3 = $[3] !== c;
const c_4 = $[4] !== d;
let x$3;
if (c_3 || c_4) {
x$3 = undefined;
if (true) {
const x$4 = c;
x$3 = x$4;
} else {
const x$5 = d;
x$3 = x$5;
}
$[3] = c;
$[4] = d;
$[5] = x$3;
} else {
x$3 = $[5];
}
x$3;
}
true;
true;
const x$0 = a;
x$0;
}
```
@@ -20,25 +20,10 @@ function foo() {
```javascript
function foo() {
const $ = React.useMemoCache();
const y = 2;
let y$0;
if ($[0] === Symbol.for("react.memo_cache_sentinel")) {
y$0 = undefined;
if (y > 1) {
const y$1 = 1;
y$0 = y$1;
} else {
const y$2 = 2;
y$0 = y$2;
}
$[0] = y$0;
} else {
y$0 = $[0];
}
1;
true;
const y$0 = 1;
const x = y$0;
}
@@ -19,9 +19,7 @@ function foo() {
function foo() {
const x = 1;
const y = 2;
if (y) {
const z = x + y;
}
const z = 3;
}
```
@@ -30,25 +30,30 @@ function foo() {
function foo() {
const $ = React.useMemoCache();
const x = 1;
2;
1;
let x$0;
if ($[0] === Symbol.for("react.memo_cache_sentinel")) {
x$0 = undefined;
bb1: switch (x) {
case x === 1: {
const x$1 = x + 1;
case true: {
1;
const x$1 = 2;
x$0 = x$1;
break bb1;
}
case x === 2: {
const x$2 = x + 2;
case false: {
2;
const x$2 = 3;
x$0 = x$2;
break bb1;
}
default: {
const x$3 = x + 3;
3;
const x$3 = 4;
x$0 = x$3;
}
}
@@ -16,22 +16,10 @@ function foo() {
```javascript
function foo() {
const $ = React.useMemoCache();
const x = 1;
let x$0;
if ($[0] === Symbol.for("react.memo_cache_sentinel")) {
x$0 = x;
if (x === 1) {
const x$1 = 2;
x$0 = x$1;
}
$[0] = x$0;
} else {
x$0 = $[0];
}
1;
true;
const x$0 = 2;
throw x$0;
}
@@ -18,8 +18,9 @@ function foo() {
```javascript
function foo() {
const x = 1;
while (x < 10) {
x + 1;
while ((10, true)) {
1;
2;
}
return x;