ContextConfig.java:
package geektime.tdd.di;
import java.util.*;
import static java.util.List.of;
public class ContextConfig {
private Map<Class<?>, ComponentProvider<?>> providers = new HashMap<>();
public <Type> void bind(Class<Type> type, Type instance) {
providers.put(type, new ComponentProvider<Type>() {
@Override
public Type get(Context context) {
return instance;
}
@Override
public List<Class<?>> getDependencies() {
return of();
}
});
}
public <Type, Implementation extends Type>
void bind(Class<Type> type, Class<Implementation> implementation) {
providers.put(type, new ConstructorInjectionProvider<>(implementation));
}
public Context getContext() {
providers.keySet().forEach(component -> checkDependencies(component, new Stack<>()));
return new Context() {
@Override
public <Type> Optional<Type> get(Class<Type> type) {
return Optional.ofNullable(providers.get(type)).map(provider -> (Type) provider.get(this));
}
};
}
private void checkDependencies(Class<?> component, Stack<Class<?>> visiting) {
for (Class<?> dependency: providers.get(component).getDependencies()) {
if (!providers.containsKey(dependency)) throw new DependencyNotFoundException(component, dependency);
if (visiting.contains(dependency)) throw new CyclicDependenciesFoundException(visiting);
visiting.push(dependency);
checkDependencies(dependency, visiting);
visiting.pop();
}
}
interface ComponentProvider<T> {
T get(Context context);
List<Class<?>> getDependencies();
}
}
ConstructorInjectionProvider.java:
package geektime.tdd.di;
import jakarta.inject.Inject;
import java.lang.reflect.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import static java.util.Arrays.stream;
import static java.util.stream.Stream.concat;
class ConstructorInjectionProvider<T> implements ContextConfig.ComponentProvider<T> {
private Constructor<T> injectConstructor;
private List<Field> injectFields;
private List<Method> injectMethods;
public ConstructorInjectionProvider(Class<T> component) {
if (Modifier.isAbstract(component.getModifiers())) throw new IllegalComponentException();
this.injectConstructor = getInjectConstructor(component);
this.injectFields = getInjectFields(component);
this.injectMethods = getInjectMethods(component);
if (injectFields.stream().anyMatch(f -> Modifier.isFinal(f.getModifiers()))) throw new IllegalComponentException();
if (injectMethods.stream().anyMatch(m -> m.getTypeParameters().length != 0)) throw new IllegalComponentException();
}
@Override
public T get(Context context) {
try {
Object[] dependencies = stream(injectConstructor.getParameters())
.map(p -> context.get(p.getType()).get())
.toArray(Object[]::new);
T instance = injectConstructor.newInstance(dependencies);
for (Field field : injectFields)
field.set(instance, context.get(field.getType()).get());
for (Method method : injectMethods)
method.invoke(instance, stream(method.getParameterTypes()).map(t -> context.get(t).get())
.toArray(Object[]::new));
return instance;
} catch (InvocationTargetException | InstantiationException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
@Override
public List<Class<?>> getDependencies() {
return concat(concat(stream(injectConstructor.getParameters()).map(Parameter::getType),
injectFields.stream().map(Field::getType)),
injectMethods.stream().flatMap(m -> stream(m.getParameterTypes()))
).toList();
}
private static <T> List<Method> getInjectMethods(Class<T> component) {
List<Method> injectMethods = new ArrayList<>();
Class<?> current = component;
while (current != Object.class) {
injectMethods.addAll(stream(current.getDeclaredMethods()).filter(m -> m.isAnnotationPresent(Inject.class))
.filter(m -> injectMethods.stream().noneMatch(o -> o.getName().equals(m.getName()) &&
Arrays.equals(o.getParameterTypes(), m.getParameterTypes())))
.filter(m -> stream(component.getDeclaredMethods()).filter(m1 -> !m1.isAnnotationPresent(Inject.class))
.noneMatch(o -> o.getName().equals(m.getName()) &&
Arrays.equals(o.getParameterTypes(), m.getParameterTypes())))
.toList());
current = current.getSuperclass();
}
Collections.reverse(injectMethods);
return injectMethods;
}
private static <T> List<Field> getInjectFields(Class<T> component) {
List<Field> injectFields = new ArrayList<>();
Class<?> current = component;
while (current != Object.class) {
injectFields.addAll(stream(current.getDeclaredFields()).filter(f -> f.isAnnotationPresent(Inject.class))
.toList());
current = current.getSuperclass();
}
return injectFields;
}
private static <Type> Constructor<Type> getInjectConstructor(Class<Type> implementation) {
List<Constructor<?>> injectConstructors = stream(implementation.getConstructors())
.filter(c -> c.isAnnotationPresent(Inject.class)).collect(Collectors.toList());
if (injectConstructors.size() > 1) throw new IllegalComponentException();
return (Constructor<Type>) injectConstructors.stream().findFirst().orElseGet(() -> {
try {
return implementation.getDeclaredConstructor();
} catch (NoSuchMethodException e) {
throw new IllegalComponentException();
}
});
}
}
Context.java:
package geektime.tdd.di;
import java.util.Optional;
public interface Context {
<Type> Optional<Type> get(Class<Type> type);
}