views:

1112

answers:

12

Well, I have this bit of code that is slowing down the program hugely because it is linear complexity but called a lot of times making the program quadratic complexity. If possible I would like to reduce its computational complexity but otherwise I'll just optimize it where I can. So far I have reduced down to:

def table(n):
    a = 1
    while 2*a <= n:
   if (-a*a)%n == 1: return a

   a += 1

Anyone see anything I've missed? Thanks!

EDIT: I forgot to mention: n is always a prime number.

EDIT 2: Here is my new improved program (thank's for all the contributions!):

def table(n):
    if n == 2: return 1
    if n%4 != 1: return

    a1 = n-1
    for a in range(1, n//2+1):
     if (a*a)%n == a1: return a

EDIT 3: And testing it out in its real context it is much faster! Well this question appears solved but there are many useful answers. I should also say that as well as those above optimizations, I have memoized the function using Python dictionaries...

+5  A: 

Ignoring the algorithm for a moment (yes, I know, bad idea), the running time of this can be decreased hugely just by switching from while to for.

for a in range(1, n / 2 + 1)

(Hope this doesn't have an off-by-one error. I'm prone to make these.)

Another thing that I would try is to look if the step width can be incremented.

Konrad Rudolph
Did you mean the running time can be *decreased*?
Robert Gamble
Thanks; that saves about 20%!
Robert: hmm … I like offering counter-productive hints. ;-) Thanks.
Konrad Rudolph
+2  A: 

It looks like you're trying to find the square root of -1 modulo n. Unfortunately, this is not an easy problem, depending on what values of n are input into your function. Depending on n, there might not even be a solution. See Wikipedia for more information on this problem.

Adam Rosenfield
n is always a prime number so that may help.
+1  A: 

Is it possible for you to cache the results?

When you calculate a large n you are given the results for the lower n's almost for free.

Peter Olsson
+2  A: 

(Building on Adam's answer.) Look at the Wikipedia page on quadratic reciprocity:

x^2 ≡ −1 (mod p) is solvable if and only if p ≡ 1 (mod 4).

Then you can avoid the search of a root precisely for those odd prime n's that are not congruent with 1 modulo 4:

def table(n):
    if n == 2: return 1
    if n%4 != 1: return None   # or raise exception
    ...
Federico Ramponi
That is great! Thanks!
+4  A: 

Consider pre-computing the results and storing them in a file. Nowadays many platforms have a huge disk capacity. Then, obtaining the result will be an O(1) operation.

Diomidis Spinellis
I wouldn't do it to disk, just to memory.
Loren Pechtel
@me.yahoo.com/loren.pechtel, He means load the file into memory on startup, instead of calculating the table each time you start the application.
strager
That isn't practical: computing them in the first place was too slow. But normally that would be an option.
+2  A: 

Edit 2: Surprisingly, strength-reducing the squaring reduces the time a lot, at least on my Python2.5 installation. (I'm surprised because I thought interpreter overhead was taking most of the time, and this doesn't reduce the count of operations in the inner loop.) Reduces the time from 0.572s to 0.146s for table(1234577).

 def table(n):
     n1 = n - 1
     square = 0
     for delta in xrange(1, n, 2):
         square += delta
         if n <= square: square -= n
         if square == n1: return delta // 2 + 1

strager posted the same idea but I think less tightly coded. Again, jug's answer is best.

Original answer: Another trivial coding tweak on top of Konrad Rudolph's:

def table(n):
    n1 = n - 1
    for a in xrange(1, n // 2 + 1):
          if (a*a) % n == n1: return a

Speeds it up measurably on my laptop. (About 25% for table(1234577).)

Edit: I didn't notice the python3.0 tag; but the main change was hoisting part of the calculation out of the loop, not the use of xrange. (Academic since there's a better algorithm.)

Darius Bacon
This speedup should not concern Python 3.0, only older versions. However, using `//` is of course superior.
Konrad Rudolph
Hoisting n-1 (or -a) out of the loop was the change I had in mind. That accounts for the majority of the speedup.
Darius Bacon
+5  A: 

Take a look at http://modular.fas.harvard.edu/ent/ent_py . The function sqrtmod does the job if you set a = -1 and p = n.

A small point you missed is that the running time of your improved algorithm is still in the order of the square root of n. As long you have only small primes n (let's say less than 2^64) that's ok and you should probably prefer your implementation to a more complex one.

If the prime n becomes bigger, you might have to switch to an algorithm using a little bit of number theory. To my knowledge your problem can be solved only with a probabilistic algorithm in time log(n)^3. If I remember correctly, assuming the Riemann hypothesis holds (which most people do), one can show that the running time of the following algorithm (in ruby - sorry, I don't know python) is log(log(n))*log(n)^3:

class Integer
  # calculate b to the power of e modulo self
  def power(b, e)
    raise 'power only defined for integer base' unless b.is_a? Integer
    raise 'power only defined for integer exponent' unless e.is_a? Integer
    raise 'power is implemented only for positive exponent' if e < 0
    return 1 if e.zero?
    x = power(b, e>>1)
    x *= x
    (e & 1).zero? ? x % self : (x*b) % self
  end
  # Fermat test (probabilistic prime number test)
  def prime?(b = 2)
    raise "base must be at least 2 in prime?" if b < 2
    raise "base must be an integer in prime?" unless b.is_a? Integer
    power(b, self >> 1) == 1
  end
  # find square root of -1 modulo prime
  def sqrt_of_minus_one
    return 1 if self == 2
    return false if (self & 3) != 1
    raise 'sqrt_of_minus_one works only for primes' unless prime?
    # now just try all numbers (each succeeds with probability 1/2)
    2.upto(self) do |b|
      e = self >> 1
      e >>= 1 while (e & 1).zero?
      x = power(b, e)
      next if [1, self-1].include? x
      loop do
        y = (x*x) % self
        return x if y == self-1
        raise 'sqrt_of_minus_one works only for primes' if y == 1
        x = y
      end
    end
  end
end

# find a prime
p = loop do
      x = rand(1<<512)
      next if (x & 3) != 1
      break x if x.prime?
    end

puts "%x" % p
puts "%x" % p.sqrt_of_minus_one

The slow part is now finding the prime (which takes approx. log(n)^4 integer operation); finding the square root of -1 takes for 512-bit primes still less than a second.

+1  A: 

One thing that you are doing is repeating the calculation -a*a over and over again.

Create a table of the values once and then do look up in the main loop.

Also although this probably doesn't apply to you because your function name is table but if you call a function that takes time to calculate you should cache the result in a table and just do a table look up if you call it again with the same value. This save you the time of calculating all of the values when you first run but you don't waste time repeating the calculation more than once.

Rex Logan
A lookup would probably be slower than a multiply. Multiply is pretty cheap compared to an unsigned modulus.
strager
It's academic since there's a better algorithm, but my answer hoisted part of the -a*a calculation out of the loop.
Darius Bacon
+1  A: 

Based off OP's second edit:

def table(n):
    if n == 2: return 1
    if n%4 != 1: return

    mod = 0
    a1 = n - 1
    for a in xrange(1, a1, 2):
        mod += a

        while mod >= n: mod -= n
        if mod == a1: return a//2 + 1
strager
should be while mod >= n: mod -= n
Rex Logan
@Rex Logan, Hmm... I guess you're right, but when I tested values 0..31337 they all worked. I think this is because a1=n-1: if it was a1=n-1 I could perhaps do mod>n-1. Weird side effect in my favor, I guess. I'll correct it regardless. Thanks!
strager
A: 

To squeeze even more out a*a >= n-1 so min a = sqrt(n-1). Starting with strager's code.

import math

def table(n):
    if n == 2: return 1
    if n%4 != 1: return

    a1 = n - 1
    a0 = int(math.sqrt(a1))-1
    mod = (a0*a0) % n
    a0 = 2*a0 + 1

    for a in range(a0, a1, 2):
        mod += a
        if mod >= n: mod -= n
        if mod == a1: return a//2 + 1

Edit:

Out of curiosity I came up with two more versions which are faster than the above function. This next one came from using the calculation for a sum and the equation for remainder and the fact that for the test to be true everything must be integers.
The equation is ((a+1)^2)/4 = n-1 +q*n => a=2*sqrt(n-1+q*n)-1
so for a to be an int then sqrt(n-1+q*n) has to be an int. So if you then check all the values of q and find a sqrt that is an int them you have your answer. I also noticed that the answers tend to be on the high end so I go backwards through the loop. This one while being faster is slowed down by the sqrt.

import math

def table(n):
    if n == 2: return 1
    if n%4 != 1: return


    s = [ 0 , 1 , 4 , 9 ]  #valid nibbles for end match

    a1 = n - 1
    qm = ((n-2)*(n-2))//(4*n)
    qn = qm*n +a1
    for q in range(qm, 0,-1):
        qn_l = qn & 0x0f
        if qn_l in s : #only check values that might pass test
            sq = math.sqrt(qn)
            if float(sq - int(sq)) < 0.000000001:
                return int(sq)
        qn -= n

The next function which for my limited testing is the fastest is to take the first answer above and reverse the loop to start at the top of the range and go backwards. At least for my test cases it is about 10 times faster than the solution posted by the OP.

def table(n):
    if n == 2: return 1
    if n%4 != 1: return

    a1 = n - 1
    a0 = a1 // 2
    mod = (a0*a0) % n
    #start at the top and go backwards
    for a in range(n-2, 0, -2):
        if mod == a1:
            return a//2 + 1
        mod -= a
        if mod < 0: mod += n
Rex Logan
+1  A: 

Does Stack Overflow need a 'refactor my code' label?

Gabriel
+1  A: 

I went through and fixed the Harvard version to make it work with python 3. http://modular.fas.harvard.edu/ent/ent_py

I made some slight changes to make the results exactly the same as the OP's function. There are two possible answers and I forced it to return the smaller answer.

import timeit

def table(n):

    if n == 2: return 1
    if n%4 != 1: return

    a1=n-1

    def inversemod(a, p):
        x, y = xgcd(a, p)
        return x%p

    def xgcd(a, b):
        x_sign = 1
        if a < 0: a = -a; x_sign = -1
        x = 1; y = 0; r = 0; s = 1
        while b != 0:
            (c, q) = (a%b, a//b)
            (a, b, r, s, x, y) = (b, c, x-q*r, y-q*s, r, s)
        return (x*x_sign, y)

    def mul(x, y):      
        return ((x[0]*y[0]+a1*y[1]*x[1])%n,(x[0]*y[1]+x[1]*y[0])%n)

    def pow(x, nn):      
        ans = (1,0)
        xpow = x
        while nn != 0:
           if nn%2 != 0:
               ans = mul(ans, xpow)
           xpow = mul(xpow, xpow)
           nn >>= 1
        return ans

    for z in range(2,n) :
        u, v = pow((1,z), a1//2)
        if v != 0:
            vinv = inversemod(v, n)
            if (vinv*vinv)%n == a1:
                vinv %= n
                if vinv <= n//2:
                    return vinv
                else:
                    return n-vinv


tt=0
pri = [ 5,13,17,29,37,41,53,61,73,89,97,1234577,5915587277,3267000013,3628273133,2860486313,5463458053,3367900313 ]
for x in pri:
    t=timeit.Timer('q=table('+str(x)+')','from __main__ import table')
    tt +=t.timeit(number=100)
    print("table(",x,")=",table(x))

print('total time=',tt/100)

This version takes about 3ms to run through the test cases above.

For comparison using the prime number 1234577
OP Edit2 745ms
The accepted answer 522ms
The above function 0.2ms

Rex Logan