Newer
Older
AppendIrisCodeUtils / casic / iris / CasicSegPostProcess.cpp
#include "CasicSegPostProcess.h"

#include <math.h>

namespace iristrt
{

std::vector<cv::Mat> CasicSegPostProcess::getConnections(cv::Mat input, int connectivity) {
	cv::Mat output;
	int nLabels = connectedComponents(input, output, connectivity);
	output.convertTo(output, CV_8UC1, 1);

	std::vector<cv::Mat> conns;
	for (int i = 1; i <= nLabels; i++) {
		cv::Mat dst = cv::Mat::zeros(output.size(), output.type());
		for (int row = 0; row < output.rows; row++) {
			for (int col = 0; col < output.cols; col++) {
				uchar label = output.at<uchar>(row, col);
				if (label == i) {
					dst.at<uchar>(row, col) = 255;
				}
			}
		}
		conns.push_back(dst);
	}
	return conns;
}

void CasicSegPostProcess::showConnections(std::vector<cv::Mat> conns) {
	for (int i = 0; i < conns.size(); i++) {
		cv::imshow("con", conns[i]);
		cv::waitKey(0);
	}
	cv::destroyAllWindows();
}

std::vector<CasicTriplet>  CasicSegPostProcess::getTripleSet(std::vector<cv::Mat> maskConns,
															std::vector<cv::Mat> irisConns, 
															std::vector<cv::Mat> pupilConns,
															int chessboardDistance) {
    std::vector<CasicTriplet> triplet_set;
    for (auto maskConn : maskConns) {
        for (auto irisConn : irisConns) {
			int maskIrisDis = getChessboardDistance(maskConn, irisConn, chessboardDistance);
            for (auto pupilConn : pupilConns) {
				int maskPupilDis = getChessboardDistance(maskConn, pupilConn, chessboardDistance);
				if (maskIrisDis <= chessboardDistance && maskPupilDis <= chessboardDistance) {
                    CasicTriplet triplet;
                    triplet.maskConn = maskConn;
                    triplet.irisConn = irisConn;
                    triplet.pupilConn = pupilConn;
					triplet_set.push_back(triplet);
				}
			}
		}
	}
	return triplet_set;
}

void CasicSegPostProcess::showTriple(CasicTriplet triple) {
	/*cv::Mat tripleImage;
	cv::add(std::get<0>(triple), std::get<1>(triple), tripleImage);
	cv::add(std::get<2>(triple), tripleImage, tripleImage);
	cv::imshow("tripleImage", tripleImage);
	cv::waitKey(0);*/

    cv::Mat mask = triple.maskConn;
    cv::Mat iris = triple.irisConn;
    cv::Mat pupil = triple.pupilConn;

	cv::imshow("triple-mask", mask);
	cv::waitKey(0);

	cv::imshow("triple-iris", iris);
	cv::waitKey(0);

	cv::imshow("triple-pupil", pupil);
	cv::waitKey(0);
}

int CasicSegPostProcess::getChessboardDistance(cv::Mat con1, cv::Mat con2, int chessboardDistance) {
	// mask contours
	std::vector<std::vector<cv::Point>> c1;
	findContours(con1, c1,cv:: RETR_TREE, cv::CHAIN_APPROX_SIMPLE, cv::Point());
	//cout << "c1.size() " << c1.size() << endl;
	std::vector<cv::Point> c1All;
	for (auto c : c1) {
		//cout << "c.size()" << c.size() << endl;
		c1All.insert(c1All.end(), c.begin(), c.end());
	}
	//cout << "c1All.size()	" << c1All.size() << endl;

	// other contours
	std::vector<std::vector<cv::Point>> c2;
	findContours(con2, c2, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE, cv::Point());
	//cout << "c2.size() " << c2.size() << endl;
	std::vector<cv::Point> c2All;
	for (auto c : c2) {
		//cout << "c.size()" << c.size() << endl;
		c2All.insert(c2All.end(), c.begin(), c.end());
	}
	//cout << "c2All.size()	" << c2All.size() << endl;

	int distance = 1000;
	for (auto p1 : c1All) {
		for (auto p2 : c2All) {
			int tempDistance = std::max(abs(p1.x - p2.x), abs(p1.y - p2.y));
			if (tempDistance < distance) {
				distance = tempDistance;
			}
			if (distance <= chessboardDistance) {
				break;
			}
		}
	}
	return distance;


}

CasicTriplet CasicSegPostProcess::getMaxTriple(std::vector<CasicTriplet> tripleSet) {
	long maxCount = 0;
	int maxIndex = -1;
	for (int i = 0; i < tripleSet.size(); i++) {
        CasicTriplet triplet = tripleSet[i];
        long count = cv::countNonZero(triplet.maskConn)
            + cv::countNonZero(triplet.irisConn)
            + cv::countNonZero(triplet.pupilConn);
		if (count > maxCount) {
			maxCount = count;
			maxIndex = i;
		}
	}
	return tripleSet[maxIndex];
}

std::vector<int> CasicSegPostProcess::leastSquareCircleFitting(std::vector<cv::Point> counter) {

	std::vector<int> result;
	int center_x, center_y, radius = 0;

	if (counter.size() < 3) {
		result.push_back(center_x);
		result.push_back(center_y);
		result.push_back(radius);
		return result;
	}

	double sum_x, sum_y = 0.0;
	double sum_x2, sum_y2 = 0.0;
	double sum_x3, sum_y3 = 0.0;
	double sum_xy, sum_x1y2, sum_x2y1 = 0.0;

	for (auto point : counter) {
		double x = (double)point.x;
		double y = (double)point.y;
		double x2 = x * x;
		double y2 = y * y;
		sum_x += x;
		sum_y += y;
		sum_x2 += x2;
		sum_y2 += y2;
		sum_x3 += x2 * x;
		sum_y3 += y2 * y;
		sum_xy += x * y;
		sum_x1y2 += x * y2;
		sum_x2y1 += x2 * y;
	}

	double N = counter.size();
	double	C = N * sum_x2 - sum_x * sum_x;
	double	D = N * sum_xy - sum_x * sum_y;
	double	E = N * sum_x3 + N * sum_x1y2 - (sum_x2 + sum_y2) * sum_x;
	double	G = N * sum_y2 - sum_y * sum_y;
	double	H = N * sum_x2y1 + N * sum_y3 - (sum_x2 + sum_y2) * sum_y;
	double	a = (H * D - E * G) / (C * G - D * D);
	double	b = (H * C - E * D) / (D * D - G * C);
	double	c = -(a * sum_x + b * sum_y + sum_x2 + sum_y2) / N;

	center_x = (int)(a / (-2));
	center_y = (int)(b / (-2));
	radius = (int)(sqrt(a * a + b * b - 4 * c) / 2);

	result.push_back(center_x);
	result.push_back(center_y);
	result.push_back(radius);
	return result;
}

bool CasicSegPostProcess::postProcess(cv::Mat & mask, cv::Mat & iris, cv::Mat & pupil, std::vector<int> & irisCircle, std::vector<int> & pupilCircle) {
	int chessboard_distance = 15;

	// cv::Mat mask,iris,pupil;
    mask.convertTo(mask, CV_8UC1);
    iris.convertTo(iris, CV_8UC1);
    pupil.convertTo(pupil, CV_8UC1);

	threshold(mask, mask, 127, 255, cv::THRESH_BINARY);
	threshold(iris, iris, 90, 255, cv::THRESH_BINARY);
	threshold(pupil, pupil, 0, 255, cv::THRESH_BINARY + cv::THRESH_OTSU);

	cv::Mat element = getStructuringElement(cv::MORPH_RECT, cv::Size(5, 5));
	morphologyEx(iris, iris, cv::MORPH_CLOSE, element);

//     cv::imshow("mask",mask);
//     cv::imshow("iris",iris);
//     cv::imshow("pupil",pupil);
//     cv::waitKey(0);

//     std::cout << "get connection ..."  << std::endl;
	
	std::vector<cv::Mat> maskCons = getConnections(mask, 8);
	std::vector<cv::Mat> irisCons = getConnections(iris, 8);
	std::vector<cv::Mat> pupilCons = getConnections(pupil, 8);

	if(maskCons.size() <= 0 || irisCons.size() <= 0 || pupilCons.size() <= 0)
	{
		return false;
	}

//     std::cout << "get triple set ..."  << std::endl;

    std::vector<CasicTriplet> tripleSet = getTripleSet(maskCons, irisCons, pupilCons, chessboard_distance);

	if (tripleSet.size() < 1) {
		return false;
	}

//     std::cout << "get max triple ..." << std::endl;

    CasicTriplet maxTriple = getMaxTriple(tripleSet);
	//showTriple(maxTriple);

//     std::cout << "least Square Circle Fitting ..."  << std::endl;

    cv::Mat bestIris = maxTriple.irisConn;
	std::vector<cv::Point> outPoints;
	findNonZero(bestIris, outPoints);
	irisCircle = leastSquareCircleFitting(outPoints);
	int iris_x = irisCircle[0];
	int	iris_y = irisCircle[1];
	int	iris_r = irisCircle[2];

    cv::Mat bestPupil = maxTriple.pupilConn;
	std::vector<std::vector<cv::Point>> innerCounter;
	findContours(bestPupil, innerCounter, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
	std::vector<cv::Point> innerPoints;
	for (auto c : innerCounter) {
		innerPoints.insert(innerPoints.end(), c.begin(), c.end());
	}
	pupilCircle = leastSquareCircleFitting(innerPoints);
	int pupil_x = pupilCircle[0];
	int pupil_y = pupilCircle[1];
	int	pupil_r = pupilCircle[2];

	//cout << "iris: " << iris_x << " " << iris_y << " " << iris_r << endl;
	//cout << "pupil:" << pupil_x << " " << pupil_y << " " << pupil_r << endl;

//     std::cout << "bitwise_and ..."  << std::endl;

	cv::Mat circleMask = cv::Mat::zeros(mask.size(), CV_8UC1);
	circle(circleMask,cv::Point(iris_x, iris_y), iris_r, 255, -1);
	circle(circleMask,cv::Point(pupil_x, pupil_y), pupil_r, 0, -1);
	cv::bitwise_and(mask, circleMask, mask);

	// show result
//    cv::imshow("circleMask", circleMask);
//    cv::waitKey(0);

//    cv::Mat three_mask = cv::Mat::zeros(mask.rows, mask.cols, CV_8UC3);
//    std::vector<cv::Mat> channels_m;
//    for (int i = 0; i < 3; i++)
//    {
//        channels_m.push_back(mask);
//    }
//    merge(channels_m, three_mask);

//    circle(three_mask, cv::Point(iris_x, iris_y), iris_r, cv::Scalar(0, 0, 255), 1);
//    circle(three_mask, cv::Point(pupil_x, pupil_y), pupil_r, cv::Scalar(0, 0, 255), 1);
//    cv::imshow("three_mask", three_mask);
//    cv::waitKey(0);
	return true;
}

} // end of namespace iristrt