summaryrefslogtreecommitdiff
path: root/bignum.lua
blob: 886f3f28f30094659bcf5a074374fa25f3f47c36 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
local ffi = require "ffi"
local bit = require "bit"

local M = {}

ffi.cdef [[
	typedef struct bigint {
		int8_t sign;
		uint32_t size;
		uint32_t values[?];
	} bigint;
]]

local int_mt = {}

local function normalize(int)
	while int.size > 1 do
		if int.values[int.size - 1] ~= 0 then
			break
		end
		int.size = int.size - 1
	end
	if int.size == 1 and int.values[0] == 0 then
		int.sign = 0
	end
	return int
end

local function unsigned_add(a, b)
	local size = math.max(a.size, b.size) + 1
	local result = ffi.new("bigint", size)

	local carry = 0ULL
	for i = 0, size - 1 do
		local sum = carry
		if i < a.size then
			sum = sum + a.values[i]
		end
		if i < b.size then
			sum = sum + b.values[i]
		end
		result.values[i] = sum
		carry = bit.rshift(sum, 32)
	end

	result.size = size
	result.sign = 1
	return result
end

local function unsigned_sub(a, b)
	local result = ffi.new("bigint", a.size)
	local carry = 0LL
	for i = 0, a.size - 1 do
		local diff = carry + a.values[i]
		if i < b.size then
			diff = diff - b.values[i]
		end
		if diff < 0 then
			diff = diff + 2^32
			carry = -1LL
		else
			carry = 0LL
		end
		result.values[i] = diff
	end

	result.size = a.size
	result.sign = 1
	return result
end
-- some sort of bug or something occurs with this!
require"jit".off(unsigned_sub)

local function unsigned_cmp(a, b)
	if a.size > b.size then
		return 1
	elseif a.size < b.size then
		return -1
	end
	for i = a.size - 1, 0, -1 do
		if a.values[i] > b.values[i] then
			return 1
		elseif b.values[i] > a.values[i] then
			return -1
		end
	end
	return 0
end

local function signed_cmp(a, b)
	if a.sign > b.sign then
		return 1
	elseif a.sign < b.sign then
		return -1
	end
	if a.sign > 0 then
		return unsigned_cmp(a, b)
	else
		return -unsigned_cmp(a, b)
	end
end

local function lshift(int, n)
	if n == 0 then return int end
	if int == 0 then return int end
	assert(n > 0)

	local size = int.size + n / 32 + 1
	local result = ffi.new("bigint", size)
	result.size = size

	for i = 0, int.size - 1 do
		local val = bit.lshift(ffi.cast("uint64_t", int.values[i]), n % 32)
		local index = i + math.floor(n / 32)
		result.values[index] = bit.bor(val, result.values[index])
		result.values[index + 1] = bit.rshift(val, 32)
	end
	result.sign = int.sign

	normalize(result)
	return result
end

function int_mt.__add(a, b)
	a = M.is_bigint(a) and a or M.int(a)
	b = M.is_bigint(b) and b or M.int(b)

	local result
	if a.sign < 0 and b.sign < 0 then
		result = unsigned_add(a, b)
		result.sign = -1
	elseif a.sign < 0 or b.sign < 0 then
		local neg = a.sign < 0 and a or b
		local pos = a.sign < 0 and b or a
		if unsigned_cmp(pos, neg) > 0 then
			result = unsigned_sub(pos, neg)
		else
			result = unsigned_sub(neg, pos)
			result.sign = -1
		end
	else
		result = unsigned_add(a, b)
	end
	normalize(result)
	return result
end

function int_mt.__sub(a, b)
	return a + -b
end

function int_mt.__mul(a, b)
	a = M.is_bigint(a) and a or M.int(a)
	b = M.is_bigint(b) and b or M.int(b)

	local result = ffi.new("bigint", a.size + b.size + 1)
	result.size = a.size + b.size

	for i = 0, b.size - 1 do
		local carry = 0ULL
		for j = 0, a.size - 1 do
			local val =
				carry + result.values[i + j] + a.values[j] * b.values[i]
			carry = bit.rshift(val, 32)
			result.values[i + j] = val
		end
		result.values[i + a.size] = carry
	end
	result.sign = a.sign * b.sign

	normalize(result)
	return result
end

function int_mt.__unm(int)
	local result = ffi.new("bigint", int.size)
	result.size = int.size
	result.sign = -int.sign
	for i = 0, int.size - 1 do
		result.values[i] = int.values[i]
	end
	return result
end

function int_mt.__tostring(int)
	local s = {"0x"}
	if int.sign < 0 then table.insert(s, 1, "-") end

	for i = int.size - 1, 0, -1 do
		if i < int.size - 1 then
			table.insert(s, ("%08x"):format(int.values[i]))
		else
			table.insert(s, ("%x"):format(int.values[i]))
		end
	end
	return table.concat(s)
end

function int_mt.__eq(a, b)
	a = M.is_bigint(a) and a or M.int(a)
	b = M.is_bigint(b) and b or M.int(b)

	return signed_cmp(a, b) == 0
end

function int_mt.__lt(a, b)
	a = M.is_bigint(a) and a or M.int(a)
	b = M.is_bigint(b) and b or M.int(b)

	return signed_cmp(a, b) < 0
end

ffi.metatype("bigint", int_mt)

function M.int(i)
	if M.is_bigint(i) then
		local int = ffi.new("bigint", i.size)
		int.size = i.size
		int.sign = i.sign
		for j = 0, i.size - 1 do
			int.values[j] = i.values[j]
		end
		return int
	end

	i = math.floor(i)
	local size = 1
	if i ~= 0 and math.abs(i) > 2^32 then
		size = math.max(math.floor(math.log(math.abs(i), 2) / 32), 1)
	end
	local int = ffi.new("bigint", size)

	if i < 0 then
		int.sign = -1
		i = -i
	elseif i == 0 then
		int.sign = 0
		int.size = 1
		return int
	else
		int.sign = 1
	end

	int.size = 0
	while i > 0 do
		int.values[int.size] = i % (2^32)
		i = math.floor(i / 2^32)
		int.size = int.size + 1
	end
	
	return int
end

function M.is_bigint(v)
	return type(v) == "cdata" and ffi.typeof(v) == ffi.typeof"bigint"
end

local float_mt = {}

local function truncate(f, precision)
	local precision = precision or f.precision

	if f.mantissa == 0 then f.exp = 0 return f end
	if f.mantissa.size == 1 then return f end
	if precision < 1 then return M.float(0) end

	local to = math.max(f.mantissa.size - precision, 0)
	while f.mantissa.values[to] == 0 do
		to = to + 1
	end

	if to ~= 0 then
		local new = M.float()
		new.exp = f.exp + to * 32
		new.precision = precision

		new.mantissa = ffi.new("bigint", f.mantissa.size - to)
		new.mantissa.size = f.mantissa.size - to
		new.mantissa.sign = f.mantissa.sign
		for i = to, f.mantissa.size - 1 do
			new.mantissa.values[i - to] = f.mantissa.values[i]
		end

		return new
	else
		return f
	end
end

local function cmp_float(a, b)
	if a.mantissa.sign > b.mantissa.sign then
		return 1
	elseif a.mantissa.sign < b.mantissa.sign then
		return -1
	end

	local a_mag = a.mantissa.size + math.ceil(a.exp / 32)
	local b_mag = b.mantissa.size + math.ceil(b.exp / 32)
	if a_mag > b_mag then
		return 1 * a.mantissa.sign
	elseif a_mag < b_mag then
		return -1 * a.mantissa.sign
	end

	return (a - b).mantissa.sign
end

local recursion = true
function float_mt.__add(a, b)
	a = M.is_bigfloat(a) and a or M.float(a)
	b = M.is_bigfloat(b) and b or M.float(b)

	local f = M.float()
	f.precision = math.min(a.precision, b.precision)
	f.exp = math.min(a.exp, b.exp)

	local a = truncate(a, math.floor(f.precision - (a.exp - f.exp) / 32))
	local b = truncate(b, math.floor(f.precision - (b.exp - f.exp) / 32))

	local a_shifted = lshift(a.mantissa, a.exp - f.exp)
	local b_shifted = lshift(b.mantissa, b.exp - f.exp)
	f.mantissa = a_shifted + b_shifted

	return truncate(f)
end

function float_mt.__sub(a, b)
	return a + -b
end

function float_mt.__mul(a, b)
	a = M.is_bigfloat(a) and a or M.float(a)
	b = M.is_bigfloat(b) and b or M.float(b)

	local f = M.float()
	f.exp = a.exp + b.exp
	f.mantissa = a.mantissa * b.mantissa
	f.precision = math.min(a.precision, b.precision)
	return truncate(f)
end

function float_mt.__unm(f)
	local new = M.float()
	new.exp = f.exp
	new.mantissa = -f.mantissa
	new.precision = f.precision
	return new
end

function float_mt.__eq(a, b)
	a = M.is_bigfloat(a) and a or M.float(a)
	b = M.is_bigfloat(b) and b or M.float(b)

	return cmp_float(a, b) == 0
end

function float_mt.__lt(a, b)
	a = M.is_bigfloat(a) and a or M.float(a)
	b = M.is_bigfloat(b) and b or M.float(b)

	return cmp_float(a, b) < 0
end

local function digits_in_range(f, min, max, anchor)
	anchor = anchor or min

	local digits = {0}
	for i = min, max do
		if (i - anchor) % 4 == 0 and i ~= min then
			table.insert(digits, 0)
		end

		local b
		if i >= 0 and i < f.mantissa.size * 32 then
			b = bit.rshift(f.mantissa.values[i / 32], i % 32) or 0
		else
			b = 0
		end
		b = bit.band(b, 1)
		b = bit.lshift(b, (i - anchor) % 4)

		local digit = digits[#digits]
		digit = bit.bor(digit, b)
		digits[#digits] = digit
	end
	return digits
end

function float_mt.__tostring(f)
	local before_point = digits_in_range(f, -f.exp, f.mantissa.size * 32 - 1)
	while before_point[#before_point] == 0 and #before_point > 1 do
		table.remove(before_point, #before_point)
	end

	local after_point = digits_in_range(f, 0, -f.exp - 1, -f.exp)
	while after_point[1] == 0 and #after_point > 1 do
		table.remove(after_point, 1)
	end

	local result = {"0x"}
	if f.mantissa < 0 then
		table.insert(result, 1, "-")
	end
	local hex = "0123456789abcdef"
	for i = #before_point, 1, -1 do
		local digit = before_point[i]
		table.insert(result, hex:sub(digit + 1, digit + 1))
	end
	table.insert(result, ".")
	for i = #after_point, 1, -1 do
		local digit = after_point[i]
		table.insert(result, hex:sub(digit + 1, digit + 1))
	end
	return table.concat(result)
end

function M.float(f, precision)
	local new = setmetatable({}, float_mt)
	new.precision = precision or math.huge

	if f == 0 or f == nil then
		new.mantissa = M.int(0)
		new.exp = 0
		return new
	elseif M.is_bigfloat(f) then
		new.mantissa = f.mantissa
		new.exp = f.exp
		return new
	elseif M.is_bigint(f) then
		new.mantissa = f
		new.exp = 0
		return new
	end
	
	local m, e = math.frexp(f)
	new.mantissa, new.exp = M.int(m * 2^53), e - 53
	return truncate(new)
end

function M.is_bigfloat(v)
	return getmetatable(v) == float_mt
end

function M.tonumber(n)
	local result
	if type(n) == "number" then
		result = n
	elseif M.is_bigint(n) then
		result = 0
		for i = 0, n.size - 1 do
			result = result + n.values[i] * 2^(32 * i)
		end
		result = result * n.sign
	elseif M.is_bigfloat(n) then
		result = 0
		for i = 0, n.mantissa.size - 1 do
			result = result + n.mantissa.values[i] * 2^(32 * i + n.exp)
		end
		result = result * n.mantissa.sign
	end
	return result
end

return M