Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions datafusion/functions/src/math/factorial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::{ArrayRef, AsArray, Int64Array};
use arrow::array::{ArrayRef, AsArray, Decimal256Array};
use std::sync::Arc;

use arrow::datatypes::DataType::Int64;
use arrow::datatypes::{DataType, Int64Type};
use arrow::datatypes::DataType::{Decimal256, Int64};
use arrow::datatypes::{DECIMAL256_MAX_PRECISION, DataType, Int64Type};
use arrow_buffer::i256;

use datafusion_common::{
Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
DataFusionError, Result, ScalarValue, internal_err, utils::take_function_args,
};
use datafusion_expr::{
ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Expand Down Expand Up @@ -63,6 +64,8 @@ impl FactorialFunc {
}
}

const FACTORIAL_RETURN_TYPE: DataType = Decimal256(DECIMAL256_MAX_PRECISION, 0);

impl ScalarUDFImpl for FactorialFunc {
fn name(&self) -> &str {
"factorial"
Expand All @@ -73,7 +76,7 @@ impl ScalarUDFImpl for FactorialFunc {
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Int64)
Ok(FACTORIAL_RETURN_TYPE)
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Expand All @@ -82,13 +85,21 @@ impl ScalarUDFImpl for FactorialFunc {
match arg {
ColumnarValue::Scalar(scalar) => {
if scalar.is_null() {
return Ok(ColumnarValue::Scalar(ScalarValue::Int64(None)));
return Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(
None,
DECIMAL256_MAX_PRECISION,
0,
)));
}

match scalar {
ScalarValue::Int64(Some(v)) => {
let result = compute_factorial(v)?;
Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(result))))
Ok(ColumnarValue::Scalar(ScalarValue::Decimal256(
Some(result),
DECIMAL256_MAX_PRECISION,
0,
)))
}
_ => {
internal_err!(
Expand All @@ -100,9 +111,13 @@ impl ScalarUDFImpl for FactorialFunc {
}
ColumnarValue::Array(array) => match array.data_type() {
Int64 => {
let result: Int64Array = array
let result = array
.as_primitive::<Int64Type>()
.try_unary(compute_factorial)?;
.iter()
.map(|value| value.map(compute_factorial).transpose())
.collect::<Result<Vec<_>>>()?;
let result = Decimal256Array::from(result)
.with_precision_and_scale(DECIMAL256_MAX_PRECISION, 0)?;
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
}
other => {
Expand Down Expand Up @@ -139,14 +154,23 @@ const FACTORIALS: [i64; 21] = [
6402373705728000,
121645100408832000,
2432902008176640000,
]; // if return type changes, this constant needs to be updated accordingly
];

fn compute_factorial(n: i64) -> Result<i64> {
fn compute_factorial(n: i64) -> Result<i256> {
if n < 0 {
Ok(1)
} else if n < FACTORIALS.len() as i64 {
Ok(FACTORIALS[n as usize])
} else {
exec_err!("Overflow happened on FACTORIAL({n})")
return Ok(i256::from(1));
}

if n < FACTORIALS.len() as i64 {
return Ok(i256::from(FACTORIALS[n as usize]));
}

let mut result = i256::from(FACTORIALS[FACTORIALS.len() - 1]);
for value in FACTORIALS.len() as i64..=n {
result = result.checked_mul(i256::from(value)).ok_or_else(|| {
DataFusionError::Execution(format!("Overflow happened on FACTORIAL({n})"))
})?;
}

Ok(result)
}
16 changes: 13 additions & 3 deletions datafusion/sqllogictest/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -461,19 +461,29 @@ select round(exp(a), 5), round(exp(e), 5), round(exp(f), 5) from signed_integers
## factorial

# factorial scalar function
query III rowsort
query RRR rowsort
select factorial(0), factorial(10), factorial(15);
----
1 3628800 1307674368000

query TR
select arrow_typeof(factorial(21)), factorial(21);
----
Decimal256(76, 0) 51090942171709440000

query R
select factorial(21);
----
51090942171709440000

# factorial scalar nulls
query I rowsort
query R rowsort
select factorial(null);
----
NULL

# factorial with columns
query III rowsort
query RRR rowsort
select factorial(a), factorial(e), factorial(f) from unsigned_integers;
----
1 24 3628800
Expand Down
Loading