Skip to content

Commit cca98fc

Browse files
authored
feat(java): support add columns via reader (#3456)
1 parent 9ea6b7e commit cca98fc

File tree

3 files changed

+170
-0
lines changed

3 files changed

+170
-0
lines changed

java/core/lance-jni/src/blocking_dataset.rs

+47
Original file line numberDiff line numberDiff line change
@@ -1057,3 +1057,50 @@ fn inner_add_columns_by_sql_expressions(
10571057
)?;
10581058
Ok(())
10591059
}
1060+
1061+
#[no_mangle]
1062+
pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeAddColumnsByReader(
1063+
mut env: JNIEnv,
1064+
java_dataset: JObject,
1065+
arrow_array_stream_addr: jlong,
1066+
batch_size: JObject, // Optional<Long>
1067+
) {
1068+
ok_or_throw_without_return!(
1069+
env,
1070+
inner_add_columns_by_reader(&mut env, java_dataset, arrow_array_stream_addr, batch_size)
1071+
)
1072+
}
1073+
1074+
fn inner_add_columns_by_reader(
1075+
env: &mut JNIEnv,
1076+
java_dataset: JObject,
1077+
arrow_array_stream_addr: jlong,
1078+
batch_size: JObject, // Optional<Long>
1079+
) -> Result<()> {
1080+
let stream_ptr = arrow_array_stream_addr as *mut FFI_ArrowArrayStream;
1081+
1082+
let reader = unsafe { ArrowArrayStreamReader::from_raw(stream_ptr) }?;
1083+
1084+
let transform = NewColumnTransform::Reader(Box::new(reader));
1085+
1086+
let batch_size = if env.call_method(&batch_size, "isPresent", "()Z", &[])?.z()? {
1087+
let batch_size_value = env.get_long_opt(&batch_size)?;
1088+
match batch_size_value {
1089+
Some(value) => Some(
1090+
value
1091+
.try_into()
1092+
.map_err(|_| Error::input_error("Batch size conversion error".to_string()))?,
1093+
),
1094+
None => None,
1095+
}
1096+
} else {
1097+
None
1098+
};
1099+
1100+
let mut dataset_guard =
1101+
unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?;
1102+
1103+
RT.block_on(dataset_guard.inner.add_columns(transform, None, batch_size))?;
1104+
1105+
Ok(())
1106+
}

java/core/src/main/java/com/lancedb/lance/Dataset.java

+17
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,23 @@ public void addColumns(SqlExpressions sqlExpressions, Optional<Long> batchSize)
287287
private native void nativeAddColumnsBySqlExpressions(
288288
SqlExpressions sqlExpressions, Optional<Long> batchSize);
289289

290+
/**
291+
* Add columns to the dataset.
292+
*
293+
* @param stream The Arrow Array Stream generated by arrow reader to add columns.
294+
* @param batchSize The number of rows to read at a time from the source dataset when applying the
295+
* transform.
296+
*/
297+
public void addColumns(ArrowArrayStream stream, Optional<Long> batchSize) {
298+
try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) {
299+
Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed");
300+
nativeAddColumnsByReader(stream.memoryAddress(), batchSize);
301+
}
302+
}
303+
304+
private native void nativeAddColumnsByReader(
305+
long arrowStreamMemoryAddress, Optional<Long> batchSize);
306+
290307
/**
291308
* Drop columns from the dataset.
292309
*

java/core/src/test/java/com/lancedb/lance/DatasetTest.java

+106
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,16 @@
1313
*/
1414
package com.lancedb.lance;
1515

16+
import com.lancedb.lance.ipc.LanceScanner;
1617
import com.lancedb.lance.schema.ColumnAlteration;
1718
import com.lancedb.lance.schema.SqlExpressions;
1819

20+
import org.apache.arrow.c.ArrowArrayStream;
21+
import org.apache.arrow.c.Data;
1922
import org.apache.arrow.memory.BufferAllocator;
2023
import org.apache.arrow.memory.RootAllocator;
24+
import org.apache.arrow.vector.FieldVector;
25+
import org.apache.arrow.vector.IntVector;
2126
import org.apache.arrow.vector.VectorSchemaRoot;
2227
import org.apache.arrow.vector.ipc.ArrowReader;
2328
import org.apache.arrow.vector.types.pojo.ArrowType;
@@ -351,6 +356,107 @@ void testAddColumnBySqlExpressions() {
351356
}
352357
}
353358

359+
@Test
360+
void testAddColumnsByStream() throws IOException {
361+
String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName();
362+
String datasetPath = tempDir.resolve(testMethodName).toString();
363+
try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
364+
TestUtils.SimpleTestDataset testDataset =
365+
new TestUtils.SimpleTestDataset(allocator, datasetPath);
366+
367+
try (Dataset initialDataset = testDataset.createEmptyDataset()) {
368+
try (Dataset datasetV1 = testDataset.write(1, 3)) {
369+
assertEquals(3, datasetV1.countRows());
370+
}
371+
}
372+
373+
dataset = Dataset.open(datasetPath, allocator);
374+
375+
Schema newColumnSchema =
376+
new Schema(
377+
Collections.singletonList(Field.nullable("age", new ArrowType.Int(32, true))), null);
378+
379+
try (VectorSchemaRoot vector = VectorSchemaRoot.create(newColumnSchema, allocator);
380+
ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) {
381+
382+
IntVector ageVector = (IntVector) vector.getVector("age");
383+
ageVector.allocateNew(3);
384+
ageVector.set(0, 25);
385+
ageVector.set(1, 30);
386+
ageVector.set(2, 35);
387+
vector.setRowCount(3);
388+
389+
class SimpleVectorReader extends ArrowReader {
390+
private boolean batchLoaded = false;
391+
392+
protected SimpleVectorReader(BufferAllocator allocator) {
393+
super(allocator);
394+
}
395+
396+
@Override
397+
public boolean loadNextBatch() {
398+
if (!batchLoaded) {
399+
batchLoaded = true;
400+
return true;
401+
}
402+
return false;
403+
}
404+
405+
@Override
406+
public VectorSchemaRoot getVectorSchemaRoot() {
407+
return vector;
408+
}
409+
410+
@Override
411+
public long bytesRead() {
412+
return vector.getFieldVectors().stream().mapToLong(FieldVector::getBufferSize).sum();
413+
}
414+
415+
@Override
416+
protected void closeReadSource() {}
417+
418+
@Override
419+
protected Schema readSchema() {
420+
return newColumnSchema;
421+
}
422+
}
423+
424+
try (ArrowReader reader = new SimpleVectorReader(allocator)) {
425+
Data.exportArrayStream(allocator, reader, stream);
426+
427+
dataset.addColumns(stream, Optional.of(3L));
428+
429+
Schema expectedSchema =
430+
new Schema(
431+
Arrays.asList(
432+
Field.nullable("id", new ArrowType.Int(32, true)),
433+
Field.nullable("name", new ArrowType.Utf8()),
434+
Field.nullable("age", new ArrowType.Int(32, true))),
435+
null);
436+
Schema actualSchema = dataset.getSchema();
437+
assertEquals(expectedSchema.getFields(), actualSchema.getFields());
438+
439+
try (LanceScanner scanner = dataset.newScan()) {
440+
try (ArrowReader resultReader = scanner.scanBatches()) {
441+
assertTrue(resultReader.loadNextBatch());
442+
VectorSchemaRoot root = resultReader.getVectorSchemaRoot();
443+
assertEquals(3, root.getRowCount());
444+
445+
IntVector idVector = (IntVector) root.getVector("id");
446+
IntVector ageVectorResult = (IntVector) root.getVector("age");
447+
for (int i = 0; i < 3; i++) {
448+
assertEquals(i, idVector.get(i));
449+
assertEquals(25 + i * 5, ageVectorResult.get(i));
450+
}
451+
}
452+
}
453+
}
454+
}
455+
} catch (Exception e) {
456+
fail("Exception occurred during test: " + e.getMessage(), e);
457+
}
458+
}
459+
354460
@Test
355461
void testDropPath() {
356462
String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName();

0 commit comments

Comments
 (0)