30 lines
745 B
Python
30 lines
745 B
Python
import numpy as np
|
|
from scipy.stats import wasserstein_distance
|
|
|
|
def calculate_emd(vector1, vector2):
|
|
vec1 = np.array(vector1)
|
|
vec2 = np.array(vector2)
|
|
|
|
positions = np.arange(len(vec1))
|
|
emd = wasserstein_distance(positions, positions, vec1, vec2)
|
|
|
|
total_weight = np.sum(vec1) # 权重和
|
|
emd = emd * total_weight
|
|
|
|
return emd
|
|
|
|
def main():
|
|
# input_str = input().strip()
|
|
line1 = input("输入第一行:").strip()
|
|
line2 = input("输入第二行:").strip()
|
|
|
|
vector1 = list(map(float, line1.split()))
|
|
vector2 = list(map(float, line2.split()))
|
|
|
|
# 计算EMD距离
|
|
emd_value = calculate_emd(vector1, vector2)
|
|
|
|
print(int(round(emd_value)))
|
|
|
|
if __name__ == '__main__':
|
|
main() |