#
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

function print_hdrs()
{
  print "\
/*\n\
 *     Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.\n\
 *\n\
 * NVIDIA CORPORATION and its licensors retain all intellectual property\n\
 * and proprietary rights in and to this software, related documentation\n\
 * and any modifications thereto.  Any use, reproduction, disclosure or\n\
 * distribution of this software and related documentation without an express\n\
 * license agreement from NVIDIA CORPORATION is strictly prohibited.\n\
 *\n\
 */\n\
\n\n\
/*\n\
 *\n\
 * WARNING - this file is automatically generated. DO NOT EDIT.\n\
 *\n\
 */\n\
\n\
#include \"mth_intrinsics.h\" \n\
#include \"mth_tbldefs.h\" \n\
\n\n\
#if defined (TARGET_X8664) \n\
#include \"immintrin.h\" \n\
#else \n\
#error Unknown TARGET. Must be \"TARGET_X8664\" \n\
#endif \n\
\n\
#include \"mth_z2yy.h\"\n\
\n\
/*\n\
 * Common set of interface routines to convert an intrinsic math library call\n\
 * using Intel AVX-512 vectors in to two calls of the corresponding AVX2\n\
 * implementation.\n\
 *\n\
 * Note: code is common to both AVX-512 and KNL architectures.\n\
 *       Thus, have to use Intel intrinsics that are common to both systems.\n\
 */\n\
"
if (0) {
print "\n\
static\nvrs16_t\n\
__attribute__((noinline))\n\
__gs_z2yy_x(vrs16_t x, vrs8_t(*func)(vrs8_t))\n\
{ \n\
  vrs8_t rl, ru;\n\
  ru = func((vrs8_t) _mm512_extractf64x4_pd((__m512d)x, 1));\n\
  rl = func((vrs8_t) _mm512_castps512_ps256(x));\n\
  return (vrs16_t) _mm512_insertf64x4(_mm512_castpd256_pd512((__m256d)rl),\n\
                                      (__m256d)ru, 1);\n\
}\n\
\n\
static\nvrs16_t\n\
__attribute__((noinline))\n\
__gs_z2yy_xy(vrs16_t x, vrs16_t y, vrs8_t(*func)(vrs8_t, vrs8_t))\n\
{ \n\
  vrs8_t rl, ru;\n\
  ru = func((vrs8_t) _mm512_extractf64x4_pd((__m512d)x, 1),\n\
            (vrs8_t) _mm512_extractf64x4_pd((__m512d)y, 1));\n\
  rl = func((vrs8_t) _mm512_castps512_ps256(x),\n\
            (vrs8_t) _mm512_castps512_ps256(y));\n\
  return (vrs16_t) _mm512_insertf64x4(_mm512_castpd256_pd512((__m256d)rl),\n\
                                      (__m256d)ru, 1);\n\
}\n\
\n\
static\nvrs16_t\n\
__attribute__((noinline))\n\
__gs_z2yy_sincos(vrs16_t x, vrs8_t(*func)(vrs8_t))\n\
{ \n\
  vrs8_t su, sl, cu;\n\
  su = func((vrs8_t) _mm512_extractf64x4_pd((__m512d)x, 1));\n\
  asm(\"vmovaps  %%ymm1, %0\" : :\"m\"(cu) :);\n\
  sl = func((vrs8_t) _mm512_castps512_ps256(x));\n\
  asm(\"vinsertf64x4 $0x1,%0,%%zmm1,%%zmm1\" : : \"m\"(cu) : );\n\
  return (vrs16_t) _mm512_insertf64x4(_mm512_castpd256_pd512((__m256d)sl),\n\
                                      (__m256d)su, 1);\n\
}\n\
\n\
static\nvrs16_t\n\
__attribute__((noinline))\n\
__gs_z2yy_xk1(vrs16_t x, int64_t iy, vrs8_t(*func)(vrs8_t, int64_t))\n\
{\n\
  vrs8_t rl, ru;\n\
  ru = func((vrs8_t) _mm512_extractf64x4_pd((__m512d)x, 1), iy);\n\
  rl = func((vrs8_t) _mm512_castps512_ps256(x), iy);\n\
  return (vrs16_t) _mm512_insertf64x4(_mm512_castpd256_pd512((__m256d)rl),\n\
                                      (__m256d)ru, 1);\n\
}\n\
\n\
static\nvrs16_t\n\
__attribute__((noinline))\n\
__gs_z2yy_xi(vrs16_t x, vis16_t iy, vrs8_t(*func)(vrs8_t, vis8_t))\n\
{\n\
  vrs8_t rl, ru;\n\
  ru = func((vrs8_t) _mm512_extractf64x4_pd((__m512d)x, 1),\n\
            (vis8_t) _mm512_extractf64x4_pd((__m512d)iy, 1));\n\
  rl = func((vrs8_t) _mm512_castps512_ps256(x),\n\
            (vis8_t) _mm512_castps512_ps256((__m512)iy));\n\
  return (vrs16_t) _mm512_insertf64x4(_mm512_castpd256_pd512((__m256d)rl),\n\
                                     (__m256d)ru, 1);\n\
}\n\
\n\
static\nvrs16_t\n\
__attribute__((noinline))\n\
__gs_z2yy_xk(vrs16_t x, vid8_t iyu, vid8_t iyl, vrs8_t(*func)(vrs8_t, vid4_t, vid4_t))\n\
{\n\
  vrs8_t rl, ru;\n\
  ru = func((vrs8_t) _mm512_extractf64x4_pd((__m512d)x, 1),\n\
            (vid4_t) _mm512_extractf64x4_pd((__m512d)iyu, 1),\n\
            (vid4_t) _mm512_extractf64x4_pd((__m512d)iyu, 0));\n\
  rl = func((vrs8_t) _mm512_castps512_ps256(x),\n\
            (vid4_t) _mm512_extractf64x4_pd((__m512d)iyl, 1),\n\
            (vid4_t) _mm512_extractf64x4_pd((__m512d)iyl, 0));\n\
  return (vrs16_t) _mm512_insertf64x4(_mm512_castpd256_pd512((__m256d)rl),\n\
                                     (__m256d)ru, 1);\n\
}\n\
\n\
static\nvrd8_t\n\
__attribute__((noinline))\n\
__gd_z2yy_x(vrd8_t x, vrd4_t(*func)(vrd4_t))\n\
{\n\
  vrd4_t rl, ru;\n\
  ru = func((vrd4_t) _mm512_extractf64x4_pd((__m512d)x, 1));\n\
  rl = func((vrd4_t) _mm512_castpd512_pd256(x));\n\
  return (vrd8_t) _mm512_insertf64x4(_mm512_castpd256_pd512((__m256d)rl),\n\
                                     (__m256d)ru, 1);\n\
}\n\
\n\
static\nvrd8_t\n\
__attribute__((noinline))\n\
__gd_z2yy_xy(vrd8_t x, vrd8_t y, vrd4_t(*func)(vrd4_t, vrd4_t))\n\
{\n\
  vrd4_t rl, ru;\n\
  ru = func((vrd4_t) _mm512_extractf64x4_pd((__m512d)x, 1),\n\
            (vrd4_t) _mm512_extractf64x4_pd((__m512d)y, 1));\n\
  rl = func((vrd4_t) _mm512_castpd512_pd256(x),\n\
            (vrd4_t) _mm512_castpd512_pd256(y));\n\
  return (vrd8_t) _mm512_insertf64x4(_mm512_castpd256_pd512((__m256d)rl),\n\
                                     (__m256d)ru, 1);\n\
}\n\
\n\
static\nvrd8_t\n\
__attribute__((noinline))\n\
__gd_z2yy_sincos(vrd8_t x, vrd4_t(*func)(vrd4_t))\n\
{ \n\
  vrd4_t su, sl, cu;\n\
  su = func((vrd4_t) _mm512_extractf64x4_pd((__m512d)x, 1));\n\
  asm(\"vmovaps  %%ymm1, %0\" : :\"m\"(cu) :);\n\
  sl = func((vrd4_t) _mm512_castpd512_pd256(x));\n\
  asm(\"vinsertf64x4 $0x1,%0,%%zmm1,%%zmm1\" : : \"m\"(cu) : );\n\
  return (vrd8_t) _mm512_insertf64x4(_mm512_castpd256_pd512((__m256d)sl),\n\
                                      (__m256d)su, 1);\n\
}\n\
\n\
static\nvrd8_t\n\
__attribute__((noinline))\n\
__gd_z2yy_xk1(vrd8_t x, int64_t iy, vrd4_t(*func)(vrd4_t, int64_t))\n\
{\n\
  vrd4_t rl, ru;\n\
  ru = func((vrd4_t) _mm512_extractf64x4_pd((__m512d)x, 1), iy);\n\
  rl = func((vrd4_t) _mm512_castpd512_pd256(x), iy);\n\
  return (vrd8_t) _mm512_insertf64x4(_mm512_castpd256_pd512((__m256d)rl),\n\
                                     (__m256d)ru, 1);\n\
}\n\
\n\
static\nvrd8_t\n\
__attribute__((noinline))\n\
__gd_z2yy_xk(vrd8_t x, vid8_t iy, vrd4_t(*func)(vrd4_t, vid4_t))\n\
{\n\
  vrd4_t rl, ru;\n\
  ru = func((vrd4_t) _mm512_extractf64x4_pd((__m512d)x, 1),\n\
            (vid4_t) _mm512_extractf64x4_pd((__m512d)iy, 1));\n\
  rl = func((vrd4_t) _mm512_castpd512_pd256(x),\n\
            (vid4_t) _mm512_castpd512_pd256((__m512d)iy));\n\
  return (vrd8_t) _mm512_insertf64x4(_mm512_castpd256_pd512((__m256d)rl),\n\
                                     (__m256d)ru, 1);\n\
}\n\
\n\
static\nvrd8_t\n\
__attribute__((noinline))\n\
__gd_z2yy_xi(vrd8_t x, vis8_t iy, vrd4_t(*func)(vrd4_t, vis4_t))\n\
{\n\
  vrd4_t rl, ru;\n\
  ru = func((vrd4_t) _mm512_extractf64x4_pd((__m512d)x, 1),\n\
            (vis4_t) _mm256_extractf128_si256((__m256i)iy, 1));\n\
  rl = func((vrd4_t) _mm512_castpd512_pd256(x),\n\
            (vis4_t) _mm256_castsi256_si128((__m256i)iy));\n\
  return (vrd8_t) _mm512_insertf64x4(_mm512_castpd256_pd512((__m256d)rl),\n\
                                     (__m256d)ru, 1);\n\
}\n\
\n\
"
} # if (0)
}

function init_target_arrays()
{
    _mm = "_mm512"
    __m = "__m512"
    _si = "_si512"
   
  frps["f"]= ""
  frps["r"]= ""
  frps["p"]= ""
  sds["s"]= ""
  sds["d"]= ""
  iks["i"]= ""
  iks["k"]= ""
}

function VL(sd)
{
  return sd == "s" ? VLS : VLD
}

function VL_BY2(sd)
{
  return sd == "s" ? VLS/2 : VLD/2
}

function VR_T(sd) {
  return "vr" sd (sd == "s" ? VLS : VLD) "_t"
}

function VI_T(sd) {
  return "vi" sd (sd == "s" ? VLS : VLD) "_t"
}

function VR_T_BY2(sd) {
  return "vr" sd VL_BY2(sd) "_t"
}

function VI_T_BY2(sd) {
  return "vi" sd VL_BY2(sd) "_t"
}

function arg_ne_0(yarg, a, b)
{
  return yarg != 0 ? a : b
}

function func_r_decl(name, frp, sd, yarg)
{
  print "\n" VR_T(sd)
  print "__" frp sd "_" name "_" VL(sd) "_z2yy" \
        "(" VR_T(sd) " x" \
        arg_ne_0(yarg, ", " VR_T(sd) " y",  "") \
        ")"
        
}

function func_rr_def(name, frp, sd, safeval, yarg) {
  func_r_decl(name, frp, sd, yarg)
  print "{\n  " \
        VR_T_BY2(sd) " rl, ru;\n  " \
        VR_T_BY2(sd) " (*fptr) (" VR_T_BY2(sd) \
        arg_ne_0(yarg, ", " VR_T_BY2(sd), "") \
        ");"
  print "  fptr = (" VR_T_BY2(sd) "(*) (" VR_T_BY2(sd), \
          (yarg != 0) ? ", " VR_T_BY2(sd) : "", \
          ")) MTH_DISPATCH_TBL[func_" name "][sv_" sd "v" VL_BY2(sd) "][frp_" frp "];"
  print "  return __g" sd "_z2yy_" (name != "sincos" ? "x" : name) \
        arg_ne_0(yarg, "y", "") "(x" arg_ne_0(yarg, ", y", "") ", fptr);"
#  print "  ru = fptr((" VR_T_BY2(sd) ") _mm512_extractf64x4_pd((__m512d)x, 1)" \
#          arg_ne_0(yarg, ", (" VR_T_BY2(sd) ") _mm512_extractf64x4_pd((__m512d)y, 1)", "") \
#          ");"
#  print "  rl = fptr((" VR_T_BY2(sd) ") _mm512_castp" sd "512_p" sd "256(x)" \
#          arg_ne_0(yarg, ", (" VR_T_BY2(sd) ") _mm512_castp" sd "512_p" sd "256(y)") \
#          ");"
#  print "  return (" VR_T(sd) ") "\
#          "_mm512_insertf64x4(_mm512_castpd256_pd512((__m256d)rl), (__m256d)ru, 1);"
  print "}"
}

function func_pow_args(sd, is_scalar, ik, with_vars, by2)
{
  ll = arg_ne_0(by2, VR_T_BY2(sd), VR_T(sd)) arg_ne_0(with_vars, " x", "") ", "
  if (is_scalar) {
    ll = ll ((ik == "i") ? "int32_t" : "int64_t") arg_ne_0(with_vars, " iy", "")
  } else {
    if (sd == "s" && ik == "k") {
      ll = ll arg_ne_0(by2, VI_T_BY2("d"), VI_T("d")) \
           arg_ne_0(with_vars, " iyu", "") ", " \
           arg_ne_0(by2, VI_T_BY2("d"), VI_T("d")) \
           arg_ne_0(with_vars, " iyl", "")
    } else {
      ll = ll arg_ne_0(by2, VI_T_BY2(sd), VI_T(sd)) \
           arg_ne_0(with_vars, " iy", "")
    }
  }

  return ll
}

function func_pow_decl(name, frp, sd, is_scalar, ik)
{
  print "\n" VR_T(sd)
  l = "__" frp sd "_" name arg_ne_0(is_scalar, ik"1", ik)"_" VL(sd) "_z2yy" "("
  l = l func_pow_args(sd, is_scalar, ik, 1, 0)
  l = l ")"
  print l
        
}

function func_pow_decl_scalar(name, frp, sd, ik)
{
#
# Waring: when iy is an int32_t, we promote it to an int64_t
# and then call the func_powk1 entry point from the dispatch
# table.
# It is inconceivable that the int32_t and int64_t power routines
# do not produce the same bit-for-bit results.
#
  print VR_T(sd)
  print "__" frp sd "_pow" ik "1_" VL(sd) "_z2yy(" VR_T(sd) " x, " \
        (ik == "i" ? "int32_t" : "int64_t") " iy)"
  print "{"
  print "  " VR_T_BY2(sd) " (*fptr)(" VR_T_BY2(sd) ", int64_t);"
  print "  fptr = (" VR_T_BY2(sd) "(*) (" \
        VR_T_BY2(sd) ", int64_t" \
        ")) MTH_DISPATCH_TBL[func_" name \
        "k1][sv_" sd "v" VL_BY2(sd) "][frp_" frp "];"
  print "  return __g" sd "_z2yy_xk1(x, iy, fptr);\n}\n"
}

function func_pow_decl_vect(name, frp, sd, ik)
{
  print VR_T(sd)
  print "__" frp sd "_pow" ik "_" VL(sd) "_z2yy(" VR_T(sd) " x, " \
        VI_T(sd) " iy)"
  print "{"
  print "  " VR_T_BY2(sd) " (*fptr)(" VR_T_BY2(sd) ", " VI_T_BY2(sd) ");"
  print "  fptr = (" VR_T_BY2(sd) "(*) (" \
        VR_T_BY2(sd) ", " VI_T_BY2(sd) \
        ")) MTH_DISPATCH_TBL[func_" name ik \
        "][sv_" sd "v" VL_BY2(sd) "][frp_" frp "];"
  print "  return __g" sd "_z2yy_x" ik "(x, iy, fptr);\n}\n"
}

function func_pow_decl_vect_di(name, frp)
{
  print VR_T("d")
  print "__" frp "d_powi_" VL("d") "_z2yy(" VR_T("d") " x, " \
        VI_T_BY2("s") " iy)"
  print "{"
#  print "  " VR_T_BY2(sd) " (*fptr)(" VR_T_BY2(sd) ", " VI_T_BY2("s") ");"
#  print "  fptr = (" VR_T_BY2("d") "(*) (" \
#        VR_T_BY2("d") ", " VI_T_BY2("s") \
#        ")) MTH_DISPATCH_TBL[func_" name "i" \
#        "][sv_" sd "v" VL_BY2("d") "][frp_" frp "];"
# Gave up trying to make this general.
  print "  vrd4_t (*fptr)(vrd4_t, vis4_t);"
  print "  fptr = (vrd4_t(*) (vrd4_t, vis4_t)) MTH_DISPATCH_TBL[func_powi][sv_dv4][frp_" frp "];"
  print "  return __gd_z2yy_xi(x, iy, fptr);\n}\n"
}

function func_pow_decl_vect_sk(name, frp)
{
  print VR_T("s")
  print "__" frp "s_powk_" VL("s") "_z2yy(" VR_T("s") " x, " \
        VI_T("d") " iyu, " VI_T("d") " iyl)"
  print "{"
  print "  vrs8_t (*fptr)(vrs8_t, vid4_t, vid4_t);"
  print "  fptr = (vrs8_t(*) (vrs8_t, vid4_t, vid4_t)) MTH_DISPATCH_TBL[func_powk][sv_dv4][frp_" frp "];"
  print "  return __gs_z2yy_xk(x, iyu, iyl, fptr);\n}\n"
}

function func_pow_def(name, frp, sd, is_scalar, ik)
{
  if (is_scalar) {
    func_pow_decl_scalar(name, frp, sd, ik)
  } else {
    # Four variants of R(:)**I(:)
    # 1) sd == "d" && ik == "k" - trivial both args same size(512)
    # 2) sd == "s" && ik == "i" - trivial both args same size(512)
    # 3) sd == "d" && ik == "i" - x is 512, iy is effectively 256
    # 4) sd == "s" && ik == "k" - x is 512, iy is effectively 1024

    # Trivial first
    if (sd == "d" && ik == "k" || sd == "s" && ik == "i") {
      func_pow_decl_vect(name, frp, sd, ik)
    } else if (sd == "d" && ik == "i") {
      func_pow_decl_vect_di(name, frp)
    } else {
      func_pow_decl_vect_sk(name, frp)
    }
  }
  return
}
function XXfunc_pow_def(name, frp, sd, is_scalar, ik)
{
  func_pow_decl(name, frp, sd, is_scalar, ik)
  print "{"

  if (is_scalar) {
    print "  " VR_T_BY2(sd) " (*fptr)(" VR_T_BY2(sd) ", int64_t);"
    print "  fptr = (" VR_T_BY2(sd) "(*) (" \
          func_pow_args(sd, is_scalar, "k", 0, 1) \
          ")) MTH_DISPATCH_TBL[func_" name arg_ne_0(is_scalar, ik"1", ik) \
          "][sv_" sd "v" VL_BY2(sd) "][frp_" frp "];"
    print "  return __g" sd "_z2yy_xk1(x, iy, fptr);\n}\n"
    return
  }

#
# This has turned in to a bit of a mess/hack.
#

  print "  " VR_T_BY2(sd) " rl, ru;\n  " \
        VR_T_BY2(sd) " (*fptr) (" func_pow_args(sd, is_scalar, ik, 0, 1) ");"
  print "  fptr = (" VR_T_BY2(sd) "(*) (" \
        func_pow_args(sd, is_scalar, ik, 0, 1) \
        ")) MTH_DISPATCH_TBL[func_" name arg_ne_0(is_scalar, ik"1", ik) \
        "][sv_" sd "v" VL_BY2(sd) "][frp_" frp "];"

  printf ("%s", "  ru = fptr((" VR_T_BY2(sd) ") _mm512_extractf64x4_pd((__m512d)x, 1), ")
  if (is_scalar) {
    printf ("%s", "iy")
  } else {
    if (sd == "s" && ik == "k") {
      printf ("%s", "(" VI_T_BY2("d") \
             ") _mm512_extractf64x4_pd((__m512d)iyu, 1) , ("\
             VI_T_BY2("d") ") _mm512_extractf64x4_pd((__m512d)iyl, 1)")
    } else {
      printf ("%s", "(" VI_T_BY2(sd) \
              ") _mm512_extractf64x4_pd((__m512d) iy, 1)")
    }
  }
  print ");"

  printf ("%s", "  rl = fptr((" VR_T_BY2(sd) ") _mm512_castp" sd "512_p" sd "256(x), ")
  if (is_scalar) {
    printf ("%s", "iy")
  } else {
    if (sd == "s" && ik == "k") {
      printf("%s", "(" VI_T_BY2("d") \
             ") _mm512_castps512_ps256((__m512)iyu) , ("\
             VI_T_BY2("d") ") _mm512_castps512_ps256((__m512)iyl)")
    } else {
    printf ("%s", "(" VI_T_BY2(sd) \
            ") _mm512_castps512_ps256((__m512)iy)")
    }
  }
  print ");"
  print "  return (" VR_T(sd) ") _mm512_insertf64x4(_mm512_castpd256_pd512((__m256d)rl), (__m256d)ru, 1);"
  print "}"
}

function do_all_rr(name, safeval, yarg)
{

  for (frp in frps) {
    for (sd in sds) {
      func_rr_def(name, frp, sd, safeval, yarg)
    }
  }
}

function do_all_pow_r2i()
{
  for (frp in frps) {
#    frp = "f"
    for (sd in sds) {
      for (ik in iks) {
#         ik = "k"
        func_pow_def("pow", frp, sd, 1, ik)
        func_pow_def("pow", frp, sd, 0, ik)
      }
    }
  }
}

BEGIN {
  if (TARGET != "X8664") {
    print "TARGET must X8664"
    exit(1)
  }
  MAX_VREG_SIZE=512
  VLS = 16
  VLD = 8

# Initialize some associative arrays
  init_target_arrays()

  print_hdrs()
  one_arg = 0
  two_args = 1


#  do_all_rr("acos", 0, one_arg)
#  do_all_rr("atan2", 1, two_args)
  do_all_rr("acos", 0, one_arg)
  do_all_rr("asin", 0, one_arg)
  do_all_rr("atan", 0, one_arg)
  do_all_rr("atan2", 1, two_args)
  do_all_rr("cos", 0, one_arg)
  do_all_rr("sin", 0, one_arg)
  do_all_rr("tan", 0, one_arg)
  do_all_rr("sincos", 0, one_arg)
  do_all_rr("cosh", 0, one_arg)
  do_all_rr("sinh", 0, one_arg)
  do_all_rr("tanh", 0, one_arg)
  do_all_rr("exp", 0, one_arg)
  do_all_rr("log", 1, one_arg)
  do_all_rr("log10", 1, one_arg)
  do_all_rr("pow", 0, two_args)
  do_all_rr("mod", 1, two_args)
  do_all_rr("aint", 0, one_arg)
  do_all_rr("ceil", 0, one_arg)
  do_all_rr("floor", 0, one_arg)
#not used  do_all_rr("div", 1, two_args)
#not used  do_all_rr("sqrt", 0, one_arg)

  do_all_pow_r2i()
}
