package com.hankcs.hanlp.mining.word2vec;

import com.hankcs.hanlp.utility.Predefine;
import com.sun.xml.bind.v2.runtime.reflect.opt.Const;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.Writer;
import java.nio.charset.Charset;
import java.util.Comparator;

/* loaded from: input_file:WEB-INF/lib/hanlp-portable-1.8.4.jar:com/hankcs/hanlp/mining/word2vec/Word2VecTraining.class */
class Word2VecTraining {
    static final int EXP_TABLE_SIZE = 1000;
    static final int MAX_EXP = 6;
    static final int TABLE_SIZE = 100000000;
    static final int MAX_SENTENCE_LENGTH = 1000;
    long timeStart;
    static double[] syn0;
    static double[] syn1;
    static double[] syn1neg;
    int[] table;
    private final Config config;
    int threadCount;
    static final Charset ENCODING = Charset.forName("UTF-8");
    static final double[] expTable = new double[1001];

    /* loaded from: input_file:WEB-INF/lib/hanlp-portable-1.8.4.jar:com/hankcs/hanlp/mining/word2vec/Word2VecTraining$TrainModelThread.class */
    static class TrainModelThread extends Thread {
        final Word2VecTraining vec;
        final Corpus corpus;
        final Config config;
        float alpha;
        final float startingAlpha;
        final float trainWords;
        final int id;
        final int vocabSize;
        final long timeStart;
        final int[] table;
        final VocabWord[] vocab;
        static int wordCountActual = 0;

        public TrainModelThread(Word2VecTraining word2VecTraining, Corpus corpus, Config config, int i) {
            this.vec = word2VecTraining;
            this.corpus = corpus;
            this.config = config;
            this.alpha = config.getAlpha();
            this.startingAlpha = this.alpha;
            this.id = i;
            this.table = word2VecTraining.table;
            this.trainWords = corpus.getTrainWords();
            this.timeStart = word2VecTraining.timeStart;
            this.vocabSize = corpus.getVocabSize();
            this.vocab = corpus.getVocab();
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            int i;
            int i2;
            int i3;
            long j;
            int i4;
            int i5;
            int i6;
            long j2;
            int i7;
            int i8;
            float iter = this.config.getIter();
            int layer1Size = this.config.getLayer1Size();
            int numThreads = this.config.getNumThreads();
            int window = this.config.getWindow();
            int negative = this.config.getNegative();
            boolean useContinuousBagOfWords = this.config.useContinuousBagOfWords();
            boolean useHierarchicalSoftmax = this.config.useHierarchicalSoftmax();
            float sample = this.config.getSample();
            try {
                int i9 = 0;
                int i10 = 0;
                int i11 = 0;
                int[] iArr = new int[1001];
                long j3 = 0;
                long j4 = 0;
                long j5 = (int) iter;
                long j6 = this.id;
                double[] dArr = new double[layer1Size];
                double[] dArr2 = new double[layer1Size];
                this.corpus.rewind(numThreads, this.id);
                while (true) {
                    if (j3 - j4 > 10000) {
                        wordCountActual = (int) (wordCountActual + (j3 - j4));
                        j4 = j3;
                        long currentTimeMillis = System.currentTimeMillis();
                        float f = wordCountActual / ((iter * this.trainWords) + 1.0f);
                        long j7 = (currentTimeMillis - this.timeStart) + 1;
                        if (this.config.getCallback() == null) {
                            System.err.printf("%cAlpha: %f  iter: %d  Progress: %.2f%%  Words/thread/sec: %.2fk", 13, Float.valueOf(this.alpha), Long.valueOf(j5), Float.valueOf(f * 100.0f), Float.valueOf(wordCountActual / ((float) j7)));
                            String humanTime = Utility.humanTime((((float) j7) / f) * (1.0f - f));
                            if (humanTime.length() > 0) {
                                System.err.printf("  ETD: %s", humanTime);
                            }
                            System.err.flush();
                        } else {
                            this.config.getCallback().training(this.alpha, f * 100.0f);
                        }
                        this.alpha = this.startingAlpha * (1.0f - (wordCountActual / ((iter * this.trainWords) + 1.0f)));
                        if (this.alpha < this.startingAlpha * 1.0E-4d) {
                            this.alpha = this.startingAlpha * 1.0E-4f;
                        }
                    }
                    if (i10 == 0) {
                        while (true) {
                            i9 = this.corpus.readWordIndex();
                            if (i9 == -2) {
                                break;
                            }
                            if (i9 != -1) {
                                j3++;
                                if (i9 == -3) {
                                    break;
                                }
                                if (sample > Const.default_value_float) {
                                    double sqrt = ((Math.sqrt(this.vocab[i9].f9cn / (sample * this.trainWords)) + 1.0d) * (sample * this.trainWords)) / this.vocab[i9].f9cn;
                                    j6 = Word2VecTraining.nextRandom(j6);
                                    if (sqrt < (j6 & 65535) / 65536.0d) {
                                    }
                                }
                                iArr[i10] = i9;
                                i10++;
                                if (i10 >= 1000) {
                                    break;
                                }
                            }
                        }
                        i11 = 0;
                    }
                    if (i9 == -2 || ((float) j3) > this.trainWords / numThreads) {
                        wordCountActual = (int) (wordCountActual + (j3 - j4));
                        j5--;
                        if (j5 == 0) {
                            this.corpus.shutdown();
                            synchronized (this.vec) {
                                this.vec.threadCount--;
                                this.vec.notify();
                            }
                            return;
                        }
                        j3 = 0;
                        j4 = 0;
                        i10 = 0;
                        this.corpus.rewind(numThreads, this.id);
                    } else {
                        i9 = iArr[i11];
                        if (i9 != -1) {
                            for (int i12 = 0; i12 < layer1Size; i12++) {
                                dArr[i12] = 0.0d;
                            }
                            for (int i13 = 0; i13 < layer1Size; i13++) {
                                dArr2[i13] = 0.0d;
                            }
                            j6 = Word2VecTraining.nextRandom(j6);
                            int i14 = ((int) j6) % window;
                            if (useContinuousBagOfWords) {
                                long j8 = 0;
                                for (int i15 = i14; i15 < ((window * 2) + 1) - i14; i15++) {
                                    if (i15 != window && (i7 = (i11 - window) + i15) >= 0 && i7 < i10 && (i8 = iArr[i7]) != -1) {
                                        for (int i16 = 0; i16 < layer1Size; i16++) {
                                            int i17 = i16;
                                            dArr[i17] = dArr[i17] + Word2VecTraining.syn0[i16 + (i8 * layer1Size)];
                                        }
                                        j8++;
                                    }
                                }
                                if (j8 != 0) {
                                    for (int i18 = 0; i18 < layer1Size; i18++) {
                                        int i19 = i18;
                                        dArr[i19] = dArr[i19] / j8;
                                    }
                                    if (useHierarchicalSoftmax) {
                                        for (int i20 = 0; i20 < this.vocab[i9].codelen; i20++) {
                                            double d = 0.0d;
                                            int i21 = this.vocab[i9].point[i20] * layer1Size;
                                            for (int i22 = 0; i22 < layer1Size; i22++) {
                                                d += dArr[i22] * Word2VecTraining.syn1[i22 + i21];
                                            }
                                            if (d > -6.0d && d < 6.0d) {
                                                double d2 = ((1 - this.vocab[i9].code[i20]) - Word2VecTraining.expTable[(int) ((d + 6.0d) * 83.0d)]) * this.alpha;
                                                for (int i23 = 0; i23 < layer1Size; i23++) {
                                                    int i24 = i23;
                                                    dArr2[i24] = dArr2[i24] + (d2 * Word2VecTraining.syn1[i23 + i21]);
                                                }
                                                for (int i25 = 0; i25 < layer1Size; i25++) {
                                                    double[] dArr3 = Word2VecTraining.syn1;
                                                    int i26 = i25 + i21;
                                                    dArr3[i26] = dArr3[i26] + (d2 * dArr[i25]);
                                                }
                                            }
                                        }
                                    }
                                    if (negative > 0) {
                                        for (int i27 = 0; i27 < negative + 1; i27++) {
                                            if (i27 == 0) {
                                                i6 = i9;
                                                j2 = 1;
                                            } else {
                                                j6 = Word2VecTraining.nextRandom(j6);
                                                i6 = this.table[Math.abs((int) ((j6 >> 16) % 100000000))];
                                                if (i6 == 0) {
                                                    i6 = Math.abs((int) ((j6 % (this.vocabSize - 1)) + 1));
                                                }
                                                if (i6 != i9) {
                                                    j2 = 0;
                                                }
                                            }
                                            int i28 = i6 * layer1Size;
                                            double d3 = 0.0d;
                                            for (int i29 = 0; i29 < layer1Size; i29++) {
                                                d3 += dArr[i29] * Word2VecTraining.syn1neg[i29 + i28];
                                            }
                                            double d4 = d3 > 6.0d ? ((float) (j2 - 1)) * this.alpha : d3 < -6.0d ? ((float) (j2 - 0)) * this.alpha : (j2 - Word2VecTraining.expTable[(int) ((d3 + 6.0d) * 83.0d)]) * this.alpha;
                                            for (int i30 = 0; i30 < layer1Size; i30++) {
                                                int i31 = i30;
                                                dArr2[i31] = dArr2[i31] + (d4 * Word2VecTraining.syn1neg[i30 + i28]);
                                            }
                                            for (int i32 = 0; i32 < layer1Size; i32++) {
                                                double[] dArr4 = Word2VecTraining.syn1neg;
                                                int i33 = i32 + i28;
                                                dArr4[i33] = dArr4[i33] + (d4 * dArr[i32]);
                                            }
                                        }
                                    }
                                    for (int i34 = i14; i34 < ((window * 2) + 1) - i14; i34++) {
                                        if (i34 != window && (i4 = (i11 - window) + i34) >= 0 && i4 < i10 && (i5 = iArr[i4]) != -1) {
                                            for (int i35 = 0; i35 < layer1Size; i35++) {
                                                double[] dArr5 = Word2VecTraining.syn0;
                                                int i36 = i35 + (i5 * layer1Size);
                                                dArr5[i36] = dArr5[i36] + dArr2[i35];
                                            }
                                        }
                                    }
                                }
                            } else {
                                for (int i37 = i14; i37 < ((window * 2) + 1) - i14; i37++) {
                                    if (i37 != window && (i = (i11 - window) + i37) >= 0 && i < i10 && (i2 = iArr[i]) != -1) {
                                        int i38 = i2 * layer1Size;
                                        for (int i39 = 0; i39 < layer1Size; i39++) {
                                            dArr2[i39] = 0.0d;
                                        }
                                        if (useHierarchicalSoftmax) {
                                            for (int i40 = 0; i40 < this.vocab[i9].codelen; i40++) {
                                                double d5 = 0.0d;
                                                int i41 = this.vocab[i9].point[i40] * layer1Size;
                                                for (int i42 = 0; i42 < layer1Size; i42++) {
                                                    d5 += Word2VecTraining.syn0[i42 + i38] * Word2VecTraining.syn1[i42 + i41];
                                                }
                                                if (d5 > -6.0d && d5 < 6.0d) {
                                                    double d6 = ((1 - this.vocab[i9].code[i40]) - Word2VecTraining.expTable[(int) ((d5 + 6.0d) * 83.0d)]) * this.alpha;
                                                    for (int i43 = 0; i43 < layer1Size; i43++) {
                                                        int i44 = i43;
                                                        dArr2[i44] = dArr2[i44] + (d6 * Word2VecTraining.syn1[i43 + i41]);
                                                    }
                                                    for (int i45 = 0; i45 < layer1Size; i45++) {
                                                        double[] dArr6 = Word2VecTraining.syn1;
                                                        int i46 = i45 + i41;
                                                        dArr6[i46] = dArr6[i46] + (d6 * Word2VecTraining.syn0[i45 + i38]);
                                                    }
                                                }
                                            }
                                        }
                                        if (negative > 0) {
                                            for (int i47 = 0; i47 < negative + 1; i47++) {
                                                if (i47 == 0) {
                                                    i3 = i9;
                                                    j = 1;
                                                } else {
                                                    j6 = Word2VecTraining.nextRandom(j6);
                                                    i3 = this.table[Math.abs((int) ((j6 >> 16) % 100000000))];
                                                    if (i3 == 0) {
                                                        i3 = Math.abs((int) ((j6 % (this.vocabSize - 1)) + 1));
                                                    }
                                                    if (i3 != i9) {
                                                        j = 0;
                                                    }
                                                }
                                                int i48 = i3 * layer1Size;
                                                double d7 = 0.0d;
                                                for (int i49 = 0; i49 < layer1Size; i49++) {
                                                    d7 += Word2VecTraining.syn0[i49 + i38] * Word2VecTraining.syn1neg[i49 + i48];
                                                }
                                                double d8 = d7 > 6.0d ? ((float) (j - 1)) * this.alpha : d7 < -6.0d ? ((float) (j - 0)) * this.alpha : (j - Word2VecTraining.expTable[(int) ((d7 + 6.0d) * 83.0d)]) * this.alpha;
                                                for (int i50 = 0; i50 < layer1Size; i50++) {
                                                    int i51 = i50;
                                                    dArr2[i51] = dArr2[i51] + (d8 * Word2VecTraining.syn1neg[i50 + i48]);
                                                }
                                                for (int i52 = 0; i52 < layer1Size; i52++) {
                                                    double[] dArr7 = Word2VecTraining.syn1neg;
                                                    int i53 = i52 + i48;
                                                    dArr7[i53] = dArr7[i53] + (d8 * Word2VecTraining.syn0[i52 + i38]);
                                                }
                                            }
                                        }
                                        for (int i54 = 0; i54 < layer1Size; i54++) {
                                            double[] dArr8 = Word2VecTraining.syn0;
                                            int i55 = i54 + i38;
                                            dArr8[i55] = dArr8[i55] + dArr2[i54];
                                        }
                                    }
                                }
                            }
                            i11++;
                            if (i11 >= i10) {
                                i10 = 0;
                            }
                        }
                    }
                }
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
    }

    /* loaded from: input_file:WEB-INF/lib/hanlp-portable-1.8.4.jar:com/hankcs/hanlp/mining/word2vec/Word2VecTraining$VocabWordComparator.class */
    static class VocabWordComparator implements Comparator<VocabWord> {
        VocabWordComparator() {
        }

        @Override // java.util.Comparator
        public int compare(VocabWord vocabWord, VocabWord vocabWord2) {
            return vocabWord2.f9cn - vocabWord.f9cn;
        }
    }

    public Word2VecTraining(Config config) {
        this.config = config;
    }

    public Config getConfig() {
        return this.config;
    }

    public void trainModel() throws IOException {
        int layer1Size = this.config.getLayer1Size();
        TextFileCorpus textFileCorpus = new TextFileCorpus(this.config);
        Predefine.logger.info("learning vocabulary");
        textFileCorpus.learnVocab();
        Predefine.logger.info("sorting vocabulary");
        textFileCorpus.sortVocab();
        int vocabSize = textFileCorpus.getVocabSize();
        VocabWord[] vocab = textFileCorpus.getVocab();
        Predefine.logger.info("Vocab size: " + vocabSize);
        Predefine.logger.info("Words in train file: " + textFileCorpus.getTrainWords());
        if (this.config.getOutputFile() == null) {
            return;
        }
        initNet(textFileCorpus);
        if (this.config.getNegative() > 0) {
            initUnigramTable(textFileCorpus);
        }
        this.timeStart = System.currentTimeMillis();
        this.threadCount = this.config.getNumThreads();
        for (int i = 0; i < this.config.getNumThreads(); i++) {
            new TrainModelThread(this, new CacheCorpus(textFileCorpus), this.config, i).start();
        }
        textFileCorpus.shutdown();
        synchronized (this) {
            while (this.threadCount > 0) {
                try {
                    wait();
                } catch (InterruptedException e) {
                }
            }
        }
        System.err.println();
        Predefine.logger.info(String.format("finished training in %s", Utility.humanTime(System.currentTimeMillis() - this.timeStart)));
        syn1 = null;
        this.table = null;
        FileOutputStream fileOutputStream = null;
        OutputStreamWriter outputStreamWriter = null;
        PrintWriter printWriter = null;
        try {
            fileOutputStream = new FileOutputStream(this.config.getOutputFile());
            outputStreamWriter = new OutputStreamWriter(fileOutputStream, ENCODING);
            printWriter = new PrintWriter(outputStreamWriter);
            Predefine.logger.info("now saving the word vectors to the file " + this.config.getOutputFile());
            printWriter.printf("%d %d\n", Integer.valueOf(vocabSize), Integer.valueOf(layer1Size));
            for (int i2 = 0; i2 < vocabSize; i2++) {
                printWriter.print(vocab[i2].word);
                for (int i3 = 0; i3 < layer1Size; i3++) {
                    printWriter.printf(" %f", Double.valueOf(syn0[(i2 * layer1Size) + i3]));
                }
                printWriter.println();
            }
            textFileCorpus.close();
            Utility.closeQuietly((Writer) printWriter);
            Utility.closeQuietly((Writer) outputStreamWriter);
            Utility.closeQuietly((OutputStream) fileOutputStream);
        } catch (Throwable th) {
            textFileCorpus.close();
            Utility.closeQuietly((Writer) printWriter);
            Utility.closeQuietly((Writer) outputStreamWriter);
            Utility.closeQuietly((OutputStream) fileOutputStream);
            throw th;
        }
    }

    void initUnigramTable(Corpus corpus) {
        int vocabSize = corpus.getVocabSize();
        VocabWord[] vocab = corpus.getVocab();
        long j = 0;
        this.table = new int[TABLE_SIZE];
        for (int i = 0; i < vocabSize; i++) {
            j = (long) (j + Math.pow(vocab[i].f9cn, 0.75d));
        }
        int i2 = 0;
        double pow = Math.pow(vocab[0].f9cn, 0.75d) / j;
        for (int i3 = 0; i3 < TABLE_SIZE; i3++) {
            this.table[i3] = i2;
            if (i3 / 1.0E8d > pow) {
                i2++;
                pow += Math.pow(vocab[i2].f9cn, 0.75d) / j;
            }
            if (i2 >= vocabSize) {
                i2 = vocabSize - 1;
            }
        }
    }

    void initNet(Corpus corpus) {
        int layer1Size = this.config.getLayer1Size();
        int vocabSize = corpus.getVocabSize();
        syn0 = posixMemAlign128(vocabSize * layer1Size);
        if (this.config.useHierarchicalSoftmax()) {
            syn1 = posixMemAlign128(vocabSize * layer1Size);
            for (int i = 0; i < vocabSize; i++) {
                for (int i2 = 0; i2 < layer1Size; i2++) {
                    syn1[(i * layer1Size) + i2] = 0.0d;
                }
            }
        }
        if (this.config.getNegative() > 0) {
            syn1neg = posixMemAlign128(vocabSize * layer1Size);
            for (int i3 = 0; i3 < vocabSize; i3++) {
                for (int i4 = 0; i4 < layer1Size; i4++) {
                    syn1neg[(i3 * layer1Size) + i4] = 0.0d;
                }
            }
        }
        long j = 1;
        for (int i5 = 0; i5 < vocabSize; i5++) {
            for (int i6 = 0; i6 < layer1Size; i6++) {
                j = nextRandom(j);
                syn0[(i5 * layer1Size) + i6] = (((j & 65535) / 65536.0d) - 0.5d) / layer1Size;
            }
        }
        corpus.createBinaryTree();
    }

    static double[] posixMemAlign128(int i) {
        return i % 128 > 0 ? new double[((i / 128) + 1) * 128] : new double[i];
    }

    static long nextRandom(long j) {
        return (j * 25214903917L) + 11;
    }

    static {
        for (int i = 0; i < 1000; i++) {
            expTable[i] = Math.exp((((i / 1000.0d) * 2.0d) - 1.0d) * 6.0d);
            expTable[i] = expTable[i] / (expTable[i] + 1.0d);
        }
    }
}
