///////////////////////////////////////////////////////////////////////////////
//
//  Copyright (2008) Alexander Stukowski
//
//  This file is part of OVITO (Open Visualization Tool).
//
//  OVITO is free software; you can redistribute it and/or modify
//  it under the terms of the GNU General Public License as published by
//  the Free Software Foundation; either version 2 of the License, or
//  (at your option) any later version.
//
//  OVITO is distributed in the hope that it will be useful,
//  but WITHOUT ANY WARRANTY; without even the implied warranty of
//  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//  GNU General Public License for more details.
//
//  You should have received a copy of the GNU General Public License
//  along with this program.  If not, see <http://www.gnu.org/licenses/>.
//
///////////////////////////////////////////////////////////////////////////////

#include <core/Core.h>
#include <core/data/units/ParameterUnit.h>
#include <atomviz/atoms/AtomsObject.h>
#include <atomviz/atoms/datachannels/PositionDataChannel.h>
#include <core/gui/properties/FloatPropertyUI.h>
#include <boost/iterator/counting_iterator.hpp>

#include "NearestNeighborList.h"
#include "ChemicalElements.h"

namespace AtomViz {

IMPLEMENT_SERIALIZABLE_PLUGIN_CLASS(NearestNeighborList, RefTarget)
DEFINE_PROPERTY_FIELD(NearestNeighborList, "NearestNeighborCutoff", _nearestNeighborCutoff)
SET_PROPERTY_FIELD_LABEL(NearestNeighborList, _nearestNeighborCutoff, "Cutoff radius")
SET_PROPERTY_FIELD_UNITS(NearestNeighborList, _nearestNeighborCutoff, WorldParameterUnit)

/******************************************************************************
* Constructor
******************************************************************************/
NearestNeighborList::NearestNeighborList(bool isLoading) : RefTarget(isLoading),
	_nearestNeighborCutoff(0.0)
{
	INIT_PROPERTY_FIELD(NearestNeighborList, _nearestNeighborCutoff);

	if(!isLoading) {
		// Use the default cutoff radius stored in the application settings.
		QSettings settings;
		settings.beginGroup("atomviz/neigborlist");
		setNearestNeighborCutoff(settings.value("DefaultCutoff", 0.0).value<FloatType>());
		settings.endGroup();
	}
}

/******************************************************************************
* Computes the nearest neighbor list.
* Throws an exception on error.
* Returns false when the operation has been canceled by the user.
******************************************************************************/
bool NearestNeighborList::build(AtomsObject* input, bool noProgressIndicator)
{
	CHECK_OBJECT_POINTER(input);
	clear();

	PositionDataChannel* posChannel = static_object_cast<PositionDataChannel>(input->getStandardDataChannel(DataChannel::PositionChannel));
	if(!posChannel) throw Exception(tr("Input object does not contain atomic positions. Position channel is missing."));

	FloatType cutoffRadius = nearestNeighborCutoff();
	FloatType cutoffRadiusSquared = cutoffRadius * cutoffRadius;

	if(cutoffRadius <= 0.0)
		throw Exception(tr("Invalid parameter: Neighbor cutoff radius must be positive."));

	AffineTransformation simCell = input->simulationCell()->cellMatrix();
	if(fabs(simCell.determinant()) <= FLOATTYPE_EPSILON)
		throw Exception(tr("Simulation cell is degenerate."));
	AffineTransformation simCellInverse = simCell.inverse();
	array<bool,3> pbc = input->simulationCell()->periodicity();

	Vector3 perpVecs[3];
	perpVecs[0] = Normalize(CrossProduct(simCell.column(1), simCell.column(2)));
	perpVecs[1] = Normalize(CrossProduct(simCell.column(2), simCell.column(0)));
	perpVecs[2] = Normalize(CrossProduct(simCell.column(0), simCell.column(1)));

	// Calculate the number of bins required in each spatial direction.
	array<int,3> binDim;
	binDim.assign(1);
	if(cutoffRadius > 0.0) {
		Matrix3 binCell;
		for(size_t i=0; i<3; i++) {
			binDim[i] = (int)floor(fabs(DotProduct(simCell.column(i), perpVecs[i])) / cutoffRadius);
			binDim[i] = min(binDim[i], 60);
			binDim[i] = max(binDim[i], 1);

			// Only accept an even number of bins (exception is exactly 1 bin) to avoid problems with the parallel processing and periodic boundary conditions.
			binDim[i] &= ~1;

			if(binDim[i] < 4) {
				if(pbc[i]) {
					VerboseLogger() << "Periodic simulation cell too small: axis:" << i << "  cutoff radius:" << cutoffRadius << "   cell size:" << DotProduct(simCell.column(i), perpVecs[i]) << "   perpvec:" << perpVecs[i] << "   cellvec:" << simCell.column(i) << endl;
					throw Exception(tr("Periodic simulation cell is smaller than four times the neighbor cutoff radius."));
				}
				binDim[i] = 1;
			}
			binCell.column(i) = simCell.column(i) / (FloatType)binDim[i];
			//VerboseLogger() << "Bin dim " << binDim[i] << endl;
			//VerboseLogger() << "Bin cell " << binCell.column(i) << endl;
			//VerboseLogger() << "Bin cell size " << (fabs(DotProduct(binCell.column(i), perpVecs[i])) / cutoffRadius) << endl;
		}
	}

	// Show progress dialog.
	scoped_ptr<ProgressIndicator> progress;
	if(!noProgressIndicator)
		progress.reset(new ProgressIndicator(tr("Building neighbor lists (using %n processor(s))", NULL, QThread::idealThreadCount())));

	// A 3d array of cubic bins. Each bin is a linked list of atoms.
	QVector<NeighborListAtom*> bins(binDim[0] * binDim[1] * binDim[2]);

	// Allocate output array.
	atoms.resize(input->atomsCount());

	// Measure computation time.
	QTime timer;
	timer.start();

	// Sort atoms into bins.
	const Point3* p = posChannel->constDataPoint3();
	vector<NeighborListAtom>::iterator a = atoms.begin();
	int atomIndex = 0;
	for(; a != atoms.end(); ++a, ++p, ++atomIndex) {
		a->index = atomIndex;

		// Transform atom position from absolute coordinates to reduced coordinates.
		a->pos = *p;
		Point3 reducedp = simCellInverse * (*p);

		int indices[3];
		for(size_t k=0; k<3; k++) {
			// Shift atom position to make it be inside simulation cell.
			if(pbc[k]) {
				while(reducedp[k] < 0) {
					reducedp[k] += 1;
					a->pos += simCell.column(k);
				}
				while(reducedp[k] > 1) {
					reducedp[k] -= 1;
					a->pos -= simCell.column(k);
				}
			}
			else {
				reducedp[k] = max(reducedp[k], (FloatType)0);
				reducedp[k] = min(reducedp[k], (FloatType)1);
			}

			// Determine the atom's bin from its reduced position in the simulation cell.
			indices[k] = (int)(reducedp[k] * binDim[k]);
			if(indices[k] == binDim[k]) indices[k] = binDim[k]-1;
			OVITO_ASSERT(indices[k] >= 0 && indices[k] < binDim[k]);
		}

		// Put atom into its bin.
		NeighborListAtom*& binList = bins[indices[0] + indices[1]*binDim[0] + indices[2]*binDim[0]*binDim[1]];
		a->nextInBin = binList;
		binList = &*a;
	}

	VerboseLogger() << "Neighbor list binning took" << timer.restart() << "msec." << endl;

	Vector3I offset;
	for(offset.X = 0; offset.X <= 1; offset.X++) {
		for(offset.Y = 0; offset.Y <= 1; offset.Y++) {
			for(offset.Z = 0; offset.Z <= 1; offset.Z++) {

				// Put together list of bins to process in this iteration.
				QVector<Point3I> binsToProcess;
				binsToProcess.reserve(binDim[0]*binDim[1]*binDim[2]/8);
				Point3I binIndex;
				for(binIndex.X = offset.X; binIndex.X < binDim[0]; binIndex.X += 2)
					for(binIndex.Y = offset.Y; binIndex.Y < binDim[1]; binIndex.Y += 2)
						for(binIndex.Z = offset.Z; binIndex.Z < binDim[2]; binIndex.Z += 2)
							binsToProcess.push_back(binIndex);

				// Execute neighbor list code for each atom in a parallel fashion.
				Kernel kernel(bins, binDim, simCell, pbc, cutoffRadiusSquared, offset);
				QFuture<void> future = QtConcurrent::map(binsToProcess, kernel);
				if(progress) {
					progress->setLabelText(tr("Building neighbor lists (%1/8) (using %n processor(s))", NULL, QThread::idealThreadCount()).arg(offset.X*4+offset.Y*2+offset.Z+1));
					progress->waitForFuture(future);
				}
				else future.waitForFinished();

				// Throw away results obtained so far if the user cancels the calculation.
				if(future.isCanceled()) {
					clear();
					return false;
				}
			}
		}
	}

	VerboseLogger() << "Neighbor list building took" << timer.elapsed() << "msec." << endl;

	return true;
}

static const Vector3I stencils[][2] = {
		{ Vector3I(0,0,0), Vector3I(0,0,0) },
		{ Vector3I(0,0,0), Vector3I(0,0,1) },
		{ Vector3I(0,0,0), Vector3I(0,1,0) },
		{ Vector3I(0,0,0), Vector3I(0,1,1) },
		{ Vector3I(0,0,0), Vector3I(1,0,0) },
		{ Vector3I(0,0,0), Vector3I(1,0,1) },
		{ Vector3I(0,0,0), Vector3I(1,1,0) },
		{ Vector3I(0,0,0), Vector3I(1,1,1) },

		{ Vector3I(1,0,0), Vector3I(0,1,0) },	// -1  1  0
		{ Vector3I(1,0,0), Vector3I(0,0,1) },	// -1  0  1
		{ Vector3I(0,1,0), Vector3I(0,0,1) },   //  0 -1  1

		{ Vector3I(1,0,0), Vector3I(0,1,1) },   // -1  1  1
		{ Vector3I(0,1,0), Vector3I(1,0,1) },   //  1 -1  1
		{ Vector3I(0,0,1), Vector3I(1,1,0) },   //  1  1 -1
};

/******************************************************************************
* Finds the neighbors for all atoms in a single bin.
******************************************************************************/
void NearestNeighborList::Kernel::operator()(const Point3I& binOrigin)
{
	OVITO_ASSERT(binOrigin.X < binDim[0]);
	OVITO_ASSERT(binOrigin.Y < binDim[1]);
	OVITO_ASSERT(binOrigin.Z < binDim[2]);

	size_t numStencils = sizeof(stencils)/sizeof(stencils[0]);
	for(size_t stencilIndex = 0; stencilIndex < numStencils; ++stencilIndex) {
		Point3I bin1 = binOrigin + stencils[stencilIndex][0];
		Point3I bin2 = binOrigin + stencils[stencilIndex][1];
		Vector3 pbcOffset(NULL_VECTOR);
		bool skipStencil = false;
		for(size_t k = 0; k < 3; k++) {
			if(bin1[k] == binDim[k]) {
				if(!pbc[k]) { skipStencil = true; break; }
				bin1[k] = 0;
				pbcOffset += simCell.column(k);
			}
			if(bin2[k] == binDim[k]) {
				if(!pbc[k]) { skipStencil = true; break; }
				bin2[k] = 0;
				pbcOffset -= simCell.column(k);
			}
		}
		if(skipStencil) continue;

		int numNeighborPairs = 0;
		int numAtoms = 0;
		for(NeighborListAtom* atom1 = bins[bin1.X + bin1.Y*binDim[0] + bin1.Z*binDim[0]*binDim[1]]; atom1 != NULL; atom1 = atom1->nextInBin) {
			NeighborListAtom* atom2 = bins[bin2.X + bin2.Y*binDim[0] + bin2.Z*binDim[0]*binDim[1]];
			if(bin1 == bin2) atom2 = atom1->nextInBin;
			for(; atom2 != NULL; atom2 = atom2->nextInBin) {
				Vector3 delta = atom1->pos - atom2->pos + pbcOffset;
				if(LengthSquared(delta) > cutoffRadiusSquared) continue;

				OVITO_ASSERT(atom1 != atom2);
				atom1->neighbors.append(atom2);
				atom2->neighbors.append(atom1);
				numNeighborPairs++;
			}
			numAtoms++;

			if(numNeighborPairs > numAtoms * 400)
				throw Exception(tr("The average number of nearest neighbors per atom exceeds the reasonable limit. Atomic positions seem to be invalid or cutoff radius too large."));
		}
	}
}

IMPLEMENT_PLUGIN_CLASS(NearestNeighborListEditor, PropertiesEditor)

/******************************************************************************
* Sets up the UI widgets of the editor.
******************************************************************************/
void NearestNeighborListEditor::createUI(const RolloutInsertionParameters& rolloutParams)
{
	// Create a rollout.
	QWidget* rollout = createRollout(tr("Neighbor list"), rolloutParams);

    // Create the rollout contents.
	QGridLayout* layout = new QGridLayout(rollout);
#ifndef Q_WS_MAC
	layout->setContentsMargins(4,4,4,4);
	layout->setSpacing(0);
#endif
	layout->setColumnStretch(1, 1);

	// Cutoff parameter.
	FloatPropertyUI* cutoffRadiusPUI = new FloatPropertyUI(this, PROPERTY_FIELD_DESCRIPTOR(NearestNeighborList, _nearestNeighborCutoff));
	layout->addWidget(cutoffRadiusPUI->label(), 0, 0);
	layout->addLayout(cutoffRadiusPUI->createFieldLayout(), 0, 1);
	cutoffRadiusPUI->setMinValue(0);
	connect(cutoffRadiusPUI->spinner(), SIGNAL(spinnerValueChanged()), this, SLOT(memorizeCutoff()));

	// Selection box for predefined cutoff radii.
	nearestNeighborPresetsBox = new QComboBox(rollout);
	nearestNeighborPresetsBox->addItem(tr("Choose..."));
	for(size_t i=0; i<NumberOfChemicalElements; i++) {
		if(ChemicalElements[i].structure == ChemicalElement::FaceCenteredCubic) {
			FloatType r = ChemicalElements[i].latticeParameter * 0.5 * (1.0 + 1.0/sqrt(2.0));
			nearestNeighborPresetsBox->addItem(QString("%1 (%2) - FCC - %3").arg(ChemicalElements[i].elementName).arg(i).arg(r, 0, 'f', 2), r);
		}
		else if(ChemicalElements[i].structure == ChemicalElement::BodyCenteredCubic) {
			FloatType r = ChemicalElements[i].latticeParameter * (1.0 + (sqrt(2.0)-1.0)*0.5);
			nearestNeighborPresetsBox->addItem(QString("%1 (%2) - BCC - %3").arg(ChemicalElements[i].elementName).arg(i).arg(r, 0, 'f', 2), r);
		}
	}
	layout->addWidget(new QLabel(tr("Presets:")), 1, 0);
	layout->addWidget(nearestNeighborPresetsBox, 1, 1);
	connect(nearestNeighborPresetsBox, SIGNAL(activated(int)), this, SLOT(onSelectNearestNeighborPreset(int)));
}

/******************************************************************************
* Is called when the user has selected an item in the radius preset box.
******************************************************************************/
void NearestNeighborListEditor::onSelectNearestNeighborPreset(int index)
{
	FloatType r = nearestNeighborPresetsBox->itemData(index).value<FloatType>();
	if(r != 0) {
		if(!editObject()) return;
		NearestNeighborList* obj = static_object_cast<NearestNeighborList>(editObject());
		UNDO_MANAGER.beginCompoundOperation(tr("Change Cutoff Radius"));
		obj->setNearestNeighborCutoff(r);
		UNDO_MANAGER.endCompoundOperation();
		memorizeCutoff();
	}
	nearestNeighborPresetsBox->setCurrentIndex(0);
}


/******************************************************************************
* Stores the current cutoff radius in the application settings
* so it can be used as default value for new neighbor lists.
******************************************************************************/
void NearestNeighborListEditor::memorizeCutoff()
{
	if(!editObject()) return;
	NearestNeighborList* nnList = static_object_cast<NearestNeighborList>(editObject());

	QSettings settings;
	settings.beginGroup("atomviz/neigborlist");
	settings.setValue("DefaultCutoff", nnList->nearestNeighborCutoff());
	settings.endGroup();
}

};	// End of namespace AtomViz

