自己写的精简版的nebula-spark-writer的核心代码,避开了上面提到的版本问题,可以参考一下,希望对你有用。自己在生产已经正常使用,并且取得了预期的效果,完成了几十个T的数据迁移
package com.qihoo.finance.graph.data.nebula
import com.vesoft.nebula.client.graph.NebulaPoolConfig
import com.vesoft.nebula.client.graph.data.{HostAddress, ResultSet}
import com.vesoft.nebula.client.graph.net.{NebulaPool, Session}
import com.vesoft.nebula.connector.connector.Address
import com.vesoft.nebula.connector.exception.GraphConnectException
import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
/**
* GraphProvider for Nebula Graph Service
*/
class GraphProvider(addresses: List[Address]) extends AutoCloseable with Serializable {
// private[this] lazy val LOG = Logger.getLogger(this.getClass)
@transient val nebulaPoolConfig = new NebulaPoolConfig
@transient val pool: NebulaPool = new NebulaPool
@transient val address = new ListBuffer[HostAddress]()
for (addr <- addresses) {
address.append(new HostAddress(addr._1, addr._2))
}
nebulaPoolConfig.setMaxConnSize(1)
pool.init(address.asJava, nebulaPoolConfig)
var session: Session = null
override def close(): Unit = {
pool.close()
}
def switchSpace(user: String, password: String, space: String): Boolean = {
if (session == null) {
session = pool.getSession(user, password, true)
}
val switchStatment = s"use $space"
val result = submit(switchStatment)
result.isSucceeded
}
def submit(statement: String): ResultSet = {
if (session == null) {
throw new GraphConnectException("session is null")
}
session.execute(statement)
}
}
package com.qihoo.finance.graph.data.nebula
import java.util.regex.Pattern
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}
/**
* @Description NebulaBaseWriter is used for
* @author huangzhaolai-jk
* @Date 2021/8/2 - 19:13
* @version 1.0.0
*/
abstract class NebulaBaseWriter extends Serializable {
val BATCH_INSERT_TEMPLATE = "INSERT %s %s(%s) VALUES %s"
val INSERT_VALUE_TEMPLATE = "%s: (%s)"
val INSERT_VALUE_TEMPLATE_WITH_POLICY = "%s(%s): (%s)"
val EDGE_VALUE_WITHOUT_RANKING_TEMPLATE = "%s->%s: (%s)"
val ENDPOINT_TEMPLATE = "%s(\"%s\")"
val EDGE_VALUE_WITHOUT_RANKING_TEMPLATE_WITH_POLICY = "%s->%s: (%s)"
val EDGE_VALUE_TEMPLATE = "%s->%s@%d: (%s)"
val EDGE_VALUE_TEMPLATE_WITH_POLICY = "%s->%s@%d: (%s)"
def write(rdd: DataFrame): Unit
/**
* Extra value from the row by field name.
* When the field is null, we will fill it with default value.
*
* @param row The row value.
* @param field The field name.
* @return
*/
def extraValue(row: Row, field: String): Any = {
val index = row.schema.fieldIndex(field)
row.schema.fields(index).dataType match {
case StringType =>
if (!row.isNullAt(index)) {
var str = row.getString(index)
if ("contact_name".equals(field)) {
removeSpecialChar(str).mkString("\"", "", "\"")
} else if (str != null && !str.equals("")) {
str.replace("\\", "\\\\")
.replace("\"", "\\\"")
.replace("'", "\\'")
.mkString("\"", "", "\"")
} else {
"\"\""
}
} else {
"\"\""
}
case ShortType =>
if (!row.isNullAt(index)) {
row.getShort(index).toString
} else {
"0"
}
case IntegerType =>
if (!row.isNullAt(index)) {
row.getInt(index).toString
} else {
"0"
}
case LongType =>
if (!row.isNullAt(index)) {
row.getLong(index).toString
} else {
"0"
}
case FloatType =>
if (!row.isNullAt(index)) {
row.getFloat(index).toString
} else {
"0.0"
}
case DoubleType =>
if (!row.isNullAt(index)) {
row.getDouble(index).toString
} else {
"0.0"
}
case _: DecimalType =>
if (!row.isNullAt(index)) {
row.getDecimal(index).toString
} else {
"0.0"
}
case BooleanType =>
if (!row.isNullAt(index)) {
row.getBoolean(index).toString
} else {
"false"
}
case TimestampType =>
if (!row.isNullAt(index)) {
row.getTimestamp(index).getTime / 1000L
} else {
"0"
}
case _: DateType =>
if (!row.isNullAt(index)) {
row.getDate(index).toString
} else {
"\"\""
}
case _: ArrayType =>
if (!row.isNullAt(index)) {
row.getSeq(index).mkString("\"[", ",", "]\"")
} else {
"\"[]\""
}
case _: MapType =>
if (!row.isNullAt(index)) {
row.getMap(index).mkString("\"{", ",", "}\"")
} else {
"\"{}\""
}
}
}
val SPECIAL_CHAR_PATTERN: Pattern = Pattern.compile("[\n\t\"\'()<>/\\\\]")
def removeSpecialChar(value: String): String = SPECIAL_CHAR_PATTERN.matcher(value).replaceAll("")
}
package com.qihoo.finance.graph.data.nebula
import com.qihoo.finance.graph.data.config.{EdgeConfig, NebulaConfig}
import com.vesoft.nebula.connector.NebulaOptions
import org.apache.commons.lang.StringUtils
import org.apache.spark.sql.{DataFrame, Encoders}
import scala.collection.mutable.ListBuffer
/**
* @Description NebulaEdgeWriter is used for
* @author huangzhaolai-jk
* @Date 2021/8/2 - 15:54
* @version 1.0.0
*/
class NebulaEdgeWriter(nebulaOptions: NebulaOptions, edgeConfig: EdgeConfig, nebulaConfig: NebulaConfig) extends NebulaBaseWriter {
// private val properties: Map[String, Integer] = metaProvider.getEdgeSchema(nebulaConfig.getSpace, edgeConfig.getEdge)
private var properties: Map[String, Integer] = Map()
override def write(rdd: DataFrame): Unit = {
val srcKeyPolicy = if (StringUtils.isBlank(edgeConfig.getSrcKeyPolicy)) Option.empty else Option(KeyPolicy.withName(edgeConfig.getSrcKeyPolicy))
val dstKeyPolicy = if (StringUtils.isBlank(edgeConfig.getDstKeyPolicy)) Option.empty else Option(KeyPolicy.withName(edgeConfig.getDstKeyPolicy))
if(properties.isEmpty){
for (elem <- rdd.columns) {
properties += (elem -> 1)
}
}
rdd.repartition(nebulaConfig.getPartitions).map { row =>
val values: ListBuffer[Any] = new ListBuffer[Any]()
for {property <- properties.keySet} {
values.append(extraValue(row, property))
}
val valueStr = values.mkString(",")
(String.valueOf(extraValue(row, edgeConfig.getSrcIdField)), String.valueOf(extraValue(row, edgeConfig.getDstIdField)), edgeConfig.getRank, valueStr)
}(Encoders.tuple(Encoders.STRING, Encoders.STRING, Encoders.LONG, Encoders.STRING))
.foreachPartition { iterator: Iterator[(String, String, java.lang.Long, String)] =>
val writer = new NebulaWriter(nebulaOptions = nebulaOptions)
writer.prepareSpace()
iterator.grouped(nebulaConfig.getBatchWriteSize).foreach { edges =>
val values =
if (edgeConfig.getRank == null || edgeConfig.getRank == 0L) {
edges
.map { edge =>
(for (source <- edge._1.split(","))
yield
if (srcKeyPolicy.isEmpty && dstKeyPolicy.isEmpty) {
EDGE_VALUE_WITHOUT_RANKING_TEMPLATE
.format(source, edge._2, edge._4)
} else {
val source = if (srcKeyPolicy.isDefined) {
srcKeyPolicy.get match {
case KeyPolicy.HASH =>
ENDPOINT_TEMPLATE.format(KeyPolicy.HASH.toString, edge._1)
case KeyPolicy.UUID =>
ENDPOINT_TEMPLATE.format(KeyPolicy.UUID.toString, edge._1)
case _ =>
throw new IllegalArgumentException()
}
} else {
edge._1
}
val target = if (dstKeyPolicy.isDefined) {
dstKeyPolicy.get match {
case KeyPolicy.HASH =>
ENDPOINT_TEMPLATE.format(KeyPolicy.HASH.toString, edge._2)
case KeyPolicy.UUID =>
ENDPOINT_TEMPLATE.format(KeyPolicy.UUID.toString, edge._2)
case _ =>
throw new IllegalArgumentException()
}
} else {
edge._2
}
EDGE_VALUE_WITHOUT_RANKING_TEMPLATE_WITH_POLICY
.format(source, target, edge._4)
}).mkString(",")
}
.toList
.mkString(",")
} else {
edges
.map { edge =>
(for (source <- edge._1.split(","))
yield
if (srcKeyPolicy.isEmpty && dstKeyPolicy.isEmpty) {
EDGE_VALUE_TEMPLATE
.format(source, edge._2, edge._3, edge._4)
} else {
val source = srcKeyPolicy.get match {
case KeyPolicy.HASH =>
ENDPOINT_TEMPLATE.format(KeyPolicy.HASH.toString, edge._1)
case KeyPolicy.UUID =>
ENDPOINT_TEMPLATE.format(KeyPolicy.UUID.toString, edge._1)
case _ =>
edge._1
}
val target = dstKeyPolicy.get match {
case KeyPolicy.HASH =>
ENDPOINT_TEMPLATE.format(KeyPolicy.HASH.toString, edge._2)
case KeyPolicy.UUID =>
ENDPOINT_TEMPLATE.format(KeyPolicy.UUID.toString, edge._2)
case _ =>
edge._2
}
EDGE_VALUE_TEMPLATE_WITH_POLICY
.format(source, target, edge._3, edge._4)
})
.mkString(",")
}
.toList
.mkString(",")
}
val exec = BATCH_INSERT_TEMPLATE
.format(DataTypeEnum.EDGE.toString, edgeConfig.getEdge, properties.keySet.mkString(","), values)
writer.submit(exec)
}
}
}
}
package com.qihoo.finance.graph.data.nebula
object DataTypeEnum extends Enumeration {
type DataType = Value
val VERTEX = Value("vertex")
val EDGE = Value("edge")
def validDataType(dataType: String): Boolean = {
dataType.equalsIgnoreCase(VERTEX.toString) || dataType.equalsIgnoreCase(EDGE.toString)
}
}
object KeyPolicy extends Enumeration {
type POLICY = Value
val HASH = Value("hash")
val UUID = Value("uuid")
}
object OperaType extends Enumeration {
type Operation = Value
val READ = Value("read")
val WRITE = Value("write")
}
package com.qihoo.finance.graph.data.nebula
import com.qihoo.finance.graph.data.config.{NebulaConfig, VertexConfig}
import com.vesoft.nebula.connector.NebulaOptions
import org.apache.commons.lang.StringUtils
import org.apache.spark.sql.{DataFrame, Encoders}
import scala.collection.mutable.ListBuffer
/**
* @Description NebulaVertexWriter is used for
* @author huangzhaolai-jk
* @Date 2021/8/2 - 14:27
* @version 1.0.0
*/
class NebulaVertexWriter(nebulaOptions: NebulaOptions, vertexConfig: VertexConfig, nebulaConfig: NebulaConfig) extends NebulaBaseWriter {
// private val properties: Map[String, Integer] = metaProvider.getTagSchema(nebulaConfig.getSpace, vertexConfig.getTag)
private var properties: Map[String, Integer] = Map()
override def write(rdd: DataFrame): Unit = {
val keyPolicy = if (StringUtils.isBlank(vertexConfig.getKeyPolicy)) Option.empty else Option(KeyPolicy.withName(vertexConfig.getKeyPolicy))
if (properties.isEmpty) {
for (elem <- rdd.columns) {
properties += (elem -> 1)
}
}
rdd.repartition(nebulaConfig.getPartitions).map { row =>
val values: ListBuffer[Any] = new ListBuffer[Any]()
for {property <- properties.keySet} {
values.append(extraValue(row, property))
}
val valueStr = values.mkString(",")
(String.valueOf(extraValue(row, vertexConfig.getIdField)), valueStr)
}(Encoders.tuple(Encoders.STRING, Encoders.STRING))
.foreachPartition { iterator: Iterator[(String, String)] =>
val writer = new NebulaWriter(nebulaOptions)
writer.prepareSpace()
iterator.grouped(nebulaConfig.getBatchWriteSize).foreach { tags =>
val exec = BATCH_INSERT_TEMPLATE.format(
DataTypeEnum.VERTEX.toString,
vertexConfig.getTag,
properties.keySet.mkString(", "),
tags
.map { tag =>
if (keyPolicy.isEmpty) {
INSERT_VALUE_TEMPLATE.format(tag._1, tag._2)
} else {
keyPolicy.get match {
case KeyPolicy.HASH =>
INSERT_VALUE_TEMPLATE_WITH_POLICY
.format(KeyPolicy.HASH.toString, tag._1, tag._2)
case KeyPolicy.UUID =>
INSERT_VALUE_TEMPLATE_WITH_POLICY
.format(KeyPolicy.UUID.toString, tag._1, tag._2)
case _ => throw new IllegalArgumentException
}
}
}
.mkString(",")
)
writer.submit(exec)
}
}
}
}
package com.qihoo.finance.graph.data.nebula
import java.util.concurrent.TimeUnit
import com.google.common.util.concurrent.RateLimiter
import com.vesoft.nebula.connector.NebulaOptions
import org.slf4j.LoggerFactory
import scala.util.control.Breaks
class NebulaWriter(nebulaOptions: NebulaOptions) extends Serializable {
private val LOG = LoggerFactory.getLogger(this.getClass)
private val address: List[(String, Int)] = nebulaOptions.getGraphAddress
var graphProvider: GraphProvider = _
{
if (address != null && address.size > 1) {
val index = (new util.Random).nextInt(address.size)
val list = List.apply(address(index))
graphProvider = new GraphProvider(list)
}else{
graphProvider = new GraphProvider(address)
}
}
// val metaProvider = new MetaProvider(nebulaOptions.getMetaAddress)
// val isVidStringType = metaProvider.getVidType(nebulaOptions.spaceName) == VidType.STRING
def prepareSpace(): Unit = {
graphProvider.switchSpace(nebulaOptions.user, nebulaOptions.passwd, nebulaOptions.spaceName)
}
def submit(exec: String): Unit = {
@transient val rateLimiter = RateLimiter.create(nebulaOptions.rateLimit)
val loop = new Breaks
loop.breakable {
while (true) {
if (rateLimiter.tryAcquire(nebulaOptions.rateTimeOut, TimeUnit.MILLISECONDS)) {
val result = graphProvider.submit(exec)
if (!result.isSucceeded) {
LOG.error("failed to write for " + exec + ",error:" + result.getErrorMessage)
} else {
LOG.info("batch write succeed:")
LOG.debug("batch write succeed: " + exec)
loop.break
}
} else {
LOG.error(s"failed to acquire reteLimiter for statement")
}
}
}
}
}