--- lib/system/scrypt.lua
-- crypting password
-- https://github.com/bungle/lua-resty-scrypt
-- @module scrypt
-- scrypt password encryption
--
local ffi = require "mffi"
local bit = require "bit"
local util = require "util"
local scrypt = util.loadDll("scrypt")
local print = util.print
local crypto

local band = bit.band
local rshift = bit.rshift
local hex = ffi.cstrAnchor("0123456789abcdef")

ffi.cdef [[
int libscrypt_check(char *mcf, const char *password);
int libscrypt_scrypt(const uint8_t *, size_t, const uint8_t *, size_t, uint64_t, uint32_t, uint32_t, uint8_t *, size_t);
]]
ffi.cdef [[
int crypto_scrypt(
    const uint8_t *passwd,
    size_t passwdlen,
    const uint8_t *salt,
    size_t saltlen,
    uint64_t N,
    uint32_t _r,
    uint32_t _p,
    uint8_t *buf,
    size_t buflen);
int calibrate(
    size_t maxmem,
    double maxmemfrac,
    double maxtime,
    uint64_t *n,
    uint32_t *r,
    uint32_t *p);
int scryptenc_cpuperf(double *);
]]

local function hex_dump(dst, src, len)
	-- dumps hex in reverse ordr (why reverse?)
	--[[
  http://lxr.nginx.org/source/src/core/ngx_string.c#1089
  1088 u_char *
  1089 ngx_hex_dump(u_char *dst, u_char *src, size_t len)
  1090 {
  1091     static u_char  hex[] = "0123456789abcdef";
  1092
  1093     while (len--) {
  1094         *dst++ = hex[*src >> 4];
  1095         *dst++ = hex[*src++ & 0xf];
  1096     }
  1097
  1098     return dst;
  1099 }
  1100

  u_char * ngx_hex_dump(u_char *dst, const u_char *src, size_t len);
  ]]

	local pos = 0
	for i = 0, len - 1 do
		dst[pos] = hex[rshift(src[i], 4)]
		pos = pos + 1
		dst[pos] = hex[band(src[i], 0xf)]
		pos = pos + 1
	end
end

local s = 32
local z = 64
local t = ffi.typeof("uint8_t[?]")
local n = ffi.newAnchor("uint64_t[1]", 32768)
local r = ffi.newAnchor("uint32_t[1]", 8)
local p = ffi.newAnchor("uint32_t[1]", 1)
local b = ffi.newAnchor(t, s)
local h = ffi.newAnchor(t, z)

local function randomBytes(buf, bufLen)
	if not crypto then
		crypto = require"system/crypto".crypto or {}
	end
	crypto.RAND_pseudo_bytes(buf, bufLen)
end

local function random(len)
	if not crypto then
		crypto = require"system/crypto".crypto or {}
	end
	local s = ffi.newNoAnchor(t, len)
	crypto.RAND_pseudo_bytes(s, len)
	if not s then
		return nil
	end
	local b = ffi.newNoAnchor(t, len * 2)
	hex_dump(b, s, len)
	return ffi.string(b, len * 2)
	--[[
	local function random(len)
			local s = ffi.newNoAnchor(t, len)
			crypto.RAND_pseudo_bytes(s, len)
			return ffi.string(s, len)
	end
	]]
end

local function crypt(opts)
	local secret, salt, saltsize, keysize = '', nil, 8, 32
	if type(opts) ~= "table" then
		secret = tostring(opts)
	else
		if type(opts.secret) == "string" then
			secret = opts.secret
		end
		if type(opts.keysize) == "number" then
			if opts.keysize < 16 then
				keysize = 16
			elseif opts.keysize > 512 then
				keysize = 512
			else
				keysize = opts.keysize
			end
			if keysize ~= s then
				s, z = keysize, keysize * 2
				ffi.free(b)
				ffi.free(h)
				b, h = ffi.newAnchor(t, s), ffi.newAnchor(t, z)
			end
		end
		if type(opts.n) == "number" then
			if (tonumber(n[0]) ~= opts.n) then
				n[0] = opts.n
			end
		end
		if type(opts.r) == "number" then
			if (tonumber(r[0]) ~= opts.r) then
				r[0] = opts.r
			end
		end
		if type(opts.p) == "number" then
			if (tonumber(p[0]) ~= opts.p) then
				p[0] = opts.p
			end
		end
		if type(opts.salt) == "string" then
			salt = opts.salt
		end
		if type(opts.saltsize) == "number" then
			if opts.saltsize < 8 then
				saltsize = 8
			elseif (opts.saltsize > 32) then
				saltsize = 32
			else
				saltsize = opts.saltsize
			end
		end
	end
	if not salt then
		salt = random(saltsize)
	end
	-- util.print("%s, %d, %s, %d, %d, %d, %d, %s, %d", secret, #secret, salt, #salt, tonumber(n[0]), tonumber(r[0]), tonumber(p[0]), tostring(b), s)
	local ret
	if ffi.arch ~= "arm64" then
		ret = scrypt.crypto_scrypt(secret, #secret, salt, #salt, n[0], r[0], p[0], ffi.cast("uint8_t *", b), s)
	else
		ret = scrypt.libscrypt_scrypt(secret, #secret, salt, #salt, n[0], r[0], p[0], ffi.cast("uint8_t *", b), s)
	end

	if ret == 0 then -- ffi.cast("uint8_t *", b) is neede for pure lua 5.1 ffi
		hex_dump(h, b, s)
		return string.format("%02x$%02x$%02x$%s$%s", tonumber(n[0]), r[0], p[0], salt, ffi.string(h, z))
	else
		return false
	end
end

local function check(secret, hash)
	if hash == nil then
		util.printError("hash is nil")
		return false
	elseif hash == "" then
		util.printError("hash is empty string")
		return false
	elseif secret == nil then
		util.printError("secret is nil")
		return false
	end
	local opts = {}
	local n, r, p, salt = hash:match(("([^$]*)$"):rep(5))
	opts.secret = secret
	opts.salt = salt
	opts.n = tonumber(n, 16)
	opts.r = tonumber(r, 16)
	opts.p = tonumber(p, 16)
	local passHash = crypt(opts)
	return passHash == hash
end

local function calibrate(maxmem, maxmemfrac, maxtime)
	if type(maxmem) ~= "number" then
		maxmem = 1048576
	end
	if type(maxmemfrac) ~= "number" then
		maxmemfrac = 0.5
	end
	if type(maxtime) ~= "number" then
		maxtime = 0.2
	end
	if (scrypt.calibrate(maxmem, maxmemfrac, maxtime, n, r, p) == 0) then
		return tonumber(n[0]), r[0], p[0]
	else
		return false
	end
end

local function memoryuse(n, r, p)
	return 128 * (r or 8) * (p or 1) + 256 * (r or 8) + 128 * (r or 8) * (n or 32768);
end

local function cpuperf()
	--[[local func = pcall(function() return ffi.cast("void*", scrypt.scryptenc_cpuperf) end)
	if not func then
		print(" - scrypt.scryptenc_cpuperf was not found from shared library")
		return -1
	end	]]
	local d = ffi.newNoAnchor("double[1]")
	local ret = scrypt.scryptenc_cpuperf(d)
	return d[0]
end

if not ... then -- test when running this file directly with no require "scrypt"

	local t = "ABCz"
	local len = #t
	local b = ffi.newNoAnchor(t, len * 2)
	hex_dump(b, t, len)
	print(ffi.string(t, len), ffi.string(b, len * 2))

	local opts = {
		secret = "My Secret",
		keysize = 32,
		n = 32768,
		r = 8,
		p = 1,
		-- salt     = "random (saltsize) bytes generated with OpenSSL",
		saltsize = 8
	}
	print("crypt(opts)")
	local time = util.seconds()
	local timeAll = time
	local hash = crypt(opts)
	time = util.seconds(time)
	print("crypt(opts) time: " .. time .. " seconds")
	print("check('My Secret')")
	time = util.seconds()
	local valid = check("My Secret", hash) -- valid holds true
	time = util.seconds(time)
	print("check('My Secret') time: " .. time .. " seconds")
	print("check('My Guess')")
	valid = check("My Guess", hash) -- valid holds false

	print("crypt('My Secret')")
	hash = crypt("My Secret") -- returns a hash that can be stored in db
	print("check('My Secret')")
	valid = check("My Secret", hash) -- valid holds true
	print("check('My Guess')")
	valid = check("My Guess", hash) -- valid holds false

	print("calibrate()")
	local n, r, p = calibrate() -- returns n,r,p calibration values

	print("scryptenc_cpuperf(d)")
	print("scryptenc_cpuperf:", cpuperf())
	time = util.seconds(timeAll)
	print("time all: " .. time .. " seconds")
	print(jit.version, jit.arch)
end

return {randomBytes = randomBytes, random = random, crypt = crypt, check = check, calibrate = calibrate, memoryuse = memoryuse, cpuperf = cpuperf}
