[ACCEPTED]-Combining two lambda expressions in c#-lambda
OK; pretty long snippet, but here's a starter for 4 an expression-rewriter; it doesn't handle 3 a few cases yet (I'll fix it later), but 2 it works for the example given and a lot of 1 others:
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Text.RegularExpressions;
public class GrandParent
{
public Parent Parent { get; set; }
}
public class Parent
{
public Child Child { get; set; }
public string Method(string s) { return s + "abc"; }
}
public class Child
{
public string Name { get; set; }
}
public static class ExpressionUtils
{
public static Expression<Func<T1, T3>> Combine<T1, T2, T3>(
this Expression<Func<T1, T2>> outer, Expression<Func<T2, T3>> inner, bool inline)
{
var invoke = Expression.Invoke(inner, outer.Body);
Expression body = inline ? new ExpressionRewriter().AutoInline(invoke) : invoke;
return Expression.Lambda<Func<T1, T3>>(body, outer.Parameters);
}
}
public class ExpressionRewriter
{
internal Expression AutoInline(InvocationExpression expression)
{
isLocked = true;
if(expression == null) throw new ArgumentNullException("expression");
LambdaExpression lambda = (LambdaExpression)expression.Expression;
ExpressionRewriter childScope = new ExpressionRewriter(this);
var lambdaParams = lambda.Parameters;
var invokeArgs = expression.Arguments;
if (lambdaParams.Count != invokeArgs.Count) throw new InvalidOperationException("Lambda/invoke mismatch");
for(int i = 0 ; i < lambdaParams.Count; i++) {
childScope.Subst(lambdaParams[i], invokeArgs[i]);
}
return childScope.Apply(lambda.Body);
}
public ExpressionRewriter()
{
subst = new Dictionary<Expression, Expression>();
}
private ExpressionRewriter(ExpressionRewriter parent)
{
if (parent == null) throw new ArgumentNullException("parent");
subst = new Dictionary<Expression, Expression>(parent.subst);
inline = parent.inline;
}
private bool isLocked, inline;
private readonly Dictionary<Expression, Expression> subst;
private void CheckLocked() {
if(isLocked) throw new InvalidOperationException(
"You cannot alter the rewriter after Apply has been called");
}
public ExpressionRewriter Subst(Expression from,
Expression to)
{
CheckLocked();
subst.Add(from, to);
return this;
}
public ExpressionRewriter Inline() {
CheckLocked();
inline = true;
return this;
}
public Expression Apply(Expression expression)
{
isLocked = true;
return Walk(expression) ?? expression;
}
private static IEnumerable<Expression> CoalesceTerms(
IEnumerable<Expression> sourceWithNulls, IEnumerable<Expression> replacements)
{
if(sourceWithNulls != null && replacements != null) {
using(var left = sourceWithNulls.GetEnumerator())
using (var right = replacements.GetEnumerator())
{
while (left.MoveNext() && right.MoveNext())
{
yield return left.Current ?? right.Current;
}
}
}
}
private Expression[] Walk(IEnumerable<Expression> expressions) {
if(expressions == null) return null;
return expressions.Select(expr => Walk(expr)).ToArray();
}
private static bool HasValue(Expression[] expressions)
{
return expressions != null && expressions.Any(expr => expr != null);
}
// returns null if no need to rewrite that branch, otherwise
// returns a re-written branch
private Expression Walk(Expression expression)
{
if (expression == null) return null;
Expression tmp;
if (subst.TryGetValue(expression, out tmp)) return tmp;
switch(expression.NodeType) {
case ExpressionType.Constant:
case ExpressionType.Parameter:
{
return expression; // never a need to rewrite if not already matched
}
case ExpressionType.MemberAccess:
{
MemberExpression me = (MemberExpression)expression;
Expression target = Walk(me.Expression);
return target == null ? null : Expression.MakeMemberAccess(target, me.Member);
}
case ExpressionType.Add:
case ExpressionType.Divide:
case ExpressionType.Multiply:
case ExpressionType.Subtract:
case ExpressionType.AddChecked:
case ExpressionType.MultiplyChecked:
case ExpressionType.SubtractChecked:
case ExpressionType.And:
case ExpressionType.Or:
case ExpressionType.ExclusiveOr:
case ExpressionType.Equal:
case ExpressionType.NotEqual:
case ExpressionType.AndAlso:
case ExpressionType.OrElse:
case ExpressionType.Power:
case ExpressionType.Modulo:
case ExpressionType.GreaterThan:
case ExpressionType.GreaterThanOrEqual:
case ExpressionType.LessThan:
case ExpressionType.LessThanOrEqual:
case ExpressionType.LeftShift:
case ExpressionType.RightShift:
case ExpressionType.Coalesce:
case ExpressionType.ArrayIndex:
{
BinaryExpression binExp = (BinaryExpression)expression;
Expression left = Walk(binExp.Left), right = Walk(binExp.Right);
return (left == null && right == null) ? null : Expression.MakeBinary(
binExp.NodeType, left ?? binExp.Left, right ?? binExp.Right, binExp.IsLiftedToNull,
binExp.Method, binExp.Conversion);
}
case ExpressionType.Not:
case ExpressionType.UnaryPlus:
case ExpressionType.Negate:
case ExpressionType.NegateChecked:
case ExpressionType.Convert:
case ExpressionType.ConvertChecked:
case ExpressionType.TypeAs:
case ExpressionType.ArrayLength:
{
UnaryExpression unExp = (UnaryExpression)expression;
Expression operand = Walk(unExp.Operand);
return operand == null ? null : Expression.MakeUnary(unExp.NodeType, operand,
unExp.Type, unExp.Method);
}
case ExpressionType.Conditional:
{
ConditionalExpression ce = (ConditionalExpression)expression;
Expression test = Walk(ce.Test), ifTrue = Walk(ce.IfTrue), ifFalse = Walk(ce.IfFalse);
if (test == null && ifTrue == null && ifFalse == null) return null;
return Expression.Condition(test ?? ce.Test, ifTrue ?? ce.IfTrue, ifFalse ?? ce.IfFalse);
}
case ExpressionType.Call:
{
MethodCallExpression mce = (MethodCallExpression)expression;
Expression instance = Walk(mce.Object);
Expression[] args = Walk(mce.Arguments);
if (instance == null && !HasValue(args)) return null;
return Expression.Call(instance, mce.Method, CoalesceTerms(args, mce.Arguments));
}
case ExpressionType.TypeIs:
{
TypeBinaryExpression tbe = (TypeBinaryExpression)expression;
tmp = Walk(tbe.Expression);
return tmp == null ? null : Expression.TypeIs(tmp, tbe.TypeOperand);
}
case ExpressionType.New:
{
NewExpression ne = (NewExpression)expression;
Expression[] args = Walk(ne.Arguments);
if (HasValue(args)) return null;
return ne.Members == null ? Expression.New(ne.Constructor, CoalesceTerms(args, ne.Arguments))
: Expression.New(ne.Constructor, CoalesceTerms(args, ne.Arguments), ne.Members);
}
case ExpressionType.ListInit:
{
ListInitExpression lie = (ListInitExpression)expression;
NewExpression ctor = (NewExpression)Walk(lie.NewExpression);
var inits = lie.Initializers.Select(init => new
{
Original = init,
NewArgs = Walk(init.Arguments)
}).ToArray();
if (ctor == null && !inits.Any(init => HasValue(init.NewArgs))) return null;
ElementInit[] initArr = inits.Select(init => Expression.ElementInit(
init.Original.AddMethod, CoalesceTerms(init.NewArgs, init.Original.Arguments))).ToArray();
return Expression.ListInit(ctor ?? lie.NewExpression, initArr);
}
case ExpressionType.NewArrayBounds:
case ExpressionType.NewArrayInit:
/* not quite right... leave as not-implemented for now
{
NewArrayExpression nae = (NewArrayExpression)expression;
Expression[] expr = Walk(nae.Expressions);
if (!HasValue(expr)) return null;
return expression.NodeType == ExpressionType.NewArrayBounds
? Expression.NewArrayBounds(nae.Type, CoalesceTerms(expr, nae.Expressions))
: Expression.NewArrayInit(nae.Type, CoalesceTerms(expr, nae.Expressions));
}*/
case ExpressionType.Invoke:
case ExpressionType.Lambda:
case ExpressionType.MemberInit:
case ExpressionType.Quote:
throw new NotImplementedException("Not implemented: " + expression.NodeType);
default:
throw new NotSupportedException("Not supported: " + expression.NodeType);
}
}
}
static class Program
{
static void Main()
{
Expression<Func<GrandParent, Parent>> myFirst = gp => gp.Parent;
Expression<Func<Parent, string>> mySecond = p => p.Child.Name;
Expression<Func<GrandParent, string>> outputWithInline = myFirst.Combine(mySecond, false);
Expression<Func<GrandParent, string>> outputWithoutInline = myFirst.Combine(mySecond, true);
Expression<Func<GrandParent, string>> call =
ExpressionUtils.Combine<GrandParent, Parent, string>(
gp => gp.Parent, p => p.Method(p.Child.Name), true);
unchecked
{
Expression<Func<double, double>> mathUnchecked =
ExpressionUtils.Combine<double, double, double>(x => (x * x) + x, x => x - (x / x), true);
}
checked
{
Expression<Func<double, double>> mathChecked =
ExpressionUtils.Combine<double, double, double>(x => x - (x * x) , x => (x / x) + x, true);
}
Expression<Func<int,int>> bitwise =
ExpressionUtils.Combine<int, int, int>(x => (x & 0x01) | 0x03, x => x ^ 0xFF, true);
Expression<Func<int, bool>> logical =
ExpressionUtils.Combine<int, bool, bool>(x => x == 123, x => x != false, true);
Expression<Func<int[][], int>> arrayAccess =
ExpressionUtils.Combine<int[][], int[], int>(x => x[0], x => x[0], true);
Expression<Func<string, bool>> isTest =
ExpressionUtils.Combine<string,object,bool>(s=>s, s=> s is Regex, true);
Expression<Func<List<int>>> f = () => new List<int>(new int[] { 1, 1, 1 }.Length);
Expression<Func<string, Regex>> asTest =
ExpressionUtils.Combine<string, object, Regex>(s => s, s => s as Regex, true);
var initTest = ExpressionUtils.Combine<int, int[], List<int>>(i => new[] {i,i,i},
arr => new List<int>(arr.Length), true);
var anonAndListTest = ExpressionUtils.Combine<int, int, List<int>>(
i => new { age = i }.age, i => new List<int> {i, i}, true);
/*
var arrBoundsInit = ExpressionUtils.Combine<int, int[], int[]>(
i => new int[i], arr => new int[arr[0]] , true);
var arrInit = ExpressionUtils.Combine<int, int, int[]>(
i => i, i => new int[1] { i }, true);*/
}
}
I am assuming that your goal is to obtain 12 the expression tree that you would have obtained, had you actually compiled the "combined" lambda. It's 11 much easier to construct a new expression 10 tree that simply invokes the given expression 9 trees appropriately, but I assume that's 8 not what you want.
- extract the body of first, cast it to MemberExpression. Call this firstBody.
- extract the body of second, call this secondBody
- extract the parameter of first. Call this firstParam.
- extract the parameter of second. Call this secondParam.
- Now, the hard part. Write a visitor pattern implementation which searches through secondBody looking for the single usage of secondParam. (This will be much easier if you know that it's only member access expressions, but you can solve the problem in general.) When you find it, construct a new expression of the same type as its parent, substituting in firstBody for the parameter. Continue to rebuild the transformed tree on the way back out; remember, all you have to rebuild is the "spine" of the tree that contains the parameter reference.
- the result of the visitor pass will be a rewritten secondBody with no occurrences of secondParam, only occurences of expressions involving firstParam.
- construct a new lambda expression with that body as its body, and firstParam as its param.
- and you're done!
Matt Warren's blog might 7 be a good thing for you to read. He designed 6 and implemented all this stuff and has written 5 a lot about ways to rewrite expression trees 4 effectively. (I only did the compiler end 3 of things.)
UPDATE:
As this related answer points out, in .NET 4 there is now a 2 base class for expression rewriters that 1 makes this sort of thing a lot easier.
I'm not sure what you mean by it not being 5 a nested function call, but this will do 4 the trick - with an example:
using System;
using System.IO;
using System.Linq.Expressions;
class Test
{
static Expression<Func<TOuter, TInner>> Combine<TOuter, TMiddle, TInner>
(Expression<Func<TOuter, TMiddle>> first,
Expression<Func<TMiddle, TInner>> second)
{
var parameter = Expression.Parameter(typeof(TOuter), "x");
var firstInvoke = Expression.Invoke(first, new[] { parameter });
var secondInvoke = Expression.Invoke(second, new[] { firstInvoke} );
return Expression.Lambda<Func<TOuter, TInner>>(secondInvoke, parameter);
}
static void Main()
{
Expression<Func<int, string>> first = x => (x + 1).ToString();
Expression<Func<string, StringReader>> second = y => new StringReader(y);
Expression<Func<int, StringReader>> output = Combine(first, second);
Func<int, StringReader> compiled = output.Compile();
var reader = compiled(10);
Console.WriteLine(reader.ReadToEnd());
}
}
I don't know 3 how efficient the generated code will be 2 compared with a single lambda expression, but 1 I suspect it won't be too bad.
For a complete solution have a look at LINQKit:
Expression<Func<GrandParent, string>> output = gp => mySecond.Invoke(myFirst.Invoke(gp));
output = output.Expand().Expand();
output.ToString()
prints 3 out
gp => gp.Parent.Child.Name
whereas Jon Skeet's solution yields
x => Invoke(p => p.Child.Name,Invoke(gp => gp.Parent,x))
I 2 guess that's what you're referring to as 1 'nested function calls'.
Try this:
public static Expression<Func<TOuter, TInner>> Combine<TOuter, TMiddle, TInner>(
Expression<Func<TOuter, TMiddle>> first,
Expression<Func<TMiddle, TInner>> second)
{
return x => second.Compile()(first.Compile()(x));
}
and the usage:
Expression<Func<GrandParent, Parent>> myFirst = gp => gp.Parent;
Expression<Func<Parent, string>> mySecond = p => p.Child.Name;
Expression<Func<GrandParent, string>> output = Combine(myFirst, mySecond);
var grandParent = new GrandParent
{
Parent = new Parent
{
Child = new Child
{
Name = "child name"
}
}
};
var childName = output.Compile()(grandParent);
Console.WriteLine(childName); // prints "child name"
0
public static Expression<Func<T, TResult>> And<T, TResult>(this Expression<Func<T, TResult>> expr1, Expression<Func<T, TResult>> expr2)
{
var invokedExpr = Expression.Invoke(expr2, expr1.Parameters.Cast<Expression>());
return Expression.Lambda<Func<T, TResult>>(Expression.AndAlso(expr1.Body, invokedExpr), expr1.Parameters);
}
public static Expression<Func<T, bool>> Or<T>(this Expression<Func<T, bool>> expr1, Expression<Func<T, bool>> expr2)
{
var invokedExpr = Expression.Invoke(expr2, expr1.Parameters.Cast<Expression>());
return Expression.Lambda<Func<T, bool>>(Expression.OrElse(expr1.Body, invokedExpr), expr1.Parameters);
}
0
After a half-day's digging came up with 11 the following solution (much simpler than 10 the accepted answer):
For generic lambda 9 composition:
public static Expression<Func<X, Z>> Compose<X, Y, Z>(Expression<Func<Y, Z>> f, Expression<Func<X, Y>> g)
{
return Expression.Lambda<Func<X, Z>>(Expression.Invoke(f, Expression.Invoke(g, g.Parameters[0])), g.Parameters);
}
This combines two expressions 8 in one, i.e. applies the first expression 7 to the result of the second.
So if we have 6 f(y) and g(x), combine(f,g)(x) === f(g(x))
Transitive 5 and associative, so the combinator can be 4 chained
More specifically, for property access 3 (needed for MVC/EF):
public static Expression<Func<X, Z>> Property<X, Y, Z>(Expression<Func<X, Y>> fObj, Expression<Func<Y, Z>> fProp)
{
return Expression.Lambda<Func<X, Z>>(Expression.Property(fObj.Body, (fProp.Body as MemberExpression).Member as PropertyInfo), fObj.Parameters);
}
Note: fProp
must be a simple 2 property access expression, such as x => x.Prop
.
fObj
can 1 be any expression (but must be MVC-compatible)
With a toolkit called Layer Over LINQ, there's an extension 3 method that does exactly this, combines 2 two expressions to create a new one suitable 1 for use in LINQ to Entities.
Expression<Func<GrandParent, Parent>>> myFirst = gp => gp.Parent;
Expression<Func<Parent, string>> mySecond = p => p.Child.Name;
Expression<Func<GrandParent, string>> output = myFirst.Chain(mySecond);
More Related questions
We use cookies to improve the performance of the site. By staying on our site, you agree to the terms of use of cookies.