fin_diff_block_solve_vector Function

private function fin_diff_block_solve_vector(this, rhs, offset) result(solution)

Solves the linear(ised) system represented by this finite difference block, for a given right hand side state vector (represented by a vector field). Optionally, the differential operator can be augmented by adding an offset, i.e. a vector field which is added to the operator.

Arguments

Type IntentOptional AttributesName
class(fin_diff_block), intent(inout) :: this
class(cheb1d_vector_field), intent(in) :: rhs

The right hand side of the linear(ised) system.

class(cheb1d_vector_field), intent(in), optional :: offset

An offset to add to the differential operator

Return Value class(vector_field), pointer


Calls

proc~~fin_diff_block_solve_vector~~CallsGraph proc~fin_diff_block_solve_vector fin_diff_block_solve_vector str str proc~fin_diff_block_solve_vector->str

Contents


Source Code

  function fin_diff_block_solve_vector(this, rhs, offset) result(solution)
    !* Author: Chris MacMackin
    !  Date: December 2016
    !
    ! Solves the linear(ised) system represented by this finite
    ! difference block, for a given right hand side state vector
    ! (represented by a vector field). Optionally, the differential
    ! operator can be augmented by adding an offset, i.e. a vector
    ! field which is added to the operator.
    !
    ! @Warning Currently this is only implemented for a 1-D field.
    !
    ! @Bug For some reason, calls to the `vector_dimensions()` method
    ! produce a segfault when `rhs` is
    ! `class(vector_field)`. Everything works fine if it is
    ! `class(cheb1d_vector_field)`, so this is used as a workaround.
    !
    class(fin_diff_block), intent(inout)      :: this
    class(cheb1d_vector_field), intent(in)    :: rhs
      !! The right hand side of the linear(ised) system.
    class(cheb1d_vector_field), optional, intent(in) :: offset
      !! An offset to add to the differential operator
    class(vector_field), pointer              :: solution

    real(r8), dimension(:), allocatable :: sol_vector, diag_vector
    integer                             :: flag, n, i, j, k
    class(scalar_field), pointer        :: component, ocomponent
    integer                             :: m
    real(r8)                            :: forward_err, &
                                           backward_err, &
                                           condition_num
    character(len=1)                    :: factor
    character(len=:), allocatable       :: msg
 
    call rhs%guard_temp()
    if (present(offset)) call offset%guard_temp()
    allocate(sol_vector(rhs%raw_size()))
    n = size(this%diagonal)
    ! Allocate the arrays used to hold the factorisation of the
    ! tridiagonal matrix
    if (.not. allocated(this%pivots)) then
      allocate(this%l_multipliers(n-1))
      allocate(this%u_diagonal(n))
      allocate(this%u_superdiagonal1(n-1))
      allocate(this%u_superdiagonal2(n-2))
      allocate(this%pivots(n))
      factor = 'N'
    else
      if (.not. present(offset) .and. .not. this%had_offset) then
        factor = 'F'
      else
        factor = 'N'
      end if
    end if

    call rhs%allocate_scalar_field(component)
    call component%guard_temp()
    if (present(offset)) then
#ifdef DEBUG
      if (offset%elements() /= n) then
        call logger%fatal('fin_diff_block%solve_for','Offset field has '// &
                          'different resolution than finite difference block.')
        error stop
      else if (offset%vector_dimensions() < rhs%vector_dimensions()) then
        call logger%fatal('fin_diff_block%solve_for','Offset field has '// &
                          'different number of vector components than '// &
                          'field being solved for.')
        error stop
      end if
#endif
      call offset%allocate_scalar_field(ocomponent)
      call ocomponent%guard_temp()
    end if

    do i = 1, rhs%vector_dimensions()
      if (i > 1 .and. .not. present(offset)) factor = 'F'
      component = rhs%component(i)
      if (present(offset)) then
        ocomponent = offset%component(i)
        if (i == 1) allocate(diag_vector(n))
        where (this%diagonal == 0._r8)
          diag_vector = this%diagonal + ocomponent%raw()
        elsewhere
          diag_vector = this%diagonal
        end where
        if (.not. any(1 == this%boundary_locs)) then
          diag_vector(1) = diag_vector(1) + ocomponent%get_element(1)
        end if
        if (.not. any(n == this%boundary_locs)) then
          diag_vector(n) = diag_vector(n) + ocomponent%get_element(n)
        end if
        do j=1, size(this%boundary_locs)
          k = this%boundary_locs(j)
          if ((k == 1 .or. k == n) .and. this%boundary_types(j) == free_boundary) then
            diag_vector(k) = diag_vector(k) + ocomponent%get_element(k)
          end if
        end do
        call la_gtsvx(this%sub_diagonal, diag_vector, this%super_diagonal, &
                      component%raw(), sol_vector((i-1)*n+1:i*n),          &
                      this%l_multipliers, this%u_diagonal,                 &
                      this%u_superdiagonal1, this%u_superdiagonal2,        &
                      this%pivots, factor, 'N', forward_err, backward_err, &
                      condition_num, flag)
      else
        call la_gtsvx(this%sub_diagonal, this%diagonal, this%super_diagonal, &
                      component%raw(), sol_vector((i-1)*n+1:i*n),            &
                      this%l_multipliers, this%u_diagonal,                   &
                      this%u_superdiagonal1, this%u_superdiagonal2,          &
                      this%pivots, factor, 'N', forward_err, backward_err,   &
                      condition_num, flag)
      end if
    end do
    call component%clean_temp()
    if(present(offset)) call ocomponent%clean_temp()
    if (flag/=0) then
      msg = 'Tridiagonal matrix solver returned with flag '//str(flag)
      call logger%error('fin_diff_block%solve_for',msg)
    else
      this%had_offset = present(offset)
#ifdef DEBUG
      msg = 'Tridiagonal matrix solver returned with estimated condition '// &
            'number '//str(condition_num)
      call logger%debug('fin_diff_block%solve_for',msg)
#endif
    end if
    call rhs%allocate_vector_field(solution)
    call solution%unset_temp()
    call solution%assign_meta_data(rhs)
    call solution%set_from_raw(sol_vector)
    if (present(offset)) call offset%clean_temp()
    call rhs%clean_temp(); call solution%set_temp()
  end function fin_diff_block_solve_vector