views:

2230

answers:

8

Given a generic class definition like public class ConstrainedNumber<T>: IEquatable<ConstrainedNumber<T>>, IEquatable<T>, IComparable<ConstrainedNumber<T>>, IComparable<T>, IComparable where T:struct, IComparable, IComparable<T>, IEquatable<T>, how can I define arithmetic operators for it?

The following does not compile, because the '+' operator cannot be applied to types 'T' and 'T':

public static T operator +( ConstrainedNumber<T> x, ConstrainedNumber<T> y)
{
    return x._value + y._value;
}

The generic type 'T' is constrained with the 'where' keyword as you can see, but I need a constraint for number types that have arithmetic operators (IArithmetic?).

'T' will be a primitive number type such as int, float, etc. Is there a 'where' constraint for such types?

+4  A: 

Unfortunately there is no way to constrain a generic parameter to be an integral type (Edit: I guess "arithmetical type" might be a better word as this does not pertain to just integers).

It would be nice to be able to do something like this:

where T : integral // or "arithmetical" depending on how pedantic you are

or

where T : IArithmetic

I would suggest that you read Generic Operators by our very own Marc Gravell and Jon Skeet. It explains why this is such a difficult problem and what can be done to work around it.

.NET 2.0 introduced generics into the .NET world, which opened the door for many elegant solutions to existing problems. Generic constraints can be used to restrict the type-arguments to known interfaces etc, to ensure access to functionality - or for simple equality/inequality tests the Comparer.Default and EqualityComparer.Default singletons implement IComparer and IEqualityComparer respectively (allowing us to sort elements for instance, without having to know anything about the "T" in question).

With all this, though, there is still a big gap when it comes to operators. Because operators are declared as static methods, there is no IMath or similar equivalent interface that all the numeric types implement; and indeed, the flexibility of operators would make this very hard to do in a meaningful way. Worse: many of the operators on primitive types don't even exist as operators; instead there are direct IL methods. [emphasis mine] To make the situation even more complex, Nullable<> demands the concept of "lifted operators", where the inner "T" describes the operators applicable to the nullable type - but this is implemented as a language feature, and is not provided by the runtime (making reflection even more fun).

Andrew Hare
That article was almost entirely written by Marc Gravell, in fact - who also wrote the helper classes in MiscUtil.
Jon Skeet
I corrected the citation - thanks for pointing that out!
Andrew Hare
"...there is no IMath or similar equivalent interface that all the numeric types implement." Actually, all the numeric types implement IConvertible, which for all numeric types (byte,sbyte,short,ushort,int,uint,long,ulong,float,double,decimal) results in a simple implicit or explicit conversion at compile-time through a call to IConvertible.ToDecimal (internally Convert.ToDecimal) in my arithmetic overloads. The decimal type is precise enough to store all the other types, including Int64 (~20 significant figures). This is perfect on platforms that support decimal arithmetic in the hardware.
Triynko
A: 

Unfortunately, this is not possible as there is not an IArithmetic (as you said) interface defined for integers. You can wrap those primitive types in classes that do implement such an interface.

Mehrdad Afshari
+2  A: 

No, this does not work. But there are some suggestions on how to solve the problem. I did the following (using some ideas from different sources on the net):

public delegate TResult BinaryOperator<TLeft, TRight, TResult>(TLeft left, TRight right);

/// <summary>
/// Provide efficient generic access to either native or static operators for the given type combination.
/// </summary>
/// <typeparam name="TLeft">The type of the left operand.</typeparam>
/// <typeparam name="TRight">The type of the right operand.</typeparam>
/// <typeparam name="TResult">The type of the result value.</typeparam>
/// <remarks>Inspired by Keith Farmer's code on CodeProject:<br/>http://www.codeproject.com/KB/cs/genericoperators.aspx&lt;/remarks&gt;
public static class Operator<TLeft, TRight, TResult> {
 private static BinaryOperator<TLeft, TRight, TResult> addition;
 private static BinaryOperator<TLeft, TRight, TResult> bitwiseAnd;
 private static BinaryOperator<TLeft, TRight, TResult> bitwiseOr;
 private static BinaryOperator<TLeft, TRight, TResult> division;
 private static BinaryOperator<TLeft, TRight, TResult> exclusiveOr;
 private static BinaryOperator<TLeft, TRight, TResult> leftShift;
 private static BinaryOperator<TLeft, TRight, TResult> modulus;
 private static BinaryOperator<TLeft, TRight, TResult> multiply;
 private static BinaryOperator<TLeft, TRight, TResult> rightShift;
 private static BinaryOperator<TLeft, TRight, TResult> subtraction;

 /// <summary>
 /// Gets the addition operator + (either native or "op_Addition").
 /// </summary>
 /// <value>The addition operator.</value>
 public static BinaryOperator<TLeft, TRight, TResult> Addition {
  get {
   if (addition == null) {
    addition = CreateOperator("op_Addition", OpCodes.Add);
   }
   return addition;
  }
 }

 /// <summary>
 /// Gets the modulus operator % (either native or "op_Modulus").
 /// </summary>
 /// <value>The modulus operator.</value>
 public static BinaryOperator<TLeft, TRight, TResult> Modulus {
  get {
   if (modulus == null) {
    modulus = CreateOperator("op_Modulus", OpCodes.Rem);
   }
   return modulus;
  }
 }

 /// <summary>
 /// Gets the exclusive or operator ^ (either native or "op_ExclusiveOr").
 /// </summary>
 /// <value>The exclusive or operator.</value>
 public static BinaryOperator<TLeft, TRight, TResult> ExclusiveOr {
  get {
   if (exclusiveOr == null) {
    exclusiveOr = CreateOperator("op_ExclusiveOr", OpCodes.Xor);
   }
   return exclusiveOr;
  }
 }

 /// <summary>
 /// Gets the bitwise and operator &amp; (either native or "op_BitwiseAnd").
 /// </summary>
 /// <value>The bitwise and operator.</value>
 public static BinaryOperator<TLeft, TRight, TResult> BitwiseAnd {
  get {
   if (bitwiseAnd == null) {
    bitwiseAnd = CreateOperator("op_BitwiseAnd", OpCodes.And);
   }
   return bitwiseAnd;
  }
 }

 /// <summary>
 /// Gets the division operator / (either native or "op_Division").
 /// </summary>
 /// <value>The division operator.</value>
 public static BinaryOperator<TLeft, TRight, TResult> Division {
  get {
   if (division == null) {
    division = CreateOperator("op_Division", OpCodes.Div);
   }
   return division;
  }
 }

 /// <summary>
 /// Gets the multiplication operator * (either native or "op_Multiply").
 /// </summary>
 /// <value>The multiplication operator.</value>
 public static BinaryOperator<TLeft, TRight, TResult> Multiply {
  get {
   if (multiply == null) {
    multiply = CreateOperator("op_Multiply", OpCodes.Mul);
   }
   return multiply;
  }
 }

 /// <summary>
 /// Gets the bitwise or operator | (either native or "op_BitwiseOr").
 /// </summary>
 /// <value>The bitwise or operator.</value>
 public static BinaryOperator<TLeft, TRight, TResult> BitwiseOr {
  get {
   if (bitwiseOr == null) {
    bitwiseOr = CreateOperator("op_BitwiseOr", OpCodes.Or);
   }
   return bitwiseOr;
  }
 }

 /// <summary>
 /// Gets the left shift operator &lt;&lt; (either native or "op_LeftShift").
 /// </summary>
 /// <value>The left shift operator.</value>
 public static BinaryOperator<TLeft, TRight, TResult> LeftShift {
  get {
   if (leftShift == null) {
    leftShift = CreateOperator("op_LeftShift", OpCodes.Shl);
   }
   return leftShift;
  }
 }

 /// <summary>
 /// Gets the right shift operator &gt;&gt; (either native or "op_RightShift").
 /// </summary>
 /// <value>The right shift operator.</value>
 public static BinaryOperator<TLeft, TRight, TResult> RightShift {
  get {
   if (rightShift == null) {
    rightShift = CreateOperator("op_RightShift", OpCodes.Shr);
   }
   return rightShift;
  }
 }

 /// <summary>
 /// Gets the subtraction operator - (either native or "op_Addition").
 /// </summary>
 /// <value>The subtraction operator.</value>
 public static BinaryOperator<TLeft, TRight, TResult> Subtraction {
  get {
   if (subtraction == null) {
    subtraction = CreateOperator("op_Subtraction", OpCodes.Sub);
   }
   return subtraction;
  }
 }

 private static BinaryOperator<TLeft, TRight, TResult> CreateOperator(string operatorName, OpCode opCode) {
  if (operatorName == null) {
   throw new ArgumentNullException("operatorName");
  }
  bool isPrimitive = true;
  bool isLeftNullable;
  bool isRightNullable = false;
  Type leftType = typeof(TLeft);
  Type rightType = typeof(TRight);
  MethodInfo operatorMethod = LookupOperatorMethod(ref leftType, operatorName, ref isPrimitive, out isLeftNullable) ??
                              LookupOperatorMethod(ref rightType, operatorName, ref isPrimitive, out isRightNullable);
  DynamicMethod method = new DynamicMethod(string.Format("{0}:{1}:{2}:{3}", operatorName, typeof(TLeft).FullName, typeof(TRight).FullName, typeof(TResult).FullName), typeof(TResult),
                                           new Type[] {typeof(TLeft), typeof(TRight)});
  Debug.WriteLine(method.Name, "Generating operator method");
  ILGenerator generator = method.GetILGenerator();
  if (isPrimitive) {
   Debug.WriteLine("Primitives using opcode", "Emitting operator code");
   generator.Emit(OpCodes.Ldarg_0);
   if (isLeftNullable) {
    generator.EmitCall(OpCodes.Call, typeof(TLeft).GetMethod("op_Explicit", BindingFlags.Public|BindingFlags.Static), null);
   }
   IlTypeHelper.ILType stackType = IlTypeHelper.EmitWidening(generator, IlTypeHelper.GetILType(leftType), IlTypeHelper.GetILType(rightType));
   generator.Emit(OpCodes.Ldarg_1);
   if (isRightNullable) {
    generator.EmitCall(OpCodes.Call, typeof(TRight).GetMethod("op_Explicit", BindingFlags.Public | BindingFlags.Static), null);
   }
   stackType = IlTypeHelper.EmitWidening(generator, IlTypeHelper.GetILType(rightType), stackType);
   generator.Emit(opCode);
   if (typeof(TResult) == typeof(object)) {
    generator.Emit(OpCodes.Box, IlTypeHelper.GetPrimitiveType(stackType));
   } else {
    Type resultType = typeof(TResult);
    if (IsNullable(ref resultType)) {
     generator.Emit(OpCodes.Newobj, typeof(TResult).GetConstructor(new Type[] {resultType}));
    } else {
     IlTypeHelper.EmitExplicit(generator, stackType, IlTypeHelper.GetILType(resultType));
    }
   }
  } else if (operatorMethod != null) {
   Debug.WriteLine("Call to static operator method", "Emitting operator code");
   generator.Emit(OpCodes.Ldarg_0);
   generator.Emit(OpCodes.Ldarg_1);
   generator.EmitCall(OpCodes.Call, operatorMethod, null);
   if (typeof(TResult).IsPrimitive && operatorMethod.ReturnType.IsPrimitive) {
    IlTypeHelper.EmitExplicit(generator, IlTypeHelper.GetILType(operatorMethod.ReturnType), IlTypeHelper.GetILType(typeof(TResult)));
   } else if (!typeof(TResult).IsAssignableFrom(operatorMethod.ReturnType)) {
    Debug.WriteLine("Conversion to return type", "Emitting operator code");
    generator.Emit(OpCodes.Ldtoken, typeof(TResult));
    generator.EmitCall(OpCodes.Call, typeof(Type).GetMethod("GetTypeFromHandle", new Type[] {typeof(RuntimeTypeHandle)}), null);
    generator.EmitCall(OpCodes.Call, typeof(Convert).GetMethod("ChangeType", new Type[] {typeof(object), typeof(Type)}), null);
   }
  } else {
   Debug.WriteLine("Throw NotSupportedException", "Emitting operator code");
   generator.ThrowException(typeof(NotSupportedException));
  }
  generator.Emit(OpCodes.Ret);
  return (BinaryOperator<TLeft, TRight, TResult>)method.CreateDelegate(typeof(BinaryOperator<TLeft, TRight, TResult>));
 }

 private static bool IsNullable(ref Type type) {
  if (type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(Nullable<>))) {
   type = type.GetGenericArguments()[0];
   return true;
  }
  return false;
 }

 private static MethodInfo LookupOperatorMethod(ref Type type, string operatorName, ref bool isPrimitive, out bool isNullable) {
  isNullable = IsNullable(ref type);
  if (!type.IsPrimitive) {
   isPrimitive = false;
   foreach (MethodInfo methodInfo in type.GetMethods(BindingFlags.Static|BindingFlags.Public)) {
    if (methodInfo.Name == operatorName) {
     bool isMatch = true;
     foreach (ParameterInfo parameterInfo in methodInfo.GetParameters()) {
      switch (parameterInfo.Position) {
      case 0:
       if (parameterInfo.ParameterType != typeof(TLeft)) {
        isMatch = false;
       }
       break;
      case 1:
       if (parameterInfo.ParameterType != typeof(TRight)) {
        isMatch = false;
       }
       break;
      default:
       isMatch = false;
       break;
      }
     }
     if (isMatch) {
      if (typeof(TResult).IsAssignableFrom(methodInfo.ReturnType) || typeof(IConvertible).IsAssignableFrom(methodInfo.ReturnType)) {
       return methodInfo; // full signature match
      }
     }
    }
   }
  }
  return null;
 }
}

internal static class IlTypeHelper {
 [Flags]
 public enum ILType {
  None = 0,
  Unsigned = 1,
  B8 = 2,
  B16 = 4,
  B32 = 8,
  B64 = 16,
  Real = 32,
  I1 = B8, // 2
  U1 = B8|Unsigned, // 3
  I2 = B16, // 4
  U2 = B16|Unsigned, // 5
  I4 = B32, // 8
  U4 = B32|Unsigned, // 9
  I8 = B64, //16
  U8 = B64|Unsigned, //17
  R4 = B32|Real, //40
  R8 = B64|Real //48
 }

 public static ILType GetILType(Type type) {
  if (type == null) {
   throw new ArgumentNullException("type");
  }
  if (!type.IsPrimitive) {
   throw new ArgumentException("IL native operations requires primitive types", "type");
  }
  if (type == typeof(double)) {
   return ILType.R8;
  }
  if (type == typeof(float)) {
   return ILType.R4;
  }
  if (type == typeof(ulong)) {
   return ILType.U8;
  }
  if (type == typeof(long)) {
   return ILType.I8;
  }
  if (type == typeof(uint)) {
   return ILType.U4;
  }
  if (type == typeof(int)) {
   return ILType.I4;
  }
  if (type == typeof(short)) {
   return ILType.U2;
  }
  if (type == typeof(ushort)) {
   return ILType.I2;
  }
  if (type == typeof(byte)) {
   return ILType.U1;
  }
  if (type == typeof(sbyte)) {
   return ILType.I1;
  }
  return ILType.None;
 }

 public static Type GetPrimitiveType(ILType iLType) {
  switch (iLType) {
  case ILType.R8:
   return typeof(double);
  case ILType.R4:
   return typeof(float);
  case ILType.U8:
   return typeof(ulong);
  case ILType.I8:
   return typeof(long);
  case ILType.U4:
   return typeof(uint);
  case ILType.I4:
   return typeof(int);
  case ILType.U2:
   return typeof(short);
  case ILType.I2:
   return typeof(ushort);
  case ILType.U1:
   return typeof(byte);
  case ILType.I1:
   return typeof(sbyte);
  }
  throw new ArgumentOutOfRangeException("iLType");
 }

 public static ILType EmitWidening(ILGenerator generator, ILType onStackIL, ILType otherIL) {
  if (generator == null) {
   throw new ArgumentNullException("generator");
  }
  if (onStackIL == ILType.None) {
   throw new ArgumentException("Stack needs a value", "onStackIL");
  }
  if (onStackIL < ILType.I8) {
   onStackIL = ILType.I8;
  }
  if ((onStackIL < otherIL) && (onStackIL != ILType.R4)) {
   switch (otherIL) {
   case ILType.R4:
   case ILType.R8:
    if ((onStackIL&ILType.Unsigned) == ILType.Unsigned) {
     generator.Emit(OpCodes.Conv_R_Un);
    } else if (onStackIL != ILType.R4) {
     generator.Emit(OpCodes.Conv_R8);
    } else {
     return ILType.R4;
    }
    return ILType.R8;
   case ILType.U8:
   case ILType.I8:
    if ((onStackIL&ILType.Unsigned) == ILType.Unsigned) {
     generator.Emit(OpCodes.Conv_U8);
     return ILType.U8;
    }
    if (onStackIL != ILType.I8) {
     generator.Emit(OpCodes.Conv_I8);
    }
    return ILType.I8;
   }
  }
  return onStackIL;
 }

 public static void EmitExplicit(ILGenerator generator, ILType onStackIL, ILType otherIL) {
  if (otherIL != onStackIL) {
   switch (otherIL) {
   case ILType.I1:
    generator.Emit(OpCodes.Conv_I1);
    break;
   case ILType.I2:
    generator.Emit(OpCodes.Conv_I2);
    break;
   case ILType.I4:
    generator.Emit(OpCodes.Conv_I4);
    break;
   case ILType.I8:
    generator.Emit(OpCodes.Conv_I8);
    break;
   case ILType.U1:
    generator.Emit(OpCodes.Conv_U1);
    break;
   case ILType.U2:
    generator.Emit(OpCodes.Conv_U2);
    break;
   case ILType.U4:
    generator.Emit(OpCodes.Conv_U4);
    break;
   case ILType.U8:
    generator.Emit(OpCodes.Conv_U8);
    break;
   case ILType.R4:
    generator.Emit(OpCodes.Conv_R4);
    break;
   case ILType.R8:
    generator.Emit(OpCodes.Conv_R8);
    break;
   }
  }
 }
}

Use like this: int i = Operator.Addition(3, 5);

Lucero
An interesting approach, but it looks like it belongs in a C# compiler, rather than a C# program. It looks like this is an automated version of manually building an IArithmetic<T> class, and implementations of it such as IntArithmetic:IArithmetic<int>, etc., which would provide compile-time checks.
Triynko
I've regretted several times so far that generic type constraints were not more flexible. Two things: specifying method signatures, for instance through implicit interface mapping, would be great (but I understand the technical challenges of this). Second, more BCL interfaces for the primitives etc.
Lucero
(comments are too short ;) )Indeed, this is an approach to overcome some of these limitations. But they will obviously not be able to produce compile-time checking, except maybe if combined with a toolset such as PostSharp to directly inject "good" code and complain about invalid combinations.
Lucero
Clever approach
neontapir
+1  A: 

There is no current support in .Net generics to indicate that operators are supported.

This is an oft requested feature.

It can be semi worked around (see MiscUtils) but this will not give you the syntax you desire

ShuggyCoUk
+1  A: 

I think the best you'd be able to do is use IConvertible as a constraint and do something like:

 public static operator T +(T x, T y)
    where T: IConvertible
{
    var type = typeof(T);
    if (type == typeof(String) ||
        type == typeof(DateTime)) throw new ArgumentException(String.Format("The type {0} is not supported", type.FullName), "T");

    try { return (T)(Object)(x.ToDouble(NumberFormatInfo.CurrentInfo) + y.ToDouble(NumberFormatInfo.CurrentInfo)); }
    catch(Exception ex) { throw new ApplicationException("The operation failed.", ex); }
}

That won't stop someone from passing in a String or DateTime though, so you might want to do some manual checking - but IConvertible should get you close enough, and allow you to do the operation.

Daniel Schaffer
I actually changed it to a Double, since it is the largest numeric type with the most precision, and then it gets boxed back to whatever T is. This should work, but I make no guarantees about it's performance :D
Daniel Schaffer
Yeah but then you'd have to do it specifically for each type. This is a bit hacky, true, but it should give you coverage for everything.
Daniel Schaffer
Argh stop deleting your comments :D
Daniel Schaffer
Ok! +Involves writing no new code. +Primitive number types all use IConvertible. +Subclasses always specify 'T' as a primative numeric, so no error risk. +The where constraint involves no cast to interface, so 'ToDouble' is fast as non-generic, and ToDouble implementation generally returns 'this'.
Triynko
Sorry about the comment rush. I was analyzing the pros/cons of this answer. If subclasses using ConstrainNumber<T> always specify T as a numeric, the type-checking and error handling is unnecessary, which is a huge plus. Only downside I see is the required highest-precision arithmetic and boxing.
Triynko
np :) generics make my head spin sometimes too!
Daniel Schaffer
I could actually avoid the boxing by just having the operator return the double, since arithmetic expressions using the operator can handle casting (without boxing) as necessary. I don't have to convert back to ConstrainedNumber<T> until the point of assignment anyway, where the constraining matters
Triynko
Performance-wise, I'm not worried about ToDouble. It just call the static, type-specific overload of Convert.ToDouble(int x), which just returns x, since it's an implicit conversion! Since we constrained T with 'where' as IConvertible, it's a compile-time check, so ToDouble is a standard call.
Triynko
NumberFormatInfo.CurrentInfo can be null instead, since ToDouble ignores the parameter (according to Reflector) for the primitive types.Final code:public static double operator +( ConstrainedNumber<T> x, ConstrainedNumber<T> y) {return x._value.ToDouble(null) + y._value.ToDouble(null);}
Triynko
A: 

I have seen some potential solutions involving expression trees, where the operator expression is created manually.

It's not perfect because you lose compile-time verification, but it might do the trick for you.

here's an article about that.

Denis Troller
A: 

There aren't constraints available for that but there is a way to get around the problem:

public static T operator -(T foo, T bar)
{
    return (T)System.Convert.ChangeType(
            System.Convert.ToDecimal(foo)
                -
            System.Convert.ToDecimal(bar),
                typeof(T));
}
GoodEnough
A: 

I just did this after looking here. The Vector4<T> class contains 4 numbers/axis of type T with the usual vector math. Just add 2 implicit ops to convert to and from Decimal. This is probably as un-verbose as you're going to get, but as you point out, more precise and thus heavier than it needs to be. Like you guys, I wish there was an INumeric or something!


public static Vector4<T> operator +(Vector4<T> a, Vector4<T> b)
{
    Vector4<Decimal> A = a;
    Vector4<Decimal> B = b;

    var result = new Vector4<Decimal>(A.X + B.X, A.Y + B.Y, A.Z + B.Z, A.W + B.W);

    return result;
}

public static implicit operator Vector4<Decimal>(Vector4<T> v)
{
    return new Vector4<Decimal>(
        Convert.ToDecimal(v.X), 
        Convert.ToDecimal(v.Y), 
        Convert.ToDecimal(v.Z), 
        Convert.ToDecimal(v.W));
}

public static implicit operator Vector4<T>(Vector4<Decimal> v)
{
    return new Vector4<T>(
        (T)Convert.ChangeType(v.X, typeof(T)), 
        (T)Convert.ChangeType(v.Y, typeof(T)), 
        (T)Convert.ChangeType(v.Z, typeof(T)), 
        (T)Convert.ChangeType(v.W, typeof(T)));
}

George R