#include <iostream>
#include <string.h>
#include <strlib.h>
#include <error.h>
#include <vector.h>
#include "math.h"
#include "tokenscanner.h"
#include "map.h"
#include "queue.h"
using namespace std;

// Storing the tree in a data structure requires the creation of the following structure, the node,
// which represents a species merging event and a node in the species tree.
struct node{
    node* left = NULL; // Left daughter node
    node* right = NULL; // Right daughter node
    node* parent = NULL; // Parent node
    double time; // Branch length
    bool isLeaf = false;  // updated from input
    bool isInSubtree = false; // minimal S subtree; updated from markSubtree
    Vector<int> initial; // 2 elements: S, C initial vaues (for isLeaf == true only, otherwise empty)
    Vector<int> bounds; // 2 elements: Used in marking minimal S subtree and setting summation bounds; the number of initial S and C lineages below the node
};

// This function gets the next phrase (subtree) from a string (phrase) that represents a tree.
// If phrase = (((AB)(CD))E), then nextPhrase will return ((AB)(CD)) and update pos so that calling nextPhrase on phrase again will return E.
// Calling nextPhrase on ((AB)(CD)) will return (AB) and update pos so that calling nextPhrase on ((AB)(CD)) again will return (CD).
// Used in parsing input text to construct tree structure.
string nextPhrase(string phrase, int& pos){
    int leftCount = 0; // Counts number of open left parentheses (when this number reaches 0 again, the phrase is over).
    int currPos = pos; // Keeps track of current position while reading the string "phrase"
    // This loop reads through the string "phrase" and finds the point where all left parentheses are closed.
    while(true){
        if(phrase[currPos] == '(') leftCount++;
        else if(phrase[currPos] == ')') leftCount--;
        if(leftCount == 0) break;
        else currPos++;
    }
    // The subtring of "phrase" that represents a complete subtree is stored and returned as the next phrase.
    string next = phrase.substr(pos, currPos + 1 - pos);
    pos = currPos + 1;
    return next;
}


// This function builds a tree structure from input text.  It is recursive and uses nextPhrase to walk along the input string.
// The recursion stops at leaves, so in the above example if nextPhrase returns (AB) the recursion will continue with A and B but
// if nextPhrase returns E then the recursion stops.
node* buildTree(string phrase){
    int pos = 1;
    node* leftNode = new node;
    node* rightNode = new node;
    string phrase1 = nextPhrase(phrase, pos); //nextPhrase updates pos
    string phrase2 = nextPhrase(phrase, pos);
    if(phrase1[0] == '('){
        leftNode = buildTree(phrase1);
    }
    else {
        leftNode->isLeaf = true;
    }
    if(phrase2[0] == '('){
        rightNode = buildTree(phrase2);
    }
    else{
        rightNode->isLeaf = true;
    }
    node* currNode = new node;
    if(leftNode != NULL) leftNode->parent = currNode;
    if(rightNode != NULL) rightNode->parent = currNode;
    currNode->left = leftNode;
    currNode->right = rightNode;
    if(currNode->left == NULL && currNode->right == NULL) currNode->isLeaf = true;
    return currNode;
}

// Prints the tree structure in order (debugging code)
void inorder(node* root){
    if(root == NULL) return;
    inorder(root->left);
    cout << root << " " << root->time << " " << root->isLeaf << " " << root->isInSubtree << root->initial << endl;
    inorder(root->right);
    return;
}

// Puts pointers to nodes in the map that stores calculation results for memoization purposes.
void getNodes(Map<node*, Map<string, long double> >& calcs, node* currNode){
    Map<string, long double> newMap;
    if(currNode == NULL) return;
    getNodes(calcs, currNode->left);
    calcs.put(currNode, newMap);
    getNodes(calcs, currNode->right);
    return;
}

// Calculates the upper bounds for a particular node (the root).
// You must use the node in question in both arguments when calling it first.
void getBounds(node* currNode, node* root){
    if(currNode == NULL) return;
    getBounds(currNode->left, root);
    getBounds(currNode->right, root);
    if(currNode->isLeaf){
    for(int i = 0; i < 3; i++){
        if(root->bounds.size() >= i + 1){
            int currBound;
            currBound = root->bounds.get(i);
            currBound += currNode->initial.get(i);
            root->bounds.set(i, currBound);
        }
        else root->bounds.add(0);
    }
    }
    return;
}

// Gets the total number of input S lineages for the purposes of marking the S subtree before the initialS vector is destroyed by the loading process.
int getTotalS(Vector<int> initialS){
    int totalS = 0;
    for(int i = 0; i < initialS.size(); i++){    //Counting the total number of S lineages in the tree.
    totalS += initialS.get(i);
    }
    return totalS;
}

// Calculates the number of lineages of type "type" (0=S, 1=C, 2=M) below the current node.  The 64 is in honor of current Tigers LHP Duane Below, who is frankly not very good, but still.
int numBelow64(node* currNode, int type){
    int output;
    if(currNode->isLeaf == true) output =  currNode->initial.get(type);
    else output = currNode->left->bounds.get(type) + currNode->right->bounds.get(type);
    return output;

}

// Goes down the tree once in a post-order traversal to mark at least a minimally-defining subset of the entire S-subtree.
// Also frames the getting of the bounds.  Those bounds will be used in the S subtree determination.
void nodeCheck(node* currNode, int totalS){
    if(currNode == NULL) return;
    nodeCheck(currNode->left, totalS);
    nodeCheck(currNode->right, totalS);
    for(int i = 0; i < 2; i++){
        if(currNode->bounds.size() < i + 1) currNode->bounds.add(numBelow64(currNode, i));
        else currNode->bounds.set(i, numBelow64(currNode, i));
    }
    if(currNode->bounds.get(0) > 0 && currNode->bounds.get(0) < totalS) currNode->isInSubtree = true;
    return;
}

// Marks every node not already marked below a given root.  Uses a pre-order traversal for funsies.
void finishMarking(node* currNode){
    if(currNode == NULL) return;
    if(!currNode->isInSubtree) currNode->isInSubtree = true;
    finishMarking(currNode->left);
    finishMarking(currNode->right);
    return;
}

// Finds the root of the S-subtree using a breadth-first search.
node* findSubRoot(node* root){
    Queue<node*> nodeQueue;
    node* currNode;
    nodeQueue.enqueue(root);
    while(!nodeQueue.isEmpty()){
        //cout << "test" << endl;
        currNode = nodeQueue.dequeue();
        if((currNode->left != NULL && currNode->left->isInSubtree) && (currNode->right != NULL && currNode->right->isInSubtree)) return currNode;
        else{
            if(currNode->left != NULL) nodeQueue.enqueue(currNode->left);
            if(currNode->right != NULL) nodeQueue.enqueue(currNode->right);
        }
    }
    return NULL;  // This happens if the subtree is just a leaf; nothing should be marked.

}

// When marking the subtree, the algorithm is as follows.  We count the total number of S lineages in the tree, and then we go and calculate the total number of S lineages that are in the leaves of a particular node (by summing up a value from the left and right).  If this number is strictly less than the total AND greater than 0, then the node is in the subtree.  If the two numbers are equal, then the node is not in the subtree (the first time they are equal it technically is, but then it is the root of that subtree and we don't actually want to mark it).  If this number is 0, then it is preliminarily not in the subtree.  We then go and find the node whose two daughters are in the subtree (the root of the subtree) and make sure that all nodes descended from it are marked in the subtree (this is to capture all the ones that were 0 but still in the subtree).  This should work.
void markSubtree(node* root, int totalS){
    nodeCheck(root, totalS); //doing the initial markings
    //inorder(root);
    node* subRoot = findSubRoot(root); //finding the root of the S-subtree
    //inorder(root);
    if(subRoot != NULL){
        finishMarking(subRoot->left); //finishing marking the entire S-subtree
        finishMarking(subRoot->right); //note that we do not want to mark the root of the subtree.
    }
    return;
}

// Obtains vector of initial samples from string
Vector<int> extractInitialSamples(string initial){
    Vector<int> initialVec;
    int size = initial.size();
    for(int i = 0; i < size; i++){
        char next = initial[i];
        int nextint = next;
        initialVec.add(nextint);
    }
    return initialVec;
}

// Adds branch lengths to tree structure
void loadTimes(node* root, Vector<double>& times){
    if(root == NULL) return;
    loadTimes(root->left, times);
    root->time = times.get(times.size() - 1);
    times.remove(times.size() - 1);
    loadTimes(root->right, times);
    return;
}

// Adds initial samples to tree structure
void loadInitialSamples(node* root, Vector<int>& initial){
    if(root == NULL) return;
    loadInitialSamples(root->left, initial);
    if(root->isLeaf){
        root->initial.add(initial.get(initial.size() - 1));
        initial.remove(initial.size() - 1);
    }
    loadInitialSamples(root->right, initial);
    return;
}

// Pre-order printing of tree (debugging)
void pre(node* root){
    if(root == NULL) return;
    cout << root << " " << root->time << " " << root->isLeaf << " " << root->isInSubtree << root->initial << endl;
    pre(root->left);
    pre(root->right);
    return;
}

// Post-order printing of tree (debugging)
void post(node* root){
    if(root == NULL) return;
    post(root->left);
    post(root->right);
    cout << root << " " << root->time << " " << root->isLeaf << " " << root->isInSubtree << root->initial << endl;
    return;
}

// Reads tree initialization; not sure if actually used
void readinitials(TokenScanner scanner, Vector<int>& initial){
    string next;
    while(true){
        next = scanner.nextToken();
        cout << next << endl;
        if(next == ",") continue;
        if(next == ";") return;
        else if(isalnum(next[0])) initial.add(stringToInteger(next));
        cout << initial << endl;
    }
}

// converts string of times to a vector
void getTimes(string timesstr, Vector<double>& times){
    for(int i = 0; i < timesstr.size(); i++){
        string tempstring = timesstr.substr(i, 1);
        int THEHEAT = stringToInteger(tempstring);
        times.add(THEHEAT);
    }
}

// Non-overflow implementation of binomial coefficient
double biCo(int a, int b){
    double result = 1;
    for(double i = 0; i < b; i++){
        result = result * (1.0 + (a - b) / (b - i));
    }
    return result;
}

// Tavare's g function
long double gFunction(int n, int j, double T){
    long double value = 0;
    if(j > n) return 0;
    else if(n == 0 || j == 0) return 0;
    else if(T > 5000) {
        if(j == 1) return 1;
        else return 0;
    }
    else if(T >= 0.1 && n >= 90 && j >= 50) return 0;
    else if(T >= 1 && n >= 20 && j >= 10) return 0;
    else{
        for(int k = j; k <= n; k++){
            long double exponent, val, num, den;
            exponent = exp((-k) * (k - 1) * T / 2);
            int sign;
            sign = pow((-1), k - j);
            if(k == 1){
                val = exponent * sign;
            }
            else{
                num = (2 * k - 1) * biCo(j + k - 2, j) * biCo(n - 1, k - 1) * biCo(k - 1, j - 1);
                den = (n + k - 1) * biCo(n + k - 2, n);
                val = exponent * sign * num / den;
            }
            value += val;
        }
        if(value < 0.000001) value = 0;
        return value;
    }
}

// Combinatorial term for M_S
long double bigF(int s_L, int s_R, int c_L, int c_R, int m_L, int m_R, int s, int c, int m, double T, bool isInSubtree){
    int s_I = s_L + s_R;
    int c_I = c_L + c_R;
    int m_I = m_L + m_R;
    if(s_I + c_I + m_I < s + c + m) return 0;
    else{
        long double K;
        if(s_I == 0 && c_I > 0 && m_I == 0 && c_I >= c && c > 0 && s == 0 && m == 0) {
            K = 1;
        }
        else if(s_I == 0 && c_I > 0 && m_I == 1 && c_I >= c && s == 0 && m == 1) {
            K = 1;
        }
        else if(s_I == 0 && c_I == 0 && m_I == 1 && c == 0 && s == 0 && m == 1){
            K = 1;
        }
        else if(s_I > 0 && isInSubtree == false && m_I == 0 && s == 0 && m == 1 && c_I >= c && c_I > 0){
            K = 0;
            for(int k = c + 1; k <= c_I; k++){
                K += 2 * biCo(c_I - 1, c_I - k) / biCo(s_I + c_I, s_I) / biCo(s_I + c_I - 1, k);
            }
        }
        else if(s_I > 0 && s > 0 && s_I >= s  && c_I == 0 && c == 0 && m_I == 0 && m == 0){
            K = 1;
        }
        else if(s_I == 0 && s == 0 && c_I >= c  && c_I > 0 && c > 0 && m_I == 0 && m == 0){
            K = 1;
        }
        else if(s_I > 0 && c_I > 0 && c > 0 && s > 0 && s_I >= s && c_I >= c && m_I == 0  && m == 0){
            K = biCo(s + c, s) * biCo(s_I - 1, s - 1) * biCo(c_I - 1, c - 1) / biCo(s_I + c_I - 1, s + c - 1) / biCo(s_I + c_I, s_I);
        }
        else K = 0;
        long double output;
        long double gvalue = gFunction(s_I + c_I + m_I, s + c + m, T);
        if(K == 0 || gvalue == 0) output = 0;
        else output = gvalue * K;
        return output;
    }
}

// Combinatorial term for M_{SC}
long double bigFRM(int s_L, int s_R, int c_L, int c_R, int m_L, int m_R, int s, int c, int m, double T, bool isRoot){
    int s_I = s_L + s_R;
    int c_I = c_L + c_R;
    int m_I = m_L + m_R;
    if(s_I + c_I + m_I < s + c + m) {
        return 0;
    }
    else{
        long double K;
        if(s_I == 0 && c_I > 0 && m_I == 0 && c_I >= c && c > 0 && s == 0 && m == 0) {
            K = 1;
        }
        else if(s_I == 0 && c_I == 0 && m_I == 1 && c == 0 && s == 0 && m == 1){
            K = 1;
        }
        else if(s_I > 0 && s > 0 && s_I >= s  && c_I == 0 && c == 0 && m_I == 0 && m == 0){
            K = 1;
        }
        else if(s_I > 0 && c_I > 0 && m_I == 0 && s == 0 && c == 0 && m == 1 && isRoot == true){
            K = 2 / (biCo(s_I + c_I, s_I) * (s_I + c_I - 1));
        }
        else if(s_I > 0 && c_I > 0 && m_I == 0 && s > 0 && c > 0 && s_I >= s && c_I >= c && m == 0){
            K = biCo(s + c, s) * biCo(s_I - 1, s - 1) * biCo(c_I - 1, c - 1) / biCo(s_I + c_I - 1, s + c - 1) / biCo(s_I + c_I, s_I);
        }
        else K = 0;
        long double output;
        long double gvalue = gFunction(s_I + c_I + m_I, s + c + m, T);
        if(K == 0 || gvalue == 0) output = 0;
        else output = gvalue * K;
        //if(output != output) cout << gvalue << " " << K << " " << endl;
        return output;
    }
}

// Computes the probability (Eq. 9 I think).  Walks recursively through the tree structure starting at the root.
// This function keeps track of the "level" of the tree it is at.
long double calculateProb(node* root, int s, int c, int m, int UB_s, int UB_c, Map<node*, Map<string, long double> >& calcs, bool rm, int level){
    level++;
    bool isRoot;
    if(level == 0){
        isRoot = true;
    }
    else(isRoot = false);
    long double value = 0, prod;
    for(int s_L = 0; s_L <= UB_s; s_L++){
        for(int s_R = 0; s_R <= UB_s - s_L; s_R++){
            for(int c_L = 0; c_L <= UB_c; c_L++){
                for(int c_R = 0; c_R <= UB_c - c_L; c_R++){
                    for(int m_L = 0; m_L <= 1; m_L++){
                        for(int m_R = 0; m_R <= 1 - m_L; m_R++){
                            if(s_L + c_L + m_L > 0 || s_R + c_R + m_R > 0){
                                    if(s_L + c_L + m_L + s_R + c_R + m_R >= s + c + m){
                                    long double leftProb, rightProb;
                                    long double fResult;
                                    if(rm == true){
                                        fResult = bigFRM(s_L, s_R, c_L, c_R, m_L, m_R, s, c, m, root->time, isRoot);
                                    }
                                    else if(rm == false) {
                                        fResult = bigF(s_L, s_R, c_L, c_R, m_L, m_R, s, c, m, root->time, root->isInSubtree);
                                    }
                                    else error("rm not bool");
                                    if(fResult == 0) prod = 0;
                                    else{
                                        if(root->left == NULL && root->right == NULL) {

                                            if(s_L == root->initial[0] && c_R == 0 && m_L == 0 && s_R == 0 && c_L == root->initial[1] && m_R == 0) {
                                                leftProb = 1;
                                                rightProb = 1;
                                            }
                                            else {
                                                leftProb = 0;
                                                rightProb = 0;
                                            }
                                        }
                                        else {
                                            string config;
                                            config += integerToString(s_L);
                                            config += integerToString(c_L);
                                            config += integerToString(m_L);
                                            Map<string, long double> nodeCalc = calcs.get(root->left);
                                            if(nodeCalc.containsKey(config)){
                                                leftProb = nodeCalc.get(config);
                                            }
                                            else{
                                                leftProb = calculateProb(root->left, s_L, c_L, m_L, UB_s, UB_c, calcs, rm, level);
                                                nodeCalc.put(config, leftProb);
                                                calcs.put(root->left, nodeCalc);
                                            }
                                            config.clear();
                                            config += integerToString(s_R);
                                            config += integerToString(c_R);
                                            config += integerToString(m_R);
                                            nodeCalc = calcs.get(root->right);
                                            if(nodeCalc.containsKey(config)){
                                                rightProb = nodeCalc.get(config);
                                            }
                                            else{
                                                rightProb = calculateProb(root->right, s_R, c_R, m_R, UB_s, UB_c, calcs, rm, level);
                                                nodeCalc.put(config, rightProb);
                                                calcs.put(root->right, nodeCalc);
                                            }
                                            config.clear();
                                        }
                                        prod = leftProb * rightProb * fResult;
                                      }
                                    value += prod;
                                  }
                            }
                        }
                    }
                }
            }
        }
    }
    level--;
    return value;
}
