package testtherunner;

import org.junit.internal.runners.JUnit4ClassRunner;
import org.junit.internal.runners.InitializationError;
import org.junit.internal.runners.TestMethod;
import org.junit.internal.runners.MethodRoadie;
import org.junit.runner.notification.RunNotifier;
import org.junit.runner.Description;
import org.junit.Assert;

import java.lang.reflect.Method;
import java.lang.reflect.InvocationTargetException;

public class ThreadPoliceRunner extends JUnit4ClassRunner {
    public ThreadPoliceRunner(Class<?> klass) throws InitializationError {
        super(klass);
    }


    protected void invokeTestMethod(Method method, RunNotifier notifier) {
        Description description = methodDescription(method);
        Object test;
        try {
            test = createTest();
        } catch (InvocationTargetException e) {
            notifier.testAborted(description, e.getCause());
            return;
        } catch (Exception e) {
            notifier.testAborted(description, e);
            return;
        }
        TestMethod testMethod = wrapMethod(method);
        new MyMethodRoadie(test, testMethod, notifier, description).run();
    }

    public static class MyMethodRoadie extends MethodRoadie {
        public MyMethodRoadie(Object test, TestMethod method, RunNotifier notifier, Description description) {
            super(test, method, notifier, description);
        }

        public void runTest() {
            ThreadGroup testGroup = new ThreadGroup("Test execution");
            Runnable testRunner = new Runnable() {
                public void run() {
                    MyMethodRoadie.super.runTest();
                }
            };
            Thread testThread = new Thread(testGroup, testRunner);
            testThread.start();
            try {
                testThread.join();
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }

            int active = testGroup.enumerate(new Thread[100]);
            try {
                Assert.assertEquals("You have threads loitering after test completion.", 0, active);
            } catch (Throwable e) {
                addFailure(e);
            }
            testGroup.stop();
        }
    }
}

