Mạng Neural trong Spark MLlib: Từ Nguyên lý đến Thực thi

  1. Kiến thức nền tảng về mạng neural

1.1 Đơn vị xử lý thần kinh (Neuron)

Mạng neural (Neural Network) được cấu tạo từ nhiều đơn vị xử lý thần kinh kết nối với nhau. Neuron là thành phần cơ bản nhất của mạng neural, và toàn bộ mạng được xây dựng từ nhiều neuron. Cấu trúc của một neuron như sau:

Trong đó, x₁, x₂, x₃ và 1 là các đầu vào, hw,b(x) là đầu ra.

Hàm f(x) được gọi là hàm kích hoạt (activation function). Hai hàm kích hoạt phổ biến nhất là hàm sigmoid và hàm tanh (hyperbolic tangent).

Hàm sigmoid có công thức:

Hàm tanh có công thức:

1.2 Kiến trúc mạng neural

Mạng neural bao gồm nhiều tầng, với các neuron ở các tầng liền kề có mối quan hệ đầu vào. Tầng đầu tiên được gọi là tầng đầu vào (input layer), tầng cuối cùng là tầng đầu ra (output layer), và các tầng ở giữa được gọi là tầng ẩn (hidden layers).

1.3 Lan truyền tiến và lan truyền ngược

Giả sử mạng neural có n tầng, được đánh số lần lượt là L₁, L₂, ..., Lₙ. Số lượng neuron ở tầng thứ p (với p = 1,2,...,n) là mₚ. aⱽᵏ) biểu thị giá trị đầu ra của neuron thứ j ở tầng k.

Đối với tầng L₁ (tầng đầu vào), ta có:

Đầu ra của neuron thứ j ở tầng (k+1) được tính bằng:

Giả sử lỗi của một mẫu huấn luyện là:

Hàm lỗi tổng thể được định nghĩa như sau:

Để tránh overfitting, ta thêm thành phần L2 regularization:

Mục tiêu là tìm (w,b) để J(w,b) đạt giá trị nhỏ nhất. Ta sử dụng phương pháp gradient descent, cập nhật w và b theo công thức sau ở mỗi lần lặp:

Đối với neuron thứ j ở tầng cuối cùng (tầng đầu ra), sai số (residual) được tính bằng:

Sai số của neuron thứ i ở tầng k được tính bằng:

Quy trình tìm (w,b) như sau:

  1. Với mọi k, khởi tạo w⁽ᵏ⁾:=0, b⁽ᵏ⁾:=0;

  2. Lan truyền tiến: Với mỗi mẫu dữ liệu, tính toán hw,b(x) dựa trên các giá trị đầu vào và w⁽ᵏ⁾, b⁽ᵏ⁾;

  3. Lan truyền ngược: Tính toán sai số cho từng neuron ở mỗi tầng;

  4. Cập nhật giá trị của w và b.

Lặp lại các bước (2)~(4) cho đến khi đạt số lần lặp mong muốn.

  1. Triển khai mạng neural trong MLlib

Lớp mạng neural trong MLlib là NeuralNet. Các tham số chính bao gồm:

  • Size: Mảng[Int] - Số lượng neuron ở mỗi tầng
  • Layer: Số lượng tầng của mạng neural
  • Activation_function: Hàm kích hoạt, có thể là 'sigm' hoặc 'tanh'
  • Output_function: Hàm đầu ra, có thể là 'sigm', 'softmax' hoặc 'linear'

Mã nguồn triển khai:

import org.apache.log4j.{ Level, Logger }
import org.apache.spark.{ SparkConf, SparkContext }
import breeze.linalg.{
  DenseMatrix => BDM,
  max => Bmax,
  min => Bmin
}
import scala.collection.mutable.ArrayBuffer

object NeuralNetworkExample {

  def main(args: Array[String]): Unit = {
    // Cấu hình môi trường chạy
    val conf = new SparkConf()
      .setAppName("Neural Network Example")
      .setMaster("spark://master:7077")
      .setJars(Seq("E:\\Projects\\SparkML\\NeuralNetwork.jar"))
    val sc = new SparkContext(conf)
    Logger.getRootLogger.setLevel(Level.WARN)

    // Tạo dữ liệu mẫu ngẫu nhiên
    val sampleRows = 1000
    val sampleCols = 5
    val randomSamples = RandSampleData.generateMatrix(sampleRows, sampleCols, -10, 10, "sphere")
    
    // Chuẩn hóa dữ liệu
    val maxValues = Bmax(randomSamples(::, breeze.linalg.*))
    val minValues = Bmin(randomSamples(::, breeze.linalg.*))
    val normalized1 = randomSamples - (BDM.ones[Double](randomSamples.rows, 1)) * minValues
    val normalized2 = normalized1 :/ ((BDM.ones[Double](normalized1.rows, 1)) * (maxValues - minValues))
    
    // Chuyển đổi định dạng dữ liệu
    val sampleList = ArrayBuffer[BDM[Double]]()
    for (i <- 0 until sampleRows) {
      val matrixRow = normalized2(i, ::)
      val rowVector = matrixRow.inner
      val rowArray = rowVector.toArray
      val sampleMatrix = new BDM(1, rowArray.length, rowArray)
      sampleList += sampleMatrix
    }
    
    val parallelSamples = sc.parallelize(sampleList, 10)
    sc.setCheckpointDir("hdfs://master:9000/ml/checkpoint")
    parallelSamples.checkpoint()
    
    val trainingData = parallelSamples.map(f => (new BDM(1, 1, f(::, 0).data), f(::, 1 to -1)))
    
    // Huấn luyện mô hình
    val options = Array(100.0, 50.0, 0.0)
    trainingData.cache()
    val totalExamples = trainingData.count()
    println(s"Số lượng mẫu: $totalExamples")
    
    val neuralModel = new NeuralNet()
      .setSize(Array(5, 10, 10, 10, 10, 10, 1))
      .setLayer(7)
      .setActivation_function("tanh_opt")
      .setLearningRate(2.0)
      .setScaling_learningRate(1.0)
      .setWeightPenaltyL2(0.0)
      .setNonSparsityPenalty(0.0)
      .setSparsityTarget(0.05)
      .setInputZeroMaskedFraction(0.0)
      .setDropoutFraction(0.0)
      .setOutput_function("sigm")
      .train(trainingData, options)

    // Kiểm tra mô hình
    val predictions = neuralModel.predict(trainingData)
    val predictionError = neuralModel.Loss(predictions)
    println(s"Lỗi của mạng neural = $predictionError")
    
    val samplePredictions = predictions.map(f => (f.label.data(0), f.predict_label.data(0))).take(100)
    println("Kết quả dự đoán")
    println("Giá trị thực\t\tDự đoán\t\tLỗi")
    for (i <- 0 until samplePredictions.length)
      println(f"${samplePredictions(i)._1}%.6f\t\t${samplePredictions(i)._2}%.6f\t\t${samplePredictions(i)._2 - samplePredictions(i)._1}%.6f")

    // Hiển thị trọng số của các tầng
    var layerWeights = neuralModel.weights(0)
    for (i <- 0 to 5) {
      layerWeights = neuralModel.weights(i)
      println(s"Trọng số của tầng ${i+1}")
      for (j <- 0 to layerWeights.rows - 1) {
        for (k <- 0 to layerWeights.cols - 1) {
          print(f"${layerWeights(j, k)}%.6f\t")
        }
        println()
      }
    }
  }
}

Mã trên xây dựng một mạng neural 7 tầng với số lượng neuron ở các tầng tương ứng là Array(5, 10, 10, 10, 10, 10, 1), và kiểm tra hàm Sphere.

Kết quả thực hiện:

Số lượng mẫu: 1000
epoch: numepochs = 1 , Took = 17 seconds; Full-batch train mse = 0.066738, val mse = 0.000000.
epoch: numepochs = 2 , Took = 12 seconds; Full-batch train mse = 0.069649, val mse = 0.000000.
epoch: numepochs = 3 , Took = 10 seconds; Full-batch train mse = 0.055260, val mse = 0.000000.
epoch: numepochs = 4 , Took = 10 seconds; Full-batch train mse = 0.016346, val mse = 0.000000.
epoch: numepochs = 5 , Took = 9 seconds; Full-batch train mse = 0.013802, val mse = 0.000000.
epoch: numepochs = 6 , Took = 13 seconds; Full-batch train mse = 0.045142, val mse = 0.000000.
epoch: numepochs = 7 , Took = 7 seconds; Full-batch train mse = 0.031211, val mse = 0.000000.
epoch: numepochs = 8 , Took = 7 seconds; Full-batch train mse = 0.016334, val mse = 0.000000.
epoch: numepochs = 9 , Took = 9 seconds; Full-batch train mse = 0.013348, val mse = 0.000000.
epoch: numepochs = 10 , Took = 7 seconds; Full-batch train mse = 0.017879, val mse = 0.000000.
epoch: numepochs = 11 , Took = 7 seconds; Full-batch train mse = 0.012627, val mse = 0.000000.
epoch: numepochs = 12 , Took = 7 seconds; Full-batch train mse = 0.018080, val mse = 0.000000.
epoch: numepochs = 13 , Took = 7 seconds; Full-batch train mse = 0.016755, val mse = 0.000000.
epoch: numepochs = 14 , Took = 7 seconds; Full-batch train mse = 0.012250, val mse = 0.000000.
epoch: numepochs = 15 , Took = 7 seconds; Full-batch train mse = 0.044833, val mse = 0.000000.
epoch: numepochs = 16 , Took = 7 seconds; Full-batch train mse = 0.024345, val mse = 0.000000.
epoch: numepochs = 17 , Took = 7 seconds; Full-batch train mse = 0.039005, val mse = 0.000000.
epoch: numepochs = 18 , Took = 7 seconds; Full-batch train mse = 0.012298, val mse = 0.000000.
epoch: numepochs = 19 , Took = 7 seconds; Full-batch train mse = 0.012371, val mse = 0.000000.
epoch: numepochs = 20 , Took = 6 seconds; Full-batch train mse = 0.014077, val mse = 0.000000.
epoch: numepochs = 21 , Took = 7 seconds; Full-batch train mse = 0.040328, val mse = 0.000000.
epoch: numepochs = 22 , Took = 6 seconds; Full-batch train mse = 0.036575, val mse = 0.000000.
epoch: numepochs = 23 , Took = 6 seconds; Full-batch train mse = 0.033986, val mse = 0.000000.
epoch: numepochs = 24 , Took = 6 seconds; Full-batch train mse = 0.026421, val mse = 0.000000.
epoch: numepochs = 25 , Took = 6 seconds; Full-batch train mse = 0.036776, val mse = 0.000000.
epoch: numepochs = 26 , Took = 6 seconds; Full-batch train mse = 0.011838, val mse = 0.000000.
epoch: numepochs = 27 , Took = 6 seconds; Full-batch train mse = 0.010749, val mse = 0.000000.
epoch: numepochs = 28 , Took = 6 seconds; Full-batch train mse = 0.012717, val mse = 0.000000.
epoch: numepochs = 29 , Took = 6 seconds; Full-batch train mse = 0.011883, val mse = 0.000000.
epoch: numepochs = 30 , Took = 7 seconds; Full-batch train mse = 0.010562, val mse = 0.000000.
epoch: numepochs = 31 , Took = 6 seconds; Full-batch train mse = 0.010591, val mse = 0.000000.
epoch: numepochs = 32 , Took = 6 seconds; Full-batch train mse = 0.010389, val mse = 0.000000.
epoch: numepochs = 33 , Took = 6 seconds; Full-batch train mse = 0.015908, val mse = 0.000000.
epoch: numepochs = 34 , Took = 6 seconds; Full-batch train mse = 0.012413, val mse = 0.000000.
epoch: numepochs = 35 , Took = 6 seconds; Full-batch train mse = 0.010442, val mse = 0.000000.
epoch: numepochs = 36 , Took = 6 seconds; Full-batch train mse = 0.056686, val mse = 0.000000.
epoch: numepochs = 37 , Took = 6 seconds; Full-batch train mse = 0.054850, val mse = 0.000000.
epoch: numepochs = 38 , Took = 6 seconds; Full-batch train mse = 0.019422, val mse = 0.000000.
epoch: numepochs = 39 , Took = 6 seconds; Full-batch train mse = 0.016443, val mse = 0.000000.
epoch: numepochs = 40 , Took = 6 seconds; Full-batch train mse = 0.010289, val mse = 0.000000.
epoch: numepochs = 41 , Took = 7 seconds; Full-batch train mse = 0.022615, val mse = 0.000000.
epoch: numepochs = 42 , Took = 6 seconds; Full-batch train mse = 0.010723, val mse = 0.000000.
epoch: numepochs = 43 , Took = 6 seconds; Full-batch train mse = 0.010289, val mse = 0.000000.
epoch: numepochs = 44 , Took = 6 seconds; Full-batch train mse = 0.033933, val mse = 0.000000.
epoch: numepochs = 45 , Took = 7 seconds; Full-batch train mse = 0.030156, val mse = 0.000000.
epoch: numepochs = 46 , Took = 7 seconds; Full-batch train mse = 0.022068, val mse = 0.000000.
epoch: numepochs = 47 , Took = 7 seconds; Full-batch train mse = 0.029382, val mse = 0.000000.
epoch: numepochs = 48 , Took = 6 seconds; Full-batch train mse = 0.021275, val mse = 0.000000.
epoch: numepochs = 49 , Took = 6 seconds; Full-batch train mse = 0.039427, val mse = 0.000000.
epoch: numepochs = 50 , Took = 7 seconds; Full-batch train mse = 0.016674, val mse = 0.000000.
Lỗi của mạng neural = 0.016674267332022572
Kết quả dự đoán
Giá trị thực            Dự đoán                Lỗi
0.604893                0.190976               -0.413917
0.591746                0.357267               -0.234479
0.579818                0.192327               -0.387491
0.398089                0.192644               -0.205445
0.414092                0.195298               -0.218795
0.088474                0.191101               0.102627
0.358346                0.211703               -0.146643
0.296353                0.354909               0.058556
0.219472                0.191566               -0.027907
0.535717                0.360182               -0.175534
0.554781                0.191250               -0.363531
0.405299                0.218263               -0.187036
0.476532                0.344091               -0.132441
0.057596                0.191437               0.133841
0.254152                0.291695               0.037543
0.273122                0.194527               -0.078595
0.021104                0.191318               0.170214
0.240983                0.334303               0.093320
0.630081                0.359500               -0.270581
0.418276                0.195478               -0.222798
0.252640                0.194558               -0.058083
0.166199                0.191265               0.025066
0.007724                0.190945               0.183220
0.089267                0.191971               0.102704
0.482286                0.192444               -0.289842
0.121666                0.192421               0.070755
0.288349                0.309397               0.021048
0.388173                0.190992               -0.197181
0.345884                0.195769               -0.150115
0.199586                0.193489               -0.006097
0.313404                0.198285               -0.115119
0.317757                0.192116               -0.125642
0.487894                0.191207               -0.296687
0.435984                0.360434               -0.075550
0.173600                0.191433               0.017834
0.362936                0.200448               -0.162488
0.462762                0.211199               -0.251563
0.496521                0.191016               -0.305505
0.126186                0.199396               0.073210
0.452762                0.192416               -0.260346
0.283772                0.201612               -0.082160
0.345902                0.360121               0.014219
0.196150                0.194086               -0.002063
0.221358                0.276164               0.054806
0.433565                0.191503               -0.242062
0.095667                0.190873               0.095206
0.298306                0.199597               -0.098709
0.307053                0.343221               0.036168
0.070527                0.191187               0.120661
0.550118                0.202402               -0.347717
0.318943                0.191767               -0.127176
0.085855                0.208486               0.122632
0.202457                0.192181               -0.010276
0.171277                0.191338               0.020061
0.377919                0.203501               -0.174418
0.241910                0.190898               -0.051012
0.405780                0.356181               -0.049600
0.208344                0.191031               -0.017313
0.496759                0.191523               -0.305236
0.234226                0.192021               -0.042204
0.180459                0.204200               0.023742
0.230961                0.191262               -0.039699
0.406449                0.191730               -0.214719
0.116916                0.192802               0.075886
0.056969                0.190836               0.133867
0.471645                0.250661               -0.220984
0.620847                0.358227               -0.262620
0.465590                0.208306               -0.257285
0.505221                0.198679               -0.306542
0.412723                0.359821               -0.052902
0.169609                0.191355               0.021746
0.197224                0.190805               -0.006418
0.433576                0.209208               -0.224369
0.149656                0.190906               0.041250
0.301522                0.192259               -0.109262
0.000000                0.191439               0.191439
0.365560                0.191898               -0.173662
0.396316                0.194516               -0.201801
0.313326                0.191688               -0.121638
0.503471                0.333933               -0.169539
0.422458                0.353997               -0.068461
0.085231                0.191322               0.106092
0.260809                0.190954               -0.069855
0.132464                0.193043               0.060579
0.130557                0.192246               0.061689
0.236254                0.191761               -0.044493
0.501957                0.308155               -0.193802
0.030391                0.190835               0.160444
0.342746                0.191201               -0.151544
0.451497                0.191612               -0.259885
0.531777                0.351540               -0.180237
0.367773                0.331728               -0.036045
0.416005                0.222780               -0.193224
0.365065                0.325628               -0.039437
0.314009                0.324089               0.010080
0.292589                0.192116               -0.100473
0.465862                0.214683               -0.251179
0.228024                0.191587               -0.036437
0.500358                0.191494               -0.308864
0.444844                0.190862               -0.253982

Trọng số của tầng 1
1.374171    1.099796    -2.307752    2.094696    2.245881    0.718695
-1.188581    0.250464    -1.253987    1.535571    0.144009    1.211066
-0.237848    -0.513377    0.535559    -0.986276    2.234245    -0.521692
2.049615    -0.900046    1.340620    2.118526    1.038388    -0.011886
2.401718    0.534206    2.188686    -0.604587    0.061698    -0.484290
-1.234468    0.779040    -0.220676    -2.041414    -0.932451    0.798505
0.843464    1.861270    2.290145    1.229188    2.363957    -2.117557
2.148848    2.253852    -1.879143    -0.230116    2.434251    -2.184430
1.344634    0.394114    -1.458897    2.656729    -0.857682    -1.991447
1.427775    0.637960    -0.378303    1.415869    1.531836    -1.201613

Trọng số của tầng 2
-1.264966    2.045363    0.390870    1.093048    -1.571245    -0.965506    2.170980    -1.025175    -1.523080    1.769571    0.234782
-0.729776    -0.757614    -0.165234    -1.851681    1.371553    1.957373    0.082469    1.019040    0.386794    0.529335    0.433566
-1.494115    0.474415    -0.432909    1.831860    1.745830    -2.095790    -0.121956    0.037855    -0.661687    -1.298892    -2.014673
-2.036968    -0.650295    -0.503150    1.203369    -0.359007    -1.213966    -0.624731    -1.172636    0.149298    0.780509    0.427563
1.944461    -1.684950    -2.214371    1.452957    -0.817315    -2.043758    1.414190    0.800293    0.510399    1.072725    0.772831
1.328010    -0.572318    1.896532    1.913148    0.711631    -1.133708    1.462404    -1.266144    0.207436    1.760537    0.550331
-0.128898    -1.581367    0.317952    -1.409350    0.377473    2.097325    1.379673    -1.322477    0.051766    0.279797    0.264919
0.355307    -1.236625    -1.760872    1.002497    -2.174841    -0.657898    2.301789    1.147911    -0.384481    -1.757162    0.005889
1.088915    -0.490752    -1.531422    -2.068119    -0.696665    1.717284    -1.030902    1.789453    1.521370    -2.031079    -0.236383
-0.990033    -1.005252    -0.256403    0.574722    -0.672770    -0.096898    1.828632    -1.202745    1.462449    1.866887    0.337038

Trọng số của tầng 3
1.836453    -1.101651    1.955119    -1.573246    -1.180083    -1.192389    0.063129    1.457982    1.728884    1.029547    -0.683558
-1.281489    1.985545    -0.213372    -0.820422    -0.826010    -1.397489    -0.187893    1.585297    -0.947519    -1.010036    1.006970
0.733671    -0.819010    -1.982199    -2.078944    -1.181293    -0.120074    -1.887663    -0.896891    1.347102    -0.152197    -0.927397
-0.542751    -1.637916    1.660354    -0.282633    -1.708930    -0.308678    -0.838637    0.459111    0.267541    -1.848201    -0.231304
-0.375857    0.490077    0.686011    -0.957255    0.303810    -1.733808    -0.918754    -2.143514    1.961464    0.513675    1.736428
1.430016    -0.028645    0.943471    -1.256118    -2.286119    0.839281    1.046352    0.225247    -2.065346    -1.398478    0.806485
-0.020037    1.357183    1.816811    1.173229    -1.121605    -0.799247    1.473700    -1.262947    -2.166720    -0.134442    -0.069909
-1.147538    0.793998    -0.603432    1.260982    -1.055654    -1.260039    0.205150    -0.936812    1.848619    -1.748954    0.596231
-1.667731    1.647581    0.863563    -1.794107    0.235767    1.934559    1.866522    1.715231    1.444815    1.458522    -2.097657
0.860551    0.482214    1.517683    1.465255    1.909458    1.114471    0.489185    -1.621701    -0.266487    -0.702859    1.035102

Trọng số của tầng 4
-0.911096    1.393613    -0.933040    1.236417    0.653925    -0.349759    -0.527293    1.007356    -0.775414    0.397822    0.651697
-2.131394    0.687022    -0.140635    -1.460180    -0.504634    0.454909    -0.346967    -1.124234    -1.272399    0.775430    -0.197188
1.268461    1.433232    -1.973464    -2.150023    1.691873    -0.534257    -1.823694    -0.924831    -1.058514    -0.758573    -1.478647
-1.300421    -1.420639    -1.832267    1.853288    -1.534657    0.189377    -2.148761    0.193008    2.260043    -0.743915    -0.794404
-1.599848    -1.652410    1.594752    -1.866233    1.233801    -0.488452    -1.933618    -0.162073    -0.257516    0.814798    1.399598
2.094657    -0.994127    1.644555    -1.800673    -1.639945    -0.040082    -1.598300    -0.520300    1.787558    -1.094556    1.206877
2.125425    -0.322957    -1.390981    -1.767540    1.619649    -0.113302    1.224949    -0.962328    0.594987    -0.458925    -1.448088
0.411246    1.826639    -1.344491    -1.846602    -1.271082    -2.367224    0.079851    1.192851    0.122415    1.581079    -1.800780
-1.759527    1.270323    -0.016649    -1.979594    -0.020393    -1.431414    -1.365137    1.730292    1.982386    0.352328    0.734322
1.394420    1.249955    -1.549642    1.848772    0.736186    0.631093    -0.881346    0.004308    0.433216    0.808797    0.071110

Trọng số của tầng 5
0.035857    0.609456    1.796846    -2.362628    2.224782    -1.078429    -1.055450    0.586141    -1.368938    -1.627240    -1.650355
1.630199    -0.890691    0.242432    0.400952    -1.075857    -0.905124    -2.201466    -0.601656    -1.546189    0.306886    -0.191094
0.394938    -0.765227    0.680900    1.234827    2.197797    0.745647    2.177222    0.577448    0.671330    -1.422241    0.392331
2.016733    0.907297    -1.173787    0.762875    0.578593    -0.251727    2.052501    1.329991    -0.301679    1.643322    1.465803
-0.709371    -0.363654    0.146708    1.222945    2.036536    -0.025436    1.153799    -1.232430    0.776147    0.879949    -0.587393
0.313975    -1.657314    0.703810    -1.408819    2.097157    -0.703637    0.416049    0.826217    -1.479536    0.285206    -1.542169
-1.369208    -0.185644    1.477393    1.566187    2.230336    1.686712    -2.256248    -0.707443    0.316977    -0.738799    -1.153392
-1.225574    -1.012741    -1.375621    1.930937    0.551201    -1.886053    1.834320    1.749923    -0.054702    0.141541    0.534773
0.868562    1.011788    -1.068028    -0.642395    2.095731    -0.329274    0.411534    0.244882    -0.366905    1.094661    -1.297273
0.929739    -0.848332    0.988637    -1.149095    0.488179    1.201938    -1.840580    1.612710    -2.048002    -0.975730    0.521178

Trọng số của tầng 6
-1.154830    -1.600131    -2.387282    -1.082677    -0.013943    0.105339    0.574241    -1.501416    0.805800    1.847965    -0.235089

Thẻ: Spark MLlib Neural Network Machine Learning Big Data

Đăng vào ngày 24 tháng 5 lúc 22:33