mercredi 1 juillet 2015

Multi threaded Black Scholes PDE Solver, waiting for worker threads

The purpose is to construct a multiu threaded pde solver for the black scholes model with constant parameters and using threading capabilities by c++11 using explicit finite differences scheme.

the vector currentSolution_m is split into the number of cpus at each iteration and then:

std::vector<std::future<bool> >confirmation

is used to wait for each worker to finish the current iteration.

Since bool returned by propagateOneTimeStep() is unused, what other more elegant model can be used for this purpose?

    #ifndef MultiThreadPDE_Explicit_h
    #define MultiThreadPDE_Explicit_h

    #include <thread>
    #include "ThreadPool.h"
    #include "spline.h"

    using namespace std;

    /****************************************************************************
    * Author: Horacio Aliaga    
    * Copyright (C) 2015 Horacio Aliaga (horacio.aliaga at gmail.com)
    * Description:  Library providing a multithreading Black Scholes PDE Solver
    *               using the explicit method of order one.  
    *               The algorithm in unconditionally stable provided the infinite norm
    *               of the heat equation matrix is lower than 1
    * Dependencies: this library is using ThreadPool, that can be found on:
    *               http://ift.tt/Onhtkh
    *
    *               and is also using spline.h
    * Copyright (C) 2011, 2014 Tino Kluge (ttk448 at gmail.com)
    *               
    ****************************************************************************/

    template<class T>
    class MultiThreadPDE_Explicit
    {
    public:

        MultiThreadPDE_Explicit(T &process) : process_m(process) {}

        MultiThreadPDE_Explicit(
            T &process,                      //process returs the diffusion, advection and sink/source terms of the parabolic equation
            int numberOfTimeNodes,           //number of partitions of the time domain
            double initialTime,              //lower boundary for time domain
            double finalTime,                //upper boundary for time domain
            int numberOfSpatialNodes,        //number of partitions of the spatial domain   
            double lowerBoundLogUnderlying,  //lower boundary for spatial domain      
            double upperBoundLogUnderlying   //upper boundary for time domain      
            ) :
            process_m(process),
            totalTimePoints_m(numberOfTimeNodes),
            initialTime_m(initialTime),
            finalTime_m(finalTime),
            totalSpatialPoints_m(numberOfSpatialNodes),
            spatialLowerBound_m(lowerBoundLogUnderlying),
            spatialUpperBound_m(upperBoundLogUnderlying),
            deltaX_m((upperBoundLogUnderlying - lowerBoundLogUnderlying) / (numberOfSpatialNodes + 1.0)),
            deltaT_m((finalTime - initialTime) / (numberOfTimeNodes - 1.0)),
            spatialGrid_m(numberOfSpatialNodes + 2),
            currentSolution_m(numberOfSpatialNodes + 2),
            previousSolution_m(numberOfSpatialNodes + 2)
        {
            //initializes lower and upper boundaries of spatialGrid_m and currentSolution_m
            spatialGrid_m.front() = spatialLowerBound_m;
            currentSolution_m.front() = process_m.Payoff(spatialGrid_m.front(), initialTime_m);
            spatialGrid_m.back() = spatialUpperBound_m;
            currentSolution_m.back() = process_m.Payoff(spatialGrid_m.back(), initialTime_m);

            //initializes non-boundaries values of spatialGrid_m and currentSolution_m
            for (int r = 1; r <= totalSpatialPoints_m; r++)
            {
                spatialGrid_m[r] = spatialLowerBound_m + r*deltaX_m;
                currentSolution_m[r] = process_m.Payoff(spatialGrid_m[r], initialTime_m);
            }
        }

        virtual ~MultiThreadPDE_Explicit() {}

        //propagates solution from time i+1 to time i
        //using the kolmogorov backward equation
        //the propagation is using order one on explicit scheme
        //which subdivides the spatial nodes in different groups (depending on hardware concurrency
        //the tasks are allocated on working threads by ThreadPool
        //confirmation is a dummy variable used foor synchronization of each iteration
        //multithreading modifies the currentSolution_m[1,totalSpatialPoints_m] elements
        //the update of 0 and totalSpatialPoints_m of currenSolution_m (boundaries) is performed
        //separately by the main thread
        void propagate()
        {
            int num_threads = std::thread::hardware_concurrency();
            int blockSize = int(totalSpatialPoints_m / num_threads);
            ThreadPool pool(num_threads);
            for (int n = 0; n<(totalTimePoints_m - 1); n++)
            {
                std::vector<std::future<bool> >confirmation;
                double currentTime = initialTime_m + n*deltaT_m;
                previousSolution_m.swap(currentSolution_m);
                int block_start = 1;
                for (int i = 0; i < int(num_threads-1); ++i)
                {
                    int block_end = block_start;
                    block_end += blockSize;
                    block_end = std::min(block_end, totalSpatialPoints_m + 1);
                    confirmation.push_back(pool.enqueue(std::bind(&MultiThreadPDE_Explicit<T>::propagateOneTimeStep, this, block_start, block_end, currentTime)));
                    block_start = block_end + 1;
                }
                int block_end = block_start;
                block_end += blockSize;
                block_end = std::min(block_end, totalSpatialPoints_m);
                confirmation.push_back(pool.enqueue(std::bind(&MultiThreadPDE_Explicit<T>::propagateOneTimeStep, this, block_start, block_end, currentTime)));

                //wait for all threads to end iteration
                for (int k = 0; k < num_threads; ++k) confirmation[k].get();

                //updates lower and upper boundary of currentSolution_m
                currentSolution_m[0] = 0.0;
                currentSolution_m[totalSpatialPoints_m + 1] = process_m.Payoff(spatialGrid_m.back(), currentTime);
            }
        }

        //gets the Present Value of the payoff specified
        const double getPV()
        {
            tk::spline spl;
            spl.set_points(spatialGrid_m, currentSolution_m);
            double interp = spl.operator()(std::log(process_m.getSpot()));
            return interp;
        }
        //Stability is based on calculation of infinite Norm of transition matrix
        //for the particular case of Black Scholes with constant parameters,
        //this matrix is independent of Underlying and time
        //the function below is true only for this special case
        // x = 0.0 (underlying at spot, and time t=1.0  are completely arbitrary parameters
        //that allows quick determination of infinite norm
        bool isStable()
        {
            double x = 0;
            double t = 1.0;
            double advection = process_m.PdeAdvectionConvection(x, t) * 0.5 / deltaX_m;
            double diffusion = process_m.PdeDifusion(x, t) / deltaX_m / deltaX_m;
            double shortRate = process_m.PdeSourceSink(x, t); 
            double L = deltaT_m*(diffusion - advection);
            double D = (1.0 - deltaT_m*(2.0*diffusion - shortRate));
            double U = deltaT_m*(diffusion + advection);
            double infiniteNorm = std::abs(L) + std::abs(D) + std::abs(U);
            return (infiniteNorm < 1.0 ? true : false);
        }

        MultiThreadPDE_Explicit(const MultiThreadPDE_Explicit&) = delete;
        void operator=(const MultiThreadPDE_Explicit&) = delete;

    private:

        //mission is to update currentSolution_m
        //minimum start is ONE
        //maximum end is totalSpatialPoints_m
        bool propagateOneTimeStep(const int& begin, const int& end, const double& t)
        {
            std::thread::id id = std::this_thread::get_id();
            // minimum begin needs to be 1
            if (begin < 1)
                throw std::invalid_argument("propagateOneTimeStep:: received index array less than 1");
            if (end > totalSpatialPoints_m)
                throw std::invalid_argument("propagateOneTimeStep:: received index array greater than than totalSpatialPoints_m");
            // maximum end needs to be totalSpatialPoints_m
            double tempInfiniteNorm = -1E+12;
            for (int r = begin; r <= end; r++)
            {
                double x = spatialGrid_m[r];
                //advection, diffusion and shortRate are constant
                //on Black Scholes model with constant parameters
                //the unoptimized generalized calculation below
                //is left for a future extension to non constant parameters
                double advection = process_m.PdeAdvectionConvection(x, t) * 0.5/ deltaX_m;      
                double diffusion = process_m.PdeDifusion(x, t)/ deltaX_m / deltaX_m;                    
                double shortRate = process_m.PdeSourceSink(x, t);   

                double L = deltaT_m*(diffusion - advection);
                double D = (1.0 - deltaT_m*(2.0*diffusion - shortRate));
                double U = deltaT_m*(diffusion + advection);

                currentSolution_m[r] = previousSolution_m[r - 1] * L;
                currentSolution_m[r] +=    previousSolution_m[r] * D;
                currentSolution_m[r] +=previousSolution_m[r + 1] * U;
            }
            return true;
        }

        double initialTime_m;
        double finalTime_m;
        int totalTimePoints_m;
        double spatialLowerBound_m;
        double spatialUpperBound_m;
        int totalSpatialPoints_m;
        double deltaX_m, deltaT_m;

        T& process_m;
        std::vector<double> spatialGrid_m;              
        std::vector<double> currentSolution_m;         
        std::vector<double> previousSolution_m;         

    };


    #endif

#ifndef __SpotProcess_h__
#define __SpotProcess_h__

#include <vector>
#include <math.h>
using namespace std;


class SpotProcess{
public:
    //this class encapsulates the pay off of a european call option on spot
    //also allows retrieving the advection, difusion and source/sink terms
    //of the black scholes PDE of Spot process, with short term risk free rate 
    //and a cost of carry rate
    //the PDE is expressed in units of natural logarithm of Spot, since this way
    //the discretization error is lower than when using the PDE in Spot units
    SpotProcess()
    {
        Spot_m = 100; Strike_m = 100; shortTermRiskFreeRate_m = 1E-2; CostOfCarry_m = 1E-2; Sigma_m = 0.50; Expiry_m = 1.0;
    }

    SpotProcess(const double& Spot, const double& Strike, const double& ShortTermRiskFreeRate, const double& CostOfCarry, const double& Sigma, const double& Expiry)
    {
        Spot_m = Spot;
        Strike_m = Strike;
        shortTermRiskFreeRate_m = ShortTermRiskFreeRate;
        CostOfCarry_m = CostOfCarry;
        Sigma_m = Sigma;
        Expiry_m = Expiry;
    }

    inline double PdeAdvectionConvection(double S, double t) const 
    {
        return (shortTermRiskFreeRate_m - CostOfCarry_m - 0.5*Sigma_m*Sigma_m);
    }

    inline double PdeDifusion(double S, double t) const 
    {
        double val = 0.5*Sigma_m*Sigma_m;
        return val;
    }

    inline double PdeSourceSink(const double& S, const double& t) const 
    {
        return -shortTermRiskFreeRate_m;
    }

    virtual double Payoff(const double& x, const double& t) const 
    {
        return std::exp(x) - Strike_m > 0 ? std::exp(x) - Strike_m : 0;
    }

    double getSigma() const 
    {
        return Sigma_m;
    }

    double getShortTermRiskFreeRate() const 
    {
        return shortTermRiskFreeRate_m;
    }

    double getSpot() const 
    {
        return Spot_m;
    }

    void operator=(SpotProcess&) = delete;

    SpotProcess(const SpotProcess&) = delete;

private:
    double Spot_m;
    double Strike_m;
    double shortTermRiskFreeRate_m;
    double CostOfCarry_m;
    double Sigma_m;
    double Expiry_m;

};

#endif //__SpotProcess_h__


#include "stdafx.h"
#include "SpotProcess.h"
#include "MultiThreadPDE_Explicit.h"

using namespace std;
using namespace tk;

int main()
{
    SpotProcess sp;
    double maxS = std::log(400);
    double minS = std::log(0.1);
    MultiThreadPDE_Explicit<SpotProcess> mtpde(sp, 100, 0.0, 1.0, 50, minS, maxS);
    double pv = 0.0;
    if (mtpde.isStable())
    {
        mtpde.propagate();
        pv = mtpde.getPV();
    }
    return 0;
}

#ifndef THREAD_POOL_H
#define THREAD_POOL_H

#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>

class ThreadPool {
public:
    ThreadPool(size_t);
    template<class F, class... Args>
    auto enqueue(F&& f, Args&&... args)
        ->std::future<typename std::result_of<F(Args...)>::type>;
    ~ThreadPool();
    bool areTasksEmpty();
private:
    // need to keep track of threads so we can join them
    std::vector< std::thread > workers;
    // the task queue
    //std::queue< std::function<void()> > tasks;

    // synchronization
    std::mutex queue_mutex;
    std::condition_variable condition;
    std::queue< std::function<void()> > tasks;
    bool stop;
};


// the constructor just launches some amount of workers
ThreadPool::ThreadPool(size_t threads)
    : stop(false)
{
    std::thread::id id = std::this_thread::get_id();
    for (size_t i = 0; i<threads; ++i)
        workers.emplace_back(
        [this]
    {
        for (;;)
        {
            std::function<void()> task;

            {
                std::unique_lock<std::mutex> lock(this->queue_mutex);
                this->condition.wait(lock,
                    [this]{ return this->stop || !this->tasks.empty(); });
                if (this->stop && this->tasks.empty())
                    return;
                task = std::move(this->tasks.front());
                this->tasks.pop();
            }
            task();
        }
    }
    );
}
/// <summary>
/// add new work item to the pool
/// </summary>
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
    std::thread::id id = std::this_thread::get_id();

    using return_type = typename std::result_of<F(Args...)>::type;

    auto task = std::make_shared< std::packaged_task<return_type()> >(
        std::bind(std::forward<F>(f), std::forward<Args>(args)...)
        );

    std::future<return_type> res = task->get_future();
    {
        std::thread::id id1 = std::this_thread::get_id();
        std::unique_lock<std::mutex> lock(queue_mutex);

        // don't allow enqueueing after stopping the pool
        if (stop)
            throw std::runtime_error("enqueue on stopped ThreadPool");

        tasks.emplace([task](){ (*task)(); });
    }
    condition.notify_one();
    return res;
}

// the destructor joins all threads
ThreadPool::~ThreadPool()
{
    std::thread::id id = std::this_thread::get_id();
    {
        std::unique_lock<std::mutex> lock(queue_mutex);
        stop = true;
    }
    condition.notify_all();
    for (std::thread &worker : workers)
        worker.join();
}

bool ThreadPool::areTasksEmpty()
{
    std::unique_lock<std::mutex> lock(this->queue_mutex);
    return tasks.empty();
}

#endif

/*
* spline.h
*
* simple cubic spline interpolation library without external
* dependencies
*
* ---------------------------------------------------------------------
* Copyright (C) 2011, 2014 Tino Kluge (ttk448 at gmail.com)
*
*  This program is free software; you can redistribute it and/or
*  modify it under the terms of the GNU General Public License
*  as published by the Free Software Foundation; either version 2
*  of the License, or (at your option) any later version.
*
*  This program is distributed in the hope that it will be useful,
*  but WITHOUT ANY WARRANTY; without even the implied warranty of
*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
*  GNU General Public License for more details.
*
*  You should have received a copy of the GNU General Public License
*  along with this program.  If not, see <http://ift.tt/NfiWTa;.
* ---------------------------------------------------------------------
*
*/


#ifndef _tk_spline_h
#define _tk_spline_h

#include <cstdio>
#include <cassert>
#include <vector>
#include <algorithm>


// unnamed namespace only because the implementation is in this
// header file and we don't want to export symbols to the obj files
namespace {

    namespace tk {

        // band matrix propagater
        class band_matrix {
        private:
            std::vector< std::vector<double> > m_upper;  // upper band
            std::vector< std::vector<double> > m_lower;  // lower band
        public:
            band_matrix() {};                             // constructor
            band_matrix(int dim, int n_u, int n_l);       // constructor
            ~band_matrix() {};                            // destructor
            void resize(int dim, int n_u, int n_l);      // init with dim,n_u,n_l
            int dim() const;                             // matrix dimension
            int num_upper() const {
                return m_upper.size() - 1;
            }
            int num_lower() const {
                return m_lower.size() - 1;
            }
            // access operator
            double & operator () (int i, int j);            // write
            double   operator () (int i, int j) const;      // read
            // we can store an additional diogonal (in m_lower)
            double& saved_diag(int i);
            double  saved_diag(int i) const;
            void lu_decompose();
            std::vector<double> r_propagate(const std::vector<double>& b) const;
            std::vector<double> l_propagate(const std::vector<double>& b) const;
            std::vector<double> lu_propagate(const std::vector<double>& b,
                bool is_lu_decomposed = false);
        };


        // spline interpolation
        class spline {
        private:
            std::vector<double> m_x, m_y;           // x,y coordinates of points
            // interpolation parameters
            // f(x) = a*(x-x_i)^3 + b*(x-x_i)^2 + c*(x-x_i) + y_i
            std::vector<double> m_a, m_b, m_c, m_d;
        public:
            void set_points(const std::vector<double>& x,
                const std::vector<double>& y, bool cubic_spline = true);
            double operator() (double x) const;
        };






        // ---------------------------------------------------------------------
        // implementation part, which should be separated into a cpp file
        // ---------------------------------------------------------------------




        // band_matrix implementation
        // -------------------------

        band_matrix::band_matrix(int dim, int n_u, int n_l) {
            resize(dim, n_u, n_l);
        }
        void band_matrix::resize(int dim, int n_u, int n_l) {
            assert(dim>0);
            assert(n_u >= 0);
            assert(n_l >= 0);
            m_upper.resize(n_u + 1);
            m_lower.resize(n_l + 1);
            for (size_t i = 0; i<m_upper.size(); i++) {
                m_upper[i].resize(dim);
            }
            for (size_t i = 0; i<m_lower.size(); i++) {
                m_lower[i].resize(dim);
            }
        }
        int band_matrix::dim() const {
            if (m_upper.size()>0) {
                return m_upper[0].size();
            }
            else {
                return 0;
            }
        }


        // defines the new operator (), so that we can access the elements
        // by A(i,j), index going from i=0,...,dim()-1
        double & band_matrix::operator () (int i, int j) {
            int k = j - i;       // what band is the entry
            assert((i >= 0) && (i<dim()) && (j >= 0) && (j<dim()));
            assert((-num_lower() <= k) && (k <= num_upper()));
            // k=0 -> diogonal, k<0 lower left part, k>0 upper right part
            if (k >= 0)   return m_upper[k][i];
            else        return m_lower[-k][i];
        }
        double band_matrix::operator () (int i, int j) const {
            int k = j - i;       // what band is the entry
            assert((i >= 0) && (i<dim()) && (j >= 0) && (j<dim()));
            assert((-num_lower() <= k) && (k <= num_upper()));
            // k=0 -> diogonal, k<0 lower left part, k>0 upper right part
            if (k >= 0)   return m_upper[k][i];
            else        return m_lower[-k][i];
        }
        // second diag (used in LU decomposition), saved in m_lower
        double band_matrix::saved_diag(int i) const {
            assert((i >= 0) && (i<dim()));
            return m_lower[0][i];
        }
        double & band_matrix::saved_diag(int i) {
            assert((i >= 0) && (i<dim()));
            return m_lower[0][i];
        }

        // LR-Decomposition of a band matrix
        void band_matrix::lu_decompose() {
            int  i_max, j_max;
            int  j_min;
            double x;

            // preconditioning
            // normalize column i so that a_ii=1
            for (int i = 0; i<this->dim(); i++) {
                assert(this->operator()(i, i) != 0.0);
                this->saved_diag(i) = 1.0 / this->operator()(i, i);
                j_min = std::max(0, i - this->num_lower());
                j_max = std::min(this->dim() - 1, i + this->num_upper());
                for (int j = j_min; j <= j_max; j++) {
                    this->operator()(i, j) *= this->saved_diag(i);
                }
                this->operator()(i, i) = 1.0;          // prevents rounding errors
            }

            // Gauss LR-Decomposition
            for (int k = 0; k<this->dim(); k++) {
                i_max = std::min(this->dim() - 1, k + this->num_lower());  // num_lower not a mistake!
                for (int i = k + 1; i <= i_max; i++) {
                    assert(this->operator()(k, k) != 0.0);
                    x = -this->operator()(i, k) / this->operator()(k, k);
                    this->operator()(i, k) = -x;                         // assembly part of L
                    j_max = std::min(this->dim() - 1, k + this->num_upper());
                    for (int j = k + 1; j <= j_max; j++) {
                        // assembly part of R
                        this->operator()(i, j) = this->operator()(i, j) + x*this->operator()(k, j);
                    }
                }
            }
        }
        // propagates Ly=b
        std::vector<double> band_matrix::l_propagate(const std::vector<double>& b) const {
            assert(this->dim() == (int)b.size());
            std::vector<double> x(this->dim());
            int j_start;
            double sum;
            for (int i = 0; i<this->dim(); i++) {
                sum = 0;
                j_start = std::max(0, i - this->num_lower());
                for (int j = j_start; j<i; j++) sum += this->operator()(i, j)*x[j];
                x[i] = (b[i] * this->saved_diag(i)) - sum;
            }
            return x;
        }
        // propagates Rx=y
        std::vector<double> band_matrix::r_propagate(const std::vector<double>& b) const {
            assert(this->dim() == (int)b.size());
            std::vector<double> x(this->dim());
            int j_stop;
            double sum;
            for (int i = this->dim() - 1; i >= 0; i--) {
                sum = 0;
                j_stop = std::min(this->dim() - 1, i + this->num_upper());
                for (int j = i + 1; j <= j_stop; j++) sum += this->operator()(i, j)*x[j];
                x[i] = (b[i] - sum) / this->operator()(i, i);
            }
            return x;
        }

        std::vector<double> band_matrix::lu_propagate(const std::vector<double>& b,
            bool is_lu_decomposed) {
            assert(this->dim() == (int)b.size());
            std::vector<double>  x, y;
            if (is_lu_decomposed == false) {
                this->lu_decompose();
            }
            y = this->l_propagate(b);
            x = this->r_propagate(y);
            return x;
        }





        // spline implementation
        // -----------------------

        void spline::set_points(const std::vector<double>& x,
            const std::vector<double>& y, bool cubic_spline) {
            assert(x.size() == y.size());
            m_x = x;
            m_y = y;
            int   n = x.size();
            // TODO sort x and y, rather than returning an error
            for (int i = 0; i<n - 1; i++) {
                assert(m_x[i]<m_x[i + 1]);
            }

            if (cubic_spline == true) { // cubic spline interpolation
                // setting up the matrix and right hand side of the equation system
                // for the parameters b[]
                band_matrix A(n, 1, 1);
                std::vector<double>  rhs(n);
                for (int i = 1; i<n - 1; i++) {
                    A(i, i - 1) = 1.0 / 3.0*(x[i] - x[i - 1]);
                    A(i, i) = 2.0 / 3.0*(x[i + 1] - x[i - 1]);
                    A(i, i + 1) = 1.0 / 3.0*(x[i + 1] - x[i]);
                    rhs[i] = (y[i + 1] - y[i]) / (x[i + 1] - x[i]) - (y[i] - y[i - 1]) / (x[i] - x[i - 1]);
                }
                // boundary conditions, zero curvature b[0]=b[n-1]=0
                A(0, 0) = 2.0;
                A(0, 1) = 0.0;
                rhs[0] = 0.0;
                A(n - 1, n - 1) = 2.0;
                A(n - 1, n - 2) = 0.0;
                rhs[n - 1] = 0.0;

                // propagate the equation system to obtain the parameters b[]
                m_b = A.lu_propagate(rhs);

                // calculate parameters a[] and c[] based on b[]
                m_a.resize(n);
                m_c.resize(n);
                for (int i = 0; i<n - 1; i++) {
                    m_a[i] = 1.0 / 3.0*(m_b[i + 1] - m_b[i]) / (x[i + 1] - x[i]);
                    m_c[i] = (y[i + 1] - y[i]) / (x[i + 1] - x[i])
                        - 1.0 / 3.0*(2.0*m_b[i] + m_b[i + 1])*(x[i + 1] - x[i]);
                }
            }
            else { // linear interpolation
                m_a.resize(n);
                m_b.resize(n);
                m_c.resize(n);
                for (int i = 0; i<n - 1; i++) {
                    m_a[i] = 0.0;
                    m_b[i] = 0.0;
                    m_c[i] = (m_y[i + 1] - m_y[i]) / (m_x[i + 1] - m_x[i]);
                }
            }

            // for the right boundary we define
            // f_{n-1}(x) = b*(x-x_{n-1})^2 + c*(x-x_{n-1}) + y_{n-1}
            double h = x[n - 1] - x[n - 2];
            // m_b[n-1] is determined by the boundary condition
            m_a[n - 1] = 0.0;
            m_c[n - 1] = 3.0*m_a[n - 2] * h*h + 2.0*m_b[n - 2] * h + m_c[n - 2];   // = f'_{n-2}(x_{n-1})
        }

        double spline::operator() (double x) const {
            size_t n = m_x.size();
            // find the closest point m_x[idx] < x, idx=0 even if x<m_x[0]
            std::vector<double>::const_iterator it;
            it = std::lower_bound(m_x.begin(), m_x.end(), x);
            int idx = std::max(int(it - m_x.begin()) - 1, 0);

            double h = x - m_x[idx];
            double interpol;
            if (x<m_x[0]) {
                // extrapolation to the left
                interpol = ((m_b[0])*h + m_c[0])*h + m_y[0];
            }
            else if (x>m_x[n - 1]) {
                // extrapolation to the right
                interpol = ((m_b[n - 1])*h + m_c[n - 1])*h + m_y[n - 1];
            }
            else {
                // interpolation
                interpol = ((m_a[idx] * h + m_b[idx])*h + m_c[idx])*h + m_y[idx];
            }
            return interpol;
        }


    } // namespace tk


} // namespace

#endif /* _tk_spline_h */

Aucun commentaire:

Enregistrer un commentaire