From e989a8af6fb23d6fe1037cb964918b38b385d623 Mon Sep 17 00:00:00 2001
From: tauslim <douglas.lim@daimler.com>
Date: Fri, 8 Mar 2024 16:27:12 +0800
Subject: [PATCH] feat: add in cosmos db in operators

---
 pkg/driver/cosmos/cosmos.go      | 31 +++++++++++++++++++++++++++++++
 pkg/driver/cosmos/cosmos_test.go | 16 ++++++++++++++++
 2 files changed, 47 insertions(+)

diff --git a/pkg/driver/cosmos/cosmos.go b/pkg/driver/cosmos/cosmos.go
index 0e605ba..3a332f5 100644
--- a/pkg/driver/cosmos/cosmos.go
+++ b/pkg/driver/cosmos/cosmos.go
@@ -135,6 +135,7 @@ func NewCosmosTranslator(r *gorql.RqlRootNode) (st *Translator) {
 	st.SetOpFunc(driver.GeOp, st.GetFieldValueTranslatorFunc(">=", convert))
 	st.SetOpFunc(driver.LeOp, st.GetFieldValueTranslatorFunc("<=", convert))
 	st.SetOpFunc(driver.NotOp, st.GetOpFirstTranslatorFunc(driver.NotOp, convert))
+	st.SetOpFunc(driver.InOp, st.GetSliceTranslatorFunc("ARRAY_CONTAINS", convert))
 
 	return
 }
@@ -281,6 +282,36 @@ func (ct *Translator) GetOpFirstTranslatorFunc(op string, valueAlterFunc AlterVa
 	}
 }
 
+func (ct *Translator) GetSliceTranslatorFunc(op string, alterValueFunc AlterValueFunc) driver.TranslatorOpFunc {
+	return func(n *gorql.RqlNode) (s string, err error) {
+		var values []string
+		var field string
+		var placeholder string
+		for i, a := range n.Args {
+			if i == 0 {
+				if gorql.IsValidField(a.(string)) {
+					field = fmt.Sprintf("c.%s", a.(string))
+				} else {
+					return "", fmt.Errorf("first argument must be a valid field name (arg: %s)", a)
+				}
+			} else {
+				placeholder = fmt.Sprintf("@p%s", strconv.Itoa(len(ct.args)+1))
+				convertedValue, err := alterValueFunc(a)
+				if err != nil {
+					return "", err
+				}
+				values = append(values, fmt.Sprintf("%v", convertedValue))
+			}
+		}
+		ct.args = append(ct.args, Param{
+			Name:  placeholder,
+			Value: values,
+		})
+		s += fmt.Sprintf(`%s, %s, false`, placeholder, field)
+		return op + "(" + s + ")", nil
+	}
+}
+
 // Args returns slice of arguments for WHERE statement
 func (ct *Translator) Args() []interface{} {
 	return ct.args
diff --git a/pkg/driver/cosmos/cosmos_test.go b/pkg/driver/cosmos/cosmos_test.go
index 88a4a4c..3f1595d 100644
--- a/pkg/driver/cosmos/cosmos_test.go
+++ b/pkg/driver/cosmos/cosmos_test.go
@@ -134,6 +134,22 @@ var tests = []Test{
 		WantParseError:      false,
 		WantTranslatorError: false,
 	},
+	{
+		Name: `Basic translation for IN operator`,
+		RQL:  `in(foo,bar,john,doe)`,
+		Model: new(struct {
+			Foo string `rql:"filter"`
+		}),
+		ExpectedSQL: `WHERE ARRAY_CONTAINS(@p1, c.foo, false)`,
+		ExpectedArgs: []interface{}{
+			Param{
+				Name:  "@p1",
+				Value: []string{"bar", "john", "doe"},
+			},
+		},
+		WantParseError:      false,
+		WantTranslatorError: false,
+	},
 	{
 		Name: `Mixed style translation`,
 		RQL:  `((eq(foo,42)&gt(price,10))|ge(price,500))&eq(disabled,false)`,