新注册的用户请输入邮箱并保存,随后登录邮箱激活账号。后续可直接使用邮箱登录!

Commit 283fc468 authored by guanchen's avatar guanchen

LOCAL JOIN无法有效处理重复列,优化下

parent cf3ffe73
......@@ -15,11 +15,28 @@ class Join:
inputs = task.get("input").get("data")
asset_en_name1 = inputs[0].get("assetName")
asset_en_name2 = inputs[1].get("assetName")
table_name1 = inputs[0].get("tableName")
table_name2 = inputs[1].get("tableName")
task_src1 = inputs[0].get("taskSrc")
task_src2 = inputs[1].get("taskSrc")
data_id_1 = inputs[0].get("dataID")
data_id_2 = inputs[1].get("dataID")
column_name1 = inputs[0].get("params").get("field")
column_name2 = inputs[1].get("params").get("field")
# 非初始任务,输入来自于上游的计算结果,tablename是运行时生成的uuid
if task_src1 != '':
table_name1 = data_id_1
if task_src2 != '':
table_name2 = data_id_2
join_type = task.get("module").get("params").get("joinType")
mysql_url1, mysql_prop1, table_name1 = self.sdk.get_source_conn_info(asset_en_name1, chain_info_id)
mysql_url2, mysql_prop2, table_name2 = self.sdk.get_source_conn_info(asset_en_name2, chain_info_id)
mysql_url1, mysql_prop1, _ = self.sdk.get_source_conn_info(asset_en_name1, chain_info_id)
mysql_url2, mysql_prop2, _ = self.sdk.get_source_conn_info(asset_en_name2, chain_info_id)
self.sdk.logger.info("============计算方读取数据")
self.sdk.logger.info("表1资产名:" + asset_en_name1)
......@@ -47,12 +64,56 @@ class Join:
self.sdk.logger.info("============更新输出信息上链")
self.sdk.update_output_id(output_id, data_name, final_result)
def join(self, df1: DataFrame, df2: DataFrame, cond, how):
def join(self, df1: DataFrame, df2: DataFrame, column, join_type):
self.sdk.logger.log("Local join, left table sample data: ")
df1.show()
self.sdk.logger.lof("Local join, right table sample data: ")
df2.show()
df = df1.join(df2, cond, how)
repeated_columns = [c for c in df1.columns if c in df2.columns]
for repeated_column in repeated_columns:
df = df.drop(repeated_column)
return df
\ No newline at end of file
df = df1.join(df2, column, join_type)
self.sdk.logger.info("Local join finished, result table sample data: ")
df.show()
# repeated_columns = [c for c in df1.columns if c in df2.columns]
# for repeated_column in repeated_columns:
# df = df.drop(repeated_column)
# return df
non_key_columns_df1 = df1.columns.filter(lambda c: c != column)
non_key_columns_df2 = df2.columns.filter(lambda c: c != column)
repeated_non_key_columns = set(non_key_columns_df1) & set(non_key_columns_df2)
self.sdk.logger.info(f"Repeated non-key columns, would be dropped...: {str(repeated_non_key_columns)}")
# Drop the identified duplicate non-key columns
for col in repeated_non_key_columns:
df = df.drop(col)
if repeated_non_key_columns:
self.sdk.logger.log("Repeated non-key columns dropped, result table sample data:")
df.show()
def task_run_test(self, task):
from pyspark.sql import SparkSession
from pyspark.sql.functions import lit
# 初始化 SparkSession
spark = SparkSession.builder.appName("JoinTest").getOrCreate()
# 构建 users_df
data_users = [
(1, "user1@example.com", "Active"),
(2, "user2@example.com", "Active"),
(3, "user3@example.com", "Inactive")
]
schema_users = ["user_id", "email", "status"]
users_df = spark.createDataFrame(data=data_users, schema=schema_users)
# 构建 orders_df
data_orders = [
(1001, 1, 50, "Pending"),
(1002, 2, 75, "Active"),
(1003, 1, 100, "Active")
]
schema_orders = ["order_id", "user_id", "order_amount", "status"]
orders_df = spark.createDataFrame(data=data_orders, schema=schema_orders)
# 显示数据集结构,供验证
users_df.show()
orders_df.show()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment