-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
Copy pathCaseExpression.cs
168 lines (148 loc) · 6.3 KB
/
CaseExpression.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal;
/// <summary>
/// <para>
/// An expression that represents a CASE statement in a SQL tree.
/// </para>
/// <para>
/// This type is typically used by database providers (and other extensions). It is generally
/// not used in application code.
/// </para>
/// </summary>
public class CaseExpression : SqlExpression
{
private readonly List<CaseWhenClause> _whenClauses = [];
/// <summary>
/// Creates a new instance of the <see cref="CaseExpression" /> class which represents a simple CASE expression.
/// </summary>
/// <param name="operand">An expression to compare with <see cref="CaseWhenClause.Test" /> in <see cref="WhenClauses" />.</param>
/// <param name="whenClauses">A list of <see cref="CaseWhenClause" /> to compare or evaluate and get result from.</param>
/// <param name="elseResult">A value to return if no <see cref="WhenClauses" /> matches, if any.</param>
public CaseExpression(
SqlExpression? operand,
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult = null)
: base(whenClauses[0].Result.Type, whenClauses[0].Result.TypeMapping)
{
Operand = operand;
_whenClauses.AddRange(whenClauses);
ElseResult = elseResult;
}
/// <summary>
/// Creates a new instance of the <see cref="CaseExpression" /> class which represents a searched CASE expression.
/// </summary>
/// <param name="whenClauses">A list of <see cref="CaseWhenClause" /> to evaluate condition and get result from.</param>
/// <param name="elseResult">A value to return if no <see cref="WhenClauses" /> matches, if any.</param>
public CaseExpression(
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult = null)
: this(null, whenClauses, elseResult)
{
}
/// <summary>
/// The value to compare in <see cref="WhenClauses" />.
/// </summary>
public virtual SqlExpression? Operand { get; }
/// <summary>
/// The list of <see cref="CaseWhenClause" /> to match <see cref="Operand" /> or evaluate condition to get result.
/// </summary>
public virtual IReadOnlyList<CaseWhenClause> WhenClauses
=> _whenClauses;
/// <summary>
/// The value to return if none of the <see cref="WhenClauses" /> matches.
/// </summary>
public virtual SqlExpression? ElseResult { get; }
/// <inheritdoc />
protected override Expression VisitChildren(ExpressionVisitor visitor)
{
var operand = (SqlExpression?)visitor.Visit(Operand);
var changed = operand != Operand;
var whenClauses = new List<CaseWhenClause>();
foreach (var whenClause in WhenClauses)
{
var test = (SqlExpression)visitor.Visit(whenClause.Test);
var result = (SqlExpression)visitor.Visit(whenClause.Result);
if (test != whenClause.Test
|| result != whenClause.Result)
{
changed = true;
whenClauses.Add(new CaseWhenClause(test, result));
}
else
{
whenClauses.Add(whenClause);
}
}
var elseResult = (SqlExpression?)visitor.Visit(ElseResult);
changed |= elseResult != ElseResult;
return changed
? new CaseExpression(operand, whenClauses, elseResult)
: this;
}
/// <summary>
/// Creates a new expression that is like this one, but using the supplied children. If all of the children are the same, it will
/// return this expression.
/// </summary>
/// <param name="operand">The <see cref="Operand" /> property of the result.</param>
/// <param name="whenClauses">The <see cref="WhenClauses" /> property of the result.</param>
/// <param name="elseResult">The <see cref="ElseResult" /> property of the result.</param>
/// <returns>This expression if no children changed, or an expression with the updated children.</returns>
public virtual CaseExpression Update(
SqlExpression? operand,
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult)
=> operand != Operand || !whenClauses.SequenceEqual(WhenClauses) || elseResult != ElseResult
? new CaseExpression(operand, whenClauses, elseResult)
: this;
/// <inheritdoc />
protected override void Print(ExpressionPrinter expressionPrinter)
{
expressionPrinter.Append("CASE");
if (Operand != null)
{
expressionPrinter.Append(" ");
expressionPrinter.Visit(Operand);
}
using (expressionPrinter.Indent())
{
foreach (var whenClause in WhenClauses)
{
expressionPrinter.AppendLine().Append("WHEN ");
expressionPrinter.Visit(whenClause.Test);
expressionPrinter.Append(" THEN ");
expressionPrinter.Visit(whenClause.Result);
}
if (ElseResult != null)
{
expressionPrinter.AppendLine().Append("ELSE ");
expressionPrinter.Visit(ElseResult);
}
}
expressionPrinter.AppendLine().Append("END");
}
/// <inheritdoc />
public override bool Equals(object? obj)
=> obj != null
&& (ReferenceEquals(this, obj)
|| obj is CaseExpression caseExpression
&& Equals(caseExpression));
private bool Equals(CaseExpression caseExpression)
=> base.Equals(caseExpression)
&& (Operand?.Equals(caseExpression.Operand) ?? caseExpression.Operand == null)
&& WhenClauses.SequenceEqual(caseExpression.WhenClauses)
&& (ElseResult?.Equals(caseExpression.ElseResult) ?? caseExpression.ElseResult == null);
/// <inheritdoc />
public override int GetHashCode()
{
var hash = new HashCode();
hash.Add(base.GetHashCode());
hash.Add(Operand);
for (var i = 0; i < WhenClauses.Count; i++)
{
hash.Add(WhenClauses[i]);
}
hash.Add(ElseResult);
return hash.ToHashCode();
}
}