Skip to content

Added Support for translating Array.IndexOf methods for byte arrays for SqlServer & SQLite #34457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@ public class SqlServerByteArrayMethodTranslator : IMethodCallTranslator
{
private readonly ISqlExpressionFactory _sqlExpressionFactory;

// NOTE: Might want to move these to a shared file, similar to EnumerableMethods
private static readonly MethodInfo IndexOfMethodInfo
= typeof(Array)
.GetGenericMethod(nameof(Array.IndexOf), 1, BindingFlags.Public | BindingFlags.Static, (_, t) =>
{
return [t[0].MakeArrayType(), t[0]];
})!;

private static readonly MethodInfo IndexOfWithStartingPositionMethodInfo
= typeof(Array)
.GetGenericMethod(nameof(Array.IndexOf), 1, BindingFlags.Public | BindingFlags.Static, (_, t) =>
{
return [t[0].MakeArrayType(), t[0], typeof(int)];
})!;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand All @@ -40,40 +55,107 @@ public SqlServerByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFac
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)
&& arguments.Count >= 1
&& arguments[0].Type == typeof(byte[]))
{
var source = arguments[0];
var sourceTypeMapping = source.TypeMapping;
var methodDefinition = method.GetGenericMethodDefinition();
if (methodDefinition.Equals(EnumerableMethods.Contains))
{
// NOTE: Should this be refactored to use the TranslateIndexOf method?? Everything is same expect one check
var source = arguments[0];
var sourceTypeMapping = source.TypeMapping;

var value = arguments[1] is SqlConstantExpression constantValue
? _sqlExpressionFactory.Constant(new[] { (byte)constantValue.Value! }, sourceTypeMapping)
: _sqlExpressionFactory.Convert(arguments[1], typeof(byte[]), sourceTypeMapping);
var value = arguments[1] is SqlConstantExpression constantValue
? _sqlExpressionFactory.Constant(new[] { (byte)constantValue.Value! }, sourceTypeMapping)
: _sqlExpressionFactory.Convert(arguments[1], typeof(byte[]), sourceTypeMapping);

return _sqlExpressionFactory.GreaterThan(
_sqlExpressionFactory.Function(
"CHARINDEX",
[value, source],
nullable: true,
argumentsPropagateNullability: [true, true],
typeof(int)),
_sqlExpressionFactory.Constant(0));
}
return _sqlExpressionFactory.GreaterThan(
_sqlExpressionFactory.Function(
"CHARINDEX",
[value, source],
nullable: true,
argumentsPropagateNullability: [true, true],
typeof(int)),
_sqlExpressionFactory.Constant(0));
}

if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.FirstWithoutPredicate)
&& arguments[0].Type == typeof(byte[]))
{
return _sqlExpressionFactory.Convert(
if (methodDefinition.Equals(EnumerableMethods.FirstWithoutPredicate))
{
return _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function(
"SUBSTRING",
[arguments[0], _sqlExpressionFactory.Constant(1), _sqlExpressionFactory.Constant(1)],
nullable: true,
argumentsPropagateNullability: [true, true, true],
typeof(byte[])),
method.ReturnType);
}

if (methodDefinition.Equals(IndexOfMethodInfo))
{
return TranslateIndexOf(method, arguments[0], arguments[1], null);
}

if (methodDefinition.Equals(IndexOfWithStartingPositionMethodInfo))
{
return TranslateIndexOf(method, arguments[0], arguments[1], arguments[2]);
}
}

return null;
}

private SqlExpression TranslateIndexOf(
MethodInfo method,
SqlExpression source,
SqlExpression valueToSearch,
SqlExpression? startIndex
)
{
var sourceTypeMapping = source.TypeMapping;
var sqlArguments = new List<SqlExpression>
{
valueToSearch is SqlConstantExpression { Value: byte constantValue }
? _sqlExpressionFactory.Constant(new byte[] { constantValue }, sourceTypeMapping)
: _sqlExpressionFactory.Convert(valueToSearch, typeof(byte[]), sourceTypeMapping),
source
};

if (startIndex is not null)
{
sqlArguments.Add(
startIndex is SqlConstantExpression { Value : int index }
? _sqlExpressionFactory.Constant(index + 1, typeof(int))
: _sqlExpressionFactory.Add(startIndex, _sqlExpressionFactory.Constant(1))
);
}

var argumentsPropagateNullability = Enumerable.Repeat(true, sqlArguments.Count);

SqlExpression charIndexExpr;
var storeType = sourceTypeMapping?.StoreType;
if (storeType == "varbinary(max)")
{
charIndexExpr = _sqlExpressionFactory.Function(
"CHARINDEX",
sqlArguments,
nullable: true,
argumentsPropagateNullability: argumentsPropagateNullability,
typeof(long));

charIndexExpr = _sqlExpressionFactory.Convert(charIndexExpr, typeof(int));
}
else
{
charIndexExpr = _sqlExpressionFactory.Function(
"CHARINDEX",
sqlArguments,
nullable: true,
argumentsPropagateNullability: argumentsPropagateNullability,
method.ReturnType);
}


return _sqlExpressionFactory.Subtract(charIndexExpr, _sqlExpressionFactory.Constant(1));
}
}
98 changes: 98 additions & 0 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6259,6 +6259,104 @@ public virtual Task Byte_array_filter_by_length_parameter(bool async)
ss => ss.Set<Squad>().Where(w => w.Banner != null && w.Banner.Length == someByteArr.Length));
}

#region Byte Array IndexOf

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_max_filter_by_index_of_literal_casts_to_int(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, (byte)1) == 1),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cast here shouldn't be needed, no?

Suggested change
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, (byte)1) == 1),
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, 1) == 1),

Copy link
Author

@nikhil197 nikhil197 Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No actually, it is needed. Without this it's picking non-generic version of the method IndexOf(Array arr, object value) (because 1 is an Int32 by default and I haven't specified the type argument on the IndexOf).

Do we want to support that too?

ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner, (byte)1) == 1)
);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_max_filter_by_index_of_parameter_casts_to_int(bool async)
{
byte b = 0;
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, b) == 0),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner, b) == 0)
);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_n_filter_by_index_of_literal_does_not_cast(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, (byte)5) == 1),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner5, (byte)5) == 1)
);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_n_filter_by_index_of_parameter_does_not_cast(bool async)
{
byte b = 4;
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, b) == 0),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner5, b) == 0)
);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_max_filter_by_index_of_with_starting_position_literal_casts_to_int(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, (byte)1, 1) == 1),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner, (byte)1, 1) == 1)
);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_max_filter_by_index_of_with_starting_position_parameter_casts_to_int(bool async)
{
byte b = 0;
int startPos = 0;
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, b, startPos) == 0),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner, b, startPos) == 0)
);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_n_filter_by_index_of_with_starting_position_literal_does_not_cast(bool async)
{
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, (byte)5, 1) == 1),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner5, (byte)5, 1) == 1)
);
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_n_filter_by_index_of_with_starting_position_parameter_does_not_cast(bool async)
{
byte b = 4;
int startPos = 0;
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, b, startPos) == 0),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner5, b, startPos) == 0)
);
}

#endregion

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task OrderBy_bool_coming_from_optional_navigation(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public class Squad
{
public Squad()
{
Members = new List<Gear>();
Members = [];
}

// non-auto generated key
Expand Down
Loading