Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPL cidr function #706

Merged
merged 13 commits into from
Oct 30, 2024
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ lazy val pplSparkIntegration = (project in file("ppl-spark-integration"))
"com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test",
"com.github.sbt" % "junit-interface" % "0.13.3" % "test",
"org.projectlombok" % "lombok" % "1.18.30",
"com.github.seancfoley" % "ipaddress" % "5.5.1",
),
libraryDependencies ++= deps(sparkVersion),
// ANTLR settings
Expand Down
2 changes: 2 additions & 0 deletions docs/ppl-lang/PPL-Example-Commands.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ _- **Limitation: new field added by eval command with a function cannot be dropp
- `source = table | where a not in (1, 2, 3) | fields a,b,c`
- `source = table | where a between 1 and 4` - Note: This returns a >= 1 and a <= 4, i.e. [1, 4]
- `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10'
- `source = table | where cidrmatch(ip, '192.169.1.0/24')`
- `source = table | where cidrmatch(ipv6, '2003:db8::/32')`

```sql
source = table | eval status_category =
Expand Down
1 change: 1 addition & 0 deletions docs/ppl-lang/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ For additional examples see the next [documentation](PPL-Example-Commands.md).

- [`Cryptographic Functions`](functions/ppl-cryptographic.md)

- [`IP Address Functions`](functions/ppl-ip.md)

---
### PPL On Spark
Expand Down
35 changes: 35 additions & 0 deletions docs/ppl-lang/functions/ppl-ip.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
## PPL IP Address Functions

### `CIDRMATCH`

**Description**

`CIDRMATCH(ip, cidr)` checks if ip is within the specified cidr range.

**Argument type:**
- STRING, STRING
- Return type: **BOOLEAN**

Example:

os> source=ips | where cidrmatch(ip, '192.169.1.0/24') | fields ip
fetched rows / total rows = 1/1
+--------------+
| ip |
|--------------|
| 192.169.1.5 |
+--------------+

os> source=ipsv6 | where cidrmatch(ip, '2003:db8::/32') | fields ip
fetched rows / total rows = 1/1
+-----------------------------------------+
| ip |
|-----------------------------------------|
| 2003:0db8:0000:0000:0000:0000:0000:0000 |
+-----------------------------------------+

Note:
- `ip` can be an IPv4 or an IPv6 address
YANG-DB marked this conversation as resolved.
Show resolved Hide resolved
- `cidr` can be an IPv4 or an IPv6 block
- `ip` and `cidr` must be either both IPv4 or both IPv6
- `ip` and `cidr` must both be valid and non-empty/non-null
2 changes: 2 additions & 0 deletions docs/ppl-lang/ppl-where-command.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ PPL query:
- `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'`
- `source = table | where a between 1 and 4` - Note: This returns a >= 1 and a <= 4, i.e. [1, 4]
- `source = table | where b not between '2024-09-10' and '2025-09-10'` - Note: This returns b >= '2024-09-10' and b <= '2025-09-10'
- `source = table | where cidrmatch(ip, '192.169.1.0/24')`
- `source = table | where cidrmatch(ipv6, '2003:db8::/32')`

- `source = table | eval status_category =
case(a >= 200 AND a < 300, 'Success',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -669,4 +669,30 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| (11, null, false)
| """.stripMargin)
}

protected def createIpAddressTable(testTable: String): Unit = {
sql(s"""
| CREATE TABLE $testTable
| (
| id INT,
| ipAddress STRING,
| isV6 BOOLEAN,
| isValid BOOLEAN
| )
| USING $tableType $tableOptions
|""".stripMargin)

sql(s"""
| INSERT INTO $testTable
| VALUES (1, '127.0.0.1', false, true),
| (2, '192.168.1.0', false, true),
| (3, '192.168.1.1', false, true),
| (4, '192.168.2.1', false, true),
| (5, '192.168.2.', false, false),
| (6, '2001:db8::ff00:12:3455', true, true),
| (7, '2001:db8::ff00:12:3456', true, true),
| (8, '2001:db8::ff00:13:3457', true, true),
| (9, '2001:db8::ff00:12:', true, false)
| """.stripMargin)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.apache.spark.SparkException
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLCidrmatchITSuite
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createIpAddressTable(testTable)
}

protected override def afterEach(): Unit = {
super.afterEach()
// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("test cidrmatch for ipv4 for 192.168.1.0/24") {
val frame = sql(s"""
| source = $testTable | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.168.1.0/24')
| """.stripMargin)

val results = frame.collect()
assert(results.length == 2)
}

test("test cidrmatch for ipv4 for 192.169.1.0/24") {
val frame = sql(s"""
| source = $testTable | where isV6 = false and isValid = true and cidrmatch(ipAddress, '192.169.1.0/24')
| """.stripMargin)

val results = frame.collect()
assert(results.length == 0)
}

test("test cidrmatch for ipv6 for 2001:db8::/32") {
val frame = sql(s"""
| source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '2001:db8::/32')
| """.stripMargin)

val results = frame.collect()
assert(results.length == 3)
}

test("test cidrmatch for ipv6 for 2003:db8::/32") {
val frame = sql(s"""
| source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '2003:db8::/32')
| """.stripMargin)

val results = frame.collect()
assert(results.length == 0)
}

test("test cidrmatch for ipv6 with ipv4 cidr") {
val frame = sql(s"""
| source = $testTable | where isV6 = true and isValid = true and cidrmatch(ipAddress, '192.169.1.0/24')
| """.stripMargin)

assertThrows[SparkException](frame.collect())
}

test("test cidrmatch for invalid ipv4 addresses") {
val frame = sql(s"""
| source = $testTable | where isV6 = false and isValid = false and cidrmatch(ipAddress, '192.169.1.0/24')
| """.stripMargin)

assertThrows[SparkException](frame.collect())
}

test("test cidrmatch for invalid ipv6 addresses") {
val frame = sql(s"""
| source = $testTable | where isV6 = true and isValid = false and cidrmatch(ipAddress, '2003:db8::/32')
| """.stripMargin)

assertThrows[SparkException](frame.collect())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ ISNULL: 'ISNULL';
ISNOTNULL: 'ISNOTNULL';
ISPRESENT: 'ISPRESENT';
BETWEEN: 'BETWEEN';
CIDRMATCH: 'CIDRMATCH';
salyh marked this conversation as resolved.
Show resolved Hide resolved

// FLOWCONTROL FUNCTIONS
IFNULL: 'IFNULL';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ booleanExpression
| isEmptyExpression # isEmptyExpr
| valueExpressionList NOT? IN LT_SQR_PRTHS subSearch RT_SQR_PRTHS # inSubqueryExpr
| EXISTS LT_SQR_PRTHS subSearch RT_SQR_PRTHS # existsSubqueryExpr
| cidrMatchFunctionCall # cidrFunctionCallExpr
;

isEmptyExpression
Expand Down Expand Up @@ -519,6 +520,10 @@ booleanFunctionCall
: conditionFunctionBase LT_PRTHS functionArgs RT_PRTHS
;

cidrMatchFunctionCall
: CIDRMATCH LT_PRTHS ipAddress = functionArg COMMA cidrBlock = functionArg RT_PRTHS
;

convertedDataType
: typeName = DATE
| typeName = TIME
Expand Down Expand Up @@ -1116,4 +1121,5 @@ keywordsCanBeId
| SEMI
| ANTI
| BETWEEN
| CIDRMATCH
;
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.sql.ast.expression.AttributeList;
import org.opensearch.sql.ast.expression.Between;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Cidr;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
Expand Down Expand Up @@ -322,4 +323,7 @@ public T visitExistsSubquery(ExistsSubquery node, C context) {
public T visitWindow(Window node, C context) {
return visitChildren(node, context);
}
public T visitCidr(Cidr node, C context) {
return visitChildren(node, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import lombok.AllArgsConstructor;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.sql.ast.AbstractNodeVisitor;

import java.util.Arrays;
import java.util.List;

/** AST node that represents CIDR function. */
@AllArgsConstructor
@Getter
@EqualsAndHashCode(callSuper = false)
@ToString
public class Cidr extends UnresolvedExpression {
private UnresolvedExpression ipAddress;
private UnresolvedExpression cidrBlock;

@Override
public List<UnresolvedExpression> getChild() {
return Arrays.asList(ipAddress, cidrBlock);
}

@Override
public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitCidr(this, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.expression.function;

import inet.ipaddr.AddressStringException;
import inet.ipaddr.IPAddressString;
import inet.ipaddr.IPAddressStringParameters;
import scala.Function2;
import scala.Serializable;
import scala.runtime.AbstractFunction2;


public interface SerializableUdf {

Function2<String,String,Boolean> cidrFunction = new SerializableAbstractFunction2<>() {

IPAddressStringParameters valOptions = new IPAddressStringParameters.Builder()
.allowEmpty(false)
.setEmptyAsLoopback(false)
.allow_inet_aton(false)
.allowSingleSegment(false)
.toParams();

@Override
public Boolean apply(String ipAddress, String cidrBlock) {

IPAddressString parsedIpAddress = new IPAddressString(ipAddress, valOptions);
salyh marked this conversation as resolved.
Show resolved Hide resolved

try {
parsedIpAddress.validate();
} catch (AddressStringException e) {
throw new RuntimeException("The given ipAddress '"+ipAddress+"' is invalid. It must be a valid IPv4 or IPv6 address. Error details: "+e.getMessage());
}

IPAddressString parsedCidrBlock = new IPAddressString(cidrBlock, valOptions);

try {
parsedCidrBlock.validate();
salyh marked this conversation as resolved.
Show resolved Hide resolved
} catch (AddressStringException e) {
throw new RuntimeException("The given cidrBlock '"+cidrBlock+"' is invalid. It must be a valid CIDR or netmask. Error details: "+e.getMessage());
}

if(parsedIpAddress.isIPv4() && parsedCidrBlock.isIPv6() || parsedIpAddress.isIPv6() && parsedCidrBlock.isIPv4()) {
throw new RuntimeException("The given ipAddress '"+ipAddress+"' and cidrBlock '"+cidrBlock+"' are not compatible. Both must be either IPv4 or IPv6.");
}

return parsedCidrBlock.contains(parsedIpAddress);
}
};

abstract class SerializableAbstractFunction2<T1,T2,R> extends AbstractFunction2<T1,T2,R>
implements Serializable {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$;
import org.apache.spark.sql.catalyst.expressions.ScalaUDF;
import org.apache.spark.sql.catalyst.expressions.SortDirection;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.plans.logical.*;
Expand Down Expand Up @@ -88,6 +89,7 @@
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.ast.tree.Window;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.expression.function.SerializableUdf;
import org.opensearch.sql.ppl.utils.AggregatorTranslator;
import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator;
import org.opensearch.sql.ppl.utils.ComparatorTransformer;
Expand All @@ -100,7 +102,11 @@
import scala.collection.IterableLike;
import scala.collection.Seq;

import java.util.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Stack;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -879,5 +885,24 @@ public Expression visitBetween(Between node, CatalystPlanContext context) {
context.retainAllNamedParseExpressions(p -> p);
return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(new GreaterThanOrEqual(value, lower), new LessThanOrEqual(value, upper)));
}

@Override
public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, CatalystPlanContext context) {
analyze(node.getIpAddress(), context);
Expression ipAddressExpression = context.getNamedParseExpressions().pop();
analyze(node.getCidrBlock(), context);
Expression cidrBlockExpression = context.getNamedParseExpressions().pop();

ScalaUDF udf = new ScalaUDF(SerializableUdf.cidrFunction,
DataTypes.BooleanType,
seq(ipAddressExpression,cidrBlockExpression),
seq(),
Option.empty(),
Option.apply("cidr"),
false,
true);

return context.getNamedParseExpressions().push(udf);
}
}
}
Loading
Loading