123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464 |
- #!/usr/bin/env python
- # -*- encoding:utf8 -*-
- # protoc-gen-erl
- # Google's Protocol Buffers project, ported to lua.
- # https://code.google.com/p/protoc-gen-lua/
- #
- # Copyright (c) 2010 , 林卓毅 (Zhuoyi Lin) netsnail@gmail.com
- # All rights reserved.
- #
- # Use, modification and distribution are subject to the "New BSD License"
- # as listed at <url: http://www.opensource.org/licenses/bsd-license.php >.
- import sys
- import os.path as path
- from cStringIO import StringIO
- import plugin_pb2
- import google.protobuf.descriptor_pb2 as descriptor_pb2
- _packages = {}
- _files = {}
- _message = {}
- FDP = plugin_pb2.descriptor_pb2.FieldDescriptorProto
- if sys.platform == "win32":
- import msvcrt, os
- msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
- msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
- class CppType:
- CPPTYPE_INT32 = 1
- CPPTYPE_INT64 = 2
- CPPTYPE_UINT32 = 3
- CPPTYPE_UINT64 = 4
- CPPTYPE_DOUBLE = 5
- CPPTYPE_FLOAT = 6
- CPPTYPE_BOOL = 7
- CPPTYPE_ENUM = 8
- CPPTYPE_STRING = 9
- CPPTYPE_MESSAGE = 10
- CPP_TYPE ={
- FDP.TYPE_DOUBLE : CppType.CPPTYPE_DOUBLE,
- FDP.TYPE_FLOAT : CppType.CPPTYPE_FLOAT,
- FDP.TYPE_INT64 : CppType.CPPTYPE_INT64,
- FDP.TYPE_UINT64 : CppType.CPPTYPE_UINT64,
- FDP.TYPE_INT32 : CppType.CPPTYPE_INT32,
- FDP.TYPE_FIXED64 : CppType.CPPTYPE_UINT64,
- FDP.TYPE_FIXED32 : CppType.CPPTYPE_UINT32,
- FDP.TYPE_BOOL : CppType.CPPTYPE_BOOL,
- FDP.TYPE_STRING : CppType.CPPTYPE_STRING,
- FDP.TYPE_MESSAGE : CppType.CPPTYPE_MESSAGE,
- FDP.TYPE_BYTES : CppType.CPPTYPE_STRING,
- FDP.TYPE_UINT32 : CppType.CPPTYPE_UINT32,
- FDP.TYPE_ENUM : CppType.CPPTYPE_ENUM,
- FDP.TYPE_SFIXED32 : CppType.CPPTYPE_INT32,
- FDP.TYPE_SFIXED64 : CppType.CPPTYPE_INT64,
- FDP.TYPE_SINT32 : CppType.CPPTYPE_INT32,
- FDP.TYPE_SINT64 : CppType.CPPTYPE_INT64
- }
- def printerr(*args):
- sys.stderr.write(" ".join(args))
- sys.stderr.write("\n")
- sys.stderr.flush()
- class TreeNode(object):
- def __init__(self, name, parent=None, filename=None, package=None):
- super(TreeNode, self).__init__()
- self.child = []
- self.parent = parent
- self.filename = filename
- self.package = package
- if parent:
- self.parent.add_child(self)
- self.name = name
- def add_child(self, child):
- self.child.append(child)
- def find_child(self, child_names):
- if child_names:
- for i in self.child:
- if i.name == child_names[0]:
- return i.find_child(child_names[1:])
- raise StandardError
- else:
- return self
- def get_child(self, child_name):
- for i in self.child:
- if i.name == child_name:
- return i
- return None
- def get_path(self, end = None):
- pos = self
- out = []
- while pos and pos != end:
- out.append(pos.name)
- pos = pos.parent
- out.reverse()
- return '.'.join(out)
- def get_global_name(self):
- return self.get_path()
- def get_local_name(self):
- pos = self
- while pos.parent:
- pos = pos.parent
- if self.package and pos.name == self.package[-1]:
- break
- return self.get_path(pos)
- def __str__(self):
- return self.to_string(0)
- def __repr__(self):
- return str(self)
- def to_string(self, indent = 0):
- return ' '*indent + '<TreeNode ' + self.name + '(\n' + \
- ','.join([i.to_string(indent + 4) for i in self.child]) + \
- ' '*indent +')>\n'
- class Env(object):
- filename = None
- package = None
- extend = None
- descriptor = None
- message = None
- context = None
- register = None
- def __init__(self):
- self.message_tree = TreeNode('')
- self.scope = self.message_tree
- def get_global_name(self):
- return self.scope.get_global_name()
- def get_local_name(self):
- return self.scope.get_local_name()
- def get_ref_name(self, type_name):
- try:
- node = self.lookup_name(type_name)
- except:
- # if the child doesn't be founded, it must be in this file
- return type_name[len('.'.join(self.package)) + 2:]
- if node.filename != self.filename:
- return node.filename + '_pb.' + node.get_local_name()
- return node.get_local_name()
- def lookup_name(self, name):
- names = name.split('.')
- if names[0] == '':
- return self.message_tree.find_child(names[1:])
- else:
- return self.scope.parent.find_child(names)
- def enter_package(self, package):
- if not package:
- return self.message_tree
- names = package.split('.')
- pos = self.message_tree
- for i, name in enumerate(names):
- new_pos = pos.get_child(name)
- if new_pos:
- pos = new_pos
- else:
- return self._build_nodes(pos, names[i:])
- return pos
- def enter_file(self, filename, package):
- self.filename = filename
- self.package = package.split('.')
- self._init_field()
- self.scope = self.enter_package(package)
- def exit_file(self):
- self._init_field()
- self.filename = None
- self.package = []
- self.scope = self.scope.parent
- def enter(self, message_name):
- self.scope = TreeNode(message_name, self.scope, self.filename,
- self.package)
- def exit(self):
- self.scope = self.scope.parent
- def _init_field(self):
- self.descriptor = []
- self.context = []
- self.message = []
- self.register = []
- def _build_nodes(self, node, names):
- parent = node
- for i in names:
- parent = TreeNode(i, parent, self.filename, self.package)
- return parent
- class Writer(object):
- def __init__(self, prefix=None):
- self.io = StringIO()
- self.__indent = ''
- self.__prefix = prefix
- def getvalue(self):
- return self.io.getvalue()
- def __enter__(self):
- self.__indent += ' '
- return self
- def __exit__(self, type, value, trackback):
- self.__indent = self.__indent[:-4]
- def __call__(self, data):
- self.io.write(self.__indent)
- if self.__prefix:
- self.io.write(self.__prefix)
- self.io.write(data)
- DEFAULT_VALUE = {
- FDP.TYPE_DOUBLE : '0.0',
- FDP.TYPE_FLOAT : '0.0',
- FDP.TYPE_INT64 : '0',
- FDP.TYPE_UINT64 : '0',
- FDP.TYPE_INT32 : '0',
- FDP.TYPE_FIXED64 : '0',
- FDP.TYPE_FIXED32 : '0',
- FDP.TYPE_BOOL : 'false',
- FDP.TYPE_STRING : '""',
- FDP.TYPE_MESSAGE : 'nil',
- FDP.TYPE_BYTES : '""',
- FDP.TYPE_UINT32 : '0',
- FDP.TYPE_ENUM : '1',
- FDP.TYPE_SFIXED32 : '0',
- FDP.TYPE_SFIXED64 : '0',
- FDP.TYPE_SINT32 : '0',
- FDP.TYPE_SINT64 : '0',
- }
- def code_gen_enum_item(index, enum_value, env):
- full_name = env.get_local_name() + '.' + enum_value.name
- obj_name = full_name.upper().replace('.', '_') + '_ENUM'
- env.descriptor.append(
- "%s = protobuf.EnumValueDescriptor();\n"% obj_name
- )
- context = Writer(obj_name)
- context('.name = "%s"\n' % enum_value.name)
- context('.index = %d\n' % index)
- context('.number = %d\n' % enum_value.number)
- env.context.append(context.getvalue())
- return obj_name
- def code_gen_enum(enum_desc, env):
- env.enter(enum_desc.name)
- full_name = env.get_local_name()
- obj_name = full_name.upper().replace('.', '_')
- env.descriptor.append(
- "%s = protobuf.EnumDescriptor();\n"% obj_name
- )
- context = Writer(obj_name)
- context('.name = "%s"\n' % enum_desc.name)
- context('.full_name = "%s"\n' % env.get_global_name())
- values = []
- for i, enum_value in enumerate(enum_desc.value):
- values.append(code_gen_enum_item(i, enum_value, env))
- context('.values = {%s}\n' % ','.join(values))
- env.context.append(context.getvalue())
- env.exit()
- return obj_name
- def code_gen_field(index, field_desc, env):
- full_name = env.get_local_name() + '.' + field_desc.name
- obj_name = full_name.upper().replace('.', '_') + '_FIELD'
- env.descriptor.append(
- "local %s = protobuf.FieldDescriptor();\n"% obj_name
- )
- context = Writer(obj_name)
- context('.name = "%s"\n' % field_desc.name)
- context('.full_name = "%s"\n' % (
- env.get_global_name() + '.' + field_desc.name))
- context('.number = %d\n' % field_desc.number)
- context('.index = %d\n' % index)
- context('.label = %d\n' % field_desc.label)
- pb_name = ""
- if field_desc.HasField('type_name'):
- type_name = env.get_ref_name(field_desc.type_name)
- names = type_name.split('.')
- if len(names) > 1:
- pb_name = names[0]
- type_name = names[0]+'.'+names[1].upper()
- if field_desc.type == FDP.TYPE_MESSAGE:
- context('.message_type = %s\n' % type_name)
- else:
- context('.enum_type = %s\n' % type_name)
- else:
- type_name = type_name.upper()
- if field_desc.type == FDP.TYPE_MESSAGE:
- context('.message_type = %s\n' % type_name)
- else:
- context('.enum_type = %s\n' % type_name)
-
- if field_desc.HasField("default_value"):
- context('.has_default_value = true\n')
- value = field_desc.default_value
- if field_desc.type == FDP.TYPE_STRING:
- context('.default_value = "%s"\n'%value)
- elif field_desc.type == FDP.TYPE_ENUM:
- pb_name = pb_name+'.'+value
- context('.default_value = %s\n'%pb_name)
- else:
- context('.default_value = %s\n'%value)
- else:
- context('.has_default_value = false\n')
- if field_desc.label == FDP.LABEL_REPEATED:
- default_value = "{}"
- elif field_desc.HasField('type_name'):
- default_value = "nil"
- else:
- default_value = DEFAULT_VALUE[field_desc.type]
- context('.default_value = %s\n' % default_value)
- if field_desc.HasField('extendee'):
- type_name = env.get_ref_name(field_desc.extendee)
- env.register.append(
- "%s.RegisterExtension(%s)\n" % (type_name, obj_name)
- )
- context('.type = %d\n' % field_desc.type)
- context('.cpp_type = %d\n\n' % CPP_TYPE[field_desc.type])
- env.context.append(context.getvalue())
- return obj_name
- def code_gen_message(message_descriptor, env, containing_type = None):
- env.enter(message_descriptor.name)
- full_name = env.get_local_name()
- obj_name = full_name.upper().replace('.', '_')
- env.descriptor.append(
- "%s = protobuf.Descriptor();\n"% obj_name
- )
- context = Writer(obj_name)
- context('.name = "%s"\n' % message_descriptor.name)
- context('.full_name = "%s"\n' % env.get_global_name())
- nested_types = []
- for msg_desc in message_descriptor.nested_type:
- msg_name = code_gen_message(msg_desc, env, obj_name)
- nested_types.append(msg_name)
- context('.nested_types = {%s}\n' % ', '.join(nested_types))
- enums = []
- for enum_desc in message_descriptor.enum_type:
- enums.append(code_gen_enum(enum_desc, env))
- context('.enum_types = {%s}\n' % ', '.join(enums))
- fields = []
- for i, field_desc in enumerate(message_descriptor.field):
- fields.append(code_gen_field(i, field_desc, env))
- context('.fields = {%s}\n' % ', '.join(fields))
- if len(message_descriptor.extension_range) > 0:
- context('.is_extendable = true\n')
- else:
- context('.is_extendable = false\n')
- extensions = []
- for i, field_desc in enumerate(message_descriptor.extension):
- extensions.append(code_gen_field(i, field_desc, env))
- context('.extensions = {%s}\n' % ', '.join(extensions))
- if containing_type:
- context('.containing_type = %s\n' % containing_type)
- env.message.append('%s = protobuf.Message(%s)\n' % (full_name,
- obj_name))
- env.context.append(context.getvalue())
- env.exit()
- return obj_name
- def write_header(writer):
- writer("""-- Generated By protoc-gen-lua Do not Edit
- """)
- def code_gen_file(proto_file, env, is_gen):
- filename = path.splitext(proto_file.name)[0]
- env.enter_file(filename, proto_file.package)
- includes = []
- for f in proto_file.dependency:
- inc_file = path.splitext(f)[0]
- includes.append(inc_file)
- # for field_desc in proto_file.extension:
- # code_gen_extensions(field_desc, field_desc.name, env)
- for enum_desc in proto_file.enum_type:
- code_gen_enum(enum_desc, env)
- for enum_value in enum_desc.value:
- env.message.append('%s = %d\n' % (enum_value.name,
- enum_value.number))
- for msg_desc in proto_file.message_type:
- code_gen_message(msg_desc, env)
- if is_gen:
- lua = Writer()
- write_header(lua)
- lua('local protobuf = require "protobuf"\n')
- for i in includes:
- lua('local %s_pb = require("%s_pb")\n' % (i, i))
- lua("module('%s_pb')\n" % env.filename)
- lua('\n\n')
- map(lua, env.descriptor)
- lua('\n')
- map(lua, env.context)
- lua('\n')
- env.message.sort()
- map(lua, env.message)
- lua('\n')
- map(lua, env.register)
- _files[env.filename+ '_pb.lua'] = lua.getvalue()
- env.exit_file()
- def main():
- plugin_require_bin = sys.stdin.read()
- code_gen_req = plugin_pb2.CodeGeneratorRequest()
- code_gen_req.ParseFromString(plugin_require_bin)
- env = Env()
- for proto_file in code_gen_req.proto_file:
- code_gen_file(proto_file, env,
- proto_file.name in code_gen_req.file_to_generate)
- code_generated = plugin_pb2.CodeGeneratorResponse()
- for k in _files:
- file_desc = code_generated.file.add()
- file_desc.name = k
- file_desc.content = _files[k]
- sys.stdout.write(code_generated.SerializeToString())
- if __name__ == "__main__":
- main()
|