├── .idea
├── .name
├── .gitignore
├── compiler.xml
├── vcs.xml
├── misc.xml
└── gradle.xml
├── app
├── .gitignore
├── src
│ ├── main
│ │ ├── res
│ │ │ ├── values
│ │ │ │ ├── strings.xml
│ │ │ │ ├── colors.xml
│ │ │ │ └── themes.xml
│ │ │ ├── mipmap-hdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ ├── mipmap-mdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ ├── mipmap-xhdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ ├── mipmap-xxhdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ ├── mipmap-xxxhdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ ├── drawable
│ │ │ │ ├── rect_4dp_solid.xml
│ │ │ │ ├── rect_4dp.xml
│ │ │ │ ├── ic_baseline_send_100.xml
│ │ │ │ ├── ic_baseline_arrow_back_100.xml
│ │ │ │ ├── ic_baseline_cleaning_services_100.xml
│ │ │ │ └── ic_launcher_background.xml
│ │ │ ├── mipmap-anydpi-v26
│ │ │ │ ├── ic_launcher.xml
│ │ │ │ └── ic_launcher_round.xml
│ │ │ ├── layout
│ │ │ │ ├── activity_main.xml
│ │ │ │ ├── holder_talk_answer.xml
│ │ │ │ ├── holder_talk_question.xml
│ │ │ │ ├── fragment_main.xml
│ │ │ │ ├── fragment_write.xml
│ │ │ │ ├── fragment_answer.xml
│ │ │ │ └── include_header.xml
│ │ │ ├── xml
│ │ │ │ ├── backup_rules.xml
│ │ │ │ └── data_extraction_rules.xml
│ │ │ ├── values-night
│ │ │ │ └── themes.xml
│ │ │ └── drawable-v24
│ │ │ │ └── ic_launcher_foreground.xml
│ │ ├── java
│ │ │ └── com
│ │ │ │ └── litesnap
│ │ │ │ └── open
│ │ │ │ └── rwkv
│ │ │ │ ├── GptTokenizer.java
│ │ │ │ ├── Atts.java
│ │ │ │ ├── PathManager.java
│ │ │ │ ├── App.java
│ │ │ │ ├── MyRunnable.java
│ │ │ │ ├── HexUtils.java
│ │ │ │ ├── Vocab.java
│ │ │ │ ├── GptModel.java
│ │ │ │ ├── Talk.java
│ │ │ │ ├── PreferencesManager.java
│ │ │ │ ├── MainActivity.java
│ │ │ │ ├── Pair.java
│ │ │ │ ├── StringUtils.java
│ │ │ │ ├── FileUtils.java
│ │ │ │ ├── MyAdapter.java
│ │ │ │ ├── WorldTokenizerImp.java
│ │ │ │ ├── MainFragment.java
│ │ │ │ ├── PreferencesUtils.java
│ │ │ │ ├── SampleLogits.java
│ │ │ │ ├── GptTokenizerImp.java
│ │ │ │ ├── WriteFragment.java
│ │ │ │ ├── GPTByteUtils.java
│ │ │ │ ├── OnnxModelImp.java
│ │ │ │ └── TalkFragment.java
│ │ └── AndroidManifest.xml
│ ├── test
│ │ └── java
│ │ │ └── com
│ │ │ └── litesnap
│ │ │ └── open
│ │ │ └── rwkv
│ │ │ └── ExampleUnitTest.java
│ └── androidTest
│ │ └── java
│ │ └── com
│ │ └── litesnap
│ │ └── open
│ │ └── rwkv
│ │ └── ExampleInstrumentedTest.java
├── proguard-rules.pro
└── build.gradle
├── 5.jpg
├── gradle
└── wrapper
│ ├── gradle-wrapper.jar
│ └── gradle-wrapper.properties
├── .gitignore
├── settings.gradle
├── gradle.properties
├── README.md
├── gradlew.bat
└── gradlew
/.idea/.name:
--------------------------------------------------------------------------------
1 | RWKV-Android
--------------------------------------------------------------------------------
/app/.gitignore:
--------------------------------------------------------------------------------
1 | /build
--------------------------------------------------------------------------------
/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZTMIDGO/RWKV-Android/HEAD/5.jpg
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/app/src/main/res/values/strings.xml:
--------------------------------------------------------------------------------
1 |
2 | RWKV-Android
3 |
--------------------------------------------------------------------------------
/gradle/wrapper/gradle-wrapper.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZTMIDGO/RWKV-Android/HEAD/gradle/wrapper/gradle-wrapper.jar
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-hdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZTMIDGO/RWKV-Android/HEAD/app/src/main/res/mipmap-hdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-mdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZTMIDGO/RWKV-Android/HEAD/app/src/main/res/mipmap-mdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-xhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZTMIDGO/RWKV-Android/HEAD/app/src/main/res/mipmap-xhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZTMIDGO/RWKV-Android/HEAD/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZTMIDGO/RWKV-Android/HEAD/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZTMIDGO/RWKV-Android/HEAD/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZTMIDGO/RWKV-Android/HEAD/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZTMIDGO/RWKV-Android/HEAD/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZTMIDGO/RWKV-Android/HEAD/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZTMIDGO/RWKV-Android/HEAD/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/.idea/compiler.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/app/src/main/res/drawable/rect_4dp_solid.xml:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/app/src/main/res/drawable/rect_4dp.xml:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/gradle/wrapper/gradle-wrapper.properties:
--------------------------------------------------------------------------------
1 | #Thu Jun 08 12:55:17 CST 2023
2 | distributionBase=GRADLE_USER_HOME
3 | distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-bin.zip
4 | distributionPath=wrapper/dists
5 | zipStorePath=wrapper/dists
6 | zipStoreBase=GRADLE_USER_HOME
7 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.iml
2 | .gradle
3 | /local.properties
4 | /.idea/caches
5 | /.idea/libraries
6 | /.idea/modules.xml
7 | /.idea/workspace.xml
8 | /.idea/navEditor.xml
9 | /.idea/assetWizardSettings.xml
10 | .DS_Store
11 | /build
12 | /captures
13 | .externalNativeBuild
14 | .cxx
15 | local.properties
16 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/GptTokenizer.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import java.util.List;
4 |
5 | /**
6 | * Created by ZTMIDGO 2022/9/9
7 | */
8 | public abstract interface GptTokenizer {
9 | List encode(String text);
10 | String decode(List tokens);
11 | }
12 |
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/Atts.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | /**
4 | * Created by ZTMIDGO 2023/6/20
5 | */
6 | public interface Atts {
7 | String TOP_K = "top_k";
8 | String LEN = "len";
9 | String P1 = "p1";
10 | String P2 = "p2";
11 | String TEMP = "temp";
12 | String TOP_P = "top_p";
13 | }
14 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/PathManager.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import android.content.Context;
4 |
5 | /**
6 | * Created by ZTMIDGO 2023/6/8
7 | */
8 | public class PathManager {
9 | public static String getModelPath(Context context){
10 | return context.getFilesDir().getAbsolutePath() + "/model";
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/App.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import android.app.Application;
4 |
5 | /**
6 | * Created by ZTMIDGO 2023/6/20
7 | */
8 | public class App extends Application {
9 |
10 | @Override
11 | public void onCreate() {
12 | super.onCreate();
13 | PreferencesUtils.init(this);
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/MyRunnable.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | public abstract class MyRunnable implements Runnable {
4 | private boolean isCancel;
5 |
6 | public boolean isCancel() {
7 | return isCancel;
8 | }
9 |
10 | public void setCancel(boolean cancel) {
11 | isCancel = cancel;
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/app/src/main/res/drawable/ic_baseline_send_100.xml:
--------------------------------------------------------------------------------
1 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/settings.gradle:
--------------------------------------------------------------------------------
1 | pluginManagement {
2 | repositories {
3 | gradlePluginPortal()
4 | google()
5 | mavenCentral()
6 | }
7 | }
8 | dependencyResolutionManagement {
9 | repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
10 | repositories {
11 | google()
12 | mavenCentral()
13 | }
14 | }
15 | rootProject.name = "RWKV-Android"
16 | include ':app'
17 |
--------------------------------------------------------------------------------
/app/src/main/res/drawable/ic_baseline_arrow_back_100.xml:
--------------------------------------------------------------------------------
1 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/app/src/main/res/values/colors.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | #FFBB86FC
4 | #FF6200EE
5 | #FF3700B3
6 | #FF03DAC5
7 | #FF018786
8 | #FF000000
9 | #FFFFFFFF
10 |
--------------------------------------------------------------------------------
/app/src/main/res/layout/activity_main.xml:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/app/src/test/java/com/litesnap/open/rwkv/ExampleUnitTest.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import org.junit.Test;
4 |
5 | import static org.junit.Assert.*;
6 |
7 | /**
8 | * Example local unit test, which will execute on the development machine (host).
9 | *
10 | * @see Testing documentation
11 | */
12 | public class ExampleUnitTest {
13 | @Test
14 | public void addition_isCorrect() {
15 | assertEquals(4, 2 + 2);
16 | }
17 | }
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/HexUtils.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | /**
4 | * Created by ZTMIDGO 2023/7/21
5 | */
6 | public class HexUtils {
7 | public static String charsToHex(char[] chars){
8 | StringBuilder sb = new StringBuilder();
9 | for (int i = 0; i < chars.length; i++) {
10 | String hex = Integer.toHexString(chars[i]);
11 | if (hex.length() % 2 != 0) hex = 0 + hex;
12 | sb.append(hex);
13 | }
14 | return sb.toString();
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/app/src/main/res/xml/backup_rules.xml:
--------------------------------------------------------------------------------
1 |
8 |
9 |
13 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/Vocab.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import java.util.Map;
4 |
5 | /**
6 | * Created by ZTMIDGO 2023/6/8
7 | */
8 | public class Vocab {
9 | private Inner model;
10 |
11 | public Inner getModel() {
12 | return model;
13 | }
14 |
15 | public class Inner{
16 | private Map vocab;
17 | private String[] merges;
18 |
19 | public Map getVocab() {
20 | return vocab;
21 | }
22 |
23 | public String[] getMerges() {
24 | return merges;
25 | }
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/app/src/main/res/drawable/ic_baseline_cleaning_services_100.xml:
--------------------------------------------------------------------------------
1 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/app/src/main/res/xml/data_extraction_rules.xml:
--------------------------------------------------------------------------------
1 |
6 |
7 |
8 |
12 |
13 |
19 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/GptModel.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import java.util.List;
4 |
5 | /**
6 | * Created by ZTMIDGO 2022/9/9
7 | */
8 | public interface GptModel {
9 | void generate(List arrays, int maxCount, Callback callback);
10 | int sample(List indexes, List probs);
11 | void close();
12 | void cancel();
13 | void setTop(float temp, float topp, int topk);
14 | void setPenalty(float v1, float v2);
15 | void clean();
16 | boolean isRunning();
17 | interface Callback{
18 | void callback(int token, int index, int maxCount, boolean isEnd);
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/.idea/gradle.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
18 |
19 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/Talk.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | /**
4 | * Created by ZTMIDGO 2023/6/20
5 | */
6 | public class Talk {
7 | public static final int TYPE_QUESTION = 0;
8 | public static final int TYPE_ANSWER = 1;
9 |
10 | private int type;
11 | private String text;
12 |
13 | public Talk(int type, String text) {
14 | this.type = type;
15 | this.text = text;
16 | }
17 |
18 | public int getType() {
19 | return type;
20 | }
21 |
22 | public void setType(int type) {
23 | this.type = type;
24 | }
25 |
26 | public String getText() {
27 | return text;
28 | }
29 |
30 | public void setText(String text) {
31 | this.text = text;
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/app/proguard-rules.pro:
--------------------------------------------------------------------------------
1 | # Add project specific ProGuard rules here.
2 | # You can control the set of applied configuration files using the
3 | # proguardFiles setting in build.gradle.
4 | #
5 | # For more details, see
6 | # http://developer.android.com/guide/developing/tools/proguard.html
7 |
8 | # If your project uses WebView with JS, uncomment the following
9 | # and specify the fully qualified class name to the JavaScript interface
10 | # class:
11 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview {
12 | # public *;
13 | #}
14 |
15 | # Uncomment this to preserve the line number information for
16 | # debugging stack traces.
17 | #-keepattributes SourceFile,LineNumberTable
18 |
19 | # If you keep the line number information, uncomment this to
20 | # hide the original source file name.
21 | #-renamesourcefileattribute SourceFile
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/PreferencesManager.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | /**
4 | * Created by ZTMIDGO 2023/6/20
5 | */
6 | public class PreferencesManager {
7 | public static int getTopK(){
8 | return PreferencesUtils.getInt(Atts.TOP_K, 0);
9 | }
10 |
11 | public static int getLen(){
12 | return PreferencesUtils.getInt(Atts.LEN, 512);
13 | }
14 |
15 | public static float getP1(){
16 | return PreferencesUtils.getFloat(Atts.P1, 0.7f);
17 | }
18 |
19 | public static float getP2(){
20 | return PreferencesUtils.getFloat(Atts.P2, 0.4f);
21 | }
22 |
23 | public static float getTemp(){
24 | return PreferencesUtils.getFloat(Atts.TEMP, 1.0f);
25 | }
26 |
27 | public static float getTopp(){
28 | return PreferencesUtils.getFloat(Atts.TOP_P, 0.1f);
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/app/src/main/res/values/themes.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
16 |
--------------------------------------------------------------------------------
/app/src/main/res/values-night/themes.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
16 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/MainActivity.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import android.app.ProgressDialog;
4 | import android.os.Bundle;
5 | import android.view.View;
6 | import android.widget.EditText;
7 | import android.widget.TextView;
8 |
9 | import androidx.annotation.Nullable;
10 | import androidx.appcompat.app.AppCompatActivity;
11 |
12 | import java.io.File;
13 | import java.util.concurrent.ExecutorService;
14 | import java.util.concurrent.Executors;
15 |
16 | import ai.onnxruntime.OrtSession;
17 |
18 | public class MainActivity extends AppCompatActivity {
19 | @Override
20 | protected void onCreate(@Nullable Bundle savedInstanceState) {
21 | super.onCreate(savedInstanceState);
22 | setContentView(R.layout.activity_main);
23 | getSupportFragmentManager().beginTransaction().replace(R.id.container, MainFragment.newInstance()).commit();
24 | }
25 | }
--------------------------------------------------------------------------------
/app/src/androidTest/java/com/litesnap/open/rwkv/ExampleInstrumentedTest.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import android.content.Context;
4 |
5 | import androidx.test.platform.app.InstrumentationRegistry;
6 | import androidx.test.ext.junit.runners.AndroidJUnit4;
7 |
8 | import org.junit.Test;
9 | import org.junit.runner.RunWith;
10 |
11 | import static org.junit.Assert.*;
12 |
13 | /**
14 | * Instrumented test, which will execute on an Android device.
15 | *
16 | * @see Testing documentation
17 | */
18 | @RunWith(AndroidJUnit4.class)
19 | public class ExampleInstrumentedTest {
20 | @Test
21 | public void useAppContext() {
22 | // Context of the app under test.
23 | Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
24 | assertEquals("com.litesnap.open.rwkv", appContext.getPackageName());
25 | }
26 | }
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/Pair.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import java.util.Objects;
4 |
5 | /**
6 | * Created by ZTMIDGO 2022/9/9
7 | */
8 | public class Pair {
9 | public A first;
10 | public B second;
11 |
12 | public Pair(A first, B second) {
13 | this.first = first;
14 | this.second = second;
15 | }
16 |
17 | @Override
18 | public boolean equals(Object o) {
19 | if (this == o) return true;
20 | if (o == null || getClass() != o.getClass()) return false;
21 | Pair pair = (Pair) o;
22 | return pair.first.equals(first) && pair.second.equals(second);
23 | }
24 |
25 | @Override
26 | public int hashCode() {
27 | return Objects.hash(first, second);
28 | }
29 |
30 | @Override
31 | public String toString() {
32 | return String.format("("+ first+", "+ second+")");
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/app/src/main/res/layout/holder_talk_answer.xml:
--------------------------------------------------------------------------------
1 |
2 |
6 |
24 |
--------------------------------------------------------------------------------
/app/src/main/res/layout/holder_talk_question.xml:
--------------------------------------------------------------------------------
1 |
2 |
6 |
24 |
--------------------------------------------------------------------------------
/app/src/main/AndroidManifest.xml:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
16 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/StringUtils.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import android.text.TextUtils;
4 |
5 | import java.math.BigDecimal;
6 | import java.math.RoundingMode;
7 |
8 | public class StringUtils {
9 | public static boolean isEmpty(String...strings){
10 | for (String string : strings){
11 | if (TextUtils.isEmpty(string)){
12 | return true;
13 | }
14 | }
15 |
16 | return false;
17 | }
18 |
19 | public static String[] toArrays(String text){
20 | int[] codePoints = text.codePoints().toArray();
21 | String[] words = new String[codePoints.length];
22 | for (int i = 0; i < codePoints.length; i++){
23 | int code = codePoints[i];
24 | words[i] = new String(Character.toChars(code));
25 | }
26 | return words;
27 | }
28 |
29 | public static double round(double value, int places) {
30 | if (places < 0) {
31 | throw new IllegalArgumentException();
32 | }
33 | BigDecimal bd = new BigDecimal(value);
34 | bd = bd.setScale(places, RoundingMode.HALF_UP);
35 | return bd.doubleValue();
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/gradle.properties:
--------------------------------------------------------------------------------
1 | # Project-wide Gradle settings.
2 | # IDE (e.g. Android Studio) users:
3 | # Gradle settings configured through the IDE *will override*
4 | # any settings specified in this file.
5 | # For more details on how to configure your build environment visit
6 | # http://www.gradle.org/docs/current/userguide/build_environment.html
7 | # Specifies the JVM arguments used for the daemon process.
8 | # The setting is particularly useful for tweaking memory settings.
9 | org.gradle.jvmargs=-Xmx8048m -Dfile.encoding=UTF-8
10 | # When configured, Gradle will run in incubating parallel mode.
11 | # This option should only be used with decoupled projects. More details, visit
12 | # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
13 | # org.gradle.parallel=true
14 | # AndroidX package structure to make it clearer which packages are bundled with the
15 | # Android operating system, and which are packaged with your app's APK
16 | # https://developer.android.com/topic/libraries/support-library/androidx-rn
17 | android.useAndroidX=true
18 | # Enables namespacing of each library's R class so that its R class includes only the
19 | # resources declared in the library itself and none from the library's dependencies,
20 | # thereby reducing the size of the R class for that library
21 | android.nonTransitiveRClass=true
--------------------------------------------------------------------------------
/app/src/main/res/layout/fragment_main.xml:
--------------------------------------------------------------------------------
1 |
2 |
14 |
20 |
27 |
34 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RWKV-Android
2 |
3 | 适用于 RWKV-World-0.4B 的模型,运行在CPU上进行推理
4 |
5 | 使用方式:下载模型 https://huggingface.co/TIEMING/rwkv-world-0.4B-onnx/tree/main ,把模型复制到Assets文件夹
6 |
7 | 下载APP安装体验(RWKV-0.4B-World-CHNtuned-INT8): https://huggingface.co/TIEMING/rwkv-world-0.4B-onnx/blob/main/world-CHNtuned-int8.apk
8 |
9 | 下载APP安装体验(RWKV-0.4B-Pile-INT8): https://drive.google.com/file/d/1Rrx8SlErId6TLCCL1SKhA3ba9o0rJCVz/view?usp=sharing
10 |
11 | Model for RWKV-World-0.4B, running on CPU for inference
12 |
13 | How to use: Download the model https://huggingface.co/TIEMING/rwkv-world-0.4B-onnx/tree/main and copy the model to the Assets folder
14 |
15 | Download APP Installation Experience (RWKV-0.4B-World-CHNtuned-INT8): https://huggingface.co/TIEMING/rwkv-world-0.4B-onnx/blob/main/world-CHNtuned-int8.apk
16 |
17 | Download APP Installation Experience (RWKV-0.4B-Pile-INT8): https://drive.google.com/file/d/1Rrx8SlErId6TLCCL1SKhA3ba9o0rJCVz/view?usp=sharing
18 |
19 | 
20 | 
21 | 
22 |
23 | 
24 |
--------------------------------------------------------------------------------
/app/build.gradle:
--------------------------------------------------------------------------------
1 | plugins {
2 | id 'com.android.application'
3 | }
4 |
5 | android {
6 | namespace 'com.litesnap.open.rwkv'
7 | compileSdk 32
8 |
9 | defaultConfig {
10 | applicationId "com.litesnap.open.rwkv"
11 | minSdk 24
12 | targetSdk 32
13 | versionCode 1
14 | versionName "1.0"
15 |
16 | testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
17 | }
18 |
19 | buildTypes {
20 | release {
21 | minifyEnabled false
22 | proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
23 | }
24 | }
25 |
26 | aaptOptions {
27 | noCompress 'tflite', 'txt', 'json', "ort"
28 | }
29 |
30 | compileOptions {
31 | sourceCompatibility JavaVersion.VERSION_1_8
32 | targetCompatibility JavaVersion.VERSION_1_8
33 | }
34 | }
35 |
36 | dependencies {
37 |
38 | implementation 'androidx.appcompat:appcompat:1.4.1'
39 | implementation 'com.google.android.material:material:1.5.0'
40 | implementation 'androidx.constraintlayout:constraintlayout:2.1.3'
41 | implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
42 | implementation 'com.google.code.gson:gson:2.9.1'
43 | testImplementation 'junit:junit:4.13.2'
44 | androidTestImplementation 'androidx.test.ext:junit:1.1.3'
45 | androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'
46 | }
--------------------------------------------------------------------------------
/app/src/main/res/drawable-v24/ic_launcher_foreground.xml:
--------------------------------------------------------------------------------
1 |
7 |
8 |
9 |
15 |
18 |
21 |
22 |
23 |
24 |
30 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/FileUtils.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import android.content.res.AssetManager;
4 | import android.util.Log;
5 |
6 | import java.io.File;
7 | import java.io.FileOutputStream;
8 | import java.io.IOException;
9 | import java.io.InputStream;
10 | import java.io.OutputStream;
11 |
12 | /**
13 | * Created by ZTMIDGO 2023/4/21
14 | */
15 | public class FileUtils {
16 |
17 | public static void copyAssets(AssetManager assetManager, String path, File outPath) throws IOException {
18 | String[] assets = assetManager.list(path);
19 |
20 | if (assets != null) {
21 | if (assets.length == 0) {
22 | copyFile(assetManager, path, outPath);
23 | } else {
24 | File dir = new File(outPath, path);
25 | if (!dir.exists()) {
26 | if (!dir.mkdirs()) {
27 | Log.v("copyAssets", "Failed to create directory " + dir.getAbsolutePath());
28 | }
29 | }
30 |
31 | String[] var5 = assets;
32 | int var6 = assets.length;
33 |
34 | for(int var7 = 0; var7 < var6; ++var7) {
35 | String asset = var5[var7];
36 | copyAssets(assetManager, path + "/" + asset, outPath);
37 | }
38 | }
39 |
40 | }
41 | }
42 |
43 | private static void copyFile(AssetManager assetManager, String fileName, File outPath) throws IOException {
44 | Log.v("copyFile", "Copy " + fileName + " to " + outPath);
45 | InputStream in = assetManager.open(fileName);
46 | OutputStream out = new FileOutputStream(outPath + "/" + fileName);
47 | byte[] buffer = new byte[4000];
48 |
49 | int read;
50 | while((read = in.read(buffer)) != -1) {
51 | out.write(buffer, 0, read);
52 | }
53 |
54 | in.close();
55 | out.close();
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/app/src/main/res/layout/fragment_write.xml:
--------------------------------------------------------------------------------
1 |
2 |
9 |
10 |
19 |
29 |
30 |
38 |
46 |
47 |
--------------------------------------------------------------------------------
/app/src/main/res/layout/fragment_answer.xml:
--------------------------------------------------------------------------------
1 |
2 |
9 |
10 |
15 |
19 |
23 |
31 |
45 |
53 |
54 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/MyAdapter.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import android.content.Context;
4 | import android.view.LayoutInflater;
5 | import android.view.View;
6 | import android.view.ViewGroup;
7 | import android.widget.TextView;
8 |
9 | import androidx.annotation.NonNull;
10 | import androidx.recyclerview.widget.RecyclerView;
11 |
12 | import java.util.List;
13 |
14 | /**
15 | * Created by ZTMIDGO 2023/6/20
16 | */
17 | public class MyAdapter extends RecyclerView.Adapter {
18 | private Context context;
19 | private LayoutInflater inflater;
20 | private List dataList;
21 |
22 | public MyAdapter(Context context, LayoutInflater inflater, List dataList) {
23 | this.context = context;
24 | this.inflater = inflater;
25 | this.dataList = dataList;
26 | }
27 |
28 | @NonNull
29 | @Override
30 | public Holder onCreateViewHolder(@NonNull ViewGroup parent, int viewType) {
31 | switch (viewType){
32 | case Talk.TYPE_QUESTION:
33 | return new Holder(inflater.inflate(R.layout.holder_talk_question, parent, false));
34 | case Talk.TYPE_ANSWER:
35 | return new Holder(inflater.inflate(R.layout.holder_talk_answer, parent, false));
36 | default:
37 | return null;
38 | }
39 | }
40 |
41 | @Override
42 | public void onAttachedToRecyclerView(@NonNull RecyclerView recyclerView) {
43 | super.onAttachedToRecyclerView(recyclerView);
44 | recyclerView.setItemAnimator(null);
45 | }
46 |
47 | @Override
48 | public void onBindViewHolder(@NonNull Holder holder, int position) {
49 | holder.bind(dataList.get(position));
50 | }
51 |
52 | @Override
53 | public int getItemViewType(int position) {
54 | return dataList.get(position).getType();
55 | }
56 |
57 | @Override
58 | public int getItemCount() {
59 | return dataList.size();
60 | }
61 |
62 | public void add(Talk bean){
63 | dataList.add(bean);
64 | notifyItemInserted(dataList.size());
65 | }
66 |
67 | public void clean(){
68 | dataList.clear();
69 | notifyDataSetChanged();
70 | }
71 |
72 | public class Holder extends RecyclerView.ViewHolder {
73 | private TextView mTextView;
74 | public Holder(@NonNull View itemView) {
75 | super(itemView);
76 | mTextView = itemView.findViewById(R.id.text);
77 | }
78 |
79 | public void bind(Talk bean){
80 | mTextView.setText(bean.getText());
81 | }
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/gradlew.bat:
--------------------------------------------------------------------------------
1 | @rem
2 | @rem Copyright 2015 the original author or authors.
3 | @rem
4 | @rem Licensed under the Apache License, Version 2.0 (the "License");
5 | @rem you may not use this file except in compliance with the License.
6 | @rem You may obtain a copy of the License at
7 | @rem
8 | @rem https://www.apache.org/licenses/LICENSE-2.0
9 | @rem
10 | @rem Unless required by applicable law or agreed to in writing, software
11 | @rem distributed under the License is distributed on an "AS IS" BASIS,
12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | @rem See the License for the specific language governing permissions and
14 | @rem limitations under the License.
15 | @rem
16 |
17 | @if "%DEBUG%" == "" @echo off
18 | @rem ##########################################################################
19 | @rem
20 | @rem Gradle startup script for Windows
21 | @rem
22 | @rem ##########################################################################
23 |
24 | @rem Set local scope for the variables with windows NT shell
25 | if "%OS%"=="Windows_NT" setlocal
26 |
27 | set DIRNAME=%~dp0
28 | if "%DIRNAME%" == "" set DIRNAME=.
29 | set APP_BASE_NAME=%~n0
30 | set APP_HOME=%DIRNAME%
31 |
32 | @rem Resolve any "." and ".." in APP_HOME to make it shorter.
33 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
34 |
35 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
36 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
37 |
38 | @rem Find java.exe
39 | if defined JAVA_HOME goto findJavaFromJavaHome
40 |
41 | set JAVA_EXE=java.exe
42 | %JAVA_EXE% -version >NUL 2>&1
43 | if "%ERRORLEVEL%" == "0" goto execute
44 |
45 | echo.
46 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
47 | echo.
48 | echo Please set the JAVA_HOME variable in your environment to match the
49 | echo location of your Java installation.
50 |
51 | goto fail
52 |
53 | :findJavaFromJavaHome
54 | set JAVA_HOME=%JAVA_HOME:"=%
55 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe
56 |
57 | if exist "%JAVA_EXE%" goto execute
58 |
59 | echo.
60 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
61 | echo.
62 | echo Please set the JAVA_HOME variable in your environment to match the
63 | echo location of your Java installation.
64 |
65 | goto fail
66 |
67 | :execute
68 | @rem Setup the command line
69 |
70 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
71 |
72 |
73 | @rem Execute Gradle
74 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
75 |
76 | :end
77 | @rem End local scope for the variables with windows NT shell
78 | if "%ERRORLEVEL%"=="0" goto mainEnd
79 |
80 | :fail
81 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
82 | rem the _cmd.exe /c_ return code!
83 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
84 | exit /b 1
85 |
86 | :mainEnd
87 | if "%OS%"=="Windows_NT" endlocal
88 |
89 | :omega
90 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/WorldTokenizerImp.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import android.content.Context;
4 | import android.util.Log;
5 |
6 | import com.google.gson.Gson;
7 |
8 | import java.io.BufferedInputStream;
9 | import java.io.BufferedReader;
10 | import java.io.FileInputStream;
11 | import java.io.FileReader;
12 | import java.io.InputStreamReader;
13 | import java.util.ArrayList;
14 | import java.util.Arrays;
15 | import java.util.HashMap;
16 | import java.util.HashSet;
17 | import java.util.LinkedHashSet;
18 | import java.util.List;
19 | import java.util.Map;
20 | import java.util.Set;
21 | import java.util.regex.MatchResult;
22 | import java.util.regex.Matcher;
23 | import java.util.regex.Pattern;
24 | import java.util.stream.Stream;
25 |
26 | /**
27 | * Created by ZTMIDGO 2022/9/15
28 | */
29 | public class WorldTokenizerImp implements GptTokenizer {
30 | private final String VOCAB_NAME = "vocab.json";
31 | private final Map encoder = new HashMap<>();
32 | private final Map decoder = new HashMap<>();
33 | private final Set tiesSet = new HashSet<>();
34 | private final Context context;
35 |
36 | public WorldTokenizerImp(Context context){
37 | this.context = context;
38 | fillDecoder();
39 | fillEncoder();
40 | }
41 |
42 | @Override
43 | public List encode(String text) {
44 | List result = new ArrayList<>();
45 | char[] chars = text.toCharArray();
46 | int position = 0;
47 | int start = 0;
48 |
49 | while (position <= chars.length){
50 |
51 | char[] copy = Arrays.copyOfRange(chars, start, position);
52 | if (!tiesSet.contains(HexUtils.charsToHex(copy)) || position == chars.length){
53 | while (position > start){
54 | String word = new String(Arrays.copyOfRange(chars, start, position));
55 | if (encoder.containsKey(word)) {
56 | result.add(encoder.get(word));
57 | start = position;
58 | break;
59 | }else {
60 | if (-- position <= start){
61 | start += 1;
62 | position = start;
63 | break;
64 | }
65 | }
66 | }
67 | }
68 |
69 | position ++;
70 | }
71 | return result;
72 | }
73 |
74 | @Override
75 | public String decode(List tokens) {
76 | StringBuilder sb = new StringBuilder();
77 | for (int i : tokens){
78 | if (decoder.containsKey(i)) sb.append(decoder.get(i));
79 | }
80 | return sb.toString();
81 | }
82 |
83 | private void addTies(String word){
84 | char[] chars = word.toCharArray();
85 |
86 | for (int i = 1; i <= chars.length; i++) {
87 | tiesSet.add(HexUtils.charsToHex(Arrays.copyOf(chars, i)));
88 | }
89 | }
90 |
91 | private void fillEncoder(){
92 | try {
93 | for (Map.Entry entry : decoder.entrySet()){
94 | encoder.put(entry.getValue(), entry.getKey());
95 | }
96 | }catch (Exception e){
97 | e.printStackTrace();
98 | }
99 | }
100 |
101 | private void fillDecoder(){
102 | try {
103 | String path = PathManager.getModelPath(context) + "/" + VOCAB_NAME;
104 | Map map = new HashMap<>();
105 | map.putAll(new Gson().fromJson(new FileReader(path), decoder.getClass()));
106 | for (Map.Entry entry : map.entrySet()){
107 | addTies(entry.getValue());
108 | decoder.put(Integer.parseInt(entry.getKey()), entry.getValue());
109 | }
110 | }catch (Exception e){
111 | e.printStackTrace();
112 | }
113 | }
114 | }
115 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/MainFragment.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import android.app.ProgressDialog;
4 | import android.os.Bundle;
5 | import android.os.Handler;
6 | import android.view.LayoutInflater;
7 | import android.view.View;
8 | import android.view.ViewGroup;
9 |
10 | import androidx.annotation.NonNull;
11 | import androidx.annotation.Nullable;
12 | import androidx.fragment.app.Fragment;
13 |
14 | import java.io.File;
15 | import java.util.concurrent.ExecutorService;
16 | import java.util.concurrent.Executors;
17 |
18 | /**
19 | * Created by ZTMIDGO 2023/6/20
20 | */
21 | public class MainFragment extends Fragment {
22 | private final ExecutorService exec = Executors.newCachedThreadPool();
23 |
24 | public static MainFragment newInstance() {
25 |
26 | Bundle args = new Bundle();
27 |
28 | MainFragment fragment = new MainFragment();
29 | fragment.setArguments(args);
30 | return fragment;
31 | }
32 |
33 | private View mCopyView;
34 | private View mWriteView;
35 | private View mAnswerView;
36 |
37 | private ProgressDialog dialog;
38 | private Handler uiHandler;
39 |
40 | private boolean isCopy = false;
41 |
42 | @Override
43 | public void onCreate(@Nullable Bundle savedInstanceState) {
44 | super.onCreate(savedInstanceState);
45 | uiHandler = new Handler();
46 | dialog = new ProgressDialog(getActivity());
47 | dialog.setCancelable(false);
48 | }
49 |
50 | @Override
51 | public void onDestroy() {
52 | super.onDestroy();
53 | exec.shutdownNow();
54 | }
55 |
56 | @Nullable
57 | @Override
58 | public View onCreateView(@NonNull LayoutInflater inflater, @Nullable ViewGroup container, @Nullable Bundle savedInstanceState) {
59 | View view = inflater.inflate(R.layout.fragment_main, container, false);
60 | mCopyView = view.findViewById(R.id.copy);
61 | mWriteView = view.findViewById(R.id.write);
62 | mAnswerView = view.findViewById(R.id.answer);
63 |
64 | File file = new File(PathManager.getModelPath(getActivity()) + "/model.onnx");
65 | isCopy = file.exists();
66 | setEnable(isCopy);
67 |
68 | mCopyView.setOnClickListener(new View.OnClickListener() {
69 | @Override
70 | public void onClick(View v) {
71 | dialog.show();
72 | exec.execute(new Runnable() {
73 | @Override
74 | public void run() {
75 | try {
76 | FileUtils.copyAssets(getActivity().getAssets(), "model", getActivity().getFilesDir().getAbsoluteFile());
77 | isCopy = true;
78 | }catch (Exception e){
79 | isCopy = false;
80 | }finally {
81 | uiHandler.post(new Runnable() {
82 | @Override
83 | public void run() {
84 | setEnable(isCopy);
85 | dialog.dismiss();
86 | }
87 | });
88 | }
89 | }
90 | });
91 | }
92 | });
93 |
94 | mWriteView.setOnClickListener(new View.OnClickListener() {
95 | @Override
96 | public void onClick(View v) {
97 | getActivity().getSupportFragmentManager().beginTransaction().add(R.id.container, WriteFragment.newInstance()).addToBackStack(null).commit();
98 | }
99 | });
100 |
101 | mAnswerView.setOnClickListener(new View.OnClickListener() {
102 | @Override
103 | public void onClick(View v) {
104 | getActivity().getSupportFragmentManager().beginTransaction().add(R.id.container, TalkFragment.newInstance()).addToBackStack(null).commit();
105 | }
106 | });
107 | return view;
108 | }
109 |
110 | private void setEnable(boolean isEnable){
111 | mWriteView.setEnabled(isEnable);
112 | mAnswerView.setEnabled(isEnable);
113 | }
114 | }
115 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/PreferencesUtils.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import android.content.Context;
4 | import android.content.SharedPreferences;
5 |
6 | public class PreferencesUtils {
7 | private static Context context;
8 |
9 | public static void init(Context c){
10 | context = c;
11 | }
12 |
13 | public static String getString(String key, String value){
14 | if (context == null){
15 | return null;
16 | }
17 | SharedPreferences sp=context.getSharedPreferences(context.getPackageName(), Context.MODE_PRIVATE);
18 | return sp.getString(key, value);
19 | }
20 |
21 | public static int getInt(String key, int value){
22 | if (context == null){
23 | return value;
24 | }
25 | SharedPreferences sp=context.getSharedPreferences(context.getPackageName(), Context.MODE_PRIVATE);
26 | return sp.getInt(key, value);
27 | }
28 |
29 | public static long getLong(String key, long value){
30 | if (context == null){
31 | return value;
32 | }
33 | SharedPreferences sp=context.getSharedPreferences(context.getPackageName(), Context.MODE_PRIVATE);
34 | return sp.getLong(key, value);
35 | }
36 |
37 | public static float getFloat(String key, float value){
38 | if (context == null){
39 | return value;
40 | }
41 | SharedPreferences sp=context.getSharedPreferences(context.getPackageName(), Context.MODE_PRIVATE);
42 | return sp.getFloat(key, value);
43 | }
44 |
45 | public static boolean getBoolean(String key, boolean value){
46 | if (context == null){
47 | return value;
48 | }
49 | SharedPreferences sp=context.getSharedPreferences(context.getPackageName(), Context.MODE_PRIVATE);
50 | return sp.getBoolean(key, value);
51 | }
52 |
53 | public static void setProperty(String key, String value){
54 | if (context == null){
55 | return;
56 | }
57 |
58 | SharedPreferences sp=context.getSharedPreferences(context.getPackageName(), Context.MODE_PRIVATE);
59 | SharedPreferences.Editor editor=sp.edit();
60 | editor.putString(key, value);
61 | editor.commit();
62 | }
63 |
64 | public static void setProperty(String key, float value){
65 | if (context == null){
66 | return;
67 | }
68 |
69 | SharedPreferences sp=context.getSharedPreferences(context.getPackageName(), Context.MODE_PRIVATE);
70 | SharedPreferences.Editor editor=sp.edit();
71 | editor.putFloat(key, value);
72 | editor.commit();
73 | }
74 |
75 | public static void setProperty(String key, long value){
76 | if (context == null){
77 | return;
78 | }
79 |
80 | SharedPreferences sp=context.getSharedPreferences(context.getPackageName(), Context.MODE_PRIVATE);
81 | SharedPreferences.Editor editor=sp.edit();
82 | editor.putLong(key, value);
83 | editor.commit();
84 | }
85 |
86 | public static void setProperty(String key, int value){
87 | if (context == null){
88 | return;
89 | }
90 |
91 | SharedPreferences sp=context.getSharedPreferences(context.getPackageName(), Context.MODE_PRIVATE);
92 | SharedPreferences.Editor editor=sp.edit();
93 | editor.putInt(key, value);
94 | editor.commit();
95 | }
96 |
97 | public static void setProperty(String key, boolean value){
98 | if (context == null){
99 | return;
100 | }
101 |
102 | SharedPreferences sp=context.getSharedPreferences(context.getPackageName(), Context.MODE_PRIVATE);
103 | SharedPreferences.Editor editor=sp.edit();
104 | editor.putBoolean(key, value);
105 | editor.commit();
106 | }
107 |
108 | public static void removeProperty(String key){
109 | if (context == null){
110 | return;
111 | }
112 |
113 | SharedPreferences sp=context.getSharedPreferences(context.getPackageName(), Context.MODE_PRIVATE);
114 | SharedPreferences.Editor editor=sp.edit();
115 | editor.remove(key);
116 | editor.commit();
117 | }
118 | }
119 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/SampleLogits.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import java.util.Arrays;
4 | import java.util.Comparator;
5 | import java.util.Random;
6 |
7 | public class SampleLogits {
8 | public static final Random RANDOM = new Random();
9 | public static int sample(float[] logits, float temperature, float top_p, int top_k){
10 | float[] probs = softmax(logits);
11 | int[] sorted_ids = argsort(probs, true);
12 | float[] sorted_probs = sit(probs, sorted_ids);
13 | float[] cumulative_probs = cumsum(sorted_probs);
14 | float cutoff = sorted_probs[argmax(cumulative_probs, top_p)];
15 |
16 | for (int i = 0; i < probs.length; i++){
17 | if (probs[i] < cutoff) probs[i] = 0;
18 | }
19 |
20 | if (top_k < probs.length && top_k > 0){
21 | for (int i = 0; i < sorted_ids.length - top_k; i++){
22 | probs[sorted_ids[i]] = 0;
23 | }
24 | }else if (temperature != 1){
25 | for (int i = 0; i < probs.length; i++){
26 | probs[i] = (float) Math.pow(probs[i], 1.0 / temperature);
27 | }
28 | }
29 |
30 | float sum = sum(probs);
31 | for (int i = 0; i < probs.length; i++){
32 | probs[i] = probs[i] / sum;
33 | }
34 |
35 | int[] indexs = new int[probs.length];
36 | for (int i = 0; i < indexs.length; i++) indexs[i] = i;
37 | return choice(indexs, probs);
38 | }
39 |
40 | public static float[] softmax(float[] input) {
41 | float total = 0.0f;
42 | for (float value : input) {
43 | total += (float) Math.exp(value);
44 | }
45 | float[] output = new float[input.length];
46 | for (int i = 0; i < input.length; i++) {
47 | output[i] = (float) Math.exp(input[i]) / total;
48 | }
49 | return output;
50 | }
51 |
52 | public static int[] argsort(final float[] a, final boolean ascending) {
53 | Integer[] indexes = new Integer[a.length];
54 | for (int i = 0; i < indexes.length; i++) {
55 | indexes[i] = i;
56 | }
57 | Arrays.sort(indexes, new Comparator() {
58 | @Override
59 | public int compare(final Integer i1, final Integer i2) {
60 | return (ascending ? 1 : -1) * Float.compare(a[i1], a[i2]);
61 | }
62 | });
63 | return asArray(indexes);
64 | }
65 |
66 | public static int[] asArray(final T... a) {
67 | int[] b = new int[a.length];
68 | for (int i = 0; i < b.length; i++) {
69 | b[i] = a[i].intValue();
70 | }
71 | return b;
72 | }
73 |
74 | public static float[] sit(float[] floats, int[] ids){
75 | float[] result = new float[floats.length];
76 | for (int i = 0; i < ids.length; i++){
77 | result[(ids.length - 1) - i] = floats[ids[i]];
78 | }
79 | return result;
80 | }
81 |
82 | public static float[] cumsum(float[] input) {
83 | float[] output = new float[input.length];
84 | float cumulativeSum = 0.0f;
85 | for (int i = 0; i < input.length; i++) {
86 | cumulativeSum += input[i];
87 | output[i] = cumulativeSum;
88 | }
89 | return output;
90 | }
91 |
92 | public static int argmax(float[] input, float value) {
93 | for (int i = 0; i < input.length - 1; i++) {
94 | if (input[i] > value) {
95 | return i;
96 | }
97 | }
98 | return 0;
99 | }
100 |
101 | public static float sum(float[] input) {
102 | float total = 0.0f;
103 | for (float value : input) {
104 | total += value;
105 | }
106 | return total;
107 | }
108 |
109 |
110 | public static int choice(int[] a, float[] p) {
111 | if (a.length != p.length) {
112 | throw new IllegalArgumentException("a and p must have the same length");
113 | }
114 |
115 | float r = RANDOM.nextFloat();
116 | float cumulativeProbability = 0.0f;
117 | for (int i = 0; i < a.length; i++) {
118 | cumulativeProbability += p[i];
119 | if (r <= cumulativeProbability) {
120 | return a[i];
121 | }
122 | }
123 | return -1;
124 | }
125 | }
126 |
--------------------------------------------------------------------------------
/app/src/main/res/drawable/ic_launcher_background.xml:
--------------------------------------------------------------------------------
1 |
2 |
7 |
10 |
15 |
20 |
25 |
30 |
35 |
40 |
45 |
50 |
55 |
60 |
65 |
70 |
75 |
80 |
85 |
90 |
95 |
100 |
105 |
110 |
115 |
120 |
125 |
130 |
135 |
140 |
145 |
150 |
155 |
160 |
165 |
170 |
171 |
--------------------------------------------------------------------------------
/gradlew:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | #
4 | # Copyright 2015 the original author or authors.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # https://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | #
18 |
19 | ##############################################################################
20 | ##
21 | ## Gradle start up script for UN*X
22 | ##
23 | ##############################################################################
24 |
25 | # Attempt to set APP_HOME
26 | # Resolve links: $0 may be a link
27 | PRG="$0"
28 | # Need this for relative symlinks.
29 | while [ -h "$PRG" ] ; do
30 | ls=`ls -ld "$PRG"`
31 | link=`expr "$ls" : '.*-> \(.*\)$'`
32 | if expr "$link" : '/.*' > /dev/null; then
33 | PRG="$link"
34 | else
35 | PRG=`dirname "$PRG"`"/$link"
36 | fi
37 | done
38 | SAVED="`pwd`"
39 | cd "`dirname \"$PRG\"`/" >/dev/null
40 | APP_HOME="`pwd -P`"
41 | cd "$SAVED" >/dev/null
42 |
43 | APP_NAME="Gradle"
44 | APP_BASE_NAME=`basename "$0"`
45 |
46 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
47 | DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
48 |
49 | # Use the maximum available, or set MAX_FD != -1 to use that value.
50 | MAX_FD="maximum"
51 |
52 | warn () {
53 | echo "$*"
54 | }
55 |
56 | die () {
57 | echo
58 | echo "$*"
59 | echo
60 | exit 1
61 | }
62 |
63 | # OS specific support (must be 'true' or 'false').
64 | cygwin=false
65 | msys=false
66 | darwin=false
67 | nonstop=false
68 | case "`uname`" in
69 | CYGWIN* )
70 | cygwin=true
71 | ;;
72 | Darwin* )
73 | darwin=true
74 | ;;
75 | MINGW* )
76 | msys=true
77 | ;;
78 | NONSTOP* )
79 | nonstop=true
80 | ;;
81 | esac
82 |
83 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
84 |
85 |
86 | # Determine the Java command to use to start the JVM.
87 | if [ -n "$JAVA_HOME" ] ; then
88 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
89 | # IBM's JDK on AIX uses strange locations for the executables
90 | JAVACMD="$JAVA_HOME/jre/sh/java"
91 | else
92 | JAVACMD="$JAVA_HOME/bin/java"
93 | fi
94 | if [ ! -x "$JAVACMD" ] ; then
95 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
96 |
97 | Please set the JAVA_HOME variable in your environment to match the
98 | location of your Java installation."
99 | fi
100 | else
101 | JAVACMD="java"
102 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
103 |
104 | Please set the JAVA_HOME variable in your environment to match the
105 | location of your Java installation."
106 | fi
107 |
108 | # Increase the maximum file descriptors if we can.
109 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
110 | MAX_FD_LIMIT=`ulimit -H -n`
111 | if [ $? -eq 0 ] ; then
112 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
113 | MAX_FD="$MAX_FD_LIMIT"
114 | fi
115 | ulimit -n $MAX_FD
116 | if [ $? -ne 0 ] ; then
117 | warn "Could not set maximum file descriptor limit: $MAX_FD"
118 | fi
119 | else
120 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
121 | fi
122 | fi
123 |
124 | # For Darwin, add options to specify how the application appears in the dock
125 | if $darwin; then
126 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
127 | fi
128 |
129 | # For Cygwin or MSYS, switch paths to Windows format before running java
130 | if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
131 | APP_HOME=`cygpath --path --mixed "$APP_HOME"`
132 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
133 |
134 | JAVACMD=`cygpath --unix "$JAVACMD"`
135 |
136 | # We build the pattern for arguments to be converted via cygpath
137 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
138 | SEP=""
139 | for dir in $ROOTDIRSRAW ; do
140 | ROOTDIRS="$ROOTDIRS$SEP$dir"
141 | SEP="|"
142 | done
143 | OURCYGPATTERN="(^($ROOTDIRS))"
144 | # Add a user-defined pattern to the cygpath arguments
145 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then
146 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
147 | fi
148 | # Now convert the arguments - kludge to limit ourselves to /bin/sh
149 | i=0
150 | for arg in "$@" ; do
151 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
152 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
153 |
154 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
155 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
156 | else
157 | eval `echo args$i`="\"$arg\""
158 | fi
159 | i=`expr $i + 1`
160 | done
161 | case $i in
162 | 0) set -- ;;
163 | 1) set -- "$args0" ;;
164 | 2) set -- "$args0" "$args1" ;;
165 | 3) set -- "$args0" "$args1" "$args2" ;;
166 | 4) set -- "$args0" "$args1" "$args2" "$args3" ;;
167 | 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
168 | 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
169 | 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
170 | 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
171 | 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
172 | esac
173 | fi
174 |
175 | # Escape application args
176 | save () {
177 | for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
178 | echo " "
179 | }
180 | APP_ARGS=`save "$@"`
181 |
182 | # Collect all arguments for the java command, following the shell quoting and substitution rules
183 | eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
184 |
185 | exec "$JAVACMD" "$@"
186 |
--------------------------------------------------------------------------------
/app/src/main/res/layout/include_header.xml:
--------------------------------------------------------------------------------
1 |
2 |
7 |
15 |
21 |
26 |
35 |
36 |
42 |
47 |
56 |
57 |
63 |
68 |
77 |
78 |
84 |
89 |
98 |
99 |
105 |
110 |
119 |
120 |
126 |
131 |
140 |
141 |
142 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/GptTokenizerImp.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import android.content.Context;
4 |
5 | import com.google.gson.Gson;
6 |
7 | import java.io.BufferedReader;
8 | import java.io.FileReader;
9 | import java.io.InputStreamReader;
10 | import java.nio.charset.StandardCharsets;
11 | import java.util.ArrayList;
12 | import java.util.Arrays;
13 | import java.util.HashMap;
14 | import java.util.LinkedHashSet;
15 | import java.util.List;
16 | import java.util.Map;
17 | import java.util.Set;
18 | import java.util.regex.MatchResult;
19 | import java.util.regex.Matcher;
20 | import java.util.regex.Pattern;
21 | import java.util.stream.Stream;
22 |
23 | /**
24 | * Created by ZTMIDGO 2022/9/15
25 | */
26 | public class GptTokenizerImp implements GptTokenizer {
27 | private final Pattern pattern = Pattern.compile("'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+");
28 | private final String VOCAB_NAME = "vocab.json";
29 |
30 | private final Map encoder = new HashMap<>();
31 | private final Map decoder = new HashMap<>();
32 | private final Map, Integer> bpeRanks = new HashMap<>();
33 | private final Context context;
34 |
35 | public GptTokenizerImp(Context context){
36 | this.context = context;
37 | fillEncoder();
38 | fillDecoder();
39 | }
40 |
41 | @Override
42 | public String decode(List tokens) {
43 | StringBuilder sb = new StringBuilder();
44 | for (int value : tokens){
45 | if (decoder.containsKey(value)) sb.append(decoder.get(value));
46 | }
47 | List result = new ArrayList<>();
48 | for (int i = 0; i < sb.length(); i++){
49 | String key = String.valueOf(sb.charAt(i));
50 | if (GPTByteUtils.BYTE_DECODER.containsKey(key)){
51 | result.add(GPTByteUtils.BYTE_DECODER.get(key));
52 | }
53 | }
54 | int[] ints = new int[result.size()];
55 | for (int i = 0; i < result.size(); i++) ints[i] = result.get(i);
56 | return new String(ints, 0, ints.length);
57 | }
58 |
59 | @Override
60 | public List encode(String text){
61 | List stringList = new ArrayList<>();
62 | Matcher matcher = pattern.matcher(text);
63 | while (matcher.find()){
64 | MatchResult result = matcher.toMatchResult();
65 | String value = result.group();
66 | Stream stream = value.codePoints().boxed();
67 | StringBuilder sb = new StringBuilder();
68 | Object[] array = stream.toArray();
69 | for (Object o : array){
70 | if (GPTByteUtils.BYTE_ENCODER.containsKey(o)){
71 | sb.append(GPTByteUtils.BYTE_ENCODER.get(o));
72 | }
73 | }
74 | stringList.add(sb.toString());
75 | }
76 |
77 | List> strings = new ArrayList<>();
78 | for (String string : stringList){
79 | strings.add(bpe(string));
80 | }
81 |
82 | List result = new ArrayList<>();
83 | for (List list : strings){
84 | for (String string : list) {
85 | if (encoder.containsKey(string)){
86 | result.add(encoder.get(string));
87 | }
88 | }
89 | }
90 | return result;
91 | }
92 |
93 | private List bpe(String token){
94 | if (token.length() <= 1) return Arrays.asList(token);
95 |
96 | List word = new ArrayList<>(token.length());
97 | for (int i = 0; i < token.length(); i++) word.add(String.valueOf(token.charAt(i)));
98 | Set> pairs = getPairs(word);
99 |
100 | while (true){
101 | Pair min = null;
102 | int minValue = 0;
103 | for (Pair pair : pairs){
104 | if (!bpeRanks.containsKey(pair)) {
105 | continue ;
106 | }
107 | int value = bpeRanks.get(pair);
108 | if (min == null || value < minValue){
109 | min = pair;
110 | minValue = value;
111 | }
112 | }
113 |
114 | if (min == null) break;
115 |
116 | int i = 0;
117 | List newWord = new ArrayList<>();
118 | while (i < word.size()){
119 | int j = -1;
120 | for (int x =0; x < word.size(); x++){
121 | if (x >= i && word.get(x).equals(min.first)){
122 | j = x;
123 | break;
124 | }
125 | }
126 | if (j != -1){
127 | newWord.addAll(word.subList(i, j));
128 | i = j;
129 | }else {
130 | newWord.addAll(word.subList(i, word.size()));
131 | break;
132 | }
133 |
134 | if (word.get(i).equals(min.first) && i < word.size() - 1 && word.get(i + 1).equals(min.second)){
135 | newWord.add(min.first + min.second);
136 | i += 2;
137 | }else {
138 | newWord.add(word.get(i));
139 | i += 1;
140 | }
141 | }
142 |
143 | word = newWord;
144 | if (word.size() == 1) {
145 | break;
146 | } else {
147 | pairs = getPairs(word);
148 | }
149 | }
150 | return word;
151 | }
152 |
153 | private Set> getPairs(List word){
154 | Set> result = new LinkedHashSet<>();
155 | for (int i =0; i < word.size() - 1; i++){
156 | result.add(new Pair<>(word.get(i), word.get(i + 1)));
157 | }
158 | return result;
159 | }
160 |
161 | private void fillEncoder(){
162 | try {
163 | String path = PathManager.getModelPath(context) + "/" + VOCAB_NAME;
164 | Gson gson = new Gson();
165 | Vocab vocab = gson.fromJson(new FileReader(path), Vocab.class);
166 | encoder.putAll(vocab.getModel().getVocab());
167 | fillBpeRanks(vocab.getModel().getMerges());
168 | }catch (Exception e){
169 | e.printStackTrace();
170 | }
171 | }
172 |
173 | private void fillBpeRanks(String[] array){
174 | for (int i = 0; i < array.length; i++){
175 | String[] data = array[i].split(" ");
176 | if (data.length >= 2) {
177 | bpeRanks.put(new Pair<>(data[0], data[1]), i);
178 | }
179 | }
180 | }
181 |
182 | private void fillDecoder(){
183 | for (Map.Entry entry : encoder.entrySet())
184 | decoder.put(entry.getValue(), entry.getKey());
185 | }
186 | }
187 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/WriteFragment.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import android.app.ProgressDialog;
4 | import android.os.Bundle;
5 | import android.os.Handler;
6 | import android.text.TextUtils;
7 | import android.view.LayoutInflater;
8 | import android.view.View;
9 | import android.view.ViewGroup;
10 | import android.widget.EditText;
11 | import android.widget.ScrollView;
12 |
13 | import androidx.annotation.NonNull;
14 | import androidx.annotation.Nullable;
15 | import androidx.fragment.app.Fragment;
16 |
17 | import java.util.ArrayList;
18 | import java.util.Arrays;
19 | import java.util.List;
20 | import java.util.concurrent.ExecutorService;
21 | import java.util.concurrent.Executors;
22 |
23 | /**
24 | * Created by ZTMIDGO 2023/6/20
25 | */
26 | public class WriteFragment extends Fragment {
27 | private final ExecutorService exec = Executors.newCachedThreadPool();
28 |
29 | public static WriteFragment newInstance() {
30 |
31 | Bundle args = new Bundle();
32 |
33 | WriteFragment fragment = new WriteFragment();
34 | fragment.setArguments(args);
35 | return fragment;
36 | }
37 |
38 | private EditText mContentView;
39 | private EditText mTopKView;
40 | private EditText mLenView;
41 | private EditText mP1View;
42 | private EditText mP2View;
43 | private EditText mTempView;
44 | private EditText mTopPView;
45 | private View mStartView;
46 | private ScrollView mScrollView;
47 |
48 | private Handler uiHandler;
49 | private ProgressDialog dialog;
50 | private GptTokenizer tokenizer;
51 | private OnnxModelImp model;
52 |
53 | @Override
54 | public void onCreate(@Nullable Bundle savedInstanceState) {
55 | super.onCreate(savedInstanceState);
56 | uiHandler = new Handler();
57 | dialog = new ProgressDialog(getActivity());
58 | dialog.setCancelable(false);
59 | }
60 |
61 | @Override
62 | public void onDestroy() {
63 | super.onDestroy();
64 | exec.shutdownNow();
65 | if (model != null) model.close();
66 | }
67 |
68 | @Nullable
69 | @Override
70 | public View onCreateView(@NonNull LayoutInflater inflater, @Nullable ViewGroup container, @Nullable Bundle savedInstanceState) {
71 | View view = inflater.inflate(R.layout.fragment_write, container, false);
72 | mContentView = view.findViewById(R.id.content);
73 | mTopKView = view.findViewById(R.id.top_k);
74 | mLenView = view.findViewById(R.id.len);
75 | mP1View = view.findViewById(R.id.p1);
76 | mP2View = view.findViewById(R.id.p2);
77 | mStartView = view.findViewById(R.id.start);
78 | mScrollView = view.findViewById(R.id.scroll);
79 | mTempView = view.findViewById(R.id.temp);
80 | mTopPView = view.findViewById(R.id.top_p);
81 |
82 | mTopKView.setText(String.valueOf(PreferencesManager.getTopK()));
83 | mLenView.setText(String.valueOf(PreferencesManager.getLen()));
84 | mP1View.setText(String.valueOf(PreferencesManager.getP1()));
85 | mP2View.setText(String.valueOf(PreferencesManager.getP2()));
86 | mTempView.setText(String.valueOf(PreferencesManager.getTemp()));
87 | mTopPView.setText(String.valueOf(PreferencesManager.getTopp()));
88 |
89 | mStartView.setOnClickListener(new View.OnClickListener() {
90 | @Override
91 | public void onClick(View v) {
92 | String topKStr = mTopKView.getText().toString();
93 | String lenStr = mLenView.getText().toString();
94 | String p1Str = mP1View.getText().toString();
95 | String p2Str = mP2View.getText().toString();
96 | String text = mContentView.getText().toString();
97 | String tempStr = mTempView.getText().toString();
98 | String toppStr = mTopPView.getText().toString();
99 |
100 | final int topK = TextUtils.isEmpty(topKStr) ? PreferencesManager.getTopK() : Integer.parseInt(topKStr);
101 | final int len = TextUtils.isEmpty(lenStr) ? PreferencesManager.getLen() : Integer.parseInt(lenStr);
102 | final float p1 = TextUtils.isEmpty(p1Str) ? PreferencesManager.getP1() : Float.parseFloat(p1Str);
103 | final float p2 = TextUtils.isEmpty(p2Str) ? PreferencesManager.getP2() : Float.parseFloat(p2Str);
104 | final float temp = TextUtils.isEmpty(tempStr) ? PreferencesManager.getTemp() : Float.parseFloat(tempStr);
105 | final float topp = TextUtils.isEmpty(toppStr) ? PreferencesManager.getTopp() : Float.parseFloat(toppStr);
106 |
107 |
108 | PreferencesUtils.setProperty(Atts.LEN, (int)len);
109 | PreferencesUtils.setProperty(Atts.TOP_K, (int)topK);
110 | PreferencesUtils.setProperty(Atts.P1, p1 * 1f);
111 | PreferencesUtils.setProperty(Atts.P2, p2 * 1f);
112 | PreferencesUtils.setProperty(Atts.TEMP, temp * 1f);
113 | PreferencesUtils.setProperty(Atts.TOP_P, topp * 1f);
114 |
115 | if (model != null && model.isRunning()) return;
116 |
117 | if (model == null){
118 | dialog.show();
119 | exec.execute(new MyRunnable() {
120 | @Override
121 | public void run() {
122 | tokenizer = new WorldTokenizerImp(getActivity());
123 | model = new OnnxModelImp(getActivity(), OnnxModelImp.MODE_WRITE);
124 | uiHandler.post(new Runnable() {
125 | @Override
126 | public void run() {
127 | working(text, temp, topp, topK, len, p1, p2);
128 | }
129 | });
130 | }
131 | });
132 | }else {
133 | working(text, temp, topp, topK, len, p1, p2);
134 | }
135 | }
136 | });
137 |
138 | return view;
139 | }
140 |
141 | private void working(String text, float temp, float topp, int topK, int len, float p1, float p2){
142 | dialog.dismiss();
143 | model.setTop(temp, topp, topK);
144 | model.setPenalty(p1, p2);
145 | List integers = new ArrayList<>();
146 | integers.add(11);
147 | integers.addAll(tokenizer.encode(text));
148 | model.generate(integers, len, new GptModel.Callback() {
149 | @Override
150 | public void callback(int token, int index, int maxCount, boolean isEnd) {
151 | uiHandler.post(new MyRunnable() {
152 | @Override
153 | public void run() {
154 | mContentView.append(tokenizer.decode(Arrays.asList(token)));
155 | mScrollView.fullScroll(ScrollView.FOCUS_DOWN);
156 | }
157 | });
158 | }
159 | });
160 | }
161 | }
162 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/GPTByteUtils.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import java.util.HashMap;
4 | import java.util.Map;
5 |
6 | public class GPTByteUtils {
7 | public static final Map BYTE_ENCODER = new HashMap<>();
8 | public static final Map BYTE_DECODER = new HashMap<>();
9 | static {
10 | put(33, "!");
11 | put(34, "\"");
12 | put(35, "#");
13 | put(36, "$");
14 | put(37, "%");
15 | put(38, "&");
16 | put(39, "'");
17 | put(40, "(");
18 | put(41, ")");
19 | put(42, "*");
20 | put(43, "+");
21 | put(44, ",");
22 | put(45, "-");
23 | put(46, ".");
24 | put(47, "/");
25 | put(48, "0");
26 | put(49, "1");
27 | put(50, "2");
28 | put(51, "3");
29 | put(52, "4");
30 | put(53, "5");
31 | put(54, "6");
32 | put(55, "7");
33 | put(56, "8");
34 | put(57, "9");
35 | put(58, ":");
36 | put(59, ";");
37 | put(60, "<");
38 | put(61, "=");
39 | put(62, ">");
40 | put(63, "?");
41 | put(64, "@");
42 | put(65, "A");
43 | put(66, "B");
44 | put(67, "C");
45 | put(68, "D");
46 | put(69, "E");
47 | put(70, "F");
48 | put(71, "G");
49 | put(72, "H");
50 | put(73, "I");
51 | put(74, "J");
52 | put(75, "K");
53 | put(76, "L");
54 | put(77, "M");
55 | put(78, "N");
56 | put(79, "O");
57 | put(80, "P");
58 | put(81, "Q");
59 | put(82, "R");
60 | put(83, "S");
61 | put(84, "T");
62 | put(85, "U");
63 | put(86, "V");
64 | put(87, "W");
65 | put(88, "X");
66 | put(89, "Y");
67 | put(90, "Z");
68 | put(91, "[");
69 | put(92, "\\");
70 | put(93, "]");
71 | put(94, "^");
72 | put(95, "_");
73 | put(96, "`");
74 | put(97, "a");
75 | put(98, "b");
76 | put(99, "c");
77 | put(100, "d");
78 | put(101, "e");
79 | put(102, "f");
80 | put(103, "g");
81 | put(104, "h");
82 | put(105, "i");
83 | put(106, "j");
84 | put(107, "k");
85 | put(108, "l");
86 | put(109, "m");
87 | put(110, "n");
88 | put(111, "o");
89 | put(112, "p");
90 | put(113, "q");
91 | put(114, "r");
92 | put(115, "s");
93 | put(116, "t");
94 | put(117, "u");
95 | put(118, "v");
96 | put(119, "w");
97 | put(120, "x");
98 | put(121, "y");
99 | put(122, "z");
100 | put(123, "{");
101 | put(124, "|");
102 | put(125, "}");
103 | put(126, "~");
104 | put(161, "\u00a1");
105 | put(162, "\u00a2");
106 | put(163, "\u00a3");
107 | put(164, "\u00a4");
108 | put(165, "\u00a5");
109 | put(166, "\u00a6");
110 | put(167, "\u00a7");
111 | put(168, "\u00a8");
112 | put(169, "\u00a9");
113 | put(170, "\u00aa");
114 | put(171, "\u00ab");
115 | put(172, "\u00ac");
116 | put(174, "\u00ae");
117 | put(175, "\u00af");
118 | put(176, "\u00b0");
119 | put(177, "\u00b1");
120 | put(178, "\u00b2");
121 | put(179, "\u00b3");
122 | put(180, "\u00b4");
123 | put(181, "\u00b5");
124 | put(182, "\u00b6");
125 | put(183, "\u00b7");
126 | put(184, "\u00b8");
127 | put(185, "\u00b9");
128 | put(186, "\u00ba");
129 | put(187, "\u00bb");
130 | put(188, "\u00bc");
131 | put(189, "\u00bd");
132 | put(190, "\u00be");
133 | put(191, "\u00bf");
134 | put(192, "\u00c0");
135 | put(193, "\u00c1");
136 | put(194, "\u00c2");
137 | put(195, "\u00c3");
138 | put(196, "\u00c4");
139 | put(197, "\u00c5");
140 | put(198, "\u00c6");
141 | put(199, "\u00c7");
142 | put(200, "\u00c8");
143 | put(201, "\u00c9");
144 | put(202, "\u00ca");
145 | put(203, "\u00cb");
146 | put(204, "\u00cc");
147 | put(205, "\u00cd");
148 | put(206, "\u00ce");
149 | put(207, "\u00cf");
150 | put(208, "\u00d0");
151 | put(209, "\u00d1");
152 | put(210, "\u00d2");
153 | put(211, "\u00d3");
154 | put(212, "\u00d4");
155 | put(213, "\u00d5");
156 | put(214, "\u00d6");
157 | put(215, "\u00d7");
158 | put(216, "\u00d8");
159 | put(217, "\u00d9");
160 | put(218, "\u00da");
161 | put(219, "\u00db");
162 | put(220, "\u00dc");
163 | put(221, "\u00dd");
164 | put(222, "\u00de");
165 | put(223, "\u00df");
166 | put(224, "\u00e0");
167 | put(225, "\u00e1");
168 | put(226, "\u00e2");
169 | put(227, "\u00e3");
170 | put(228, "\u00e4");
171 | put(229, "\u00e5");
172 | put(230, "\u00e6");
173 | put(231, "\u00e7");
174 | put(232, "\u00e8");
175 | put(233, "\u00e9");
176 | put(234, "\u00ea");
177 | put(235, "\u00eb");
178 | put(236, "\u00ec");
179 | put(237, "\u00ed");
180 | put(238, "\u00ee");
181 | put(239, "\u00ef");
182 | put(240, "\u00f0");
183 | put(241, "\u00f1");
184 | put(242, "\u00f2");
185 | put(243, "\u00f3");
186 | put(244, "\u00f4");
187 | put(245, "\u00f5");
188 | put(246, "\u00f6");
189 | put(247, "\u00f7");
190 | put(248, "\u00f8");
191 | put(249, "\u00f9");
192 | put(250, "\u00fa");
193 | put(251, "\u00fb");
194 | put(252, "\u00fc");
195 | put(253, "\u00fd");
196 | put(254, "\u00fe");
197 | put(255, "\u00ff");
198 | put(0, "\u0100");
199 | put(1, "\u0101");
200 | put(2, "\u0102");
201 | put(3, "\u0103");
202 | put(4, "\u0104");
203 | put(5, "\u0105");
204 | put(6, "\u0106");
205 | put(7, "\u0107");
206 | put(8, "\u0108");
207 | put(9, "\u0109");
208 | put(10, "\u010a");
209 | put(11, "\u010b");
210 | put(12, "\u010c");
211 | put(13, "\u010d");
212 | put(14, "\u010e");
213 | put(15, "\u010f");
214 | put(16, "\u0110");
215 | put(17, "\u0111");
216 | put(18, "\u0112");
217 | put(19, "\u0113");
218 | put(20, "\u0114");
219 | put(21, "\u0115");
220 | put(22, "\u0116");
221 | put(23, "\u0117");
222 | put(24, "\u0118");
223 | put(25, "\u0119");
224 | put(26, "\u011a");
225 | put(27, "\u011b");
226 | put(28, "\u011c");
227 | put(29, "\u011d");
228 | put(30, "\u011e");
229 | put(31, "\u011f");
230 | put(32, "\u0120");
231 | put(127, "\u0121");
232 | put(128, "\u0122");
233 | put(129, "\u0123");
234 | put(130, "\u0124");
235 | put(131, "\u0125");
236 | put(132, "\u0126");
237 | put(133, "\u0127");
238 | put(134, "\u0128");
239 | put(135, "\u0129");
240 | put(136, "\u012a");
241 | put(137, "\u012b");
242 | put(138, "\u012c");
243 | put(139, "\u012d");
244 | put(140, "\u012e");
245 | put(141, "\u012f");
246 | put(142, "\u0130");
247 | put(143, "\u0131");
248 | put(144, "\u0132");
249 | put(145, "\u0133");
250 | put(146, "\u0134");
251 | put(147, "\u0135");
252 | put(148, "\u0136");
253 | put(149, "\u0137");
254 | put(150, "\u0138");
255 | put(151, "\u0139");
256 | put(152, "\u013a");
257 | put(153, "\u013b");
258 | put(154, "\u013c");
259 | put(155, "\u013d");
260 | put(156, "\u013e");
261 | put(157, "\u013f");
262 | put(158, "\u0140");
263 | put(159, "\u0141");
264 | put(160, "\u0142");
265 | put(173, "\u0143");
266 | reversal();
267 | }
268 |
269 | private static void put(Integer key, String value){
270 | BYTE_ENCODER.put(key, value);
271 | }
272 |
273 | public static void reversal(){
274 | for (Map.Entry entry : BYTE_ENCODER.entrySet()) BYTE_DECODER.put(entry.getValue(), entry.getKey());
275 | }
276 | }
277 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/OnnxModelImp.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import android.content.Context;
4 | import android.util.Log;
5 |
6 | import java.nio.DoubleBuffer;
7 | import java.nio.FloatBuffer;
8 | import java.nio.IntBuffer;
9 | import java.nio.LongBuffer;
10 | import java.util.ArrayList;
11 | import java.util.Arrays;
12 | import java.util.Collections;
13 | import java.util.Comparator;
14 | import java.util.HashMap;
15 | import java.util.Iterator;
16 | import java.util.LinkedHashMap;
17 | import java.util.LinkedHashSet;
18 | import java.util.List;
19 | import java.util.Map;
20 | import java.util.Random;
21 | import java.util.Set;
22 | import java.util.concurrent.ExecutorService;
23 | import java.util.concurrent.Executors;
24 |
25 | import ai.onnxruntime.OnnxTensor;
26 | import ai.onnxruntime.OnnxValue;
27 | import ai.onnxruntime.OrtEnvironment;
28 | import ai.onnxruntime.OrtException;
29 | import ai.onnxruntime.OrtSession;
30 |
31 | /**
32 | * Created by ZTMIDGO 2022/9/15
33 | */
34 | public class OnnxModelImp implements GptModel {
35 | public static final int MODE_WRITE = 0;
36 | public static final int MODE_TALK = 1;
37 |
38 | private final String MODEL_NAME = "model.onnx";
39 | private final OrtEnvironment environment = OrtEnvironment.getEnvironment();
40 | private final OrtSession.SessionOptions options = new OrtSession.SessionOptions();
41 | private final Map map = new LinkedHashMap<>();
42 | private final Random random = new Random();
43 | private final ExecutorService exec = Executors.newCachedThreadPool();
44 |
45 | private final Context context;
46 |
47 | private float temp = 1f;
48 | private float topp = 0.1f;
49 | private int topk = 0;
50 |
51 | private final int layer = 24;
52 | private final int embd = 1024;
53 | private final int sequenceLength = 1;
54 | private final List inputNames = new ArrayList<>();
55 |
56 | private OrtSession.Result ort;
57 | private OrtSession session;
58 | private MyRunnable runnable;
59 |
60 | private int mode = MODE_TALK;
61 | private float presence = 0.7f;
62 | private float frequency = 0.4f;
63 | private boolean isRunnable;
64 |
65 | public OnnxModelImp(Context context, int mode){
66 | this.context = context;
67 | this.mode = mode;
68 | try {
69 | String path = PathManager.getModelPath(context) + "/" + MODEL_NAME;
70 | options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.BASIC_OPT);
71 | session = environment.createSession(path, options);
72 | inputNames.addAll(session.getInputNames());
73 | fillMap();
74 | }catch (Exception e){
75 | e.printStackTrace();
76 | }
77 | }
78 |
79 | @Override
80 | public void generate(List arrays, int maxCount, Callback callback) {
81 | if (runnable != null) runnable.setCancel(true);
82 | isRunnable = true;
83 | closeResult();
84 | runnable = new MyRunnable() {
85 | @Override
86 | public void run() {
87 | try {
88 | Map occurrence = new HashMap<>();
89 |
90 | if (mode == MODE_WRITE) fillMap();
91 |
92 | int nextToken = 0;
93 | int size = maxCount + arrays.size();
94 |
95 | for (int i = 0; i < size; i++) {
96 | int[] paddedTokens = new int[sequenceLength];
97 | IntBuffer buffer = IntBuffer.wrap(paddedTokens);
98 |
99 | if (!arrays.isEmpty()){
100 | nextToken = arrays.remove(0);
101 | }
102 |
103 | paddedTokens[0] = nextToken;
104 |
105 | OnnxTensor idx = OnnxTensor.createTensor(environment, buffer, new long[]{sequenceLength});
106 |
107 | map.put(inputNames.get(0), idx);
108 |
109 | ort = session.run(map);
110 | float[] outputLogits = (float[]) ort.get(0).getValue();
111 |
112 | for (Map.Entry entry : occurrence.entrySet()){
113 | int x = entry.getKey();
114 | outputLogits[x] = outputLogits[x] - (presence + entry.getValue() * frequency);
115 | }
116 |
117 | nextToken = SampleLogits.sample(outputLogits, temp, topp, topk);
118 |
119 | if (isCancel()) return;
120 |
121 | if (!occurrence.containsKey(nextToken)) occurrence.put(nextToken, 0f);
122 | occurrence.put(nextToken, occurrence.get(nextToken) + 1f);
123 |
124 | if (arrays.isEmpty()){
125 | if (mode == MODE_TALK && (nextToken == 60807 || nextToken == 23692 || nextToken == 33161 || nextToken == 82 || nextToken == 24281 || nextToken == 53648 || nextToken == 40301)) break;
126 | if (callback != null) callback.callback(nextToken, i, maxCount, false);
127 | fillMap(ort);
128 | }else {
129 | fillMap(ort);
130 | }
131 | }
132 | }catch (Exception e){
133 | e.printStackTrace();
134 | }finally {
135 | isRunnable = false;
136 | closeResult();
137 | }
138 | }
139 | };
140 | exec.execute(runnable);
141 | }
142 |
143 | @Override
144 | public int sample(List indexes, List probs){
145 | int index = randomIndex(probs);
146 | return indexes.get(index);
147 | }
148 |
149 | @Override
150 | public void close() {
151 | if (runnable != null) runnable.setCancel(true);
152 | exec.shutdown();
153 | if (session != null) {
154 | try {
155 | closeResult();
156 | session.close();
157 | options.close();
158 | }catch (Exception e){
159 | e.printStackTrace();
160 | }
161 | }
162 | }
163 |
164 | @Override
165 | public void cancel() {
166 | if (runnable != null) runnable.setCancel(true);
167 | }
168 |
169 | @Override
170 | public void setTop(float temp, float topp, int topk) {
171 | this.temp = temp;
172 | this.topp = topp;
173 | this.topk = topk;
174 | }
175 |
176 | @Override
177 | public void setPenalty(float v1, float v2) {
178 | presence = v1;
179 | frequency = v2;
180 | }
181 |
182 | @Override
183 | public void clean() {
184 | try {
185 | fillMap();
186 | } catch (Exception e) {
187 | e.printStackTrace();
188 | }
189 | }
190 |
191 | @Override
192 | public boolean isRunning() {
193 | return isRunnable;
194 | }
195 |
196 | private void fillMap() throws Exception {
197 | for (String name : inputNames){
198 | float[] buff = new float[layer * embd];
199 | if (name.equals("pp_att")) Arrays.fill(buff, (float) -1e30);
200 | OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(buff), new long[]{layer, embd});
201 | map.put(name, inputTensor);
202 | }
203 | }
204 |
205 | /*private void fillMap() throws Exception {
206 | for (int i = 0; i < inputNames.size(); i++){
207 | String name = inputNames.get(i);
208 | float[] buff = new float[embd];
209 | OnnxTensor inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(buff), new long[]{embd});
210 | map.put(name, inputTensor);
211 | }
212 | }*/
213 |
214 | private void fillMap(OrtSession.Result result){
215 | if (result == null) return;
216 | for (int x = 0; x < inputNames.size(); x++){
217 | map.put(inputNames.get(x),(OnnxTensor) result.get(x));
218 | }
219 | }
220 |
221 | private void closeResult(){
222 | if (ort != null){
223 | ort.close();
224 | ort = null;
225 | }
226 | }
227 |
228 | private int randomIndex(List probs){
229 | float sun = 0;
230 | for (float value : probs) sun += value;
231 | float rnd = sun * random.nextFloat();
232 | float acc = 0f;
233 | for (int i = 0; i < probs.size(); i++){
234 | acc += probs.get(i);
235 | if (rnd < acc) return i;
236 | }
237 | return probs.size() - 1;
238 | }
239 | }
240 |
--------------------------------------------------------------------------------
/app/src/main/java/com/litesnap/open/rwkv/TalkFragment.java:
--------------------------------------------------------------------------------
1 | package com.litesnap.open.rwkv;
2 |
3 | import android.app.ProgressDialog;
4 | import android.os.Bundle;
5 | import android.os.Handler;
6 | import android.text.TextUtils;
7 | import android.util.Log;
8 | import android.view.LayoutInflater;
9 | import android.view.View;
10 | import android.view.ViewGroup;
11 | import android.widget.EditText;
12 |
13 | import androidx.annotation.NonNull;
14 | import androidx.annotation.Nullable;
15 | import androidx.fragment.app.Fragment;
16 | import androidx.recyclerview.widget.LinearLayoutManager;
17 | import androidx.recyclerview.widget.RecyclerView;
18 |
19 | import java.util.ArrayList;
20 | import java.util.Arrays;
21 | import java.util.List;
22 | import java.util.concurrent.ExecutorService;
23 | import java.util.concurrent.Executors;
24 | import java.util.concurrent.TimeUnit;
25 |
26 | /**
27 | * Created by ZTMIDGO 2023/6/20
28 | */
29 | public class TalkFragment extends Fragment {
30 | private final ExecutorService exec = Executors.newCachedThreadPool();
31 |
32 | public static TalkFragment newInstance() {
33 |
34 | Bundle args = new Bundle();
35 |
36 | TalkFragment fragment = new TalkFragment();
37 | fragment.setArguments(args);
38 | return fragment;
39 | }
40 |
41 | private EditText mTopKView;
42 | private EditText mLenView;
43 | private EditText mP1View;
44 | private EditText mP2View;
45 | private EditText mTempView;
46 | private EditText mTopPView;
47 | private View mCleanView;
48 | private View mSendView;
49 | private EditText mEditText;
50 | private RecyclerView mRecyclerView;
51 | private LinearLayoutManager mLayoutManager;
52 |
53 | private Handler uiHandler;
54 | private ProgressDialog dialog;
55 | private GptTokenizer tokenizer;
56 | private OnnxModelImp model;
57 | private MyAdapter mAdapter;
58 |
59 | @Override
60 | public void onCreate(@Nullable Bundle savedInstanceState) {
61 | super.onCreate(savedInstanceState);
62 | uiHandler = new Handler();
63 | dialog = new ProgressDialog(getActivity());
64 | dialog.setCancelable(false);
65 | }
66 |
67 | @Override
68 | public void onDestroy() {
69 | super.onDestroy();
70 | exec.shutdownNow();
71 | if (model != null) model.close();
72 | }
73 | @Nullable
74 | @Override
75 | public View onCreateView(@NonNull LayoutInflater inflater, @Nullable ViewGroup container, @Nullable Bundle savedInstanceState) {
76 | View view = inflater.inflate(R.layout.fragment_answer, container, false);
77 | mTopKView = view.findViewById(R.id.top_k);
78 | mLenView = view.findViewById(R.id.len);
79 | mP1View = view.findViewById(R.id.p1);
80 | mP2View = view.findViewById(R.id.p2);
81 | mCleanView = view.findViewById(R.id.clean);
82 | mSendView = view.findViewById(R.id.send);
83 | mEditText = view.findViewById(R.id.edit);
84 | mRecyclerView = view.findViewById(R.id.recycler_view);
85 | mTempView = view.findViewById(R.id.temp);
86 | mTopPView = view.findViewById(R.id.top_p);
87 |
88 | mLayoutManager = new LinearLayoutManager(getActivity(), RecyclerView.VERTICAL, false);
89 | mAdapter = new MyAdapter(getActivity(), inflater, new ArrayList<>());
90 | mRecyclerView.setLayoutManager(mLayoutManager);
91 | mLayoutManager.setStackFromEnd(true);
92 | mRecyclerView.setAdapter(mAdapter);
93 |
94 | mTopKView.setText(String.valueOf(PreferencesManager.getTopK()));
95 | mLenView.setText(String.valueOf(PreferencesManager.getLen()));
96 | mP1View.setText(String.valueOf(PreferencesManager.getP1()));
97 | mP2View.setText(String.valueOf(PreferencesManager.getP2()));
98 | mTempView.setText(String.valueOf(PreferencesManager.getTemp()));
99 | mTopPView.setText(String.valueOf(PreferencesManager.getTopp()));
100 |
101 | mSendView.setOnClickListener(new View.OnClickListener() {
102 | @Override
103 | public void onClick(View v) {
104 | String topKStr = mTopKView.getText().toString();
105 | String lenStr = mLenView.getText().toString();
106 | String p1Str = mP1View.getText().toString();
107 | String p2Str = mP2View.getText().toString();
108 | String text = mEditText.getText().toString();
109 | String tempStr = mTempView.getText().toString();
110 | String toppStr = mTopPView.getText().toString();
111 |
112 | final int topK = TextUtils.isEmpty(topKStr) ? PreferencesManager.getTopK() : Integer.parseInt(topKStr);
113 | final int len = TextUtils.isEmpty(lenStr) ? PreferencesManager.getLen() : Integer.parseInt(lenStr);
114 | final float p1 = TextUtils.isEmpty(p1Str) ? PreferencesManager.getP1() : Float.parseFloat(p1Str);
115 | final float p2 = TextUtils.isEmpty(p2Str) ? PreferencesManager.getP2() : Float.parseFloat(p2Str);
116 | final float temp = TextUtils.isEmpty(tempStr) ? PreferencesManager.getTemp() : Float.parseFloat(tempStr);
117 | final float topp = TextUtils.isEmpty(toppStr) ? PreferencesManager.getTopp() : Float.parseFloat(toppStr);
118 |
119 | PreferencesUtils.setProperty(Atts.LEN, (int)len);
120 | PreferencesUtils.setProperty(Atts.TOP_K, (int)topK);
121 | PreferencesUtils.setProperty(Atts.P1, p1 * 1f);
122 | PreferencesUtils.setProperty(Atts.P2, p2 * 1f);
123 | PreferencesUtils.setProperty(Atts.TEMP, temp * 1f);
124 | PreferencesUtils.setProperty(Atts.TOP_P, topp * 1f);
125 |
126 | if (model != null && model.isRunning()) return;
127 |
128 | if (model == null){
129 | dialog.show();
130 | exec.execute(new MyRunnable() {
131 | @Override
132 | public void run() {
133 | tokenizer = new WorldTokenizerImp(getActivity());
134 | model = new OnnxModelImp(getActivity(), OnnxModelImp.MODE_TALK);
135 | uiHandler.post(new Runnable() {
136 | @Override
137 | public void run() {
138 | working(text, temp, topp, topK, len, p1, p2);
139 | }
140 | });
141 | }
142 | });
143 | }else {
144 | working(text, temp, topp, topK, len, p1, p2);
145 | }
146 | }
147 | });
148 |
149 | mCleanView.setOnClickListener(new View.OnClickListener() {
150 | @Override
151 | public void onClick(View v) {
152 | if (model == null || model.isRunning()) return;
153 |
154 | mAdapter.clean();
155 | model.clean();
156 | }
157 | });
158 |
159 | return view;
160 | }
161 |
162 | private void working(String text, float temp, float topp, int topK, int len, float p1, float p2){
163 | dialog.dismiss();
164 |
165 | if (TextUtils.isEmpty(text)) return;
166 |
167 | model.setTop(temp, topp, topK);
168 | model.setPenalty(p1, p2);
169 | mAdapter.add(new Talk(Talk.TYPE_QUESTION, text));
170 | final Talk answer = new Talk(Talk.TYPE_ANSWER, "");
171 | mAdapter.add(answer);
172 | mLayoutManager.scrollToPositionWithOffset(mAdapter.getItemCount() - 1, Integer.MIN_VALUE);
173 |
174 | List integers = new ArrayList<>();
175 | integers.add(11);
176 | integers.add(261);
177 | integers.add(53648);
178 | integers.add(59);
179 | integers.addAll(tokenizer.encode(text));
180 | integers.add(261);
181 | integers.add(40301);
182 | integers.add(59);
183 |
184 | mEditText.setText("");
185 | model.generate(integers, len, new GptModel.Callback() {
186 | long time = System.currentTimeMillis();
187 | int laster = 0;
188 | @Override
189 | public void callback(int token, int index, int maxCount, boolean isEnd) {
190 | if (TimeUnit.MILLISECONDS.toMillis(System.currentTimeMillis() - time) >= 1000){
191 | Log.e("Dong", "callback: 每秒生成 "+(index - laster));
192 | time = System.currentTimeMillis();
193 | laster = index;
194 | }
195 | answer.setText((answer.getText() + tokenizer.decode(Arrays.asList(token))).replaceFirst("(\n\n)$", ""));
196 | uiHandler.post(new MyRunnable() {
197 | @Override
198 | public void run() {
199 | mAdapter.notifyItemChanged(mAdapter.getItemCount() - 1);
200 | mLayoutManager.scrollToPositionWithOffset(mAdapter.getItemCount() - 1, Integer.MIN_VALUE);
201 | }
202 | });
203 | }
204 | });
205 | }
206 | }
207 |
--------------------------------------------------------------------------------