from math import *
import numpy as np
import  matplotlib.pyplot as plt
import scipy.sparse as sp
import scipy.sparse.linalg



#-------------------------------------------------------------------------------
def fem_1D(m,p):
    """
    m = parameter for the mesh (here: uniform mesh)
    p = polynomial degree 

    numbering of the unknowns is such that the first and the last
    unknown correspond to the left and right endpts 

    """

    g = -1;                   #value of the Neumann bc at the right endpt 

    #define the mesh
    xleft = 0; xright = pi; 
    h = (xright-xleft)/m; 
    mesh = np.linspace(xleft,xright,m+1);

    #compute the stiffness matrix and the load vector without boundary conditions 
    [B,l] = assemble_neumann_problem(mesh,p); 
    N = B.shape[0];           #get size of B

#TODO: ENFORCE NEUMANN BOUNDARY CONDITIONS AT RIGHT ENDPOINT 
#TODO: ENFORCE DIRICHLET BOUNDARY CONDITIONS AT LEFT ENDPOINT 
#TODO: COMPUTE FEM SOLUTION
#TODO: COMPUTE ENERGY 
    return energy

#-------------------------------------------------------------------------------
def assemble_neumann_problem(mesh,p):
    """input: a mesh, i.e., a (sorted) list of nodes x1, x2,...,xM
    input: polynomial degree p
    """
     
    [T,N] = local_to_global_maps(mesh,p);

#TODO: LOOP OVER ALL ELEMENTS AND ASSEMBLE 
    return B,l
    
#-------------------------------------------------------------------------------
def local_to_global_maps(mesh,p):
    """input: the mesh and the polynomial degree 
    output: matrix T, where T[K,i] gives the number of the global DOF corresponding
            to the function N_i on element K
            N = total number of DOF 
            recall: N_1, N_2 are linears, N_i for i \ge 3 are bubbles 
            global numbering: for each element, first the left endpt, then all bubbles, then
            the right endpt 
    """
#TODO: SET UP THE MATRIX T and COMPUTE N 
    return T,N
#-------------------------------------------------------------------------------
def element_stiffness_matrix(vertex_coords,p):
    """
     vertex_coords[0,1] contains the two endpoints of the element
     p >= 1 polynomial degree to be employed 
    """
    h = vertex_coords[1] - vertex_coords[0]; 
    K = np.zeros((p+1,p+1)); 
    M = np.zeros((p+1,p+1)); 
    q = p+1;             #number of Gaussian quadrature points to be employed 
    [x,w] = gauleg(q);   #x = quadrature points on (-1,1); w = quadrature weights
    x = x.T; w = w.T;    #gauleg returns column vectors, we want row vectors 
    PHI = N(x,p+1);           #evaluate the shape fcts N_i at the quadrature points 
    GRAD_PHI = grad_N(x,p+1); #evaluate N_i' at the quadrature points 


#TODO: COMPUTE THE ELEMENT STIFFNESS MATRIX K
    #realize the numerical integration of 
    #\int_{-1}^1 N_i^\prime N_j^\prime/(h/2)

#TODO: COMPUTE THE ELEMENT MASS MATRIX M
    #realize the numerical integration of 
    #\int_{-1}^1 N_i N_j*(h/2)
    
    B = K+M;                                                           
    #
    return B

#-------------------------------------------------------------------------------
def element_load_vector(vertex_coords,p) :
    """
     vertex_coords[0,1] contains the two endpoints of the element
     p >= 1 polynomial degree to be employed 
    """
    h = vertex_coords[1] - vertex_coords[0]; 
    l = np.zeros(p+1); 
    q = p+1;           #number of Gaussian quadrature points to be employed 
    [x,w] = gauleg(q); #x = quadrature points on (-1,1); w = quadrature weights
    x = x.T; w = w.T;    #gauleg returns column vectors, we want row vectors 
    PHI = N(x,p+1);                     #evaluate the shape fcts N_i at the quadrature points 
    F = f(h/2*(x+1)+vertex_coords[0]);  #evaluate the right-hand side at the quadrature points 

    l = PHI@np.diag(w)@F.T*h/2;             #\int_{-1}^1 f N_i*h/2
    #
    #---comment
    # the quadrature to compute l,M is written very compactly. Multiplying out, the definition of l 
    # amounts to  setting l(i) = \sum_{k=1}^{length(x)} w(k) PHI(i,k) * F(k) *h/2    for all i
    # similarly for the matrix K
    #---endcomment
    return l

#-------------------------------------------------------------------------------
def N(x,n):
    """
    Calculates the first n trial functions on [-1,1]; 
    we assume implicitly that the vector x is a row vector

    N_1 = (1-x)/2
    N_2 = (1+x)/2
    N_n = \sqrt{(2n-3)/2}\int_{-1}^x L_{n-2}(t) dt     if n \ge 3
    %
    """
    U=np.zeros((n,len(x)));
    U[0,:] = (1-x)/2;
    U[1,:] = (1+x)/2;
    if n >2: 
        P = legtable(x,n);
        for i in range(3,n+1):
            ip=i-2;
            U[i-1,:] = sqrt((2*i-3)/2)*1/(2*ip+1)*(P[ip+1,:]-P[ip-1,:]);
    return U

#-------------------------------------------------------------------------------
def grad_N(x,n):
    """
    Calculates the derivatives of the first n trial functions on [-1,1]; 
    we assume implicitly that the vector x is a row vector

    N_1 = (1-x)/2
    N_2 = (1+x)/2
    N_n = \sqrt{(2n-3)/2}\int_{-1}^x L_{n-2}(t) dt     if n \ge 3
    """

    U=np.zeros((n,len(x)));
    U[0,:] = -1/2*np.ones_like(x);
    U[1,:] =  1/2*np.ones_like(x);
    if n >2:
        P = legtable(x,n);
        for i in range(3,n+1):
            ip=i-2;
            U[i-1,:] = sqrt((2*i-3)/2)*P[ip,:];
    return U

#-------------------------------------------------------------------------------
def legtable(x,m):
    """
    input:  row vector x \subset (-1,1) of points
    m: the maximal order of the Legendre polynomials
    output: a matrix P=P(m+1,length(x)) containing the values of the
        Legendre polynomials at the points x
    """
    l=len(x);
    P=np.ones((m+1,l));
    P[1,:]=x;
    for i in range(2,m+1):
        P[i,:]=((2*i-1)*x*P[i-1,:]-(i-1)*P[i-2,:])/i;

 
    return P

def gauleg(n):
    x,w=np.polynomial.legendre.leggauss(n)
    
    return x,w


#-------------------------------------------------------------------------------
def f(x):
    #right-hand side function f 
    y = 2*np.sin(x); 
    #exact solution: u(x) = sin(x) on (0,pi);    hence:  -u''+u = 2*sin(x)
    #exact energy: pi
    return y
#-------------------------------------------------------------------------------



# The code for testing
plt.rcParams['text.usetex'] = True
pmax = 10; 
mmax = 7; 
#create convergence plots for different values of p and uniform meshes 
exact_energy = pi; 
energy = np.zeros((mmax,pmax)); 
h      = np.zeros((mmax,pmax)); 
#
for p in range(1,pmax+1):
 for m in range(0,mmax):
  M = 2**(m+1); 
  energy[m,p-1] = fem_1D(M,p)
  h[m,p-1] = 1.0/M;
  
  
error = np.sqrt(abs(exact_energy - energy)); 
print(h[:,1])
print(error)

plt.figure(1)
plt.loglog(h[:,0],error[:,0],'*-k',   \
           h[:,1],error[:,1],'o--b',  \
           h[:,2],error[:,2],'d:r',   \
           h[:,3],error[:,3],'+-.m')

#plt.axis([1/200, 1, 10**(-7), 1])
plt.xlabel('mesh size h')
plt.ylabel('energy norm error')
plt.title('$-u^{\prime\prime} + u =f$; solution $u(x) = sin(x)$; uniform meshes')
plt.legend(['p=1','p=2','p=3','p=4','Location','SouthEast'])

plt.figure(2)
plt.semilogy(range(1,pmax+1),error[0,:],'*-k')
plt.axis([1,10, 10**(-8), 1])
plt.xlabel('polynomial degree $p$')
plt.ylabel('energy norm error')
plt.title('$-u^{\prime\prime} + u =f$; solution $u(x) = sin(x)$; p-method')


plt.show()
