blob: 6de3bcef7cf056f03472f64f83975e69df959699 [file] [log] [blame]
Adrià Vilanova Martínezf19ea432024-01-23 20:20:52 +01001#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""Protocol buffer support for message types.
19
20For more details about protocol buffer encoding and decoding please see:
21
22 http://code.google.com/apis/protocolbuffers/docs/encoding.html
23
24Public Exceptions:
25 DecodeError: Raised when a decode error occurs from incorrect protobuf format.
26
27Public Functions:
28 encode_message: Encodes a message in to a protocol buffer string.
29 decode_message: Decode from a protocol buffer string to a message.
30"""
31import six
32
33__author__ = 'rafek@google.com (Rafe Kaplan)'
34
35
36import array
37
38from . import message_types
39from . import messages
40from . import util
41from .google_imports import ProtocolBuffer
42
43
44__all__ = ['ALTERNATIVE_CONTENT_TYPES',
45 'CONTENT_TYPE',
46 'encode_message',
47 'decode_message',
48 ]
49
50CONTENT_TYPE = 'application/octet-stream'
51
52ALTERNATIVE_CONTENT_TYPES = ['application/x-google-protobuf']
53
54
55class _Encoder(ProtocolBuffer.Encoder):
56 """Extension of protocol buffer encoder.
57
58 Original protocol buffer encoder does not have complete set of methods
59 for handling required encoding. This class adds them.
60 """
61
62 # TODO(rafek): Implement the missing encoding types.
63 def no_encoding(self, value):
64 """No encoding available for type.
65
66 Args:
67 value: Value to encode.
68
69 Raises:
70 NotImplementedError at all times.
71 """
72 raise NotImplementedError()
73
74 def encode_enum(self, value):
75 """Encode an enum value.
76
77 Args:
78 value: Enum to encode.
79 """
80 self.putVarInt32(value.number)
81
82 def encode_message(self, value):
83 """Encode a Message in to an embedded message.
84
85 Args:
86 value: Message instance to encode.
87 """
88 self.putPrefixedString(encode_message(value))
89
90
91 def encode_unicode_string(self, value):
92 """Helper to properly pb encode unicode strings to UTF-8.
93
94 Args:
95 value: String value to encode.
96 """
97 if isinstance(value, six.text_type):
98 value = value.encode('utf-8')
99 self.putPrefixedString(value)
100
101
102class _Decoder(ProtocolBuffer.Decoder):
103 """Extension of protocol buffer decoder.
104
105 Original protocol buffer decoder does not have complete set of methods
106 for handling required decoding. This class adds them.
107 """
108
109 # TODO(rafek): Implement the missing encoding types.
110 def no_decoding(self):
111 """No decoding available for type.
112
113 Raises:
114 NotImplementedError at all times.
115 """
116 raise NotImplementedError()
117
118 def decode_string(self):
119 """Decode a unicode string.
120
121 Returns:
122 Next value in stream as a unicode string.
123 """
124 return self.getPrefixedString().decode('UTF-8')
125
126 def decode_boolean(self):
127 """Decode a boolean value.
128
129 Returns:
130 Next value in stream as a boolean.
131 """
132 return bool(self.getBoolean())
133
134
135# Number of bits used to describe a protocol buffer bits used for the variant.
136_WIRE_TYPE_BITS = 3
137_WIRE_TYPE_MASK = 7
138
139
140# Maps variant to underlying wire type. Many variants map to same type.
141_VARIANT_TO_WIRE_TYPE = {
142 messages.Variant.DOUBLE: _Encoder.DOUBLE,
143 messages.Variant.FLOAT: _Encoder.FLOAT,
144 messages.Variant.INT64: _Encoder.NUMERIC,
145 messages.Variant.UINT64: _Encoder.NUMERIC,
146 messages.Variant.INT32: _Encoder.NUMERIC,
147 messages.Variant.BOOL: _Encoder.NUMERIC,
148 messages.Variant.STRING: _Encoder.STRING,
149 messages.Variant.MESSAGE: _Encoder.STRING,
150 messages.Variant.BYTES: _Encoder.STRING,
151 messages.Variant.UINT32: _Encoder.NUMERIC,
152 messages.Variant.ENUM: _Encoder.NUMERIC,
153 messages.Variant.SINT32: _Encoder.NUMERIC,
154 messages.Variant.SINT64: _Encoder.NUMERIC,
155}
156
157
158# Maps variant to encoder method.
159_VARIANT_TO_ENCODER_MAP = {
160 messages.Variant.DOUBLE: _Encoder.putDouble,
161 messages.Variant.FLOAT: _Encoder.putFloat,
162 messages.Variant.INT64: _Encoder.putVarInt64,
163 messages.Variant.UINT64: _Encoder.putVarUint64,
164 messages.Variant.INT32: _Encoder.putVarInt32,
165 messages.Variant.BOOL: _Encoder.putBoolean,
166 messages.Variant.STRING: _Encoder.encode_unicode_string,
167 messages.Variant.MESSAGE: _Encoder.encode_message,
168 messages.Variant.BYTES: _Encoder.encode_unicode_string,
169 messages.Variant.UINT32: _Encoder.no_encoding,
170 messages.Variant.ENUM: _Encoder.encode_enum,
171 messages.Variant.SINT32: _Encoder.no_encoding,
172 messages.Variant.SINT64: _Encoder.no_encoding,
173}
174
175
176# Basic wire format decoders. Used for reading unknown values.
177_WIRE_TYPE_TO_DECODER_MAP = {
178 _Encoder.NUMERIC: _Decoder.getVarInt64,
179 _Encoder.DOUBLE: _Decoder.getDouble,
180 _Encoder.STRING: _Decoder.getPrefixedString,
181 _Encoder.FLOAT: _Decoder.getFloat,
182}
183
184
185# Map wire type to variant. Used to find a variant for unknown values.
186_WIRE_TYPE_TO_VARIANT_MAP = {
187 _Encoder.NUMERIC: messages.Variant.INT64,
188 _Encoder.DOUBLE: messages.Variant.DOUBLE,
189 _Encoder.STRING: messages.Variant.STRING,
190 _Encoder.FLOAT: messages.Variant.FLOAT,
191}
192
193
194# Wire type to name mapping for error messages.
195_WIRE_TYPE_NAME = {
196 _Encoder.NUMERIC: 'NUMERIC',
197 _Encoder.DOUBLE: 'DOUBLE',
198 _Encoder.STRING: 'STRING',
199 _Encoder.FLOAT: 'FLOAT',
200}
201
202
203# Maps variant to decoder method.
204_VARIANT_TO_DECODER_MAP = {
205 messages.Variant.DOUBLE: _Decoder.getDouble,
206 messages.Variant.FLOAT: _Decoder.getFloat,
207 messages.Variant.INT64: _Decoder.getVarInt64,
208 messages.Variant.UINT64: _Decoder.getVarUint64,
209 messages.Variant.INT32: _Decoder.getVarInt32,
210 messages.Variant.BOOL: _Decoder.decode_boolean,
211 messages.Variant.STRING: _Decoder.decode_string,
212 messages.Variant.MESSAGE: _Decoder.getPrefixedString,
213 messages.Variant.BYTES: _Decoder.getPrefixedString,
214 messages.Variant.UINT32: _Decoder.no_decoding,
215 messages.Variant.ENUM: _Decoder.getVarInt32,
216 messages.Variant.SINT32: _Decoder.no_decoding,
217 messages.Variant.SINT64: _Decoder.no_decoding,
218}
219
220
221def encode_message(message):
222 """Encode Message instance to protocol buffer.
223
224 Args:
225 Message instance to encode in to protocol buffer.
226
227 Returns:
228 String encoding of Message instance in protocol buffer format.
229
230 Raises:
231 messages.ValidationError if message is not initialized.
232 """
233 message.check_initialized()
234 encoder = _Encoder()
235
236 # Get all fields, from the known fields we parsed and the unknown fields
237 # we saved. Note which ones were known, so we can process them differently.
238 all_fields = [(field.number, field) for field in message.all_fields()]
239 all_fields.extend((key, None)
240 for key in message.all_unrecognized_fields()
241 if isinstance(key, six.integer_types))
242 all_fields.sort()
243 for field_num, field in all_fields:
244 if field:
245 # Known field.
246 value = message.get_assigned_value(field.name)
247 if value is None:
248 continue
249 variant = field.variant
250 repeated = field.repeated
251 else:
252 # Unrecognized field.
253 value, variant = message.get_unrecognized_field_info(field_num)
254 if not isinstance(variant, messages.Variant):
255 continue
256 repeated = isinstance(value, (list, tuple))
257
258 tag = ((field_num << _WIRE_TYPE_BITS) | _VARIANT_TO_WIRE_TYPE[variant])
259
260 # Write value to wire.
261 if repeated:
262 values = value
263 else:
264 values = [value]
265 for next in values:
266 encoder.putVarInt32(tag)
267 if isinstance(field, messages.MessageField):
268 next = field.value_to_message(next)
269 field_encoder = _VARIANT_TO_ENCODER_MAP[variant]
270 field_encoder(encoder, next)
271
272 buffer = encoder.buffer()
273 return buffer.tobytes()
274
275
276def decode_message(message_type, encoded_message):
277 """Decode protocol buffer to Message instance.
278
279 Args:
280 message_type: Message type to decode data to.
281 encoded_message: Encoded version of message as string.
282
283 Returns:
284 Decoded instance of message_type.
285
286 Raises:
287 DecodeError if an error occurs during decoding, such as incompatible
288 wire format for a field.
289 messages.ValidationError if merged message is not initialized.
290 """
291 message = message_type()
292 message_array = array.array('B')
293 message_array.frombytes(encoded_message)
294 try:
295 decoder = _Decoder(message_array, 0, len(message_array))
296
297 while decoder.avail() > 0:
298 # Decode tag and variant information.
299 encoded_tag = decoder.getVarInt32()
300 tag = encoded_tag >> _WIRE_TYPE_BITS
301 wire_type = encoded_tag & _WIRE_TYPE_MASK
302 try:
303 found_wire_type_decoder = _WIRE_TYPE_TO_DECODER_MAP[wire_type]
304 except:
305 raise messages.DecodeError('No such wire type %d' % wire_type)
306
307 if tag < 1:
308 raise messages.DecodeError('Invalid tag value %d' % tag)
309
310 try:
311 field = message.field_by_number(tag)
312 except KeyError:
313 # Unexpected tags are ok.
314 field = None
315 wire_type_decoder = found_wire_type_decoder
316 else:
317 expected_wire_type = _VARIANT_TO_WIRE_TYPE[field.variant]
318 if expected_wire_type != wire_type:
319 raise messages.DecodeError('Expected wire type %s but found %s' % (
320 _WIRE_TYPE_NAME[expected_wire_type],
321 _WIRE_TYPE_NAME[wire_type]))
322
323 wire_type_decoder = _VARIANT_TO_DECODER_MAP[field.variant]
324
325 value = wire_type_decoder(decoder)
326
327 # Save unknown fields and skip additional processing.
328 if not field:
329 # When saving this, save it under the tag number (which should
330 # be unique), and set the variant and value so we know how to
331 # interpret the value later.
332 variant = _WIRE_TYPE_TO_VARIANT_MAP.get(wire_type)
333 if variant:
334 message.set_unrecognized_field(tag, value, variant)
335 continue
336
337 # Special case Enum and Message types.
338 if isinstance(field, messages.EnumField):
339 try:
340 value = field.type(value)
341 except TypeError:
342 raise messages.DecodeError('Invalid enum value %s' % value)
343 elif isinstance(field, messages.MessageField):
344 value = decode_message(field.message_type, value)
345 value = field.value_from_message(value)
346
347 # Merge value in to message.
348 if field.repeated:
349 values = getattr(message, field.name)
350 if values is None:
351 setattr(message, field.name, [value])
352 else:
353 values.append(value)
354 else:
355 setattr(message, field.name, value)
356 except ProtocolBuffer.ProtocolBufferDecodeError as err:
357 raise messages.DecodeError('Decoding error: %s' % str(err))
358
359 message.check_initialized()
360 return message