//                                               -*- C++ -*-
/**
 *  @file  ProductCovarianceModel.cxx
 *
 *  Copyright 2005-2015 Airbus-EDF-IMACS-Phimeca
 *
 *  This library is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU Lesser General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This library is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU Lesser General Public License for more details.
 *
 *  You should have received a copy of the GNU Lesser General Public
 *  along with this library.  If not, see <http://www.gnu.org/licenses/>.
 *
 *  @author: schueller
 */
#include "ProductCovarianceModel.hxx"
#include "PersistentObjectFactory.hxx"
#include "Exception.hxx"
#include "AbsoluteExponential.hxx"

BEGIN_NAMESPACE_OPENTURNS

TEMPLATE_CLASSNAMEINIT(PersistentCollection< CovarianceModel >);

static Factory< PersistentCollection< CovarianceModel > > RegisteredFactory1("PersistentCollection< CovarianceModel >");

CLASSNAMEINIT(ProductCovarianceModel);

static Factory<ProductCovarianceModel> RegisteredFactory("ProductCovarianceModel");


/* Default constructor */
ProductCovarianceModel::ProductCovarianceModel(const UnsignedInteger spatialDimension)
  : CovarianceModelImplementation(spatialDimension)
  , collection_(spatialDimension, AbsoluteExponential(1))
{
  // Update the default values for the scale and the amplitude
  setScale(NumericalPoint(spatialDimension, collection_[0].getScale()[0]));
  setAmplitude(NumericalPoint(spatialDimension, collection_[0].getAmplitude()[0]));
}

/* Parameters constructor */
ProductCovarianceModel::ProductCovarianceModel(const CovarianceModelCollection & collection)
  : CovarianceModelImplementation()
{
  setCollection(collection);
}

/* Collection accessor */
void ProductCovarianceModel::setCollection(const CovarianceModelCollection & collection)
{
  // Check if the given models have a spatial dimension=1
  const UnsignedInteger size(collection.getSize());
  if (size == 0) throw InvalidArgumentException(HERE) << "Error: the collection must have a positive size, here size=0";
  NumericalPoint scale(0);
  NumericalPoint amplitude(0);
  spatialDimension_ = 0;
  for (UnsignedInteger i = 0; i < size; ++i)
  {
    const UnsignedInteger localDimension(collection[i].getSpatialDimension());
    const NumericalPoint localScale(collection[i].getScale());
    const NumericalPoint localAmplitude(collection[i].getAmplitude());
    for (UnsignedInteger j = 0; j < localDimension; ++j)
    {
      scale.add(localScale[j]);
      amplitude.add(localAmplitude[j]);
    }
    spatialDimension_ += localDimension;
  }
  setScale(scale);
  setAmplitude(amplitude);
  collection_ = collection;
}

const ProductCovarianceModel::CovarianceModelCollection & ProductCovarianceModel::getCollection() const
{
  return collection_;
}

/* Virtual constructor */
ProductCovarianceModel * ProductCovarianceModel::clone() const
{
  return new ProductCovarianceModel(*this);
}

/* Computation of the covariance density function */
CovarianceMatrix ProductCovarianceModel::operator() (const NumericalPoint & s,
    const NumericalPoint & t) const
{
  if (s.getDimension() != spatialDimension_) throw InvalidArgumentException(HERE) << "Error: the point s has dimension=" << s.getDimension() << ", expected dimension=" << spatialDimension_;
  if (t.getDimension() != spatialDimension_) throw InvalidArgumentException(HERE) << "Error: the point t has dimension=" << t.getDimension() << ", expected dimension=" << spatialDimension_;
  CovarianceMatrix covariance(1);
  NumericalScalar value(1.0);
  UnsignedInteger start(0);
  for (UnsignedInteger i = 0; i < collection_.getSize(); ++i)
  {
    const UnsignedInteger localSpatialDimension(collection_[i].getSpatialDimension());
    const UnsignedInteger stop(start + localSpatialDimension);
    NumericalPoint localS(localSpatialDimension);
    std::copy(s.begin() + start, s.begin() + stop, localS.begin());
    NumericalPoint localT(localSpatialDimension);
    std::copy(t.begin() + start, t.begin() + stop, localT.begin());
    value *= collection_[i](localS, localT)(0, 0);
    start = stop;
  }
  covariance(0, 0) = value;
  return covariance;
}

/* Gradient */
Matrix ProductCovarianceModel::partialGradient(const NumericalPoint & s,
    const NumericalPoint & t) const
{
  if (s.getDimension() != spatialDimension_) throw InvalidArgumentException(HERE) << "Error: the point s has dimension=" << s.getDimension() << ", expected dimension=" << spatialDimension_;
  if (t.getDimension() != spatialDimension_) throw InvalidArgumentException(HERE) << "Error: the point t has dimension=" << t.getDimension() << ", expected dimension=" << spatialDimension_;
  const UnsignedInteger size(collection_.getSize());
  Collection<NumericalPoint> localCovariances(size);
  NumericalPoint localValues(size, 1.0);
  NumericalScalar value(1.0);
  UnsignedInteger start = 0;
  for (UnsignedInteger i = 0; i < size; ++i)
  {
    const UnsignedInteger localSpatialDimension(collection_[i].getSpatialDimension());
    const UnsignedInteger stop(start + localSpatialDimension);
    NumericalPoint localS(localSpatialDimension);
    std::copy(s.begin() + start, s.begin() + stop, localS.begin());
    NumericalPoint localT(localSpatialDimension);
    std::copy(t.begin() + start, t.begin() + stop, localT.begin());
    const Matrix localGradient(collection_[i](localS, localT));
    localCovariances[i] = NumericalPoint(localSpatialDimension);
    std::copy(localGradient.getImplementation()->begin(), localGradient.getImplementation()->end(), localCovariances[i].begin());
    for (UnsignedInteger j = 0; j < localSpatialDimension; ++j)
      localValues[i] *= localGradient(j, 0);
    value *= localValues[i];
    start = stop;
  }
  Matrix gradient(spatialDimension_, 1);
  // Usual case, value != 0
  if (value != 0.0)
  {
    start = 0;
    for (UnsignedInteger i = 0; i < size; ++i)
    {
      const UnsignedInteger localSpatialDimension(collection_[i].getSpatialDimension());
      const NumericalScalar coefficient(value / localValues[i]);
      const NumericalPoint localGradient(localCovariances[i] * coefficient);
      std::copy(localGradient.begin(), localGradient.end(), gradient.getImplementation()->begin() + start);
      start += localSpatialDimension;
    }
  } // value != 0
  else
  {
    // We must recompute the components using products
    start = 0;
    for (UnsignedInteger i = 0; i < size; ++i)
    {
      const UnsignedInteger localSpatialDimension(collection_[i].getSpatialDimension());
      NumericalScalar coefficient(1.0);
      for (UnsignedInteger j = 0; j < size; ++j)
        if (j != i) coefficient *= localValues[j];
      const NumericalPoint localGradient(localCovariances[i] * coefficient);
      std::copy(localGradient.begin(), localGradient.end(), gradient.getImplementation()->begin() + start);
      start += localSpatialDimension;
    }
  } // value == 0
  return gradient;
}

/* Parameters accessor */
void ProductCovarianceModel::setParameters(const NumericalPoint & parameters)
{
  const UnsignedInteger parametersDimension(getParameters().getDimension());
  if (parameters.getDimension() != parametersDimension) throw InvalidArgumentException(HERE) << "Error: parameters dimension should be 1 (got " << parameters.getDimension() << ")";
  UnsignedInteger start(0);
  for (UnsignedInteger i = 0; i < collection_.getSize(); ++i)
  {
    const UnsignedInteger atomParametersDimension(collection_[i].getParameters().getDimension());
    const UnsignedInteger stop(start + atomParametersDimension);
    NumericalPoint atomParameters(atomParametersDimension);
    std::copy(parameters.begin() + start, parameters.begin() + stop, atomParameters.begin());
    start = stop;
    collection_[i].setParameters(atomParameters);
  }
}

NumericalPointWithDescription ProductCovarianceModel::getParameters() const
{
  NumericalPointWithDescription result(0);
  Description description(0);
  const UnsignedInteger size(collection_.getSize());
  for (UnsignedInteger i = 0; i < size; ++i)
  {
    const NumericalPointWithDescription atomParameters(collection_[i].getParameters());
    const Description atomDescription(atomParameters.getDescription());
    result.add(atomParameters);
    for (UnsignedInteger j = 0; j < atomDescription.getSize(); ++j)
      description.add(OSS() << "model_" << i << "_" << atomDescription[j]);
  }
  result.setDescription(description);
  return result;
}

/* Is it a stationary model ? */
Bool ProductCovarianceModel::isStationary() const
{
  for (UnsignedInteger i = 0; i < collection_.getSize(); ++i)
    if (!collection_[i].isStationary()) return false;
  return true;
}

/* String converter */
String ProductCovarianceModel::__repr__() const
{
  OSS oss;
  oss << "class=" << ProductCovarianceModel::GetClassName()
      << " input dimension=" << spatialDimension_
      << " models=" << collection_;
  return oss;
}

/* String converter */
String ProductCovarianceModel::__str__(const String & offset) const
{
  return __repr__();
}

/* Marginal accessor */
ProductCovarianceModel::Implementation ProductCovarianceModel::getMarginal(const UnsignedInteger index) const
{
  if (index >= dimension_) throw InvalidArgumentException(HERE) << "Error: index=" << index << " must be less than output dimension=" << dimension_;
  return collection_[index].getImplementation();
}

/* Method save() stores the object through the StorageManager */
void ProductCovarianceModel::save(Advocate & adv) const
{
  CovarianceModelImplementation::save(adv);
  adv.saveAttribute("collection_", collection_);
}

/* Method load() reloads the object from the StorageManager */
void ProductCovarianceModel::load(Advocate & adv)
{
  CovarianceModelImplementation::load(adv);
  adv.loadAttribute("collection_", collection_);
}

END_NAMESPACE_OPENTURNS
