Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
import com.ibm.wala.util.intset.IntSet;
import com.ibm.wala.util.intset.OrdinalSet;
import java.io.File;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -751,20 +750,16 @@ private Map<PointsToSetVariable, TensorType> getShapeSourceCalls(
op,
builder,
(CGNode src, SSAAbstractInvokeInstruction call) -> {
try {
if (call.getNumberOfUses() > param)
targets.put(
builder
.getPropagationSystem()
.findOrCreatePointsToSet(
builder
.getPointerAnalysis()
.getHeapModel()
.getPointerKeyForLocal(src, call.getDef())),
TensorType.shapeArg(src, call.getUse(param)));
} catch (IOException e) {
throw new RuntimeException("Error while processing shape source call: " + call, e);
}
if (call.getNumberOfUses() > param)
targets.put(
builder
.getPropagationSystem()
.findOrCreatePointsToSet(
builder
.getPointerAnalysis()
.getHeapModel()
.getPointerKeyForLocal(src, call.getDef())),
TensorType.shapeArg(src, call.getUse(param)));
});
return targets;
}
Expand Down Expand Up @@ -856,13 +851,9 @@ private Map<PointsToSetVariable, TensorType> getSetShapeCallsSyntactic(
}
}
if (!receiverEligible) continue;
try {
targets.put(
builder.getPropagationSystem().findOrCreatePointsToSet(receiverKey),
TensorType.shapeArg(caller, call.getUse(1)));
} catch (IOException e) {
throw new RuntimeException("Error while processing set_shape call: " + call, e);
}
targets.put(
builder.getPropagationSystem().findOrCreatePointsToSet(receiverKey),
TensorType.shapeArg(caller, call.getUse(1)));
}
}
return targets;
Expand Down Expand Up @@ -971,12 +962,8 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)

for (PointsToSetVariable v : sources) init.put(v, getTensorTypes(v, builder));

Map<PointsToSetVariable, TensorType> placeholders = null;
try {
placeholders = handleShapeSourceOp(builder, dataflow, placeholder, 2);
} catch (IOException e) {
throw new RuntimeException("Error while processing placeholder calls.", e);
}
Map<PointsToSetVariable, TensorType> placeholders =
handleShapeSourceOp(builder, dataflow, placeholder, 2);
LOGGER.fine("Placeholders: " + placeholders);

for (Map.Entry<PointsToSetVariable, TensorType> e : placeholders.entrySet())
Expand Down Expand Up @@ -1015,12 +1002,7 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
}

Map<PointsToSetVariable, TensorType> shapeOps = HashMapFactory.make();

try {
shapeOps.putAll(handleShapeSourceOp(builder, dataflow, reshape, 2));
} catch (IOException e) {
throw new RuntimeException("Error while processing reshape calls.", e);
}
shapeOps.putAll(handleShapeSourceOp(builder, dataflow, reshape, 2));

handlePassThroughOp(builder, dataflow, convert_to_tensor, 1);

Expand Down Expand Up @@ -1098,8 +1080,7 @@ private Map<PointsToSetVariable, TensorType> handleShapeSourceOp(
PropagationCallGraphBuilder builder,
Graph<PointsToSetVariable> dataflow,
MethodReference op,
int shapeSrcOperand)
throws IOException {
int shapeSrcOperand) {
Map<PointsToSetVariable, TensorType> reshapeTypes =
getShapeSourceCalls(op, builder, shapeSrcOperand);
for (PointsToSetVariable to : reshapeTypes.keySet()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@
return new TensorType(FLOAT32.name().toLowerCase(Locale.ROOT), Arrays.asList(batch, vec));
}

public static TensorType shapeArg(CGNode node, int literalVn) throws IOException {
public static TensorType shapeArg(CGNode node, int literalVn) {
logger.fine(() -> node.getIR().toString());
Map<Integer, Dimension<?>> dims = new TreeMap<>();
DefUse du = node.getDU();
Expand Down Expand Up @@ -632,14 +632,19 @@
((AstMethod) node.getMethod())
.debugInfo()
.getInstructionPosition(du.getDef(val).iIndex());
System.err.println(p);
SourceBuffer b = new SourceBuffer(p);
String expr = b.toString();
System.err.println(expr);
Integer ival = PythonInterpreter.interpretAsInt(expr);
if (ival != null) {
dims.put(index, new NumericDim(ival));
continue;
// `SourceBuffer(Position)` reads the underlying source file. If the file
// is unavailable (synthetic / detached position), fall through to the
// symbolic-dim fallback below rather than propagating the I/O failure.
try {
SourceBuffer b = new SourceBuffer(p);
String expr = b.toString();
Integer ival = PythonInterpreter.interpretAsInt(expr);

Check warning on line 641 in com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorType.java

View check run for this annotation

Codecov / codecov/patch

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorType.java#L639-L641

Added lines #L639 - L641 were not covered by tests
if (ival != null) {
dims.put(index, new NumericDim(ival));
continue;

Check warning on line 644 in com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorType.java

View check run for this annotation

Codecov / codecov/patch

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorType.java#L643-L644

Added lines #L643 - L644 were not covered by tests
}
} catch (IOException e) {
logger.fine(() -> "Could not read source for shape-arg position " + p + ": " + e);

Check warning on line 647 in com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorType.java

View check run for this annotation

Codecov / codecov/patch

com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorType.java#L646-L647

Added lines #L646 - L647 were not covered by tests
}
}
dims.put(index, new SymbolicDim("?"));
Expand Down
Loading