Skip to content

Commit

Permalink
Add test for fetchAndApplyDynamicConf
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Mar 13, 2024
1 parent afd30ee commit 8773b65
Showing 1 changed file with 83 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,45 @@
package org.apache.uniffle.shuffle.manager;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

import com.google.common.collect.ImmutableMap;
import org.apache.spark.SparkConf;
import org.apache.spark.shuffle.RssSparkConfig;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.function.Executable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import org.mockito.ArgumentCaptor;
import org.mockito.MockedStatic;

import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.request.RssFetchClientConfRequest;
import org.apache.uniffle.client.response.RssFetchClientConfResponse;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;

import static org.apache.uniffle.common.rpc.StatusCode.INVALID_REQUEST;
import static org.apache.uniffle.common.rpc.StatusCode.SUCCESS;
import static org.apache.uniffle.shuffle.manager.RssShuffleManagerBase.getTaskAttemptIdForBlockId;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrowsExactly;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class RssShuffleManagerBaseTest {

Expand Down Expand Up @@ -628,4 +647,67 @@ public void testGetTaskAttemptIdWithSpeculation() {
// check that a lower mapIndex works as expected
assertEquals(bits("11111111|00"), getTaskAttemptIdForBlockId(255, 0, 4, false, 10));
}

@Test
void testFetchAndApplyDynamicConf() {
ClientType clientType = ClientType.GRPC;
String coordinators = "host1,host2,host3";
int timeout = RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.defaultValue().get() / 10;

SparkConf conf = new SparkConf();
conf.set(RssSparkConfig.RSS_CLIENT_TYPE, clientType.toString());
conf.set(RssSparkConfig.RSS_COORDINATOR_QUORUM, coordinators);
conf.set(RssSparkConfig.RSS_ACCESS_TIMEOUT_MS, timeout);

CoordinatorClientFactory mockFactoryInstance = mock(CoordinatorClientFactory.class);
CoordinatorClient mockClient1 = mock(CoordinatorClient.class);
CoordinatorClient mockClient2 = mock(CoordinatorClient.class);
CoordinatorClient mockClient3 = mock(CoordinatorClient.class);

Map<String, String> clientConf1 = ImmutableMap.of("rss.config.from", "client1");
Map<String, String> clientConf2 = ImmutableMap.of("rss.config.from", "client2");
Map<String, String> clientConf3 = ImmutableMap.of("rss.config.from", "client3");

when(mockClient1.fetchClientConf(any(RssFetchClientConfRequest.class)))
.thenReturn(new RssFetchClientConfResponse(INVALID_REQUEST, "error", clientConf1));
when(mockClient2.fetchClientConf(any(RssFetchClientConfRequest.class)))
.thenReturn(new RssFetchClientConfResponse(SUCCESS, "success", clientConf2));
when(mockClient3.fetchClientConf(any(RssFetchClientConfRequest.class)))
.thenReturn(new RssFetchClientConfResponse(SUCCESS, "success", clientConf3));

List<CoordinatorClient> mockClients = Arrays.asList(mockClient1, mockClient2, mockClient3);
when(mockFactoryInstance.createCoordinatorClient(clientType, coordinators))
.thenReturn(mockClients);

assertFalse(conf.contains("rss.config.from"));
assertFalse(conf.contains("spark.rss.config.from"));

try (MockedStatic<CoordinatorClientFactory> mockFactoryStatic =
mockStatic(CoordinatorClientFactory.class)) {
mockFactoryStatic.when(CoordinatorClientFactory::getInstance).thenReturn(mockFactoryInstance);
RssShuffleManagerBase.fetchAndApplyDynamicConf(conf);
}

assertFalse(conf.contains("rss.config.from"));
assertTrue(conf.contains("spark.rss.config.from"));
assertEquals("client2", conf.get("spark.rss.config.from"));

ArgumentCaptor<RssFetchClientConfRequest> request1 =
ArgumentCaptor.forClass(RssFetchClientConfRequest.class);
ArgumentCaptor<RssFetchClientConfRequest> request2 =
ArgumentCaptor.forClass(RssFetchClientConfRequest.class);
ArgumentCaptor<RssFetchClientConfRequest> request3 =
ArgumentCaptor.forClass(RssFetchClientConfRequest.class);

verify(mockClient1, times(1)).fetchClientConf(request1.capture());
verify(mockClient2, times(1)).fetchClientConf(request2.capture());
verify(mockClient3, never()).fetchClientConf(request3.capture());

assertEquals(timeout, request1.getValue().getTimeoutMs());
assertEquals(timeout, request2.getValue().getTimeoutMs());

verify(mockClient1, times(1)).close();
verify(mockClient2, times(1)).close();
verify(mockClient3, times(1)).close();
}
}

0 comments on commit 8773b65

Please sign in to comment.