From a6a381cf088bd3fd0bd5478004c494374adae01d Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Fri, 24 Nov 2023 15:05:55 +0100 Subject: [PATCH 1/4] ok --- CHANGELOG.md | 1 + src/data_type/function.rs | 126 +++++- src/data_type/value.rs | 4 +- src/differential_privacy/aggregates.rs | 511 ++++++++++++++++++++++++- src/expr/aggregate.rs | 12 + src/expr/implementation.rs | 2 +- src/expr/mod.rs | 2 +- src/expr/sql.rs | 75 ++++ src/relation/rewriting.rs | 34 ++ src/sql/expr.rs | 11 +- src/sql/mod.rs | 1 + 11 files changed, 760 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71b4cba6..a7e1bc0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ## Added +- implemented `DISTINCT`` in aggregations [#197](https://github.com/Qrlew/qrlew/issues/197) - Implemented math functions: `PI`, `DEGREES`, `TAN`, `RANDOM`, `LOG10`, `LOG2`, `SQUARE` [#196](https://github.com/Qrlew/qrlew/issues/196) ## [0.5.2] - 2023-11-19 diff --git a/src/data_type/function.rs b/src/data_type/function.rs index 00afd453..ccb4d8b2 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -1,7 +1,7 @@ use std::{ borrow::BorrowMut, cell::RefCell, - cmp, collections, + cmp, collections::{self, HashSet}, convert::{Infallible, TryFrom, TryInto}, error, fmt, hash::Hasher, @@ -1906,6 +1906,21 @@ pub fn mean() -> impl Function { ) } +/// Mean distinct aggregation +pub fn mean_distinct() -> impl Function { + // Only works on types that can be converted to floats + Aggregate::from( + data_type::Float::full(), + |values| { + let (count, sum) = values.into_iter().collect::>().into_iter().fold((0.0, 0.0), |(count, sum), value| { + (count + 1.0, sum + f64::from(value)) + }); + (sum / count).into() + }, + |(intervals, _size)| Ok(intervals.into_interval()), + ) +} + /// Aggregate as a list pub fn list() -> impl Function { null() @@ -1935,6 +1950,30 @@ pub fn count() -> impl Function { )) } +/// Count distinct aggregation +pub fn count_distinct() -> impl Function { + Polymorphic::from(( + // Any implementation + Aggregate::from( + DataType::Any, + |values| (values.iter().cloned().collect::>().len()as i64).into(), + |(_dt, size)| Ok(data_type::Integer::from_interval(1, *size.max().unwrap())), + ), + // Optional implementation + Aggregate::from( + data_type::Optional::from(DataType::Any), + |values| { + values + .iter() + .filter_map(|value| value.as_ref().and(Some(1))) + .sum::() + .into() + }, + |(_dt, size)| Ok(data_type::Integer::from_interval(0, *size.max().unwrap())), + ), + )) +} + /// Min aggregation pub fn min() -> impl Function { Polymorphic::from(( @@ -2035,6 +2074,32 @@ pub fn sum() -> impl Function { )) } +/// Sum distinct aggregation +pub fn sum_distinct() -> impl Function { + Polymorphic::from(( + // Integer implementation + Aggregate::from( + data_type::Integer::full(), + |values| values.iter().cloned().collect::>().into_iter().map(|f| *f).sum::().into(), + |(intervals, size)| { + Ok(data_type::Integer::try_from(multiply().super_image( + &DataType::structured_from_data_types([intervals.into(), size.into()]), + )?)?) + }, + ), + // Float implementation + Aggregate::from( + data_type::Float::full(), + |values| values.iter().cloned().collect::>().into_iter().map(|f| *f).sum::().into(), + |(intervals, size)| { + Ok(data_type::Float::try_from(multiply().super_image( + &DataType::structured_from_data_types([intervals.into(), size.into()]), + )?)?) + }, + ), + )) +} + /// Agg groups aggregation pub fn agg_groups() -> impl Function { null() @@ -2066,6 +2131,34 @@ pub fn std() -> impl Function { ) } +/// Standard deviation distinct aggregation +pub fn std_distinct() -> impl Function { + // Only works on types that can be converted to floats + Aggregate::from( + data_type::Float::full(), + |values| { + let (count, sum, sum_2) = + values + .into_iter() + .collect::>() + .into_iter() + .fold((0.0, 0.0, 0.0), |(count, sum, sum_2), value| { + let value: f64 = value.into(); + ( + count + 1.0, + sum + f64::from(value), + sum_2 + (f64::from(value) * f64::from(value)), + ) + }); + ((sum_2 - sum * sum / count) / (count - 1.)).sqrt().into() + }, + |(intervals, _size)| match (intervals.min(), intervals.max()) { + (Some(&min), Some(&max)) => Ok(data_type::Float::from_interval(0., (max - min) / 2.)), + _ => Ok(data_type::Float::from_min(0.)), + }, + ) +} + /// Variance aggregation pub fn var() -> impl Function { // Only works on types that can be converted to floats @@ -2095,6 +2188,37 @@ pub fn var() -> impl Function { ) } +/// Variance distinct aggregation +pub fn var_distinct() -> impl Function { + // Only works on types that can be converted to floats + Aggregate::from( + data_type::Float::full(), + |values| { + let (count, sum, sum_2) = + values + .into_iter() + .collect::>() + .into_iter() + .fold((0.0, 0.0, 0.0), |(count, sum, sum_2), value| { + let value: f64 = value.into(); + ( + count + 1.0, + sum + f64::from(value), + sum_2 + (f64::from(value) * f64::from(value)), + ) + }); + ((sum_2 - sum * sum / count) / (count - 1.)).into() + }, + |(intervals, _size)| match (intervals.min(), intervals.max()) { + (Some(&min), Some(&max)) => Ok(data_type::Float::from_interval( + 0., + ((max - min) / 2.).powi(2), + )), + _ => Ok(data_type::Float::from_min(0.)), + }, + ) +} + #[cfg(test)] mod tests { use super::{ diff --git a/src/data_type/value.rs b/src/data_type/value.rs index eb8575ef..2c4c8928 100644 --- a/src/data_type/value.rs +++ b/src/data_type/value.rs @@ -259,7 +259,7 @@ impl Variant for Boolean { impl_variant_conversions!(Boolean); /// Integer value -#[derive(Clone, Hash, PartialEq, PartialOrd, Debug, Deserialize, Serialize)] +#[derive(Clone, Hash, PartialEq, PartialOrd, Debug, Deserialize, Serialize, Eq)] pub struct Integer(i64); impl DataTyped for Integer { @@ -333,6 +333,8 @@ impl_variant_conversions!(Enum); #[derive(Clone, PartialEq, PartialOrd, Debug, Deserialize, Serialize)] pub struct Float(f64); +impl Eq for Float {} + #[allow(clippy::derive_hash_xor_eq)] impl hash::Hash for Float { fn hash(&self, state: &mut H) { diff --git a/src/differential_privacy/aggregates.rs b/src/differential_privacy/aggregates.rs index 2933a0d4..ce327b0a 100644 --- a/src/differential_privacy/aggregates.rs +++ b/src/differential_privacy/aggregates.rs @@ -1,12 +1,13 @@ use crate::{ - builder::With, + builder::{With, WithIterator}, data_type::DataTyped, differential_privacy::private_query::PrivateQuery, differential_privacy::{private_query, DPRelation, Error, Result}, - expr::{aggregate, AggregateColumn, Expr}, + expr::{aggregate, AggregateColumn, Expr, Column, Identifier}, privacy_unit_tracking::PUPRelation, - relation::{field::Field, Map, Reduce, Relation, Variant as _}, - DataType, Ready, + relation::{field::Field, Map, Reduce, Relation, Variant}, + DataType, Ready, display::Dot, + }; use std::{cmp, collections::HashMap, ops::Deref}; @@ -219,15 +220,117 @@ impl Reduce { delta: f64, ) -> Result { let pup_input = PUPRelation::try_from(self.input().clone())?; - pup_input.differentially_private_aggregates( - self.named_aggregates() + + // Split the aggregations with different DISTINCT clauses + let reduces = self.split_distinct_aggregates(); + let epsilon = epsilon / (cmp::max(reduces.len(), 1) as f64); + let delta = delta / (cmp::max(reduces.len(), 1) as f64); + + let (relation, private_query) = reduces.iter() + .map(|r| pup_input.clone().differentially_private_aggregates( + r.named_aggregates() + .into_iter() + .map(|(n, agg)| (n, agg.clone())) + .collect(), + self.group_by(), + epsilon, + delta, + )) + .reduce(|acc, dp_rel| { + let acc = acc?; + let dp_rel = dp_rel?; + Ok(DPRelation::new( + acc.relation().clone().natural_inner_join(dp_rel.relation().clone()), + acc.private_query().clone().compose(dp_rel.private_query().clone()) + )) + }) + .unwrap()? + .into(); + + let relation: Relation = Relation::map() + .input(relation) + .with_iter(self.fields().into_iter().map(|f| (f.name(), Expr::col(f.name())))) + .build(); + Ok((relation, private_query).into()) + } + + fn split_distinct_aggregates(&self) -> Vec { + let mut distinct_map: HashMap, Vec<(String, AggregateColumn)>> = HashMap::new(); + let mut first_aggs: Vec<(String, AggregateColumn)> = vec![]; + for (agg, f) in self.aggregate().iter().zip(self.fields()) { + match agg.aggregate() { + aggregate::Aggregate::CountDistinct + | aggregate::Aggregate::SumDistinct + | aggregate::Aggregate::MeanDistinct + | aggregate::Aggregate::VarDistinct + | aggregate::Aggregate::StdDistinct => distinct_map.entry(Some(agg.column().clone())).or_insert(Vec::new()).push((f.name().to_string(), agg.clone())), + aggregate::Aggregate::First => first_aggs.push((f.name().to_string(), agg.clone())), + _ => distinct_map.entry(None).or_insert(Vec::new()).push((f.name().to_string(), agg.clone())), + } + } + + first_aggs.extend( + self.group_by_columns() + .into_iter() + .map(|x| (x.to_string(), AggregateColumn::new(aggregate::Aggregate::First, x.clone()))) + .collect::>() + ); + + distinct_map.into_iter() + .map(|(identifier, mut aggs)| { + aggs.extend(first_aggs.clone()); + self.rewrite_distinct(identifier, aggs) + }) + .collect() + } + + fn rewrite_distinct(&self, identifier: Option, aggs: Vec<(String, AggregateColumn)>) -> Reduce { + let builder = Relation::reduce() + .input(self.input().clone()); + if let Some(identifier) = identifier { + let mut group_by = self.group_by_columns() .into_iter() - .map(|(n, agg)| (n, agg.clone())) - .collect(), - self.group_by(), - epsilon, - delta, - ) + .map(|c| c.clone()) + .collect::>(); + group_by.push(identifier); + + let first_aggs = group_by.clone() + .into_iter() + .map(|c| (c.to_string(), AggregateColumn::new(aggregate::Aggregate::First, c))); + + let group_by = group_by.into_iter() + .map(|c| Expr::from(c.clone())) + .collect::>(); + + let reduce: Relation = builder + .group_by_iter(group_by) + .with_iter(first_aggs) + .build(); + + let aggs = aggs.into_iter() + .map(|(s, agg)| { + let new_agg = match agg.aggregate() { + aggregate::Aggregate::MeanDistinct => aggregate::Aggregate::Mean, + aggregate::Aggregate::CountDistinct => aggregate::Aggregate::Count, + aggregate::Aggregate::SumDistinct => aggregate::Aggregate::Sum, + aggregate::Aggregate::StdDistinct => aggregate::Aggregate::Std, + aggregate::Aggregate::VarDistinct => aggregate::Aggregate::Var, + aggregate::Aggregate::First => aggregate::Aggregate::First, + _ => todo!(), + }; + (s, AggregateColumn::new(new_agg, agg.column().clone())) + }); + Relation::reduce() + .input(reduce) + .group_by_iter(self.group_by().to_vec()) + .with_iter(aggs) + .build() + } else { + builder + .group_by_iter(self.group_by().clone().to_vec()) + .with_iter(aggs) + .build() + } } } @@ -243,6 +346,8 @@ mod tests { privacy_unit_tracking::{PrivacyUnitTracking, Strategy}, sql::parse, Relation, + relation::{Schema, Variant as _}, + privacy_unit_tracking::PrivacyUnit }; use std::ops::Deref; @@ -526,4 +631,386 @@ mod tests { println!("{query}"); _ = database.query(query).unwrap(); } + + #[test] + fn test_differentially_private_aggregates_with_distinct_aggregates() { + let mut database = postgresql::test_database(); + let relations = database.relations(); + + let table = relations + .get(&["item_table".to_string()]) + .unwrap() + .deref() + .clone(); + let (epsilon, delta) = (1., 1e-3); + + let privacy_unit_tracking = PrivacyUnitTracking::from(( + &relations, + vec![ + ( + "item_table", + vec![("order_id", "order_table", "id")], + "date", + ), + ("order_table", vec![], "date"), + ], + Strategy::Hard, + )); + let pup_table = privacy_unit_tracking + .table(&table.try_into().unwrap()) + .unwrap(); + + // with group by + let reduce = Reduce::new( + "my_reduce".to_string(), + vec![ + ("count_price".to_string(), AggregateColumn::count("price")), + ("count_distinct_price".to_string(), AggregateColumn::count_distinct("price")), + ("sum_price".to_string(), AggregateColumn::sum("price")), + ("sum_distinct_price".to_string(), AggregateColumn::sum_distinct("price")), + ("item".to_string(), AggregateColumn::first("item")), + ], + vec![expr!(item)], + pup_table.deref().clone().into(), + ); + let relation = Relation::from(reduce.clone()); + + let dp_relation = reduce + .differentially_private_aggregates(epsilon, delta) + .unwrap(); + dp_relation.display_dot().unwrap(); + assert_eq!(dp_relation.schema().len(), 5); + assert!(dp_relation + .data_type() + .is_subset_of(&DataType::structured(vec![ + ("count_price", DataType::float()), + ("count_distinct_price", DataType::float()), + ("sum_price", DataType::float()), + ("sum_distinct_price", DataType::float()), + ("item", DataType::text()), + ]))); + + let query: &str = &ast::Query::from(&relation).to_string(); + println!("{query}"); + _ = database + .query(query) + .unwrap() + .iter() + .map(ToString::to_string); + + // no group by + let reduce = Reduce::new( + "my_reduce".to_string(), + vec![ + ("count_price".to_string(), AggregateColumn::count("price")), + ("count_distinct_price".to_string(), AggregateColumn::count_distinct("price")), + ("sum_price".to_string(), AggregateColumn::sum("price")), + ("sum_distinct_price".to_string(), AggregateColumn::sum_distinct("price")), + ], + vec![], + pup_table.deref().clone().into(), + ); + let relation = Relation::from(reduce.clone()); + + let dp_relation = reduce + .differentially_private_aggregates(epsilon, delta) + .unwrap(); + dp_relation.display_dot().unwrap(); + assert_eq!(dp_relation.schema().len(), 4); + assert!(dp_relation + .data_type() + .is_subset_of(&DataType::structured(vec![ + ("count_price", DataType::float()), + ("count_distinct_price", DataType::float()), + ("sum_price", DataType::float()), + ("sum_distinct_price", DataType::float()), + ]))); + + let query: &str = &ast::Query::from(&relation).to_string(); + println!("{query}"); + _ = database + .query(query) + .unwrap() + .iter() + .map(ToString::to_string); + } + + #[test] + fn test_split_distinct_aggregates() { + let schema: Schema = vec![ + ("a", DataType::float_interval(-2., 2.)), + ("b", DataType::integer_interval(0, 10)), + ("c", DataType::float_interval(0., 20.)), + ("d", DataType::float_interval(0., 1.)), + ] + .into_iter() + .collect(); + let table: Relation = Relation::table() + .name("table") + .schema(schema.clone()) + .size(1000) + .build(); + + // No distinct + no group by + let reduce: Reduce = Relation::reduce() + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .build(); + let reduces = reduce.split_distinct_aggregates(); + assert_eq!(reduces.len(), 1); + assert_eq!( + reduces[0].data_type(), + DataType::structured([ + ("sum_a", DataType::float_interval(-2000., 2000.)) + ]) + ); + + // No distinct + group by + let reduce: Reduce = Relation::reduce() + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .group_by(expr!(b)) + .build(); + let reduces = reduce.split_distinct_aggregates(); + assert_eq!(reduces.len(), 1); + assert_eq!( + reduces[0].data_type(), + DataType::structured([ + ("sum_a", DataType::float_interval(-2000., 2000.)), + ("b", DataType::integer_interval(0, 10)), + ]) + ); + + // simple distinct + let reduce: Reduce = Relation::reduce() + .input(table.clone()) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .build(); + let reduces = reduce.split_distinct_aggregates(); + assert_eq!(reduces.len(), 1); + Relation::from(reduces[0].clone()).display_dot().unwrap(); + assert_eq!( + reduces[0].data_type(), + DataType::structured([ + ("sum_distinct_a", DataType::float_interval(-2000., 2000.)) + ]) + ); + + // simple distinct with group by + let reduce: Reduce = Relation::reduce() + .input(table.clone()) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .group_by(expr!(b)) + .build(); + let reduces = reduce.split_distinct_aggregates(); + assert_eq!(reduces.len(), 1); + Relation::from(reduces[0].clone()).display_dot().unwrap(); + assert_eq!( + reduces[0].data_type(), + DataType::structured([ + ("sum_distinct_a", DataType::float_interval(-2000., 2000.)), + ("b", DataType::integer_interval(0, 10)), + ]) + ); + + // simple distinct with group by + let reduce: Reduce = Relation::reduce() + .input(table.clone()) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .with_group_by_column("b") + .build(); + let reduces = reduce.split_distinct_aggregates(); + assert_eq!(reduces.len(), 1); + Relation::from(reduces[0].clone()).display_dot().unwrap(); + assert_eq!( + reduces[0].data_type(), + DataType::structured([ + ("sum_distinct_a", DataType::float_interval(-2000., 2000.)), + ("b", DataType::integer_interval(0, 10)), + ]) + ); + + // multi distinct + no group by + let reduce: Reduce = Relation::reduce() + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .with(("count_b", AggregateColumn::count("b"))) + .with(("count_distinct_b", AggregateColumn::count_distinct("b"))) + .build(); + let reduces = reduce.split_distinct_aggregates(); + assert_eq!(reduces.len(), 3); + Relation::from(reduces[0].clone()).display_dot().unwrap(); + Relation::from(reduces[1].clone()).display_dot().unwrap(); + Relation::from(reduces[2].clone()).display_dot().unwrap(); + + // multi distinct + group by + let reduce: Reduce = Relation::reduce() + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .with(("count_b", AggregateColumn::count("b"))) + .with(("count_distinct_b", AggregateColumn::count_distinct("b"))) + .with(("my_c", AggregateColumn::first("c"))) + .group_by(expr!(c)) + .build(); + let reduces = reduce.split_distinct_aggregates(); + assert_eq!(reduces.len(), 3); + Relation::from(reduces[0].clone()).display_dot().unwrap(); + Relation::from(reduces[1].clone()).display_dot().unwrap(); + Relation::from(reduces[2].clone()).display_dot().unwrap(); + } + + #[test] + fn test_distinct_differentially_private_aggregates() { + let schema: Schema = vec![ + ("a", DataType::float_interval(-2., 2.)), + ("b", DataType::integer_interval(0, 10)), + ("c", DataType::float_interval(10., 20.)), + (PrivacyUnit::privacy_unit(), DataType::text()), + (PrivacyUnit::privacy_unit_weight(), DataType::float()), + ] + .into_iter() + .collect(); + let table: Relation = Relation::table() + .name("table") + .schema(schema.clone()) + .size(1000) + .build(); + + let (epsilon, delta) = (1.0, 1e-5); + + // No distinct + no group by + let reduce: Reduce = Relation::reduce() + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .build(); + let dp_relation = reduce.differentially_private_aggregates(epsilon.clone(), delta.clone()).unwrap(); + assert_eq!( + dp_relation.private_query(), + &PrivateQuery::gaussian_from_epsilon_delta_sensitivity( + epsilon.clone(), + delta.clone(), + 2. + ) + ); + assert_eq!( + dp_relation.relation().data_type(), + DataType::structured([ + ("sum_a", DataType::float_interval(-2000., 2000.)) + ]) + ); + + // No distinct + group by + let reduce: Reduce = Relation::reduce() + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .group_by(expr!(b)) + .build(); + let dp_relation = reduce.differentially_private_aggregates(epsilon.clone(), delta.clone()).unwrap(); + assert_eq!( + dp_relation.private_query(), + &PrivateQuery::gaussian_from_epsilon_delta_sensitivity( + epsilon.clone(), + delta.clone(), + 2. + ) + ); + assert_eq!( + dp_relation.relation().data_type(), + DataType::structured([ + ("sum_a", DataType::float_interval(-2000., 2000.)) + ]) + ); + + // simple distinct + let reduce: Reduce = Relation::reduce() + .input(table.clone()) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .build(); + let dp_relation = reduce.differentially_private_aggregates(epsilon.clone(), delta.clone()).unwrap(); + //dp_relation.relation().display_dot().unwrap(); + assert_eq!( + dp_relation.private_query(), + &PrivateQuery::gaussian_from_epsilon_delta_sensitivity( + epsilon.clone(), + delta.clone(), + 2. + ) + ); + assert_eq!( + dp_relation.relation().data_type(), + DataType::structured([ + ("sum_distinct_a", DataType::float_interval(-2000., 2000.)) + ]) + ); + + // simple distinct with group by + let reduce: Reduce = Relation::reduce() + .input(table.clone()) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .with_group_by_column("b") + .build(); + let dp_relation = reduce.differentially_private_aggregates(epsilon.clone(), delta.clone()).unwrap(); + //dp_relation.relation().display_dot().unwrap(); + assert_eq!( + dp_relation.private_query(), + &PrivateQuery::gaussian_from_epsilon_delta_sensitivity( + epsilon.clone(), + delta.clone(), + 2. + ) + ); + assert_eq!( + dp_relation.relation().data_type(), + DataType::structured([ + ("sum_distinct_a", DataType::float_interval(-2000., 2000.)), + ("b", DataType::integer_interval(0, 10)) + ]) + ); + + // multi distinct + no group by + let reduce: Reduce = Relation::reduce() + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .with(("count_b", AggregateColumn::count("b"))) + .with(("count_distinct_b", AggregateColumn::count_distinct("b"))) + .build(); + let dp_relation = reduce.differentially_private_aggregates(epsilon.clone(), delta.clone()).unwrap(); + //dp_relation.relation().display_dot().unwrap(); + assert_eq!( + dp_relation.relation().data_type(), + DataType::structured([ + ("sum_a", DataType::float_interval(-2000., 2000.)), + ("sum_distinct_a", DataType::float_interval(-2000., 2000.)), + ("count_b", DataType::float_interval(0., 1000.)), + ("count_distinct_b", DataType::float_interval(0., 1000.)), + ]) + ); + + // multi distinct + group by + let reduce: Reduce = Relation::reduce() + .input(table.clone()) + .with(("sum_a", AggregateColumn::sum("a"))) + .with(("sum_distinct_a", AggregateColumn::sum_distinct("a"))) + .with(("count_b", AggregateColumn::count("b"))) + .with(("count_distinct_b", AggregateColumn::count_distinct("b"))) + .with(("my_c", AggregateColumn::first("c"))) + .group_by(expr!(c)) + .build(); + let dp_relation = reduce.differentially_private_aggregates(epsilon.clone(), delta.clone()).unwrap(); + dp_relation.relation().display_dot().unwrap(); + assert_eq!( + dp_relation.relation().data_type(), + DataType::structured([ + ("sum_a", DataType::float_interval(-2000., 2000.)), + ("sum_distinct_a", DataType::float_interval(-2000., 2000.)), + ("count_b", DataType::float_interval(0., 1000.)), + ("count_distinct_b", DataType::float_interval(0., 1000.)), + ("my_c", DataType::float_interval(10., 20.)), + ]) + ); + } + } diff --git a/src/expr/aggregate.rs b/src/expr/aggregate.rs index 82a78a0a..8350ad26 100644 --- a/src/expr/aggregate.rs +++ b/src/expr/aggregate.rs @@ -5,6 +5,8 @@ use itertools::Itertools; use super::{implementation, Result}; use crate::data_type::{value::Value, DataType}; + + /// The list of operators /// inspired by: https://docs.rs/sqlparser/latest/sqlparser/ast/enum.BinaryOperator.html /// and mostly: https://docs.rs/polars/latest/polars/prelude/enum.AggExpr.html @@ -18,14 +20,19 @@ pub enum Aggregate { First, Last, Mean, + MeanDistinct, List, Count, + CountDistinct, Quantile(f64), Quantiles(&'static [f64]), Sum, + SumDistinct, AggGroups, Std, + StdDistinct, Var, + VarDistinct } // TODO make sure f64::nan do not happen @@ -83,6 +90,11 @@ impl fmt::Display for Aggregate { Aggregate::AggGroups => write!(f, "agg_groups"), Aggregate::Std => write!(f, "std"), Aggregate::Var => write!(f, "var"), + Aggregate::MeanDistinct => write!(f, "mean_distinct"), + Aggregate::CountDistinct => write!(f, "count_distinct"), + Aggregate::SumDistinct => write!(f, "sum_distinct"), + Aggregate::StdDistinct => write!(f, "std_distinct"), + Aggregate::VarDistinct => write!(f, "var_distinct"), } } } diff --git a/src/expr/implementation.rs b/src/expr/implementation.rs index 8d7eee58..4fc5e423 100644 --- a/src/expr/implementation.rs +++ b/src/expr/implementation.rs @@ -125,7 +125,7 @@ macro_rules! aggregate_implementations { } aggregate_implementations!( - [Min, Max, Median, NUnique, First, Last, Mean, List, Count, Sum, AggGroups, Std, Var], + [Min, Max, Median, NUnique, First, Last, Mean, List, Count, Sum, AggGroups, Std, Var, MeanDistinct, CountDistinct, SumDistinct, StdDistinct, VarDistinct], x, { match x { diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 4c166ed9..83becd96 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -489,7 +489,7 @@ macro_rules! impl_aggregation_constructors { }; } -impl_aggregation_constructors!(First, Last, Min, Max, Count, Mean, Sum, Var, Std); +impl_aggregation_constructors!(First, Last, Min, Max, Count, Mean, Sum, Var, Std, CountDistinct, MeanDistinct, SumDistinct, VarDistinct, StdDistinct); /// An aggregate function expression #[derive(Clone, Debug, Hash, PartialEq, Eq)] diff --git a/src/expr/sql.rs b/src/expr/sql.rs index b50d9ec2..1cc5c3dc 100644 --- a/src/expr/sql.rs +++ b/src/expr/sql.rs @@ -396,6 +396,81 @@ impl<'a> expr::Visitor<'a, ast::Expr> for FromExprVisitor { filter: None, null_treatment: None, }), + expr::aggregate::Aggregate::MeanDistinct => ast::Expr::Function(ast::Function { + name: ast::ObjectName(vec![ast::Ident { + value: String::from("avg"), + quote_style: None, + }]), + args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + argument, + ))], + over: None, + distinct: true, + special: false, + order_by: vec![], + filter: None, + null_treatment: None, + }), + expr::aggregate::Aggregate::CountDistinct => ast::Expr::Function(ast::Function { + name: ast::ObjectName(vec![ast::Ident { + value: String::from("count"), + quote_style: None, + }]), + args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + argument, + ))], + over: None, + distinct: true, + special: false, + order_by: vec![], + filter: None, + null_treatment: None, + }), + expr::aggregate::Aggregate::SumDistinct => ast::Expr::Function(ast::Function { + name: ast::ObjectName(vec![ast::Ident { + value: String::from("sum"), + quote_style: None, + }]), + args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + argument, + ))], + over: None, + distinct: true, + special: false, + order_by: vec![], + filter: None, + null_treatment: None, + }), + expr::aggregate::Aggregate::StdDistinct => ast::Expr::Function(ast::Function { + name: ast::ObjectName(vec![ast::Ident { + value: String::from("stddev"), + quote_style: None, + }]), + args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + argument, + ))], + over: None, + distinct: true, + special: false, + order_by: vec![], + filter: None, + null_treatment: None, + }), + expr::aggregate::Aggregate::VarDistinct => ast::Expr::Function(ast::Function { + name: ast::ObjectName(vec![ast::Ident { + value: String::from("variance"), + quote_style: None, + }]), + args: vec![ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + argument, + ))], + over: None, + distinct: true, + special: false, + order_by: vec![], + filter: None, + null_treatment: None, + }), } } diff --git a/src/relation/rewriting.rs b/src/relation/rewriting.rs index 08d7ff5b..c4a8c843 100644 --- a/src/relation/rewriting.rs +++ b/src/relation/rewriting.rs @@ -741,6 +741,40 @@ impl Relation { .right_names(right_names) .build()) } + + /// Returns the outer join between `self` and `right` where + /// the output names of the fields are conserved. + /// The joining criteria is the equality of columns with the same name + pub fn natural_inner_join(self, right: Self) -> Relation { + let mut left_names: Vec = vec![]; + let mut right_names: Vec = vec![]; + let mut names: Vec<(String, Expr)> = vec![]; + for f in self.fields() { + let col = f.name().to_string(); + left_names.push(col.clone()); + names.push((col.clone(), Expr::col(col))); + } + for f in right.fields() { + let col = f.name().to_string(); + if left_names.contains(&col) { + right_names.push(format!("right_{}", col)); + } else { + right_names.push(col.clone()); + names.push((col.clone(), Expr::col(col))); + } + } + let join: Relation = Relation::join() + .left(self.clone()) + .right(right.clone()) + .inner() + .left_names(left_names) + .right_names(right_names) + .build(); + Relation::map() + .input(join) + .with_iter(names) + .build() + } } impl With<(&str, Expr)> for Relation { diff --git a/src/sql/expr.rs b/src/sql/expr.rs index 7070d910..3a0e3d34 100644 --- a/src/sql/expr.rs +++ b/src/sql/expr.rs @@ -522,9 +522,6 @@ impl<'a, T: Clone, V: Visitor<'a, T>> visitor::Visitor<'a, ast::Expr, T> for V { ast::Expr::TypedString { data_type, value } => todo!(), ast::Expr::MapAccess { column, keys } => todo!(), ast::Expr::Function(function) => { - if function.distinct { - todo!() - } self.function(function, { let mut result = vec![]; for function_arg in function.args.iter() { @@ -909,6 +906,9 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { .collect(); let flat_args = flat_args?; let function_name: &str = &function.name.0.iter().join(".").to_lowercase(); + if function.distinct && !["count", "sum", "avg", "variance", "stddev"].contains(&function_name) { + todo!() + } Ok(match function_name { // Functions Opposite, Not, Exp, Ln, Log, Abs, Sin, Cos "opposite" => Expr::opposite(flat_args[0].clone()), @@ -1001,10 +1001,15 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { // Aggregates "min" => Expr::min(flat_args[0].clone()), "max" => Expr::max(flat_args[0].clone()), + "count" if function.distinct => Expr::count_distinct(flat_args[0].clone()), "count" => Expr::count(flat_args[0].clone()), + "avg" if function.distinct => Expr::mean_distinct(flat_args[0].clone()), "avg" => Expr::mean(flat_args[0].clone()), + "sum" if function.distinct => Expr::sum_distinct(flat_args[0].clone()), "sum" => Expr::sum(flat_args[0].clone()), + "variance" if function.distinct => Expr::var_distinct(flat_args[0].clone()), "variance" => Expr::var(flat_args[0].clone()), + "stddev" if function.distinct => Expr::std_distinct(flat_args[0].clone()), "stddev" => Expr::std(flat_args[0].clone()), _ => todo!(), }) diff --git a/src/sql/mod.rs b/src/sql/mod.rs index 05f16582..a68101da 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -152,6 +152,7 @@ mod tests { "SELECT CAST(x AS float) FROM table_2", // integer => float "SELECT CAST('true' AS boolean) FROM table_2", // integer => float "SELECT CEIL(3 * b), FLOOR(3 * b), TRUNC(3 * b), ROUND(3 * b) FROM table_1", + "SELECT SUM(DISTINCT a), SUM(a) FROM table_1" ] { let res1 = database.query(query).unwrap(); let relation = Relation::try_from(parse(query).unwrap().with(&database.relations())).unwrap(); From 7e38cb02ca2d260bfd24dae9551e537cf4fc42e5 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Mon, 27 Nov 2023 08:41:39 +0100 Subject: [PATCH 2/4] add docstrings --- src/differential_privacy/aggregates.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/differential_privacy/aggregates.rs b/src/differential_privacy/aggregates.rs index ce327b0a..86becd7c 100644 --- a/src/differential_privacy/aggregates.rs +++ b/src/differential_privacy/aggregates.rs @@ -226,6 +226,7 @@ impl Reduce { let epsilon = epsilon / (cmp::max(reduces.len(), 1) as f64); let delta = delta / (cmp::max(reduces.len(), 1) as f64); + // Rewritten into differential privacy each `Reduce` then join them. let (relation, private_query) = reduces.iter() .map(|r| pup_input.clone().differentially_private_aggregates( r.named_aggregates() @@ -254,6 +255,10 @@ impl Reduce { Ok((relation, private_query).into()) } + + /// Returns a Vec of rewritten `Reduce` whose each item corresponds to a specific `DISTINCT` clause + /// (e.g.: SUM(DISTINCT a) or COUNT(DISTINCT a) have the same `DISTINCT` clause). The original `Reduce`` + /// has been rewritten with `GROUP BY`s for each `DISTINCT` clause. fn split_distinct_aggregates(&self) -> Vec { let mut distinct_map: HashMap, Vec<(String, AggregateColumn)>> = HashMap::new(); let mut first_aggs: Vec<(String, AggregateColumn)> = vec![]; @@ -284,6 +289,21 @@ impl Reduce { .collect() } + /// Rewrite the `DISTINCT` aggregate with a `GROUP BY` + /// + /// # Arguments + /// - `self`: we reuse the `input` and `group_by` fields of the current `Reduce + /// - `identifier`: The optionnal column `Identifier` associated with the `DISTINCT`, if `None` then the aggregates + /// contain no `DISTINCT`. + /// - `aggs` the vector of the `AggregateColumn` with their names + /// + /// Example 1 : + /// (SELECT sum(DISTINCT col1), count(*) FROM table GROUP BY a, Some(col1), ("my_sum", sum(col1))) + /// --> SELECT a AS a, sum(col1) AS my_sum FROM (SELECT a AS a, sum(col1) AS col1 FROM table GROUP BY a, col1) GROUP BY a + /// + /// Example 2 : + /// (SELECT sum(DISTINCT col1), count(*) FROM table GROUP BY a, None, ("my_count", count(*))) + /// --> SELECT a AS a, count(*) AS my_count FROM table GROUP BY a fn rewrite_distinct(&self, identifier: Option, aggs: Vec<(String, AggregateColumn)>) -> Reduce { let builder = Relation::reduce() .input(self.input().clone()); From 2077f5d7fb31c232cb584e1ac83d5777edbd2078 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Mon, 27 Nov 2023 09:26:39 +0100 Subject: [PATCH 3/4] implement inverse trigo functions --- CHANGELOG.md | 1 + src/data_type/function.rs | 84 ++++++++++++++++++++++++++++++++++++++ src/expr/function.rs | 16 +++++++- src/expr/implementation.rs | 4 +- src/expr/mod.rs | 71 +++++++++++++++++++++++++++++++- src/expr/sql.rs | 32 ++++++++++++++- src/sql/expr.rs | 3 ++ 7 files changed, 205 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a7e1bc0c..884fa4b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Added - implemented `DISTINCT`` in aggregations [#197](https://github.com/Qrlew/qrlew/issues/197) - Implemented math functions: `PI`, `DEGREES`, `TAN`, `RANDOM`, `LOG10`, `LOG2`, `SQUARE` [#196](https://github.com/Qrlew/qrlew/issues/196) +- Implemented inverse trigo functions: `ASIN`, `ACOS`, `ATAN` [#198](https://github.com/Qrlew/qrlew/issues/198) ## [0.5.2] - 2023-11-19 ## Added diff --git a/src/data_type/function.rs b/src/data_type/function.rs index ccb4d8b2..6da1b5ba 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -1724,6 +1724,30 @@ pub fn cos() -> impl Function { ) } +/// inverse sine +pub fn asin() -> impl Function { + PartitionnedMonotonic::univariate( + data_type::Float::from_interval(-1., 1.), + |x| x.asin() + ) +} + +/// inverse cosine +pub fn acos() -> impl Function { + PartitionnedMonotonic::univariate( + data_type::Float::from_interval(-1., 1.), + |x| x.acos() + ) +} + +/// inverse tangent +pub fn atan() -> impl Function { + PartitionnedMonotonic::univariate( + data_type::Float::default(), + |x| x.atan() + ) +} + pub fn least() -> impl Function { Polymorphic::from(( PartitionnedMonotonic::bivariate( @@ -3717,4 +3741,64 @@ mod tests { println!("im({}) = {}", set, im); assert!(im == DataType::integer_value(0)); } + + #[test] + fn test_asin() { + println!("\nTest asin"); + let fun = asin(); + println!("type = {}", fun); + println!("domain = {}", fun.domain()); + println!("co_domain = {}", fun.co_domain()); + println!("data_type = {}", fun.data_type()); + + let set = DataType::float_interval(-1., 1.); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::float_interval(-std::f64::consts::PI / 2., std::f64::consts::PI / 2.)); + + let set = DataType::float_value(0.); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::float_value(0.)); + } + + #[test] + fn test_acos() { + println!("\nTest acos"); + let fun = acos(); + println!("type = {}", fun); + println!("domain = {}", fun.domain()); + println!("co_domain = {}", fun.co_domain()); + println!("data_type = {}", fun.data_type()); + + let set = DataType::float_interval(-1., 1.); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::float_interval(0., std::f64::consts::PI)); + + let set = DataType::float_value(0.); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::float_value(std::f64::consts::PI / 2.)); + } + + #[test] + fn test_atan() { + println!("\nTest atan"); + let fun = atan(); + println!("type = {}", fun); + println!("domain = {}", fun.domain()); + println!("co_domain = {}", fun.co_domain()); + println!("data_type = {}", fun.data_type()); + + let set = DataType::float_min(0.); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::float_interval(0., std::f64::consts::PI / 2.)); + + let set = DataType::float_value(0.); + let im = fun.super_image(&set).unwrap(); + println!("im({}) = {}", set, im); + assert!(im == DataType::float_value(0.)); + } } diff --git a/src/expr/function.rs b/src/expr/function.rs index a073ef8c..267a994c 100644 --- a/src/expr/function.rs +++ b/src/expr/function.rs @@ -68,7 +68,10 @@ pub enum Function { Ceil, Floor, Round, - Trunc + Trunc, + Asin, + Acos, + Atan } #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] @@ -136,6 +139,9 @@ impl Function { | Function::CastAsDate | Function::CastAsTime | Function::Sign + | Function::Asin + | Function::Acos + | Function::Atan // Binary Functions | Function::Pow | Function::Position @@ -203,7 +209,10 @@ impl Function { | Function::CastAsTime | Function::Ceil | Function::Floor - | Function::Sign => Arity::Unary, + | Function::Sign + | Function::Asin + | Function::Acos + | Function::Atan => Arity::Unary, // Binary Function Function::Pow | Function::Position @@ -296,6 +305,9 @@ impl fmt::Display for Function { Function::CastAsDate => "cast_as_date", Function::CastAsTime => "cast_as_time", Function::Sign => "sign", + Function::Asin => "asin", + Function::Acos => "acos", + Function::Atan => "atan", // Binary Functions Function::Pow => "pow", Function::Position => "position", diff --git a/src/expr/implementation.rs b/src/expr/implementation.rs index 4fc5e423..a198c445 100644 --- a/src/expr/implementation.rs +++ b/src/expr/implementation.rs @@ -40,12 +40,12 @@ macro_rules! function_implementations { } // All functions: -// Unary: Opposite, Not, Exp, Ln, Abs, Sin, Cos, CharLength, Lower, Upper, Md5, Ceil, Floor, Sign +// Unary: Opposite, Not, Exp, Ln, Abs, Sin, Cos, CharLength, Lower, Upper, Md5, Ceil, Floor, Sign, Asin, Acos, Atan // Binary: Plus, Minus, Multiply, Divide, Modulo, StringConcat, Gt, Lt, GtEq, LtEq, Eq, NotEq, And, Or, Xor, BitwiseOr, BitwiseAnd, BitwiseXor, Position, Concat, Greatest, Least, Round, Trunc // Ternary: Case, Position // Nary: Concat function_implementations!( - [Opposite, Not, Exp, Ln, Log, Abs, Sin, Cos, Sqrt, Md5, Ceil, Floor, Sign], + [Opposite, Not, Exp, Ln, Log, Abs, Sin, Cos, Sqrt, Md5, Ceil, Floor, Sign, Asin, Acos, Atan], [ Plus, Minus, diff --git a/src/expr/mod.rs b/src/expr/mod.rs index 83becd96..b55a81ae 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -297,7 +297,10 @@ impl_unary_function_constructors!( CastAsTime, Ceil, Floor, - Sign + Sign, + Asin, + Acos, + Atan ); // TODO Complete that /// Implement binary function constructors @@ -3101,4 +3104,70 @@ mod tests { DataType::float_value(3.141592653589793) ); } + + #[test] + fn test_asin() { + println!("asin"); + let expression = expr!(asin(a)); + println!("expression = {}", expression); + println!("expression domain = {}", expression.domain()); + println!("expression co domain = {}", expression.co_domain()); + println!("expression data type = {}", expression.data_type()); + + let set = DataType::structured([ + ("a", DataType::float_interval(-1., 1.)), + ]); + println!( + "expression super image = {}", + expression.super_image(&set).unwrap() + ); + assert_eq!( + expression.super_image(&set).unwrap(), + DataType::float_interval(-std::f64::consts::PI / 2., std::f64::consts::PI / 2.) + ); + } + + #[test] + fn test_acos() { + println!("acos"); + let expression = expr!(acos(a)); + println!("expression = {}", expression); + println!("expression domain = {}", expression.domain()); + println!("expression co domain = {}", expression.co_domain()); + println!("expression data type = {}", expression.data_type()); + + let set = DataType::structured([ + ("a", DataType::float_interval(-1., 1.)), + ]); + println!( + "expression super image = {}", + expression.super_image(&set).unwrap() + ); + assert_eq!( + expression.super_image(&set).unwrap(), + DataType::float_interval(0., std::f64::consts::PI) + ); + } + + #[test] + fn test_atan() { + println!("atan"); + let expression = expr!(atan(a)); + println!("expression = {}", expression); + println!("expression domain = {}", expression.domain()); + println!("expression co domain = {}", expression.co_domain()); + println!("expression data type = {}", expression.data_type()); + + let set = DataType::structured([ + ("a", DataType::float()), + ]); + println!( + "expression super image = {}", + expression.super_image(&set).unwrap() + ); + assert_eq!( + expression.super_image(&set).unwrap(), + DataType::float_interval(-std::f64::consts::PI / 2., std::f64::consts::PI / 2.) + ); + } } diff --git a/src/expr/sql.rs b/src/expr/sql.rs index 1cc5c3dc..d1cbfb09 100644 --- a/src/expr/sql.rs +++ b/src/expr/sql.rs @@ -188,7 +188,10 @@ impl<'a> expr::Visitor<'a, ast::Expr> for FromExprVisitor { | expr::function::Function::SubstrWithSize | expr::function::Function::Ceil | expr::function::Function::Floor - | expr::function::Function::Sign => ast::Expr::Function(ast::Function { + | expr::function::Function::Sign + | expr::function::Function::Asin + | expr::function::Function::Acos + | expr::function::Function::Atan => ast::Expr::Function(ast::Function { name: ast::ObjectName(vec![ast::Ident::new(function.to_string())]), args: arguments .into_iter() @@ -824,6 +827,33 @@ mod tests { assert_eq!(gen_expr, true_expr); } + #[test] + fn test_inverse_trigo() { + let str_expr = "asin(x)"; + let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); + let expr = Expr::try_from(&ast_expr).unwrap(); + println!("expr = {}", expr); + let gen_expr = ast::Expr::from(&expr); + println!("ast::expr = {gen_expr}"); + assert_eq!(ast_expr, gen_expr); + + let str_expr = "acos(x)"; + let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); + let expr = Expr::try_from(&ast_expr).unwrap(); + println!("expr = {}", expr); + let gen_expr = ast::Expr::from(&expr); + println!("ast::expr = {gen_expr}"); + assert_eq!(ast_expr, gen_expr); + + let str_expr = "atan(x)"; + let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); + let expr = Expr::try_from(&ast_expr).unwrap(); + println!("expr = {}", expr); + let gen_expr = ast::Expr::from(&expr); + println!("ast::expr = {gen_expr}"); + assert_eq!(ast_expr, gen_expr); + } + #[test] fn test_random() { let str_expr = "random()"; diff --git a/src/sql/expr.rs b/src/sql/expr.rs index 3a0e3d34..e2861759 100644 --- a/src/sql/expr.rs +++ b/src/sql/expr.rs @@ -928,6 +928,9 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { "sin" => Expr::sin(flat_args[0].clone()), "cos" => Expr::cos(flat_args[0].clone()), "tan" => Expr::divide(Expr::sin(flat_args[0].clone()), Expr::cos(flat_args[0].clone())), + "asin" => Expr::asin(flat_args[0].clone()), + "acos" => Expr::acos(flat_args[0].clone()), + "atan" => Expr::atan(flat_args[0].clone()), "sqrt" => Expr::sqrt(flat_args[0].clone()), "pow" => Expr::pow(flat_args[0].clone(), flat_args[1].clone()), "power" => Expr::pow(flat_args[0].clone(), flat_args[1].clone()), From 351438fe81fa351518201544666c24e1e2a74909 Mon Sep 17 00:00:00 2001 From: victoria de sainte agathe Date: Mon, 27 Nov 2023 09:58:32 +0100 Subject: [PATCH 4/4] ok --- CHANGELOG.md | 1 - src/data_type/function.rs | 84 -------------------------------------- src/expr/function.rs | 16 +------- src/expr/implementation.rs | 4 +- src/expr/mod.rs | 71 +------------------------------- src/expr/sql.rs | 32 +-------------- src/sql/expr.rs | 3 -- 7 files changed, 6 insertions(+), 205 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 884fa4b3..a7e1bc0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Added - implemented `DISTINCT`` in aggregations [#197](https://github.com/Qrlew/qrlew/issues/197) - Implemented math functions: `PI`, `DEGREES`, `TAN`, `RANDOM`, `LOG10`, `LOG2`, `SQUARE` [#196](https://github.com/Qrlew/qrlew/issues/196) -- Implemented inverse trigo functions: `ASIN`, `ACOS`, `ATAN` [#198](https://github.com/Qrlew/qrlew/issues/198) ## [0.5.2] - 2023-11-19 ## Added diff --git a/src/data_type/function.rs b/src/data_type/function.rs index 6da1b5ba..ccb4d8b2 100644 --- a/src/data_type/function.rs +++ b/src/data_type/function.rs @@ -1724,30 +1724,6 @@ pub fn cos() -> impl Function { ) } -/// inverse sine -pub fn asin() -> impl Function { - PartitionnedMonotonic::univariate( - data_type::Float::from_interval(-1., 1.), - |x| x.asin() - ) -} - -/// inverse cosine -pub fn acos() -> impl Function { - PartitionnedMonotonic::univariate( - data_type::Float::from_interval(-1., 1.), - |x| x.acos() - ) -} - -/// inverse tangent -pub fn atan() -> impl Function { - PartitionnedMonotonic::univariate( - data_type::Float::default(), - |x| x.atan() - ) -} - pub fn least() -> impl Function { Polymorphic::from(( PartitionnedMonotonic::bivariate( @@ -3741,64 +3717,4 @@ mod tests { println!("im({}) = {}", set, im); assert!(im == DataType::integer_value(0)); } - - #[test] - fn test_asin() { - println!("\nTest asin"); - let fun = asin(); - println!("type = {}", fun); - println!("domain = {}", fun.domain()); - println!("co_domain = {}", fun.co_domain()); - println!("data_type = {}", fun.data_type()); - - let set = DataType::float_interval(-1., 1.); - let im = fun.super_image(&set).unwrap(); - println!("im({}) = {}", set, im); - assert!(im == DataType::float_interval(-std::f64::consts::PI / 2., std::f64::consts::PI / 2.)); - - let set = DataType::float_value(0.); - let im = fun.super_image(&set).unwrap(); - println!("im({}) = {}", set, im); - assert!(im == DataType::float_value(0.)); - } - - #[test] - fn test_acos() { - println!("\nTest acos"); - let fun = acos(); - println!("type = {}", fun); - println!("domain = {}", fun.domain()); - println!("co_domain = {}", fun.co_domain()); - println!("data_type = {}", fun.data_type()); - - let set = DataType::float_interval(-1., 1.); - let im = fun.super_image(&set).unwrap(); - println!("im({}) = {}", set, im); - assert!(im == DataType::float_interval(0., std::f64::consts::PI)); - - let set = DataType::float_value(0.); - let im = fun.super_image(&set).unwrap(); - println!("im({}) = {}", set, im); - assert!(im == DataType::float_value(std::f64::consts::PI / 2.)); - } - - #[test] - fn test_atan() { - println!("\nTest atan"); - let fun = atan(); - println!("type = {}", fun); - println!("domain = {}", fun.domain()); - println!("co_domain = {}", fun.co_domain()); - println!("data_type = {}", fun.data_type()); - - let set = DataType::float_min(0.); - let im = fun.super_image(&set).unwrap(); - println!("im({}) = {}", set, im); - assert!(im == DataType::float_interval(0., std::f64::consts::PI / 2.)); - - let set = DataType::float_value(0.); - let im = fun.super_image(&set).unwrap(); - println!("im({}) = {}", set, im); - assert!(im == DataType::float_value(0.)); - } } diff --git a/src/expr/function.rs b/src/expr/function.rs index 267a994c..a073ef8c 100644 --- a/src/expr/function.rs +++ b/src/expr/function.rs @@ -68,10 +68,7 @@ pub enum Function { Ceil, Floor, Round, - Trunc, - Asin, - Acos, - Atan + Trunc } #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] @@ -139,9 +136,6 @@ impl Function { | Function::CastAsDate | Function::CastAsTime | Function::Sign - | Function::Asin - | Function::Acos - | Function::Atan // Binary Functions | Function::Pow | Function::Position @@ -209,10 +203,7 @@ impl Function { | Function::CastAsTime | Function::Ceil | Function::Floor - | Function::Sign - | Function::Asin - | Function::Acos - | Function::Atan => Arity::Unary, + | Function::Sign => Arity::Unary, // Binary Function Function::Pow | Function::Position @@ -305,9 +296,6 @@ impl fmt::Display for Function { Function::CastAsDate => "cast_as_date", Function::CastAsTime => "cast_as_time", Function::Sign => "sign", - Function::Asin => "asin", - Function::Acos => "acos", - Function::Atan => "atan", // Binary Functions Function::Pow => "pow", Function::Position => "position", diff --git a/src/expr/implementation.rs b/src/expr/implementation.rs index a198c445..4fc5e423 100644 --- a/src/expr/implementation.rs +++ b/src/expr/implementation.rs @@ -40,12 +40,12 @@ macro_rules! function_implementations { } // All functions: -// Unary: Opposite, Not, Exp, Ln, Abs, Sin, Cos, CharLength, Lower, Upper, Md5, Ceil, Floor, Sign, Asin, Acos, Atan +// Unary: Opposite, Not, Exp, Ln, Abs, Sin, Cos, CharLength, Lower, Upper, Md5, Ceil, Floor, Sign // Binary: Plus, Minus, Multiply, Divide, Modulo, StringConcat, Gt, Lt, GtEq, LtEq, Eq, NotEq, And, Or, Xor, BitwiseOr, BitwiseAnd, BitwiseXor, Position, Concat, Greatest, Least, Round, Trunc // Ternary: Case, Position // Nary: Concat function_implementations!( - [Opposite, Not, Exp, Ln, Log, Abs, Sin, Cos, Sqrt, Md5, Ceil, Floor, Sign, Asin, Acos, Atan], + [Opposite, Not, Exp, Ln, Log, Abs, Sin, Cos, Sqrt, Md5, Ceil, Floor, Sign], [ Plus, Minus, diff --git a/src/expr/mod.rs b/src/expr/mod.rs index b55a81ae..83becd96 100644 --- a/src/expr/mod.rs +++ b/src/expr/mod.rs @@ -297,10 +297,7 @@ impl_unary_function_constructors!( CastAsTime, Ceil, Floor, - Sign, - Asin, - Acos, - Atan + Sign ); // TODO Complete that /// Implement binary function constructors @@ -3104,70 +3101,4 @@ mod tests { DataType::float_value(3.141592653589793) ); } - - #[test] - fn test_asin() { - println!("asin"); - let expression = expr!(asin(a)); - println!("expression = {}", expression); - println!("expression domain = {}", expression.domain()); - println!("expression co domain = {}", expression.co_domain()); - println!("expression data type = {}", expression.data_type()); - - let set = DataType::structured([ - ("a", DataType::float_interval(-1., 1.)), - ]); - println!( - "expression super image = {}", - expression.super_image(&set).unwrap() - ); - assert_eq!( - expression.super_image(&set).unwrap(), - DataType::float_interval(-std::f64::consts::PI / 2., std::f64::consts::PI / 2.) - ); - } - - #[test] - fn test_acos() { - println!("acos"); - let expression = expr!(acos(a)); - println!("expression = {}", expression); - println!("expression domain = {}", expression.domain()); - println!("expression co domain = {}", expression.co_domain()); - println!("expression data type = {}", expression.data_type()); - - let set = DataType::structured([ - ("a", DataType::float_interval(-1., 1.)), - ]); - println!( - "expression super image = {}", - expression.super_image(&set).unwrap() - ); - assert_eq!( - expression.super_image(&set).unwrap(), - DataType::float_interval(0., std::f64::consts::PI) - ); - } - - #[test] - fn test_atan() { - println!("atan"); - let expression = expr!(atan(a)); - println!("expression = {}", expression); - println!("expression domain = {}", expression.domain()); - println!("expression co domain = {}", expression.co_domain()); - println!("expression data type = {}", expression.data_type()); - - let set = DataType::structured([ - ("a", DataType::float()), - ]); - println!( - "expression super image = {}", - expression.super_image(&set).unwrap() - ); - assert_eq!( - expression.super_image(&set).unwrap(), - DataType::float_interval(-std::f64::consts::PI / 2., std::f64::consts::PI / 2.) - ); - } } diff --git a/src/expr/sql.rs b/src/expr/sql.rs index d1cbfb09..1cc5c3dc 100644 --- a/src/expr/sql.rs +++ b/src/expr/sql.rs @@ -188,10 +188,7 @@ impl<'a> expr::Visitor<'a, ast::Expr> for FromExprVisitor { | expr::function::Function::SubstrWithSize | expr::function::Function::Ceil | expr::function::Function::Floor - | expr::function::Function::Sign - | expr::function::Function::Asin - | expr::function::Function::Acos - | expr::function::Function::Atan => ast::Expr::Function(ast::Function { + | expr::function::Function::Sign => ast::Expr::Function(ast::Function { name: ast::ObjectName(vec![ast::Ident::new(function.to_string())]), args: arguments .into_iter() @@ -827,33 +824,6 @@ mod tests { assert_eq!(gen_expr, true_expr); } - #[test] - fn test_inverse_trigo() { - let str_expr = "asin(x)"; - let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); - let expr = Expr::try_from(&ast_expr).unwrap(); - println!("expr = {}", expr); - let gen_expr = ast::Expr::from(&expr); - println!("ast::expr = {gen_expr}"); - assert_eq!(ast_expr, gen_expr); - - let str_expr = "acos(x)"; - let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); - let expr = Expr::try_from(&ast_expr).unwrap(); - println!("expr = {}", expr); - let gen_expr = ast::Expr::from(&expr); - println!("ast::expr = {gen_expr}"); - assert_eq!(ast_expr, gen_expr); - - let str_expr = "atan(x)"; - let ast_expr: ast::Expr = parse_expr(str_expr).unwrap(); - let expr = Expr::try_from(&ast_expr).unwrap(); - println!("expr = {}", expr); - let gen_expr = ast::Expr::from(&expr); - println!("ast::expr = {gen_expr}"); - assert_eq!(ast_expr, gen_expr); - } - #[test] fn test_random() { let str_expr = "random()"; diff --git a/src/sql/expr.rs b/src/sql/expr.rs index e2861759..3a0e3d34 100644 --- a/src/sql/expr.rs +++ b/src/sql/expr.rs @@ -928,9 +928,6 @@ impl<'a> Visitor<'a, Result> for TryIntoExprVisitor<'a> { "sin" => Expr::sin(flat_args[0].clone()), "cos" => Expr::cos(flat_args[0].clone()), "tan" => Expr::divide(Expr::sin(flat_args[0].clone()), Expr::cos(flat_args[0].clone())), - "asin" => Expr::asin(flat_args[0].clone()), - "acos" => Expr::acos(flat_args[0].clone()), - "atan" => Expr::atan(flat_args[0].clone()), "sqrt" => Expr::sqrt(flat_args[0].clone()), "pow" => Expr::pow(flat_args[0].clone(), flat_args[1].clone()), "power" => Expr::pow(flat_args[0].clone(), flat_args[1].clone()),