"""
Simulation of interferometer for measuring two-state system ('spinometer')
See pps 1-4, Fig. 1 in Sidles 1996, http://xxx.lanl.gov/pdf/quant-ph/9612001
"""

import random
from math import sin, cos, pi
from scipy import matrix
from scipy.linalg import norm

# Constants

i = (0+1j)  

# Quantum mechanics

I  = matrix('1 0; 0 1')  # 2x2 identity matrix. Alternatively: matrix(eye(2))
Sz = matrix('1 0; 0 -1') # Pauli spin matrix, sigma_z

def probability(v):
    return norm(v)**2

def normalize(v):
    return v/norm(v)

def polarization(v):
    'qubit polarization, eqn 8, p. 4, v is state vector, return scalar'
    return v.H*Sz*v # v.H is conjugate transpose

def polarize(p):
    """
    Given p in -1..1, return a state v = [ a;b ] whose polarization(v) is p
    Choose real a,b so a - b == p and a + b == 1 (should it be sqrt(a^2+b^2) ?)
    """
    return matrix([[(1.0+p)/2.0],[(1.0-p)/2.0]]) # quote form doesn't work here

# A and B matrices, like eqn's 4 and 5 on p. 3 in the paper
# EXCEPT with revised phase convention, factor of i in 2nd term

def A_matrix(theta):
    return (1./2) * ((cos(theta) + i)*I + i*(sin(theta))*Sz)

def B_matrix(theta):
    return (1./2) * ((cos(theta) - i)*I + i*(sin(theta))*Sz)

# global variables

spin = None
A = None
B = None

# functions

def initialize(polarization, T1, seed):
    """
    polarization: in range -1 (down) .. 1 (up), used to calculate amplitudes
    amplitudes: spin state (vector of 2 complex numbers, needn't be normalized)
    T1: time constant, related to inverse of photon phase shift theta
    seed: random seed, nonzero for repeatable runs
    """
    global spin, A, B
    amplitudes = polarize(polarization)
    spin = normalize(amplitudes)
    phaseshift = 3./(4*T1)   # 3/4 does indeed result in 1-e^t/T1 approach
    A = A_matrix(phaseshift)
    B = B_matrix(phaseshift)
    if seed:
        random.seed(seed) # make runs repeatable

def trial():
    """
    Core of the simulation: run a single photon to a detector, like eq.1, p.2
    return detector where photon seen, and real scalar polarization
    """
    global spin
    aspin = A*spin
    detect_a = (random.random() < probability(aspin))
    spin = normalize(aspin) if detect_a else normalize(B*spin)
    return 'a' if detect_a else 'b', (polarization(spin).real)[0,0] #scalar pol

def trial_a():
    """
    Photon always goes to A, demonstrate effect of T1, phaseshift
    """
    global spin
    aspin = A*spin
    spin = normalize(aspin)
    return 'a', (polarization(spin).real)[0,0] #scalar pol
