InjectProvider.java:
package geektime.tdd.di;
import jakarta.inject.Inject;
import java.lang.reflect.*;
import java.util.*;
import java.util.function.BiFunction;
import java.util.stream.Stream;
import static java.util.Arrays.stream;
import static java.util.stream.Stream.concat;
class InjectionProvider<T> implements ContextConfig.ComponentProvider<T> {
private Constructor<T> injectConstructor;
private List<Field> injectFields;
private List<Method> injectMethods;
public InjectionProvider(Class<T> component) {
if (Modifier.isAbstract(component.getModifiers())) throw new IllegalComponentException();
this.injectConstructor = getInjectConstructor(component);
this.injectFields = getInjectFields(component);
this.injectMethods = getInjectMethods(component);
if (injectFields.stream().anyMatch(f -> Modifier.isFinal(f.getModifiers())))
throw new IllegalComponentException();
if (injectMethods.stream().anyMatch(m -> m.getTypeParameters().length != 0))
throw new IllegalComponentException();
}
@Override
public T get(Context context) {
try {
T instance = injectConstructor.newInstance(toDependencies(context, injectConstructor));
for (Field field : injectFields) field.set(instance, toDependency(context, field));
for (Method method : injectMethods) method.invoke(instance, toDependencies(context, method));
return instance;
} catch (InvocationTargetException | InstantiationException | IllegalAccessException e) {
throw new RuntimeException(e);
}
}
@Override
public List<Context.Ref> getDependencies() {
return concat(concat(stream(injectConstructor.getParameters()).map(Parameter::getParameterizedType),
injectFields.stream().map(Field::getGenericType)),
injectMethods.stream().flatMap(m -> stream(m.getParameters()).map(Parameter::getParameterizedType)))
.map(Context.Ref::of).toList();
}
private static <T> List<Method> getInjectMethods(Class<T> component) {
List<Method> injectMethods = traverse(component, (methods, current) -> injectable(current.getDeclaredMethods())
.filter(m -> isOverrideByInjectMethod(methods, m))
.filter(m -> isOverrideByNoInjectMethod(component, m)).toList());
Collections.reverse(injectMethods);
return injectMethods;
}
private static <T> List<Field> getInjectFields(Class<T> component) {
return traverse(component, (fields, current) -> injectable(current.getDeclaredFields()).toList());
}
private static <Type> Constructor<Type> getInjectConstructor(Class<Type> implementation) {
List<Constructor<?>> injectConstructors = injectable(implementation.getConstructors()).toList();
if (injectConstructors.size() > 1) throw new IllegalComponentException();
return (Constructor<Type>) injectConstructors.stream().findFirst().orElseGet(() -> defaultConstructor(implementation));
}
private static <Type> Constructor<Type> defaultConstructor(Class<Type> implementation) {
try {
return implementation.getDeclaredConstructor();
} catch (NoSuchMethodException e) {
throw new IllegalComponentException();
}
}
private static <T> List<T> traverse(Class<?> component, BiFunction<List<T>, Class<?>, List<T>> finder) {
List<T> members = new ArrayList<>();
Class<?> current = component;
while (current != Object.class) {
members.addAll(finder.apply(members, current));
current = current.getSuperclass();
}
return members;
}
private static <T extends AnnotatedElement> Stream<T> injectable(T[] declaredFields) {
return stream(declaredFields).filter(f -> f.isAnnotationPresent(Inject.class));
}
private static boolean isOverride(Method m, Method o) {
return o.getName().equals(m.getName()) && Arrays.equals(o.getParameterTypes(), m.getParameterTypes());
}
private static <T> boolean isOverrideByNoInjectMethod(Class<T> component, Method m) {
return stream(component.getDeclaredMethods()).filter(m1 -> !m1.isAnnotationPresent(Inject.class)).noneMatch(o -> isOverride(m, o));
}
private static boolean isOverrideByInjectMethod(List<Method> injectMethods, Method m) {
return injectMethods.stream().noneMatch(o -> isOverride(m, o));
}
private static Object[] toDependencies(Context context, Executable executable) {
return stream(executable.getParameters()).map(p -> toDependency(context, p.getParameterizedType())).toArray(Object[]::new);
}
private static Object toDependency(Context context, Field field) {
return toDependency(context, field.getGenericType());
}
private static Object toDependency(Context context, Type type) {
return context.get(Context.Ref.of(type)).get();
}
}
Context.java:
package geektime.tdd.di;
import java.lang.annotation.Annotation;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.Objects;
import java.util.Optional;
public interface Context {
<ComponentType> Optional<ComponentType> get(Ref<ComponentType> ref);
class Ref<ComponentType> {
public static <ComponentType> Ref<ComponentType> of(Class<ComponentType> component) {
return new Ref(component, null);
}
public static <ComponentType> Ref<ComponentType> of(Class<ComponentType> component, Annotation qualifier) {
return new Ref(component, qualifier);
}
public static Ref of(Type type) {
return new Ref(type, null);
}
private Type container;
private Class<ComponentType> component;
private Annotation qualifier;
Ref(Type type, Annotation qualifier) {
init(type);
this.qualifier = qualifier;
}
protected Ref() {
Type type = ((ParameterizedType) getClass().getGenericSuperclass()).getActualTypeArguments()[0];
init(type);
}
private void init(Type type) {
if (type instanceof ParameterizedType container) {
this.container = container.getRawType();
this.component = (Class<ComponentType>) container.getActualTypeArguments()[0];
} else
this.component = (Class<ComponentType>) type;
}
public Type getContainer() {
return container;
}
public Class<?> getComponent() {
return component;
}
public boolean isContainer() {
return container != null;
}
public Annotation getQualifier() {
return qualifier;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Ref ref = (Ref) o;
return Objects.equals(container, ref.container) && component.equals(ref.component);
}
@Override
public int hashCode() {
return Objects.hash(container, component);
}
}
}
ContextConfig.java:
package geektime.tdd.di;
import jakarta.inject.Provider;
import java.lang.annotation.Annotation;
import java.util.*;
public class ContextConfig {
private Map<Class<?>, ComponentProvider<?>> providers = new HashMap<>();
private Map<Component, ComponentProvider<?>> components = new HashMap<>();
public <Type> void bind(Class<Type> type, Type instance) {
providers.put(type, (ComponentProvider<Type>) context -> instance);
}
public <Type> void bind(Class<Type> type, Type instance, Annotation... qualifiers) {
for (Annotation qualifier : qualifiers)
components.put(new Component(type, qualifier), context -> instance);
}
record Component(Class<?> type, Annotation qualifiers) {
}
public <Type, Implementation extends Type>
void bind(Class<Type> type, Class<Implementation> implementation) {
providers.put(type, new InjectionProvider<>(implementation));
}
public <Type, Implementation extends Type>
void bind(Class<Type> type, Class<Implementation> implementation, Annotation... qualifiers) {
for (Annotation qualifier : qualifiers)
components.put(new Component(type, qualifier), new InjectionProvider<>(implementation));
}
public Context getContext() {
providers.keySet().forEach(component -> checkDependencies(component, new Stack<>()));
return new Context() {
@Override
public <ComponentType> Optional<ComponentType> get(Ref<ComponentType> ref) {
if (ref.getQualifier() != null) {
return Optional.ofNullable(components.get(new Component(ref.getComponent(), ref.getQualifier()))).map(provider -> (ComponentType) provider.get(this));
}
if (ref.isContainer()) {
if (ref.getContainer() != Provider.class) return Optional.empty();
return (Optional<ComponentType>) Optional.ofNullable(providers.get(ref.getComponent()))
.map(provider -> (Provider<Object>) () -> provider.get(this));
}
return Optional.ofNullable(providers.get(ref.getComponent())).map(provider -> (ComponentType) provider.get(this));
}
};
}
private void checkDependencies(Class<?> component, Stack<Class<?>> visiting) {
for (Context.Ref dependency : providers.get(component).getDependencies()) {
if (!providers.containsKey(dependency.getComponent()))
throw new DependencyNotFoundException(component, dependency.getComponent());
if (!dependency.isContainer()) {
if (visiting.contains(dependency.getComponent())) throw new CyclicDependenciesFoundException(visiting);
visiting.push(dependency.getComponent());
checkDependencies(dependency.getComponent(), visiting);
visiting.pop();
}
}
}
interface ComponentProvider<T> {
T get(Context context);
default List<Context.Ref> getDependencies() {
return List.of();
}
}
}