号称面试的题目总是非常有趣的,这里是又一个例子:

【原题出处

【问题梗概】

求一个函数的一阶导数。

【代码方案】

 

[csharp] 
  1. namespace Derivative  
  2. {  
  3.     class Program  
  4.     {  
  5.         // 求一个节点表达的算式的导函数    
  6.         static Expression GetDerivative(Expression node)  
  7.         {  
  8.             if (node.NodeType == ExpressionType.Add  
  9.                 || node.NodeType == ExpressionType.Subtract)  
  10.             {   // 该节点在做加减法,套用加减法导数公式    
  11.                 BinaryExpression binexp = (BinaryExpression)node;  
  12.                 Expression dleft = GetDerivative(binexp.Left);  
  13.                 Expression dright = GetDerivative(binexp.Right);  
  14.                 BinaryExpression resbinexp;  
  15.   
  16.                 if (node.NodeType == ExpressionType.Add)  
  17.                     resbinexp = Expression.Add(dleft, dright);  
  18.                 else  
  19.                     resbinexp = Expression.Subtract(dleft, dright);  
  20.                 return resbinexp;  
  21.             }  
  22.             else if (node.NodeType == ExpressionType.Multiply)  
  23.             {   // 该节点在做乘法,套用乘法导数公式    
  24.                 BinaryExpression binexp = (BinaryExpression)node;  
  25.                 Expression left = binexp.Left;  
  26.                 Expression right = binexp.Right;  
  27.   
  28.                 Expression dleft = GetDerivative(left);  
  29.                 Expression dright = GetDerivative(right);  
  30.   
  31.                 return Expression.Add(Expression.Multiply(dleft, right),  
  32.                     Expression.Multiply(left, dright));  
  33.             }  
  34.             else if (node.NodeType == ExpressionType.Parameter)  
  35.             {   // 该节点是x本身(叶子节点),故而其导数即常数1    
  36.                 return Expression.Constant(1.0);  
  37.             }  
  38.             else if (node.NodeType == ExpressionType.Constant)  
  39.             {   // 该节点是一个常数(叶子节点),故其导数为零    
  40.                 return Expression.Constant(0.0);  
  41.             }  
  42.             else if (node.NodeType == ExpressionType.Call)  
  43.             {  
  44.                 MethodCallExpression callexp = (MethodCallExpression)node;  
  45.                 Expression arg0 = callexp.Arguments[0];  
  46.                 // 一下一元函数求导后均需要乘上自变量的导数  
  47.                 Expression darg0 = GetDerivative(arg0);  
  48.                 if (callexp.Method.Name == "Exp")  
  49.                 {  
  50.                     // 指数函数的导数还是其本身  
  51.                     return Expression.Multiply(  
  52.                            Expression.Call(null, callexp.Method, arg0), darg0);  
  53.                 }  
  54.                 else if (callexp.Method.Name == "Sin")  
  55.                 {  
  56.                     // 正弦函数的倒数是余弦函数  
  57.                     MethodInfo miCos = typeof(Math).GetMethod("Cos",   
  58.                                        BindingFlags.Public | BindingFlags.Static);  
  59.                     return Expression.Multiply(  
  60.                            Expression.Call(null, miCos, arg0), darg0);  
  61.                 }  
  62.                 else if (callexp.Method.Name == "Cos")  
  63.                 {  
  64.                     // 余弦函数的导数是正弦函数的相反数  
  65.                     MethodInfo miSin = typeof(Math).GetMethod("Sin",   
  66.                                        BindingFlags.Public | BindingFlags.Static);  
  67.                     return Expression.Multiply(  
  68.                            Expression.Negate(Expression.Call(null, miSin, arg0)), darg0);  
  69.                 }  
  70.             }  
  71.   
  72.             throw new NotImplementedException();    // 其余的尚未实现            
  73.         }  
  74.   
  75.         static Func<doubledouble> GetDerivative(Expression<Func<doubledouble>> func)  
  76.         {  
  77.             // 从Lambda表达式中获得函数体    
  78.             Expression resBody = GetDerivative(func.Body);  
  79.   
  80.             // 需要续用Lambda表达式的自变量    
  81.             ParameterExpression parX = func.Parameters[0];  
  82.   
  83.             Expression<Func<doubledouble>> resFunc  
  84.                 = (Expression<Func<doubledouble>>)Expression.Lambda(resBody, parX);  
  85.   
  86.             Console.WriteLine("diff function = {0}", resFunc);  
  87.   
  88.             // 编译成CLR的IL表达的函数    
  89.             return resFunc.Compile();  
  90.         }  
  91.   
  92.         static double GetDerivative(Expression<Func<doubledouble>> func, double x)  
  93.         {  
  94.             Func<doubledouble> diff = GetDerivative(func);  
  95.             return diff(x);  
  96.         }  
  97.   
  98.         static void Main(string[] args)  
  99.         {  
  100.             // 举例:求出函数f(x) = cos(x*x)+sin(3*x)+exp(2*x)在x=2.0处的导数    
  101.             double y = GetDerivative(x => Math.Cos(x*x) + Math.Sin(3*x) + Math.Exp(2*x), 2.0);  
  102.             Console.WriteLine("f'(x) = {0}", y);  
  103.         }  
  104.     }  
  105. }    

 

 

【实现大意】

用表达式分解并递归求导(过程是相当容易的,比想象的还容易)。目前只是实现了一个最简单的模型。
【优势】
给出的是解析解,在求导运算方面没有任何数值解的误差,输出运算也是瞬时的,时间复杂度仅和表达式复杂度相关。
【限制】
1. 函数只能以Lambda表达式输入,只能是能求出解析解的表达式
2. 目前只实现了加减法和乘法
【后续扩展】
1. 实现其他运算符(没有太大难度,只是比较繁琐而已)
2. 表达式树优化(也不太难的,根据情况定),最基本的可以从常数乘法开始……
3. 条件运算符的处理(这个会变得极难极复杂,但一定程度上实现分段函数求导),其他特殊情况(对求导还可以,如果考虑求不定积分问题可能会有很多特殊情况和hardcode)
4. 输入端向字符串解析过渡;复杂运算符->逐渐向自定义的数据结构过渡?……
...