Skip to content

Commit a7fdcbc

Browse files
shashi4uOceania2018
authored andcommitted
Added support for user defined decimal precision for np.around() and TensorEngine.Round()
1 parent 1fed94d commit a7fdcbc

3 files changed

Lines changed: 107 additions & 0 deletions

File tree

src/NumSharp.Core/Backends/Default/Math/Default.Round.cs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ public partial class DefaultEngine
99
{
1010
public override NDArray Round(in NDArray nd, Type dtype) => Round(nd, dtype?.GetTypeCode());
1111

12+
public override NDArray Round(in NDArray nd, int decimals, Type dtype) => Round(nd, decimals, dtype?.GetTypeCode());
13+
1214
public override NDArray Round(in NDArray nd, NPTypeCode? typeCode = null)
1315
{
1416
if (nd.size == 0)
@@ -59,6 +61,61 @@ public override NDArray Round(in NDArray nd, NPTypeCode? typeCode = null)
5961
}
6062
default:
6163
throw new NotSupportedException();
64+
#endif
65+
}
66+
}
67+
}
68+
69+
public override NDArray Round(in NDArray nd, int decimals, NPTypeCode? typeCode = null)
70+
{
71+
if (nd.size == 0)
72+
return nd.Clone();
73+
74+
var @out = Cast(nd, ResolveUnaryReturnType(nd, typeCode), copy: true);
75+
var len = @out.size;
76+
77+
unsafe
78+
{
79+
switch (@out.GetTypeCode)
80+
{
81+
#if _REGEN
82+
%foreach except(supported_numericals, "Decimal"),except(supported_numericals_lowercase, "decimal")%
83+
case NPTypeCode.#1:
84+
{
85+
var out_addr = (#2*)@out.Address;
86+
for (int i = 0; i < len; i++) out_addr[i] = Converts.To#1(Math.Round(out_addr[i], decimals));
87+
return @out;
88+
}
89+
%
90+
case NPTypeCode.Decimal:
91+
{
92+
var out_addr = (decimal*)@out.Address;
93+
for (int i = 0; i < len; i++) out_addr[i] = (DecimalEx.Round(out_addr[i], decimals));
94+
return @out;
95+
}
96+
default:
97+
throw new NotSupportedException();
98+
#else
99+
case NPTypeCode.Double:
100+
{
101+
var out_addr = (double*)@out.Address;
102+
for (int i = 0; i < len; i++) out_addr[i] = Converts.ToDouble(Math.Round(out_addr[i], decimals));
103+
return @out;
104+
}
105+
case NPTypeCode.Single:
106+
{
107+
var out_addr = (float*)@out.Address;
108+
for (int i = 0; i < len; i++) out_addr[i] = Converts.ToSingle(Math.Round(out_addr[i], decimals));
109+
return @out;
110+
}
111+
case NPTypeCode.Decimal:
112+
{
113+
var out_addr = (decimal*)@out.Address;
114+
for (int i = 0; i < len; i++) out_addr[i] = (decimal.Round(out_addr[i], decimals));
115+
return @out;
116+
}
117+
default:
118+
throw new NotSupportedException();
62119
#endif
63120
}
64121
}

src/NumSharp.Core/Backends/TensorEngine.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ public abstract class TensorEngine
7979
public abstract NDArray Ceil(in NDArray nd, Type dtype);
8080
public abstract NDArray Ceil(in NDArray nd, NPTypeCode? typeCode = null);
8181
public abstract NDArray Round(in NDArray nd, Type dtype);
82+
public abstract NDArray Round(in NDArray nd, int decimals, Type dtype);
8283
public abstract NDArray Round(in NDArray nd, NPTypeCode? typeCode = null);
84+
public abstract NDArray Round(in NDArray nd, int decimals, NPTypeCode? typeCode = null);
8385
public abstract (NDArray Fractional, NDArray Intergral) ModF(in NDArray nd, Type dtype);
8486
public abstract (NDArray Fractional, NDArray Intergral) ModF(in NDArray nd, NPTypeCode? typeCode = null);
8587

src/NumSharp.Core/Math/np.round.cs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@ public static partial class np
1616
public static NDArray round_(in NDArray x, NPTypeCode? outType = null)
1717
=> x.TensorEngine.Round(x, outType);
1818

19+
/// <summary>
20+
/// Evenly round to the given number of decimals.
21+
/// </summary>
22+
/// <param name="x">Input array</param>
23+
/// <param name="decimals">Number of decimal places to round to</param>
24+
/// <param name="outType">The dtype the returned ndarray should be of, only non integer values are supported.</param>
25+
/// <returns>An array of the same type as a, containing the rounded values. Unless out was specified, a new array is created. A reference to the result is returned.
26+
/// The real and imaginary parts of complex numbers are rounded separately.The result of rounding a float is a float.</returns>
27+
/// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.around.html</remarks>
28+
public static NDArray round_(in NDArray x, int decimals, NPTypeCode? outType = null)
29+
=> x.TensorEngine.Round(x, decimals, outType);
30+
1931
/// <summary>
2032
/// Evenly round to the given number of decimals.
2133
/// </summary>
@@ -27,6 +39,18 @@ public static NDArray round_(in NDArray x, NPTypeCode? outType = null)
2739
public static NDArray round_(in NDArray x, Type outType)
2840
=> x.TensorEngine.Round(x, outType);
2941

42+
/// <summary>
43+
/// Evenly round to the given number of decimals.
44+
/// </summary>
45+
/// <param name="x">Input array</param>
46+
/// <param name="decimals">Number of decimal places to round to</param>
47+
/// <param name="outType">The dtype the returned ndarray should be of, only non integer values are supported.</param>
48+
/// <returns>An array of the same type as a, containing the rounded values. Unless out was specified, a new array is created. A reference to the result is returned.
49+
/// The real and imaginary parts of complex numbers are rounded separately.The result of rounding a float is a float.</returns>
50+
/// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.around.html</remarks>
51+
public static NDArray round_(in NDArray x, int decimals, Type outType)
52+
=> x.TensorEngine.Round(x, decimals, outType);
53+
3054
/// <summary>
3155
/// Evenly round to the given number of decimals.
3256
/// </summary>
@@ -38,6 +62,18 @@ public static NDArray round_(in NDArray x, Type outType)
3862
public static NDArray around(in NDArray x, NPTypeCode? outType = null)
3963
=> x.TensorEngine.Round(x, outType);
4064

65+
/// <summary>
66+
/// Evenly round to the given number of decimals.
67+
/// </summary>
68+
/// <param name="x">Input array</param>
69+
/// <param name="decimals">Number of decimal places to round to</param>
70+
/// <param name="outType">The dtype the returned ndarray should be of, only non integer values are supported.</param>
71+
/// <returns>An array of the same type as a, containing the rounded values. Unless out was specified, a new array is created. A reference to the result is returned.
72+
/// The real and imaginary parts of complex numbers are rounded separately.The result of rounding a float is a float.</returns>
73+
/// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.around.html</remarks>
74+
public static NDArray around(in NDArray x, int decimals, NPTypeCode? outType = null)
75+
=> x.TensorEngine.Round(x, decimals, outType);
76+
4177
/// <summary>
4278
/// Evenly round to the given number of decimals.
4379
/// </summary>
@@ -48,5 +84,17 @@ public static NDArray around(in NDArray x, NPTypeCode? outType = null)
4884
/// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.around.html</remarks>
4985
public static NDArray around(in NDArray x, Type outType)
5086
=> x.TensorEngine.Round(x, outType);
87+
88+
/// <summary>
89+
/// Evenly round to the given number of decimals.
90+
/// </summary>
91+
/// <param name="x">Input array</param>
92+
/// <param name="decimals">Number of decimal places to round to</param>
93+
/// <param name="outType">The dtype the returned ndarray should be of, only non integer values are supported.</param>
94+
/// <returns>An array of the same type as a, containing the rounded values. Unless out was specified, a new array is created. A reference to the result is returned.
95+
/// The real and imaginary parts of complex numbers are rounded separately.The result of rounding a float is a float.</returns>
96+
/// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.around.html</remarks>
97+
public static NDArray around(in NDArray x, int decimals, Type outType)
98+
=> x.TensorEngine.Round(x, decimals, outType);
5199
}
52100
}

0 commit comments

Comments
 (0)