package com.casic.util; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.BeansException; import org.springframework.beans.factory.config.BeanFactoryPostProcessor; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.support.AbstractBeanDefinition; import org.springframework.beans.factory.support.BeanDefinitionBuilder; import org.springframework.beans.factory.support.BeanDefinitionReaderUtils; import org.springframework.beans.factory.support.BeanDefinitionRegistry; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import org.springframework.context.annotation.Bean; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; import java.util.List; import java.util.Map; import java.util.concurrent.ForkJoinPool; import java.util.function.Function; /** * @Description: Spring应用上下文工具 * @Author: wangpeng * @Date: 2022/8/11 18:04 */ @Component public class SpringContextUtil implements BeanFactoryPostProcessor, ApplicationContextAware { /** * Spring应用上下文环境 */ private static ApplicationContext applicationContext; private static ConfigurableListableBeanFactory beanFactory; /** * 实现ApplicationContextAware接口的回调方法,设置上下文环境 */ @Override public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { SpringContextUtil.applicationContext = applicationContext; } @Override public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { SpringContextUtil.beanFactory = beanFactory; } public static ApplicationContext getApplicationContext() { return applicationContext; } // public static Object getBean(String beanId) throws BeansException { // return applicationContext.getBean(beanId); // } public static <T> T getBean(Class<T> requiredType) { return (T) applicationContext.getBean(requiredType); } @SuppressWarnings("unchecked") public static <T> T getBean(String name) throws BeansException { return (T) beanFactory.getBean(name); } public static boolean containsBean(String name) { return beanFactory.containsBean(name); } public static <T> T registerBean(String beanName, Class<T> clazz, Function<BeanDefinitionBuilder, AbstractBeanDefinition> function) { // 生成bean定义 BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.genericBeanDefinition(clazz); // 执行bean处理函数 AbstractBeanDefinition beanDefinition = function.apply(beanDefinitionBuilder); BeanDefinitionRegistry beanFactory = (BeanDefinitionRegistry) SpringContextUtil.beanFactory; // 判断是否通过beanName注册 if (StringUtils.isNotBlank(beanName) && !containsBean(beanName)) { beanFactory.registerBeanDefinition(beanName, beanDefinition); return getBean(beanName); } else { // 非命名bean注册 String name = BeanDefinitionReaderUtils.registerWithGeneratedName(beanDefinition, beanFactory); return getBean(name); } } public static <T> T registerBean(String beanName, Class<T> clazz, List<Object> args, Map<String, Object> property) { return registerBean(beanName, clazz, beanDefinitionBuilder -> { // 放入构造参数 if (!CollectionUtils.isEmpty(args)) { args.forEach(beanDefinitionBuilder::addConstructorArgValue); } // 放入属性 if (!CollectionUtils.isEmpty(property)) { property.forEach(beanDefinitionBuilder::addPropertyValue); } return beanDefinitionBuilder.getBeanDefinition(); }); } @Bean public ForkJoinPool forkJoinPool() { return new ForkJoinPool(100); } }