Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
10
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
新注册的用户请输入邮箱并保存,随后登录邮箱激活账号。后续可直接使用邮箱登录!
Open sidebar
chainweaver
mira
mira-mpc-engine
Commits
283fc468
Commit
283fc468
authored
11 months ago
by
guanchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
LOCAL JOIN无法有效处理重复列,优化下
parent
cf3ffe73
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
9 deletions
+70
-9
exe_local/join.py
exe_local/join.py
+70
-9
No files found.
exe_local/join.py
View file @
283fc468
...
...
@@ -15,11 +15,28 @@ class Join:
inputs
=
task
.
get
(
"input"
).
get
(
"data"
)
asset_en_name1
=
inputs
[
0
].
get
(
"assetName"
)
asset_en_name2
=
inputs
[
1
].
get
(
"assetName"
)
table_name1
=
inputs
[
0
].
get
(
"tableName"
)
table_name2
=
inputs
[
1
].
get
(
"tableName"
)
task_src1
=
inputs
[
0
].
get
(
"taskSrc"
)
task_src2
=
inputs
[
1
].
get
(
"taskSrc"
)
data_id_1
=
inputs
[
0
].
get
(
"dataID"
)
data_id_2
=
inputs
[
1
].
get
(
"dataID"
)
column_name1
=
inputs
[
0
].
get
(
"params"
).
get
(
"field"
)
column_name2
=
inputs
[
1
].
get
(
"params"
).
get
(
"field"
)
# 非初始任务,输入来自于上游的计算结果,tablename是运行时生成的uuid
if
task_src1
!=
''
:
table_name1
=
data_id_1
if
task_src2
!=
''
:
table_name2
=
data_id_2
join_type
=
task
.
get
(
"module"
).
get
(
"params"
).
get
(
"joinType"
)
mysql_url1
,
mysql_prop1
,
table_name1
=
self
.
sdk
.
get_source_conn_info
(
asset_en_name1
,
chain_info_id
)
mysql_url2
,
mysql_prop2
,
table_name2
=
self
.
sdk
.
get_source_conn_info
(
asset_en_name2
,
chain_info_id
)
mysql_url1
,
mysql_prop1
,
_
=
self
.
sdk
.
get_source_conn_info
(
asset_en_name1
,
chain_info_id
)
mysql_url2
,
mysql_prop2
,
_
=
self
.
sdk
.
get_source_conn_info
(
asset_en_name2
,
chain_info_id
)
self
.
sdk
.
logger
.
info
(
"============计算方读取数据"
)
self
.
sdk
.
logger
.
info
(
"表1资产名:"
+
asset_en_name1
)
...
...
@@ -47,12 +64,56 @@ class Join:
self
.
sdk
.
logger
.
info
(
"============更新输出信息上链"
)
self
.
sdk
.
update_output_id
(
output_id
,
data_name
,
final_result
)
def
join
(
self
,
df1
:
DataFrame
,
df2
:
DataFrame
,
cond
,
how
):
def
join
(
self
,
df1
:
DataFrame
,
df2
:
DataFrame
,
column
,
join_type
):
self
.
sdk
.
logger
.
log
(
"Local join, left table sample data: "
)
df1
.
show
()
self
.
sdk
.
logger
.
lof
(
"Local join, right table sample data: "
)
df2
.
show
()
df
=
df1
.
join
(
df2
,
cond
,
how
)
repeated_columns
=
[
c
for
c
in
df1
.
columns
if
c
in
df2
.
columns
]
for
repeated_column
in
repeated_columns
:
df
=
df
.
drop
(
repeated_column
)
return
df
\ No newline at end of file
df
=
df1
.
join
(
df2
,
column
,
join_type
)
self
.
sdk
.
logger
.
info
(
"Local join finished, result table sample data: "
)
df
.
show
()
# repeated_columns = [c for c in df1.columns if c in df2.columns]
# for repeated_column in repeated_columns:
# df = df.drop(repeated_column)
# return df
non_key_columns_df1
=
df1
.
columns
.
filter
(
lambda
c
:
c
!=
column
)
non_key_columns_df2
=
df2
.
columns
.
filter
(
lambda
c
:
c
!=
column
)
repeated_non_key_columns
=
set
(
non_key_columns_df1
)
&
set
(
non_key_columns_df2
)
self
.
sdk
.
logger
.
info
(
f
"Repeated non-key columns, would be dropped...:
{
str
(
repeated_non_key_columns
)
}
"
)
# Drop the identified duplicate non-key columns
for
col
in
repeated_non_key_columns
:
df
=
df
.
drop
(
col
)
if
repeated_non_key_columns
:
self
.
sdk
.
logger
.
log
(
"Repeated non-key columns dropped, result table sample data:"
)
df
.
show
()
def
task_run_test
(
self
,
task
):
from
pyspark.sql
import
SparkSession
from
pyspark.sql.functions
import
lit
# 初始化 SparkSession
spark
=
SparkSession
.
builder
.
appName
(
"JoinTest"
).
getOrCreate
()
# 构建 users_df
data_users
=
[
(
1
,
"user1@example.com"
,
"Active"
),
(
2
,
"user2@example.com"
,
"Active"
),
(
3
,
"user3@example.com"
,
"Inactive"
)
]
schema_users
=
[
"user_id"
,
"email"
,
"status"
]
users_df
=
spark
.
createDataFrame
(
data
=
data_users
,
schema
=
schema_users
)
# 构建 orders_df
data_orders
=
[
(
1001
,
1
,
50
,
"Pending"
),
(
1002
,
2
,
75
,
"Active"
),
(
1003
,
1
,
100
,
"Active"
)
]
schema_orders
=
[
"order_id"
,
"user_id"
,
"order_amount"
,
"status"
]
orders_df
=
spark
.
createDataFrame
(
data
=
data_orders
,
schema
=
schema_orders
)
# 显示数据集结构,供验证
users_df
.
show
()
orders_df
.
show
()
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment