Skip to content

Commit

Permalink
findById Optimization (#35261)
Browse files Browse the repository at this point in the history
* findById Optimization

* reformat

* add hashcode override and licence to BasicItem

* improve point read test

* address linting errors

* fix linting errors

* fix linting

* some enhancements

* fix build errors

* fix testFindByIdNull integration test failure

* add return null for CosmosAccessException

* address Fabian review comment

* update change log

* address review comments

* fix build errors

* fix build errors

* fix build errors

* fix build errors

* add testFindByIdWithInvalidId

* add missing license header

* add repository api checks for point read/query

* change pointReadWarningLogged to primitive

* edit change log and warning text

* add extra checks per Kushagra's comments

* Update CHANGELOG.md

---------

Co-authored-by: Theo van Kraay <thvankra@microsoft.com>
Co-authored-by: Kushagra Thapar <kushuthapar@gmail.com>
  • Loading branch information
3 people authored Jun 9, 2023
1 parent 43f68e7 commit 8236c59
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 6 deletions.
1 change: 1 addition & 0 deletions sdk/spring/azure-spring-data-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#### Bugs Fixed

#### Other Changes
* Optimized default implementation of `findById(ID id)` from `CrudRepository` so that it will execute point reads where id is also the partition key, and log a warning where it is not. The new behaviour is more optimal, especially for large containers with many partitions - see [PR 35261](https://github.com/Azure/azure-sdk-for-java/pull/35261).

### 3.35.0 (2023-05-25)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public class CosmosTemplate implements CosmosOperations, ApplicationContextAware
private final int maxBufferedItemCount;
private final int responseContinuationTokenLimitInKb;
private final DatabaseThroughputConfig databaseThroughputConfig;

private boolean pointReadWarningLogged = false;
private ApplicationContext applicationContext;

/**
Expand Down Expand Up @@ -351,6 +351,15 @@ public <T> T findById(Object id, Class<T> domainType, PartitionKey partitionKey)
public <T> T findById(String containerName, Object id, Class<T> domainType) {
Assert.hasText(containerName, "containerName should not be null, empty or only whitespaces");
Assert.notNull(domainType, "domainType should not be null");
CosmosEntityInformation<?, ?> cosmosEntityInformation = CosmosEntityInformation.getInstance(domainType);
String containerPartitionKey = cosmosEntityInformation.getPartitionKeyFieldName();
if ("id".equals(containerPartitionKey) && id != null) {
return findById(id, domainType, new PartitionKey(id));
}
if (!this.pointReadWarningLogged) {
LOGGER.warn("The partitionKey is not id!! Consider using findById(ID id, PartitionKey partitionKey) instead to avoid the need for using a cross partition query which results in higher latency and cost than necessary. See https://aka.ms/PointReadsInSpring for more info.");
this.pointReadWarningLogged = true;
}
String finalContainerName = getContainerNameOverride(containerName);
final String query = "select * from root where root.id = @ROOT_ID";
final SqlParameter param = new SqlParameter("@ROOT_ID", CosmosUtils.getStringIDValue(id));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public class ReactiveCosmosTemplate implements ReactiveCosmosOperations, Applica
private final int responseContinuationTokenLimitInKb;
private final IsNewAwareAuditingHandler cosmosAuditingHandler;
private final DatabaseThroughputConfig databaseThroughputConfig;

private boolean pointReadWarningLogged = false;
private ApplicationContext applicationContext;

/**
Expand Down Expand Up @@ -325,6 +325,15 @@ public <T> Mono<T> findById(Object id, Class<T> domainType) {
public <T> Mono<T> findById(String containerName, Object id, Class<T> domainType) {
Assert.hasText(containerName, "containerName should not be null, empty or only whitespaces");
Assert.notNull(domainType, "domainType should not be null");
CosmosEntityInformation<?, ?> cosmosEntityInformation = CosmosEntityInformation.getInstance(domainType);
String containerPartitionKey = cosmosEntityInformation.getPartitionKeyFieldName();
if ("id".equals(containerPartitionKey) && id != null) {
return findById(id, domainType, new PartitionKey(id));
}
if (!this.pointReadWarningLogged) {
LOGGER.warn("The partitionKey is not id!! Consider using findById(ID id, PartitionKey partitionKey) instead to avoid the need for using a cross partition query which results in higher latency and cost than necessary. See https://aka.ms/PointReadsInSpring for more info.");
this.pointReadWarningLogged = true;
}
final String finalContainerName = getContainerNameOverride(containerName);
final String query = "select * from root where root.id = @ROOT_ID";
final SqlParameter param = new SqlParameter("@ROOT_ID", CosmosUtils.getStringIDValue(id));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.azure.spring.data.cosmos.core.query.CriteriaType;
import com.azure.spring.data.cosmos.domain.AuditableEntity;
import com.azure.spring.data.cosmos.domain.AutoScaleSample;
import com.azure.spring.data.cosmos.domain.BasicItem;
import com.azure.spring.data.cosmos.domain.GenIdEntity;
import com.azure.spring.data.cosmos.domain.Person;
import com.azure.spring.data.cosmos.exception.CosmosAccessException;
Expand Down Expand Up @@ -103,10 +104,14 @@ public class CosmosTemplateIT {
private static final Person TEST_PERSON_3 = new Person(ID_3, NEW_FIRST_NAME, NEW_LAST_NAME, HOBBIES,
ADDRESSES, AGE, PASSPORT_IDS_BY_COUNTRY);

private static final BasicItem BASIC_ITEM = new BasicItem(ID_1);

private static final String PRECONDITION_IS_NOT_MET = "is not met";

private static final String WRONG_ETAG = "WRONG_ETAG";

private static final String INVALID_ID = "http://xxx.html";

private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final JsonNode NEW_PASSPORT_IDS_BY_COUNTRY_JSON = OBJECT_MAPPER.convertValue(NEW_PASSPORT_IDS_BY_COUNTRY, JsonNode.class);

Expand Down Expand Up @@ -135,6 +140,7 @@ public class CosmosTemplateIT {

private Person insertedPerson;

private BasicItem pointReadItem;
@Autowired
private ApplicationContext applicationContext;
@Autowired
Expand All @@ -157,9 +163,11 @@ public void setUp() throws ClassNotFoundException {
}

collectionManager.ensureContainersCreatedAndEmpty(cosmosTemplate, Person.class,
GenIdEntity.class, AuditableEntity.class);
GenIdEntity.class, AuditableEntity.class, BasicItem.class);
insertedPerson = cosmosTemplate.insert(Person.class.getSimpleName(), TEST_PERSON,
new PartitionKey(TEST_PERSON.getLastName()));
pointReadItem = cosmosTemplate.insert(BasicItem.class.getSimpleName(), BASIC_ITEM,
new PartitionKey(BASIC_ITEM.getId()));
}

private CosmosTemplate createCosmosTemplate(CosmosConfig config, String dbName) throws ClassNotFoundException {
Expand Down Expand Up @@ -187,6 +195,7 @@ public void testInsertDuplicateIdShouldFailWithConflictException() {
}
}


@Test(expected = CosmosAccessException.class)
public void testInsertShouldFailIfColumnNotAnnotatedWithAutoGenerate() {
final Person person = new Person(null, FIRST_NAME, LAST_NAME, HOBBIES, ADDRESSES, AGE, PASSPORT_IDS_BY_COUNTRY);
Expand All @@ -212,6 +221,20 @@ public void testFindAll() {
assertThat(responseDiagnosticsTestUtils.getCosmosResponseStatistics().getRequestCharge()).isGreaterThan(0);
}

@Test
public void testFindByIdPointRead() {
final BasicItem result = cosmosTemplate.findById(BasicItem.class.getSimpleName(),
BASIC_ITEM.getId(), BasicItem.class);
assertEquals(result, BASIC_ITEM);
assertThat(responseDiagnosticsTestUtils.getCosmosDiagnostics()).isNotNull();
assertThat(responseDiagnosticsTestUtils.getCosmosResponseStatistics()).isNull();
final BasicItem nullResult = cosmosTemplate.findById(BasicItem.class.getSimpleName(),
NOT_EXIST_ID, BasicItem.class);
assertThat(nullResult).isNull();
assertThat(responseDiagnosticsTestUtils.getCosmosDiagnostics()).isNotNull();
assertThat(responseDiagnosticsTestUtils.getCosmosDiagnostics().toString().contains("\"requestOperationType\":\"Read\"")).isTrue();
}

@Test
public void testFindById() {
final Person result = cosmosTemplate.findById(Person.class.getSimpleName(),
Expand All @@ -227,6 +250,17 @@ public void testFindById() {
assertThat(responseDiagnosticsTestUtils.getCosmosDiagnostics()).isNotNull();
}

@Test
public void testFindByIdWithInvalidId() {
try {
cosmosTemplate.findById(BasicItem.class.getSimpleName(),
INVALID_ID, BasicItem.class);
fail();
} catch (CosmosAccessException ex) {
assertThat(responseDiagnosticsTestUtils.getCosmosDiagnostics()).isNotNull();
}
}

@Test
public void testFindByMultiIds() {
cosmosTemplate.insert(TEST_PERSON_2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.azure.spring.data.cosmos.core.query.CriteriaType;
import com.azure.spring.data.cosmos.domain.AuditableEntity;
import com.azure.spring.data.cosmos.domain.AutoScaleSample;
import com.azure.spring.data.cosmos.domain.BasicItem;
import com.azure.spring.data.cosmos.domain.GenIdEntity;
import com.azure.spring.data.cosmos.domain.Person;
import com.azure.spring.data.cosmos.exception.CosmosAccessException;
Expand Down Expand Up @@ -67,8 +68,10 @@
import static com.azure.spring.data.cosmos.common.TestConstants.AGE;
import static com.azure.spring.data.cosmos.common.TestConstants.FIRST_NAME;
import static com.azure.spring.data.cosmos.common.TestConstants.HOBBIES;
import static com.azure.spring.data.cosmos.common.TestConstants.ID_1;
import static com.azure.spring.data.cosmos.common.TestConstants.LAST_NAME;
import static com.azure.spring.data.cosmos.common.TestConstants.NEW_PASSPORT_IDS_BY_COUNTRY;
import static com.azure.spring.data.cosmos.common.TestConstants.NOT_EXIST_ID;
import static com.azure.spring.data.cosmos.common.TestConstants.PASSPORT_IDS_BY_COUNTRY;
import static com.azure.spring.data.cosmos.common.TestConstants.PATCH_AGE_1;
import static com.azure.spring.data.cosmos.common.TestConstants.PATCH_AGE_INCREMENT;
Expand All @@ -95,12 +98,15 @@ public class ReactiveCosmosTemplateIT {
private static final Person TEST_PERSON_4 = new Person(TestConstants.ID_4, TestConstants.NEW_FIRST_NAME,
TestConstants.NEW_LAST_NAME, TestConstants.HOBBIES, TestConstants.ADDRESSES, AGE, PASSPORT_IDS_BY_COUNTRY);

private static final BasicItem BASIC_ITEM = new BasicItem(ID_1);
private static final String PRECONDITION_IS_NOT_MET = "is not met";
private static final String WRONG_ETAG = "WRONG_ETAG";

private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final JsonNode NEW_PASSPORT_IDS_BY_COUNTRY_JSON = OBJECT_MAPPER.convertValue(NEW_PASSPORT_IDS_BY_COUNTRY, JsonNode.class);

private static final String INVALID_ID = "http://xxx.html";

private static final CosmosPatchOperations operations = CosmosPatchOperations
.create()
.replace("/age", PATCH_AGE_1);
Expand Down Expand Up @@ -128,10 +134,14 @@ public class ReactiveCosmosTemplateIT {
private static ReactiveCosmosTemplate cosmosTemplate;
private static String containerName;
private static CosmosEntityInformation<Person, String> personInfo;

private static CosmosEntityInformation<BasicItem, String> itemInfo;
private static AzureKeyCredential azureKeyCredential;

private Person insertedPerson;

private BasicItem pointReadItem;

@Autowired
private ApplicationContext applicationContext;
@Autowired
Expand All @@ -150,14 +160,17 @@ public void setUp() throws ClassNotFoundException {
cosmosClientBuilder.credential(azureKeyCredential);
client = CosmosFactory.createCosmosAsyncClient(cosmosClientBuilder);
personInfo = new CosmosEntityInformation<>(Person.class);
itemInfo = new CosmosEntityInformation<>(BasicItem.class);
containerName = personInfo.getContainerName();
cosmosTemplate = createReactiveCosmosTemplate(cosmosConfig, TestConstants.DB_NAME);
}

collectionManager.ensureContainersCreatedAndEmpty(cosmosTemplate, Person.class, GenIdEntity.class, AuditableEntity.class);
collectionManager.ensureContainersCreatedAndEmpty(cosmosTemplate, Person.class, GenIdEntity.class, AuditableEntity.class, BasicItem.class);

insertedPerson = cosmosTemplate.insert(TEST_PERSON,
new PartitionKey(personInfo.getPartitionKeyFieldValue(TEST_PERSON))).block();
pointReadItem = cosmosTemplate.insert(BASIC_ITEM,
new PartitionKey(BASIC_ITEM.getId())).block();
}

private ReactiveCosmosTemplate createReactiveCosmosTemplate(CosmosConfig config, String dbName) throws ClassNotFoundException {
Expand Down Expand Up @@ -186,6 +199,28 @@ public void testInsertDuplicateId() {
assertThat(responseDiagnosticsTestUtils.getCosmosDiagnostics()).isNotNull();
}

@Test
public void testFindByIdWithInvalidId() {
final Mono<BasicItem> readMono = cosmosTemplate.findById(BasicItem.class.getSimpleName(),
INVALID_ID, BasicItem.class);
StepVerifier.create(readMono)
.expectErrorMatches(ex -> ex instanceof CosmosAccessException)
.verify();
}

@Test
public void testFindByIdPointRead() {
final Mono<BasicItem> findById = cosmosTemplate.findById(BasicItem.class.getSimpleName(),
BASIC_ITEM.getId(),
BasicItem.class);
StepVerifier.create(findById)
.consumeNextWith(actual -> Assert.assertEquals(actual, BASIC_ITEM))
.verifyComplete();
assertThat(responseDiagnosticsTestUtils.getCosmosDiagnostics()).isNotNull();
assertThat(responseDiagnosticsTestUtils.getCosmosResponseStatistics()).isNull();
assertThat(responseDiagnosticsTestUtils.getCosmosDiagnostics().toString().contains("\"requestOperationType\":\"Read\"")).isTrue();
}

@Test
public void testFindByID() {
final Mono<Person> findById = cosmosTemplate.findById(Person.class.getSimpleName(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.spring.data.cosmos.domain;

import com.azure.spring.data.cosmos.core.mapping.Container;
import com.azure.spring.data.cosmos.core.mapping.PartitionKey;
import org.springframework.data.annotation.Id;

import java.util.Objects;

@Container()
public class BasicItem {

@Id
@PartitionKey
private String id;

public BasicItem(String id) {
this.id = id;
}

public String getId() {
return id;
}

public void setId(String id) {
this.id = id;
}

public BasicItem() {
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final BasicItem item = (BasicItem) o;
return Objects.equals(id, item.id);
}

@Override
public int hashCode() {
return Objects.hash(id);
}

@Override
public String toString() {
return "BasicItem{"
+ "id='"
+ id
+ '\''
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.azure.cosmos.models.CosmosPatchOperations;
import com.azure.cosmos.models.PartitionKey;
import com.azure.spring.data.cosmos.IntegrationTestCollectionManager;
import com.azure.spring.data.cosmos.common.ResponseDiagnosticsTestUtils;
import com.azure.spring.data.cosmos.common.TestConstants;
import com.azure.spring.data.cosmos.common.TestUtils;
import com.azure.spring.data.cosmos.core.CosmosTemplate;
Expand Down Expand Up @@ -52,6 +53,9 @@ public class AddressRepositoryIT {
@Autowired
private CosmosTemplate template;

@Autowired
private ResponseDiagnosticsTestUtils responseDiagnosticsTestUtils;

@Rule
public ExpectedException expectedException = ExpectedException.none();

Expand All @@ -77,6 +81,15 @@ public void setUp() {
TEST_ADDRESS2_PARTITION1, TEST_ADDRESS4_PARTITION3));
}

@Test
public void testFindById() {
// test findById (ID id) cross partition
final Address result = repository.findById(TEST_ADDRESS1_PARTITION1.getPostalCode()).get();
assertThat(responseDiagnosticsTestUtils.getCosmosResponseStatistics()).isNotNull();
assertThat(responseDiagnosticsTestUtils.getCosmosDiagnostics().toString().contains("\"requestOperationType\":\"Query\"")).isTrue();
assertThat(result).isEqualTo(TEST_ADDRESS1_PARTITION1);
}

@Test
public void testFindAll() {
// findAll cross partition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
package com.azure.spring.data.cosmos.repository.integration;

import com.azure.spring.data.cosmos.IntegrationTestCollectionManager;
import com.azure.spring.data.cosmos.common.ResponseDiagnosticsTestUtils;
import com.azure.spring.data.cosmos.core.CosmosTemplate;
import com.azure.spring.data.cosmos.core.ResponseDiagnostics;
import com.azure.spring.data.cosmos.domain.Question;
import com.azure.spring.data.cosmos.repository.TestRepositoryConfig;
import com.azure.spring.data.cosmos.repository.repository.QuestionRepository;
Expand All @@ -21,6 +23,8 @@
import java.util.List;
import java.util.Optional;

import static org.assertj.core.api.Assertions.assertThat;

@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(classes = TestRepositoryConfig.class)
public class QuestionRepositoryIT {
Expand All @@ -29,6 +33,8 @@ public class QuestionRepositoryIT {

private static final String QUESTION_URL = "http://xxx.html";

private static final String NULL_ID = "null-id";

private static final Question QUESTION = new Question(QUESTION_ID, QUESTION_URL);

@ClassRule
Expand All @@ -40,6 +46,9 @@ public class QuestionRepositoryIT {
@Autowired
private QuestionRepository repository;

@Autowired
private ResponseDiagnosticsTestUtils responseDiagnosticsTestUtils;

@Before
public void setUp() {
collectionManager.ensureContainersCreatedAndEmpty(template, Question.class);
Expand All @@ -49,14 +58,14 @@ public void setUp() {
@Test
public void testFindById() {
final Optional<Question> optional = this.repository.findById(QUESTION_ID);

assertThat(responseDiagnosticsTestUtils.getCosmosResponseStatistics()).isNull();
Assert.assertTrue(optional.isPresent());
Assert.assertEquals(QUESTION, optional.get());
}

@Test
public void testFindByIdNull() {
final Optional<Question> byId = this.repository.findById(QUESTION_URL);
final Optional<Question> byId = this.repository.findById(NULL_ID);
Assert.assertFalse(byId.isPresent());
}

Expand Down
Loading

0 comments on commit 8236c59

Please sign in to comment.