Edit in response to jonalm's comment:
jonalm: N~3^n not n~3^N. N is max element in a and n is number of
elements in a.
n is ~ 2^20. If N is ~ 3^n then N is ~ 3^(2^20) > 10^(500207).
Scientists estimate (http://www.stormloader.com/ajy/reallife.html) that there are only around 10^87 particles in the universe. So there is no (naive) way a computer can handle an int of size 10^(500207).
jonalm: I am however a bit curios about the pv() function you define. (I
do not manage to run it as text.find() is not defined (guess its in another
module)). How does this function work and what is its advantage?
pv is a little helper function I wrote to debug the value of variables. It works like
print() except when you say pv(x) it prints both the literal variable name (or expression string), a colon, and then the variable's value.
If you put
#!/usr/bin/env python
import traceback
def pv(var):
(filename,line_number,function_name,text)=traceback.extract_stack()[-2]
print('%s: %s'%(text[text.find('(')+1:-1],var))
x=1
pv(x)
in a script you should get
x: 1
The modest advantage of using pv over print is that it saves you typing. Instead of having to
write
print('x: %s'%x)
you can just slap down
pv(x)
When there are multiple variables to track, it's helpful to label the variables.
I just got tired of writing it all out.
The pv function works by using the traceback module to peek at the line of code
used to call the pv function itself. (See http://docs.python.org/library/traceback.html#module-traceback) That line of code is stored as a string in the variable text.
text.find() is a call to the usual string method find(). For instance, if
text='pv(x)'
then
text.find('(') == 2 # The index of the '(' in string text
text[text.find('(')+1:-1] == 'x' # Everything in between the parentheses
I'm assuming n ~ 3^N, and n~2**20
The idea is to work module N. This cuts down on the size of the arrays.
The second idea (important when n is huge) is to use numpy ndarrays of 'object' type because if you use an integer dtype you run the risk of overflowing the size of the maximum integer allowed.
#!/usr/bin/env python
import traceback
import numpy as np
def pv(var):
(filename,line_number,function_name,text)=traceback.extract_stack()[-2]
print('%s: %s'%(text[text.find('(')+1:-1],var))
You can change n to be 2**20, but below I show what happens with small n
so the output is easier to read.
n=100
N=int(np.exp(1./3*np.log(n)))
pv(N)
# N: 4
a=np.random.randint(N,size=n)
b=np.random.randint(N,size=n)
pv(a)
pv(b)
# a: [1 0 3 0 1 0 1 2 0 2 1 3 1 0 1 2 2 0 2 3 3 3 1 0 1 1 2 0 1 2 3 1 2 1 0 0 3
# 1 3 2 3 2 1 1 2 2 0 3 0 2 0 0 2 2 1 3 0 2 1 0 2 3 1 0 1 1 0 1 3 0 2 2 0 2
# 0 2 3 0 2 0 1 1 3 2 2 3 2 0 3 1 1 1 1 2 3 3 2 2 3 1]
# b: [1 3 2 1 1 2 1 1 1 3 0 3 0 2 2 3 2 0 1 3 1 0 0 3 3 2 1 1 2 0 1 2 0 3 3 1 0
# 3 3 3 1 1 3 3 3 1 1 0 2 1 0 0 3 0 2 1 0 2 2 0 0 0 1 1 3 1 1 1 2 1 1 3 2 3
# 3 1 2 1 0 0 2 3 1 0 2 1 1 1 1 3 3 0 2 2 3 2 0 1 3 1]
wa holds the number of 0s, 1s, 2s, 3s in a
wb holds the number of 0s, 1s, 2s, 3s in b
wa=np.bincount(a)
wb=np.bincount(b)
pv(wa)
pv(wb)
# wa: [24 28 28 20]
# wb: [21 34 20 25]
result=np.zeros(N,dtype='object')
Think of a 0 as a token or chip. Similarly for 1,2,3.
Think of wa=[24 28 28 20] as meaning there is a bag with 24 0-chips, 28 1-chips, 28 2-chips, 20 3-chips.
You have a wa-bag and a wb-bag. When you draw a chip from each bag, you "add" them together and form a new chip. You "mod" the answer (modulo N).
Imagine taking a 1-chip from the wb-bag and adding it with each chip in the wa-bag.
1-chip + 0-chip = 1-chip
1-chip + 1-chip = 2-chip
1-chip + 2-chip = 3-chip
1-chip + 3-chip = 4-chip = 0-chip (we are mod'ing by N=4)
Since there are 34 1-chips in the wb bag, when you add them against all the chips in the wa=[24 28 28 20] bag, you get
34*24 1-chips
34*28 2-chips
34*28 3-chips
34*20 0-chips
This is just the partial count due to the 34 1-chips. You also have to handle the other
types of chips in the wb-bag, but this shows you the method used below:
for i,count in enumerate(wb):
partial_count=count*wa
pv(partial_count)
shifted_partial_count=np.roll(partial_count,i)
pv(shifted_partial_count)
result+=shifted_partial_count
# partial_count: [504 588 588 420]
# shifted_partial_count: [504 588 588 420]
# partial_count: [816 952 952 680]
# shifted_partial_count: [680 816 952 952]
# partial_count: [480 560 560 400]
# shifted_partial_count: [560 400 480 560]
# partial_count: [600 700 700 500]
# shifted_partial_count: [700 700 500 600]
pv(result)
# result: [2444 2504 2520 2532]
This is the final result: 2444 0s, 2504 1s, 2520 2s, 2532 3s.
# This is a test to make sure the result is correct.
# This uses a very memory intensive method.
# c is too huge when n is large.
if n>1000:
print('n is too large to run the check')
else:
c=(a[:]+b[:,np.newaxis])
c=c.ravel()
c=c%N
result2=np.bincount(c)
pv(result2)
assert(all(r1==r2 for r1,r2 in zip(result,result2)))
# result2: [2444 2504 2520 2532]