- Kiến thức nền tảng về mạng neural
1.1 Đơn vị xử lý thần kinh (Neuron)
Mạng neural (Neural Network) được cấu tạo từ nhiều đơn vị xử lý thần kinh kết nối với nhau. Neuron là thành phần cơ bản nhất của mạng neural, và toàn bộ mạng được xây dựng từ nhiều neuron. Cấu trúc của một neuron như sau:
Trong đó, x₁, x₂, x₃ và 1 là các đầu vào, hw,b(x) là đầu ra.
Hàm f(x) được gọi là hàm kích hoạt (activation function). Hai hàm kích hoạt phổ biến nhất là hàm sigmoid và hàm tanh (hyperbolic tangent).
Hàm sigmoid có công thức:
Hàm tanh có công thức:
1.2 Kiến trúc mạng neural
Mạng neural bao gồm nhiều tầng, với các neuron ở các tầng liền kề có mối quan hệ đầu vào. Tầng đầu tiên được gọi là tầng đầu vào (input layer), tầng cuối cùng là tầng đầu ra (output layer), và các tầng ở giữa được gọi là tầng ẩn (hidden layers).
1.3 Lan truyền tiến và lan truyền ngược
Giả sử mạng neural có n tầng, được đánh số lần lượt là L₁, L₂, ..., Lₙ. Số lượng neuron ở tầng thứ p (với p = 1,2,...,n) là mₚ. aⱽᵏ) biểu thị giá trị đầu ra của neuron thứ j ở tầng k.
Đối với tầng L₁ (tầng đầu vào), ta có:
Đầu ra của neuron thứ j ở tầng (k+1) được tính bằng:
Giả sử lỗi của một mẫu huấn luyện là:
Hàm lỗi tổng thể được định nghĩa như sau:
Để tránh overfitting, ta thêm thành phần L2 regularization:
Mục tiêu là tìm (w,b) để J(w,b) đạt giá trị nhỏ nhất. Ta sử dụng phương pháp gradient descent, cập nhật w và b theo công thức sau ở mỗi lần lặp:
Đối với neuron thứ j ở tầng cuối cùng (tầng đầu ra), sai số (residual) được tính bằng:
Sai số của neuron thứ i ở tầng k được tính bằng:
Quy trình tìm (w,b) như sau:
-
Với mọi k, khởi tạo w⁽ᵏ⁾:=0, b⁽ᵏ⁾:=0;
-
Lan truyền tiến: Với mỗi mẫu dữ liệu, tính toán hw,b(x) dựa trên các giá trị đầu vào và w⁽ᵏ⁾, b⁽ᵏ⁾;
-
Lan truyền ngược: Tính toán sai số cho từng neuron ở mỗi tầng;
-
Cập nhật giá trị của w và b.
Lặp lại các bước (2)~(4) cho đến khi đạt số lần lặp mong muốn.
- Triển khai mạng neural trong MLlib
Lớp mạng neural trong MLlib là NeuralNet. Các tham số chính bao gồm:
- Size: Mảng[Int] - Số lượng neuron ở mỗi tầng
- Layer: Số lượng tầng của mạng neural
- Activation_function: Hàm kích hoạt, có thể là 'sigm' hoặc 'tanh'
- Output_function: Hàm đầu ra, có thể là 'sigm', 'softmax' hoặc 'linear'
Mã nguồn triển khai:
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.{ SparkConf, SparkContext }
import breeze.linalg.{
DenseMatrix => BDM,
max => Bmax,
min => Bmin
}
import scala.collection.mutable.ArrayBuffer
object NeuralNetworkExample {
def main(args: Array[String]): Unit = {
// Cấu hình môi trường chạy
val conf = new SparkConf()
.setAppName("Neural Network Example")
.setMaster("spark://master:7077")
.setJars(Seq("E:\\Projects\\SparkML\\NeuralNetwork.jar"))
val sc = new SparkContext(conf)
Logger.getRootLogger.setLevel(Level.WARN)
// Tạo dữ liệu mẫu ngẫu nhiên
val sampleRows = 1000
val sampleCols = 5
val randomSamples = RandSampleData.generateMatrix(sampleRows, sampleCols, -10, 10, "sphere")
// Chuẩn hóa dữ liệu
val maxValues = Bmax(randomSamples(::, breeze.linalg.*))
val minValues = Bmin(randomSamples(::, breeze.linalg.*))
val normalized1 = randomSamples - (BDM.ones[Double](randomSamples.rows, 1)) * minValues
val normalized2 = normalized1 :/ ((BDM.ones[Double](normalized1.rows, 1)) * (maxValues - minValues))
// Chuyển đổi định dạng dữ liệu
val sampleList = ArrayBuffer[BDM[Double]]()
for (i <- 0 until sampleRows) {
val matrixRow = normalized2(i, ::)
val rowVector = matrixRow.inner
val rowArray = rowVector.toArray
val sampleMatrix = new BDM(1, rowArray.length, rowArray)
sampleList += sampleMatrix
}
val parallelSamples = sc.parallelize(sampleList, 10)
sc.setCheckpointDir("hdfs://master:9000/ml/checkpoint")
parallelSamples.checkpoint()
val trainingData = parallelSamples.map(f => (new BDM(1, 1, f(::, 0).data), f(::, 1 to -1)))
// Huấn luyện mô hình
val options = Array(100.0, 50.0, 0.0)
trainingData.cache()
val totalExamples = trainingData.count()
println(s"Số lượng mẫu: $totalExamples")
val neuralModel = new NeuralNet()
.setSize(Array(5, 10, 10, 10, 10, 10, 1))
.setLayer(7)
.setActivation_function("tanh_opt")
.setLearningRate(2.0)
.setScaling_learningRate(1.0)
.setWeightPenaltyL2(0.0)
.setNonSparsityPenalty(0.0)
.setSparsityTarget(0.05)
.setInputZeroMaskedFraction(0.0)
.setDropoutFraction(0.0)
.setOutput_function("sigm")
.train(trainingData, options)
// Kiểm tra mô hình
val predictions = neuralModel.predict(trainingData)
val predictionError = neuralModel.Loss(predictions)
println(s"Lỗi của mạng neural = $predictionError")
val samplePredictions = predictions.map(f => (f.label.data(0), f.predict_label.data(0))).take(100)
println("Kết quả dự đoán")
println("Giá trị thực\t\tDự đoán\t\tLỗi")
for (i <- 0 until samplePredictions.length)
println(f"${samplePredictions(i)._1}%.6f\t\t${samplePredictions(i)._2}%.6f\t\t${samplePredictions(i)._2 - samplePredictions(i)._1}%.6f")
// Hiển thị trọng số của các tầng
var layerWeights = neuralModel.weights(0)
for (i <- 0 to 5) {
layerWeights = neuralModel.weights(i)
println(s"Trọng số của tầng ${i+1}")
for (j <- 0 to layerWeights.rows - 1) {
for (k <- 0 to layerWeights.cols - 1) {
print(f"${layerWeights(j, k)}%.6f\t")
}
println()
}
}
}
}
Mã trên xây dựng một mạng neural 7 tầng với số lượng neuron ở các tầng tương ứng là Array(5, 10, 10, 10, 10, 10, 1), và kiểm tra hàm Sphere.
Kết quả thực hiện:
Số lượng mẫu: 1000
epoch: numepochs = 1 , Took = 17 seconds; Full-batch train mse = 0.066738, val mse = 0.000000.
epoch: numepochs = 2 , Took = 12 seconds; Full-batch train mse = 0.069649, val mse = 0.000000.
epoch: numepochs = 3 , Took = 10 seconds; Full-batch train mse = 0.055260, val mse = 0.000000.
epoch: numepochs = 4 , Took = 10 seconds; Full-batch train mse = 0.016346, val mse = 0.000000.
epoch: numepochs = 5 , Took = 9 seconds; Full-batch train mse = 0.013802, val mse = 0.000000.
epoch: numepochs = 6 , Took = 13 seconds; Full-batch train mse = 0.045142, val mse = 0.000000.
epoch: numepochs = 7 , Took = 7 seconds; Full-batch train mse = 0.031211, val mse = 0.000000.
epoch: numepochs = 8 , Took = 7 seconds; Full-batch train mse = 0.016334, val mse = 0.000000.
epoch: numepochs = 9 , Took = 9 seconds; Full-batch train mse = 0.013348, val mse = 0.000000.
epoch: numepochs = 10 , Took = 7 seconds; Full-batch train mse = 0.017879, val mse = 0.000000.
epoch: numepochs = 11 , Took = 7 seconds; Full-batch train mse = 0.012627, val mse = 0.000000.
epoch: numepochs = 12 , Took = 7 seconds; Full-batch train mse = 0.018080, val mse = 0.000000.
epoch: numepochs = 13 , Took = 7 seconds; Full-batch train mse = 0.016755, val mse = 0.000000.
epoch: numepochs = 14 , Took = 7 seconds; Full-batch train mse = 0.012250, val mse = 0.000000.
epoch: numepochs = 15 , Took = 7 seconds; Full-batch train mse = 0.044833, val mse = 0.000000.
epoch: numepochs = 16 , Took = 7 seconds; Full-batch train mse = 0.024345, val mse = 0.000000.
epoch: numepochs = 17 , Took = 7 seconds; Full-batch train mse = 0.039005, val mse = 0.000000.
epoch: numepochs = 18 , Took = 7 seconds; Full-batch train mse = 0.012298, val mse = 0.000000.
epoch: numepochs = 19 , Took = 7 seconds; Full-batch train mse = 0.012371, val mse = 0.000000.
epoch: numepochs = 20 , Took = 6 seconds; Full-batch train mse = 0.014077, val mse = 0.000000.
epoch: numepochs = 21 , Took = 7 seconds; Full-batch train mse = 0.040328, val mse = 0.000000.
epoch: numepochs = 22 , Took = 6 seconds; Full-batch train mse = 0.036575, val mse = 0.000000.
epoch: numepochs = 23 , Took = 6 seconds; Full-batch train mse = 0.033986, val mse = 0.000000.
epoch: numepochs = 24 , Took = 6 seconds; Full-batch train mse = 0.026421, val mse = 0.000000.
epoch: numepochs = 25 , Took = 6 seconds; Full-batch train mse = 0.036776, val mse = 0.000000.
epoch: numepochs = 26 , Took = 6 seconds; Full-batch train mse = 0.011838, val mse = 0.000000.
epoch: numepochs = 27 , Took = 6 seconds; Full-batch train mse = 0.010749, val mse = 0.000000.
epoch: numepochs = 28 , Took = 6 seconds; Full-batch train mse = 0.012717, val mse = 0.000000.
epoch: numepochs = 29 , Took = 6 seconds; Full-batch train mse = 0.011883, val mse = 0.000000.
epoch: numepochs = 30 , Took = 7 seconds; Full-batch train mse = 0.010562, val mse = 0.000000.
epoch: numepochs = 31 , Took = 6 seconds; Full-batch train mse = 0.010591, val mse = 0.000000.
epoch: numepochs = 32 , Took = 6 seconds; Full-batch train mse = 0.010389, val mse = 0.000000.
epoch: numepochs = 33 , Took = 6 seconds; Full-batch train mse = 0.015908, val mse = 0.000000.
epoch: numepochs = 34 , Took = 6 seconds; Full-batch train mse = 0.012413, val mse = 0.000000.
epoch: numepochs = 35 , Took = 6 seconds; Full-batch train mse = 0.010442, val mse = 0.000000.
epoch: numepochs = 36 , Took = 6 seconds; Full-batch train mse = 0.056686, val mse = 0.000000.
epoch: numepochs = 37 , Took = 6 seconds; Full-batch train mse = 0.054850, val mse = 0.000000.
epoch: numepochs = 38 , Took = 6 seconds; Full-batch train mse = 0.019422, val mse = 0.000000.
epoch: numepochs = 39 , Took = 6 seconds; Full-batch train mse = 0.016443, val mse = 0.000000.
epoch: numepochs = 40 , Took = 6 seconds; Full-batch train mse = 0.010289, val mse = 0.000000.
epoch: numepochs = 41 , Took = 7 seconds; Full-batch train mse = 0.022615, val mse = 0.000000.
epoch: numepochs = 42 , Took = 6 seconds; Full-batch train mse = 0.010723, val mse = 0.000000.
epoch: numepochs = 43 , Took = 6 seconds; Full-batch train mse = 0.010289, val mse = 0.000000.
epoch: numepochs = 44 , Took = 6 seconds; Full-batch train mse = 0.033933, val mse = 0.000000.
epoch: numepochs = 45 , Took = 7 seconds; Full-batch train mse = 0.030156, val mse = 0.000000.
epoch: numepochs = 46 , Took = 7 seconds; Full-batch train mse = 0.022068, val mse = 0.000000.
epoch: numepochs = 47 , Took = 7 seconds; Full-batch train mse = 0.029382, val mse = 0.000000.
epoch: numepochs = 48 , Took = 6 seconds; Full-batch train mse = 0.021275, val mse = 0.000000.
epoch: numepochs = 49 , Took = 6 seconds; Full-batch train mse = 0.039427, val mse = 0.000000.
epoch: numepochs = 50 , Took = 7 seconds; Full-batch train mse = 0.016674, val mse = 0.000000.
Lỗi của mạng neural = 0.016674267332022572
Kết quả dự đoán
Giá trị thực Dự đoán Lỗi
0.604893 0.190976 -0.413917
0.591746 0.357267 -0.234479
0.579818 0.192327 -0.387491
0.398089 0.192644 -0.205445
0.414092 0.195298 -0.218795
0.088474 0.191101 0.102627
0.358346 0.211703 -0.146643
0.296353 0.354909 0.058556
0.219472 0.191566 -0.027907
0.535717 0.360182 -0.175534
0.554781 0.191250 -0.363531
0.405299 0.218263 -0.187036
0.476532 0.344091 -0.132441
0.057596 0.191437 0.133841
0.254152 0.291695 0.037543
0.273122 0.194527 -0.078595
0.021104 0.191318 0.170214
0.240983 0.334303 0.093320
0.630081 0.359500 -0.270581
0.418276 0.195478 -0.222798
0.252640 0.194558 -0.058083
0.166199 0.191265 0.025066
0.007724 0.190945 0.183220
0.089267 0.191971 0.102704
0.482286 0.192444 -0.289842
0.121666 0.192421 0.070755
0.288349 0.309397 0.021048
0.388173 0.190992 -0.197181
0.345884 0.195769 -0.150115
0.199586 0.193489 -0.006097
0.313404 0.198285 -0.115119
0.317757 0.192116 -0.125642
0.487894 0.191207 -0.296687
0.435984 0.360434 -0.075550
0.173600 0.191433 0.017834
0.362936 0.200448 -0.162488
0.462762 0.211199 -0.251563
0.496521 0.191016 -0.305505
0.126186 0.199396 0.073210
0.452762 0.192416 -0.260346
0.283772 0.201612 -0.082160
0.345902 0.360121 0.014219
0.196150 0.194086 -0.002063
0.221358 0.276164 0.054806
0.433565 0.191503 -0.242062
0.095667 0.190873 0.095206
0.298306 0.199597 -0.098709
0.307053 0.343221 0.036168
0.070527 0.191187 0.120661
0.550118 0.202402 -0.347717
0.318943 0.191767 -0.127176
0.085855 0.208486 0.122632
0.202457 0.192181 -0.010276
0.171277 0.191338 0.020061
0.377919 0.203501 -0.174418
0.241910 0.190898 -0.051012
0.405780 0.356181 -0.049600
0.208344 0.191031 -0.017313
0.496759 0.191523 -0.305236
0.234226 0.192021 -0.042204
0.180459 0.204200 0.023742
0.230961 0.191262 -0.039699
0.406449 0.191730 -0.214719
0.116916 0.192802 0.075886
0.056969 0.190836 0.133867
0.471645 0.250661 -0.220984
0.620847 0.358227 -0.262620
0.465590 0.208306 -0.257285
0.505221 0.198679 -0.306542
0.412723 0.359821 -0.052902
0.169609 0.191355 0.021746
0.197224 0.190805 -0.006418
0.433576 0.209208 -0.224369
0.149656 0.190906 0.041250
0.301522 0.192259 -0.109262
0.000000 0.191439 0.191439
0.365560 0.191898 -0.173662
0.396316 0.194516 -0.201801
0.313326 0.191688 -0.121638
0.503471 0.333933 -0.169539
0.422458 0.353997 -0.068461
0.085231 0.191322 0.106092
0.260809 0.190954 -0.069855
0.132464 0.193043 0.060579
0.130557 0.192246 0.061689
0.236254 0.191761 -0.044493
0.501957 0.308155 -0.193802
0.030391 0.190835 0.160444
0.342746 0.191201 -0.151544
0.451497 0.191612 -0.259885
0.531777 0.351540 -0.180237
0.367773 0.331728 -0.036045
0.416005 0.222780 -0.193224
0.365065 0.325628 -0.039437
0.314009 0.324089 0.010080
0.292589 0.192116 -0.100473
0.465862 0.214683 -0.251179
0.228024 0.191587 -0.036437
0.500358 0.191494 -0.308864
0.444844 0.190862 -0.253982
Trọng số của tầng 1
1.374171 1.099796 -2.307752 2.094696 2.245881 0.718695
-1.188581 0.250464 -1.253987 1.535571 0.144009 1.211066
-0.237848 -0.513377 0.535559 -0.986276 2.234245 -0.521692
2.049615 -0.900046 1.340620 2.118526 1.038388 -0.011886
2.401718 0.534206 2.188686 -0.604587 0.061698 -0.484290
-1.234468 0.779040 -0.220676 -2.041414 -0.932451 0.798505
0.843464 1.861270 2.290145 1.229188 2.363957 -2.117557
2.148848 2.253852 -1.879143 -0.230116 2.434251 -2.184430
1.344634 0.394114 -1.458897 2.656729 -0.857682 -1.991447
1.427775 0.637960 -0.378303 1.415869 1.531836 -1.201613
Trọng số của tầng 2
-1.264966 2.045363 0.390870 1.093048 -1.571245 -0.965506 2.170980 -1.025175 -1.523080 1.769571 0.234782
-0.729776 -0.757614 -0.165234 -1.851681 1.371553 1.957373 0.082469 1.019040 0.386794 0.529335 0.433566
-1.494115 0.474415 -0.432909 1.831860 1.745830 -2.095790 -0.121956 0.037855 -0.661687 -1.298892 -2.014673
-2.036968 -0.650295 -0.503150 1.203369 -0.359007 -1.213966 -0.624731 -1.172636 0.149298 0.780509 0.427563
1.944461 -1.684950 -2.214371 1.452957 -0.817315 -2.043758 1.414190 0.800293 0.510399 1.072725 0.772831
1.328010 -0.572318 1.896532 1.913148 0.711631 -1.133708 1.462404 -1.266144 0.207436 1.760537 0.550331
-0.128898 -1.581367 0.317952 -1.409350 0.377473 2.097325 1.379673 -1.322477 0.051766 0.279797 0.264919
0.355307 -1.236625 -1.760872 1.002497 -2.174841 -0.657898 2.301789 1.147911 -0.384481 -1.757162 0.005889
1.088915 -0.490752 -1.531422 -2.068119 -0.696665 1.717284 -1.030902 1.789453 1.521370 -2.031079 -0.236383
-0.990033 -1.005252 -0.256403 0.574722 -0.672770 -0.096898 1.828632 -1.202745 1.462449 1.866887 0.337038
Trọng số của tầng 3
1.836453 -1.101651 1.955119 -1.573246 -1.180083 -1.192389 0.063129 1.457982 1.728884 1.029547 -0.683558
-1.281489 1.985545 -0.213372 -0.820422 -0.826010 -1.397489 -0.187893 1.585297 -0.947519 -1.010036 1.006970
0.733671 -0.819010 -1.982199 -2.078944 -1.181293 -0.120074 -1.887663 -0.896891 1.347102 -0.152197 -0.927397
-0.542751 -1.637916 1.660354 -0.282633 -1.708930 -0.308678 -0.838637 0.459111 0.267541 -1.848201 -0.231304
-0.375857 0.490077 0.686011 -0.957255 0.303810 -1.733808 -0.918754 -2.143514 1.961464 0.513675 1.736428
1.430016 -0.028645 0.943471 -1.256118 -2.286119 0.839281 1.046352 0.225247 -2.065346 -1.398478 0.806485
-0.020037 1.357183 1.816811 1.173229 -1.121605 -0.799247 1.473700 -1.262947 -2.166720 -0.134442 -0.069909
-1.147538 0.793998 -0.603432 1.260982 -1.055654 -1.260039 0.205150 -0.936812 1.848619 -1.748954 0.596231
-1.667731 1.647581 0.863563 -1.794107 0.235767 1.934559 1.866522 1.715231 1.444815 1.458522 -2.097657
0.860551 0.482214 1.517683 1.465255 1.909458 1.114471 0.489185 -1.621701 -0.266487 -0.702859 1.035102
Trọng số của tầng 4
-0.911096 1.393613 -0.933040 1.236417 0.653925 -0.349759 -0.527293 1.007356 -0.775414 0.397822 0.651697
-2.131394 0.687022 -0.140635 -1.460180 -0.504634 0.454909 -0.346967 -1.124234 -1.272399 0.775430 -0.197188
1.268461 1.433232 -1.973464 -2.150023 1.691873 -0.534257 -1.823694 -0.924831 -1.058514 -0.758573 -1.478647
-1.300421 -1.420639 -1.832267 1.853288 -1.534657 0.189377 -2.148761 0.193008 2.260043 -0.743915 -0.794404
-1.599848 -1.652410 1.594752 -1.866233 1.233801 -0.488452 -1.933618 -0.162073 -0.257516 0.814798 1.399598
2.094657 -0.994127 1.644555 -1.800673 -1.639945 -0.040082 -1.598300 -0.520300 1.787558 -1.094556 1.206877
2.125425 -0.322957 -1.390981 -1.767540 1.619649 -0.113302 1.224949 -0.962328 0.594987 -0.458925 -1.448088
0.411246 1.826639 -1.344491 -1.846602 -1.271082 -2.367224 0.079851 1.192851 0.122415 1.581079 -1.800780
-1.759527 1.270323 -0.016649 -1.979594 -0.020393 -1.431414 -1.365137 1.730292 1.982386 0.352328 0.734322
1.394420 1.249955 -1.549642 1.848772 0.736186 0.631093 -0.881346 0.004308 0.433216 0.808797 0.071110
Trọng số của tầng 5
0.035857 0.609456 1.796846 -2.362628 2.224782 -1.078429 -1.055450 0.586141 -1.368938 -1.627240 -1.650355
1.630199 -0.890691 0.242432 0.400952 -1.075857 -0.905124 -2.201466 -0.601656 -1.546189 0.306886 -0.191094
0.394938 -0.765227 0.680900 1.234827 2.197797 0.745647 2.177222 0.577448 0.671330 -1.422241 0.392331
2.016733 0.907297 -1.173787 0.762875 0.578593 -0.251727 2.052501 1.329991 -0.301679 1.643322 1.465803
-0.709371 -0.363654 0.146708 1.222945 2.036536 -0.025436 1.153799 -1.232430 0.776147 0.879949 -0.587393
0.313975 -1.657314 0.703810 -1.408819 2.097157 -0.703637 0.416049 0.826217 -1.479536 0.285206 -1.542169
-1.369208 -0.185644 1.477393 1.566187 2.230336 1.686712 -2.256248 -0.707443 0.316977 -0.738799 -1.153392
-1.225574 -1.012741 -1.375621 1.930937 0.551201 -1.886053 1.834320 1.749923 -0.054702 0.141541 0.534773
0.868562 1.011788 -1.068028 -0.642395 2.095731 -0.329274 0.411534 0.244882 -0.366905 1.094661 -1.297273
0.929739 -0.848332 0.988637 -1.149095 0.488179 1.201938 -1.840580 1.612710 -2.048002 -0.975730 0.521178
Trọng số của tầng 6
-1.154830 -1.600131 -2.387282 -1.082677 -0.013943 0.105339 0.574241 -1.501416 0.805800 1.847965 -0.235089