blob: a45183d147c45fd734a7f43cd7a302e76313549b [file] [log] [blame]
#!/usr/bin/env python
#
# Copyright 2007 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import array
import itertools
import re
import six
from six.moves import http_client
import struct
try:
import google.net.proto.proto1 as proto1
except ImportError:
class ProtocolBufferDecodeError(Exception): pass
class ProtocolBufferEncodeError(Exception): pass
class ProtocolBufferReturnError(Exception): pass
else:
ProtocolBufferDecodeError = proto1.ProtocolBufferDecodeError
ProtocolBufferEncodeError = proto1.ProtocolBufferEncodeError
ProtocolBufferReturnError = proto1.ProtocolBufferReturnError
__all__ = ['ProtocolMessage', 'Encoder', 'Decoder',
'ExtendableProtocolMessage',
'ProtocolBufferDecodeError',
'ProtocolBufferEncodeError',
'ProtocolBufferReturnError']
URL_RE = re.compile('^(https?)://([^/]+)(/.*)$')
class ProtocolMessage:
def __init__(self, contents=None):
raise NotImplementedError
def Clear(self):
raise NotImplementedError
def IsInitialized(self, debug_strs=None):
raise NotImplementedError
def Encode(self):
try:
return self._CEncode()
except (NotImplementedError, AttributeError):
e = Encoder()
self.Output(e)
return e.buffer().tobytes()
def SerializeToString(self):
return self.Encode()
def SerializePartialToString(self):
try:
return self._CEncodePartial()
except (NotImplementedError, AttributeError):
e = Encoder()
self.OutputPartial(e)
return e.buffer().tobytes()
def _CEncode(self):
raise NotImplementedError
def _CEncodePartial(self):
raise NotImplementedError
def ParseFromString(self, s):
self.Clear()
self.MergeFromString(s)
def ParsePartialFromString(self, s):
self.Clear()
self.MergePartialFromString(s)
def MergeFromString(self, s):
self.MergePartialFromString(s)
dbg = []
if not self.IsInitialized(dbg):
raise ProtocolBufferDecodeError('\n\t'.join(dbg))
def MergePartialFromString(self, s):
try:
self._CMergeFromString(s)
except (NotImplementedError, AttributeError):
a = array.array('B')
a.frombytes(six.ensure_binary(s))
d = Decoder(a, 0, len(a))
self.TryMerge(d)
def _CMergeFromString(self, s):
raise NotImplementedError
def __getstate__(self):
return self.Encode()
def __setstate__(self, contents_):
self.__init__(contents=contents_)
def sendCommand(self, server, url, response, follow_redirects=1,
secure=0, keyfile=None, certfile=None):
data = self.Encode()
if secure:
if keyfile and certfile:
conn = http_client.HTTPSConnection(server, key_file=keyfile,
cert_file=certfile)
else:
conn = http_client.HTTPSConnection(server)
else:
conn = http_client.HTTPConnection(server)
conn.putrequest("POST", url)
conn.putheader("Content-Length", "%d" %len(data))
conn.endheaders()
conn.send(data)
resp = conn.getresponse()
if follow_redirects > 0 and resp.status == 302:
m = URL_RE.match(resp.getheader('Location'))
if m:
protocol, server, url = m.groups()
return self.sendCommand(server, url, response,
follow_redirects=follow_redirects - 1,
secure=(protocol == 'https'),
keyfile=keyfile,
certfile=certfile)
if resp.status != 200:
raise ProtocolBufferReturnError(resp.status)
if response is not None:
response.ParseFromString(resp.read())
return response
def sendSecureCommand(self, server, keyfile, certfile, url, response,
follow_redirects=1):
return self.sendCommand(server, url, response,
follow_redirects=follow_redirects,
secure=1, keyfile=keyfile, certfile=certfile)
def __str__(self, prefix="", printElemNumber=0):
raise NotImplementedError
def ToASCII(self):
return self._CToASCII(ProtocolMessage._SYMBOLIC_FULL_ASCII)
def ToShortASCII(self):
return self._CToASCII(ProtocolMessage._SYMBOLIC_SHORT_ASCII)
_NUMERIC_ASCII = 0
_SYMBOLIC_SHORT_ASCII = 1
_SYMBOLIC_FULL_ASCII = 2
def _CToASCII(self, output_format):
raise NotImplementedError
def ParseASCII(self, ascii_string):
raise NotImplementedError
def ParseASCIIIgnoreUnknown(self, ascii_string):
raise NotImplementedError
def Equals(self, other):
raise NotImplementedError
def __eq__(self, other):
if other.__class__ is self.__class__:
return self.Equals(other)
return NotImplemented
def __ne__(self, other):
if other.__class__ is self.__class__:
return not self.Equals(other)
return NotImplemented
def Output(self, e):
dbg = []
if not self.IsInitialized(dbg):
raise ProtocolBufferEncodeError('\n\t'.join(dbg))
self.OutputUnchecked(e)
return
def OutputUnchecked(self, e):
raise NotImplementedError
def OutputPartial(self, e):
raise NotImplementedError
def Parse(self, d):
self.Clear()
self.Merge(d)
return
def Merge(self, d):
self.TryMerge(d)
dbg = []
if not self.IsInitialized(dbg):
raise ProtocolBufferDecodeError('\n\t'.join(dbg))
return
def TryMerge(self, d):
raise NotImplementedError
def CopyFrom(self, pb):
if (pb == self): return
self.Clear()
self.MergeFrom(pb)
def MergeFrom(self, pb):
raise NotImplementedError
def lengthVarInt32(self, n):
return self.lengthVarInt64(n)
def lengthVarInt64(self, n):
if n < 0:
return 10
result = 0
while 1:
result += 1
n >>= 7
if n == 0:
break
return result
def lengthString(self, n):
return self.lengthVarInt32(n) + n
def DebugFormat(self, value):
return "%s" % value
def DebugFormatInt32(self, value):
if (value <= -2000000000 or value >= 2000000000):
return self.DebugFormatFixed32(value)
return "%d" % value
def DebugFormatInt64(self, value):
if (value <= -20000000000000 or value >= 20000000000000):
return self.DebugFormatFixed64(value)
return "%d" % value
def DebugFormatString(self, value):
def escape(c):
o = ord(c)
if o == 10: return r"\n"
if o == 39: return r"\'"
if o == 34: return r'\"'
if o == 92: return r"\\"
if o >= 127 or o < 32: return "\\%03o" % o
return c
return '"' + "".join(escape(c) for c in value) + '"'
def DebugFormatFloat(self, value):
return "%ff" % value
def DebugFormatFixed32(self, value):
if (value < 0): value += (1<<32)
return "0x%x" % value
def DebugFormatFixed64(self, value):
if (value < 0): value += (1<<64)
return "0x%x" % value
def DebugFormatBool(self, value):
if value:
return "true"
else:
return "false"
TYPE_DOUBLE = 1
TYPE_FLOAT = 2
TYPE_INT64 = 3
TYPE_UINT64 = 4
TYPE_INT32 = 5
TYPE_FIXED64 = 6
TYPE_FIXED32 = 7
TYPE_BOOL = 8
TYPE_STRING = 9
TYPE_GROUP = 10
TYPE_FOREIGN = 11
_TYPE_TO_DEBUG_STRING = {
TYPE_INT32: ProtocolMessage.DebugFormatInt32,
TYPE_INT64: ProtocolMessage.DebugFormatInt64,
TYPE_UINT64: ProtocolMessage.DebugFormatInt64,
TYPE_FLOAT: ProtocolMessage.DebugFormatFloat,
TYPE_STRING: ProtocolMessage.DebugFormatString,
TYPE_FIXED32: ProtocolMessage.DebugFormatFixed32,
TYPE_FIXED64: ProtocolMessage.DebugFormatFixed64,
TYPE_BOOL: ProtocolMessage.DebugFormatBool }
class Encoder:
NUMERIC = 0
DOUBLE = 1
STRING = 2
STARTGROUP = 3
ENDGROUP = 4
FLOAT = 5
MAX_TYPE = 6
def __init__(self):
self.buf = array.array('B')
return
def buffer(self):
return self.buf
def put8(self, v):
if v < 0 or v >= (1<<8): raise ProtocolBufferEncodeError("u8 too big")
self.buf.append(v & 255)
return
def put16(self, v):
if v < 0 or v >= (1<<16): raise ProtocolBufferEncodeError("u16 too big")
self.buf.append((v >> 0) & 255)
self.buf.append((v >> 8) & 255)
return
def put32(self, v):
if v < 0 or v >= (1<<32): raise ProtocolBufferEncodeError("u32 too big")
self.buf.append((v >> 0) & 255)
self.buf.append((v >> 8) & 255)
self.buf.append((v >> 16) & 255)
self.buf.append((v >> 24) & 255)
return
def put64(self, v):
if v < 0 or v >= (1<<64): raise ProtocolBufferEncodeError("u64 too big")
self.buf.append((v >> 0) & 255)
self.buf.append((v >> 8) & 255)
self.buf.append((v >> 16) & 255)
self.buf.append((v >> 24) & 255)
self.buf.append((v >> 32) & 255)
self.buf.append((v >> 40) & 255)
self.buf.append((v >> 48) & 255)
self.buf.append((v >> 56) & 255)
return
def putVarInt32(self, v):
buf_append = self.buf.append
if v & 127 == v:
buf_append(v)
return
if v >= 0x80000000 or v < -0x80000000:
raise ProtocolBufferEncodeError("int32 too big")
if v < 0:
v += 0x10000000000000000
while True:
bits = v & 127
v >>= 7
if v:
bits |= 128
buf_append(bits)
if not v:
break
return
def putVarInt64(self, v):
buf_append = self.buf.append
if v >= 0x8000000000000000 or v < -0x8000000000000000:
raise ProtocolBufferEncodeError("int64 too big")
if v < 0:
v += 0x10000000000000000
while True:
bits = v & 127
v >>= 7
if v:
bits |= 128
buf_append(bits)
if not v:
break
return
def putVarUint64(self, v):
buf_append = self.buf.append
if v < 0 or v >= 0x10000000000000000:
raise ProtocolBufferEncodeError("uint64 too big")
while True:
bits = v & 127
v >>= 7
if v:
bits |= 128
buf_append(bits)
if not v:
break
return
def putFloat(self, v):
a = array.array('B')
a.frombytes(struct.pack("<f", v))
self.buf.extend(a)
return
def putDouble(self, v):
a = array.array('B')
a.frombytes(struct.pack("<d", v))
self.buf.extend(a)
return
def putBoolean(self, v):
if v:
self.buf.append(1)
else:
self.buf.append(0)
return
def putPrefixedString(self, v):
v = six.ensure_binary(v)
self.putVarInt32(len(v))
self.buf.frombytes(v)
def putRawString(self, v):
self.buf.frombytes(six.ensure_binary(v))
_TYPE_TO_METHOD = {
TYPE_DOUBLE: putDouble,
TYPE_FLOAT: putFloat,
TYPE_FIXED64: put64,
TYPE_FIXED32: put32,
TYPE_INT32: putVarInt32,
TYPE_INT64: putVarInt64,
TYPE_UINT64: putVarUint64,
TYPE_BOOL: putBoolean,
TYPE_STRING: putPrefixedString }
_TYPE_TO_BYTE_SIZE = {
TYPE_DOUBLE: 8,
TYPE_FLOAT: 4,
TYPE_FIXED64: 8,
TYPE_FIXED32: 4,
TYPE_BOOL: 1 }
class Decoder:
def __init__(self, buf, idx, limit):
self.buf = buf
self.idx = idx
self.limit = limit
return
def avail(self):
return self.limit - self.idx
def buffer(self):
return self.buf
def pos(self):
return self.idx
def skip(self, n):
if self.idx + n > self.limit: raise ProtocolBufferDecodeError("truncated")
self.idx += n
return
def skipData(self, tag):
t = tag & 7
if t == Encoder.NUMERIC:
self.getVarInt64()
elif t == Encoder.DOUBLE:
self.skip(8)
elif t == Encoder.STRING:
n = self.getVarInt32()
self.skip(n)
elif t == Encoder.STARTGROUP:
while 1:
t = self.getVarInt32()
if (t & 7) == Encoder.ENDGROUP:
break
else:
self.skipData(t)
if (t - Encoder.ENDGROUP) != (tag - Encoder.STARTGROUP):
raise ProtocolBufferDecodeError("corrupted")
elif t == Encoder.ENDGROUP:
raise ProtocolBufferDecodeError("corrupted")
elif t == Encoder.FLOAT:
self.skip(4)
else:
raise ProtocolBufferDecodeError("corrupted")
def get8(self):
if self.idx >= self.limit: raise ProtocolBufferDecodeError("truncated")
c = self.buf[self.idx]
self.idx += 1
return c
def get16(self):
if self.idx + 2 > self.limit: raise ProtocolBufferDecodeError("truncated")
c = self.buf[self.idx]
d = self.buf[self.idx + 1]
self.idx += 2
return (d << 8) | c
def get32(self):
if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError("truncated")
c = self.buf[self.idx]
d = self.buf[self.idx + 1]
e = self.buf[self.idx + 2]
f = self.buf[self.idx + 3]
self.idx += 4
return (f << 24) | (e << 16) | (d << 8) | c
def get64(self):
if self.idx + 8 > self.limit: raise ProtocolBufferDecodeError("truncated")
c = self.buf[self.idx]
d = self.buf[self.idx + 1]
e = self.buf[self.idx + 2]
f = self.buf[self.idx + 3]
g = self.buf[self.idx + 4]
h = self.buf[self.idx + 5]
i = self.buf[self.idx + 6]
j = self.buf[self.idx + 7]
self.idx += 8
return ((j << 56) | (i << 48) | (h << 40) | (g << 32) | (f << 24)
| (e << 16) | (d << 8) | c)
def getVarInt32(self):
b = self.get8()
if not (b & 128):
return b
result = 0
shift = 0
while 1:
result |= ((b & 127) << shift)
shift += 7
if not (b & 128):
if result >= 0x10000000000000000:
raise ProtocolBufferDecodeError("corrupted")
break
if shift >= 64: raise ProtocolBufferDecodeError("corrupted")
b = self.get8()
if result >= 0x8000000000000000:
result -= 0x10000000000000000
if result >= 0x80000000 or result < -0x80000000:
raise ProtocolBufferDecodeError("corrupted")
return result
def getVarInt64(self):
result = self.getVarUint64()
if result >= (1 << 63):
result -= (1 << 64)
return result
def getVarUint64(self):
result = 0
shift = 0
while 1:
if shift >= 64: raise ProtocolBufferDecodeError("corrupted")
b = self.get8()
result |= ((b & 127) << shift)
shift += 7
if not (b & 128):
if result >= (1 << 64): raise ProtocolBufferDecodeError("corrupted")
return result
return result
def getFloat(self):
if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError("truncated")
a = self.buf[self.idx:self.idx+4]
self.idx += 4
return struct.unpack("<f", a)[0]
def getDouble(self):
if self.idx + 8 > self.limit: raise ProtocolBufferDecodeError("truncated")
a = self.buf[self.idx:self.idx+8]
self.idx += 8
return struct.unpack("<d", a)[0]
def getBoolean(self):
b = self.get8()
if b != 0 and b != 1: raise ProtocolBufferDecodeError("corrupted")
return b
def getPrefixedString(self):
length = self.getVarInt32()
if self.idx + length > self.limit:
raise ProtocolBufferDecodeError("truncated")
r = self.buf[self.idx : self.idx + length]
self.idx += length
return r.tobytes()
def getRawString(self):
r = self.buf[self.idx:self.limit]
self.idx = self.limit
return r.tobytes()
_TYPE_TO_METHOD = {
TYPE_DOUBLE: getDouble,
TYPE_FLOAT: getFloat,
TYPE_FIXED64: get64,
TYPE_FIXED32: get32,
TYPE_INT32: getVarInt32,
TYPE_INT64: getVarInt64,
TYPE_UINT64: getVarUint64,
TYPE_BOOL: getBoolean,
TYPE_STRING: getPrefixedString }
class ExtensionIdentifier(object):
__slots__ = ('full_name', 'number', 'field_type', 'wire_tag', 'is_repeated',
'default', 'containing_cls', 'composite_cls', 'message_name')
def __init__(self, full_name, number, field_type, wire_tag, is_repeated,
default):
self.full_name = full_name
self.number = number
self.field_type = field_type
self.wire_tag = wire_tag
self.is_repeated = is_repeated
self.default = default
class ExtendableProtocolMessage(ProtocolMessage):
def HasExtension(self, extension):
self._VerifyExtensionIdentifier(extension)
return extension in self._extension_fields
def ClearExtension(self, extension):
self._VerifyExtensionIdentifier(extension)
if extension in self._extension_fields:
del self._extension_fields[extension]
def GetExtension(self, extension, index=None):
self._VerifyExtensionIdentifier(extension)
if extension in self._extension_fields:
result = self._extension_fields[extension]
else:
if extension.is_repeated:
result = []
elif extension.composite_cls:
result = extension.composite_cls()
else:
result = extension.default
if extension.is_repeated:
result = result[index]
return result
def SetExtension(self, extension, *args):
self._VerifyExtensionIdentifier(extension)
if extension.composite_cls:
raise TypeError(
'Cannot assign to extension "%s" because it is a composite type.' %
extension.full_name)
if extension.is_repeated:
try:
index, value = args
except ValueError:
raise TypeError(
"SetExtension(extension, index, value) for repeated extension "
"takes exactly 4 arguments: (%d given)" % (len(args) + 2))
self._extension_fields[extension][index] = value
else:
try:
(value,) = args
except ValueError:
raise TypeError(
"SetExtension(extension, value) for singular extension "
"takes exactly 3 arguments: (%d given)" % (len(args) + 2))
self._extension_fields[extension] = value
def MutableExtension(self, extension, index=None):
self._VerifyExtensionIdentifier(extension)
if extension.composite_cls is None:
raise TypeError(
'MutableExtension() cannot be applied to "%s", because it is not a '
'composite type.' % extension.full_name)
if extension.is_repeated:
if index is None:
raise TypeError(
'MutableExtension(extension, index) for repeated extension '
'takes exactly 2 arguments: (1 given)')
return self.GetExtension(extension, index)
if extension in self._extension_fields:
return self._extension_fields[extension]
else:
result = extension.composite_cls()
self._extension_fields[extension] = result
return result
def ExtensionList(self, extension):
self._VerifyExtensionIdentifier(extension)
if not extension.is_repeated:
raise TypeError(
'ExtensionList() cannot be applied to "%s", because it is not a '
'repeated extension.' % extension.full_name)
if extension in self._extension_fields:
return self._extension_fields[extension]
result = []
self._extension_fields[extension] = result
return result
def ExtensionSize(self, extension):
self._VerifyExtensionIdentifier(extension)
if not extension.is_repeated:
raise TypeError(
'ExtensionSize() cannot be applied to "%s", because it is not a '
'repeated extension.' % extension.full_name)
if extension in self._extension_fields:
return len(self._extension_fields[extension])
return 0
def AddExtension(self, extension, value=None):
self._VerifyExtensionIdentifier(extension)
if not extension.is_repeated:
raise TypeError(
'AddExtension() cannot be applied to "%s", because it is not a '
'repeated extension.' % extension.full_name)
if extension in self._extension_fields:
field = self._extension_fields[extension]
else:
field = []
self._extension_fields[extension] = field
if extension.composite_cls:
if value is not None:
raise TypeError(
'value must not be set in AddExtension() for "%s", because it is '
'a message type extension. Set values on the returned message '
'instead.' % extension.full_name)
msg = extension.composite_cls()
field.append(msg)
return msg
field.append(value)
def _VerifyExtensionIdentifier(self, extension):
if extension.containing_cls != self.__class__:
raise TypeError("Containing type of %s is %s, but not %s."
% (extension.full_name,
extension.containing_cls.__name__,
self.__class__.__name__))
def _MergeExtensionFields(self, x):
for ext, val in x._extension_fields.items():
if ext.is_repeated:
for single_val in val:
if ext.composite_cls is None:
self.AddExtension(ext, single_val)
else:
self.AddExtension(ext).MergeFrom(single_val)
else:
if ext.composite_cls is None:
self.SetExtension(ext, val)
else:
self.MutableExtension(ext).MergeFrom(val)
def _ListExtensions(self):
return sorted(
(ext for ext in self._extension_fields
if (not ext.is_repeated) or self.ExtensionSize(ext) > 0),
key=lambda item: item.number)
def _ExtensionEquals(self, x):
extensions = self._ListExtensions()
if extensions != x._ListExtensions():
return False
for ext in extensions:
if ext.is_repeated:
if self.ExtensionSize(ext) != x.ExtensionSize(ext): return False
for e1, e2 in itertools.izip(self.ExtensionList(ext),
x.ExtensionList(ext)):
if e1 != e2: return False
else:
if self.GetExtension(ext) != x.GetExtension(ext): return False
return True
def _OutputExtensionFields(self, out, partial, extensions, start_index,
end_field_number):
def OutputSingleField(ext, value):
out.putVarInt32(ext.wire_tag)
if ext.field_type == TYPE_GROUP:
if partial:
value.OutputPartial(out)
else:
value.OutputUnchecked(out)
out.putVarInt32(ext.wire_tag + 1)
elif ext.field_type == TYPE_FOREIGN:
if partial:
out.putVarInt32(value.ByteSizePartial())
value.OutputPartial(out)
else:
out.putVarInt32(value.ByteSize())
value.OutputUnchecked(out)
else:
Encoder._TYPE_TO_METHOD[ext.field_type](out, value)
for ext_index, ext in enumerate(
itertools.islice(extensions, start_index, None), start=start_index):
if ext.number >= end_field_number:
return ext_index
if ext.is_repeated:
for field in self._extension_fields[ext]:
OutputSingleField(ext, field)
else:
OutputSingleField(ext, self._extension_fields[ext])
return len(extensions)
def _ParseOneExtensionField(self, wire_tag, d):
number = wire_tag >> 3
if number in self._extensions_by_field_number:
ext = self._extensions_by_field_number[number]
if wire_tag != ext.wire_tag:
return
if ext.field_type == TYPE_FOREIGN:
length = d.getVarInt32()
tmp = Decoder(d.buffer(), d.pos(), d.pos() + length)
if ext.is_repeated:
self.AddExtension(ext).TryMerge(tmp)
else:
self.MutableExtension(ext).TryMerge(tmp)
d.skip(length)
elif ext.field_type == TYPE_GROUP:
if ext.is_repeated:
self.AddExtension(ext).TryMerge(d)
else:
self.MutableExtension(ext).TryMerge(d)
else:
value = Decoder._TYPE_TO_METHOD[ext.field_type](d)
if ext.is_repeated:
self.AddExtension(ext, value)
else:
self.SetExtension(ext, value)
else:
d.skipData(wire_tag)
def _ExtensionByteSize(self, partial):
size = 0
for extension, value in self._extension_fields.iteritems():
ftype = extension.field_type
tag_size = self.lengthVarInt64(extension.wire_tag)
if ftype == TYPE_GROUP:
tag_size *= 2
if extension.is_repeated:
size += tag_size * len(value)
for single_value in value:
size += self._FieldByteSize(ftype, single_value, partial)
else:
size += tag_size + self._FieldByteSize(ftype, value, partial)
return size
def _FieldByteSize(self, ftype, value, partial):
size = 0
if ftype == TYPE_STRING:
size = self.lengthString(len(value))
elif ftype == TYPE_FOREIGN or ftype == TYPE_GROUP:
if partial:
size = self.lengthString(value.ByteSizePartial())
else:
size = self.lengthString(value.ByteSize())
elif ftype == TYPE_INT64 or ftype == TYPE_UINT64 or ftype == TYPE_INT32:
size = self.lengthVarInt64(value)
else:
if ftype in Encoder._TYPE_TO_BYTE_SIZE:
size = Encoder._TYPE_TO_BYTE_SIZE[ftype]
else:
raise AssertionError(
'Extension type %d is not recognized.' % ftype)
return size
def _ExtensionDebugString(self, prefix, printElemNumber):
res = ''
extensions = self._ListExtensions()
for extension in extensions:
value = self._extension_fields[extension]
if extension.is_repeated:
cnt = 0
for e in value:
elm=""
if printElemNumber: elm = "(%d)" % cnt
if extension.composite_cls is not None:
res += prefix + "[%s%s] {\n" % (extension.full_name, elm)
res += e.__str__(prefix + " ", printElemNumber)
res += prefix + "}\n"
else:
if extension.composite_cls is not None:
res += prefix + "[%s] {\n" % extension.full_name
res += value.__str__(
prefix + " ", printElemNumber)
res += prefix + "}\n"
else:
if extension.field_type in _TYPE_TO_DEBUG_STRING:
text_value = _TYPE_TO_DEBUG_STRING[
extension.field_type](self, value)
else:
text_value = self.DebugFormat(value)
res += prefix + "[%s]: %s\n" % (extension.full_name, text_value)
return res
@staticmethod
def _RegisterExtension(cls, extension, composite_cls=None):
extension.containing_cls = cls
extension.composite_cls = composite_cls
if composite_cls is not None:
extension.message_name = composite_cls._PROTO_DESCRIPTOR_NAME
actual_handle = cls._extensions_by_field_number.setdefault(
extension.number, extension)
if actual_handle is not extension:
raise AssertionError(
'Extensions "%s" and "%s" both try to extend message type "%s" with '
'field number %d.' %
(extension.full_name, actual_handle.full_name,
cls.__name__, extension.number))