blob: a45183d147c45fd734a7f43cd7a302e76313549b [file] [log] [blame]
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +01001#!/usr/bin/env python
2#
3# Copyright 2007 Google LLC
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40import array
41import itertools
42import re
43import six
44from six.moves import http_client
45import struct
46try:
47
48
49 import google.net.proto.proto1 as proto1
50except ImportError:
51
52 class ProtocolBufferDecodeError(Exception): pass
53 class ProtocolBufferEncodeError(Exception): pass
54 class ProtocolBufferReturnError(Exception): pass
55else:
56 ProtocolBufferDecodeError = proto1.ProtocolBufferDecodeError
57 ProtocolBufferEncodeError = proto1.ProtocolBufferEncodeError
58 ProtocolBufferReturnError = proto1.ProtocolBufferReturnError
59
60__all__ = ['ProtocolMessage', 'Encoder', 'Decoder',
61 'ExtendableProtocolMessage',
62 'ProtocolBufferDecodeError',
63 'ProtocolBufferEncodeError',
64 'ProtocolBufferReturnError']
65
66URL_RE = re.compile('^(https?)://([^/]+)(/.*)$')
67
68
69class ProtocolMessage:
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85 def __init__(self, contents=None):
86
87
88 raise NotImplementedError
89
90 def Clear(self):
91
92
93 raise NotImplementedError
94
95 def IsInitialized(self, debug_strs=None):
96
97 raise NotImplementedError
98
99 def Encode(self):
100
101 try:
102 return self._CEncode()
103 except (NotImplementedError, AttributeError):
104 e = Encoder()
105 self.Output(e)
106 return e.buffer().tobytes()
107
108 def SerializeToString(self):
109
110 return self.Encode()
111
112 def SerializePartialToString(self):
113
114
115
116 try:
117 return self._CEncodePartial()
118 except (NotImplementedError, AttributeError):
119 e = Encoder()
120 self.OutputPartial(e)
121 return e.buffer().tobytes()
122
123 def _CEncode(self):
124
125
126
127
128
129
130
131 raise NotImplementedError
132
133 def _CEncodePartial(self):
134
135 raise NotImplementedError
136
137 def ParseFromString(self, s):
138
139
140
141 self.Clear()
142 self.MergeFromString(s)
143
144 def ParsePartialFromString(self, s):
145
146
147 self.Clear()
148 self.MergePartialFromString(s)
149
150 def MergeFromString(self, s):
151
152
153
154 self.MergePartialFromString(s)
155 dbg = []
156 if not self.IsInitialized(dbg):
157 raise ProtocolBufferDecodeError('\n\t'.join(dbg))
158
159 def MergePartialFromString(self, s):
160
161
162 try:
163 self._CMergeFromString(s)
164 except (NotImplementedError, AttributeError):
165
166
167 a = array.array('B')
168 a.frombytes(six.ensure_binary(s))
169 d = Decoder(a, 0, len(a))
170 self.TryMerge(d)
171
172 def _CMergeFromString(self, s):
173
174
175
176
177
178
179
180
181
182 raise NotImplementedError
183
184 def __getstate__(self):
185
186
187 return self.Encode()
188
189 def __setstate__(self, contents_):
190
191
192 self.__init__(contents=contents_)
193
194 def sendCommand(self, server, url, response, follow_redirects=1,
195 secure=0, keyfile=None, certfile=None):
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211 data = self.Encode()
212 if secure:
213 if keyfile and certfile:
214 conn = http_client.HTTPSConnection(server, key_file=keyfile,
215 cert_file=certfile)
216 else:
217 conn = http_client.HTTPSConnection(server)
218 else:
219 conn = http_client.HTTPConnection(server)
220 conn.putrequest("POST", url)
221 conn.putheader("Content-Length", "%d" %len(data))
222 conn.endheaders()
223 conn.send(data)
224 resp = conn.getresponse()
225 if follow_redirects > 0 and resp.status == 302:
226 m = URL_RE.match(resp.getheader('Location'))
227 if m:
228 protocol, server, url = m.groups()
229 return self.sendCommand(server, url, response,
230 follow_redirects=follow_redirects - 1,
231 secure=(protocol == 'https'),
232 keyfile=keyfile,
233 certfile=certfile)
234 if resp.status != 200:
235 raise ProtocolBufferReturnError(resp.status)
236 if response is not None:
237 response.ParseFromString(resp.read())
238 return response
239
240 def sendSecureCommand(self, server, keyfile, certfile, url, response,
241 follow_redirects=1):
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257 return self.sendCommand(server, url, response,
258 follow_redirects=follow_redirects,
259 secure=1, keyfile=keyfile, certfile=certfile)
260
261 def __str__(self, prefix="", printElemNumber=0):
262
263 raise NotImplementedError
264
265 def ToASCII(self):
266
267 return self._CToASCII(ProtocolMessage._SYMBOLIC_FULL_ASCII)
268
269 def ToShortASCII(self):
270
271
272
273
274 return self._CToASCII(ProtocolMessage._SYMBOLIC_SHORT_ASCII)
275
276
277
278 _NUMERIC_ASCII = 0
279 _SYMBOLIC_SHORT_ASCII = 1
280 _SYMBOLIC_FULL_ASCII = 2
281
282 def _CToASCII(self, output_format):
283
284
285
286
287
288 raise NotImplementedError
289
290 def ParseASCII(self, ascii_string):
291
292
293
294
295 raise NotImplementedError
296
297 def ParseASCIIIgnoreUnknown(self, ascii_string):
298
299
300
301
302 raise NotImplementedError
303
304 def Equals(self, other):
305
306
307
308
309 raise NotImplementedError
310
311 def __eq__(self, other):
312
313
314
315
316
317
318 if other.__class__ is self.__class__:
319 return self.Equals(other)
320 return NotImplemented
321
322 def __ne__(self, other):
323
324
325
326
327
328
329 if other.__class__ is self.__class__:
330 return not self.Equals(other)
331 return NotImplemented
332
333
334
335
336
337 def Output(self, e):
338
339 dbg = []
340 if not self.IsInitialized(dbg):
341 raise ProtocolBufferEncodeError('\n\t'.join(dbg))
342 self.OutputUnchecked(e)
343 return
344
345 def OutputUnchecked(self, e):
346
347 raise NotImplementedError
348
349 def OutputPartial(self, e):
350
351
352 raise NotImplementedError
353
354 def Parse(self, d):
355
356 self.Clear()
357 self.Merge(d)
358 return
359
360 def Merge(self, d):
361
362 self.TryMerge(d)
363 dbg = []
364 if not self.IsInitialized(dbg):
365 raise ProtocolBufferDecodeError('\n\t'.join(dbg))
366 return
367
368 def TryMerge(self, d):
369
370 raise NotImplementedError
371
372 def CopyFrom(self, pb):
373
374 if (pb == self): return
375 self.Clear()
376 self.MergeFrom(pb)
377
378 def MergeFrom(self, pb):
379
380 raise NotImplementedError
381
382
383
384
385
386 def lengthVarInt32(self, n):
387 return self.lengthVarInt64(n)
388
389 def lengthVarInt64(self, n):
390 if n < 0:
391 return 10
392 result = 0
393 while 1:
394 result += 1
395 n >>= 7
396 if n == 0:
397 break
398 return result
399
400 def lengthString(self, n):
401 return self.lengthVarInt32(n) + n
402
403 def DebugFormat(self, value):
404 return "%s" % value
405 def DebugFormatInt32(self, value):
406 if (value <= -2000000000 or value >= 2000000000):
407 return self.DebugFormatFixed32(value)
408 return "%d" % value
409 def DebugFormatInt64(self, value):
410 if (value <= -20000000000000 or value >= 20000000000000):
411 return self.DebugFormatFixed64(value)
412 return "%d" % value
413 def DebugFormatString(self, value):
414
415
416
417 def escape(c):
418 o = ord(c)
419 if o == 10: return r"\n"
420 if o == 39: return r"\'"
421
422 if o == 34: return r'\"'
423 if o == 92: return r"\\"
424
425 if o >= 127 or o < 32: return "\\%03o" % o
426 return c
427 return '"' + "".join(escape(c) for c in value) + '"'
428 def DebugFormatFloat(self, value):
429 return "%ff" % value
430 def DebugFormatFixed32(self, value):
431 if (value < 0): value += (1<<32)
432 return "0x%x" % value
433 def DebugFormatFixed64(self, value):
434 if (value < 0): value += (1<<64)
435 return "0x%x" % value
436 def DebugFormatBool(self, value):
437 if value:
438 return "true"
439 else:
440 return "false"
441
442
443TYPE_DOUBLE = 1
444TYPE_FLOAT = 2
445TYPE_INT64 = 3
446TYPE_UINT64 = 4
447TYPE_INT32 = 5
448TYPE_FIXED64 = 6
449TYPE_FIXED32 = 7
450TYPE_BOOL = 8
451TYPE_STRING = 9
452TYPE_GROUP = 10
453TYPE_FOREIGN = 11
454
455
456_TYPE_TO_DEBUG_STRING = {
457 TYPE_INT32: ProtocolMessage.DebugFormatInt32,
458 TYPE_INT64: ProtocolMessage.DebugFormatInt64,
459 TYPE_UINT64: ProtocolMessage.DebugFormatInt64,
460 TYPE_FLOAT: ProtocolMessage.DebugFormatFloat,
461 TYPE_STRING: ProtocolMessage.DebugFormatString,
462 TYPE_FIXED32: ProtocolMessage.DebugFormatFixed32,
463 TYPE_FIXED64: ProtocolMessage.DebugFormatFixed64,
464 TYPE_BOOL: ProtocolMessage.DebugFormatBool }
465
466
467
468class Encoder:
469
470
471 NUMERIC = 0
472 DOUBLE = 1
473 STRING = 2
474 STARTGROUP = 3
475 ENDGROUP = 4
476 FLOAT = 5
477 MAX_TYPE = 6
478
479 def __init__(self):
480 self.buf = array.array('B')
481 return
482
483 def buffer(self):
484 return self.buf
485
486 def put8(self, v):
487 if v < 0 or v >= (1<<8): raise ProtocolBufferEncodeError("u8 too big")
488 self.buf.append(v & 255)
489 return
490
491 def put16(self, v):
492 if v < 0 or v >= (1<<16): raise ProtocolBufferEncodeError("u16 too big")
493 self.buf.append((v >> 0) & 255)
494 self.buf.append((v >> 8) & 255)
495 return
496
497 def put32(self, v):
498 if v < 0 or v >= (1<<32): raise ProtocolBufferEncodeError("u32 too big")
499 self.buf.append((v >> 0) & 255)
500 self.buf.append((v >> 8) & 255)
501 self.buf.append((v >> 16) & 255)
502 self.buf.append((v >> 24) & 255)
503 return
504
505 def put64(self, v):
506 if v < 0 or v >= (1<<64): raise ProtocolBufferEncodeError("u64 too big")
507 self.buf.append((v >> 0) & 255)
508 self.buf.append((v >> 8) & 255)
509 self.buf.append((v >> 16) & 255)
510 self.buf.append((v >> 24) & 255)
511 self.buf.append((v >> 32) & 255)
512 self.buf.append((v >> 40) & 255)
513 self.buf.append((v >> 48) & 255)
514 self.buf.append((v >> 56) & 255)
515 return
516
517 def putVarInt32(self, v):
518
519
520
521
522
523
524
525
526 buf_append = self.buf.append
527 if v & 127 == v:
528 buf_append(v)
529 return
530 if v >= 0x80000000 or v < -0x80000000:
531 raise ProtocolBufferEncodeError("int32 too big")
532 if v < 0:
533 v += 0x10000000000000000
534 while True:
535 bits = v & 127
536 v >>= 7
537 if v:
538 bits |= 128
539 buf_append(bits)
540 if not v:
541 break
542 return
543
544 def putVarInt64(self, v):
545 buf_append = self.buf.append
546 if v >= 0x8000000000000000 or v < -0x8000000000000000:
547 raise ProtocolBufferEncodeError("int64 too big")
548 if v < 0:
549 v += 0x10000000000000000
550 while True:
551 bits = v & 127
552 v >>= 7
553 if v:
554 bits |= 128
555 buf_append(bits)
556 if not v:
557 break
558 return
559
560 def putVarUint64(self, v):
561 buf_append = self.buf.append
562 if v < 0 or v >= 0x10000000000000000:
563 raise ProtocolBufferEncodeError("uint64 too big")
564 while True:
565 bits = v & 127
566 v >>= 7
567 if v:
568 bits |= 128
569 buf_append(bits)
570 if not v:
571 break
572 return
573
574 def putFloat(self, v):
575 a = array.array('B')
576 a.frombytes(struct.pack("<f", v))
577 self.buf.extend(a)
578 return
579
580 def putDouble(self, v):
581 a = array.array('B')
582 a.frombytes(struct.pack("<d", v))
583 self.buf.extend(a)
584 return
585
586 def putBoolean(self, v):
587 if v:
588 self.buf.append(1)
589 else:
590 self.buf.append(0)
591 return
592
593 def putPrefixedString(self, v):
594
595
596
597 v = six.ensure_binary(v)
598 self.putVarInt32(len(v))
599 self.buf.frombytes(v)
600
601 def putRawString(self, v):
602 self.buf.frombytes(six.ensure_binary(v))
603
604 _TYPE_TO_METHOD = {
605 TYPE_DOUBLE: putDouble,
606 TYPE_FLOAT: putFloat,
607 TYPE_FIXED64: put64,
608 TYPE_FIXED32: put32,
609 TYPE_INT32: putVarInt32,
610 TYPE_INT64: putVarInt64,
611 TYPE_UINT64: putVarUint64,
612 TYPE_BOOL: putBoolean,
613 TYPE_STRING: putPrefixedString }
614
615 _TYPE_TO_BYTE_SIZE = {
616 TYPE_DOUBLE: 8,
617 TYPE_FLOAT: 4,
618 TYPE_FIXED64: 8,
619 TYPE_FIXED32: 4,
620 TYPE_BOOL: 1 }
621
622class Decoder:
623 def __init__(self, buf, idx, limit):
624 self.buf = buf
625 self.idx = idx
626 self.limit = limit
627 return
628
629 def avail(self):
630 return self.limit - self.idx
631
632 def buffer(self):
633 return self.buf
634
635 def pos(self):
636 return self.idx
637
638 def skip(self, n):
639 if self.idx + n > self.limit: raise ProtocolBufferDecodeError("truncated")
640 self.idx += n
641 return
642
643 def skipData(self, tag):
644 t = tag & 7
645 if t == Encoder.NUMERIC:
646 self.getVarInt64()
647 elif t == Encoder.DOUBLE:
648 self.skip(8)
649 elif t == Encoder.STRING:
650 n = self.getVarInt32()
651 self.skip(n)
652 elif t == Encoder.STARTGROUP:
653 while 1:
654 t = self.getVarInt32()
655 if (t & 7) == Encoder.ENDGROUP:
656 break
657 else:
658 self.skipData(t)
659 if (t - Encoder.ENDGROUP) != (tag - Encoder.STARTGROUP):
660 raise ProtocolBufferDecodeError("corrupted")
661 elif t == Encoder.ENDGROUP:
662 raise ProtocolBufferDecodeError("corrupted")
663 elif t == Encoder.FLOAT:
664 self.skip(4)
665 else:
666 raise ProtocolBufferDecodeError("corrupted")
667
668
669 def get8(self):
670 if self.idx >= self.limit: raise ProtocolBufferDecodeError("truncated")
671 c = self.buf[self.idx]
672 self.idx += 1
673 return c
674
675 def get16(self):
676 if self.idx + 2 > self.limit: raise ProtocolBufferDecodeError("truncated")
677 c = self.buf[self.idx]
678 d = self.buf[self.idx + 1]
679 self.idx += 2
680 return (d << 8) | c
681
682 def get32(self):
683 if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError("truncated")
684 c = self.buf[self.idx]
685 d = self.buf[self.idx + 1]
686 e = self.buf[self.idx + 2]
687 f = self.buf[self.idx + 3]
688 self.idx += 4
689 return (f << 24) | (e << 16) | (d << 8) | c
690
691 def get64(self):
692 if self.idx + 8 > self.limit: raise ProtocolBufferDecodeError("truncated")
693 c = self.buf[self.idx]
694 d = self.buf[self.idx + 1]
695 e = self.buf[self.idx + 2]
696 f = self.buf[self.idx + 3]
697 g = self.buf[self.idx + 4]
698 h = self.buf[self.idx + 5]
699 i = self.buf[self.idx + 6]
700 j = self.buf[self.idx + 7]
701 self.idx += 8
702 return ((j << 56) | (i << 48) | (h << 40) | (g << 32) | (f << 24)
703 | (e << 16) | (d << 8) | c)
704
705 def getVarInt32(self):
706
707
708
709 b = self.get8()
710 if not (b & 128):
711 return b
712
713 result = 0
714 shift = 0
715
716 while 1:
717 result |= ((b & 127) << shift)
718 shift += 7
719 if not (b & 128):
720 if result >= 0x10000000000000000:
721 raise ProtocolBufferDecodeError("corrupted")
722 break
723 if shift >= 64: raise ProtocolBufferDecodeError("corrupted")
724 b = self.get8()
725
726 if result >= 0x8000000000000000:
727 result -= 0x10000000000000000
728 if result >= 0x80000000 or result < -0x80000000:
729 raise ProtocolBufferDecodeError("corrupted")
730 return result
731
732 def getVarInt64(self):
733 result = self.getVarUint64()
734 if result >= (1 << 63):
735 result -= (1 << 64)
736 return result
737
738 def getVarUint64(self):
739 result = 0
740 shift = 0
741 while 1:
742 if shift >= 64: raise ProtocolBufferDecodeError("corrupted")
743 b = self.get8()
744 result |= ((b & 127) << shift)
745 shift += 7
746 if not (b & 128):
747 if result >= (1 << 64): raise ProtocolBufferDecodeError("corrupted")
748 return result
749 return result
750
751 def getFloat(self):
752 if self.idx + 4 > self.limit: raise ProtocolBufferDecodeError("truncated")
753 a = self.buf[self.idx:self.idx+4]
754 self.idx += 4
755 return struct.unpack("<f", a)[0]
756
757 def getDouble(self):
758 if self.idx + 8 > self.limit: raise ProtocolBufferDecodeError("truncated")
759 a = self.buf[self.idx:self.idx+8]
760 self.idx += 8
761 return struct.unpack("<d", a)[0]
762
763 def getBoolean(self):
764 b = self.get8()
765 if b != 0 and b != 1: raise ProtocolBufferDecodeError("corrupted")
766 return b
767
768 def getPrefixedString(self):
769 length = self.getVarInt32()
770 if self.idx + length > self.limit:
771 raise ProtocolBufferDecodeError("truncated")
772 r = self.buf[self.idx : self.idx + length]
773 self.idx += length
774 return r.tobytes()
775
776 def getRawString(self):
777 r = self.buf[self.idx:self.limit]
778 self.idx = self.limit
779 return r.tobytes()
780
781 _TYPE_TO_METHOD = {
782 TYPE_DOUBLE: getDouble,
783 TYPE_FLOAT: getFloat,
784 TYPE_FIXED64: get64,
785 TYPE_FIXED32: get32,
786 TYPE_INT32: getVarInt32,
787 TYPE_INT64: getVarInt64,
788 TYPE_UINT64: getVarUint64,
789 TYPE_BOOL: getBoolean,
790 TYPE_STRING: getPrefixedString }
791
792
793
794
795
796class ExtensionIdentifier(object):
797 __slots__ = ('full_name', 'number', 'field_type', 'wire_tag', 'is_repeated',
798 'default', 'containing_cls', 'composite_cls', 'message_name')
799 def __init__(self, full_name, number, field_type, wire_tag, is_repeated,
800 default):
801 self.full_name = full_name
802 self.number = number
803 self.field_type = field_type
804 self.wire_tag = wire_tag
805 self.is_repeated = is_repeated
806 self.default = default
807
808class ExtendableProtocolMessage(ProtocolMessage):
809 def HasExtension(self, extension):
810
811 self._VerifyExtensionIdentifier(extension)
812 return extension in self._extension_fields
813
814 def ClearExtension(self, extension):
815
816
817 self._VerifyExtensionIdentifier(extension)
818 if extension in self._extension_fields:
819 del self._extension_fields[extension]
820
821 def GetExtension(self, extension, index=None):
822
823
824
825
826
827
828
829
830
831
832
833 self._VerifyExtensionIdentifier(extension)
834 if extension in self._extension_fields:
835 result = self._extension_fields[extension]
836 else:
837 if extension.is_repeated:
838 result = []
839 elif extension.composite_cls:
840 result = extension.composite_cls()
841 else:
842 result = extension.default
843 if extension.is_repeated:
844 result = result[index]
845 return result
846
847 def SetExtension(self, extension, *args):
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864 self._VerifyExtensionIdentifier(extension)
865 if extension.composite_cls:
866 raise TypeError(
867 'Cannot assign to extension "%s" because it is a composite type.' %
868 extension.full_name)
869 if extension.is_repeated:
870 try:
871 index, value = args
872 except ValueError:
873 raise TypeError(
874 "SetExtension(extension, index, value) for repeated extension "
875 "takes exactly 4 arguments: (%d given)" % (len(args) + 2))
876 self._extension_fields[extension][index] = value
877 else:
878 try:
879 (value,) = args
880 except ValueError:
881 raise TypeError(
882 "SetExtension(extension, value) for singular extension "
883 "takes exactly 3 arguments: (%d given)" % (len(args) + 2))
884 self._extension_fields[extension] = value
885
886 def MutableExtension(self, extension, index=None):
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904 self._VerifyExtensionIdentifier(extension)
905 if extension.composite_cls is None:
906 raise TypeError(
907 'MutableExtension() cannot be applied to "%s", because it is not a '
908 'composite type.' % extension.full_name)
909 if extension.is_repeated:
910 if index is None:
911 raise TypeError(
912 'MutableExtension(extension, index) for repeated extension '
913 'takes exactly 2 arguments: (1 given)')
914 return self.GetExtension(extension, index)
915 if extension in self._extension_fields:
916 return self._extension_fields[extension]
917 else:
918 result = extension.composite_cls()
919 self._extension_fields[extension] = result
920 return result
921
922 def ExtensionList(self, extension):
923
924
925
926
927
928 self._VerifyExtensionIdentifier(extension)
929 if not extension.is_repeated:
930 raise TypeError(
931 'ExtensionList() cannot be applied to "%s", because it is not a '
932 'repeated extension.' % extension.full_name)
933 if extension in self._extension_fields:
934 return self._extension_fields[extension]
935 result = []
936 self._extension_fields[extension] = result
937 return result
938
939 def ExtensionSize(self, extension):
940
941
942
943
944
945 self._VerifyExtensionIdentifier(extension)
946 if not extension.is_repeated:
947 raise TypeError(
948 'ExtensionSize() cannot be applied to "%s", because it is not a '
949 'repeated extension.' % extension.full_name)
950 if extension in self._extension_fields:
951 return len(self._extension_fields[extension])
952 return 0
953
954 def AddExtension(self, extension, value=None):
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976 self._VerifyExtensionIdentifier(extension)
977 if not extension.is_repeated:
978 raise TypeError(
979 'AddExtension() cannot be applied to "%s", because it is not a '
980 'repeated extension.' % extension.full_name)
981 if extension in self._extension_fields:
982 field = self._extension_fields[extension]
983 else:
984 field = []
985 self._extension_fields[extension] = field
986
987 if extension.composite_cls:
988 if value is not None:
989 raise TypeError(
990 'value must not be set in AddExtension() for "%s", because it is '
991 'a message type extension. Set values on the returned message '
992 'instead.' % extension.full_name)
993 msg = extension.composite_cls()
994 field.append(msg)
995 return msg
996
997 field.append(value)
998
999 def _VerifyExtensionIdentifier(self, extension):
1000 if extension.containing_cls != self.__class__:
1001 raise TypeError("Containing type of %s is %s, but not %s."
1002 % (extension.full_name,
1003 extension.containing_cls.__name__,
1004 self.__class__.__name__))
1005
1006 def _MergeExtensionFields(self, x):
1007 for ext, val in x._extension_fields.items():
1008 if ext.is_repeated:
1009 for single_val in val:
1010 if ext.composite_cls is None:
1011 self.AddExtension(ext, single_val)
1012 else:
1013 self.AddExtension(ext).MergeFrom(single_val)
1014 else:
1015 if ext.composite_cls is None:
1016 self.SetExtension(ext, val)
1017 else:
1018 self.MutableExtension(ext).MergeFrom(val)
1019
1020 def _ListExtensions(self):
1021 return sorted(
1022 (ext for ext in self._extension_fields
1023 if (not ext.is_repeated) or self.ExtensionSize(ext) > 0),
1024 key=lambda item: item.number)
1025
1026 def _ExtensionEquals(self, x):
1027 extensions = self._ListExtensions()
1028 if extensions != x._ListExtensions():
1029 return False
1030 for ext in extensions:
1031 if ext.is_repeated:
1032 if self.ExtensionSize(ext) != x.ExtensionSize(ext): return False
1033 for e1, e2 in itertools.izip(self.ExtensionList(ext),
1034 x.ExtensionList(ext)):
1035 if e1 != e2: return False
1036 else:
1037 if self.GetExtension(ext) != x.GetExtension(ext): return False
1038 return True
1039
1040 def _OutputExtensionFields(self, out, partial, extensions, start_index,
1041 end_field_number):
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067 def OutputSingleField(ext, value):
1068 out.putVarInt32(ext.wire_tag)
1069 if ext.field_type == TYPE_GROUP:
1070 if partial:
1071 value.OutputPartial(out)
1072 else:
1073 value.OutputUnchecked(out)
1074 out.putVarInt32(ext.wire_tag + 1)
1075 elif ext.field_type == TYPE_FOREIGN:
1076 if partial:
1077 out.putVarInt32(value.ByteSizePartial())
1078 value.OutputPartial(out)
1079 else:
1080 out.putVarInt32(value.ByteSize())
1081 value.OutputUnchecked(out)
1082 else:
1083 Encoder._TYPE_TO_METHOD[ext.field_type](out, value)
1084
1085 for ext_index, ext in enumerate(
1086 itertools.islice(extensions, start_index, None), start=start_index):
1087 if ext.number >= end_field_number:
1088
1089 return ext_index
1090 if ext.is_repeated:
1091 for field in self._extension_fields[ext]:
1092 OutputSingleField(ext, field)
1093 else:
1094 OutputSingleField(ext, self._extension_fields[ext])
1095 return len(extensions)
1096
1097 def _ParseOneExtensionField(self, wire_tag, d):
1098 number = wire_tag >> 3
1099 if number in self._extensions_by_field_number:
1100 ext = self._extensions_by_field_number[number]
1101 if wire_tag != ext.wire_tag:
1102
1103 return
1104 if ext.field_type == TYPE_FOREIGN:
1105 length = d.getVarInt32()
1106 tmp = Decoder(d.buffer(), d.pos(), d.pos() + length)
1107 if ext.is_repeated:
1108 self.AddExtension(ext).TryMerge(tmp)
1109 else:
1110 self.MutableExtension(ext).TryMerge(tmp)
1111 d.skip(length)
1112 elif ext.field_type == TYPE_GROUP:
1113 if ext.is_repeated:
1114 self.AddExtension(ext).TryMerge(d)
1115 else:
1116 self.MutableExtension(ext).TryMerge(d)
1117 else:
1118 value = Decoder._TYPE_TO_METHOD[ext.field_type](d)
1119 if ext.is_repeated:
1120 self.AddExtension(ext, value)
1121 else:
1122 self.SetExtension(ext, value)
1123 else:
1124
1125 d.skipData(wire_tag)
1126
1127 def _ExtensionByteSize(self, partial):
1128 size = 0
1129 for extension, value in self._extension_fields.iteritems():
1130 ftype = extension.field_type
1131 tag_size = self.lengthVarInt64(extension.wire_tag)
1132 if ftype == TYPE_GROUP:
1133 tag_size *= 2
1134 if extension.is_repeated:
1135 size += tag_size * len(value)
1136 for single_value in value:
1137 size += self._FieldByteSize(ftype, single_value, partial)
1138 else:
1139 size += tag_size + self._FieldByteSize(ftype, value, partial)
1140 return size
1141
1142 def _FieldByteSize(self, ftype, value, partial):
1143 size = 0
1144 if ftype == TYPE_STRING:
1145 size = self.lengthString(len(value))
1146 elif ftype == TYPE_FOREIGN or ftype == TYPE_GROUP:
1147 if partial:
1148 size = self.lengthString(value.ByteSizePartial())
1149 else:
1150 size = self.lengthString(value.ByteSize())
1151 elif ftype == TYPE_INT64 or ftype == TYPE_UINT64 or ftype == TYPE_INT32:
1152 size = self.lengthVarInt64(value)
1153 else:
1154 if ftype in Encoder._TYPE_TO_BYTE_SIZE:
1155 size = Encoder._TYPE_TO_BYTE_SIZE[ftype]
1156 else:
1157 raise AssertionError(
1158 'Extension type %d is not recognized.' % ftype)
1159 return size
1160
1161 def _ExtensionDebugString(self, prefix, printElemNumber):
1162 res = ''
1163 extensions = self._ListExtensions()
1164 for extension in extensions:
1165 value = self._extension_fields[extension]
1166 if extension.is_repeated:
1167 cnt = 0
1168 for e in value:
1169 elm=""
1170 if printElemNumber: elm = "(%d)" % cnt
1171 if extension.composite_cls is not None:
1172 res += prefix + "[%s%s] {\n" % (extension.full_name, elm)
1173 res += e.__str__(prefix + " ", printElemNumber)
1174 res += prefix + "}\n"
1175 else:
1176 if extension.composite_cls is not None:
1177 res += prefix + "[%s] {\n" % extension.full_name
1178 res += value.__str__(
1179 prefix + " ", printElemNumber)
1180 res += prefix + "}\n"
1181 else:
1182 if extension.field_type in _TYPE_TO_DEBUG_STRING:
1183 text_value = _TYPE_TO_DEBUG_STRING[
1184 extension.field_type](self, value)
1185 else:
1186 text_value = self.DebugFormat(value)
1187 res += prefix + "[%s]: %s\n" % (extension.full_name, text_value)
1188 return res
1189
1190 @staticmethod
1191 def _RegisterExtension(cls, extension, composite_cls=None):
1192 extension.containing_cls = cls
1193 extension.composite_cls = composite_cls
1194 if composite_cls is not None:
1195 extension.message_name = composite_cls._PROTO_DESCRIPTOR_NAME
1196 actual_handle = cls._extensions_by_field_number.setdefault(
1197 extension.number, extension)
1198 if actual_handle is not extension:
1199 raise AssertionError(
1200 'Extensions "%s" and "%s" both try to extend message type "%s" with '
1201 'field number %d.' %
1202 (extension.full_name, actual_handle.full_name,
1203 cls.__name__, extension.number))