Skip to content

Commit dbfe1d2

Browse files
committed
closure
1 parent ac35173 commit dbfe1d2

6 files changed

Lines changed: 115 additions & 49 deletions

File tree

rust/ql/lib/codeql/rust/internal/typeinference/BlanketImplementation.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ module SatisfiesBlanketConstraint<
9696

9797
Type getTypeAt(TypePath path) {
9898
result = at.getTypeAt(blanketPath.appendInverse(path)) and
99-
not result = TUnknownType()
99+
not result instanceof UnknownType
100100
}
101101

102102
string toString() { result = at.toString() + " [blanket at " + blanketPath.toString() + "]" }

rust/ql/lib/codeql/rust/internal/typeinference/FunctionType.qll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ module ArgIsInstantiationOf<ArgSig Arg, IsInstantiationOfInputSig<Arg, AssocFunc
329329
private class ArgSubst extends ArgFinal {
330330
Type getTypeAt(TypePath path) {
331331
result = substituteLookupTraits0(this.getEnclosingItemNode(), super.getTypeAt(path)) and
332-
not result = TUnknownType()
332+
not result instanceof UnknownType
333333
}
334334
}
335335

rust/ql/lib/codeql/rust/internal/typeinference/Type.qll

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ newtype TType =
3737
TImplTraitType(ImplTraitTypeRepr impl) or
3838
TDynTraitType(Trait t) { t = any(DynTraitTypeRepr dt).getTrait() } or
3939
TUnknownType() or
40+
TClosureParameterType(Param p) {
41+
exists(ClosureExpr ce |
42+
p = ce.getParam(_) and
43+
not p.hasTypeRepr()
44+
)
45+
} or
4046
TTypeParamTypeParameter(TypeParam t) or
4147
TAssociatedTypeTypeParameter(Trait trait, AssocType typeAlias) {
4248
getTraitAssocType(trait) = typeAlias
@@ -346,6 +352,12 @@ class PtrConstType extends PtrType {
346352
override string toString() { result = "*const" }
347353
}
348354

355+
abstract class PseudoType extends Type {
356+
override TypeParameter getPositionalTypeParameter(int i) { none() }
357+
358+
override Location getLocation() { result instanceof EmptyLocation }
359+
}
360+
349361
/**
350362
* A special pseudo type used to indicate that the actual type may have to be
351363
* inferred by propagating type information back into call arguments.
@@ -368,12 +380,18 @@ class PtrConstType extends PtrType {
368380
* into call arguments (including method call receivers), in order to avoid
369381
* combinatorial explosions.
370382
*/
371-
class UnknownType extends Type, TUnknownType {
372-
override TypeParameter getPositionalTypeParameter(int i) { none() }
383+
class UnknownType extends PseudoType, TUnknownType {
384+
override string toString() { result = "(unknown type)" }
385+
}
373386

374-
override string toString() { result = "(context typed)" }
387+
class ClosureParameterType extends PseudoType, TClosureParameterType {
388+
private Param param;
375389

376-
override Location getLocation() { result instanceof EmptyLocation }
390+
ClosureParameterType() { this = TClosureParameterType(param) }
391+
392+
Param getParam() { result = param }
393+
394+
override string toString() { result = "(closure parameter " + param + ")" }
377395
}
378396

379397
/** A type parameter. */

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ private module Input1 implements InputSig1<Location> {
3838

3939
class Type = T::Type;
4040

41+
class PseudoType = T::PseudoType;
42+
4143
class UnknownType = T::UnknownType;
4244

4345
class TypeParameter = T::TypeParameter;
@@ -645,7 +647,7 @@ private module Input3 implements InputSig3 {
645647
not tp = c.getParameterType(_, _) and
646648
// check that no explicit type arguments have been supplied for `tp`
647649
not exists(TypeArgumentPosition tapos |
648-
this.getTypeArgument(tapos, _) != TUnknownType() and
650+
this.getTypeArgument(tapos, _) = any(Type t | not t instanceof UnknownType) and
649651
TTypeParamTypeParameter(tapos.asTypeParam()) = tp
650652
)
651653
)
@@ -734,10 +736,16 @@ private module Input3 implements InputSig3 {
734736
path2 = TypePath::singleton(tt.getPositionalTypeParameter(i))
735737
)
736738
or
737-
exists(ClosureExpr ce, int index |
738-
n1 = ce.getParam(index).getPat() and
739-
n2 = ce and
739+
exists(ClosureExpr ce, int index, Param p |
740740
path1.isEmpty() and
741+
p = ce.getParam(index)
742+
|
743+
n1 = p.getPat() and
744+
n2 = p and
745+
path2.isEmpty()
746+
or
747+
n1 = p and
748+
n2 = ce and
741749
path2 = closureParameterPath(ce.getNumberOfParams(), index)
742750
)
743751
}
@@ -881,10 +889,6 @@ private module Input3 implements InputSig3 {
881889
Type inferType(AstNode n, TypePath path) {
882890
result = M3::inferType(n, path)
883891
or
884-
isPanicMacroCall(n) and
885-
path.isEmpty() and
886-
result instanceof UnknownType
887-
or
888892
result = inferAssignmentOperationType(n, path)
889893
or
890894
result = inferTryExprType(n, path)
@@ -904,6 +908,8 @@ private module Input3 implements InputSig3 {
904908
result = inferDeconstructionPatType(n, path)
905909
or
906910
result = inferUnknownType(n, path)
911+
or
912+
result = inferParamPatType(n, path)
907913
}
908914
}
909915

@@ -1318,7 +1324,7 @@ private module ContextualTyping {
13181324
abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
13191325

13201326
predicate hasTypeArgument(TypeArgumentPosition apos) {
1321-
this.getTypeArgument(apos, _) != TUnknownType()
1327+
this.getTypeArgument(apos, _) = any(Type t | not t instanceof UnknownType)
13221328
}
13231329

13241330
/**
@@ -1829,7 +1835,7 @@ private module AssocFunctionResolution {
18291835
not this.hasReceiver() and
18301836
exists(TypePath strippedTypePath, Type strippedType |
18311837
strippedType = substituteLookupTraits(this, this.getTypeAt(selfPos, strippedTypePath)) and
1832-
strippedType != TUnknownType()
1838+
not strippedType instanceof UnknownType
18331839
|
18341840
nonBlanketLikeCandidate(this, _, selfPos, _, _, strippedTypePath, strippedType)
18351841
or
@@ -1925,7 +1931,7 @@ private module AssocFunctionResolution {
19251931
FunctionPosition selfPos, DerefChain derefChain, BorrowKind borrow, TypePath path
19261932
) {
19271933
result = this.getSelfTypeAt(selfPos, derefChain, borrow, path) and
1928-
result != TUnknownType()
1934+
not result instanceof UnknownType
19291935
}
19301936

19311937
pragma[nomagic]
@@ -2584,7 +2590,7 @@ private module AssocFunctionResolution {
25842590

25852591
Type getTypeAt(TypePath path) {
25862592
result = substituteLookupTraits(afc, afc.getSelfTypeAtNoBorrow(selfPos, derefChain, path)) and
2587-
result != TUnknownType()
2593+
not result instanceof UnknownType
25882594
}
25892595

25902596
string toString() { result = afc + " [" + derefChain.toString() + "]" }
@@ -3144,12 +3150,43 @@ private Type inferUnknownType(AstNode n, TypePath path) {
31443150
or
31453151
n.(Input3::Construction).hasUnknownReturnTypeAt(path)
31463152
or
3147-
exists(Param p |
3148-
not p.hasTypeRepr() and
3149-
n = p.getPat() and
3150-
path.isEmpty()
3153+
// non-`self` parameters without type annotations always belong to closures, so
3154+
// we want
3155+
n = any(Param p | not p.hasTypeRepr()) and
3156+
path.isEmpty()
3157+
or
3158+
isPanicMacroCall(n) and
3159+
path.isEmpty()
3160+
)
3161+
}
3162+
3163+
private Type inferParamPatType(AstNode n, TypePath path) {
3164+
exists(ClosureExpr ce, Param p | p = ce.getAParam() |
3165+
n = p.getPat() and
3166+
path.isEmpty() and
3167+
result = TClosureParameterType(p)
3168+
or
3169+
exists(TypePath ret | inferType(ce, ret) = TClosureParameterType(p) |
3170+
n = ce and
3171+
path = ret and
3172+
result = TUnknownType()
3173+
or
3174+
inferType(ce, ret.appendInverse(path)) = result and
3175+
n = p
3176+
)
3177+
or
3178+
exists(AstNode n0, TypePath prefix |
3179+
inferType(n0, prefix) = TClosureParameterType(p) and
3180+
result = inferTypeCertain(n0, prefix.appendInverse(path)) and
3181+
n = p
31513182
)
31523183
)
3184+
or
3185+
n =
3186+
any(Param p |
3187+
result = inferType(p, path) and
3188+
not result instanceof UnknownType
3189+
).getPat()
31533190
}
31543191

31553192
pragma[nomagic]

rust/ql/test/library-tests/type-inference/type-inference.ql

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,19 @@ private predicate relevantNode(AstNode n) {
1010
not n.isFromMacroExpansion() and
1111
not n instanceof IdentPat and // avoid overlap in the output with the underlying `Name` node
1212
not n instanceof LiteralPat and // avoid overlap in the output with the underlying `Literal` node
13-
(n instanceof TypeMention implies n instanceof SelfParam)
13+
(n instanceof TypeMention implies n instanceof SelfParam) and
14+
not n instanceof Param
1415
}
1516

1617
query predicate inferCertainType(AstNode n, TypePath path, Type t) {
1718
t = TypeInference::inferTypeCertain(n, path) and
18-
t != TUnknownType() and
19+
not t instanceof PseudoType and
1920
relevantNode(n)
2021
}
2122

2223
query predicate inferType(AstNode n, TypePath path, Type t) {
2324
t = TypeInference::inferType(n, path) and
24-
t != TUnknownType() and
25+
not t instanceof PseudoType and
2526
relevantNode(n)
2627
}
2728

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,11 @@ signature module InputSig1<LocationSig Location> {
149149
Location getLocation();
150150
}
151151

152+
class PseudoType extends Type;
153+
152154
/**
153155
* A special pseudo type used to represent cases where the actual type needs
154-
* to be inferred from the context in a top-down manner. For example, in
156+
* to be inferred using contextual information. For example, in
155157
*
156158
* ```rust
157159
* let x = Vec::new();
@@ -161,7 +163,7 @@ signature module InputSig1<LocationSig Location> {
161163
* the element type of `x` is assigned an unknown type, which allows for type
162164
* information to flow into `x` from the call to `push`.
163165
*/
164-
class UnknownType extends Type;
166+
class UnknownType extends PseudoType;
165167

166168
/** A type parameter. */
167169
class TypeParameter extends Type;
@@ -662,7 +664,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
662664

663665
private Type getNonPseudoTypeAt(App app, TypePath path) {
664666
result = app.getTypeAt(path) and
665-
not result instanceof UnknownType
667+
not result instanceof PseudoType
666668
}
667669

668670
pragma[nomagic]
@@ -1383,7 +1385,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
13831385
exists(TypeArgumentPosition tapos |
13841386
result = a.getTypeArgument(tapos, path) and
13851387
tp = getDeclTypeParameter(target, tapos) and
1386-
not result instanceof UnknownType //and path.isEmpty())
1388+
not result instanceof PseudoType
13871389
)
13881390
}
13891391

@@ -2459,6 +2461,9 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
24592461
*/
24602462
predicate inferStep(AstNode n1, TypePath path1, AstNode n2, TypePath path2);
24612463

2464+
bindingset[path]
2465+
default predicate prohibitContextualInference(AstNode n, TypePath path) { none() }
2466+
24622467
/**
24632468
* Gets the inferred certain type of `n` at `path`.
24642469
*
@@ -2532,7 +2537,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
25322537
) and
25332538
// type annotation may for example include unknown types, such as
25342539
// `x : Vec<_>` in Rust
2535-
not result instanceof UnknownType
2540+
not result instanceof PseudoType
25362541
}
25372542

25382543
/**
@@ -2631,25 +2636,28 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
26312636

26322637
pragma[nomagic]
26332638
private Type inferTypeContextualCand(AstNode n, TypePath path) {
2634-
result = inferCallArgumentTypeContextual(n, path)
2635-
or
2636-
result = inferMemberAccessReceiverTypeContextual(n, path)
2637-
or
2638-
exists(Operation o, int pos |
2639-
n = o.getOperand(pos) and
2640-
result = OperationMatching::inferAccessType(o, pos, path)
2641-
)
2642-
or
2643-
exists(Construction c, int pos |
2644-
n = c.getArgument(pos) and
2645-
result = ConstructionMatching::inferAccessType(c, pos, path)
2646-
)
2647-
or
2648-
exists(TypePath path1, AstNode n2, TypePath path2, TypePath suffix |
2649-
result = inferType(n2, path2.appendInverse(suffix)) and
2650-
path = path1.append(suffix) and
2651-
step(n, path1, n2, path2)
2652-
)
2639+
(
2640+
result = inferCallArgumentTypeContextual(n, path)
2641+
or
2642+
result = inferMemberAccessReceiverTypeContextual(n, path)
2643+
or
2644+
exists(Operation o, int pos |
2645+
n = o.getOperand(pos) and
2646+
result = OperationMatching::inferAccessType(o, pos, path)
2647+
)
2648+
or
2649+
exists(Construction c, int pos |
2650+
n = c.getArgument(pos) and
2651+
result = ConstructionMatching::inferAccessType(c, pos, path)
2652+
)
2653+
or
2654+
exists(TypePath path1, AstNode n2, TypePath path2, TypePath suffix |
2655+
result = inferType(n2, path2.appendInverse(suffix)) and
2656+
path = path1.append(suffix) and
2657+
step(n, path1, n2, path2)
2658+
)
2659+
) and
2660+
not prohibitContextualInference(n, path)
26532661
}
26542662

26552663
// todo: share logic about `hasUnknownType`
@@ -2706,6 +2714,8 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
27062714
prefix.isPrefixOf(path)
27072715
|
27082716
not Certain::certainTypeConflict(n, prefix, path, result)
2717+
or
2718+
result instanceof PseudoType and not result instanceof UnknownType // todo
27092719
)
27102720
or
27112721
// If `n` has an explicitly unknown type at `prefix` and at the same time a certain

0 commit comments

Comments
 (0)