1 package com.github.smokestack.ejb;
2
3 import java.lang.annotation.Annotation;
4 import java.lang.reflect.Field;
5 import java.lang.reflect.InvocationTargetException;
6 import java.lang.reflect.Method;
7 import java.lang.reflect.Modifier;
8 import java.util.HashMap;
9 import java.util.Map;
10 import java.util.Vector;
11
12 import static org.hamcrest.MatcherAssert.assertThat;
13 import static org.hamcrest.Matchers.*;
14
15 import javax.annotation.PostConstruct;
16 import javax.annotation.PreDestroy;
17 import javax.ejb.EJB;
18 import javax.ejb.MessageDriven;
19 import javax.ejb.Stateful;
20 import javax.ejb.Stateless;
21 import javax.persistence.EntityManagerFactory;
22 import javax.persistence.Persistence;
23 import javax.persistence.PersistenceContext;
24 import javax.persistence.PersistenceUnit;
25
26 import com.github.smokestack.ejb.internal.ClassFinder;
27
28 import org.apache.commons.lang.ArrayUtils;
29 import org.apache.commons.lang.StringUtils;
30 import org.apache.commons.lang.builder.ReflectionToStringBuilder;
31 import org.apache.commons.lang.builder.ToStringStyle;
32
33 public class MockEJBContainer {
34
35 protected Map<Class<?>, Object> beans=new HashMap<Class<?>, Object>();
36
37 private ClassFinder classFinder;
38
39 @Override
40 public String toString(){
41 return ReflectionToStringBuilder.toString(this, ToStringStyle.MULTI_LINE_STYLE);
42 }
43
44
45
46
47
48
49
50
51
52
53
54 @SuppressWarnings("unchecked")
55 public <T> T getInstance(Class<T> clazz) throws MockEJBContainerException {
56 try {
57 if (beans.containsKey(clazz)){
58 return (T)beans.get(clazz);
59 }
60 T instance = clazz.newInstance();
61 if (clazz.isAnnotationPresent(Stateless.class) ||
62 clazz.isAnnotationPresent(Stateful.class) ||
63 clazz.isAnnotationPresent(MessageDriven.class)){
64
65
66 beans.put(clazz, instance);
67 }
68
69 injectMembers(clazz, instance);
70 if (clazz.isAnnotationPresent(Stateless.class) ||
71 clazz.isAnnotationPresent(Stateful.class) ||
72 clazz.isAnnotationPresent(MessageDriven.class)){
73 callMethodAnnotated(PostConstruct.class, clazz, instance);
74 }
75 return instance;
76 } catch (InstantiationException e) {
77 throw new MockEJBContainerException("for "+clazz.getName(), e);
78 } catch (IllegalAccessException e) {
79 throw new MockEJBContainerException("for "+clazz.getName(), e);
80 } catch (IllegalArgumentException e) {
81 throw new MockEJBContainerException("for "+clazz.getName(), e);
82 } catch (ClassNotFoundException e) {
83 throw new MockEJBContainerException("for "+clazz.getName(), e);
84 }
85 }
86
87
88
89
90
91
92
93
94 protected <T> void injectMembers(Class<T> clazz, T instance) throws IllegalArgumentException, IllegalAccessException, InstantiationException, ClassNotFoundException {
95 Field[] allFields=getAllFields(clazz);
96 for (Field f: allFields){
97 f.setAccessible(true);
98 if (f.isAnnotationPresent(EJB.class)){
99 Class<?> ejbClass=f.getType();
100 if (ejbClass.isInterface() || Modifier.isAbstract(ejbClass.getModifiers())){
101 ClassFinder finder=getClassFinder();
102 Vector<Class<?>> impls=finder.findSubclasses(ejbClass);
103 assertThat("expected single implementation", impls.size(), is(1));
104 f.set(instance, this.getInstance(impls.get(0)));
105 } else {
106
107 f.set(instance, this.getInstance(ejbClass));
108 }
109 } else if (f.isAnnotationPresent(PersistenceContext.class)){
110 String unitName=f.getName();
111 PersistenceContext pc=f.getAnnotation(PersistenceContext.class);
112 String pun=pc.unitName();
113 if (StringUtils.isNotEmpty(pun)){
114 unitName=pun;
115 }
116 EntityManagerFactory factory = Persistence.createEntityManagerFactory(unitName, System.getProperties());
117
118 f.set(instance, factory.createEntityManager());
119 } else if (f.isAnnotationPresent(PersistenceUnit.class)){
120 String unitName=f.getName();
121 PersistenceUnit pc=f.getAnnotation(PersistenceUnit.class);
122 String pun=pc.unitName();
123 if (StringUtils.isNotEmpty(pun)){
124 unitName=pun;
125 }
126
127 f.set(instance, Persistence.createEntityManagerFactory(unitName, System.getProperties()));
128 }
129 }
130 }
131
132 private ClassFinder getClassFinder() {
133 if (classFinder==null){
134 classFinder=new ClassFinder();
135 }
136 return classFinder;
137 }
138
139
140
141
142 protected <T> Field[] getAllFields(Class<T> clazz){
143 Field[] fields=clazz.getDeclaredFields();
144 Class<?> superClass=clazz.getSuperclass();
145 if (superClass!=null){
146 fields=(Field[]) ArrayUtils.addAll(fields, getAllFields(superClass));
147 }
148 return fields;
149 }
150
151 public void cleanInstances() {
152 for(Class<?> clazz:beans.keySet()){
153 Object instance=beans.get(clazz);
154 callMethodAnnotated(PreDestroy.class, clazz, instance);
155 }
156 beans.clear();
157 }
158
159 private void callMethodAnnotated(Class<? extends Annotation> annotatedWith, Class<?> clazz, Object instance) {
160 for(Method m:getAllMethods(clazz)){
161 if (m.isAnnotationPresent(annotatedWith)){
162 try {
163 m.setAccessible(true);
164 m.invoke(instance, new Object[]{});
165 } catch (Exception e) {
166 throw new MockEJBContainerException("call to @"+annotatedWith.toString()+" failed", e);
167 }
168 }
169 }
170
171 }
172
173
174
175
176 protected <T> Method[] getAllMethods(Class<T> clazz){
177 Method[] methods=clazz.getDeclaredMethods();
178 Class<?> superClass=clazz.getSuperclass();
179 if (superClass!=null){
180 methods=(Method[]) ArrayUtils.addAll(methods, getAllMethods(superClass));
181 }
182 return methods;
183 }
184 }