/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2020 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

      Math       abs             Absolute value
      Math       sqr             Square
      Math       sqrt            Square root
      Math       exp             Exponential
      Math       ln              Natural logarithm
      Math       log10           Base 10 logarithm
      Math       sin             Sine
      Math       cos             Cosine
      Math       tan             Tangent
      Math       asin            Arc sine
      Math       acos            Arc cosine
      Math       atan            Arc tangent
      Math       pow             Power
      Math       reci            Reciprocal
*/

#include <cstdlib>
#include <cdi.h>

#include "cdo_options.h"
#include "process_int.h"
#include "param_conversion.h"

static void
check_out_of_range(size_t &nmiss, const size_t len, double missval, Varray<double> &array, double rmin, double rmax)
{
  if (nmiss)
    {
      for (size_t i = 0; i < len; i++)
        if (!DBL_IS_EQUAL(array[i], missval) && (array[i] < rmin || array[i] > rmax))
          {
            array[i] = missval;
            nmiss++;
          }
    }
  else
    {
      for (size_t i = 0; i < len; i++)
        if (array[i] < rmin || array[i] > rmax)
          {
            array[i] = missval;
            nmiss++;
          }
    }
}

static void
check_lower_range(size_t &nmiss, const size_t len, double missval, Varray<double> &array, double rmin)
{
  if (nmiss)
    {
      for (size_t i = 0; i < len; i++)
        if (!DBL_IS_EQUAL(array[i], missval) && array[i] < rmin)
          {
            array[i] = missval;
            nmiss++;
          }
    }
  else
    {
      for (size_t i = 0; i < len; i++)
        if (array[i] < rmin)
          {
            array[i] = missval;
            nmiss++;
          }
    }
}

// clang-format off
static double func_nop(double value) { return value; }
static double func_int(double value) { return (int)value; }
static double func_sqr(double value) { return value * value; }
static double func_reci(double value) { return 1. / value; }
static double func_not(double value) { return IS_EQUAL(value, 0); }
// clang-format on

static void
math_varray_func(double (*func)(double), const size_t nmiss, const size_t len, double missval1, const Varray<double> &array1,
                 Varray<double> &array2)
{
  if (nmiss)
    for (size_t i = 0; i < len; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : func(array1[i]);
  else
    for (size_t i = 0; i < len; i++) array2[i] = func(array1[i]);
}

static void
math_varray_sqr_cplx(const size_t len, const Varray<double> &array1, Varray<double> &array2)
{
  for (size_t i = 0; i < len; i++)
    {
      array2[i * 2] = array1[i * 2] * array1[i * 2] + array1[i * 2 + 1] * array1[i * 2 + 1];
      array2[i * 2 + 1] = 0;
    }
}

static void
math_varray_sqrt_cplx(const size_t len, double missval1, const Varray<double> &array1, Varray<double> &array2)
{
  double missval2 = missval1;
  for (size_t i = 0; i < len; i++)
    {
      double abs = SQRTMN(ADDMN(MULMN(array1[2 * i], array1[2 * i]), MULMN(array1[2 * i + 1], array1[2 * i + 1])));
      array2[i * 2] = MULMN(1 / std::sqrt(2.), SQRTMN(ADDMN(array1[i * 2], abs)));
      array2[i * 2 + 1] = MULMN(1 / std::sqrt(2.), DIVMN(array1[2 * i + 1], SQRTMN(ADDMN(array1[2 * i], abs))));
    }
}

static void
math_varray_conj_cplx(const size_t len, const Varray<double> &array1, Varray<double> &array2)
{
  for (size_t i = 0; i < len; i++)
    {
      array2[i * 2] = array1[i * 2];
      array2[i * 2 + 1] = -array1[i * 2 + 1];
    }
}

static void
math_varray_abs_cplx(const size_t len, double missval1, const Varray<double> &array1, Varray<double> &array2)
{
  double missval2 = missval1;
  for (size_t i = 0; i < len; i++)
    {
      array2[i] = SQRTMN(ADDMN(MULMN(array1[2 * i], array1[2 * i]), MULMN(array1[2 * i + 1], array1[2 * i + 1])));
    }
}

static void
math_varray_arg_cplx(const size_t len, double missval1, const Varray<double> &array1, Varray<double> &array2)
{
  for (size_t i = 0; i < len; i++)
    {
      array2[i] = (DBL_IS_EQUAL(array1[2 * i], missval1) || DBL_IS_EQUAL(array1[2 * i + 1], missval1))
                      ? missval1
                      : atan2(array1[2 * i + 1], array1[2 * i]);
    }
}

enum struct Oper
{
  Abs,
  Int,
  Nint,
  Sqr,
  Sqrt,
  Exp,
  Ln,
  Log10,
  Sin,
  Cos,
  Tan,
  Asin,
  Acos,
  Atan,
  Pow,
  Rand,
  Reci,
  Not,
  Conj,
  Re,
  Im,
  Arg
};

static void
addOperators(void)
{
  // clang-format off
  cdoOperatorAdd("abs",   (int)Oper::Abs,   0, nullptr);
  cdoOperatorAdd("int",   (int)Oper::Int,   0, nullptr);
  cdoOperatorAdd("nint",  (int)Oper::Nint,  0, nullptr);
  cdoOperatorAdd("sqr",   (int)Oper::Sqr,   0, nullptr);
  cdoOperatorAdd("sqrt",  (int)Oper::Sqrt,  0, nullptr);
  cdoOperatorAdd("exp",   (int)Oper::Exp,   0, nullptr);
  cdoOperatorAdd("ln",    (int)Oper::Ln,    0, nullptr);
  cdoOperatorAdd("log10", (int)Oper::Log10, 0, nullptr);
  cdoOperatorAdd("sin",   (int)Oper::Sin,   0, nullptr);
  cdoOperatorAdd("cos",   (int)Oper::Cos,   0, nullptr);
  cdoOperatorAdd("tan",   (int)Oper::Tan,   0, nullptr);
  cdoOperatorAdd("asin",  (int)Oper::Asin,  0, nullptr);
  cdoOperatorAdd("acos",  (int)Oper::Acos,  0, nullptr);
  cdoOperatorAdd("atan",  (int)Oper::Atan,  0, nullptr);
  cdoOperatorAdd("pow",   (int)Oper::Pow,   0, nullptr);
  cdoOperatorAdd("rand",  (int)Oper::Rand,  0, nullptr);
  cdoOperatorAdd("reci",  (int)Oper::Reci,  0, nullptr);
  cdoOperatorAdd("not",   (int)Oper::Not,   0, nullptr);
  cdoOperatorAdd("conj",  (int)Oper::Conj,  0, nullptr);
  cdoOperatorAdd("re",    (int)Oper::Re,    0, nullptr);
  cdoOperatorAdd("im",    (int)Oper::Im,    0, nullptr);
  cdoOperatorAdd("arg",   (int)Oper::Arg,   0, nullptr);
  // clang-format on
}

void *
Math(void *process)
{
  int nrecs;
  size_t nmiss;
  size_t i;

  cdoInitialize(process);

  addOperators();

  const auto operatorID = cdoOperatorID();
  const auto operfunc = (Oper) cdoOperatorF1(operatorID);

  double rc = 0;
  if (operfunc == Oper::Pow)
    {
      operatorInputArg("value");
      rc = parameter2double(cdoOperatorArgv(0));
    }
  else
    {
      operatorCheckArgc(0);
    }

  if (operfunc == Oper::Rand) std::srand(Options::Random_Seed);

  const auto streamID1 = cdoOpenRead(0);

  const auto vlistID1 = cdoStreamInqVlist(streamID1);
  const auto vlistID2 = vlistDuplicate(vlistID1);

  if (operfunc == Oper::Re || operfunc == Oper::Im || operfunc == Oper::Abs || operfunc == Oper::Arg)
    {
      const auto nvars = vlistNvars(vlistID2);
      for (int varID = 0; varID < nvars; ++varID)
        {
          if (vlistInqVarDatatype(vlistID2, varID) == CDI_DATATYPE_CPX32) vlistDefVarDatatype(vlistID2, varID, CDI_DATATYPE_FLT32);
          if (vlistInqVarDatatype(vlistID2, varID) == CDI_DATATYPE_CPX64) vlistDefVarDatatype(vlistID2, varID, CDI_DATATYPE_FLT64);
        }
    }

  const auto taxisID1 = vlistInqTaxis(vlistID1);
  const auto taxisID2 = taxisDuplicate(taxisID1);
  vlistDefTaxis(vlistID2, taxisID2);

  auto gridsizemax = vlistGridsizeMax(vlistID1);
  if (vlistNumber(vlistID1) != CDI_REAL) gridsizemax *= 2;

  Varray<double> array1(gridsizemax), array2(gridsizemax);

  const auto streamID2 = cdoOpenWrite(1);
  cdoDefVlist(streamID2, vlistID2);

  VarList varList1;
  varListInit(varList1, vlistID1);

  int tsID = 0;
  while ((nrecs = cdoStreamInqTimestep(streamID1, tsID)))
    {
      taxisCopyTimestep(taxisID2, taxisID1);
      cdoDefTimestep(streamID2, tsID);

      for (int recID = 0; recID < nrecs; recID++)
        {
          int varID, levelID;
          cdoInqRecord(streamID1, &varID, &levelID);
          cdoReadRecord(streamID1, &array1[0], &nmiss);

          const auto missval1 = varList1[varID].missval;
          const auto n = varList1[varID].gridsize;
          const auto number = varList1[varID].nwpv;

          if (number == CDI_REAL)
            {
              // clang-format off
              switch (operfunc)
                {
                case Oper::Abs:   math_varray_func(std::fabs, nmiss, n, missval1, array1, array2); break;
                case Oper::Int:   math_varray_func(func_int, nmiss, n, missval1, array1, array2); break;
                case Oper::Nint:  math_varray_func(std::round, nmiss, n, missval1, array1, array2); break;
                case Oper::Sqr:   math_varray_func(func_sqr, nmiss, n, missval1, array1, array2); break;
                case Oper::Sqrt:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : SQRTMN(array1[i]);
                  break;
                case Oper::Exp:   math_varray_func(std::exp, nmiss, n, missval1, array1, array2); break;
                case Oper::Ln:    check_lower_range(nmiss, n, missval1, array1, -1);
                                  math_varray_func(std::log, nmiss, n, missval1, array1, array2); break;
                case Oper::Log10: check_lower_range(nmiss, n, missval1, array1, -1);
                                  math_varray_func(std::log10, nmiss, n, missval1, array1, array2); break;
                case Oper::Sin:   math_varray_func(std::sin, nmiss, n, missval1, array1, array2); break;
                case Oper::Cos:   math_varray_func(std::cos, nmiss, n, missval1, array1, array2); break;
                case Oper::Tan:   math_varray_func(std::tan, nmiss, n, missval1, array1, array2); break;
                case Oper::Asin:  check_out_of_range(nmiss, n, missval1, array1, -1, 1);
                                  math_varray_func(std::asin, nmiss, n, missval1, array1, array2); break;
                case Oper::Acos:  check_out_of_range(nmiss, n, missval1, array1, -1, 1);
                                  math_varray_func(std::acos, nmiss, n, missval1, array1, array2); break;
                case Oper::Atan:  math_varray_func(std::atan, nmiss, n, missval1, array1, array2); break;
                case Oper::Pow:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : std::pow(array1[i], rc);
                  break;
                case Oper::Rand:
                  for (i = 0; i < n; i++) array2[i] = DBL_IS_EQUAL(array1[i], missval1) ? missval1 : ((double) std::rand()) / ((double) RAND_MAX);
                  break;
                case Oper::Reci:  math_varray_func(func_reci, nmiss, n, missval1, array1, array2); break;
                case Oper::Not:   math_varray_func(func_not, nmiss, n, missval1, array1, array2); break;
                case Oper::Re:
                case Oper::Arg:   math_varray_func(func_nop, nmiss, n, missval1, array1, array2); break;
                default: cdoAbort("Operator not implemented for real data!"); break;
                }
              // clang-format on

              nmiss = varrayNumMV(n, array2, missval1);
            }
          else
            {
              // clang-format off
              switch (operfunc)
                {
                case Oper::Sqr:   math_varray_sqr_cplx(n, array1, array2); break;
                case Oper::Sqrt:  math_varray_sqrt_cplx(n, missval1, array1, array2); break;
                case Oper::Conj:  math_varray_conj_cplx(n, array1, array2); break;
                case Oper::Re:    for (i = 0; i < n; i++) array2[i] = array1[i * 2]; break;
                case Oper::Im:    for (i = 0; i < n; i++) array2[i] = array1[i * 2 + 1]; break;
                case Oper::Abs:   math_varray_abs_cplx(n, missval1, array1, array2); break;
                case Oper::Arg:   math_varray_arg_cplx(n, missval1, array1, array2); break;
                default: cdoAbort("Fields with complex numbers are not supported by this operator!"); break;
                }
              // clang-format on

              nmiss = 0;
            }

          cdoDefRecord(streamID2, varID, levelID);
          cdoWriteRecord(streamID2, array2.data(), nmiss);
        }

      tsID++;
    }

  cdoStreamClose(streamID2);
  cdoStreamClose(streamID1);

  vlistDestroy(vlistID2);

  cdoFinish();

  return nullptr;
}
