对于简单的四则运算使用递归下降生成抽象语法树

四则运算的文法,来源于网上:

Expr      ->    Term ExprTail
ExprTail  ->    + Term ExprTail
          |     - Term ExprTail
          |     null

Term      ->    Factor TermTail
TermTail  ->    * Factor TermTail
          |     / Factor TermTail
          |     null

Factor    ->    (Expr)
          |     num

       如果只是递归下降分析的话,代码非常简单,每一个非终结符对应一个函数,这样就完成了递归下降的分析:

#include<iostream>
#include<fstream> 
#include<string>
#include<stack>

using namespace std;

bool Exp(const char* ch);
bool Exp1(const char* ch);
bool Term(const char* ch);
bool Term1(const char* ch);
bool Factor(const char* ch);

int pos = 0;

int main()
{
    char ch[200];
    while (cin >> ch)
    {
        if (Exp(ch))
        {
            //右括号多余的判断
            if (')' == ch[pos])
            {
                cout << "wrong!!" << endl;
            }
            else
            {
                cout << "success!!" << endl;
            }
        }
        else
        {
            cout << "wrong!!" << endl;
        }
        memset(ch, 0, sizeof(ch));
        pos = 0;
    }
    return 0;
}
bool Exp(const char* ch)
{
    if (!Term(ch))
    {
        return false;
    }
    else
    {
        return Exp1(ch);
    }
}
bool Exp1(const char* ch)
{
    if ('+' == ch[pos] || '-' == ch[pos])
    {
        pos++;
        if (!Term(ch))
        {
            return false;
        }
        else
        {
            return Exp1(ch);
        }
    }
    return true;

}
bool Term(const char* ch)
{
    if (!Factor(ch))
    {
        return false;
    }
    else
    {
        return Term1(ch);
    }
}
bool Term1(const char* ch)
{
    if ('*' == ch[pos] || '/' == ch[pos])
    {
        pos++;
        if (!Factor(ch))
        {
            return false;
        }
        else
        {
            return Term1(ch);
        }
    }
    return true;
}
bool Factor(const char* ch)
{
    if ('(' == ch[pos])
    {
        pos++;
        if (!Exp(ch))
        {
            return false;
        }
        else if (')' != ch[pos])
        {
            return false;
        }
        pos++;
        return false;
    }
    else if ('i' == ch[pos])
    {
        pos++;
        return true;
    }
    return false;
}

       在此基础上,我们要生成抽象语法树,因为四则运算的抽象语法树一定是二叉树,所以节点信息只需要包含标识符,左右子树即可,为了简单,这里将所有的数字用i来表示,建立运算符栈和数字节点栈。依次扫描输入并对其进行递归下降分析,按照如下原则进行出入栈操作:
       1. 遇到(或者i,将其包装成节点压入运算符栈;
       2. 遇到+或者-,扫描运算符栈。如果运算符栈为空或者栈顶元素是(,则停止扫描,将新的运算符包装成节点压入运算符栈;否则弹出栈顶元素,从数字节点栈中弹出两个节点作为其的左右子树,构成新的节点压入数字节点栈。
       3. 遇到*或者/,扫描运算符栈。如果运算符栈为空或者栈顶元素是(+或者-,则停止扫描,将新的运算符包装成节点压入运算符栈;否则弹出栈顶元素,从数字节点栈中弹出两个节点作为其的左右子树,构成新的节点压入数字节点栈。
       4. 遇到),说明要对括号中的运算进行合并。扫描运算符栈,直到栈顶元素是(,将其弹出,结束;否则弹出栈顶元素,从数字节点栈中弹出两个节点作为其的左右子树,构成新的节点压入数字节点栈。
       按照上述步骤完成递归下降后,还要对未完成的节点进行合并,最终就可以得到语法树。
完整代码如下:

#include<iostream>
#include<fstream> 
#include<string>
#include<stack>

using namespace std;
/*
四则运算文法
Expr      ->    Term ExprTail
ExprTail  ->    + Term ExprTail
          |     - Term ExprTail
          |     null

Term      ->    Factor TermTail
TermTail  ->    * Factor TermTail
          |     / Factor TermTail
          |     null

Factor    ->    (Expr)
          |     num
*/

//节点类
struct node
{
    char token;
    node* left;
    node* right;
    node()
    {
        token = ' ';
        left = NULL;
        right = NULL;
    }
    node(char c)
    {
        token = c;
        left = NULL;
        right = NULL;
    }
};

stack<node*> number;//数字节点栈
stack<node*> op;//运算符栈

bool Exp(const char* ch);
bool Exp1(const char* ch);
bool Term(const char* ch);
bool Term1(const char* ch);
bool Factor(const char* ch);

//扫描输入
int pos = 0;

//打印语法树
void print_tree(node* t)
{
    if (NULL == t)
    {
        return;
    }
    else
    {
        cout << t->token << endl;
        print_tree(t->left);
        print_tree(t->right);
    }
}

int main()
{
    char ch[200] = { 0 };

    while (cin >> ch)
    {
        //是否打印
        bool flag = false;
        if (Exp(ch))
        {
            //右括号多余的判断
            if (')' == ch[pos])
            {
                cout << "wrong!!" << endl;
            }
            else
            {
                cout << "success!!" << endl;
                flag = true;
            }

        }
        else
        {
            cout << "wrong!!" << endl;
        }
        node* temp=new node();
        //对于表达式最外侧的括号(例如(i+i)),最后的节点全部在数字节点栈中
        //因此根节点要在数字栈里找
        if (op.empty())
        {
            temp = number.top();
        }
        //如果运算符栈不为空,则依次出栈构造语法树
        while (!op.empty())
        {
            temp = op.top();
            op.pop();
            //取出运算符后,需要从数字栈拿出两个节点作为左右子树,如果没有两个节点则出错
            if (number.size()<2)
            {
                cout << "illegal!!" << endl;
                break;
            }
            node* t1 = number.top();
            number.pop();
            node* t2 = number.top();
            number.pop();

            temp->left = t2;
            temp->right = t1;
            number.push(temp);
        }
        if (flag)
        {
            print_tree(temp);
        }
        //清除上一次的数据,这个很重要,不然后续会出错
        delete(temp);
        while (!number.empty())
        {
            number.pop();
        }
        while (!op.empty())
        {
            op.pop();
        }
        memset(ch, 0, sizeof(ch));
        pos = 0;

    }

    return 0;
}
bool Exp(const char* ch)
{
    if (!Term(ch))
    {
        return false;
    }
    else
    {
        return Exp1(ch);
    }

}
bool Exp1(const char* ch)
{
    //遇到+和-,如果运算符栈为空或者栈顶是(,则压栈;否则先计算栈里的运算符,因为
    //如果栈顶是+-,则顺序计算先计算栈顶,如果是*,/,则按照优先级先计算*,/
    //从运算栈顶取运算符,从数字栈取两个数字节点作为其左右子节点,组成新节点压入数字节点栈
    //最后将新的符号压栈
    if ('+' == ch[pos] || '-' == ch[pos])
    {
        while (!op.empty())
        {
            node* temp = op.top();
            if ('(' == temp->token)
            {
                break;
            }
            else
            {
                op.pop();
                if (number.size()<2)
                {
                    cout << "illegal!!" << endl;
                    return false;
                }
                node* t1 = number.top();
                number.pop();
                node* t2 = number.top();
                number.pop();
                temp->left = t2;
                temp->right = t1;
                number.push(temp);
            }
        }
        node* n = new node(ch[pos]);
        op.push(n);

        //下一个
        pos++;
        if (!Term(ch))
        {
            return false;
        }
        else
        {
            return Exp1(ch);
        }
    }
    return true;

}
bool Term(const char* ch)
{
    if (!Factor(ch))
    {
        return false;
    }
    else
    {
        return Term1(ch);
    }

}
bool Term1(const char* ch)
{
    //遇到*,/,如果运算符栈为空或者栈顶是(,则压栈;否则与栈顶运算符比较优先级,若栈顶元素是
    //+或者-,则将新的运算符包装成节点压栈;若栈顶元素是*或/,则根据运算顺序先计算栈顶元素
    if ('*' == ch[pos] || '/' == ch[pos])
    {
        while (!op.empty())
        {
            node* temp = op.top();
            if ('(' == temp->token || '+' == temp->token || '-' == temp->token)
            {
                break;
            }
            else
            {
                if (number.size()<2)
                {
                    cout << "illegal!!" << endl;
                    return false;
                }
                op.pop();
                node* t1 = number.top();
                number.pop();
                node* t2 = number.top();
                number.pop();
                temp->left = t2;
                temp->right = t1;
                number.push(temp);

            }
        }

        node* n = new node(ch[pos]);
        op.push(n);

        pos++;
        if (!Factor(ch))
        {
            return false;
        }
        else
        {
            return Term1(ch);
        }
    }
    return true;

}
bool Factor(const char* ch)
{
    if ('(' == ch[pos])
    {
        node* n = new node('(');
        op.push(n);
        pos++;
        if (!Exp(ch))
        {
            return false;
        }
        if (')' == ch[pos])
        {
            while (1)
            {
                if (op.empty())
                {
                    cout << "illegal!!" << endl;
                    return false;
                }
                node* x = op.top();
                if ('(' == x->token)
                {
                    op.pop();
                    break;
                }
                else
                {
                    if (number.size()<2)
                    {
                        cout << "illegal!!" << endl;
                        return false;
                    }
                    node* t1 = number.top();
                    number.pop();
                    node* t2 = number.top();
                    number.pop();
                    x->left = t2;
                    x->right = t1;
                    number.push(x);
                    op.pop();
                }
            }
        }
        else if (')' != ch[pos])
        {
            return false;
        }
        pos++;
        return true;
    }
    else if ('i' == ch[pos])
    {
        node* n = new node('i');
        number.push(n);
        pos++;
        return true;
    }
    return false;
}

       上述代码对于右括号的处理不是很严谨,但目前我没有更好的思路。以后会对算法进行修改。

猜你喜欢

转载自blog.csdn.net/zhang_han666/article/details/80581670