#include <iostream>
#include <set>
#include <cmath>
#include <cassert>
#include <cstdlib> // atoi and atof
#include <iomanip> // setw
#include "pinned_space.h"
#include "heat_trad.h"
#include "yz_pair.h"
#include "params.h"
using namespace std;
using namespace twong;

// Declaration of Functions (they're defined after the main function).
double N(double x, double n);
double N_neg_infty(double n);
void dimension_info(set<yz_pair> &collection, pinned_space &space);
double utility(double x);
double utility_sums(int verbosity, int losses, set<yz_pair> &collection, pinned_space &u, double dx, double dy, double dz, double final_tau, double A, double R, double gamma, double Co, double C_old, double n, double Bo);
double tax_invest(const double T, const double dx, const double init_frac, const int verbosity, const int losses);
void solve_onestep(pinned_space *&p_cur, pinned_space *&p_next, set<yz_pair> &active, heat_trad &heat, IC_params ic, fin_params fin, grid_params grid, const double refln_cutoff, const double sig_sq_div2, const int i, const int &verbosity);

/******************************************************************************
 *                                                                            *
 * MAIN FUNCTION                                                              *
 *                                                                            *
 * This processes the command-line arguments and determines if the user wants *
 * to run:                                                                    *
 *   1) a single iteration of the tax investment program                      *
 *   or                                                                       *
 *   2) multiple iterations of the tax investment program in order to         *
 *      optimize (that is, maximize) the expected return.                     *
 *                                                                            *
 ******************************************************************************/
int main(int argc, char **argv)
{
	/**********************************************************************
	 *                                                                    *
	 * PROCESS THE COMMAND-LINE ARGUMENTS                                 *
	 * Depending on the number of command-line arguments, the program     *
	 * will do different things.                                          *
	 *                                                                    *
	 **********************************************************************/

	// Note "argc" is one more than the number of arguments, and argv[1] is
	// the first argument.  So, (argc - 1) is the number of arguments.
	if (argc - 1 == 4)
	{
		/**************************************************************
		 *                                                            *
		 * We're running one iteration of the tax investment program. *
		 *                                                            *
		 **************************************************************/
	
		// Get the command-line arguments.
		const double T = atof(argv[1]); // T is in real time; it is not the final_tau!
		const double dx = atof(argv[2]);
		const double init_frac = atof(argv[3]);
		const int verbosity = atoi(argv[4]);

		// Stop if the command-line arguments are unacceptable values.
		if (verbosity < 0 || verbosity > 1 || T < 0 || init_frac < 0 || dx <= 0)
		{
			cout << "Input parameter(s) out of range, stopped." << endl
			     << endl;
			return 0;
		}

		// Run the tax investment program.
		tax_invest(T, dx, init_frac, verbosity, 1);

		// We're done!
		return 0;
	}
	else if (argc - 1 == 6)
	{
		/**************************************************************
		 *                                                            *
		 * We're optimizing the expected return.  Do this using a     *
		 * method called "Successive Parabolic Interpolation (SPI)."  *
		 *                                                            *
		 **************************************************************/
		
		// Get the remaining command-line arguments.
		const double T = atof(argv[1]); // T is in real time; it is not the final_tau!
		const double dx = atof(argv[2]);
		const int losses = atoi(argv[3]);
		double r = atof(argv[4]); //r= initial stock fraction 1
		double s = atof(argv[5]); //s= initial stock fraction 2
		double t = atof(argv[6]); //t= initial stock fraction 3

		// Stop if the command-line arguments are unacceptable values.
		if (T < 0 || r < 0 || s < 0 || t < 0 || dx <= 0 || losses < 0 || losses > 1)
		{
			cout << "Input parameter(s) out of range, stopped." << endl
			     << endl;
			return 0;
		}

		// Define the stopping tolerance.
		// Specifically, if two successive locations of the parbolic
		// max become smaller than this number, we report the last max as the argmax and stop.
		const double x_tolerance = 1e-3;
		// Also define the y_tolerance to be the smallest calculatable valid difference in y.
		const double y_tolerance = 1e-9;
	
		// Define the number of iterations before stopping.
		const int maxiter = 25;
	
		// Keep track of the location of the interpolated parabola's maximum.
		// Initially, there isn't a value for this, so flag it.
		double lastmax = 0;
		bool lastmax_flag = false;
	
		// Keep track of V at the three initial stock fractions r, s, and t
		// Initially, flag so that each of them need to be calculated:
		//  - 0 means V must be calculated
		//  - 1 means V does *not* need to be calculated
		double Vr = 0;
		double Vs = 0;
		double Vt = 0;
		bool Vr_flag = false;
		bool Vs_flag = false;
		bool Vt_flag = false;
	
		// Run for the max number of iterations.
		bool first_run = true;
		int verbosity = -2;
		for (int iter = 1; iter <= maxiter; iter++)
		{
			// Find the value of f at each point, if necessary.
			if (Vr_flag == false)
			{
				Vr = tax_invest(T, dx, r, verbosity, losses);
				Vr_flag = 1;
				verbosity = -3;
			}
			if (Vs_flag == false)
			{
				Vs = tax_invest(T, dx, s, verbosity, losses);
				Vs_flag = 1;
			}
			if (Vt_flag == false)
			{
				Vt = tax_invest(T, dx, t, verbosity, losses);
				Vt_flag = 1;
			}
	
			// If it's the first run, print a header for the output.
			if (first_run == true)
			{
				// Print a header for the output.
				cout << "Iteration  Initial fraction 1  \t Initial fraction 2  \t Initial fraction 3  \t V(frac 2)-V(frac 1) \t      V(frac 2)      \t V(frac 3)-V(frac 2) " << endl
				     << "--------- ---------------------\t---------------------\t---------------------\t---------------------\t---------------------\t---------------------" << endl;
				
				// Set the first_run flag to false.
				first_run = false;
			}
			
			// Print this iteration's r, s, and t and their correponding function values:
			cout.setf(ios::scientific);
			cout.precision(10);
			cout << "Iter " << setw(3) << iter << ": "
			     << setw(20) << r << "\t"
			     << setw(20) << s << "\t"
			     << setw(20) << t << "\t"
			     << setw(20) << Vs - Vr << "\t"
			     << setw(20) << Vs << "\t"
			     << setw(20) << Vt - Vs
			     << endl;
	
			// Find the location of the max of the parabola interpolation.
			double x = (r+s) - ((Vs - Vr)*(t-r)*(t-s))/((s-r)*(Vt-Vs) - (Vs-Vr)*(t-s));
			x /= 2;
	
			// Check if we're within the y_tolerance.
			if (abs(Vs - Vr) < y_tolerance || abs(Vt - Vs) < y_tolerance)
			{
				cout << "\t  Error: try to increase x-tolerance or reduce y_tolerance." << endl
				     << "\t  current value of x (i.e., the parabola's max): " << x << endl;
				return 0;
			}
	
			// Make sure the interpolated parabola will have the correct
			// concavity.  Do this by finding the equation of the line through
			// "r" and "t".  Then, see if the height of this line at "s" is
			// above or below V(s).
			double m = (Vt - Vr) / (t - r);
			double line_s = m * (s - r) + Vr;
			if (Vs <= line_s)
			{
				cout << "\t  Error: interpolated parabola has wrong concavity. You may need to decrease refln_cutoff or increase y_tolerance" << endl
				     << "\t  V(frac 1)    : " << Vs << endl
				     << "\t  line_s: " << line_s << endl;
				return 0;
			}
	
			// See if we're within the x_tolerance of the last max's location.
			if (lastmax_flag == true && abs(x - lastmax) < x_tolerance)
			{
				cout << endl
				     << "Success after " << iter << " iterations." << endl
				     << "The optimal initial fraction in stock is: " << x << endl;
				return 0;
			}
	
			// Update for the next iteration.  Some values of V don't need to
			// be re-calculated.  Others do.  Flag appropriately.
			if (x < r)
			{
				t = s;
				s = r;
				r = x;
	
				Vt = Vs;
				Vs = Vr;
				Vr_flag = 0;
			}
			else if (x < s)
			{
				t = s;
				s = x;
	
				Vt = Vs;
				Vs_flag = 0;
			}
			else if (x < t)
			{
				r = s;
				s = x;
	
				Vr = Vs;
				Vs_flag = 0;
			}
			else
			{
				r = s;
				s = t;
				t = x;
	
				Vr = Vs;
				Vs = Vt;
				Vt_flag = 0;
			}
	
			// Update the max of the interpolated parabola for the next iteration.
			lastmax = x;
			lastmax_flag = 1;
		}
	
		// Error if we reach here.
		printf("Error: did not find max by %d iterations.\n", maxiter);

		// We're done!
		return 0;
	}

	// If we reach here, we have incorrect command-line arguments.
	// Print a usage note.
	cout << "Usage: " << argv[0] << " <T> <dx> <init_frac> <verbosity>" << endl
	     << "               or" << endl
	     << "       " << argv[0] << " <T> <dx> <losses> <r> <s> <t>" << endl
	     << endl
	     << "       <T>           total time to solve the heat equation" << endl
	     << "       <dx>          spacing between points in x" << endl
	     << endl
	     << "       <init_frac>   initial fraction of wealth in stocks" << endl 
	     << "       <verbosity>   0 only print financial parameters and utility sum" << endl
	     << "                     1 also print computational parameters and dimension info" << endl
	     << endl
             << "       <losses>      1 to consider capital gains losses" << endl
             << "                     0 to ignore capital gains losses" << endl
	     << "       <r> <s> <t>   initial fraction of wealth in stocks for optimization" << endl
	     << endl;

	return 0;
}
 
/******************************************************************************
 *                                                                            *
 * TAX INVEST(MENT)                                                           *
 *                                                                            *
 * Given the total real time, dx, initial fraction in stocks/cash, and        *
 * verbosity, this function returns the utility function.                     *
 *                                                                            *
 ******************************************************************************/
double tax_invest(const double T, const double dx, const double init_frac, const int verbosity, const int losses)
{
	/**********************************************************************
	 *                                                                    *
	 * FINANCIAL PARAMETERS & CONSTANTS                                   *
	 *                                                                    *
	 **********************************************************************/

	// Financial Parameters.
	const double mu = .07;
	const double sigma = .2;
	const double r = .04;
	const double So = 100;
	const double Bo = 100;
	assert(So >= Bo);
	const double R = .30;
	const double C_old = 100000*exp(.04*T);
	
	//
	// Quantities derived from the financial parameters.
	//
	
	// Money spent on stocks initially in actual time = T dollars.
	const double Co = -1.0 * init_frac * C_old; 

	// Number of stocks bought initially. In the paper, this is the control,
	// but here the initial fraction in stocks is the control which
	// specifies n.
	
	const double n = init_frac * C_old * exp(-1.0 * r * T) / So;
	const double sig_sq_div2 = sigma * sigma / 2.0; 
	const double final_tau = T * sig_sq_div2;
	const double A = (mu - sig_sq_div2) / sig_sq_div2;
	
	// Print the financial parameters:
	if (verbosity >= 0)
	{
		cout << "*** Financial Parameters ***" << endl
		     << endl
		     << "mu = " << mu << "; \t sigma = " << sigma << "; \t r = " << r << endl
		     << "S(0) = " << So << "; \t B(0) = " << Bo << "; \t R = " << R << endl
		     << "C_old = " << C_old << "; \t Co = " << Co << "; \t n = " << n << endl << endl;
	}

	/**********************************************************************
	 *                                                                    *
	 * COMPUTATIONAL PARAMETERS                                           *
	 *                                                                    *
	 **********************************************************************/

	// Initial Condition (IC) Parameters.
	const double x0 = log(So);
	const double y0 = log(Bo);
	const double z0 = 0;
	
	// Computational Grid Parameters.
	const double dtau = .48 * dx * dx; // Picked.48 because needs to be just less than .5
	const double dy = dx;
	const double gamma = N_neg_infty(n) * R * exp(ceil(y0/dy)*dy + r*final_tau/sig_sq_div2);
	const double dz = gamma*dx;
	const double heat_cutoff = 1e-25;
	const double refln_cutoff = 1e-12;

	// Print the computational parameters if the verbosity is appropriate.
	if (verbosity >= 1)
	{
		cout << "*** Computational Parameters ***" << endl
		     << endl
		     << "Beginning cash in time 0 dollars = " << (C_old + Co) * exp (-1.0 * r * T) << endl
		     << "Begin with " << n << " shares which have a value in time 0 dollars of " << n * So << endl
		     << "Requested final real time  T = " << T << " years" << endl
		     << "Requested final tau (computer t) = " << final_tau << endl
		     << "dtau = " << dtau << endl
		     << "dx = " << dx << endl
		     << "dy = " << dy << endl
		     << "dz = " << dz << endl;
	}
	
	// Define the computational spaces.
	pinned_space *p_cur  = new pinned_space;
	pinned_space *p_next = new pinned_space;

	// Define a set to hold the active arrays.
	set<yz_pair> active;
	
	/**********************************************************************
	 *                                                                    *
	 * INITIAL CONDITIONS                                                 *
	 *                                                                    *
	 **********************************************************************/

	// Define the indicies where the IC is located in the pinned space.
	int x0_index = int(x0/dx); // this is the floor of (x0/dx).
	int y0_index = int(y0/dy);
	int z0_index = round(z0/dz);
	add_yz_pair(y0_index, z0_index, active);
	add_yz_pair(y0_index + 1, z0_index, active);
	
	//
	// Divide the initial condition among its closest points.
	//
	
	// This is the height of the delta function approximation in this
	// geometry.
	double delta_height = 1 / (pow(So, A/2.0)*Bo*(dx*dy*dz));

	// This is the amount the floor function cut off from x0.
	// It's between 0 and 1.
	double xrem = x0/dx - x0_index; 

	// This is the amount the floor function cut off from y0.
	// It's between 0 and 1.
	double yrem = y0/dx - y0_index; 

	// Check with stencil we're using.
	if (yrem <= xrem)
	{
		// We're using a lower triangular stencil.
		(*p_cur)[z0_index][y0_index][x0_index] = (1 - xrem)*delta_height;
		(*p_cur)[z0_index][y0_index][x0_index + 1] = (xrem - yrem)*delta_height;
		(*p_cur)[z0_index][y0_index + 1][x0_index + 1] = (yrem)*delta_height;
	}
	else if ((yrem > xrem) && (y0_index < x0_index))
	{
		// We're using an upper triangular stencil.
		(*p_cur)[z0_index][y0_index][x0_index] = (1 - yrem)*delta_height;
		(*p_cur)[z0_index][y0_index + 1][x0_index] = (yrem - xrem)*delta_height;
		(*p_cur)[z0_index][y0_index + 1][x0_index + 1] = (xrem)*delta_height;
	}
	else
	{
		// We're using a diagonal stencil.
		
		// Make sure xrem and yrem are within machine precision of each
		// other.
		assert(fabs(xrem - yrem) < 1e-14);

		// Assign the points according to the stencil.
		(*p_cur)[z0_index][y0_index][x0_index] = (1 - xrem)*delta_height;
		(*p_cur)[z0_index][y0_index + 1][x0_index + 1] = (xrem)*delta_height;
	}
	
	if (verbosity >= 0)
	{
		cout << "The IC starts based from (x,y,z) = (" << x0_index * dx << "," << y0_index * dy << "," << z0_index * dz << ") with a value of " << delta_height << endl;
	}
	
	/**************************************************************
	 *                                                            *
	 * INITIALIZATIONS                                            *
	 *                                                            *
	 **************************************************************/

	// Calculate the number of iterations that gets us closest to, but below, the requested time.
	int total_iter = int(final_tau / dtau);

	// Print the number of iterations, which depends on whether or not dtau
	// divides evenly into final_tau.
	if (verbosity >= 0)
	{
		if (total_iter*dtau != final_tau)
		{
			cout << "Number of iterations: " << total_iter + 1 << endl;
		}
		else
		{
			cout << "Number of iterations: " << total_iter << endl;
		}
	}

	// Create the heat equation solver.
	heat_trad heat(heat_cutoff, dx, dtau);

	/**************************************************************
	 *                                                            *
	 * MAIN LOOP                                                  *
	 *                                                            *
	 **************************************************************/

	// Run the heat equation for the necessary number of iterations.
	for (int i = 0; i < total_iter; ++i)
	{
		// Gather the parameters and solve for one time iteration.
		IC_params ic = {x0_index, y0_index, z0_index, xrem, yrem};
		fin_params fin = {n, R, r};
		grid_params grid = {dx, dy, dz, final_tau, gamma};
	        solve_onestep(p_cur, p_next, active, heat, ic, fin, grid, refln_cutoff, sig_sq_div2, i, verbosity);
	}

	// If our time steps don't divide evenly into the total time we want,
	// run for one more iteration with an adjusted time step.
	if (total_iter*dtau < final_tau)
	{
		// Set the adjusted time step.
		heat.set_dt(final_tau - total_iter*dtau);

		// Gather the parameters and solve for one time iteration.
		IC_params ic = {x0_index, y0_index, z0_index, xrem, yrem};
		fin_params fin = {n, R, r};
		grid_params grid = {dx, dy, dz, final_tau, gamma};
	        solve_onestep(p_cur, p_next, active, heat, ic, fin, grid, refln_cutoff, sig_sq_div2, total_iter + 1, verbosity);
	}

	/**************************************************************
	 *                                                            *
	 * FINAL RESULTS                                              *
	 *                                                            *
	 **************************************************************/

	// Print some information on the dimensions of the space.
	if (verbosity >= 1)
	{
		dimension_info(active, *p_cur);
	}
	
	// Compute and print the sum of the utility function and other useful sums
	double utility_sum = utility_sums(verbosity, losses, active, *p_cur, dx, dy, dz, final_tau, A, R, gamma, Co, C_old, n, Bo);

	if (verbosity >= 0)
	{
		cout << "E N D  O F  R U N -------- E N D  O F  R U N ------------E N D  O F  R U N ---------------" << endl 
		     << endl << endl << endl << endl;
	}

	// End the program.
	return utility_sum;
}


 
 
 
 

double prob_sum(set<yz_pair> &collection, pinned_space &space)
// Given the probability space, returns the sum of the probability.  Note it does not
// include the points behind the boundary.  The active (y,z) pairs are in "collection".
{
	// Start with a sum of 0.
	double sum = 0.0;

	// Go through each (y,z) pair.
	for (set<yz_pair>::iterator iter = collection.begin(); iter != collection.end(); iter++)
	{
		// Get the value of the point behind the boundary.
		double behind_boundary = space[iter->z][iter->y][(space[iter->z][iter->y]).getStartingIndex()];

		// Sum up each x value, but don't include the point behind the boundary.
		sum += space[iter->z][iter->y].sum() - behind_boundary;
	}
	return sum;
}

double N(double x, double n)
// Returns N(x)
{
	return n;
}

double N_neg_infty(double n)
// Returns N(-infinity)
{
	return n;
}

void dimension_info(set<yz_pair> &collection, pinned_space &space)
// Prints information about the dimensions of the (y,z) pairs of "collection"
// in the pinned space, "space".
{
	// Declare the variables that we want to keep track of.
	int maxz; // largest value of z
	int minz; // smallest value of z
	int maxy; // largest value of y
	int miny; // smallest value of y
	int maxxlen; // largest value of the length of x
	yz_pair maxxlen_yz; // the (y,z) pair at which the largest value of the length of x occurs.
	int sumxlen = 0; // sum of the lengths of x
	int numpts = 0; // number of (y,z) pairs
	bool startflag = true;

	// Go through each (y,z) pair.
	for (set<yz_pair>::iterator iter = collection.begin(); iter != collection.end(); iter++)
	{
		// If we just started, give each variable a starting value.
		if (startflag == true)
		{
			// Assign starting values.
			maxz = iter->z;
			minz = iter->z;
			maxy = iter->y;
			miny = iter->y;
			maxxlen = space[iter->z][iter->y].getLength();
			maxxlen_yz = *iter;
			sumxlen += space[iter->z][iter->y].getLength();
			numpts++;
			startflag = false;

			// Go to the next iterator.
			continue;
		}

		if (maxz < iter->z)
		{
			maxz = iter->z;
		}
		if (minz > iter->z)
		{
			minz = iter->z;
		}
		if (maxy < iter->y)
		{
			maxy = iter->y;
		}
		if (miny > iter->y)
		{
			miny = iter->y;
		}
		if (maxxlen < space[iter->z][iter->y].getLength())
		{
			maxxlen = space[iter->z][iter->y].getLength();
			maxxlen_yz = *iter;
		}
		sumxlen += space[iter->z][iter->y].getLength();
		numpts++;
	}

	cout << endl << "**************************** Dimensions of Space ****************************" << endl << endl
	     << "Largest value of z: " << maxz << endl
	     << "Smallest value of z: " << minz << endl
	     << "Largest value of y: " << maxy << endl
	     << "Smallest value of y: " << miny << endl
	     << "Largest value of the length of x: " << maxxlen << endl
	     << "  (y,z) pair at which largest value of length of x occurs: (" << maxxlen_yz.y << "," << maxxlen_yz.z << ")" << endl
	     << "Sum of lengths of x: " << sumxlen << endl
	     << "Number of (y,z) pairs: " << numpts << endl
	     << "Average length of x: " << sumxlen / numpts << endl;
}





/**************************************************************
 *                                                            *
 * UTILITY                                                    *
 *                                                            *
 **************************************************************/

double utility(double x)
// Returns the utility function of x.
{
	return pow(x, .1) ;
}

double utility_sums(int verbosity, int losses, set<yz_pair> &collection, pinned_space &u, double dx, double dy, double dz, double final_tau, double A, double R, double gamma, double Co, double C_old, double n, double Bo)
// Returns the sum of the utility function evaluated at all points in the pinned space.
{
	// Initialize the sums.
	double util_sum = 0;
	double nozutilsum = 0;
	double psum = 0;
	const double time_factor_times_dxdydz = exp(-1.0*A*A/4.0*final_tau) * dx * dy * dz;
	
	// Go through each (y,z) pair and calculate the utility and other sums
	for (set<yz_pair>::iterator iter = collection.begin(); iter != collection.end(); iter++)
	{
		// Get the y and z index from the iterator.
		int y_index = iter->y;
		int z_index = iter->z;

		// Go through each x-coordinate in the pinned, growable array.
		for (int x_index = u[iter->z][iter->y].getStartingIndex() + 1; x_index < u[iter->z][iter->y].getStartingIndex() + u[iter->z][iter->y].getLength(); x_index++)
		{
			// Calculate the actual values of the point (x,y,z).
			double x = x_index * dx;
			double y = y_index * dy;
			double z = z_index * dz;
			// Add this point's contribution to the utility sum.
			util_sum +=  utility(C_old + Co + z + n *(exp(x)*(1-R) + exp(y)*R)) * exp(A/2.0*x + y) * u[z_index][y_index][x_index];
			nozutilsum +=  utility(C_old + Co  /* no + z !!! */ + n *(exp(x)*(1-R) + Bo*R)) * exp(A/2.0*x+y) * u[z_index][y_index][x_index]; 
			psum += exp(A/2.0*x + y) * u[z_index][y_index][x_index];
		}
	}
	
	// Print the Final Utility and other sums
	if (verbosity >= -1)
	{
		cout << endl << "**************************** Utility Sums ****************************" << endl
		     << "Expected Utility at time T with all captial losses realized: " <<  time_factor_times_dxdydz * util_sum << endl 
		     << "Expected Utility at time T with no captial losses realized: " << time_factor_times_dxdydz * nozutilsum << endl << endl;

	}
	
	if (verbosity >=0)
	{
		cout << "Total probability at time T (stays close to 1): " << time_factor_times_dxdydz * psum << endl << endl;
	}

	// Return the utility sum so the optimization program can use it.
	if (losses == 1)
	{
		return time_factor_times_dxdydz * util_sum;
	}
	else
	{
		return time_factor_times_dxdydz * nozutilsum;
	}
}






/**************************************************************
 *                                                            *
 * ONE TIME ITERATION                                         *
 *                                                            *
 **************************************************************/

void solve_onestep(pinned_space *&p_cur, pinned_space *&p_next, set<yz_pair> &active, heat_trad &heat, IC_params ic, fin_params fin, grid_params grid, const double refln_cutoff, const double sig_sq_div2, const int i, const int &verbosity)
// Solves one time step of the tax_investment problem.  "heat" should already be
// defined with the appropriate time step and heat cutoff.
{
	/**************************************************************
	 *                                                            *
	 * REFLECTIONS                                                *
	 *                                                            *
	 **************************************************************/
	
	// exponential growth factor due to the u in the boundary equ: 2u_x + u_y - a*u_z = -u.
	double growth_factor = exp(grid.dy);  
	
	// Check to see if anything reflects at the boundaries.  If so, we need to zero the targets.
	bool refln_flag = false; // Flag to print an extra endline if we're reflecting.
	for (set<yz_pair>::iterator iter = active.begin(); iter != active.end(); iter++)
	{
		// Get the x index of this array's boundary.
		int boundary = (*p_cur)[iter->z][iter->y].getStartingIndex() + 1;

		// Create a pointer to the boundary so we don't have to keep retrieving it.
		double *refln_pt = &(*p_cur)[iter->z][iter->y][boundary];

		// Calculate "a", which is needed to calculate the proportion of
		// "p" that is reflected to which point.
		double y = iter->y * grid.dy;
		double t = i * heat.get_dt();
		double a = N(exp(y), fin.n) * fin.R * exp(y + fin.r*(grid.final_tau - t)/sig_sq_div2);

		// Check the first reflection target.
		if ((a/grid.gamma) * growth_factor * (*refln_pt) > refln_cutoff)
		{
			// Set the reflection flag.
			refln_flag = true;

			// Reflect the information to the first array.
			(*p_cur)[iter->z + 1][iter->y - 1][boundary - 2] = 0;
		}

		// Check the second reflection target.
		if ((1 - a/grid.gamma) * growth_factor * (*refln_pt) > refln_cutoff)
		{
			// Reflect the information to the second array.
			(*p_cur)[iter->z][iter->y - 1][boundary - 2] = 0;
		}
	}

	// Check to see if anything reflects at the boundaries.
	for (set<yz_pair>::iterator iter = active.begin(); iter != active.end(); iter++)
	{
		// Get the index of this array's boundary.
		int boundary = (*p_cur)[iter->z][iter->y].getStartingIndex() + 1;

		// Create a pointer to the boundary so we don't have to keep retrieving it.
		double *refln_pt = &(*p_cur)[iter->z][iter->y][boundary];

		// Calculate "a", which is needed to calculate the proportion of
		// "p" that is reflected to which point.
		double y = iter->y * grid.dy;
		double t = i * heat.get_dt();
		double a = N(exp(y), fin.n) * fin.R * exp(y + fin.r*(grid.final_tau - t)/sig_sq_div2);

		// Check the first reflection target.
		if ((a/grid.gamma) * growth_factor * (*refln_pt) > refln_cutoff)
		{
			// Set the reflection flag.
			refln_flag = true;

			// Reflect the information to the first array.
			(*p_cur)[iter->z + 1][iter->y - 1][boundary - 2] += (a/grid.gamma) * growth_factor * (*refln_pt);

			// Add the first array to the set of active arrays.
			add_yz_pair(iter->y - 1, iter->z + 1, active);
		}

		// Check the second reflection target.
		if ((1 - a/grid.gamma) * growth_factor * (*refln_pt) > refln_cutoff)
		{
			// Reflect the information to the second array.
			(*p_cur)[iter->z][iter->y - 1][boundary - 2] += (1 - a/grid.gamma) * growth_factor * (*refln_pt);

			// Add the second array to the set of active arrays.
			add_yz_pair(iter->y - 1, iter->z, active);
		}
	}
	
	/**************************************************************
	 *                                                            *
	 * SOLVE THE HEAT EQUATION IN THE X-DIRECTION                 *
	 *                                                            *
	 **************************************************************/

	// Solve the heat equation on each existing array.
	for (set<yz_pair>::iterator iter = active.begin(); iter != active.end(); iter++)
	{
		// Determine "end_index", which is what index we want to
		// solve the heat equation up to.
		
		// If "end_index" is equal to the boundary, then it solves
		// until the cutoff is reached.  Use this unless we're at
		// an initial condition array.
		int end_index = (*p_cur)[iter->z][iter->y].getStartingIndex() + 1;

		// See which distribution of the IC we used.
		// Then, check if we're at an IC array.
		if (ic.yrem <= ic.xrem) // We chose the lower triangular stencil
		{
			if (iter->z == ic.z0_index)
			{
				if (iter->y == ic.y0_index)
				{
					end_index = ic.x0_index;
				}
				if (iter->y == (ic.y0_index + 1))
				{
					end_index = ic.x0_index + 1;
				}
			}
		}
		else if ((ic.yrem > ic.xrem) && (ic.y0_index < ic.x0_index)) // We chose the upper triangular stencil
		{
			if (iter->z == ic.z0_index)
			{
				if (iter->y == ic.y0_index)
				{
					end_index = ic.x0_index;
				}
				if (iter->y == (ic.y0_index + 1))
				{
					end_index = ic.x0_index;
				}
			}
		}
		else // We chose the diagonal stencil
		{
			if (iter->z == ic.z0_index)
			{
				if (iter->y == ic.y0_index)
				{
					end_index = ic.x0_index;
				}
				if (iter->y == (ic.y0_index + 1))
				{
					end_index = ic.x0_index + 1;
				}
			}
		}

		// Solve the heat equation.
		heat.nextstep((*p_next)[iter->z][iter->y], (*p_cur)[iter->z][iter->y], end_index);
	}

	// Swap the arrays for the next iteration.
	pinned_space *tmp_ptr = p_cur;
	p_cur = p_next;
	p_next = tmp_ptr;
}
