import numpy as np
import scipy.sparse as sp
import scipy.sparse.linalg


#--------------------------------------------------------
"""
define the geometry, i.e., the 
list of nodes
list of elements
list of edges (with boundary flags) 
"""
def fem(m):
    # define geometry
    [nodes,elements,edges] = geometry_uniform_mesh(m);
    
    #---
    dirichlet_nodes = determine_dirichlet_nodes(edges); 
    N = nodes.shape[0]   #we implicitly assume that the node list starts with 0

    [B,l] = assemble_neumann_problem(nodes,elements);#assemble the stiffness matrix INCLUDING the boundary nodes    
    [BD,lD] = apply_Dirichlet_bc(B,l,dirichlet_nodes,nodes);           #remove the boundary nodes from B, l

    ##TODO:
    ##COMPUTE FEM SOLUTION AND FEM ENERGY
    return energy

"""helper function which plots the FEM function u"""
def plot_function(nodes,elements,u):
    import matplotlib.pyplot as plt
    fig = plt.figure(1)
    ax = fig.add_subplot(projection='3d')
    ax.plot_trisurf(nodes[:,1],nodes[:,2],elements[:,1:],Z=u)
    plt.show()
    
    

#---------------------------------------------------------
#subroutines for FEM
#---------------------------------------------------------
#---------------------------------------------------------
def apply_Dirichlet_bc(B,l,dirichlet_nodes,nodes):
#simply delete the rows and columns of B that correspond to Dirichlet nodes
#simply delete the row of l that correspond to Dirichlet nodes 
    N = B.shape[0]
    active_nodes = determine_active_nodes(N,dirichlet_nodes);
   
    mask=np.zeros(N,dtype=bool)
    mask[active_nodes]=True


    BD = B[mask][:,mask];
         
    lD = l[mask]; 
    return BD,lD


def assemble_neumann_problem(nodes,elements):
    #assembles the *unconstrained* stiffness matrix, i.e., the fact that Dirichlet nodes 
    #may exist is ignored 
    N = nodes.shape[0]    
    i_=list()
    j_=list()    
    vals=list()
    
    l = np.zeros(N);

    ##TODO:
    ##ASSEMBLE STIFFNESS MATRIX AND LOAD VECTOR
    ##HINWEIS: Am besten assembliert man sparse matrizen im so genannten COO
    ## format:
    ## https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html
    ## hierzu befüllt man die vektoren i_,j_,vals mit den nicht-null beiträgen zur globalen    
    ## Steifigkeitsmatrix. Dann ist B(i_j[l],j_[l])=vals[l]. Falls ein index i_[l],j_[l]
    ## doppelt vorkommt werden die zugehörigen werte aufaddiert.
    ## 
    ## Am ende wird in das besser geeignete csr format convertiert

    B=sp.coo_array((vals,(i_,j_)),shape=(N,N)).tocsr()

    return B,l



#---------------------------------------------------------
def  element_stiffness_matrix(vertex_coords):
#input: vertex_coords contains the coordinates of the three vertices of the triangle  (row-wise)
    BK = np.zeros([3,3]); 
    v1 = vertex_coords[:,0] ; 
    v2 = vertex_coords[:,1] ; 
    v3 = vertex_coords[:,2] ; 

    ##TODO:
    ##COMPUTE ELEMENT STIFFNESS MATRIX

    return BK  #return  stiffness matrix and ignore mass matrix
#---------------------------------------------------------
def element_load_vector(vertex_coords):
#input: vertex_coords(2:3) contains the coordinates of the three vertices of the triangle  
    v1 = vertex_coords[:,0]; 
    v2 = vertex_coords[:,1]; 
    v3 = vertex_coords[:,2];

    ##TODO:    
    ## compute the element load vector Lk
    return lK

#---------------------------------------------------------
def determine_dirichlet_nodes(edges):
    #determine Dirichlet nodes, i.e., those that are associated with DoF on the Dirichlet part of the bdy
    #
    #find all edges that are Dirichlet edges 
    Dedges = np.argwhere(edges[:,3]);                  #find spits out indices of non-zero entries
    if len(Dedges) == 0:
        return []; 


    Dnodes = [edges[Dedges,1], edges[Dedges,2]]; #grab all nodes that are endpoints of Dirichlet edges
    Dnodes=np.asarray(Dnodes,dtype=int)
  # remove possible duplicates 
    n = np.max(Dnodes);                            #get highest node number 
    tmp = np.zeros(round(n)+1,dtype=bool);
    tmp[Dnodes] = True;                             #tmp has non-zeros entries only at indices corresponding
                                                 #to Dirichlet nodes 
    dirichlet_nodes = np.argwhere(tmp);                 #find gives the indices of the nonzero entries
    return dirichlet_nodes

#---------------------------------------------------------
def determine_active_nodes(N,dirichlet_nodes):
#returns a list of nodes that are the active nodes, i.e., those that are not on the 
#Dirichlet part of the boundary 
#N = total number of nodes
#dirichlet_nodes = list of nodes that are on the Dirichlet part 
#
    if len(dirichlet_nodes) == 0 :
          active_nodes = range(0,N); 
          return active_nodes

    tmp = np.ones(N,dtype=bool); 
    tmp[dirichlet_nodes] = False;                      #tmp(i) = 0 iff i is Dirichlet node
    active_nodes = np.nonzero(tmp);               # find returns indices with nonzero entries
    return active_nodes

#---------------------------------------------------------
def compute_integral(nodes,elements,u):
#---------------------------------------------------------
#compute the integral \int_\Omega u, where u is a vector with the nodal values 
#%we use a one-point quadrature rule 
    integral = 0;
    local_node=np.zeros(3,dtype=int)
    vertex_coords=np.zeros([2,3])
    for k in range(0,elements.shape[0]):                 # loop over all elements 
        for kk in range(0,3): 
            local_node[kk] = elements[k,kk+1];        #determine the node numbers of the three vertices 
            vertex_coords[:,kk] = [nodes[local_node[kk],1], nodes[local_node[kk],2]]; #determine the coordinates of the three vertices 

        v1 = vertex_coords[:,0]; 
        v2 = vertex_coords[:,1]; 
        v3 = vertex_coords[:,2];
        FKprime = np.column_stack((v2-v1,v3-v1))        
        detFKprime = abs(np.linalg.det(FKprime)); 
        for i in range(0,3):
            integral = integral+ 1.0/6.0*u[local_node[i]]*detFKprime; 
    return integral;

#---------------------------------------------------------
def N1(x):
    xi = x[0]; eta = x[1]; 
    y = 1 - xi - eta; 
    return y

#---------------------------------------------------------
def  grad_N1(x): 
    xi = x[0]; eta = x[1]; 
    g = [-1, -1];
    return np.asarray(g)
#---------------------------------------------------------
#---------------------------------------------------------
def N2(x):
    xi = x[0]; eta = x[1]; 
    y = xi;
    return y
#---------------------------------------------------------
def grad_N2(x) :
    xi = x[0]; eta = x[1]; 
    g = [1, 0];
    return np.asarray(g)
#---------------------------------------------------------
#---------------------------------------------------------
def N3(x):
    xi = x[0]; eta = x[1]; 
    y = eta;
    return y
#---------------------------------------------------------
def grad_N3(x):
    xi = x[0]; eta = x[1]; 
    g = [0, 1];
    return np.asarray(g);
#---------------------------------------------------------
def f(x):
#the right-hand side function 
    x1 = x[0]; x2 = x[1]; 
    y = 2*(x1*(1-x1)+x2*(1-x2)); 
    #this corresponds to the exact solution u(x,y) = x*(1-x)*y*(1-y) + (x*x-y*y)
    return y

#---------------------------------------------------------
#
#----------------------------
def geometry_triangle():
#----------------------------
#simplest geometry: single triangle 
#%all edges are Neumann edges 
#%
    nodes = [
        [0, 0.0, 0.0],
        [1, 1.0, 0.0 ],
        [2, 1.0, 1.0]]; 

    elements = [
        [0,0, 1, 2]
    ]; 

    edges = [
        [0, 0, 1, 0],
        [1, 1, 2, 0 ],
        [2, 0, 2, 0]
    ]; 

    return np.asarray(nodes), np.asarray(elements),np.asarray(edges)

#----------------------------
def geometry_square():
#----------------------------
#%simple geometry: square split into 2 triangles 
#all edges are Neumann edges 
    nodes = [
        [0, 0.0, 0.0],
        [1, 1,.0, 0.0 ],
        [2, 1.0, 1.0],
        [3, 0.0, 1.0]]; 

    elements = [
        [0, 0, 1, 2],
        [1, 0, 2, 3]
    ]; 

    edges = [
        [0, 0, 1, 0],
        [1, 1, 2, 0], 
        [2, 0, 2, 0],
        [3, 0, 3, 0],
        [4, 3, 2, 0]
    ]; 
    return np.asarray(nodes),np.asarray(elements),np.asarray(edges)

#%---------------------------------------------------------
def geometry_uniform_mesh(n):
#
#generates a uniform mesh on [0,1]^2 with (n+1)^2 nodes 
#we prescribe Dirichlet bc on the full boundary 
#
    h = 1.0/n; 
    nnodes = (n+1)*(n+1); 
    nele = 2*n*n; 
    nedges = 3*n*n+n+n; 
    nodes = np.zeros([nnodes,3]); 
    elements = np.zeros([nele,4]);

    edges = np.zeros([nedges,4]); 
    #define nodes 
    counter = 0; 
    for i in range(0,n+1):
        for j in range(0,n+1):
             nodes[counter,:] = [counter, j*h, i*h]; 
             counter = counter+1; 
    #define elements 
    counter = 0; 
    for i in range(0,n):#-1
        for j in range(0,n):
            v1 = i*(n+1)+j; 
            v2 = i*(n+1)+j+1; 
            v3 = (i+1)*(n+1)+j; 
            v4 = (i+1)*(n+1)+j+1; 
            elements[counter,:] = [counter, v1, v2, v4]; 
            counter = counter+1; 
            elements[counter,:] = [counter, v1, v4, v3]; 
            counter = counter+1; 

    #define edges 
    counter = 0; 
    flag2 = 0; 
    for i in range(0,n):
        if (i==0): 
            flag1 = 1; 
        else: 
            flag1 = 0; 
        for j in range(0,n):
            if (j==0):
                flag3 = 1; 
            else: 
                flag3 = 0; 
            v1 = i*(n+1)+j  ; 
            v2 = i*(n+1)+j+1  ; 
            v3 = (i+1)*(n+1)+j  ; 
            v4 = (i+1)*(n+1)+j+1  ; 
            edges[counter,:] = [counter, v1, v2, flag1]; 
            counter = counter+1; 
            edges[counter,:] = [counter, v1, v4, flag2]; 
            counter = counter+1; 
            edges[counter,:] = [counter, v1, v3, flag3]; 
            counter = counter+1; 
        edges[counter,:] = [counter, v2, v4, 1]; 
        counter = counter+1; 

    i = n-1; 
    for j in range(0,n):
        v1 = i*(n+1)+j  ; 
        v2 = i*(n+1)+j+1  ; 
        v3 = (i+1)*(n+1)+j  ; 
        v4 = (i+1)*(n+1)+j+1  ; 
        edges[counter,:] = [counter, v3, v4, 1]; 
        counter = counter+1;


    return nodes,elements,edges
#------------------------------------------
def gauleg(n):
    x,w=np.polynomial.legendre.leggauss(n)
    
    return x,w


###actually run the simulation
exact_energy = 1/45; 
mmax = 5; 
energy = np.zeros(mmax); 
h = np.zeros(mmax); 

for mm in range(1,mmax+1):
    print("mm=",mm)
    m = 2**mm; 
    h[mm-1] = 1.0/m; 
    energy[mm-1] = fem(m); 
#compute  extrapolated energies

extrapolated_energy=np.zeros(mmax)
for mm in range(3,mmax):
    ##TODO:
    #BESTIMME extrapolierte Energien mittels Aitken Deltaverfahren 
    #aus energy(mm), energy(mm-1), energy(mm-2)

import matplotlib.pyplot as plt    
plt.rcParams['text.usetex'] = True
plt.figure(1)
plt.loglog(h,exact_energy - energy, '*-', \
       h,0.1*np.power(h,2), \
       h[2:mmax],abs(exact_energy-extrapolated_energy[2:mmax]))
plt.xlabel("Gitterweite h")
plt.ylabel("Fehler in der Energie")
plt.legend(["Fehler","$O(h^2)$","Fehler der Extrapolation"])
plt.title("FEM auf Einheitsquadrat; gleichmaessiges Dreiecksgitter, glatte Loesung")
plt.show()
