local ffi = require "ffi" local bit = require "bit" local M = {} ffi.cdef [[ typedef struct bigint { int8_t sign; uint32_t size; uint32_t values[?]; } bigint; ]] local int_mt = {} local function normalize(int) while int.size > 1 do if int.values[int.size - 1] ~= 0 then break end int.size = int.size - 1 end if int.size == 1 and int.values[0] == 0 then int.sign = 0 end return int end local function unsigned_add(a, b) local size = math.max(a.size, b.size) + 1 local result = ffi.new("bigint", size) local carry = 0ULL for i = 0, size - 1 do local sum = carry if i < a.size then sum = sum + a.values[i] end if i < b.size then sum = sum + b.values[i] end result.values[i] = sum carry = bit.rshift(sum, 32) end result.size = size result.sign = 1 return result end local function unsigned_sub(a, b) local result = ffi.new("bigint", a.size) local carry = 0LL for i = 0, a.size - 1 do local diff = carry + a.values[i] if i < b.size then diff = diff - b.values[i] end if diff < 0 then diff = diff + 2^32 carry = -1LL else carry = 0LL end result.values[i] = diff end result.size = a.size result.sign = 1 return result end -- some sort of bug or something occurs with this! require"jit".off(unsigned_sub) local function unsigned_cmp(a, b) if a.size > b.size then return 1 elseif a.size < b.size then return -1 end for i = a.size - 1, 0, -1 do if a.values[i] > b.values[i] then return 1 elseif b.values[i] > a.values[i] then return -1 end end return 0 end local function signed_cmp(a, b) if a.sign > b.sign then return 1 elseif a.sign < b.sign then return -1 end if a.sign > 0 then return unsigned_cmp(a, b) else return -unsigned_cmp(a, b) end end local function lshift(int, n) if n == 0 then return int end if int == 0 then return int end assert(n > 0) local size = int.size + n / 32 + 1 local result = ffi.new("bigint", size) result.size = size for i = 0, int.size - 1 do local val = bit.lshift(ffi.cast("uint64_t", int.values[i]), n % 32) local index = i + math.floor(n / 32) result.values[index] = bit.bor(val, result.values[index]) result.values[index + 1] = bit.rshift(val, 32) end result.sign = int.sign normalize(result) return result end function int_mt.__add(a, b) a = M.is_bigint(a) and a or M.int(a) b = M.is_bigint(b) and b or M.int(b) local result if a.sign < 0 and b.sign < 0 then result = unsigned_add(a, b) result.sign = -1 elseif a.sign < 0 or b.sign < 0 then local neg = a.sign < 0 and a or b local pos = a.sign < 0 and b or a if unsigned_cmp(pos, neg) > 0 then result = unsigned_sub(pos, neg) else result = unsigned_sub(neg, pos) result.sign = -1 end else result = unsigned_add(a, b) end normalize(result) return result end function int_mt.__sub(a, b) return a + -b end function int_mt.__mul(a, b) a = M.is_bigint(a) and a or M.int(a) b = M.is_bigint(b) and b or M.int(b) local result = ffi.new("bigint", a.size + b.size + 1) result.size = a.size + b.size for i = 0, b.size - 1 do local carry = 0ULL for j = 0, a.size - 1 do local val = carry + result.values[i + j] + a.values[j] * b.values[i] carry = bit.rshift(val, 32) result.values[i + j] = val end result.values[i + a.size] = carry end result.sign = a.sign * b.sign normalize(result) return result end function int_mt.__unm(int) local result = ffi.new("bigint", int.size) result.size = int.size result.sign = -int.sign for i = 0, int.size - 1 do result.values[i] = int.values[i] end return result end function int_mt.__tostring(int) local s = {"0x"} if int.sign < 0 then table.insert(s, 1, "-") end for i = int.size - 1, 0, -1 do if i < int.size - 1 then table.insert(s, ("%08x"):format(int.values[i])) else table.insert(s, ("%x"):format(int.values[i])) end end return table.concat(s) end function int_mt.__eq(a, b) a = M.is_bigint(a) and a or M.int(a) b = M.is_bigint(b) and b or M.int(b) return signed_cmp(a, b) == 0 end function int_mt.__lt(a, b) a = M.is_bigint(a) and a or M.int(a) b = M.is_bigint(b) and b or M.int(b) return signed_cmp(a, b) < 0 end ffi.metatype("bigint", int_mt) function M.int(i) if M.is_bigint(i) then local int = ffi.new("bigint", i.size) int.size = i.size int.sign = i.sign for j = 0, i.size - 1 do int.values[j] = i.values[j] end return int end i = math.floor(i) local size = 1 if i ~= 0 and math.abs(i) > 2^32 then size = math.max(math.floor(math.log(math.abs(i), 2) / 32), 1) end local int = ffi.new("bigint", size) if i < 0 then int.sign = -1 i = -i elseif i == 0 then int.sign = 0 int.size = 1 return int else int.sign = 1 end int.size = 0 while i > 0 do int.values[int.size] = i % (2^32) i = math.floor(i / 2^32) int.size = int.size + 1 end return int end function M.is_bigint(v) return type(v) == "cdata" and ffi.typeof(v) == ffi.typeof"bigint" end local float_mt = {} local function truncate(f, precision) local precision = precision or f.precision if f.mantissa == 0 then f.exp = 0 return f end if f.mantissa.size == 1 then return f end if precision < 1 then return M.float(0) end local to = math.max(f.mantissa.size - precision, 0) while f.mantissa.values[to] == 0 do to = to + 1 end if to ~= 0 then local new = M.float() new.exp = f.exp + to * 32 new.precision = precision new.mantissa = ffi.new("bigint", f.mantissa.size - to) new.mantissa.size = f.mantissa.size - to new.mantissa.sign = f.mantissa.sign for i = to, f.mantissa.size - 1 do new.mantissa.values[i - to] = f.mantissa.values[i] end return new else return f end end local function cmp_float(a, b) if a.mantissa.sign > b.mantissa.sign then return 1 elseif a.mantissa.sign < b.mantissa.sign then return -1 end local a_mag = a.mantissa.size + math.ceil(a.exp / 32) local b_mag = b.mantissa.size + math.ceil(b.exp / 32) if a_mag > b_mag then return 1 * a.mantissa.sign elseif a_mag < b_mag then return -1 * a.mantissa.sign end return (a - b).mantissa.sign end local recursion = true function float_mt.__add(a, b) a = M.is_bigfloat(a) and a or M.float(a) b = M.is_bigfloat(b) and b or M.float(b) local f = M.float() f.precision = math.min(a.precision, b.precision) f.exp = math.min(a.exp, b.exp) local a = truncate(a, math.floor(f.precision - (a.exp - f.exp) / 32)) local b = truncate(b, math.floor(f.precision - (b.exp - f.exp) / 32)) local a_shifted = lshift(a.mantissa, a.exp - f.exp) local b_shifted = lshift(b.mantissa, b.exp - f.exp) f.mantissa = a_shifted + b_shifted return truncate(f) end function float_mt.__sub(a, b) return a + -b end function float_mt.__mul(a, b) a = M.is_bigfloat(a) and a or M.float(a) b = M.is_bigfloat(b) and b or M.float(b) local f = M.float() f.exp = a.exp + b.exp f.mantissa = a.mantissa * b.mantissa f.precision = math.min(a.precision, b.precision) return truncate(f) end function float_mt.__unm(f) local new = M.float() new.exp = f.exp new.mantissa = -f.mantissa new.precision = f.precision return new end function float_mt.__eq(a, b) a = M.is_bigfloat(a) and a or M.float(a) b = M.is_bigfloat(b) and b or M.float(b) return cmp_float(a, b) == 0 end function float_mt.__lt(a, b) a = M.is_bigfloat(a) and a or M.float(a) b = M.is_bigfloat(b) and b or M.float(b) return cmp_float(a, b) < 0 end local function digits_in_range(f, min, max, anchor) anchor = anchor or min local digits = {0} for i = min, max do if (i - anchor) % 4 == 0 and i ~= min then table.insert(digits, 0) end local b if i >= 0 and i < f.mantissa.size * 32 then b = bit.rshift(f.mantissa.values[i / 32], i % 32) or 0 else b = 0 end b = bit.band(b, 1) b = bit.lshift(b, (i - anchor) % 4) local digit = digits[#digits] digit = bit.bor(digit, b) digits[#digits] = digit end return digits end function float_mt.__tostring(f) local before_point = digits_in_range(f, -f.exp, f.mantissa.size * 32 - 1) while before_point[#before_point] == 0 and #before_point > 1 do table.remove(before_point, #before_point) end local after_point = digits_in_range(f, 0, -f.exp - 1, -f.exp) while after_point[1] == 0 and #after_point > 1 do table.remove(after_point, 1) end local result = {"0x"} if f.mantissa < 0 then table.insert(result, 1, "-") end local hex = "0123456789abcdef" for i = #before_point, 1, -1 do local digit = before_point[i] table.insert(result, hex:sub(digit + 1, digit + 1)) end table.insert(result, ".") for i = #after_point, 1, -1 do local digit = after_point[i] table.insert(result, hex:sub(digit + 1, digit + 1)) end return table.concat(result) end function M.float(f, precision) local new = setmetatable({}, float_mt) new.precision = precision or math.huge if f == 0 or f == nil then new.mantissa = M.int(0) new.exp = 0 return new elseif M.is_bigfloat(f) then new.mantissa = f.mantissa new.exp = f.exp return new elseif M.is_bigint(f) then new.mantissa = f new.exp = 0 return new end local m, e = math.frexp(f) new.mantissa, new.exp = M.int(m * 2^53), e - 53 return truncate(new) end function M.is_bigfloat(v) return getmetatable(v) == float_mt end function M.tonumber(n) local result if type(n) == "number" then result = n elseif M.is_bigint(n) then result = 0 for i = 0, n.size - 1 do result = result + n.values[i] * 2^(32 * i) end result = result * n.sign elseif M.is_bigfloat(n) then result = 0 for i = 0, n.mantissa.size - 1 do result = result + n.mantissa.values[i] * 2^(32 * i + n.exp) end result = result * n.mantissa.sign end return result end return M