首先下载代码:https://github.com/whk6688/rnn

例子1:预测下文

private void train(CharText ctext, double lr) {
        Map<Integer, String> indexChar = ctext.getIndexChar();
        Map<String, DoubleMatrix> charVector = ctext.getCharVector();
        List<String> sequence = ctext.getSequence();
        for (int i = 0; i < 100; i++) {
            double error = 0;
            double num = 0;
            double start = System.currentTimeMillis();
            for (int s = 0; s < sequence.size(); s++) {
                String seq = sequence.get(s);
                if (seq.length() < 3) {
                    continue;
                }

                Map<String, DoubleMatrix> acts = new HashMap<>();
                // forward pass
                System.out.print(String.valueOf(seq.charAt(0)+"->"));
                for (int t = 0; t < seq.length() - 1; t++) {
                    DoubleMatrix xt = charVector.get(String.valueOf(seq.charAt(t)));
                    acts.put("x" + t, xt);

                    gru.active(t, acts);

                    DoubleMatrix predcitYt = gru.decode(acts.get("h" + t));
                    acts.put("py" + t, predcitYt);
                    DoubleMatrix trueYt = charVector.get(String.valueOf(seq.charAt(t + 1)));
                    acts.put("y" + t, trueYt);

                    System.out.print(indexChar.get(predcitYt.argmax()));
                    //error += LossFunction.getMeanCategoricalCrossEntropy(predcitYt, trueYt);

                }

                System.out.println();

                // bptt
                gru.bptt(acts, seq.length() - 2, lr);

                num +=  seq.length();
            }
            System.out.println("Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s");
        }
    }

    private void test(CharText ctext) {

        Map<Integer, String> indexChar = ctext.getIndexChar();
        Map<String, DoubleMatrix> charVector = ctext.getCharVector();
        Map<String, DoubleMatrix> acts = new HashMap<>();     
        String seq="不";     
        int t=0;
        DoubleMatrix xt = charVector.get(String.valueOf(seq.charAt(t)));
        acts.put("x" + t, xt);
        gru.active(t, acts);         
        DoubleMatrix predcitYt = gru.decode(acts.get("h" + t));
        acts.put("py" + t, predcitYt); 
        System.out.print(indexChar.get(predcitYt.argmax()));
    }

训练的文本为:

行尸走肉
金蝉脱壳
百里挑一
金玉满堂
不花不四
不花千里
不花×××
背水一战
霸王别姬
天上人间
不吐不快
海阔天空
情非得已
满腹经纶
兵临城下
春暖花开
插翅难逃
黄道吉日
天下无双
偷天换日
两小无猜
卧虎藏龙
珠光宝气
簪缨世族
×××
绘声绘影
国色天香
相亲相爱
八仙过海
金玉良缘
掌上明珠
皆大欢喜
逍遥法外

当输入“不”时,下一个词会提示为“花”。 因为此算法是有时间概念的,因此当你在加入两条不字开头的成语,会发现结果不同。

例子二:预测结果

public static void main(String[] args) {  
        loadData();  
        int hiddenSize = 4;//隐含层数量  
        double lr = 0.1;  
        gru = new GRU(4, hiddenSize, new MatIniter(MatIniter.Type.Uniform, 0.1, 0, 0),3);//4是输入层,3是输出层  
        for (int i = 0; i < 2000; i++) {//迭代2000次  
            double error = 0;  
            double num = 0;  
            double start = System.currentTimeMillis();  
            Map<String, DoubleMatrix> acts = new HashMap<>();  
            for (int s = 0; s < train_x.length; s++) {  
                double newx[][] = new double[1][4];  
                newx[0] = train_x[s];  
                DoubleMatrix xt = new DoubleMatrix(newx);//获取字的矩阵  
                //System.out.println(xt.getColumns()+" "+xt.getRows());  
                acts.put("x" + s, xt);  
                gru.active(s, acts);  
                DoubleMatrix predcitYt = gru.decode(acts.get("h" + s));  
                acts.put("py" + s, predcitYt);  

                double newy[][] = new double[1][3];  
                newy[0] = train_y[s];  
                DoubleMatrix trueYt = new DoubleMatrix(newy);  
                acts.put("y" + s, trueYt);  

                //System.out.println(predcitYt.argmax()+"-->"+trueYt.argmax());

                if(predcitYt.argmax()!=trueYt.argmax())  
                    error++;  

                // bptt  
                num ++;  
            }  
            gru.bptt(acts, train_x.length-1, lr);  
            System.out.println("Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s");  
        }//结束迭代  

        //开始测试  
        int num = 0,error = 0;  
        Map<String, DoubleMatrix> acts = new HashMap<>();  
        for(int s = 0; s<test_x.length;s++){  
            double newx[][] = new double[1][4];  
            newx[0] = test_x[s];  
            DoubleMatrix xt = new DoubleMatrix(newx);  
            acts.put("x" + s, xt);  
            gru.active(s, acts);  
            DoubleMatrix predcitYt = gru.decode(acts.get("h" + s));
            acts.put("py" + s, predcitYt);

            double newy[][] = new double[1][3];  
            newy[0] = test_y[s];  
            DoubleMatrix trueYt = new DoubleMatrix(newy);  
            acts.put("y" + s, trueYt);  
            if(predcitYt.argmax()!=trueYt.argmax())  
                error++;  
            // bptt  
            num ++;  
        }  
        System.out.println("错误数:"+error+"/"+num);  
    }  

这个例子来预测花的种类,当然也可以使用决策树来实现。换种方式也感觉挺好

引用:https://blog.csdn.net/czs1130/article/details/70717348