task2
This commit is contained in:
30
main.py
Normal file
30
main.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import numpy as np
|
||||
from scipy.stats import wasserstein_distance
|
||||
|
||||
def calculate_emd(vector1, vector2):
|
||||
vec1 = np.array(vector1)
|
||||
vec2 = np.array(vector2)
|
||||
|
||||
positions = np.arange(len(vec1))
|
||||
emd = wasserstein_distance(positions, positions, vec1, vec2)
|
||||
|
||||
total_weight = np.sum(vec1) # 权重和
|
||||
emd = emd * total_weight
|
||||
|
||||
return emd
|
||||
|
||||
def main():
|
||||
# input_str = input().strip()
|
||||
line1 = input("输入第一行:").strip()
|
||||
line2 = input("输入第二行:").strip()
|
||||
|
||||
vector1 = list(map(float, line1.split()))
|
||||
vector2 = list(map(float, line2.split()))
|
||||
|
||||
# 计算EMD距离
|
||||
emd_value = calculate_emd(vector1, vector2)
|
||||
|
||||
print(int(round(emd_value)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user