题目设计 思考如下背景:
存在两家公司 $a$ 和 $b$ ,和一个公开训练集 $TrainSet$ 。市面上现有的模型都基于 $TrainSet$ 训练,但是效果很差。
这时 $b$ 公司花费大量人力物力成本,精简出了更优质的训练集 $TrainSet_b$ ,训练出了模型 $B$ ,模型一夜之间爆火。
此时,作为 $a$ 公司的业务人员,老板希望你能得到 $TrainSet_b$ ,而你梳理资源后,发现只有公开训练集 $TrainSet$ 和模型 $B$ 的SaaS服务或许能提供帮助。
本题模拟上述现实背景,要求选手完成数据集窃取 ,导致模型 $B$ 隐私泄露,即 Membership Inference Attack 。
The Easier 2 Do 举个栗子 教小孩认识不同的瓶装饮品,你分别向他展示了可口可乐(不是百事可乐!)和农夫山泉,让他记住了他们的特征。
当你拿着芬达出现时,小孩会根据它的细节(橙子味),来猜测 这是一瓶可乐。
当你拿着雪碧出现时,小孩会根据它的细节(无色),来猜测 这是一瓶矿泉水。
当你拿着农夫山泉出现时,小孩会发现我见过这个!激动的肯定 这是一瓶矿泉水。
不妨这样认为 $$ \begin{array} 小孩 \rightarrow 模型 \\ \{可乐,农夫山泉\} \rightarrow 训练集 \\ \{芬达、雪碧、农夫山泉\} \rightarrow 测试集 \end{array} $$
那么上面的三个情况,就分别对应:
模型对未知 样本预测
模型训练过拟合
模型对已知 样本预测
显然,模型对于训练集的样本反馈更加剧烈 。 反应到softmax上,就是置信度非常高。
Unpresented 在训练集中,大概存在70%左右的数据,可以通过上述的方式进行判断。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 from pwn import *r = remote(url, port) for i in range (500 ): r.recvuntil(b'top_10_pred :' ) data=r.recvline() threshold = 0.9 top_10_pred = eval (data)[0 ] max_prob = max (top_10_pred) if max_prob > threshold: inp = 1 else : inp = 0 success(f"{i+1 } {inp} " ) r.sendlineafter(b'set?' ,str (inp).encode()) context.log_level = 'debug' r.interactive()
考虑到该题不符合该考点题目要求难度,因此未放出。
那剩下的数据呢?是否也存在这样的关系?只是不容易被肉眼看出?
Why not try the more difficult one? 赛题介绍 题目连接后如下
通过题目给出的,模型对于 $ idx $ 样本的 $ {pred}_{top_{10}} $ 预测值,进行判断:这个样本是否存在于当前模型的训练集中? (本题的题目简介应该注明题目交互方式。虽然赛中选手们都基本根据报错猜到了,但还是抱歉。)
How2solve 模型的通性 为了节省读论文的时间,我会用一种更抽象、但更快速的方式解释。
有两个模型 $A$ 和 $B$ ,每个模型有自己的训练集 $T_a$ 和 $T_b$
模型预测是函数: $A(x)$ 和 $B(x)$ , $x$ 是样本。 例如:$A(T_a)$ 表示, $A$ 模型在自己的训练集 $T_a$ 上的预测结果。
$A(x) \sim T_a$ 则表示 样本预测的结果 和 样本是否存在于训练集中 之间的关系。
假设存在这种关系,那么对于不同的模型,关系是否是同一种关系?
即 $A(x) \sim T_a$ 和 $B(x) \sim T_b$ 是否相同?
Reza Shokri 的论文证明了这确实存在
Membership Inference Attacks against Machine Learning Models
在以训练参数 为变量的假设中,他通过大量的重复实验证明,白盒、灰盒、黑盒,甚至是不同训练集 训练得到的Shadow Model,均和原模型在上述关系上存在一致性。
理论存在,实践开始 根据论文整理攻击思路如下:
通过一样的参数训练数个Shadow Model
建立Shadow Dataset ,储存每个影子模型的训练集和测试集,以及对应样本在模型上的预测结果
得到一份如下所示的数据,训练Attack Model
idx
pred_top10
label
1
[[0.05, 0.02, …, 0.23]]
in
2
[[0.01, 0.008, …, 0.53]]
out
……
使用Attack Model预测题目数据
训练题目模型的脚本在附件已经给出,添加储存数据部分即可
笔者这里一共训练了128个影子模型,使用CatBoost训练,对于题目给出的0.85还是比较简单的。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 import pandas as pdimport numpy as npfrom matplotlib import pyplot as pltfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import precision_recall_fscore_support, roc_curve, aucfrom catboost import CatBoostClassifierdf_shadow = pd.read_csv('attack_dset.csv' ) print ("Reading attack dataset:" , 'attack_dset.csv' )print (df_shadow.head)print ("data shape:" , df_shadow.shape)y = df_shadow["is_member" ] X = df_shadow.drop(["is_member" ], axis=1 ) X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2 , random_state=42 ) model = CatBoostClassifier( iterations=200 , depth=2 , learning_rate=0.01 , loss_function="Logloss" , verbose=True , ) model.fit(X_train, y_train) accuracy = model.score(X_test, y_test) precision, recall, f1_score, _ = precision_recall_fscore_support( y_test, model.predict(X_test), average="binary" ) print ("accuracy:" , accuracy)print ("precision:" , precision)print ("recall:" , recall)print ("f1_score:" , f1_score)fpr, tpr, thresholds = roc_curve(y_test, model.predict_proba(X_test)[:, 1 ]) print ("mean fpr:" , np.mean(fpr))print ("mean tpr:" , np.mean(tpr))save_path = f"./attack/{model.__class__.__name__} .pkl" model.save_model(save_path)
接着使用训练好的模型预测即可,搓一个交互板子预测即可。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 from catboost import CatBoostClassifier, Poolimport numpy as npfrom pwn import remotesh = remote("127.0.0.1" , 9999 ) attack_model = CatBoostClassifier() attack_model.load_model('CatBoostClassifier.pkl' ) sh.recvuntil(b"which ones are in this model's training set?\n" ) flag = sh.recvline() while b'flag' not in flag: sh.recvuntil(b"top_10_pred : " ) top_10 = sh.recvline() top_10 = eval (top_10.decode()) inp = attack_model.predict(top_10)[0 ] sh.sendlineafter(b"Is this picture in the training set?" , str (inp).encode()) flag = sh.recvline() print (flag)
I can RET to the best time point. I will definitely find it!
U have my words!