diff --git a/src/libraries/System.Collections/ref/System.Collections.cs b/src/libraries/System.Collections/ref/System.Collections.cs index e2578bb9d4b2b3..b97b5f48300945 100644 --- a/src/libraries/System.Collections/ref/System.Collections.cs +++ b/src/libraries/System.Collections/ref/System.Collections.cs @@ -756,6 +756,7 @@ public abstract partial class EqualityComparer : System.Collections.Generic.I { protected EqualityComparer() { } public static System.Collections.Generic.EqualityComparer Create(System.Func equals, System.Func? getHashCode = null) { throw null; } + public static System.Collections.Generic.EqualityComparer Create(System.Func keySelector, System.Collections.Generic.IEqualityComparer? keyComparer = null) { throw null; } public static System.Collections.Generic.EqualityComparer Default { get { throw null; } } public abstract bool Equals(T? x, T? y); public abstract int GetHashCode([System.Diagnostics.CodeAnalysis.DisallowNullAttribute] T obj); diff --git a/src/libraries/System.Collections/tests/Generic/Comparers/EqualityComparer.Tests.cs b/src/libraries/System.Collections/tests/Generic/Comparers/EqualityComparer.Tests.cs index d25a7cfd7ab8b6..b45874cd3e3f31 100644 --- a/src/libraries/System.Collections/tests/Generic/Comparers/EqualityComparer.Tests.cs +++ b/src/libraries/System.Collections/tests/Generic/Comparers/EqualityComparer.Tests.cs @@ -518,6 +518,126 @@ public void EqualityComparerCreate_DelegatesUsed() Assert.Equal(2, getHashCodeCalls); } + [Fact] + public void EqualityComparerCreate_KeySelectorNull_Throws() + { + AssertExtensions.Throws("keySelector", () => EqualityComparer.Create(keySelector: null)); + } + + [Fact] + public void EqualityComparerCreate_KeySelectorUsed() + { + var original = "foo"; + var otherEqualLen = "bar"; + var otherLongerLen = "fooo"; + + var comparer = EqualityComparer.Create(str => str.Length); + + Assert.True(comparer.Equals(original, original)); + Assert.True(comparer.Equals(original, otherEqualLen)); + Assert.False(comparer.Equals(original, otherLongerLen)); + + Assert.Equal(comparer.GetHashCode(original), comparer.GetHashCode(otherEqualLen)); + } + + [Fact] + public void EqualityComparerCreate_KeySelectorPassesNullToSelector() + { + int selectorCalls = 0; + var comparer = EqualityComparer.Create(str => + { + selectorCalls++; + return str ?? "default"; + }); + + // Null is passed through to keySelector in Equals, mapping to "default" + Assert.True(comparer.Equals(null, null)); + Assert.Equal(2, selectorCalls); + + selectorCalls = 0; + Assert.True(comparer.Equals(null, "default")); + Assert.Equal(2, selectorCalls); + + selectorCalls = 0; + Assert.True(comparer.Equals("default", null)); + Assert.Equal(2, selectorCalls); + + selectorCalls = 0; + Assert.False(comparer.Equals(null, "other")); + Assert.Equal(2, selectorCalls); + + // Null is passed through to keySelector in GetHashCode + selectorCalls = 0; + int hashCode = comparer.GetHashCode(null); + Assert.Equal(1, selectorCalls); + Assert.Equal(comparer.GetHashCode("default"), hashCode); + } + + [Fact] + public void EqualityComparerCreate_KeySelectorReturnsNullKey() + { + var comparer = EqualityComparer.Create(str => str == "nil" ? null : str); + + // When keySelector returns null for both, they are equal + Assert.True(comparer.Equals("nil", "nil")); + + // When keySelector returns null for one side only, they are not equal + Assert.False(comparer.Equals("nil", "foo")); + Assert.False(comparer.Equals("foo", "nil")); + + // GetHashCode returns 0 for a null key + Assert.Equal(0, comparer.GetHashCode("nil")); + Assert.Equal(comparer.GetHashCode("nil"), comparer.GetHashCode("nil")); + } + + [Fact] + public void EqualityComparerCreate_KeySelectorNotHandlingNull_Throws() + { + var comparer = EqualityComparer.Create(str => str.Length); + + // keySelector doesn't guard against null, so NullReferenceException propagates + Assert.Throws(() => comparer.Equals(null, "foo")); + Assert.Throws(() => comparer.Equals("foo", null)); + Assert.Throws(() => comparer.GetHashCode(null)); + } + + [Fact] + public void EqualityComparerCreate_KeySelectorComparerUsed() + { + var evenLen1 = "12"; + var evenLen2 = "1234"; + var evenLen3 = "123456"; + var oddLen1 = "1"; + var oddLen2 = "123"; + + bool isEven(int len) => len % 2 == 0; + + var evenOrOddComparer = EqualityComparer.Create(equals: (len1, len2) => isEven(len1) == isEven(len2), getHashCode: len => isEven(len) ? 0 : 1); + var comparer = EqualityComparer.Create(str => str?.Length ?? 0, keyComparer: evenOrOddComparer); + + Assert.True(comparer.Equals(evenLen1, evenLen1)); + Assert.True(comparer.Equals(evenLen1, evenLen2)); + Assert.True(comparer.Equals(evenLen1, evenLen3)); + Assert.True(comparer.Equals(oddLen1, oddLen2)); + + Assert.False(comparer.Equals(evenLen1, oddLen1)); + Assert.False(comparer.Equals(evenLen1, oddLen2)); + Assert.False(comparer.Equals(oddLen1, evenLen2)); + + Assert.True(comparer.Equals(null, null)); + Assert.True(comparer.Equals(evenLen1, null)); + Assert.True(comparer.Equals(null, evenLen1)); + Assert.False(comparer.Equals(oddLen1, null)); + Assert.False(comparer.Equals(null, oddLen1)); + + Assert.Equal(comparer.GetHashCode(evenLen1), comparer.GetHashCode(evenLen2)); + Assert.Equal(comparer.GetHashCode(oddLen1), comparer.GetHashCode(oddLen2)); + Assert.NotEqual(comparer.GetHashCode(evenLen1), comparer.GetHashCode(oddLen1)); + Assert.Equal(comparer.GetHashCode(null), comparer.GetHashCode(null)); + Assert.Equal(comparer.GetHashCode(null), comparer.GetHashCode(evenLen1)); + Assert.NotEqual(comparer.GetHashCode(null), comparer.GetHashCode(oddLen1)); + } + [Fact] public void EqualityComparerCreate_ArgsNotDereferenced() { diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.cs index 43b26125056808..1afd6d542769f4 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/EqualityComparer.cs @@ -34,6 +34,29 @@ public static EqualityComparer Create(Func equals, Func return new DelegateEqualityComparer(equals, getHashCode); } + /// + /// Creates an by using the specified key selector and optional key comparer. + /// + /// The delegate to use to select a comparison key from each element. + /// An optional comparer to use when comparing keys. The default comparer of is used if none is specified. + /// The new comparer. + /// The delegate was null. + public static EqualityComparer Create(Func keySelector, IEqualityComparer? keyComparer = null) + { + ArgumentNullException.ThrowIfNull(keySelector); + + keyComparer ??= EqualityComparer.Default; + + return new DelegateEqualityComparer( + equals: (itemX, itemY) => + keyComparer.Equals(x: keySelector(itemX), y: keySelector(itemY)), + getHashCode: obj => + { + TKey? key = keySelector(obj); + return key is null ? 0 : keyComparer.GetHashCode(key); + }); + } + public abstract bool Equals(T? x, T? y); public abstract int GetHashCode([DisallowNull] T obj);