/*
* Copyright 2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.yarn.boot.actuate.endpoint.mvc;
import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.hadoop.yarn.api.records.Container;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.NodeId;
import org.apache.hadoop.yarn.api.records.Priority;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.actuate.autoconfigure.EndpointWebMvcAutoConfiguration;
import org.springframework.boot.actuate.autoconfigure.ManagementServerPropertiesAutoConfiguration;
import org.springframework.boot.autoconfigure.hateoas.HypermediaAutoConfiguration;
import org.springframework.boot.test.SpringApplicationConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.core.task.TaskExecutor;
import org.springframework.http.MediaType;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.annotation.DirtiesContext.ClassMode;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.web.WebAppConfiguration;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.yarn.am.YarnAppmaster;
import org.springframework.yarn.am.allocate.ContainerAllocateData;
import org.springframework.yarn.am.allocate.ContainerAllocator;
import org.springframework.yarn.am.cluster.ContainerCluster;
import org.springframework.yarn.am.cluster.ContainerClusterStateMachineConfiguration;
import org.springframework.yarn.am.cluster.ManagedContainerClusterAppmaster;
import org.springframework.yarn.am.container.ContainerLauncher;
import org.springframework.yarn.am.grid.GridProjection;
import org.springframework.yarn.am.grid.GridProjectionFactory;
import org.springframework.yarn.am.grid.GridProjectionFactoryLocator;
import org.springframework.yarn.am.grid.support.AnyGridProjection;
import org.springframework.yarn.am.grid.support.DefaultGridProjectionFactory;
import org.springframework.yarn.am.grid.support.GridProjectionFactoryRegistry;
import org.springframework.yarn.am.grid.support.ProjectionData;
import org.springframework.yarn.am.grid.support.ProjectionDataRegistry;
import org.springframework.yarn.boot.MockUtils;
import org.springframework.yarn.boot.TestUtils;
import org.springframework.yarn.boot.actuate.endpoint.YarnContainerClusterEndpoint;
import org.springframework.yarn.boot.actuate.endpoint.mvc.YarnContainerClusterWithCustomProjectionsMvcEndpointTests.TestConfiguration;
import org.springframework.yarn.listener.ContainerAllocatorListener;
@RunWith(SpringJUnit4ClassRunner.class)
@SpringApplicationConfiguration(classes = { TestConfiguration.class })
@WebAppConfiguration
@DirtiesContext(classMode=ClassMode.AFTER_EACH_TEST_METHOD)
public class YarnContainerClusterWithCustomProjectionsMvcEndpointTests {
private final static String BASE = "/" + YarnContainerClusterEndpoint.ENDPOINT_ID;
@Autowired
private WebApplicationContext context;
private MockMvc mvc;
private TestYarnAppmaster appmaster;
@Before
public void setUp() {
mvc = MockMvcBuilders.webAppContextSetup(context).build();
appmaster = (TestYarnAppmaster) context.getBean(YarnAppmaster.class);
}
/**
* Test cluster create using any projection.
*/
@Test
public void testClusterCreateCustomProjection() throws Exception {
String content = "{\"clusterId\":\"cluster1\",\"clusterDef\":\"cluster1\",\"projection\":\"custom\",\"projectionData\":{\"any\":1}},\"extraProperties\":{\"key1\":\"value1\"}}";
mvc.
perform(post(BASE).content(content).contentType(MediaType.APPLICATION_JSON)).
andExpect(status().isCreated()).
andExpect(content().string(is(""))).
andExpect(header().string("Location", endsWith(BASE + "/cluster1")));
Map<String, ContainerCluster> clusters = TestUtils.readField("clusters", appmaster);
assertThat(clusters.size(), is(1));
assertThat(clusters.containsKey("cluster1"), is(true));
assertThat(clusters.get("cluster1").getGridProjection().getProjectionData().getAny(), is(1));
}
@Import({ ContainerClusterStateMachineConfiguration.class, EndpointWebMvcAutoConfiguration.class, ManagementServerPropertiesAutoConfiguration.class,
HypermediaAutoConfiguration.class })
@EnableWebMvc
@Configuration
public static class TestConfiguration {
@Bean
public YarnContainerClusterEndpoint endpoint() {
return new YarnContainerClusterEndpoint();
}
@Bean
public YarnContainerClusterMvcEndpoint mvcEndpoint() {
return new YarnContainerClusterMvcEndpoint(endpoint());
}
@Bean
public YarnAppmaster appMaster() {
TestContainerAllocator allocator = new TestContainerAllocator();
TestContainerLauncher launcher = new TestContainerLauncher();
TestYarnAppmaster appmaster = new TestYarnAppmaster();
appmaster.setAllocator(allocator);
appmaster.setLauncher(launcher);
return appmaster;
}
@Bean
public TaskExecutor taskExecutor() {
return new SyncTaskExecutor();
}
@Bean
public GridProjectionFactoryLocator gridProjectionFactoryLocator() {
GridProjectionFactoryRegistry registry = new GridProjectionFactoryRegistry();
registry.addGridProjectionFactory(new DefaultGridProjectionFactory());
registry.addGridProjectionFactory(new TestGridProjectionFactory());
return registry;
}
@Bean
public ProjectionDataRegistry projectionDataRegistry() {
Map<String, ProjectionData> defaults = new HashMap<String, ProjectionData>();
defaults.put("cluster1", new ProjectionData(null, null, null, "any", 0));
return new ProjectionDataRegistry(defaults);
}
}
protected static class TestGridProjectionFactory implements GridProjectionFactory {
public TestGridProjectionFactory() {
}
@Override
public GridProjection getGridProjection(ProjectionData projectionData) {
GridProjection projection = null;
if ("custom".equalsIgnoreCase(projectionData.getType())) {
AnyGridProjection p = new AnyGridProjection();
p.setPriority(projectionData.getPriority());
projection = p;
}
if (projection != null) {
projection.setProjectionData(projectionData);
}
return projection;
}
@Override
public Set<String> getRegisteredProjectionTypes() {
return new HashSet<String>(Arrays.asList("custom"));
}
}
protected static class TestYarnAppmaster extends ManagedContainerClusterAppmaster {
@Override
protected void onInit() throws Exception {
setConfiguration(new org.apache.hadoop.conf.Configuration());
super.onInit();
}
@Override
protected void doStart() {
}
@Override
protected void doStop() {
}
}
protected static class TestContainerAllocator implements ContainerAllocator {
ArrayList<ContainerAllocateData> containerAllocateData;
ArrayList<Container> releaseContainers;
public void resetTestData() {
containerAllocateData = null;
releaseContainers = null;
}
@Override
public void allocateContainers(int count) {
}
@Override
public void allocateContainers(ContainerAllocateData containerAllocateData) {
if (this.containerAllocateData == null) {
this.containerAllocateData = new ArrayList<ContainerAllocateData>();
}
this.containerAllocateData.add(containerAllocateData);
}
@Override
public void releaseContainers(List<Container> containers) {
if (this.releaseContainers == null) {
this.releaseContainers = new ArrayList<Container>();
}
this.releaseContainers.addAll(containers);
}
@Override
public void releaseContainer(ContainerId containerId) {
}
@Override
public void addListener(ContainerAllocatorListener listener) {
}
@Override
public void setProgress(float progress) {
}
}
protected static class TestContainerLauncher implements ContainerLauncher {
Container container;
public void resetTestData() {
container = null;
}
@Override
public void launchContainer(Container container, List<String> commands) {
this.container = container;
}
}
protected Container allocateContainer(Object appmaster, int id) throws Exception {
return allocateContainer(appmaster, id, null);
}
protected Container allocateContainer(Object appmaster, int id, String host) throws Exception {
ContainerId containerId = MockUtils.getMockContainerId(MockUtils.getMockApplicationAttemptId(0, 0), 0);
NodeId nodeId = MockUtils.getMockNodeId(host, 0);
Priority priority = MockUtils.getMockPriority(0);
Container container = MockUtils.getMockContainer(containerId, nodeId, null, priority);
TestUtils.callMethod("onContainerAllocated", appmaster, new Object[]{container}, new Class<?>[]{Container.class});
return container;
}
}