|
13 | 13 | */
|
14 | 14 | package com.lancedb.lance;
|
15 | 15 |
|
| 16 | +import com.lancedb.lance.ipc.LanceScanner; |
16 | 17 | import com.lancedb.lance.schema.ColumnAlteration;
|
17 | 18 | import com.lancedb.lance.schema.SqlExpressions;
|
18 | 19 |
|
| 20 | +import org.apache.arrow.c.ArrowArrayStream; |
| 21 | +import org.apache.arrow.c.Data; |
19 | 22 | import org.apache.arrow.memory.BufferAllocator;
|
20 | 23 | import org.apache.arrow.memory.RootAllocator;
|
| 24 | +import org.apache.arrow.vector.FieldVector; |
| 25 | +import org.apache.arrow.vector.IntVector; |
21 | 26 | import org.apache.arrow.vector.VectorSchemaRoot;
|
22 | 27 | import org.apache.arrow.vector.ipc.ArrowReader;
|
23 | 28 | import org.apache.arrow.vector.types.pojo.ArrowType;
|
@@ -351,6 +356,107 @@ void testAddColumnBySqlExpressions() {
|
351 | 356 | }
|
352 | 357 | }
|
353 | 358 |
|
| 359 | + @Test |
| 360 | + void testAddColumnsByStream() throws IOException { |
| 361 | + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); |
| 362 | + String datasetPath = tempDir.resolve(testMethodName).toString(); |
| 363 | + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { |
| 364 | + TestUtils.SimpleTestDataset testDataset = |
| 365 | + new TestUtils.SimpleTestDataset(allocator, datasetPath); |
| 366 | + |
| 367 | + try (Dataset initialDataset = testDataset.createEmptyDataset()) { |
| 368 | + try (Dataset datasetV1 = testDataset.write(1, 3)) { |
| 369 | + assertEquals(3, datasetV1.countRows()); |
| 370 | + } |
| 371 | + } |
| 372 | + |
| 373 | + dataset = Dataset.open(datasetPath, allocator); |
| 374 | + |
| 375 | + Schema newColumnSchema = |
| 376 | + new Schema( |
| 377 | + Collections.singletonList(Field.nullable("age", new ArrowType.Int(32, true))), null); |
| 378 | + |
| 379 | + try (VectorSchemaRoot vector = VectorSchemaRoot.create(newColumnSchema, allocator); |
| 380 | + ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { |
| 381 | + |
| 382 | + IntVector ageVector = (IntVector) vector.getVector("age"); |
| 383 | + ageVector.allocateNew(3); |
| 384 | + ageVector.set(0, 25); |
| 385 | + ageVector.set(1, 30); |
| 386 | + ageVector.set(2, 35); |
| 387 | + vector.setRowCount(3); |
| 388 | + |
| 389 | + class SimpleVectorReader extends ArrowReader { |
| 390 | + private boolean batchLoaded = false; |
| 391 | + |
| 392 | + protected SimpleVectorReader(BufferAllocator allocator) { |
| 393 | + super(allocator); |
| 394 | + } |
| 395 | + |
| 396 | + @Override |
| 397 | + public boolean loadNextBatch() { |
| 398 | + if (!batchLoaded) { |
| 399 | + batchLoaded = true; |
| 400 | + return true; |
| 401 | + } |
| 402 | + return false; |
| 403 | + } |
| 404 | + |
| 405 | + @Override |
| 406 | + public VectorSchemaRoot getVectorSchemaRoot() { |
| 407 | + return vector; |
| 408 | + } |
| 409 | + |
| 410 | + @Override |
| 411 | + public long bytesRead() { |
| 412 | + return vector.getFieldVectors().stream().mapToLong(FieldVector::getBufferSize).sum(); |
| 413 | + } |
| 414 | + |
| 415 | + @Override |
| 416 | + protected void closeReadSource() {} |
| 417 | + |
| 418 | + @Override |
| 419 | + protected Schema readSchema() { |
| 420 | + return newColumnSchema; |
| 421 | + } |
| 422 | + } |
| 423 | + |
| 424 | + try (ArrowReader reader = new SimpleVectorReader(allocator)) { |
| 425 | + Data.exportArrayStream(allocator, reader, stream); |
| 426 | + |
| 427 | + dataset.addColumns(stream, Optional.of(3L)); |
| 428 | + |
| 429 | + Schema expectedSchema = |
| 430 | + new Schema( |
| 431 | + Arrays.asList( |
| 432 | + Field.nullable("id", new ArrowType.Int(32, true)), |
| 433 | + Field.nullable("name", new ArrowType.Utf8()), |
| 434 | + Field.nullable("age", new ArrowType.Int(32, true))), |
| 435 | + null); |
| 436 | + Schema actualSchema = dataset.getSchema(); |
| 437 | + assertEquals(expectedSchema.getFields(), actualSchema.getFields()); |
| 438 | + |
| 439 | + try (LanceScanner scanner = dataset.newScan()) { |
| 440 | + try (ArrowReader resultReader = scanner.scanBatches()) { |
| 441 | + assertTrue(resultReader.loadNextBatch()); |
| 442 | + VectorSchemaRoot root = resultReader.getVectorSchemaRoot(); |
| 443 | + assertEquals(3, root.getRowCount()); |
| 444 | + |
| 445 | + IntVector idVector = (IntVector) root.getVector("id"); |
| 446 | + IntVector ageVectorResult = (IntVector) root.getVector("age"); |
| 447 | + for (int i = 0; i < 3; i++) { |
| 448 | + assertEquals(i, idVector.get(i)); |
| 449 | + assertEquals(25 + i * 5, ageVectorResult.get(i)); |
| 450 | + } |
| 451 | + } |
| 452 | + } |
| 453 | + } |
| 454 | + } |
| 455 | + } catch (Exception e) { |
| 456 | + fail("Exception occurred during test: " + e.getMessage(), e); |
| 457 | + } |
| 458 | + } |
| 459 | + |
354 | 460 | @Test
|
355 | 461 | void testDropPath() {
|
356 | 462 | String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName();
|
|
0 commit comments