package com.wps.ai.runner;

import android.content.Context;
import android.net.ParseException;
import cn.wps.shareplay.message.Message;
import com.google.gson.Gson;
import com.igexin.sdk.GTIntentService;
import com.wps.ai.AiAgent;
import com.wps.ai.runner.RunnerFactory;
import com.wps.ai.runner.bean.classify.SecondaryCategory;
import com.wps.ai.util.TFUtil;
import defpackage.hnv;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.HashMap;
import java.util.Iterator;
import java.util.regex.Pattern;
import org.json.JSONException;
import org.json.JSONObject;

/* loaded from: classes3.dex */
public class NovelClassifierRunner extends BaseRunner<String, String> {
    public static final String CHAR2ID_FILE = "novel_textcnn_vocabulary.csv";
    private static final int DIM_BATCH_SIZE = 1;
    private static final int DIM_INPUT = 30000;
    private static final String LABEL = "label";
    public static final String LABEL_FILE = "labels.json";
    public static final String MODEL_FILE = "novel_textCNN.tflite";
    private static int N_CLASSES = 7;
    private static final String SENIOR_HIGH = "senior_high";
    private JSONObject mLabel;
    private ByteBuffer mNetworkInput;
    private float[][] mNetworkOutput;
    private hnv mTextCNN;
    private HashMap<String, Integer> mVocabulary;

    /* loaded from: classes3.dex */
    public static class TextContentUtil {
        private TextContentUtil() {
        }

        public static String formatContent(String str) {
            if (str == null) {
                return null;
            }
            StringBuilder sb = new StringBuilder();
            String[] split = str.trim().replaceAll("/", " ").split("\\s+");
            Pattern compile = Pattern.compile("^[0-9;_,:\\=\\(\\)\\[\\]\\{\\}\\.\\-\\+\\'\"]+$");
            for (String str2 : split) {
                if (compile.matcher(str2).matches()) {
                    sb.append("<PAD>");
                    sb.append(" ");
                } else {
                    if (str2.matches(".*[A-Z]+.*")) {
                        str2 = str2.toLowerCase();
                    }
                    sb.append(str2);
                    sb.append(" ");
                }
            }
            return sb.toString().trim().replaceAll("^(<PAD> )+", "<PAD> ").replaceAll("( <PAD>)+$", " <PAD>").replaceAll("( <PAD>)+", " <PAD>");
        }
    }

    public NovelClassifierRunner(Context context) {
        super(context);
        this.mNetworkInput = null;
        this.mNetworkOutput = null;
    }

    private String argmaxLabel(float[][] fArr) {
        String str;
        TFUtil.log(getLogPrefix() + "  score " + fArr[0][1]);
        int i = -1;
        float f = -1.0f;
        for (int i2 = 0; i2 < N_CLASSES; i2++) {
            if (fArr[0][i2] > f) {
                f = fArr[0][i2];
                i = i2;
            }
        }
        try {
            str = this.mLabel.getString(String.valueOf(i));
        } catch (JSONException e) {
            TFUtil.e(getLogPrefix() + e.getMessage());
            str = "others";
        }
        TFUtil.log(getLogPrefix() + " predict index: " + i);
        SecondaryCategory secondaryCategory = new SecondaryCategory();
        secondaryCategory.setCategory(str);
        secondaryCategory.setScore(f);
        secondaryCategory.setFrom("content");
        return new Gson().toJson(secondaryCategory);
    }

    private void initInputOutput() {
        JSONObject jSONObject = this.mLabel;
        if (jSONObject != null) {
            Iterator<String> keys = jSONObject.keys();
            int i = 0;
            while (keys.hasNext()) {
                i++;
                keys.next();
            }
            if (N_CLASSES != i) {
                N_CLASSES = i;
            }
        }
        ByteBuffer allocateDirect = ByteBuffer.allocateDirect(120000);
        this.mNetworkInput = allocateDirect;
        allocateDirect.order(ByteOrder.nativeOrder());
        this.mNetworkOutput = (float[][]) Array.newInstance((Class<?>) float.class, 1, N_CLASSES);
    }

    private void loadLabel() throws IOException {
        File file = new File(new File(TFUtil.getModelRunDir(getContext()), RunnerFactory.AiFunc.NOVEL_CLASSIFY.toString()), LABEL_FILE);
        if (file.exists()) {
            try {
                this.mLabel = new JSONObject(TFUtil.convertStreamToString(new FileInputStream(file)));
            } catch (JSONException e) {
                TFUtil.e(getLogPrefix() + e.getMessage());
            }
        }
    }

    private MappedByteBuffer loadModelFile(Context context) throws IOException {
        File funcPath = RunnerEnv.getFuncPath(context, RunnerFactory.AiFunc.NOVEL_CLASSIFY);
        TFUtil.log(getLogPrefix() + " path " + funcPath.toString());
        File file = null;
        for (File file2 : funcPath.listFiles()) {
            TFUtil.log(getLogPrefix() + " path " + file2.toString());
            if (file2.getName().startsWith(MODEL_FILE)) {
                file = file2;
            }
        }
        if (file == null) {
            TFUtil.log(getLogPrefix() + " local model invalid or not downloaded");
        }
        FileChannel channel = new FileInputStream(file).getChannel();
        return channel.map(FileChannel.MapMode.READ_ONLY, 0L, channel.size());
    }

    private void preProcess(String str) {
        String[] split = str.split(" ");
        long min = Math.min(split.length, GTIntentService.WAIT_TIME);
        for (int i = 0; i < min; i++) {
            Integer num = this.mVocabulary.get(split[i]);
            if (num != null) {
                this.mNetworkInput.putInt(num.intValue());
            } else {
                this.mNetworkInput.putInt(0);
            }
        }
        long j = GTIntentService.WAIT_TIME - min;
        for (int i2 = 0; i2 < j; i2++) {
            this.mNetworkInput.putInt(0);
        }
    }

    private void readCsv() throws ParseException, IOException, IllegalArgumentException {
        if (this.mVocabulary != null) {
            return;
        }
        this.mVocabulary = new HashMap<>(33912);
        File file = null;
        for (File file2 : RunnerEnv.getFuncPath(AiAgent.getContext(), RunnerFactory.AiFunc.NOVEL_CLASSIFY).listFiles()) {
            if (file2.getName().startsWith(CHAR2ID_FILE)) {
                file = file2;
            }
        }
        if (file == null) {
            TFUtil.log(getLogPrefix() + " local char2id invalid or not downloaded");
            throw new IOException(getLogPrefix() + " local char2id invalid or not downloaded");
        }
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(file), "gbk"));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                return;
            }
            String[] split = readLine.split(Message.SEPARATE);
            if (split.length < 2) {
                throw new IllegalArgumentException(file + " file is illegal format!");
            }
            this.mVocabulary.put(split[0], Integer.valueOf(split[1]));
        }
    }

    @Override // com.wps.ai.runner.BaseRunner, com.wps.ai.runner.Runner
    public void close() {
        hnv hnvVar = this.mTextCNN;
        if (hnvVar != null) {
            hnvVar.close();
            this.mTextCNN = null;
        }
    }

    @Override // com.wps.ai.runner.BaseRunner
    public boolean escortModel() {
        File funcPath = RunnerEnv.getFuncPath(getContext(), RunnerFactory.AiFunc.NOVEL_CLASSIFY);
        return funcPath.exists() && funcPath.listFiles().length == 3;
    }

    @Override // com.wps.ai.runner.BaseRunner
    public RunnerFactory.AiFunc getAiFunc() {
        return RunnerFactory.AiFunc.NOVEL_CLASSIFY;
    }

    @Override // com.wps.ai.runner.BaseRunner
    public String internalProcess(String str) {
        if (this.mTextCNN == null || this.mVocabulary == null) {
            return null;
        }
        preProcess(TextContentUtil.formatContent(str));
        this.mTextCNN.b(this.mNetworkInput, this.mNetworkOutput);
        return argmaxLabel(this.mNetworkOutput);
    }

    @Override // com.wps.ai.runner.BaseRunner
    public void loadModel() {
        try {
            loadLabel();
            readCsv();
            if (this.mTextCNN == null) {
                this.mTextCNN = new hnv(loadModelFile(AiAgent.getContext()), 4);
                TFUtil.log(getLogPrefix() + " model successfully loaded");
            }
            initInputOutput();
        } catch (Exception e) {
            TFUtil.e(getLogPrefix() + " failed loading model:" + e.getMessage());
        }
    }
}
