From 99343a483534b4e1948da39fad80b0d9d951a46c Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Sun, 7 Jul 2024 18:44:40 +0800 Subject: [PATCH 1/6] leap2023 implement --- textattack/attack_recipes/leap_2023.py | 41 +++ textattack/search_methods/__init__.py | 1 + .../leap_particle_swarm_optimization.py | 237 ++++++++++++++++++ 3 files changed, 279 insertions(+) create mode 100644 textattack/attack_recipes/leap_2023.py create mode 100644 textattack/search_methods/leap_particle_swarm_optimization.py diff --git a/textattack/attack_recipes/leap_2023.py b/textattack/attack_recipes/leap_2023.py new file mode 100644 index 00000000..1917b8fd --- /dev/null +++ b/textattack/attack_recipes/leap_2023.py @@ -0,0 +1,41 @@ +""" + +LEAP +================================== + +(LEAP: Efficient and Automated Test Method for NLP Software) + +""" +from textattack import Attack +from textattack.constraints.pre_transformation import ( + MaxModificationRate, + StopwordModification, +) +from textattack.goal_functions import UntargetedClassification +from textattack.search_methods import LEAPParticleSwarmOptimization +from textattack.transformations import WordSwapWordNet + +from .attack_recipe import AttackRecipe + +class LEAP2023(AttackRecipe): + @staticmethod + def build(model_wrapper): + # + # Swap words with their synonyms extracted based on the WordNet. + # + transformation = WordSwapWordNet() + # + # MaxModificationRate = 0.16 in AG's News + # + constraints = [MaxModificationRate(max_rate=0.16), StopwordModification()] + # + # + # Use untargeted classification for demo, can be switched to targeted one + # + goal_function = UntargetedClassification(model_wrapper) + # + # Perform word substitution with LEAP algorithm. + # + search_method = LEAPParticleSwarmOptimization(pop_size=60, max_iters=20, post_turn_check=True, max_turn_retries=20) + + return Attack(goal_function, constraints, transformation, search_method) diff --git a/textattack/search_methods/__init__.py b/textattack/search_methods/__init__.py index b5e26291..0a6e2a3b 100644 --- a/textattack/search_methods/__init__.py +++ b/textattack/search_methods/__init__.py @@ -15,3 +15,4 @@ from .alzantot_genetic_algorithm import AlzantotGeneticAlgorithm from .improved_genetic_algorithm import ImprovedGeneticAlgorithm from .particle_swarm_optimization import ParticleSwarmOptimization +from .leap_particle_swarm_optimization import LEAPParticleSwarmOptimization \ No newline at end of file diff --git a/textattack/search_methods/leap_particle_swarm_optimization.py b/textattack/search_methods/leap_particle_swarm_optimization.py new file mode 100644 index 00000000..c3ad38dd --- /dev/null +++ b/textattack/search_methods/leap_particle_swarm_optimization.py @@ -0,0 +1,237 @@ +""" + +LEAP Particle Swarm Optimization +==================================== + +LEAP, an automated test method that uses LEvy flight-based Adaptive Particle +swarm optimization integrated with textual features to generate adversarial test cases. + +al +``_ +``_ +""" + +import copy + +import numpy as np +from scipy.special import gamma as gamma +from textattack.goal_function_results import GoalFunctionResultStatus +from textattack.search_methods import ParticleSwarmOptimization, PopulationMember +from textattack.shared import utils +from textattack.shared.validators import transformation_consists_of_word_swaps + +def sigmax(alpha): + numerator = gamma(alpha + 1.0) * np.sin(np.pi * alpha / 2.0) + denominator = gamma((alpha + 1) / 2.0) * alpha * np.power(2.0, (alpha - 1.0) / 2.0) + + return np.power(numerator / denominator, 1.0 / alpha) + +def vf(alpha): + x = np.random.normal(0, 1) + y = np.random.normal(0, 1) + + x = x * sigmax(alpha) + + return x / np.power(np.abs(y), 1.0 / alpha) + +def K(alpha): + k = alpha * gamma((alpha + 1.0) / (2.0 * alpha)) / gamma(1.0 / alpha) + k *= np.power(alpha * gamma((alpha + 1.0) / 2.0) / (gamma(alpha + 1.0) * np.sin(np.pi * alpha / 2.0)), 1.0 / alpha) + + return k + +def C(alpha): + x = np.array((0.75, 0.8, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.95, 1.99)) + y = np.array( + (2.2085, 2.483, 2.7675, 2.945, 2.941, 2.9005, 2.8315, 2.737, 2.6125, 2.4465, 2.206, 1.7915, 1.3925, 0.6089)) + + return np.interp(alpha, x, y) + +def levy(alpha, gamma=1, n=1): + w = 0 + for i in range(0, n): + v = vf(alpha) + + while v < -10: + v = vf(alpha) + + w += v * ((K(alpha) - 1.0) * np.exp(-v / C(alpha)) + 1.0) + + z = 1.0 / np.power(n, 1.0 / alpha) * w * gamma + + return z + +def get_one_levy(min, max): + while 1: + temp = levy(1.5, 1) + if min <= temp <= max: + break + else: + continue + return temp + +def softmax(x, axis=1): + row_max = x.max(axis=axis) + + # Each element of the row needs to be subtracted from the corresponding maximum value, otherwise exp(x) will overflow, resulting in the inf case + row_max = row_max.reshape(-1, 1) + x = x - row_max + + # Calculate the exponential power of e + x_exp = np.exp(x) + x_sum = np.sum(x_exp, axis=axis, keepdims=True) + s = x_exp / x_sum + return s + +class LEAPParticleSwarmOptimization(ParticleSwarmOptimization): + def _greedy_perturb(self, pop_member, original_result): + best_neighbors, prob_list = self._get_best_neighbors( + pop_member.result, original_result + ) + random_result = best_neighbors[np.argsort(prob_list)[-1]] + pop_member.attacked_text = random_result.attacked_text + pop_member.result = random_result + return True + + def perform_search(self, initial_result): + self._search_over = False + population = self._initialize_population(initial_result, self.pop_size) + + # Initialize velocities + v_init = [] + v_init_rand = np.random.uniform(-self.v_max, self.v_max, self.pop_size) + v_init_levy = [] + while 1: + temp = levy(1.5, 1) + if -self.v_max <= temp <= self.v_max: + v_init_levy.append(temp) + else: + continue + if len(v_init_levy) == self.pop_size: + break + for i in range(self.pop_size): + if np.random.uniform(-self.v_max, self.v_max, ) < levy(1.5, 1): + v_init.append(v_init_rand[i]) + else: + v_init.append(v_init_levy[i]) + v_init = np.array(v_init) + + velocities = np.array( + [ + [v_init[t] for _ in range(initial_result.attacked_text.num_words)] + for t in range(self.pop_size) + ] + ) + + global_elite = max(population, key=lambda x: x.score) + if ( + self._search_over + or global_elite.result.goal_status == GoalFunctionResultStatus.SUCCEEDED + ): + return global_elite.result + + local_elites = copy.copy(population) + + pop_fit_list = [] + for i in range(len(population)): + pop_fit_list.append(population[i].score) + pop_fit = np.array(pop_fit_list) + fit_ave = round(pop_fit.mean(), 3) + fit_min = pop_fit.min() + + # start iterations + omega = [] + for i in range(self.max_iters): + for k in range(len(population)): + if population[k].score < fit_ave: + omega.append(self.omega_2 + ((population[k].score - fit_min) * + (self.omega_1 - self.omega_2)) / + (fit_ave - fit_min)) + else: + omega.append(get_one_levy(0.5, 0.8)) + C1 = self.c1_origin - i / self.max_iters * (self.c1_origin - self.c2_origin) + C2 = self.c2_origin + i / self.max_iters * (self.c1_origin - self.c2_origin) + P1 = C1 + P2 = C2 + + for k in range(len(population)): + # calculate the probability of turning each word + pop_mem_words = population[k].words + local_elite_words = local_elites[k].words + assert len(pop_mem_words) == len( + local_elite_words + ), "PSO word length mismatch!" + + for d in range(len(pop_mem_words)): + velocities[k][d] = omega[k] * velocities[k][d] + (1 - omega[k]) * ( + self._equal(pop_mem_words[d], local_elite_words[d]) + + self._equal(pop_mem_words[d], global_elite.words[d]) + ) + turn_list = np.array([velocities[k]]) + turn_prob = softmax(turn_list)[0] + + if np.random.uniform() < P1: + # Move towards local elite + population[k] = self._turn( + local_elites[k], + population[k], + turn_prob, + initial_result.attacked_text, + ) + + if np.random.uniform() < P2: + # Move towards global elite + population[k] = self._turn( + global_elite, + population[k], + turn_prob, + initial_result.attacked_text, + ) + + # Check if there is any successful attack in the current population + pop_results, self._search_over = self.get_goal_results( + [p.attacked_text for p in population] + ) + if self._search_over: + # if `get_goal_results` gets cut short by query budget, resize population + population = population[: len(pop_results)] + for k in range(len(pop_results)): + population[k].result = pop_results[k] + + top_member = max(population, key=lambda x: x.score) + if ( + self._search_over + or top_member.result.goal_status == GoalFunctionResultStatus.SUCCEEDED + ): + return top_member.result + + # Mutation based on the current change rate + for k in range(len(population)): + change_ratio = initial_result.attacked_text.words_diff_ratio( + population[k].attacked_text + ) + # Referred from the original source code + p_change = 1 - 2 * change_ratio + if np.random.uniform() < p_change: + self._perturb(population[k], initial_result) + + if self._search_over: + break + + # Check if there is any successful attack in the current population + top_member = max(population, key=lambda x: x.score) + if ( + self._search_over + or top_member.result.goal_status == GoalFunctionResultStatus.SUCCEEDED + ): + return top_member.result + + # Update the elite if the score is increased + for k in range(len(population)): + if population[k].score > local_elites[k].score: + local_elites[k] = copy.copy(population[k]) + + if top_member.score > global_elite.score: + global_elite = copy.copy(top_member) + + return global_elite.result From 7213e0f29e941a9feecf401de957f908d14c0609 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Mon, 8 Jul 2024 02:11:26 +0800 Subject: [PATCH 2/6] format the code --- textattack/attack_recipes/__init__.py | 1 + textattack/attack_recipes/leap_2023.py | 4 +- textattack/search_methods/__init__.py | 2 +- ...py => particle_swarm_optimization_leap.py} | 45 +++++++++++++++---- 4 files changed, 41 insertions(+), 11 deletions(-) rename textattack/search_methods/{leap_particle_swarm_optimization.py => particle_swarm_optimization_leap.py} (90%) diff --git a/textattack/attack_recipes/__init__.py b/textattack/attack_recipes/__init__.py index 6e865dde..137c1959 100644 --- a/textattack/attack_recipes/__init__.py +++ b/textattack/attack_recipes/__init__.py @@ -39,6 +39,7 @@ from .pso_zang_2020 import PSOZang2020 from .checklist_ribeiro_2020 import CheckList2020 from .clare_li_2020 import CLARE2020 +from .leap_2023 import LEAP2023 from .french_recipe import FrenchRecipe from .spanish_recipe import SpanishRecipe from .chinese_recipe import ChineseRecipe diff --git a/textattack/attack_recipes/leap_2023.py b/textattack/attack_recipes/leap_2023.py index 1917b8fd..9149b81b 100644 --- a/textattack/attack_recipes/leap_2023.py +++ b/textattack/attack_recipes/leap_2023.py @@ -12,7 +12,7 @@ StopwordModification, ) from textattack.goal_functions import UntargetedClassification -from textattack.search_methods import LEAPParticleSwarmOptimization +from textattack.search_methods import ParticleSwarmOptimizationLEAP from textattack.transformations import WordSwapWordNet from .attack_recipe import AttackRecipe @@ -36,6 +36,6 @@ def build(model_wrapper): # # Perform word substitution with LEAP algorithm. # - search_method = LEAPParticleSwarmOptimization(pop_size=60, max_iters=20, post_turn_check=True, max_turn_retries=20) + search_method = ParticleSwarmOptimizationLEAP(pop_size=60, max_iters=20, post_turn_check=True, max_turn_retries=20) return Attack(goal_function, constraints, transformation, search_method) diff --git a/textattack/search_methods/__init__.py b/textattack/search_methods/__init__.py index 0a6e2a3b..4338cd0f 100644 --- a/textattack/search_methods/__init__.py +++ b/textattack/search_methods/__init__.py @@ -15,4 +15,4 @@ from .alzantot_genetic_algorithm import AlzantotGeneticAlgorithm from .improved_genetic_algorithm import ImprovedGeneticAlgorithm from .particle_swarm_optimization import ParticleSwarmOptimization -from .leap_particle_swarm_optimization import LEAPParticleSwarmOptimization \ No newline at end of file +from .particle_swarm_optimization_leap import ParticleSwarmOptimizationLEAP \ No newline at end of file diff --git a/textattack/search_methods/leap_particle_swarm_optimization.py b/textattack/search_methods/particle_swarm_optimization_leap.py similarity index 90% rename from textattack/search_methods/leap_particle_swarm_optimization.py rename to textattack/search_methods/particle_swarm_optimization_leap.py index c3ad38dd..cac12a47 100644 --- a/textattack/search_methods/leap_particle_swarm_optimization.py +++ b/textattack/search_methods/particle_swarm_optimization_leap.py @@ -16,16 +16,15 @@ import numpy as np from scipy.special import gamma as gamma from textattack.goal_function_results import GoalFunctionResultStatus -from textattack.search_methods import ParticleSwarmOptimization, PopulationMember -from textattack.shared import utils -from textattack.shared.validators import transformation_consists_of_word_swaps +from textattack.search_methods import ParticleSwarmOptimization + def sigmax(alpha): numerator = gamma(alpha + 1.0) * np.sin(np.pi * alpha / 2.0) denominator = gamma((alpha + 1) / 2.0) * alpha * np.power(2.0, (alpha - 1.0) / 2.0) - return np.power(numerator / denominator, 1.0 / alpha) + def vf(alpha): x = np.random.normal(0, 1) y = np.random.normal(0, 1) @@ -34,19 +33,45 @@ def vf(alpha): return x / np.power(np.abs(y), 1.0 / alpha) + def K(alpha): k = alpha * gamma((alpha + 1.0) / (2.0 * alpha)) / gamma(1.0 / alpha) - k *= np.power(alpha * gamma((alpha + 1.0) / 2.0) / (gamma(alpha + 1.0) * np.sin(np.pi * alpha / 2.0)), 1.0 / alpha) + k *= np.power( + alpha + * gamma((alpha + 1.0) / 2.0) + / (gamma(alpha + 1.0) * np.sin(np.pi * alpha / 2.0)), + 1.0 / alpha, + ) return k + def C(alpha): - x = np.array((0.75, 0.8, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.95, 1.99)) + x = np.array( + (0.75, 0.8, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.95, 1.99) + ) y = np.array( - (2.2085, 2.483, 2.7675, 2.945, 2.941, 2.9005, 2.8315, 2.737, 2.6125, 2.4465, 2.206, 1.7915, 1.3925, 0.6089)) + ( + 2.2085, + 2.483, + 2.7675, + 2.945, + 2.941, + 2.9005, + 2.8315, + 2.737, + 2.6125, + 2.4465, + 2.206, + 1.7915, + 1.3925, + 0.6089, + ) + ) return np.interp(alpha, x, y) + def levy(alpha, gamma=1, n=1): w = 0 for i in range(0, n): @@ -61,6 +86,7 @@ def levy(alpha, gamma=1, n=1): return z + def get_one_levy(min, max): while 1: temp = levy(1.5, 1) @@ -83,7 +109,10 @@ def softmax(x, axis=1): s = x_exp / x_sum return s -class LEAPParticleSwarmOptimization(ParticleSwarmOptimization): +class ParticleSwarmOptimizationLEAP(ParticleSwarmOptimization): + """Attacks a model with word substiutitions using a variant of Particle Swarm + Optimization (PSO) algorithm called LEAP. + """ def _greedy_perturb(self, pop_member, original_result): best_neighbors, prob_list = self._get_best_neighbors( pop_member.result, original_result From 7d52ce39033e74a0138f41aacd1ee7c8b9cd24e6 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Mon, 8 Jul 2024 02:26:09 +0800 Subject: [PATCH 3/6] add leap to recipe names --- textattack/attack_args.py | 1 + .../particle_swarm_optimization_leap.py | 12 +++--------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/textattack/attack_args.py b/textattack/attack_args.py index b99f6dc5..8697ea52 100644 --- a/textattack/attack_args.py +++ b/textattack/attack_args.py @@ -37,6 +37,7 @@ "checklist": "textattack.attack_recipes.CheckList2020", "clare": "textattack.attack_recipes.CLARE2020", "a2t": "textattack.attack_recipes.A2TYoo2021", + "leap": "textattack.attack_recipes.LEAP2023", } diff --git a/textattack/search_methods/particle_swarm_optimization_leap.py b/textattack/search_methods/particle_swarm_optimization_leap.py index cac12a47..74b676c2 100644 --- a/textattack/search_methods/particle_swarm_optimization_leap.py +++ b/textattack/search_methods/particle_swarm_optimization_leap.py @@ -18,13 +18,11 @@ from textattack.goal_function_results import GoalFunctionResultStatus from textattack.search_methods import ParticleSwarmOptimization - def sigmax(alpha): numerator = gamma(alpha + 1.0) * np.sin(np.pi * alpha / 2.0) denominator = gamma((alpha + 1) / 2.0) * alpha * np.power(2.0, (alpha - 1.0) / 2.0) return np.power(numerator / denominator, 1.0 / alpha) - def vf(alpha): x = np.random.normal(0, 1) y = np.random.normal(0, 1) @@ -33,7 +31,6 @@ def vf(alpha): return x / np.power(np.abs(y), 1.0 / alpha) - def K(alpha): k = alpha * gamma((alpha + 1.0) / (2.0 * alpha)) / gamma(1.0 / alpha) k *= np.power( @@ -45,7 +42,6 @@ def K(alpha): return k - def C(alpha): x = np.array( (0.75, 0.8, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.95, 1.99) @@ -71,7 +67,6 @@ def C(alpha): return np.interp(alpha, x, y) - def levy(alpha, gamma=1, n=1): w = 0 for i in range(0, n): @@ -86,9 +81,8 @@ def levy(alpha, gamma=1, n=1): return z - def get_one_levy(min, max): - while 1: + while True: temp = levy(1.5, 1) if min <= temp <= max: break @@ -126,11 +120,11 @@ def perform_search(self, initial_result): self._search_over = False population = self._initialize_population(initial_result, self.pop_size) - # Initialize velocities + # Initialize velocities v_init = [] v_init_rand = np.random.uniform(-self.v_max, self.v_max, self.pop_size) v_init_levy = [] - while 1: + while True: temp = levy(1.5, 1) if -self.v_max <= temp <= self.v_max: v_init_levy.append(temp) From 26f628ff98ad224776316ebbf7a16c228288a5b9 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Mon, 8 Jul 2024 02:47:50 +0800 Subject: [PATCH 4/6] update doc --- docs/3recipes/attack_recipes.rst | 5 ++++- docs/api/search_methods.rst | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/3recipes/attack_recipes.rst b/docs/3recipes/attack_recipes.rst index 477cb1e3..874bd885 100644 --- a/docs/3recipes/attack_recipes.rst +++ b/docs/3recipes/attack_recipes.rst @@ -36,6 +36,7 @@ Attacks on classification models 14. TextBugger (TextBugger: Generating Adversarial Text Against Real-world Applications) 15. Pruthi (Combating Adversarial Misspellings with Robust Word Recognition 2019) 16. CLARE (Contextualized Perturbation for Textual Adversarial Attack 2020) +17. LEAP (LEAP: Efficient and Automated Test Method for NLP Software 2023) @@ -136,7 +137,9 @@ Attacks on classification models :members: :noindex: - +.. automodule:: textattack.attack_recipes.leap_2023 + :members: + :noindex: diff --git a/docs/api/search_methods.rst b/docs/api/search_methods.rst index 979dadcb..2aed38cb 100644 --- a/docs/api/search_methods.rst +++ b/docs/api/search_methods.rst @@ -44,3 +44,7 @@ ParticleSwarmOptimization .. autoclass:: textattack.search_methods.ParticleSwarmOptimization :members: +ParticleSwarmOptimizationLEAP +-------------------------- +.. autoclass:: textattack.search_methods.ParticleSwarmOptimizationLEAP + :members: From be325c8f2b54c4c52d8b42d934d2495fa40efdfc Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Mon, 8 Jul 2024 03:32:39 +0800 Subject: [PATCH 5/6] format --- textattack/attack.py | 6 ++-- textattack/attack_recipes/leap_2023.py | 5 ++- .../classification/targeted_classification.py | 2 +- textattack/search_methods/__init__.py | 2 +- .../particle_swarm_optimization_leap.py | 35 ++++++++++++++----- 5 files changed, 35 insertions(+), 15 deletions(-) diff --git a/textattack/attack.py b/textattack/attack.py index 7e05f93e..47537d1b 100644 --- a/textattack/attack.py +++ b/textattack/attack.py @@ -372,9 +372,9 @@ def filter_transformations( uncached_texts.append(transformed_text) else: # promote transformed_text to the top of the LRU cache - self.constraints_cache[(current_text, transformed_text)] = ( - self.constraints_cache[(current_text, transformed_text)] - ) + self.constraints_cache[ + (current_text, transformed_text) + ] = self.constraints_cache[(current_text, transformed_text)] if self.constraints_cache[(current_text, transformed_text)]: filtered_texts.append(transformed_text) filtered_texts += self._filter_transformations_uncached( diff --git a/textattack/attack_recipes/leap_2023.py b/textattack/attack_recipes/leap_2023.py index 9149b81b..cc507628 100644 --- a/textattack/attack_recipes/leap_2023.py +++ b/textattack/attack_recipes/leap_2023.py @@ -17,6 +17,7 @@ from .attack_recipe import AttackRecipe + class LEAP2023(AttackRecipe): @staticmethod def build(model_wrapper): @@ -36,6 +37,8 @@ def build(model_wrapper): # # Perform word substitution with LEAP algorithm. # - search_method = ParticleSwarmOptimizationLEAP(pop_size=60, max_iters=20, post_turn_check=True, max_turn_retries=20) + search_method = ParticleSwarmOptimizationLEAP( + pop_size=60, max_iters=20, post_turn_check=True, max_turn_retries=20 + ) return Attack(goal_function, constraints, transformation, search_method) diff --git a/textattack/goal_functions/classification/targeted_classification.py b/textattack/goal_functions/classification/targeted_classification.py index 3b1ad3ac..6041b625 100644 --- a/textattack/goal_functions/classification/targeted_classification.py +++ b/textattack/goal_functions/classification/targeted_classification.py @@ -11,7 +11,7 @@ class TargetedClassification(ClassificationGoalFunction): """A targeted attack on classification models which attempts to maximize the score of the target label. - Complete when the arget label is the predicted label. + Complete when the target label is the predicted label. """ def __init__(self, *args, target_class=0, **kwargs): diff --git a/textattack/search_methods/__init__.py b/textattack/search_methods/__init__.py index 4338cd0f..8645c474 100644 --- a/textattack/search_methods/__init__.py +++ b/textattack/search_methods/__init__.py @@ -15,4 +15,4 @@ from .alzantot_genetic_algorithm import AlzantotGeneticAlgorithm from .improved_genetic_algorithm import ImprovedGeneticAlgorithm from .particle_swarm_optimization import ParticleSwarmOptimization -from .particle_swarm_optimization_leap import ParticleSwarmOptimizationLEAP \ No newline at end of file +from .particle_swarm_optimization_leap import ParticleSwarmOptimizationLEAP diff --git a/textattack/search_methods/particle_swarm_optimization_leap.py b/textattack/search_methods/particle_swarm_optimization_leap.py index 74b676c2..bec99114 100644 --- a/textattack/search_methods/particle_swarm_optimization_leap.py +++ b/textattack/search_methods/particle_swarm_optimization_leap.py @@ -15,14 +15,17 @@ import numpy as np from scipy.special import gamma as gamma + from textattack.goal_function_results import GoalFunctionResultStatus from textattack.search_methods import ParticleSwarmOptimization + def sigmax(alpha): numerator = gamma(alpha + 1.0) * np.sin(np.pi * alpha / 2.0) denominator = gamma((alpha + 1) / 2.0) * alpha * np.power(2.0, (alpha - 1.0) / 2.0) return np.power(numerator / denominator, 1.0 / alpha) + def vf(alpha): x = np.random.normal(0, 1) y = np.random.normal(0, 1) @@ -31,6 +34,7 @@ def vf(alpha): return x / np.power(np.abs(y), 1.0 / alpha) + def K(alpha): k = alpha * gamma((alpha + 1.0) / (2.0 * alpha)) / gamma(1.0 / alpha) k *= np.power( @@ -42,6 +46,7 @@ def K(alpha): return k + def C(alpha): x = np.array( (0.75, 0.8, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.95, 1.99) @@ -67,6 +72,7 @@ def C(alpha): return np.interp(alpha, x, y) + def levy(alpha, gamma=1, n=1): w = 0 for i in range(0, n): @@ -81,6 +87,7 @@ def levy(alpha, gamma=1, n=1): return z + def get_one_levy(min, max): while True: temp = levy(1.5, 1) @@ -90,6 +97,7 @@ def get_one_levy(min, max): continue return temp + def softmax(x, axis=1): row_max = x.max(axis=axis) @@ -103,10 +111,11 @@ def softmax(x, axis=1): s = x_exp / x_sum return s + class ParticleSwarmOptimizationLEAP(ParticleSwarmOptimization): - """Attacks a model with word substiutitions using a variant of Particle Swarm - Optimization (PSO) algorithm called LEAP. - """ + """Attacks a model with word substiutitions using a variant of Particle + Swarm Optimization (PSO) algorithm called LEAP.""" + def _greedy_perturb(self, pop_member, original_result): best_neighbors, prob_list = self._get_best_neighbors( pop_member.result, original_result @@ -115,11 +124,11 @@ def _greedy_perturb(self, pop_member, original_result): pop_member.attacked_text = random_result.attacked_text pop_member.result = random_result return True - + def perform_search(self, initial_result): self._search_over = False population = self._initialize_population(initial_result, self.pop_size) - + # Initialize velocities v_init = [] v_init_rand = np.random.uniform(-self.v_max, self.v_max, self.pop_size) @@ -133,7 +142,10 @@ def perform_search(self, initial_result): if len(v_init_levy) == self.pop_size: break for i in range(self.pop_size): - if np.random.uniform(-self.v_max, self.v_max, ) < levy(1.5, 1): + if np.random.uniform( + -self.v_max, + self.v_max, + ) < levy(1.5, 1): v_init.append(v_init_rand[i]) else: v_init.append(v_init_levy[i]) @@ -167,9 +179,14 @@ def perform_search(self, initial_result): for i in range(self.max_iters): for k in range(len(population)): if population[k].score < fit_ave: - omega.append(self.omega_2 + ((population[k].score - fit_min) * - (self.omega_1 - self.omega_2)) / - (fit_ave - fit_min)) + omega.append( + self.omega_2 + + ( + (population[k].score - fit_min) + * (self.omega_1 - self.omega_2) + ) + / (fit_ave - fit_min) + ) else: omega.append(get_one_levy(0.5, 0.8)) C1 = self.c1_origin - i / self.max_iters * (self.c1_origin - self.c2_origin) From e59bd3b11843cae6cc4ba872bb5f6045591ab8d1 Mon Sep 17 00:00:00 2001 From: jstzwj <1103870790@qq.com> Date: Mon, 8 Jul 2024 03:45:26 +0800 Subject: [PATCH 6/6] reformat files --- textattack/metrics/attack_metrics/words_perturbed.py | 6 +++--- textattack/search_methods/particle_swarm_optimization.py | 6 +++--- textattack/shared/validators.py | 5 +---- .../sentence_transformations/back_transcription.py | 5 +++-- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/textattack/metrics/attack_metrics/words_perturbed.py b/textattack/metrics/attack_metrics/words_perturbed.py index 38c11b29..6104de1b 100644 --- a/textattack/metrics/attack_metrics/words_perturbed.py +++ b/textattack/metrics/attack_metrics/words_perturbed.py @@ -65,9 +65,9 @@ def calculate(self, results): self.all_metrics["avg_word_perturbed"] = self.avg_number_word_perturbed_num() self.all_metrics["avg_word_perturbed_perc"] = self.avg_perturbation_perc() self.all_metrics["max_words_changed"] = self.max_words_changed - self.all_metrics["num_words_changed_until_success"] = ( - self.num_words_changed_until_success - ) + self.all_metrics[ + "num_words_changed_until_success" + ] = self.num_words_changed_until_success return self.all_metrics diff --git a/textattack/search_methods/particle_swarm_optimization.py b/textattack/search_methods/particle_swarm_optimization.py index fdc48aa0..639f513b 100644 --- a/textattack/search_methods/particle_swarm_optimization.py +++ b/textattack/search_methods/particle_swarm_optimization.py @@ -120,9 +120,9 @@ def _turn(self, source_text, target_text, prob, original_text): & indices_to_replace ) if "last_transformation" in source_text.attacked_text.attack_attrs: - new_text.attack_attrs["last_transformation"] = ( - source_text.attacked_text.attack_attrs["last_transformation"] - ) + new_text.attack_attrs[ + "last_transformation" + ] = source_text.attacked_text.attack_attrs["last_transformation"] if not self.post_turn_check or (new_text.words == source_text.words): break diff --git a/textattack/shared/validators.py b/textattack/shared/validators.py index 45513a2a..55f4ed08 100644 --- a/textattack/shared/validators.py +++ b/textattack/shared/validators.py @@ -25,10 +25,7 @@ r"^textattack.models.helpers.word_cnn_for_classification.*", r"^transformers.modeling_\w*\.\w*ForSequenceClassification$", ], - ( - NonOverlappingOutput, - MinimizeBleu, - ): [ + (NonOverlappingOutput, MinimizeBleu,): [ r"^textattack.models.helpers.t5_for_text_to_text.*", ], } diff --git a/textattack/transformations/sentence_transformations/back_transcription.py b/textattack/transformations/sentence_transformations/back_transcription.py index 81cc8aff..c902b6d5 100644 --- a/textattack/transformations/sentence_transformations/back_transcription.py +++ b/textattack/transformations/sentence_transformations/back_transcription.py @@ -12,8 +12,9 @@ class BackTranscription(SentenceTransformation): - """A type of sentence level transformation that takes in a text input, converts it into - synthesized speech using ASR, and transcribes it back to text using TTS. + """A type of sentence level transformation that takes in a text input, + converts it into synthesized speech using ASR, and transcribes it back to + text using TTS. tts_model: text-to-speech model from huggingface asr_model: automatic speech recognition model from huggingface