protoc-gen-lua 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. #!/usr/bin/env python
  2. # -*- encoding:utf8 -*-
  3. # protoc-gen-erl
  4. # Google's Protocol Buffers project, ported to lua.
  5. # https://code.google.com/p/protoc-gen-lua/
  6. #
  7. # Copyright (c) 2010 , 林卓毅 (Zhuoyi Lin) netsnail@gmail.com
  8. # All rights reserved.
  9. #
  10. # Use, modification and distribution are subject to the "New BSD License"
  11. # as listed at <url: http://www.opensource.org/licenses/bsd-license.php >.
  12. import sys
  13. import os.path as path
  14. from cStringIO import StringIO
  15. import plugin_pb2
  16. import google.protobuf.descriptor_pb2 as descriptor_pb2
  17. _packages = {}
  18. _files = {}
  19. _message = {}
  20. FDP = plugin_pb2.descriptor_pb2.FieldDescriptorProto
  21. if sys.platform == "win32":
  22. import msvcrt, os
  23. msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
  24. msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
  25. class CppType:
  26. CPPTYPE_INT32 = 1
  27. CPPTYPE_INT64 = 2
  28. CPPTYPE_UINT32 = 3
  29. CPPTYPE_UINT64 = 4
  30. CPPTYPE_DOUBLE = 5
  31. CPPTYPE_FLOAT = 6
  32. CPPTYPE_BOOL = 7
  33. CPPTYPE_ENUM = 8
  34. CPPTYPE_STRING = 9
  35. CPPTYPE_MESSAGE = 10
  36. CPP_TYPE ={
  37. FDP.TYPE_DOUBLE : CppType.CPPTYPE_DOUBLE,
  38. FDP.TYPE_FLOAT : CppType.CPPTYPE_FLOAT,
  39. FDP.TYPE_INT64 : CppType.CPPTYPE_INT64,
  40. FDP.TYPE_UINT64 : CppType.CPPTYPE_UINT64,
  41. FDP.TYPE_INT32 : CppType.CPPTYPE_INT32,
  42. FDP.TYPE_FIXED64 : CppType.CPPTYPE_UINT64,
  43. FDP.TYPE_FIXED32 : CppType.CPPTYPE_UINT32,
  44. FDP.TYPE_BOOL : CppType.CPPTYPE_BOOL,
  45. FDP.TYPE_STRING : CppType.CPPTYPE_STRING,
  46. FDP.TYPE_MESSAGE : CppType.CPPTYPE_MESSAGE,
  47. FDP.TYPE_BYTES : CppType.CPPTYPE_STRING,
  48. FDP.TYPE_UINT32 : CppType.CPPTYPE_UINT32,
  49. FDP.TYPE_ENUM : CppType.CPPTYPE_ENUM,
  50. FDP.TYPE_SFIXED32 : CppType.CPPTYPE_INT32,
  51. FDP.TYPE_SFIXED64 : CppType.CPPTYPE_INT64,
  52. FDP.TYPE_SINT32 : CppType.CPPTYPE_INT32,
  53. FDP.TYPE_SINT64 : CppType.CPPTYPE_INT64
  54. }
  55. def printerr(*args):
  56. sys.stderr.write(" ".join(args))
  57. sys.stderr.write("\n")
  58. sys.stderr.flush()
  59. class TreeNode(object):
  60. def __init__(self, name, parent=None, filename=None, package=None):
  61. super(TreeNode, self).__init__()
  62. self.child = []
  63. self.parent = parent
  64. self.filename = filename
  65. self.package = package
  66. if parent:
  67. self.parent.add_child(self)
  68. self.name = name
  69. def add_child(self, child):
  70. self.child.append(child)
  71. def find_child(self, child_names):
  72. if child_names:
  73. for i in self.child:
  74. if i.name == child_names[0]:
  75. return i.find_child(child_names[1:])
  76. raise StandardError
  77. else:
  78. return self
  79. def get_child(self, child_name):
  80. for i in self.child:
  81. if i.name == child_name:
  82. return i
  83. return None
  84. def get_path(self, end = None):
  85. pos = self
  86. out = []
  87. while pos and pos != end:
  88. out.append(pos.name)
  89. pos = pos.parent
  90. out.reverse()
  91. return '.'.join(out)
  92. def get_global_name(self):
  93. return self.get_path()
  94. def get_local_name(self):
  95. pos = self
  96. while pos.parent:
  97. pos = pos.parent
  98. if self.package and pos.name == self.package[-1]:
  99. break
  100. return self.get_path(pos)
  101. def __str__(self):
  102. return self.to_string(0)
  103. def __repr__(self):
  104. return str(self)
  105. def to_string(self, indent = 0):
  106. return ' '*indent + '<TreeNode ' + self.name + '(\n' + \
  107. ','.join([i.to_string(indent + 4) for i in self.child]) + \
  108. ' '*indent +')>\n'
  109. class Env(object):
  110. filename = None
  111. package = None
  112. extend = None
  113. descriptor = None
  114. message = None
  115. context = None
  116. register = None
  117. def __init__(self):
  118. self.message_tree = TreeNode('')
  119. self.scope = self.message_tree
  120. def get_global_name(self):
  121. return self.scope.get_global_name()
  122. def get_local_name(self):
  123. return self.scope.get_local_name()
  124. def get_ref_name(self, type_name):
  125. try:
  126. node = self.lookup_name(type_name)
  127. except:
  128. # if the child doesn't be founded, it must be in this file
  129. return type_name[len('.'.join(self.package)) + 2:]
  130. if node.filename != self.filename:
  131. return node.filename + '_pb.' + node.get_local_name()
  132. return node.get_local_name()
  133. def lookup_name(self, name):
  134. names = name.split('.')
  135. if names[0] == '':
  136. return self.message_tree.find_child(names[1:])
  137. else:
  138. return self.scope.parent.find_child(names)
  139. def enter_package(self, package):
  140. if not package:
  141. return self.message_tree
  142. names = package.split('.')
  143. pos = self.message_tree
  144. for i, name in enumerate(names):
  145. new_pos = pos.get_child(name)
  146. if new_pos:
  147. pos = new_pos
  148. else:
  149. return self._build_nodes(pos, names[i:])
  150. return pos
  151. def enter_file(self, filename, package):
  152. self.filename = filename
  153. self.package = package.split('.')
  154. self._init_field()
  155. self.scope = self.enter_package(package)
  156. def exit_file(self):
  157. self._init_field()
  158. self.filename = None
  159. self.package = []
  160. self.scope = self.scope.parent
  161. def enter(self, message_name):
  162. self.scope = TreeNode(message_name, self.scope, self.filename,
  163. self.package)
  164. def exit(self):
  165. self.scope = self.scope.parent
  166. def _init_field(self):
  167. self.descriptor = []
  168. self.context = []
  169. self.message = []
  170. self.register = []
  171. def _build_nodes(self, node, names):
  172. parent = node
  173. for i in names:
  174. parent = TreeNode(i, parent, self.filename, self.package)
  175. return parent
  176. class Writer(object):
  177. def __init__(self, prefix=None):
  178. self.io = StringIO()
  179. self.__indent = ''
  180. self.__prefix = prefix
  181. def getvalue(self):
  182. return self.io.getvalue()
  183. def __enter__(self):
  184. self.__indent += ' '
  185. return self
  186. def __exit__(self, type, value, trackback):
  187. self.__indent = self.__indent[:-4]
  188. def __call__(self, data):
  189. self.io.write(self.__indent)
  190. if self.__prefix:
  191. self.io.write(self.__prefix)
  192. self.io.write(data)
  193. DEFAULT_VALUE = {
  194. FDP.TYPE_DOUBLE : '0.0',
  195. FDP.TYPE_FLOAT : '0.0',
  196. FDP.TYPE_INT64 : '0',
  197. FDP.TYPE_UINT64 : '0',
  198. FDP.TYPE_INT32 : '0',
  199. FDP.TYPE_FIXED64 : '0',
  200. FDP.TYPE_FIXED32 : '0',
  201. FDP.TYPE_BOOL : 'false',
  202. FDP.TYPE_STRING : '""',
  203. FDP.TYPE_MESSAGE : 'nil',
  204. FDP.TYPE_BYTES : '""',
  205. FDP.TYPE_UINT32 : '0',
  206. FDP.TYPE_ENUM : '1',
  207. FDP.TYPE_SFIXED32 : '0',
  208. FDP.TYPE_SFIXED64 : '0',
  209. FDP.TYPE_SINT32 : '0',
  210. FDP.TYPE_SINT64 : '0',
  211. }
  212. def code_gen_enum_item(index, enum_value, env):
  213. full_name = env.get_local_name() + '.' + enum_value.name
  214. obj_name = full_name.upper().replace('.', '_') + '_ENUM'
  215. env.descriptor.append(
  216. "%s = protobuf.EnumValueDescriptor();\n"% obj_name
  217. )
  218. context = Writer(obj_name)
  219. context('.name = "%s"\n' % enum_value.name)
  220. context('.index = %d\n' % index)
  221. context('.number = %d\n' % enum_value.number)
  222. env.context.append(context.getvalue())
  223. return obj_name
  224. def code_gen_enum(enum_desc, env):
  225. env.enter(enum_desc.name)
  226. full_name = env.get_local_name()
  227. obj_name = full_name.upper().replace('.', '_')
  228. env.descriptor.append(
  229. "%s = protobuf.EnumDescriptor();\n"% obj_name
  230. )
  231. context = Writer(obj_name)
  232. context('.name = "%s"\n' % enum_desc.name)
  233. context('.full_name = "%s"\n' % env.get_global_name())
  234. values = []
  235. for i, enum_value in enumerate(enum_desc.value):
  236. values.append(code_gen_enum_item(i, enum_value, env))
  237. context('.values = {%s}\n' % ','.join(values))
  238. env.context.append(context.getvalue())
  239. env.exit()
  240. return obj_name
  241. def code_gen_field(index, field_desc, env):
  242. full_name = env.get_local_name() + '.' + field_desc.name
  243. obj_name = full_name.upper().replace('.', '_') + '_FIELD'
  244. env.descriptor.append(
  245. "local %s = protobuf.FieldDescriptor();\n"% obj_name
  246. )
  247. context = Writer(obj_name)
  248. context('.name = "%s"\n' % field_desc.name)
  249. context('.full_name = "%s"\n' % (
  250. env.get_global_name() + '.' + field_desc.name))
  251. context('.number = %d\n' % field_desc.number)
  252. context('.index = %d\n' % index)
  253. context('.label = %d\n' % field_desc.label)
  254. pb_name = ""
  255. if field_desc.HasField('type_name'):
  256. type_name = env.get_ref_name(field_desc.type_name)
  257. names = type_name.split('.')
  258. if len(names) > 1:
  259. pb_name = names[0]
  260. type_name = names[0]+'.'+names[1].upper()
  261. if field_desc.type == FDP.TYPE_MESSAGE:
  262. context('.message_type = %s\n' % type_name)
  263. else:
  264. context('.enum_type = %s\n' % type_name)
  265. else:
  266. type_name = type_name.upper()
  267. if field_desc.type == FDP.TYPE_MESSAGE:
  268. context('.message_type = %s\n' % type_name)
  269. else:
  270. context('.enum_type = %s\n' % type_name)
  271. if field_desc.HasField("default_value"):
  272. context('.has_default_value = true\n')
  273. value = field_desc.default_value
  274. if field_desc.type == FDP.TYPE_STRING:
  275. context('.default_value = "%s"\n'%value)
  276. elif field_desc.type == FDP.TYPE_ENUM:
  277. pb_name = pb_name+'.'+value
  278. context('.default_value = %s\n'%pb_name)
  279. else:
  280. context('.default_value = %s\n'%value)
  281. else:
  282. context('.has_default_value = false\n')
  283. if field_desc.label == FDP.LABEL_REPEATED:
  284. default_value = "{}"
  285. elif field_desc.HasField('type_name'):
  286. default_value = "nil"
  287. else:
  288. default_value = DEFAULT_VALUE[field_desc.type]
  289. context('.default_value = %s\n' % default_value)
  290. if field_desc.HasField('extendee'):
  291. type_name = env.get_ref_name(field_desc.extendee)
  292. env.register.append(
  293. "%s.RegisterExtension(%s)\n" % (type_name, obj_name)
  294. )
  295. context('.type = %d\n' % field_desc.type)
  296. context('.cpp_type = %d\n\n' % CPP_TYPE[field_desc.type])
  297. env.context.append(context.getvalue())
  298. return obj_name
  299. def code_gen_message(message_descriptor, env, containing_type = None):
  300. env.enter(message_descriptor.name)
  301. full_name = env.get_local_name()
  302. obj_name = full_name.upper().replace('.', '_')
  303. env.descriptor.append(
  304. "%s = protobuf.Descriptor();\n"% obj_name
  305. )
  306. context = Writer(obj_name)
  307. context('.name = "%s"\n' % message_descriptor.name)
  308. context('.full_name = "%s"\n' % env.get_global_name())
  309. nested_types = []
  310. for msg_desc in message_descriptor.nested_type:
  311. msg_name = code_gen_message(msg_desc, env, obj_name)
  312. nested_types.append(msg_name)
  313. context('.nested_types = {%s}\n' % ', '.join(nested_types))
  314. enums = []
  315. for enum_desc in message_descriptor.enum_type:
  316. enums.append(code_gen_enum(enum_desc, env))
  317. context('.enum_types = {%s}\n' % ', '.join(enums))
  318. fields = []
  319. for i, field_desc in enumerate(message_descriptor.field):
  320. fields.append(code_gen_field(i, field_desc, env))
  321. context('.fields = {%s}\n' % ', '.join(fields))
  322. if len(message_descriptor.extension_range) > 0:
  323. context('.is_extendable = true\n')
  324. else:
  325. context('.is_extendable = false\n')
  326. extensions = []
  327. for i, field_desc in enumerate(message_descriptor.extension):
  328. extensions.append(code_gen_field(i, field_desc, env))
  329. context('.extensions = {%s}\n' % ', '.join(extensions))
  330. if containing_type:
  331. context('.containing_type = %s\n' % containing_type)
  332. env.message.append('%s = protobuf.Message(%s)\n' % (full_name,
  333. obj_name))
  334. env.context.append(context.getvalue())
  335. env.exit()
  336. return obj_name
  337. def write_header(writer):
  338. writer("""-- Generated By protoc-gen-lua Do not Edit
  339. """)
  340. def code_gen_file(proto_file, env, is_gen):
  341. filename = path.splitext(proto_file.name)[0]
  342. env.enter_file(filename, proto_file.package)
  343. includes = []
  344. for f in proto_file.dependency:
  345. inc_file = path.splitext(f)[0]
  346. includes.append(inc_file)
  347. # for field_desc in proto_file.extension:
  348. # code_gen_extensions(field_desc, field_desc.name, env)
  349. for enum_desc in proto_file.enum_type:
  350. code_gen_enum(enum_desc, env)
  351. for enum_value in enum_desc.value:
  352. env.message.append('%s = %d\n' % (enum_value.name,
  353. enum_value.number))
  354. for msg_desc in proto_file.message_type:
  355. code_gen_message(msg_desc, env)
  356. if is_gen:
  357. lua = Writer()
  358. write_header(lua)
  359. lua('local protobuf = require "protobuf"\n')
  360. for i in includes:
  361. lua('local %s_pb = require("%s_pb")\n' % (i, i))
  362. lua("module('%s_pb')\n" % env.filename)
  363. lua('\n\n')
  364. map(lua, env.descriptor)
  365. lua('\n')
  366. map(lua, env.context)
  367. lua('\n')
  368. env.message.sort()
  369. map(lua, env.message)
  370. lua('\n')
  371. map(lua, env.register)
  372. _files[env.filename+ '_pb.lua'] = lua.getvalue()
  373. env.exit_file()
  374. def main():
  375. plugin_require_bin = sys.stdin.read()
  376. code_gen_req = plugin_pb2.CodeGeneratorRequest()
  377. code_gen_req.ParseFromString(plugin_require_bin)
  378. env = Env()
  379. for proto_file in code_gen_req.proto_file:
  380. code_gen_file(proto_file, env,
  381. proto_file.name in code_gen_req.file_to_generate)
  382. code_generated = plugin_pb2.CodeGeneratorResponse()
  383. for k in _files:
  384. file_desc = code_generated.file.add()
  385. file_desc.name = k
  386. file_desc.content = _files[k]
  387. sys.stdout.write(code_generated.SerializeToString())
  388. if __name__ == "__main__":
  389. main()