Triển khai bộ động cơ tự động tính đạo hàm bằng Java, xây dựng DAG động và thực hiện lan truyền ngược

Triển khai một bộ động cơ tự động tính đạo hàm cho dữ liệu vô hướng bằng Java, xây dựng đồ thị DAG động và thực hiện lan truyền ngược. Dự án này chia nhỏ mỗi neuron thành các phép cộng và nhân nhỏ, sau đó xây dựng một mạng nơ-ron sâu hoàn chỉnh để thực hiện phân loại nhị phân.

Dự án được lưu trữ tại: https://github.com/jiangnanboy/micrograd4j

Ví dụ

Dưới đây là một số ví dụ trên dữ liệu vô hướng:

test/TestEngine.java

Value x = new Value(-4.0);
Value y = new Value(2.0);
Value z = x.add(y);
Value w = x.mul(y).add(y.pow(3));
z = z.add(z.add(1));
z = z.add(z.add(1).add(x.neg()));
w = w.add(w.mul(2).add((y.add(x).relu())));
w = w.add(w.mul(3).add((y.add(x.neg())).relu()));
Value v = z.add(w.neg());
Value u = v.pow(2);
Value t = u.div(2.0);
t = t.add(u.rdiv(10.0));

t.backward();
System.out.println("x.data -> " + x.data + "; " + "x.grad -> " + x.grad); // x.data -> -4.0; x.grad -> 138.83381924198252
System.out.println("y.data -> " + y.data + "; " + "y.grad -> " + y.grad); // y.data -> 2.0; y.grad -> 645.5772594752186
System.out.println("t.data -> " + t.data + "; " + "t.grad -> " + t.grad); // t.data -> 24.70408163265306; t.grad -> 1.0
// so sánh với torch
// x.data -> tensor([-4.], dtype=torch.float64); x.grad -> tensor([138.8338], dtype=torch.float64)
// y.data -> tensor([2.], dtype=torch.float64); y.grad -> tensor([645.5773], dtype=torch.float64)
// t.data -> tensor([24.7041], dtype=torch.float64)

Đào tạo mạng nơ-ron

Demo.java

File `Demo.java` cung cấp ví dụ hoàn chỉnh để đào tạo mạng nơ-ron 2 lớp (MLP) cho bài toán phân loại nhị phân. Bằng cách khởi tạo mạng nơ-ron từ module `micrograd.nn.mlp.java`, và triển khai một hàm mất mát SVM "max-margin" đơn giản, sau đó sử dụng SGD để tối ưu hóa. Như trong mã nguồn, sử dụng mạng nơ-ron 2 lớp với hai lớp ẩn 16 nút, thực hiện ranh giới quyết định trên tập dữ liệu moon dataset (resources/test_data/test_data.txt):

MLP of [Layer of [ReLuNeuron(2),ReLuNeuron(2),ReLuNeuron(2),ReLuNeuron(2),ReLuNeuron(2),ReLuNeuron(2),ReLuNeuron(2),ReLuNeuron(2),ReLuNeuron(2),ReLuNeuron(2),ReLuNeuron(2),ReLuNeuron(2),ReLuNeuron(2),ReLuNeuron(2),ReLuNeuron(2),ReLuNeuron(2)],Layer of [ReLuNeuron(16),ReLuNeuron(16),ReLuNeuron(16),ReLuNeuron(16),ReLuNeuron(16),ReLuNeuron(16),ReLuNeuron(16),ReLuNeuron(16),ReLuNeuron(16),ReLuNeuron(16),ReLuNeuron(16),ReLuNeuron(16),ReLuNeuron(16),ReLuNeuron(16),ReLuNeuron(16),ReLuNeuron(16)],Layer of [LinearNeuron(16)]]
số tham số 337
bước 0 mất mát 1.9638621133484333, độ chính xác 50.0%
bước 1 mất mát 2.3490456150665735, độ chính xác 50.0%
bước 2 mất mát 2.6464776189077397, độ chính xác 50.0%
bước 3 mất mát 0.6152116902047455, độ chính xác 74.0%
bước 4 mất mát 1.139655680191554, độ chính xác 52.0%
bước 5 mất mát 0.5383808737647844, độ chính xác 78.0%
bước 6 mất mát 0.6543561574610679, độ chính xác 75.0%
bước 7 mất mát 0.5203942292803446, độ chính xác 78.0%
bước 8 mất mát 0.46172360123857714, độ chính xác 77.0%
bước 9 mất mát 0.5481515887605712, độ chính xác 77.0%
bước 10 mất mát 0.43029113281491693, độ chính xác 78.0%
bước 11 mất mát 0.3941153832954229, độ chính xác 79.0%
bước 12 mất mát 0.3653660105656176, độ chính xác 81.0%
bước 13 mất mát 0.36678183332025927, độ chính xác 81.0%
bước 14 mất mát 0.33462358607763043, độ chính xác 82.0%
bước 15 mất mát 0.33043574832504513, độ chính xác 83.0%
bước 16 mất mát 0.20555882746098753, độ chính xác 89.0%
bước 17 mất mát 0.17954274179722446, độ chính xác 92.0%
bước 18 mất mát 0.1634517690997521, độ chính xác 90.0%
bước 19 mất mát 0.15808448751828796, độ chính xác 93.0%
bước 20 mất mát 0.17070764235037178, độ chính xác 90.0%
bước 21 mất mát 0.15112896847332996, độ chính xác 93.0%
bước 22 mất mát 0.14816380049355352, độ chính xác 93.0%
bước 23 mất mát 0.12741924284543907, độ chính xác 93.0%
bước 24 mất mát 0.1258592888164256, độ chính xác 96.0%
bước 25 mất mát 0.14448993222619533, độ chính xác 94.0%
bước 26 mất mát 0.11703575880664031, độ chính xác 95.0%
bước 27 mất mát 0.11991399250076275, độ chính xác 95.0%
bước 28 mất mát 0.11250644859559832, độ chính xác 96.0%
bước 29 mất mát 0.1122712379123405, độ chính xác 97.0%
bước 30 mất mát 0.10848166745964823, độ chính xác 97.0%
bước 31 mất mát 0.11053301474073045, độ chính xác 96.0%
bước 32 mất mát 0.11475675943130205, độ chính xác 96.0%
bước 33 mất mát 0.1261635901826707, độ chính xác 93.0%
bước 34 mất mát 0.15131709864479434, độ chính xác 94.0%
bước 35 mất mát 0.10893801341199083, độ chính xác 95.0%
bước 36 mất mát 0.09271950174394382, độ chính xác 97.0%
bước 37 mất mát 0.09110418044688984, độ chính xác 97.0%
bước 38 mất mát 0.09912837412748972, độ chính xác 97.0%
bước 39 mất mát 0.11986141502645908, độ chính xác 96.0%
bước 40 mất mát 0.16106703014875767, độ chính xác 93.0%
bước 41 mất mát 0.09798468198520184, độ chính xác 97.0%
bước 42 mất mát 0.08102368944867655, độ chính xác 98.0%
bước 43 mất mát 0.07303947184840244, độ chính xác 99.0%
bước 44 mất mát 0.0863052809487441, độ chính xác 97.0%
bước 45 mất mát 0.07291825732593486, độ chính xác 100.0%
bước 46 mất mát 0.1057557980145795, độ chính xác 96.0%
bước 47 mất mát 0.08093449824345554, độ chính xác 97.0%
bước 48 mất mát 0.06319761143918433, độ chính xác 100.0%
bước 49 mất mát 0.06386736914872647, độ chính xác 98.0%
bước 50 mất mát 0.06845829278120481, độ chính xác 100.0%
bước 51 mất mát 0.09904393774556877, độ chính xác 96.0%
bước 52 mất mát 0.07282111419678025, độ chính xác 97.0%
bước 53 mất mát 0.05540132230996909, độ chính xác 100.0%
bước 54 mất mát 0.06998143976127322, độ chính xác 97.0%
bước 55 mất mát 0.05986002955127303, độ chính xác 100.0%
bước 56 mất mát 0.09534546654833871, độ chính xác 96.0%
bước 57 mất mát 0.06014013456733181, độ chính xác 98.0%
bước 58 mất mát 0.047855074405145484, độ chính xác 100.0%
bước 59 mất mát 0.054283928016275594, độ chính xác 98.0%
bước 60 mất mát 0.04528611993382831, độ chính xác 100.0%
bước 61 mất mát 0.05462375094558794, độ chính xác 99.0%
bước 62 mất mát 0.042032793145952985, độ chính xác 100.0%
bước 63 mất mát 0.04338790757350784, độ chính xác 100.0%
bước 64 mất mát 0.051753586897849514, độ chính xác 99.0%
bước 65 mất mát 0.03645154714588962, độ chính xác 100.0%
bước 66 mất mát 0.035129307532627406, độ chính xác 100.0%
bước 67 mất mát 0.040085759825092944, độ chính xác 100.0%
bước 68 mất mát 0.05215369584037617, độ chính xác 99.0%
bước 69 mất mát 0.03633940406301827, độ chính xác 100.0%
bước 70 mất mát 0.03888015127347711, độ chính xác 100.0%
bước 71 mất mát 0.04090424005630395, độ chính xác 100.0%
bước 72 mất mát 0.031172216887366416, độ chính xác 100.0%
bước 73 mất mát 0.04072426213271741, độ chính xác 100.0%
bước 74 mất mát 0.059378521342605975, độ chính xác 98.0%
bước 75 mất mát 0.041849846606535956, độ chính xác 100.0%
bước 76 mất mát 0.03390850067201953, độ chính xác 100.0%
bước 77 mất mát 0.02882639946719858, độ chính xác 100.0%
bước 78 mất mát 0.040177016098820253, độ chính xác 100.0%
bước 79 mất mát 0.031580874763228226, độ chính xác 100.0%
bước 80 mất mát 0.02911959861027716, độ chính xác 100.0%
bước 81 mất mát 0.03476876690968894, độ chính xác 100.0%
bước 82 mất mát 0.026663940738996236, độ chính xác 100.0%
bước 83 mất mát 0.025912574698691876, độ chính xác 100.0%
bước 84 mất mát 0.02846805443278455, độ chính xác 100.0%
bước 85 mất mát 0.02539113644948084, độ chính xác 100.0%
bước 86 mất mát 0.026658747343023592, độ chính xác 100.0%
bước 87 mất mát 0.024365215229248158, độ chính xác 100.0%
bước 88 mất mát 0.02408029822395616, độ chính xác 100.0%
bước 89 mất mát 0.023459113242738115, độ chính xác 100.0%
bước 90 mất mát 0.02343612411952584, độ chính xác 100.0%
bước 91 mất mát 0.022919043489410436, độ chính xác 100.0%
bước 92 mất mát 0.022826550953514414, độ chính xác 100.0%
bước 93 mất mát 0.02272174326823475, độ chính xác 100.0%
bước 94 mất mát 0.022656645555664323, độ chính xác 100.0%
bước 95 mất mát 0.021745650879217204, độ chính xác 100.0%
bước 96 mất mát 0.022068520750511193, độ chính xác 100.0%
bước 97 mất mát 0.021523017591105996, độ chính xác 100.0%
bước 98 mất mát 0.021910340545673795, độ chính xác 100.0%
bước 99 mất mát 0.02094203506234891, độ chính xác 100.0%

Trích dẫn

Nếu bạn sử dụng micrograd4j trong nghiên cứu của mình, vui lòng trích dẫn như sau:

@software{micrograd4j,
  author = {Shi Yan},
  title = {A micro Autograd engine developed with java},
  year = {2022},
  url = {https://github.com/jiangnanboy/micrograd4j},
}

Giấy phép

MIT

Thẻ: Java Automatic Differentiation deep learning backpropagation Neural Networks

Đăng vào ngày 14 tháng 6 lúc 01:05