View Javadoc

1   package com.github.smokestack.ejb;
2   
3   import java.lang.annotation.Annotation;
4   import java.lang.reflect.Field;
5   import java.lang.reflect.InvocationTargetException;
6   import java.lang.reflect.Method;
7   import java.lang.reflect.Modifier;
8   import java.util.HashMap;
9   import java.util.Map;
10  import java.util.Vector;
11  
12  import static org.hamcrest.MatcherAssert.assertThat;
13  import static org.hamcrest.Matchers.*;
14  
15  import javax.annotation.PostConstruct;
16  import javax.annotation.PreDestroy;
17  import javax.ejb.EJB;
18  import javax.ejb.MessageDriven;
19  import javax.ejb.Stateful;
20  import javax.ejb.Stateless;
21  import javax.persistence.EntityManagerFactory;
22  import javax.persistence.Persistence;
23  import javax.persistence.PersistenceContext;
24  import javax.persistence.PersistenceUnit;
25  
26  import com.github.smokestack.ejb.internal.ClassFinder;
27  
28  import org.apache.commons.lang.ArrayUtils;
29  import org.apache.commons.lang.StringUtils;
30  import org.apache.commons.lang.builder.ReflectionToStringBuilder;
31  import org.apache.commons.lang.builder.ToStringStyle;
32  
33  public class MockEJBContainer {
34  	
35  	protected Map<Class<?>, Object> beans=new HashMap<Class<?>, Object>();
36  	
37  	private ClassFinder classFinder;
38  
39  	@Override
40  	public String toString(){
41  		return ReflectionToStringBuilder.toString(this, ToStringStyle.MULTI_LINE_STYLE);
42  	}
43  
44  	/**
45  	 * Setup and return an EJB instance
46  	 * @param <T>
47  	 * @param clazz
48  	 * @return
49  	 * @throws InstantiationException
50  	 * @throws IllegalAccessException
51  	 * @throws ClassNotFoundException 
52  	 * @throws IllegalArgumentException 
53  	 */
54  	@SuppressWarnings("unchecked")
55  	public <T> T getInstance(Class<T> clazz) throws MockEJBContainerException {
56  		try {
57  			if (beans.containsKey(clazz)){
58  				return (T)beans.get(clazz);
59  			}
60  			T instance = clazz.newInstance();
61  			if (clazz.isAnnotationPresent(Stateless.class) ||
62  					clazz.isAnnotationPresent(Stateful.class) ||
63  					clazz.isAnnotationPresent(MessageDriven.class)){
64  				// TODO: assume all need to be cached ...
65  				// TODO: how to setup binds maybe as http://openejb.apache.org/jndi-names.html
66  				beans.put(clazz, instance);			
67  			}
68  			// TODO: Inject others too?
69  			injectMembers(clazz, instance);
70  			if (clazz.isAnnotationPresent(Stateless.class) ||
71  					clazz.isAnnotationPresent(Stateful.class) ||
72  					clazz.isAnnotationPresent(MessageDriven.class)){
73  				callMethodAnnotated(PostConstruct.class, clazz, instance);
74  			}
75  			return instance;
76  		} catch (InstantiationException e) {
77  			throw new MockEJBContainerException("for "+clazz.getName(), e);
78  		} catch (IllegalAccessException e) {
79  			throw new MockEJBContainerException("for "+clazz.getName(), e);
80  		} catch (IllegalArgumentException e) {
81  			throw new MockEJBContainerException("for "+clazz.getName(), e);
82  		} catch (ClassNotFoundException e) {
83  			throw new MockEJBContainerException("for "+clazz.getName(), e);
84  		}
85  	}
86  
87  	/**
88  	 * Inject values for member variables
89  	 * @throws IllegalAccessException 
90  	 * @throws IllegalArgumentException 
91  	 * @throws InstantiationException 
92  	 * @throws ClassNotFoundException 
93  	 */
94  	protected <T> void injectMembers(Class<T> clazz, T instance) throws IllegalArgumentException, IllegalAccessException, InstantiationException, ClassNotFoundException {
95  		Field[] allFields=getAllFields(clazz);
96  		for (Field f: allFields){
97  			f.setAccessible(true);
98  			if (f.isAnnotationPresent(EJB.class)){
99  				Class<?> ejbClass=f.getType();
100 				if (ejbClass.isInterface() || Modifier.isAbstract(ejbClass.getModifiers())){
101 					ClassFinder finder=getClassFinder();
102 					Vector<Class<?>> impls=finder.findSubclasses(ejbClass);
103 					assertThat("expected single implementation", impls.size(), is(1));
104 					f.set(instance, this.getInstance(impls.get(0)));						
105 				} else {
106 					// TODO: do we assert this configuration?
107 					f.set(instance, this.getInstance(ejbClass));
108 				}
109 			} else if (f.isAnnotationPresent(PersistenceContext.class)){
110 				String unitName=f.getName();
111 				PersistenceContext pc=f.getAnnotation(PersistenceContext.class);
112 				String pun=pc.unitName();
113 				if (StringUtils.isNotEmpty(pun)){
114 					unitName=pun;
115 				}
116 		        EntityManagerFactory factory = Persistence.createEntityManagerFactory(unitName, System.getProperties());
117 		        // TODO: do we assert this configuration?
118                 f.set(instance, factory.createEntityManager());
119 			} else if (f.isAnnotationPresent(PersistenceUnit.class)){
120 				String unitName=f.getName();
121 				PersistenceUnit pc=f.getAnnotation(PersistenceUnit.class);
122 				String pun=pc.unitName();
123 				if (StringUtils.isNotEmpty(pun)){
124 					unitName=pun;
125 				}
126 		        // TODO: do we assert this configuration?
127                 f.set(instance, Persistence.createEntityManagerFactory(unitName, System.getProperties()));				
128 			}
129 		}
130 	}
131 
132 	private ClassFinder getClassFinder() {
133 		if (classFinder==null){
134 			classFinder=new ClassFinder();
135 		}
136 		return classFinder;
137 	}
138 
139 	/**
140 	 * Recursively get all fields
141 	 */
142 	protected <T> Field[] getAllFields(Class<T> clazz){
143 		Field[] fields=clazz.getDeclaredFields();
144 		Class<?> superClass=clazz.getSuperclass();
145 		if (superClass!=null){
146 			fields=(Field[]) ArrayUtils.addAll(fields, getAllFields(superClass));
147 		}
148 		return fields;
149 	}
150 
151 	public void cleanInstances() {
152 		for(Class<?> clazz:beans.keySet()){
153 			Object instance=beans.get(clazz);
154 			callMethodAnnotated(PreDestroy.class, clazz, instance);
155 		}
156 		beans.clear();
157 	}
158 
159 	private void callMethodAnnotated(Class<? extends Annotation> annotatedWith, Class<?> clazz, Object instance) {
160 		for(Method m:getAllMethods(clazz)){
161 			if (m.isAnnotationPresent(annotatedWith)){
162 				try {
163 					m.setAccessible(true);
164 					m.invoke(instance, new Object[]{});
165 				} catch (Exception e) {
166 					throw new MockEJBContainerException("call to @"+annotatedWith.toString()+" failed", e);
167 				}
168 			}
169 		}
170 		
171 	}
172 	
173 	/**
174 	 * Recursively get all fields
175 	 */
176 	protected <T> Method[] getAllMethods(Class<T> clazz){
177 		Method[] methods=clazz.getDeclaredMethods();
178 		Class<?> superClass=clazz.getSuperclass();
179 		if (superClass!=null){
180 			methods=(Method[]) ArrayUtils.addAll(methods, getAllMethods(superClass));
181 		}
182 		return methods;
183 	}
184 }