Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support consumption of instance increment operators #77098

Open
wants to merge 2 commits into
base: features/UserDefinedCompoundAssignment
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 37 additions & 32 deletions src/Compilers/CSharp/Portable/Binder/Binder_Expressions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11011,38 +11011,7 @@ private BoundConditionalAccess BindConditionalAccessExpression(ConditionalAccess
// For improved diagnostics we detect the cases where the value will be used and produce a
// more specific (though not technically correct) diagnostic here:
// "Error CS0023: Operator '?' cannot be applied to operand of type 'T'"
bool resultIsUsed = true;
CSharpSyntaxNode parent = node.Parent;

if (parent != null)
{
switch (parent.Kind())
{
case SyntaxKind.ExpressionStatement:
resultIsUsed = ((ExpressionStatementSyntax)parent).Expression != node;
break;

case SyntaxKind.SimpleLambdaExpression:
resultIsUsed = (((SimpleLambdaExpressionSyntax)parent).Body != node) || MethodOrLambdaRequiresValue(ContainingMemberOrLambda, Compilation);
break;

case SyntaxKind.ParenthesizedLambdaExpression:
resultIsUsed = (((ParenthesizedLambdaExpressionSyntax)parent).Body != node) || MethodOrLambdaRequiresValue(ContainingMemberOrLambda, Compilation);
break;

case SyntaxKind.ArrowExpressionClause:
resultIsUsed = (((ArrowExpressionClauseSyntax)parent).Expression != node) || MethodOrLambdaRequiresValue(ContainingMemberOrLambda, Compilation);
break;

case SyntaxKind.ForStatement:
// Incrementors and Initializers doesn't have to produce a value
var loop = (ForStatementSyntax)parent;
resultIsUsed = !loop.Incrementors.Contains(node) && !loop.Initializers.Contains(node);
break;
}
}

if (resultIsUsed)
if (ResultIsUsed(node))
{
return GenerateBadConditionalAccessNodeError(node, receiver, access, diagnostics);
}
Expand All @@ -11061,6 +11030,42 @@ private BoundConditionalAccess BindConditionalAccessExpression(ConditionalAccess
return new BoundConditionalAccess(node, receiver, access, accessType);
}

private bool ResultIsUsed(ExpressionSyntax node)
{
bool resultIsUsed = true;
CSharpSyntaxNode parent = node.Parent;

if (parent != null)
{
switch (parent.Kind())
{
case SyntaxKind.ExpressionStatement:
resultIsUsed = ((ExpressionStatementSyntax)parent).Expression != node;
break;

case SyntaxKind.SimpleLambdaExpression:
resultIsUsed = (((SimpleLambdaExpressionSyntax)parent).Body != node) || MethodOrLambdaRequiresValue(ContainingMemberOrLambda, Compilation);
break;

case SyntaxKind.ParenthesizedLambdaExpression:
resultIsUsed = (((ParenthesizedLambdaExpressionSyntax)parent).Body != node) || MethodOrLambdaRequiresValue(ContainingMemberOrLambda, Compilation);
break;

case SyntaxKind.ArrowExpressionClause:
resultIsUsed = (((ArrowExpressionClauseSyntax)parent).Expression != node) || MethodOrLambdaRequiresValue(ContainingMemberOrLambda, Compilation);
break;

case SyntaxKind.ForStatement:
// Incrementors and Initializers doesn't have to produce a value
var loop = (ForStatementSyntax)parent;
resultIsUsed = !loop.Incrementors.Contains(node) && !loop.Initializers.Contains(node);
break;
}
}

return resultIsUsed;
}

internal static bool MethodOrLambdaRequiresValue(Symbol symbol, CSharpCompilation compilation)
{
return symbol is MethodSymbol method &&
Expand Down
6 changes: 5 additions & 1 deletion src/Compilers/CSharp/Portable/Binder/Binder_Lookup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,10 @@ internal SingleLookupResult CheckViability(Symbol symbol, int arity, LookupOptio
{
return LookupResult.Empty();
}
else if ((options & LookupOptions.MustBeOperator) != 0 && unwrappedSymbol is not MethodSymbol { MethodKind: MethodKind.UserDefinedOperator })
{
return LookupResult.Empty();
}
else if (!IsInScopeOfAssociatedSyntaxTree(unwrappedSymbol))
{
return LookupResult.Empty();
Expand All @@ -1417,7 +1421,7 @@ internal SingleLookupResult CheckViability(Symbol symbol, int arity, LookupOptio
{
return LookupResult.WrongArity(symbol, diagInfo);
}
else if (!InCref && !unwrappedSymbol.CanBeReferencedByNameIgnoringIllegalCharacters)
else if (!InCref && !unwrappedSymbol.CanBeReferencedByNameIgnoringIllegalCharacters && (options & LookupOptions.MustBeOperator) == 0)
{
// Strictly speaking, this test should actually check CanBeReferencedByName.
// However, we don't want to pay that cost in cases where the lookup is based
Expand Down
241 changes: 238 additions & 3 deletions src/Compilers/CSharp/Portable/Binder/Binder_Operators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.CSharp.Syntax;
Expand Down Expand Up @@ -2262,7 +2263,9 @@ public static BinaryOperatorKind SyntaxKindToBinaryOperatorKind(SyntaxKind kind)
}
}

private BoundExpression BindIncrementOperator(CSharpSyntaxNode node, ExpressionSyntax operandSyntax, SyntaxToken operatorToken, BindingDiagnosticBag diagnostics)
#nullable enable

private BoundExpression BindIncrementOperator(ExpressionSyntax node, ExpressionSyntax operandSyntax, SyntaxToken operatorToken, BindingDiagnosticBag diagnostics)
{
operandSyntax.CheckDeconstructionCompatibleArgument(diagnostics);

Expand Down Expand Up @@ -2290,7 +2293,7 @@ private BoundExpression BindIncrementOperator(CSharpSyntaxNode node, ExpressionS

// The operand has to be a variable, property or indexer, so it must have a type.
var operandType = operand.Type;
Debug.Assert((object)operandType != null);
Debug.Assert(operandType is not null);

if (operandType.IsDynamic())
{
Expand All @@ -2310,6 +2313,13 @@ private BoundExpression BindIncrementOperator(CSharpSyntaxNode node, ExpressionS
hasErrors: false);
}

// Try an in-place user-defined operator
BoundIncrementOperator? inPlaceResult = tryApplyUserDefinedInstanceOperator(node, operatorToken, kind, operand, diagnostics);
if (inPlaceResult is not null)
{
return inPlaceResult;
}

LookupResultKind resultKind;
ImmutableArray<MethodSymbol> originalUserDefinedOperators;
var best = this.UnaryOperatorOverloadResolution(kind, operand, node, diagnostics, out resultKind, out originalUserDefinedOperators);
Expand Down Expand Up @@ -2339,7 +2349,7 @@ private BoundExpression BindIncrementOperator(CSharpSyntaxNode node, ExpressionS

var resultPlaceholder = new BoundValuePlaceholder(node, signature.ReturnType).MakeCompilerGenerated();

BoundExpression resultConversion = GenerateConversionForAssignment(operandType, resultPlaceholder, diagnostics, ConversionForAssignmentFlags.IncrementAssignment);
BoundExpression? resultConversion = GenerateConversionForAssignment(operandType, resultPlaceholder, diagnostics, ConversionForAssignmentFlags.IncrementAssignment);

bool hasErrors = resultConversion.HasErrors;

Expand Down Expand Up @@ -2376,6 +2386,231 @@ private BoundExpression BindIncrementOperator(CSharpSyntaxNode node, ExpressionS
originalUserDefinedOperators,
operandType,
hasErrors);

BoundIncrementOperator? tryApplyUserDefinedInstanceOperator(ExpressionSyntax node, SyntaxToken operatorToken, UnaryOperatorKind kind, BoundExpression operand, BindingDiagnosticBag diagnostics)
{
var operandType = operand.Type;
Debug.Assert(operandType is not null);
Debug.Assert(!operandType.IsDynamic());

if (kind is not (UnaryOperatorKind.PrefixIncrement or UnaryOperatorKind.PrefixDecrement or UnaryOperatorKind.PostfixIncrement or UnaryOperatorKind.PostfixDecrement) ||
operandType.SpecialType.IsNumericType() ||
!node.IsFeatureEnabled(MessageID.IDS_FeatureUserDefinedCompoundAssignmentOperators))
{
return null;
}

bool resultIsUsed = ResultIsUsed(node);

if ((kind is (UnaryOperatorKind.PostfixIncrement or UnaryOperatorKind.PostfixDecrement) && resultIsUsed) ||
!CheckValueKind(node, operand, BindValueKind.RefersToLocation | BindValueKind.Assignable, checkingReceiver: false, BindingDiagnosticBag.Discarded))
{
return null;
}

bool checkOverflowAtRuntime = CheckOverflowAtRuntime;
CompoundUseSiteInfo<AssemblySymbol> useSiteInfo = GetNewCompoundUseSiteInfo(diagnostics);

ArrayBuilder<MethodSymbol>? methods = lookupUserDefinedInstanceOperators(
operandType,
checkedName: checkOverflowAtRuntime ?
(kind is UnaryOperatorKind.PrefixIncrement or UnaryOperatorKind.PostfixIncrement ?
WellKnownMemberNames.CheckedIncrementOperatorName :
WellKnownMemberNames.CheckedDecrementOperatorName) :
null,
ordinaryName: kind is UnaryOperatorKind.PrefixIncrement or UnaryOperatorKind.PostfixIncrement ?
WellKnownMemberNames.IncrementOperatorName :
WellKnownMemberNames.DecrementOperatorName,
ref useSiteInfo);

if (methods?.IsEmpty != false)
{
diagnostics.Add(node, useSiteInfo);
methods?.Free();
return null;
}

Debug.Assert(!methods.IsEmpty);

var overloadResolutionResult = OverloadResolutionResult<MethodSymbol>.GetInstance();
var typeArguments = ArrayBuilder<TypeWithAnnotations>.GetInstance();
var analyzedArguments = AnalyzedArguments.GetInstance();

OverloadResolution.MethodInvocationOverloadResolution(
methods,
typeArguments,
operand,
analyzedArguments,
overloadResolutionResult,
ref useSiteInfo,
OverloadResolution.Options.None);

typeArguments.Free();
diagnostics.Add(node, useSiteInfo);

BoundIncrementOperator? inPlaceResult;

if (overloadResolutionResult.Succeeded)
{
var method = overloadResolutionResult.ValidResult.Member;

ReportDiagnosticsIfObsolete(diagnostics, method, node, hasBaseReceiver: false);
ReportDiagnosticsIfUnmanagedCallersOnly(diagnostics, method, node, isDelegateConversion: false);

inPlaceResult = new BoundIncrementOperator(
node,
(kind | UnaryOperatorKind.UserDefined).WithOverflowChecksIfApplicable(checkOverflowAtRuntime),
operand,
methodOpt: method,
constrainedToTypeOpt: null,
operandPlaceholder: null,
operandConversion: null,
resultPlaceholder: null,
resultConversion: null,
LookupResultKind.Viable,
ImmutableArray<MethodSymbol>.Empty,
resultIsUsed ? operandType : GetSpecialType(SpecialType.System_Void, diagnostics, node));

methods.Free();
}
else if (overloadResolutionResult.HasAnyApplicableMember)
{
ImmutableArray<MethodSymbol> methodsArray = methods.ToImmutableAndFree();

overloadResolutionResult.ReportDiagnostics(
binder: this, location: operatorToken.GetLocation(), nodeOpt: node, diagnostics: diagnostics, name: operatorToken.ValueText,
receiver: operand, invokedExpression: node, arguments: analyzedArguments, memberGroup: methodsArray,
typeContainingConstructor: null, delegateTypeBeingInvoked: null);

inPlaceResult = new BoundIncrementOperator(
node,
(kind | UnaryOperatorKind.UserDefined).WithOverflowChecksIfApplicable(checkOverflowAtRuntime),
operand,
methodOpt: null,
constrainedToTypeOpt: null,
operandPlaceholder: null,
operandConversion: null,
resultPlaceholder: null,
resultConversion: null,
LookupResultKind.OverloadResolutionFailure,
methodsArray,
resultIsUsed ? operandType : GetSpecialType(SpecialType.System_Void, diagnostics, node));
}
else
{
inPlaceResult = null;
methods.Free();
}

analyzedArguments.Free();
overloadResolutionResult.Free();

return inPlaceResult;
}

ArrayBuilder<MethodSymbol>? lookupUserDefinedInstanceOperators(TypeSymbol lookupInType, string? checkedName, string ordinaryName, ref CompoundUseSiteInfo<AssemblySymbol> useSiteInfo)
{
var lookupResult = LookupResult.GetInstance();
ArrayBuilder<MethodSymbol>? methods = null;
if (checkedName is not null)
{
this.LookupMembersWithFallback(lookupResult, lookupInType, name: checkedName, arity: 0, ref useSiteInfo, basesBeingResolved: null, options: LookupOptions.MustBeInstance | LookupOptions.MustBeOperator);

if (lookupResult.IsMultiViable)
{
methods = ArrayBuilder<MethodSymbol>.GetInstance(lookupResult.Symbols.Count);
appendViableMethods(lookupResult, methods);
}

lookupResult.Clear();
}

this.LookupMembersWithFallback(lookupResult, lookupInType, name: ordinaryName, arity: 0, ref useSiteInfo, basesBeingResolved: null, options: LookupOptions.MustBeInstance | LookupOptions.MustBeOperator);

if (lookupResult.IsMultiViable)
{
if (methods is null)
{
methods = ArrayBuilder<MethodSymbol>.GetInstance(lookupResult.Symbols.Count);
appendViableMethods(lookupResult, methods);
}
else
{
var existing = new HashSet<MethodSymbol>(PairedOperatorComparer.Instance);

foreach (var method in methods)
{
existing.Add(method.GetLeastOverriddenMethod(ContainingType));
}

foreach (MethodSymbol method in lookupResult.Symbols)
{
if (isViable(method) && !existing.Contains(method.GetLeastOverriddenMethod(ContainingType)))
{
methods.Add(method);
}
}
}
}

lookupResult.Free();

return methods;

static void appendViableMethods(LookupResult lookupResult, ArrayBuilder<MethodSymbol> methods)
{
foreach (MethodSymbol method in lookupResult.Symbols)
{
if (isViable(method))
{
methods.Add(method);
}
}
}

static bool isViable(MethodSymbol method)
{
return method.ParameterCount == 0;
}
}
}

#nullable disable

private class PairedOperatorComparer : IEqualityComparer<MethodSymbol>
{
public static readonly PairedOperatorComparer Instance = new PairedOperatorComparer();

private PairedOperatorComparer() { }

public bool Equals(MethodSymbol x, MethodSymbol y)
{
Debug.Assert(!x.IsOverride);
Debug.Assert(!x.IsStatic);

Debug.Assert(!y.IsOverride);
Debug.Assert(!y.IsStatic);

var typeComparer = Symbols.SymbolEqualityComparer.AllIgnoreOptions;
return typeComparer.Equals(x.ContainingType, y.ContainingType) &&
SourceMemberContainerTypeSymbol.DoOperatorsPair(x, y);
}

public int GetHashCode([DisallowNull] MethodSymbol method)
{
Debug.Assert(!method.IsOverride);
Debug.Assert(!method.IsStatic);

var typeComparer = Symbols.SymbolEqualityComparer.AllIgnoreOptions;
int result = typeComparer.GetHashCode(method.ContainingType);

if (method.ParameterTypesWithAnnotations is [var typeWithAnnotations, ..])
{
result = Hash.Combine(result, typeComparer.GetHashCode(typeWithAnnotations.Type));
}

return result;
}
}

#nullable enable
Expand Down
5 changes: 5 additions & 0 deletions src/Compilers/CSharp/Portable/Binder/LookupOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ internal enum LookupOptions
/// Do not consider symbols that are parameters.
/// </summary>
MustNotBeParameter = 1 << 16,

/// <summary>
/// Consider only symbols that are user-defined operators.
/// </summary>
MustBeOperator = 1 << 17,
}

internal static class LookupOptionExtensions
Expand Down
Loading