package com.wps.ai.runner;

import android.content.Context;
import android.net.ParseException;
import cn.wps.shareplay.message.Message;
import com.fasterxml.jackson.core.util.MinimalPrettyPrinter;
import com.wps.ai.AiAgent;
import com.wps.ai.runner.RunnerFactory;
import com.wps.ai.runner.bean.classify.ClassifierBean;
import com.wps.ai.runner.bean.classify.PrimaryCategory;
import com.wps.ai.util.TFUtil;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.HashMap;
import java.util.List;
import java.util.regex.Pattern;
import org.tensorflow.lite.a;

/* loaded from: classes15.dex */
public class VolunteerClassifierRunner extends BaseRunner<String, String> {
    public static final String CHAR2ID_FILE = "char2id.csv";
    private static final int DIM_BATCH_SIZE = 1;
    private static final int DIM_INPUT = 6000;
    private static final String LABEL = "label";
    public static final String MODEL_FILE = "textCNN.tflite";
    private static final int N_CLASSES = 2;
    private static final String SENIOR_HIGH = "senior_high";
    private static final String STOPWORD_TABLE = "stopwords.txt";
    private static HashMap<String, Integer> mStopword;
    private ClassifierBean mLabelBean;
    private String mModuleDir;
    private ByteBuffer mNetworkInput;
    private float[][] mNetworkOutput;
    private String mStopWord;
    private a mTextCNN;
    private HashMap<String, Integer> mVocabulary;

    /* loaded from: classes14.dex */
    public enum CLASSES {
        no,
        yes
    }

    /* loaded from: classes14.dex */
    public static class TextContentUtil {
        private TextContentUtil() {
        }

        public static String formatContent(String str) {
            if (str == null) {
                return null;
            }
            StringBuilder sb = new StringBuilder();
            String[] split = str.trim().replaceAll("/", MinimalPrettyPrinter.DEFAULT_ROOT_VALUE_SEPARATOR).split("\\s+");
            Pattern compile = Pattern.compile("^[0-9;_,:\\=\\(\\)\\[\\]\\{\\}\\.\\-\\+\\'\"]+$");
            for (String str2 : split) {
                if (compile.matcher(str2).matches() || VolunteerClassifierRunner.mStopword.containsKey(str2)) {
                    sb.append("<PAD>");
                    sb.append(MinimalPrettyPrinter.DEFAULT_ROOT_VALUE_SEPARATOR);
                } else {
                    if (str2.matches(".*[A-Z]+.*")) {
                        str2 = str2.toLowerCase();
                    }
                    sb.append(str2);
                    sb.append(MinimalPrettyPrinter.DEFAULT_ROOT_VALUE_SEPARATOR);
                }
            }
            return sb.toString().trim().replaceAll("^(<PAD> )+", "<PAD> ").replaceAll("( <PAD>)+$", " <PAD>").replaceAll("( <PAD>)+", " <PAD>");
        }

        public static String formatSourceString(String str) {
            if (str == null) {
                return null;
            }
            Pattern compile = Pattern.compile("^[a-zA-Z0-9]+$");
            Pattern compile2 = Pattern.compile("^[a-zA-Z0-9 \\-\\.\\']+$");
            StringBuilder sb = new StringBuilder();
            int i = 0;
            int length = str.length();
            while (i < length) {
                sb.append(MinimalPrettyPrinter.DEFAULT_ROOT_VALUE_SEPARATOR);
                sb.append(str.charAt(i));
                int i2 = i + 1;
                if (compile.matcher(str.substring(i, i2)).matches()) {
                    while (i2 < length) {
                        int i3 = i2 + 1;
                        if (compile2.matcher(str.substring(i2, i3)).matches()) {
                            sb.append(str.charAt(i2));
                            i2 = i3;
                        }
                    }
                }
                i = i2;
            }
            return sb.toString();
        }
    }

    public VolunteerClassifierRunner(Context context) {
        super(context);
        this.mNetworkInput = null;
        this.mNetworkOutput = null;
    }

    private String argmaxLabel(float[][] fArr) {
        TFUtil.log(getLogPrefix() + " -> score " + fArr[0][1]);
        ClassifierBean classifierBean = new ClassifierBean();
        this.mLabelBean = classifierBean;
        classifierBean.setCode(Integer.valueOf(this.mState.toString()).intValue());
        List<PrimaryCategory> primaryCategory = this.mLabelBean.getPrimaryCategory();
        PrimaryCategory primaryCategory2 = new PrimaryCategory();
        primaryCategory2.setCategory("");
        primaryCategory2.setFrom("content");
        primaryCategory2.setScore(fArr[0][1]);
        primaryCategory.add(primaryCategory2);
        if (fArr[0].length > 1 && fArr[0][1] > fArr[0][0]) {
            PrimaryCategory primaryCategory3 = this.mLabelBean.getPrimaryCategory().get(0);
            primaryCategory3.setCategory(SENIOR_HIGH);
            primaryCategory3.setScore(fArr[0][1]);
            primaryCategory3.setFrom("content");
        }
        TFUtil.log(getLogPrefix() + this.mLabelBean.toString());
        return this.mLabelBean.toString();
    }

    private MappedByteBuffer loadModelFile(Context context) throws IOException {
        File funcPath = RunnerEnv.getFuncPath(context, RunnerFactory.AiFunc.VOLTUNTEER_CLASSIFY);
        TFUtil.log(" RunnerFactory.AiFunc.VOLTUNTEER_CLASSIFY path " + funcPath.toString());
        File file = null;
        for (File file2 : funcPath.listFiles()) {
            TFUtil.log(" RunnerFactory.AiFunc.VOLTUNTEER_CLASSIFY path " + file2.toString());
            if (file2.getName().startsWith(MODEL_FILE)) {
                file = file2;
            }
        }
        if (file == null) {
            TFUtil.log(" RunnerFactory.AiFunc.VOLTUNTEER_CLASSIFY local model invalid or not downloaded");
        }
        FileChannel channel = new FileInputStream(file).getChannel();
        return channel.map(FileChannel.MapMode.READ_ONLY, 0L, channel.size());
    }

    private void loadStopword() throws ParseException, IOException, IllegalArgumentException {
        if (mStopword != null) {
            return;
        }
        mStopword = new HashMap<>(900);
        File file = null;
        for (File file2 : RunnerEnv.getFuncPath(AiAgent.getContext(), RunnerFactory.AiFunc.VOLTUNTEER_CLASSIFY).listFiles()) {
            if (file2.getName().startsWith(STOPWORD_TABLE)) {
                file = file2;
            }
        }
        if (file == null) {
            TFUtil.log("RunnerFactory.AiFunc.VOLTUNTEER_CLASSIFY local stopword_file invalid or not downloaded");
            throw new IOException("RunnerFactory.AiFunc.VOLTUNTEER_CLASSIFY local stopword_file invalid or not downloaded");
        }
        FileInputStream fileInputStream = new FileInputStream(file);
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(fileInputStream, "gbk"));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                fileInputStream.close();
                return;
            }
            mStopword.put(readLine.split("\n")[0], 1);
        }
    }

    private void preProcess(String str) {
        long min = Math.min(str.length(), 6000L);
        TFUtil.log(getLogPrefix() + " original length: " + str.length() + ", process length: " + min);
        int i = 0;
        while (i < min) {
            int i2 = i + 1;
            Integer num = this.mVocabulary.get(str.substring(i, i2));
            if (num != null) {
                this.mNetworkInput.putInt(num.intValue());
            } else {
                this.mNetworkInput.putInt(0);
            }
            i = i2;
        }
        long j = 6000 - min;
        for (int i3 = 0; i3 < j; i3++) {
            this.mNetworkInput.putInt(0);
        }
    }

    private void readCsv() throws ParseException, IOException, IllegalArgumentException {
        if (this.mVocabulary != null) {
            return;
        }
        this.mVocabulary = new HashMap<>(33912);
        File file = null;
        for (File file2 : RunnerEnv.getFuncPath(AiAgent.getContext(), RunnerFactory.AiFunc.VOLTUNTEER_CLASSIFY).listFiles()) {
            if (file2.getName().startsWith(CHAR2ID_FILE)) {
                file = file2;
            }
        }
        if (file == null) {
            TFUtil.log("RunnerFactory.AiFunc.VOLTUNTEER_CLASSIFY local char2id invalid or not downloaded");
            throw new IOException("RunnerFactory.AiFunc.VOLTUNTEER_CLASSIFY local char2id invalid or not downloaded");
        }
        FileInputStream fileInputStream = new FileInputStream(file);
        InputStreamReader inputStreamReader = new InputStreamReader(fileInputStream, "gbk");
        BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                inputStreamReader.close();
                fileInputStream.close();
                return;
            }
            String[] split = readLine.split(Message.SEPARATE);
            if (split.length < 2) {
                throw new IllegalArgumentException(file + " file is illegal format!");
            }
            this.mVocabulary.put(split[0], Integer.valueOf(split[1]));
        }
    }

    @Override // com.wps.ai.runner.BaseRunner, com.wps.ai.runner.Runner
    public void close() {
        a aVar = this.mTextCNN;
        if (aVar != null) {
            aVar.close();
            this.mTextCNN = null;
        }
    }

    @Override // com.wps.ai.runner.BaseRunner
    public boolean escortModel() {
        File funcPath = RunnerEnv.getFuncPath(getContext(), RunnerFactory.AiFunc.VOLTUNTEER_CLASSIFY);
        return funcPath.exists() && funcPath.listFiles().length == 3;
    }

    @Override // com.wps.ai.runner.BaseRunner
    public RunnerFactory.AiFunc getAiFunc() {
        return RunnerFactory.AiFunc.VOLTUNTEER_CLASSIFY;
    }

    @Override // com.wps.ai.runner.BaseRunner
    public String internalProcess(String str) {
        if (this.mTextCNN == null || this.mVocabulary == null) {
            return null;
        }
        preProcess(TextContentUtil.formatContent(str));
        this.mTextCNN.b(this.mNetworkInput, this.mNetworkOutput);
        return argmaxLabel(this.mNetworkOutput);
    }

    @Override // com.wps.ai.runner.BaseRunner
    public void loadModel() {
        ByteBuffer allocateDirect = ByteBuffer.allocateDirect(24000);
        this.mNetworkInput = allocateDirect;
        allocateDirect.order(ByteOrder.nativeOrder());
        this.mNetworkOutput = (float[][]) Array.newInstance((Class<?>) float.class, 1, 2);
        try {
            readCsv();
            loadStopword();
            if (this.mTextCNN == null) {
                this.mTextCNN = new a(loadModelFile(AiAgent.getContext()), 4);
            }
            TFUtil.log("VolunteerClassifier: model successfully loaded");
        } catch (Exception e) {
            TFUtil.e("VolunteerClassifier failed loading model:" + e.getMessage());
        }
    }
}
