#!/usr/bin/env python3
"""
Fit example with data by M. Fitzsimmons et al,
https://doi.org/10.5281/zenodo.4072376.
Sample is a ~50 nm Pt film on a Si substrate.
Single event data from Spallation Neutron Source
Beamline-4A (MagRef) with 60 Hz pulses and a wavelength
band of roughly 4-7 Å in 100 steps of 2theta.
"""

import os, sys
import matplotlib.pyplot as plt
import numpy as np
import bornagain as ba
from bornagain import ba_plot as bp
from bornagain import angstrom

# filename of the experimental data to be loaded
datadir = os.getenv('BA_DATA_DIR', '')
filename = 'RvsQ_36563_36662.txt.gz'
filepath = os.path.join(datadir, filename)

# restrict the Q-range of the data used for fitting
qmin = 0.18
qmax = 2.4

# number of points on which the computed result is plotted
scan_size = 1500

# Use fixed values for the SLD of the substrate and Pt layer
sldPt = (6.3568e-06, 1.8967e-09)
sldSi = (2.0728e-06, 2.3747e-11)

####################################################################
#                  Create Sample and Simulation                    #
####################################################################


def get_sample(P):

    vacuum = ba.MaterialBySLD("Vacuum", 0, 0)
    material_layer = ba.MaterialBySLD("Pt", *sldPt)
    material_substrate = ba.MaterialBySLD("Si", *sldSi)

    ambient_layer = ba.Layer(vacuum)
    layer = ba.Layer(material_layer, P["t_pt/nm"])
    substrate_layer = ba.Layer(material_substrate)

    r_si = ba.LayerRoughness(P["r_si/nm"])
    r_pt = ba.LayerRoughness(P["r_pt/nm"])

    sample = ba.MultiLayer()
    sample.addLayer(ambient_layer)
    sample.addLayerWithTopRoughness(layer, r_pt)

    sample.addLayerWithTopRoughness(substrate_layer, r_si)

    return sample


def get_simulation(q_axis, parameters):

    sample = get_sample(parameters)

    scan = ba.QzScan(q_axis)
    scan.setOffset(parameters["q_offset"])

    n_sig = 4.0
    n_samples = 25

    distr = ba.DistributionGaussian(0., 1., 25, 4.)
    scan.setAbsoluteQResolution(distr, parameters["q_res/q"])

    simulation = ba.SpecularSimulation(scan, sample)

    return simulation


def run_simulation(q_axis, fitParams):
    parameters = dict(fitParams, **fixedParams)

    simulation = get_simulation(q_axis, parameters)

    return simulation.simulate()


def qr(result):
    """
    Return q and reflectivity arrays from simulation result.
    """
    q = np.array(result.convertedBinCenters(ba.Coords_QSPACE))
    r = np.array(result.array(ba.Coords_QSPACE))

    return q, r


####################################################################
#                         Plot Handling                            #
####################################################################


def plot(q, r, exp, filename, P=None):
    """
    Plot the simulated result together with the experimental data.
    """
    fig = plt.figure()
    ax = fig.add_subplot(111)

    ax.errorbar(exp[0],
                exp[1],
                xerr=exp[3],
                yerr=exp[2],
                label="R",
                fmt='.',
                markersize=1.,
                linewidth=0.6,
                color='r')

    ax.plot(q, r, label="Simulation", color='C0', linewidth=0.5)

    ax.set_yscale('log')

    ax.set_xlabel("Q [nm$^{^-1}$]")
    ax.set_ylabel("R")

    y = 0.5
    if P is not None:
        for n, v in P.items():
            plt.text(0.7, y, f"{n} = {v:.3g}", transform=ax.transAxes)
            y += 0.05

    plt.tight_layout()
    plt.savefig(filename)


####################################################################
#                          Data Handling                           #
####################################################################


def get_Experimental_data(filepath, qmin, qmax):
    """
    Read experimental data, remove duplicate q values, convert q to nm^-1.
    """
    data = np.genfromtxt(filepath, unpack=True)

    r0 = np.where(data[0] - np.roll(data[0], 1) == 0)
    data = np.delete(data, r0, 1)

    data[0] = data[0]/angstrom
    data[3] = data[3]/angstrom

    data[1] = data[1]
    data[2] = data[2]

    so = np.argsort(data[0])

    data = data[:, so]

    minIndex = np.argmin(np.abs(data[0] - qmin))
    maxIndex = np.argmin(np.abs(data[0] - qmax))

    return data[:, minIndex:maxIndex + 1]


####################################################################
#                          Fit Function                            #
####################################################################


def run_fit_ba(q_axis, r_data, r_uncertainty, simulationFactory,
               startParams):

    fit_objective = ba.FitObjective()
    fit_objective.setObjectiveMetric("chi2")

    fit_objective.addSimulationAndData(
        lambda P: simulationFactory(q_axis, P), r_data,
        r_uncertainty, 1)

    fit_objective.initPrint(10)

    P = ba.Parameters()
    for name, p in startParams.items():
        P.add(name, p[0], min=p[1], max=p[2])

    minimizer = ba.Minimizer()
    print("DEBUG 4")
    result = minimizer.minimize(fit_objective.evaluate, P)
    print("DEBUG 5")
    fit_objective.finalize(result)

    return {r.name(): r.value for r in result.parameters()}


####################################################################
#                          Main Function                           #
####################################################################

if __name__ == '__main__':
    if True: # len(sys.argv) > 1 and sys.argv[1] == "fit":
        fixedParams = {
            # parameters can be moved here to keep them fixed
        }
        fixedParams = {d: v[0] for d, v in fixedParams.items()}

        startParams = {
            # own starting values
            "q_offset": (0, -0.02, 0.02),
            "q_res/q": (0, 0, 0.02),
            "t_pt/nm": (53, 40, 60),
            "r_si/nm": (1.22, 0, 5),
            "r_pt/nm": (0.25, 0, 5),
        }
        fit = True

    else:
        startParams = {}
        fixedParams = {
            # parameters from our own fit run
            'q_offset': 0.015085985992837999,
            'q_res/q': 0.010156450689003465,
            't_pt/nm': 48.564838355355405,
            'r_si/nm': 1.2857515425763575,
            'r_pt/nm': 0.2868252673771518,
        }
        fit = False

    PInitial = {d: v[0] for d, v in startParams.items()}

    qzs = np.linspace(qmin, qmax, scan_size)
    q, r = qr(run_simulation(qzs, PInitial))
    data = get_Experimental_data(filepath, qmin, qmax)

    plot(q, r, data, "PtLayerFit_initial.pdf",
         dict(PInitial, **fixedParams))

    if fit:
        print("Start fit")
        fitResult = run_fit_ba(data[0], data[1], data[2], run_simulation,
                               startParams)

        print("Fit Result:")
        print(fitResult)

        q, r = qr(run_simulation(qzs, fitParams=fitResult))
        plot(q, r, data, "PtLayerFit_fit.pdf",
             dict(fitResult, **fixedParams))

    plt.show()
