用 Apache Spark 來訓練 Word2vec 詞向量 Skip-gram Word Embeddings

作者: Yong-Siang Shih / Sun 30 August 2015 / 分類: Notes

Scala, Spark, word embeddings, word2vec

先前曾提過 word2vec 可以把語料中的詞轉換成詞向量。 雖然原本提供的 word2vec 工具速度已經很快, 但是如果要訓練更大規模的語料還是需要不少時間。 例如之前在處理 ClueWeb09 時,以實驗室的機器來說,就算只處理中文部份, 也要一個月以上才能跑完。 此時除了購買更強大的機器外,如果已經有不少機器, 或許可以利用平行運算的方式來加速。

最近剛好接觸到了 Apache Spark,他是一個開源的運算平台, 可以運用多台電腦進行平行運算。 且因為把很多資料直接放在記憶體中處理,又比 Apache Hadoop 單純的 MapReduce 更快一些。更重要的是,他的機器學習函式庫 MLlib 已經實作了 word2vec 當中的 skip-gram 模型,正好可以直接拿來訓練詞向量。

安裝

關於如何將 Spark 安裝在一個 cluster 上,可以參考官方文件。 這裡我們只簡單的安裝單機版的 Spark 方便快速的實驗。 我們將會使用 Ubuntu 14.04 作為實驗平台。

安裝 Java 8

首先安裝 Java 8,如果你已經有裝了則可以跳過。

sudo add-apt-repository ppa:webupd8team/java
sudo apt-get update
sudo apt-get install oracle-java8-installer
sudo apt-get install oracle-java8-set-default

安裝 Apache Spark

接下來,到下載頁面下載 Spark, 我是選擇 Spark 1.4.1 Pre-built for Hadoop 2.6 and later 的 binary。 不過如果有新版應該變化也不大。直接解壓縮就可以用囉:

tar -xzf spark-1.4.1-bin-hadoop2.6.tgz

安裝 sbt

因為我打算用 Scala 所以我們還得安裝 sbt 這個 build tool。 於是依照 sbt 的下載頁面的說明安裝 sbt:

echo "deb https://dl.bintray.com/sbt/debian /" | sudo tee -a /etc/apt/sources.list.d/sbt.list
sudo apt-key adv --keyserver hkp://keyserver.ubuntu.com:80 --recv 642AC823
sudo apt-get update
sudo apt-get install sbt

下載資料集

為了簡單起見,使用 100MB 的 Wikipedia dump:

wget http://mattmahoney.net/dc/text8.zip -O text8.gz
gzip -d text8.gz -f

由於是單機,所以可以放在任意資料夾,如果是在 cluster 的話,需要放在每台機器都能存取的同一位置。

編寫 Spark 程式

首先建立專案資料夾:

mkdir -p sparkw2v/src/main/scala/

然後編輯 sparkw2v/sparkw2v.sbt 檔案,程式版本參考官方文件:

name := "Spark Word2Vec"

version := "1.0"

scalaVersion := "2.10.4"

libraryDependencies ++= Seq(
    "org.apache.spark" %% "spark-core" % "1.4.1" % "provided",
    "org.apache.spark" %% "spark-mllib" % "1.4.1"
)

最後則是程式碼 sparkw2v/src/main/scala/SparkW2V.scala 本身,注意要設定輸出入 {input directory}{output directory}的路徑。 同樣的如果是在 cluster 的話,需要放在每台機器都能存取的同一位置。 Word2Vec 參數設定則可參考 API 文件

import org.apache.spark._
import org.apache.spark.rdd._
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf

import org.apache.spark.mllib.feature.Word2Vec

object SparkW2V {
  def main(args: Array[String]) {
    val text8 = "{input directory}/text8"
    val output = "{output directory}/model"
    val conf = new SparkConf().setAppName("Spark Word2Vec")
    val sc = new SparkContext(conf)

    val input = sc.textFile(text8).map(line => line.split(" ").toSeq)
    val word2vec = new Word2Vec()

    val model = word2vec.fit(input)
    model.save(sc, output)
  }
}

Build

使用 sbt 進行 package:

cd sparkw2v
sbt package

根據版本不同,產生的檔案名稱也不同,我的輸出是 sparkw2v/target/scala-2.10/spark-word2vec_2.10-1.0.jar

執行

接下來就可以回到根目錄,利用 spark-submit 執行程式。筆者不確定記憶體的需求為何,所以開的大了些:

cd ..
spark-1.4.1-bin-hadoop2.6/bin/spark-submit --class SparkW2V --master local[*] --executor-memory 20G --driver-memory 10G sparkw2v/target/scala-2.10/spark-word2vec_2.10-1.0.jar

輸出的向量會放在 {output directory}/model/data/ 底下,而且是存成 Parquet 的格式,不太方便。 所以我們使用 spark-shell 快速的將檔案轉成文字檔:

# 執行 Spark Shell
./spark-1.4.1-bin-hadoop2.6/bin/spark-shell

# 從這裡開始是 Spark Shell
#
# Welcome to
#       ____              __
#      / __/__  ___ _____/ /__
#     _\ \/ _ \/ _ `/ __/  '_/
#    /___/ .__/\_,_/_/ /_/\_\   version 1.4.1
#       /_/
#

# 讀取檔案
scala> val d = sqlContext.read.parquet("{output directory}/model/data")
d: org.apache.spark.sql.DataFrame = [word: string, vector: array<float>]

# 檢查格式
scala> d.first
res2: org.apache.spark.sql.Row = [latifolia,ArrayBuffer(-0.08103186, 0.14688604, -0.060668133, -0.25648367, -0.06855837, -0...

# 輸出到 output directory/vectors/
scala> d.map{r => r.getString(0) + " " + r.getSeq(1).mkString(" ")}.saveAsTextFile("{output directory}/vectors")

如此便完成了!

原始碼

相關程式碼放在 shaform/experiments/spark_word2vec

Dots
Yong-Siang Shih

作者

Yong-Siang Shih

軟體工程師,機器學習科學家,開放原始碼愛好者。曾在 Appier 從事機器學習系統開發,也曾在 Google, IBM, Microsoft 擔任軟體實習生。喜好探索學習新科技。* 在 GitHub 上追蹤我

載入 Disqus 評論