Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/libraries/System.Collections/ref/System.Collections.cs
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ public abstract partial class EqualityComparer<T> : System.Collections.Generic.I
{
protected EqualityComparer() { }
public static System.Collections.Generic.EqualityComparer<T> Create(System.Func<T?, T?, bool> equals, System.Func<T, int>? getHashCode = null) { throw null; }
public static System.Collections.Generic.EqualityComparer<T> Create<TKey>(System.Func<T?, TKey?> keySelector, System.Collections.Generic.IEqualityComparer<TKey>? keyComparer = null) { throw null; }
public static System.Collections.Generic.EqualityComparer<T> Default { get { throw null; } }
public abstract bool Equals(T? x, T? y);
public abstract int GetHashCode([System.Diagnostics.CodeAnalysis.DisallowNullAttribute] T obj);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,126 @@ public void EqualityComparerCreate_DelegatesUsed()
Assert.Equal(2, getHashCodeCalls);
}

[Fact]
public void EqualityComparerCreate_KeySelectorNull_Throws()
{
AssertExtensions.Throws<ArgumentNullException>("keySelector", () => EqualityComparer<string>.Create<int>(keySelector: null));
}

[Fact]
public void EqualityComparerCreate_KeySelectorUsed()
{
var original = "foo";
var otherEqualLen = "bar";
var otherLongerLen = "fooo";

var comparer = EqualityComparer<string>.Create(str => str.Length);

Assert.True(comparer.Equals(original, original));
Assert.True(comparer.Equals(original, otherEqualLen));
Assert.False(comparer.Equals(original, otherLongerLen));

Assert.Equal(comparer.GetHashCode(original), comparer.GetHashCode(otherEqualLen));
}

[Fact]
public void EqualityComparerCreate_KeySelectorPassesNullToSelector()
{
int selectorCalls = 0;
var comparer = EqualityComparer<string>.Create<string>(str =>
{
selectorCalls++;
return str ?? "default";
});

// Null is passed through to keySelector in Equals, mapping to "default"
Assert.True(comparer.Equals(null, null));
Assert.Equal(2, selectorCalls);

selectorCalls = 0;
Assert.True(comparer.Equals(null, "default"));
Assert.Equal(2, selectorCalls);

selectorCalls = 0;
Assert.True(comparer.Equals("default", null));
Assert.Equal(2, selectorCalls);

selectorCalls = 0;
Assert.False(comparer.Equals(null, "other"));
Assert.Equal(2, selectorCalls);

// Null is passed through to keySelector in GetHashCode
selectorCalls = 0;
int hashCode = comparer.GetHashCode(null);
Assert.Equal(1, selectorCalls);
Assert.Equal(comparer.GetHashCode("default"), hashCode);
}

[Fact]
public void EqualityComparerCreate_KeySelectorReturnsNullKey()
{
var comparer = EqualityComparer<string>.Create<string>(str => str == "nil" ? null : str);

// When keySelector returns null for both, they are equal
Assert.True(comparer.Equals("nil", "nil"));

// When keySelector returns null for one side only, they are not equal
Assert.False(comparer.Equals("nil", "foo"));
Assert.False(comparer.Equals("foo", "nil"));

// GetHashCode returns 0 for a null key
Assert.Equal(0, comparer.GetHashCode("nil"));
Assert.Equal(comparer.GetHashCode("nil"), comparer.GetHashCode("nil"));
}

[Fact]
public void EqualityComparerCreate_KeySelectorNotHandlingNull_Throws()
{
var comparer = EqualityComparer<string>.Create<int>(str => str.Length);

// keySelector doesn't guard against null, so NullReferenceException propagates
Assert.Throws<NullReferenceException>(() => comparer.Equals(null, "foo"));
Assert.Throws<NullReferenceException>(() => comparer.Equals("foo", null));
Assert.Throws<NullReferenceException>(() => comparer.GetHashCode(null));
}

[Fact]
public void EqualityComparerCreate_KeySelectorComparerUsed()
{
var evenLen1 = "12";
var evenLen2 = "1234";
var evenLen3 = "123456";
var oddLen1 = "1";
var oddLen2 = "123";

bool isEven(int len) => len % 2 == 0;

var evenOrOddComparer = EqualityComparer<int>.Create(equals: (len1, len2) => isEven(len1) == isEven(len2), getHashCode: len => isEven(len) ? 0 : 1);
var comparer = EqualityComparer<string>.Create(str => str?.Length ?? 0, keyComparer: evenOrOddComparer);

Assert.True(comparer.Equals(evenLen1, evenLen1));
Comment thread
eiriktsarpalis marked this conversation as resolved.
Assert.True(comparer.Equals(evenLen1, evenLen2));
Assert.True(comparer.Equals(evenLen1, evenLen3));
Assert.True(comparer.Equals(oddLen1, oddLen2));

Assert.False(comparer.Equals(evenLen1, oddLen1));
Assert.False(comparer.Equals(evenLen1, oddLen2));
Assert.False(comparer.Equals(oddLen1, evenLen2));

Assert.True(comparer.Equals(null, null));
Assert.True(comparer.Equals(evenLen1, null));
Assert.True(comparer.Equals(null, evenLen1));
Assert.False(comparer.Equals(oddLen1, null));
Assert.False(comparer.Equals(null, oddLen1));
Comment thread
eiriktsarpalis marked this conversation as resolved.

Assert.Equal(comparer.GetHashCode(evenLen1), comparer.GetHashCode(evenLen2));
Assert.Equal(comparer.GetHashCode(oddLen1), comparer.GetHashCode(oddLen2));
Assert.NotEqual(comparer.GetHashCode(evenLen1), comparer.GetHashCode(oddLen1));
Assert.Equal(comparer.GetHashCode(null), comparer.GetHashCode(null));
Assert.Equal(comparer.GetHashCode(null), comparer.GetHashCode(evenLen1));
Assert.NotEqual(comparer.GetHashCode(null), comparer.GetHashCode(oddLen1));
}
Comment thread
weitzhandler marked this conversation as resolved.
Comment thread
eiriktsarpalis marked this conversation as resolved.

[Fact]
public void EqualityComparerCreate_ArgsNotDereferenced()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,29 @@ public static EqualityComparer<T> Create(Func<T?, T?, bool> equals, Func<T, int>
return new DelegateEqualityComparer<T>(equals, getHashCode);
}

/// <summary>
/// Creates an <see cref="EqualityComparer{T}"/> by using the specified key selector and optional key comparer.
/// </summary>
/// <param name="keySelector">The delegate to use to select a comparison key from each element.</param>
/// <param name="keyComparer">An optional comparer to use when comparing keys. The default comparer of <typeparamref name="TKey"/> is used if none is specified.</param>
/// <returns>The new comparer.</returns>
/// <exception cref="ArgumentNullException">The <paramref name="keySelector"/> delegate was null.</exception>
public static EqualityComparer<T> Create<TKey>(Func<T?, TKey?> keySelector, IEqualityComparer<TKey>? keyComparer = null)
{
Comment thread
eiriktsarpalis marked this conversation as resolved.
ArgumentNullException.ThrowIfNull(keySelector);

keyComparer ??= EqualityComparer<TKey>.Default;

return new DelegateEqualityComparer<T>(
equals: (itemX, itemY) =>
keyComparer.Equals(x: keySelector(itemX), y: keySelector(itemY)),
getHashCode: obj =>
{
TKey? key = keySelector(obj);
return key is null ? 0 : keyComparer.GetHashCode(key);
});
Comment thread
weitzhandler marked this conversation as resolved.
}

public abstract bool Equals(T? x, T? y);
public abstract int GetHashCode([DisallowNull] T obj);

Expand Down
Loading