diff --git a/telethon_generator/generators/errors.py b/telethon_generator/generators/errors.py index 6f5a6021..386575be 100644 --- a/telethon_generator/generators/errors.py +++ b/telethon_generator/generators/errors.py @@ -26,13 +26,16 @@ def generate_errors(errors, f): # Error classes generation for error in errors: - f.write('\n\nclass {}({}):\n' - ' def __init__(self, **kwargs):\n' - ' '.format(error.name, error.subclass)) + f.write('\n\nclass {}({}):\n '.format(error.name, error.subclass)) if error.has_captures: - f.write("self.{} = int(kwargs.get('capture', 0))\n " + f.write('def __init__(self, request, capture=0):\n ' + ' self.request = request\n ') + f.write(' self.{} = int(capture)\n ' .format(error.capture_name)) + else: + f.write('def __init__(self, request):\n ' + ' self.request = request\n ') f.write('super(Exception, self).__init__(' '{}'.format(repr(error.description))) @@ -40,7 +43,12 @@ def generate_errors(errors, f): if error.has_captures: f.write('.format({0}=self.{0})'.format(error.capture_name)) - f.write(" + self._fmt_request(kwargs['request']))\n") + f.write(' + self._fmt_request(self.request))\n\n') + f.write(' def __reduce__(self):\n ') + if error.has_captures: + f.write('return type(self), (self.request, self.{})\n'.format(error.capture_name)) + else: + f.write('return type(self), (self.request,)\n') # Create the actual {CODE: ErrorClassName} dict once classes are defined f.write('\n\nrpc_errors_dict = {\n') diff --git a/tests/telethon/test_pickle.py b/tests/telethon/test_pickle.py new file mode 100644 index 00000000..4854c66f --- /dev/null +++ b/tests/telethon/test_pickle.py @@ -0,0 +1,35 @@ +import pickle + +from telethon.errors import RPCError, BadRequestError, FileIdInvalidError, NetworkMigrateError + + +def _assert_equality(error, unpickled_error): + assert error.code == unpickled_error.code + assert error.message == unpickled_error.message + assert type(error) == type(unpickled_error) + assert str(error) == str(unpickled_error) + + +def test_base_rpcerror_pickle(): + error = RPCError("request", "message", 123) + unpickled_error = pickle.loads(pickle.dumps(error)) + _assert_equality(error, unpickled_error) + + +def test_rpcerror_pickle(): + error = BadRequestError("request", "BAD_REQUEST", 400) + unpickled_error = pickle.loads(pickle.dumps(error)) + _assert_equality(error, unpickled_error) + + +def test_fancy_rpcerror_pickle(): + error = FileIdInvalidError("request") + unpickled_error = pickle.loads(pickle.dumps(error)) + _assert_equality(error, unpickled_error) + + +def test_fancy_rpcerror_capture_pickle(): + error = NetworkMigrateError(request="request", capture=5) + unpickled_error = pickle.loads(pickle.dumps(error)) + _assert_equality(error, unpickled_error) + assert error.new_dc == unpickled_error.new_dc