数据挖掘 -- FP-Tree关联规则算法

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/l1832876815/article/details/89371965
1. 算法原理

FP-Tree相对于Apriori算法,减少了I/O的次数,原理是先找到原数据的频繁1项集,即项头表。得到后按照项头表的sup值给初始表排序。并且创建树形结构,每个节点存节点名称和出现次数。将初始表迭代放入树中,建树过程完成。挖掘过程是倒序遍历项头表,对于每个s,寻找s在树中到根的路径,组合其余分支的s, 父节点的sup值为所有s节点的sup值之和。得到频繁项集。最终求出最大频繁项集即可

2.代码实现
package com.clxk1997;

/**
 * @Description 单个数据节点 name->cnt
 * @Author Clxk
 * @Date 2019/4/15 20:53
 * @Version 1.0
 */
public class Data implements Comparable{

    private String name;

    private int cnt;

    public Data() {

    }

    public Data(String name, int cnt) {
        this.name = name;
        this.cnt = cnt;
    }

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public int getCnt() {
        return cnt;
    }

    public void setCnt(int cnt) {
        this.cnt = cnt;
    }

    @Override
    public int compareTo(Object o) {
        if(this.cnt > ((Data)o).getCnt()) return 1;
        return 0;
    }
}

package com.clxk1997;

import java.util.ArrayList;
import java.util.List;

/**
 * @Description FP-Tree树
 * @Author Clxk
 * @Date 2019/4/15 20:52
 * @Version 1.0
 */
public class Node {

    Data data;
    ArrayList<Node> child = new ArrayList<>();
    Node parent;



    public Data getData() {
        return data;
    }

    public void setData(Data data) {
        this.data = data;
    }

    public ArrayList<Node> getChild() {
        return child;
    }

    public void setChild(ArrayList<Node> child) {
        this.child = child;
    }

    public Node getParent() {
        return parent;
    }

    public void setParent(Node parent) {
        this.parent = parent;
    }

    /**
     * 初始化树
     * @return
     */
    public static Node init() {
        Node node = new Node();
        node.setParent(null);
        node.setData(null);
        node.setChild(new ArrayList<>());
        return node;
    }

    /**
     * 添加ArrayList到树
     */
    public static Node putList2Tree(ArrayList<String> list, Node root) {
        Node curNode = null;
        Node parent = root;
        int cnt = 0;
        while(true) {
            if (list == null || list.size() <= cnt) return root;
            ArrayList<Node> child = parent.getChild();

            for (int i = 0; i < child.size(); i++) {
                if (child.get(i).getData().getName().equals(list.get(cnt))) {
                    curNode = child.get(i);
                    break;
                }
            }
            /**
             * 没有找到
             */
            if (curNode == null) {
                curNode = new Node();
                curNode.setData(new Data(list.get(cnt), 1));
                curNode.setParent(parent);
                curNode.setChild(new ArrayList<>());
                child.add(curNode);
            } else {
                curNode.getData().setCnt(curNode.getData().getCnt() + 1);
            }
            parent = curNode;
            cnt++;
            curNode = null;
            Main.leaf.add(parent);

        }
    }

    /**
     * 深搜遍历
     * @param root
     */
    public static void dfs(Node root) {

        ArrayList<Node> child = root.getChild();
        for(int i = 0; i < child.size(); i++) {
            System.out.println(child.get(i).getData().getName() + " " + child.get(i).getData().getCnt());
            if(child != null) dfs(child.get(i));
        }
    }

    /**
     * 获取某个节点的所有子节点包含data.name的和
     * @param node
     * @param data
     * @return
     */
    public static int getAllChildCount(Node node, Data data) {

        int t = 0;
        if(node == null) return 0;
        if(node.getData().getName().equals(data.getName())) {
            t += node.getData().getCnt();
            return t;
        }
        for(Node n : node.getChild()) {
            if(node.getData().getName().equals(data.getName())) {
               t += node.getData().getCnt();
               return t;
            } else {
                t += getAllChildCount(n, data);
            }
        }
        return t;
    }

    /**
     * 获取节点深度
     * @param node
     * @return
     */
    public static int getDepth(Node node) {
        if(node.getParent() == null) return 1;
        return getDepth(node.getParent()) + 1;
    }

}

package com.clxk1997;

import java.util.*;

public class Main {

    /**
     * 数据集最大值
     */
    private static final int MAXN = 3000;
    /**
     * 原始数据集
     */
    private static ArrayList<String> data[] = new ArrayList[MAXN];
    /**
     * 项头表
     */
    private static ArrayList<Data> list = new ArrayList<>();
    /**
     *
     */
    private static Node root;
    private static Node curNode;
    /**
     * 数据集大小和最小支持度
     */
    private static int datacnt;
    private static int minsupport;

    /**
     * 所有叶节点的集合
     */
    public static List<Node> leaf = new ArrayList<>();

    public static void main(String[] args) {

        Scanner scanner = new Scanner(System.in);
        System.out.println("请输入数据集大小: ");
        datacnt = scanner.nextInt();
        System.out.println("请输入最小支持度: ");
        minsupport = scanner.nextInt();
        System.out.println("请输入原始数据集: ");
        scanner.nextLine();
        for(int i = 0; i < datacnt; i++) {
            data[i] = new ArrayList<>();
            String s = scanner.nextLine();
            String[] split = s.split("\\s");
            for(int j = 0; j < split.length; j++) {
                data[i].add(split[j]);
            }
        }

        /**
         * 处理数据
         */
        solve();

    }

    /**
     * 处理数据
     */
    public static void solve() {

        for(ArrayList<String> it : data) {
            if(it == null) break;
            for(String s: it) {
                putIntoList(s);
            }
        }

        /**
         * 排序输出满足的项头表
         */
        sortAndOut();

        /**
         * 输出排序后的数据集
         */
        getSortedList();
        /**
         * 将list放入Tree中
         */
        root = Node.init();
        for(int i = 0; i < data.length; i++) {
            if(data[i] == null) break;
            Node.putList2Tree(data[i],root);
        }

        /**
         * 遍历整棵树
         */
        System.out.println("深搜遍历树: ");
        Node.dfs(root);


        /**
         * 倒序遍历项头表寻找频繁项集
         */
        ArrayList<Data> ansdata = new ArrayList<>();
        int maxd = 0, maxlen = 0;
        for(int i = list.size() - 1; i >= 0; i--) {
            int curd = 0;
            Node node = getDepthNode(list.get(i));
            ArrayList<Data> curdata = new ArrayList<>();
            searchFrequence(node, list.get(i), curdata);
            System.out.println("项头: " + list.get(i).getName() + "的最大频繁项集是: ");
            curd = outData(curdata);
            if(curdata.size() > maxlen || (curdata.size() == maxlen && curd > maxd)) {
                maxlen = curdata.size();
                ansdata = (ArrayList<Data>) curdata.clone();
            }
        }
        System.out.println("所以最终频繁项集为: ");
        outData(ansdata);
    }

    /**
     * 获取data.name在树中最深的节点
     * @param data
     * @return
     */
    public static Node getDepthNode(Data data) {
        Node node = null;
        int depth = 0;
        for(Node n : leaf) {
            if(n.getData().getName().equals(data.getName())) {
                int cnt = Node.getDepth(n);
                if(cnt > depth) {
                    depth = cnt;
                    node = n;
                }
            }
        }
        return node;
    }

    /**
     * 寻找频繁项集
     * @param node
     * @param data
     * @param curdata
     */
    public static void searchFrequence(Node node, Data data, ArrayList<Data> curdata) {
        if(node.getData() == null) return;
        Data data1 = new Data();
        data1.setName(node.getData().getName());
        int t = Node.getAllChildCount(node, data);
        data1.setCnt(t);
        curdata.add(data1);
        searchFrequence(node.getParent(), data, curdata);
    }


    /**
     * 输出排序后的数据集
     */
    public static void getSortedList() {
        ArrayList<String> cur[] = new ArrayList[MAXN];

        for(int i = 0; i < datacnt; i++) {
            cur[i] = new ArrayList<>();
            ArrayList<String> str = data[i];
            for(int j = 0; j < list.size(); j++) {
                if(str.contains(list.get(j).getName())) {
                    cur[i].add(list.get(j).getName());
                }
            }
        }
        data = cur.clone();

        System.out.println("排序后的数据集: ");
        for(int i = 0; i < data.length; i++) {
            if(data[i] == null) break;
            System.out.println(Arrays.toString(data[i].toArray()));
        }

    }

    public static void sortAndOut() {

        list.sort(new Comparator<Data>() {
            @Override
            public int compare(Data o1, Data o2) {
                if(o1.getCnt() < o2.getCnt()) return 1;
                else if(o1.getCnt() == o2.getCnt()) return 0;
                return -1;
            }
        });

        for(int i = 0; i < list.size(); i++) {
            if(list.get(i).getCnt() < minsupport) {
                list.remove(i);
                i--;
            }
        }
        System.out.println("满足支持度的项头表: ");
        for(int i = 0; i < list.size(); i++) {
            System.out.println(list.get(i).getName() + " " + list.get(i).getCnt());
        }
    }

    /**
     * 将字符串放入List,自动合并
     * @param s
     */
    public static void putIntoList(String s) {

        Data data = new Data();
        data.setName(s);
        data.setCnt(1);
        for(int i = 0; i < list.size(); i++) {
            if(list.get(i).getName().equals(s)) {
                list.set(i,new Data(s, list.get(i).getCnt() + 1));
                return;
            }
        }
        list.add(data);
    }

    public static int outData(ArrayList<Data> curdata) {
        System.out.print("[");
        int curd = 0;
        for(int j = 0; j < curdata.size(); j++) {
            if(j != 0) System.out.print("  ");
            System.out.print(curdata.get(j).getName() + "," + curdata.get(j).getCnt());
            curd += curdata.get(j).getCnt();
        }
        System.out.println("]");
        return curd;
    }

}

猜你喜欢

转载自blog.csdn.net/l1832876815/article/details/89371965