import unittest from unittest.mock import patch, mock_open, MagicMock import yaml from py_sdk import PySdk # 假设你的 PySdk 类在 pysdk.py 文件中 from pyspark.sql import SparkSession, Row class TestPySdk(unittest.TestCase): def setUp(self): # 示例配置数据 self.example_config = { 'redis': {'host': 'localhost', 'port': 6379, 'database': 0, 'password': 'secret'}, 'mysql': {'host': 'localhost', 'port': 3306, 'database': 'test_db', 'username': 'user', 'password': 'pass'}, 'portManager': {'protocol': 'http', 'host': 'localhost', 'port': 8080, 'path': '/api'}, 'jobManager': {'protocol': 'http', 'host': 'localhost', 'port': 8081, 'path': '/jobs', 'result': '/result'}, 'serverInfo': {'ip': '192.168.1.1', 'port': 3000}, 'mira_ida_access_service': {'ip': '192.168.1.2', 'port': 50051} } self.config_yaml = yaml.safe_dump(self.example_config) self.job_id = "job123" self.task_name = "task_example" self.org_did = "org123" # 模拟 MySQL 连接 URL mysql_url = 'jdbc:mysql://localhost:3306/mydatabase' # 模拟 MySQL 连接属性 mysql_prop = { 'user': 'myuser', 'password': 'mypassword', 'driver': 'com.mysql.cj.jdbc.Driver' } # 模拟的表名 first_table_name = 'my_table' # 构造元组 self.mock_conn_info = (mysql_url, mysql_prop, first_table_name) patcher = patch('yaml.safe_load', return_value=self.example_config) self.addCleanup(patcher.stop) # 确保在测试后清理 self.mock_safe_load = patcher.start() @patch('pyspark.sql.SparkSession') # @patch('redis.StrictRedis') #@patch('builtins.open', new_callable=mock_open, read_data='config_yaml') def test_init(self,mock_spark): # 测试初始化逻辑 sdk = PySdk('./config.yaml', self.job_id, self.task_name, self.org_did) #mock_file.assert_called_with('dummy_config.yml', 'r') # mock_redis.assert_called_once() # mock_spark.builder.appName.assert_called_once_with("mira") # mock_spark.builder.getOrCreate.assert_called_once() self.assertEqual(sdk.server_ip, '192.168.1.1') self.assertEqual(sdk.server_port, 3000) @patch('pyspark.sql.DataFrameReader.jdbc') @patch('py_sdk.PySdk.get_source_conn_info', return_value=None) def test_read_input_from_mysql_v2(self, mock_get_source_conn_info, mock_jdbc): mock_get_source_conn_info.return_value = self.mock_conn_info # 设置 SparkSession 模拟 spark = SparkSession.builder.appName("TestSession").getOrCreate() sdk = PySdk('./config.yaml', self.job_id, self.task_name, self.org_did) sdk.spark = spark # 准备模拟数据 data1 = [Row(key=1, value="A"), Row(key=2, value="B")] df1 = spark.createDataFrame(data1) data2 = [Row(key=1, value="C"), Row(key=2, value="D")] df2 = spark.createDataFrame(data2) # 配置 jdbc 方法返回模拟的 DataFrame mock_jdbc.side_effect = [df1, df2] # 模拟 asset_en_info 输入 asset_en_info = [ {'asset_en_name': 'table1', 'column': 'value', 'key': 'key'}, {'asset_en_name': 'table2', 'column': 'value', 'key': 'key'} ] # 调用待测试方法 result = sdk.read_input_from_mysql_v2(asset_en_info,"chain") # # 验证 jdbc 被正确调用 # expected_calls = [ # (('jdbc:mysql://localhost:3306/test_db', '`table1`', {'user': 'user', 'password': 'pass', 'driver': 'com.mysql.cj.jdbc.Driver'}),), # (('jdbc:mysql://localhost:3306/test_db', '`table2`', {'user': 'user', 'password': 'pass', 'driver': 'com.mysql.cj.jdbc.Driver'}),) # ] # mock_jdbc.assert_has_calls(expected_calls, any_order=True) # 验证结果是否正确 expected_data = ['MQ==,QSwxLEM=', 'Mg==,QiwyLEQ='] self.assertEqual(result, expected_data) spark.stop() if __name__ == '__main__': unittest.main()