Skip to content

Commit

Permalink
Fixed CASE statement to work correctly when branch expression is of N…
Browse files Browse the repository at this point in the history
…umeric and Decimal type (#3381)

Issue:
- The CASE statement was unable to calculate the correct common typmod when the common type of the CASE statement was NUMERIC or DECIMAL. For CASE expressions, we relied on a PostgreSQL function to calculate the typmod when the common type was NUMERIC or DECIMAL. However, this approach showed different behavior from T-SQL. Therefore, we need to calculate the common typmod according to T-SQL documentation to ensure consistency with T-SQL behavior.

Changes made to fix the issues:
- The tsql_select_common_typmod_hook() function has been modified to return a common typmod for all branches when the branch expression is of NUMERIC or DECIMAL data types. This calculation now follows the T-SQL documentation. Additionally, we've added a 'case T_CoerceToDomain' in the 'resolve_numeric_typmod_from_exp' function to address issues related to user-defined types (UDTs).
- We have enhanced the 'resolve_numeric_typmod_from_exp' function by adding support for 'T_SubLink' and 'T_CoerceToDomain' nodes. This improvement enables accurate typmod calculation for expressions involving subqueries (T_SubLink) and user-defined types (T_CoerceToDomain).

Signed-off-by: yashneet vinayak <yashneet@amazon.com>
Co-authored-by: yashneet vinayak <yashneet@amazon.com>
  • Loading branch information
Yvinayak07 and yashneet vinayak authored Jan 9, 2025
1 parent 1efe014 commit 1924718
Show file tree
Hide file tree
Showing 36 changed files with 68,622 additions and 100 deletions.
1 change: 1 addition & 0 deletions contrib/babelfishpg_tds/src/backend/tds/tds_srv.c
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ pe_tds_init(void)
pltsql_plugin_handler_ptr->get_datum_from_date_time_struct = &TdsDateTimeTypeToDatum;
pltsql_plugin_handler_ptr->set_reset_tds_connection_flag = &SetResetTDSConnectionFlag;
pltsql_plugin_handler_ptr->get_reset_tds_connection_flag = &GetResetTDSConnectionFlag;
pltsql_plugin_handler_ptr->get_numeric_typmod_from_exp = &resolve_numeric_typmod_from_exp;

invalidate_stat_table_hook = invalidate_stat_table;
guc_newval_hook = TdsSetGucStatVariable;
Expand Down
181 changes: 174 additions & 7 deletions contrib/babelfishpg_tds/src/backend/tds/tdsresponse.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "catalog/indexing.h"
#include "catalog/pg_proc.h"
#include "catalog/pg_type.h"
#include "catalog/pg_cast.h"
#include "miscadmin.h"
#include "nodes/makefuncs.h"
#include "nodes/pathnodes.h"
Expand Down Expand Up @@ -131,15 +132,27 @@ static Oid sys_vector_oid = InvalidOid;
static Oid sys_sparsevec_oid = InvalidOid;
static Oid sys_halfvec_oid = InvalidOid;
static Oid decimal_oid = InvalidOid;
static Oid tsql_fixeddecimal_numeric_oid = InvalidOid;
static Oid tsql_numeric_fixeddecimal_oid = InvalidOid;
static Oid tsql_bit_numeric_oid = InvalidOid;
static Oid tsql_int4_bit_oid = InvalidOid;
static Oid sys_nspoid = InvalidOid;
static Oid tsql_bit_oid = InvalidOid;
static Oid tsql_fixeddecimal_oid = InvalidOid;

static void FillTabNameWithNumParts(StringInfo buf, uint8 numParts, TdsRelationMetaDataInfo relMetaDataInfo);
static void FillTabNameWithoutNumParts(StringInfo buf, uint8 numParts, TdsRelationMetaDataInfo relMetaDataInfo);
static void SetTdsEstateErrorData(void);
static void ResetTdsEstateErrorData(void);
static bool is_numeric_cast(Oid func_oid);
static void SetAttributesForColmetada(TdsColumnMetaData *col);
static int32 resolve_numeric_typmod_from_exp(Plan *plan, Node *expr);
static int32 resolve_numeric_typmod_outer_var(Plan *plan, AttrNumber attno);
static bool is_this_a_vector_datatype(Oid oid);
static bool is_tsql_fixeddecimal_numeric(Oid oid);
static bool is_tsql_numeric_fixeddecimal(Oid oid);
static bool is_tsql_bit_numeric(Oid oid);
static bool is_tsql_int4_bit(Oid oid);
static Oid LookupCastFuncName(Oid castsource, Oid casttarget);

static inline void
SendPendingDone(bool more)
Expand Down Expand Up @@ -516,6 +529,96 @@ resolve_numeric_typmod_outer_var(Plan *plan, AttrNumber attno)
return resolve_numeric_typmod_from_exp(outerplan, (Node *)tle->expr);
}

static Oid
LookupCastFuncName(Oid castsource, Oid casttarget)
{
HeapTuple tuple;
Form_pg_cast castForm;

tuple = SearchSysCache2(CASTSOURCETARGET,
ObjectIdGetDatum(castsource),
ObjectIdGetDatum(casttarget));
if (HeapTupleIsValid(tuple))
{
castForm = (Form_pg_cast) GETSTRUCT(tuple);
ReleaseSysCache(tuple);
return castForm->castfunc;
}
return InvalidOid;
}

static bool
is_tsql_bit_numeric(Oid oid)
{
if (!OidIsValid(tsql_bit_numeric_oid))
tsql_bit_numeric_oid = LookupCastFuncName(tsql_bit_oid, NUMERICOID);
return tsql_bit_numeric_oid == oid;
}

static bool
is_tsql_fixeddecimal_numeric(Oid oid)
{
if (!OidIsValid(tsql_fixeddecimal_numeric_oid))
tsql_fixeddecimal_numeric_oid = LookupCastFuncName(tsql_fixeddecimal_oid, NUMERICOID);
return tsql_fixeddecimal_numeric_oid == oid;
}

static bool
is_tsql_numeric_fixeddecimal(Oid oid)
{
if (!OidIsValid(tsql_numeric_fixeddecimal_oid))
tsql_numeric_fixeddecimal_oid = LookupCastFuncName(NUMERICOID, tsql_fixeddecimal_oid);
return tsql_numeric_fixeddecimal_oid == oid;
}

static bool
is_tsql_int4_bit(Oid oid)
{
if (!OidIsValid(tsql_int4_bit_oid))
tsql_int4_bit_oid = LookupCastFuncName(INT4OID, tsql_bit_oid);
return tsql_int4_bit_oid == oid;
}

/*
* is_numeric_cast checks if the given datatype can be cast to NUMERIC.
* This information is used when processing T_FuncExpr nodes to determine
* if resolve_numeric_typmod_from_exp should be called recursively.
* This ensures proper typmod resolution for nested numeric conversions.
*/
static bool
is_numeric_cast(Oid func_oid)
{
if (!OidIsValid(sys_nspoid))
sys_nspoid = get_namespace_oid("sys", false);

if (!OidIsValid(tsql_bit_oid))
tsql_bit_oid = GetSysCacheOid2(TYPENAMENSP, Anum_pg_type_oid, CStringGetDatum("bit"), ObjectIdGetDatum(sys_nspoid));

if (!OidIsValid(tsql_fixeddecimal_oid))
tsql_fixeddecimal_oid = GetSysCacheOid2(TYPENAMENSP, Anum_pg_type_oid, CStringGetDatum("fixeddecimal"), ObjectIdGetDatum(sys_nspoid));

if (func_oid == F_NUMERIC_INT4 ||
func_oid == F_NUMERIC_INT8 ||
func_oid == F_NUMERIC_INT2 ||
func_oid == F_NUMERIC_FLOAT4 ||
func_oid == F_NUMERIC_FLOAT8 ||
func_oid == F_INT8_INT4 ||
func_oid == F_INT4_INT8 ||
func_oid == F_INT8_INT2 ||
func_oid == F_INT2_INT8 ||
func_oid == F_INT4_INT2 ||
func_oid == F_INT2_INT4 ||
func_oid == F_INT4_NUMERIC ||
func_oid == F_INT2_NUMERIC ||
func_oid == F_INT8_NUMERIC ||
is_tsql_bit_numeric(func_oid) ||
is_tsql_int4_bit(func_oid) ||
is_tsql_fixeddecimal_numeric(func_oid) ||
is_tsql_numeric_fixeddecimal(func_oid))
return true;
return false;
}

/*
* is_numeric_datatype - returns bool if given datatype is numeric or decimal.
*/
Expand All @@ -535,7 +638,7 @@ is_numeric_datatype(Oid typid)
}

/* look for a typmod to return from a numeric expression */
static int32
int32
resolve_numeric_typmod_from_exp(Plan *plan, Node *expr)
{
if (expr == NULL)
Expand All @@ -546,14 +649,32 @@ resolve_numeric_typmod_from_exp(Plan *plan, Node *expr)
{
Const *con = (Const *) expr;
Numeric num;

if (!is_numeric_datatype(con->consttype) || con->constisnull)
int64 val;

if ((!(con->consttype == INT4OID) && !is_numeric_datatype(con->consttype)) ||
con->constisnull)
{
/* typmod is undefined */
return -1;
}
else
{
/*
* This function calculates the typmod for INT4
* constants when called from the babelfishpg_tsql
* extension (referred to as non-plan context). It
* converts the INT4 value to NUMERIC and then determines
* the appropriate typmod. This process ensures correct
* numeric precision handling in Babelfish TSQL operations.
*/
if (plan == NULL && con->consttype == INT4OID)
{
val = con->constvalue;
num = int64_to_numeric(val);
return numeric_get_typmod(num);
}
else if (plan != NULL && con->consttype == INT4OID)
return -1;
num = (Numeric) con->constvalue;
return numeric_get_typmod(num);
}
Expand All @@ -563,7 +684,7 @@ resolve_numeric_typmod_from_exp(Plan *plan, Node *expr)
Var *var = (Var *) expr;

/* If this var referes to tuple returned by its outer plan then find the original tle from it */
if (var->varno == OUTER_VAR)
if (plan != NULL && var->varno == OUTER_VAR)
{
Assert(plan);
return (resolve_numeric_typmod_outer_var(plan, var->varattno));
Expand Down Expand Up @@ -743,7 +864,6 @@ resolve_numeric_typmod_from_exp(Plan *plan, Node *expr)
precision = TDS_MAX_NUM_PRECISION;
scale = Max(scale - delta, 0);
}

/*
* Control reaching here for only arithmetic overflow
* cases
Expand All @@ -756,7 +876,7 @@ resolve_numeric_typmod_from_exp(Plan *plan, Node *expr)
FuncExpr *func = (FuncExpr *) expr;
Oid func_oid = InvalidOid;
int rettypmod = -1;

Node *arg = NULL;
/* Be smart about length-coercion functions... */
if (exprIsLengthCoercion(expr, &rettypmod))
return rettypmod;
Expand All @@ -772,6 +892,21 @@ resolve_numeric_typmod_from_exp(Plan *plan, Node *expr)
rettypmod = pltsql_plugin_handler_ptr->pltsql_read_numeric_typmod(func_oid,
func->args == NIL ? 0 : func->args->length,
func->funcresulttype);

/*
* If the following conditions are met then we will recursively find typmod from arg.
* 1) plan == NULL means we are invoking this function during parsing phase.
* 2) rettypmod == -1 means unable to find typmod till now.
* 3) check if only one args and then is that castable to numeric.
*/
if (plan == NULL &&
rettypmod == -1 &&
list_length(func->args) == 1 &&
is_numeric_cast(func_oid))
{
arg = linitial(func->args);
return resolve_numeric_typmod_from_exp(plan, arg);
}
return rettypmod;
}
case T_NullIfExpr:
Expand Down Expand Up @@ -943,6 +1078,38 @@ resolve_numeric_typmod_from_exp(Plan *plan, Node *expr)
else
return resolve_numeric_typmod_from_exp(plan, (Node *) rlt->arg);
}
case T_CoerceToDomain:
{
/* Copied from exprTypmod. */
CoerceToDomain *rlt = (CoerceToDomain *) expr;

if (rlt->resulttypmod != -1)
return rlt->resulttypmod;
else
return resolve_numeric_typmod_from_exp(plan, (Node *) rlt->arg);
}
case T_SubLink:
{
/* Copied from exprTypmod. */
const SubLink *sublink = (const SubLink *) expr;

if (sublink->subLinkType == EXPR_SUBLINK ||
sublink->subLinkType == ARRAY_SUBLINK)
{
/* get the typmod of the subselect's first target column */
Query *qtree = (Query *) sublink->subselect;
TargetEntry *tent;

if (!qtree || !IsA(qtree, Query))
elog(ERROR, "cannot get type for untransformed sublink");
tent = linitial_node(TargetEntry, qtree->targetList);
Assert(!tent->resjunk);
return resolve_numeric_typmod_from_exp(plan, (Node *) tent->expr);
/* note we don't need to care if it's an array */
}
/* otherwise, result is RECORD or BOOLEAN, typmod is -1 */
return -1;
}
/* TODO handle more Expr types if needed */
default:
return -1;
Expand Down
1 change: 1 addition & 0 deletions contrib/babelfishpg_tds/src/include/tds_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,6 @@ extern void TDSStatementExceptionCallback(PLtsql_execstate *estate, PLtsql_stmt
bool terminate_batch);
extern void SendColumnMetadata(TupleDesc typeinfo, List *targetlist, int16 *formats);
extern bool GetTdsEstateErrorData(int *number, int *severity, int *state);
extern int32 resolve_numeric_typmod_from_exp(Plan *plan, Node *expr);

#endif /* TDS_H */
1 change: 1 addition & 0 deletions contrib/babelfishpg_tsql/src/pltsql.h
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,7 @@ typedef struct PLtsql_protocol_plugin
bool (*get_reset_tds_connection_flag) ();
void (*get_tvp_typename_typeschemaname) (char *proc_name, char *target_arg_name,
char **tvp_type_name, char **tvp_type_schema_name);
int32 (*get_numeric_typmod_from_exp) (Plan *plan, Node *expr);
/* Session level GUCs */
bool quoted_identifier;
bool arithabort;
Expand Down
Loading

0 comments on commit 1924718

Please sign in to comment.