From a097d830a7745c8bc2e75124153201240a045dc1 Mon Sep 17 00:00:00 2001 From: baixinbs Date: Sun, 11 Aug 2024 14:05:57 +0000 Subject: [PATCH 1/3] init --- analysis.sh | 3 + run.sh | 8 +++ src/analysis.py | 50 ++++++++++++++++ src/generate.py | 153 ++++++++++++++++++++++++++++++++++++++++++++++++ src/tags.csv | 1 + 5 files changed, 215 insertions(+) create mode 100644 analysis.sh create mode 100644 run.sh create mode 100644 src/analysis.py create mode 100644 src/generate.py create mode 100644 src/tags.csv diff --git a/analysis.sh b/analysis.sh new file mode 100644 index 0000000..a119830 --- /dev/null +++ b/analysis.sh @@ -0,0 +1,3 @@ +python src/analysis.py \ + --input-file=file_to_analysis \ + --output_file=file_to_output_result \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..392c9cf --- /dev/null +++ b/run.sh @@ -0,0 +1,8 @@ +python src/generate.py \ + --llm-host=http://x.x.x.x \ + --llm-port=xxxx \ + --input-file=path_to_input_file \ + --output-file=path_to_ouput_file \ + --tag-file=path_to_tag_file \ + --num-workers=20 \ + --task-type=tag diff --git a/src/analysis.py b/src/analysis.py new file mode 100644 index 0000000..a838d6c --- /dev/null +++ b/src/analysis.py @@ -0,0 +1,50 @@ +import gzip +import json +import argparse +import logging + +logger = logging.getLogger(__file__) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--input-file', type=str, required=True) + parser.add_argument('--output-file', type=str, required=True) + parser.add_argument('--task_type', + type=str, + choices=['tag', 'quality'], + default='tag') + args, _ = parser.parse_known_args() + return args + + +args = parse_args() + + +def analysis_file(): + value_dict = dict() + cnt = 0 + with gzip.open(args.input_file, 'rt', encoding='utf-8') as f: + for line in f: + # 解析每一行 + record = json.loads(line) + # tag或quality值 + v = record[args.task_type] + if v in value_dict: + value_dict[v] += 1 + else: + value_dict[v] = 1 + cnt += 1 + if cnt == 20: + break + return value_dict + + +def write_result_to_file(value_dict): + with open(args.output_file, 'w') as f: + for v in value_dict: + f.write(f"{v},{value_dict[v]}\n") + + +if __name__ == '__main__': + write_result_to_file(analysis_file()) diff --git a/src/generate.py b/src/generate.py new file mode 100644 index 0000000..ba89965 --- /dev/null +++ b/src/generate.py @@ -0,0 +1,153 @@ +import requests +import gzip +import json +import argparse +import logging +from concurrent.futures import ThreadPoolExecutor +import threading + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger('__file__') + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--llm-host', type=str, required=True) + parser.add_argument('--llm-port', type=int, required=True) + parser.add_argument('--input-file', type=str, required=True) + parser.add_argument('--output-file', type=str, required=True) + parser.add_argument('--tag-file', type=str, default='src/tags.csv') + parser.add_argument('--num-workers', type=int, default=10) + parser.add_argument('--task-type', + type=str, + choices=['tag', 'quality'], + default='tag') + args, _ = parser.parse_known_args() + return args + + +args = parse_args() + +tag_str = None +tags_list = None +if args.tag_file: + with open(args.tag_file, 'r') as f: + tags_str = f.readline() + tags_list = tags_str.split(',') + + +def send_query(query): + query_text = None + if 'text' in query: + query_text = query['text'] + if args.task_type == 'tag': + prompt = make_prompt_tag(query_text) + elif args.task_type == 'quality': + prompt = make_prompt_quality(query_text) + else: + raise ValueError("Unsupported task type.") + payload = { + "model": "m_model", + "messages": [{ + "role": "user", + "content": prompt + }], + "max_tokens": 20, + "presence_penalty": 1.03, + "frequency_penalty": 1.0, + "seed": None, + "temperature": 0.5, + "top_p": 0.95, + "stream": False + } + try: + response = requests.post( + f"{args.llm_host}:{args.llm_port}/v1/chat/completions", + headers={'Content-Type': 'application/json'}, + json=payload) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + logger.warning(f"Request failed: {e}") + return None + + +def make_prompt_tag(text): + + if args.tag_file: + prompt = f"""请将以下文本分类为以下类别中的一个:{tags_str}。请仔细阅读文本,并根据其主要内容选择最合适的类别。 + + 请避免选择“其它”类别,除非文本内容确实与上述类别完全不相关。请以“XXX”的格式输出答案。 + + 以下是文本: + \"{text}\" + """ + else: + prompt = f"""请将以下文本分类为以下类别中的一个:安全合规、逻辑推理、文本理解、知识问答、AI智能体。请仔细阅读文本,并根据其主要内容选择最合适的类别。 + + - 如果文本涉及算法、自主学习或人工智能系统,请选择“AI智能体”; + - 如果涉及数据保护、法规或合规,请选择“安全合规”; + - 如果涉及数学证明或逻辑推理,请选择“逻辑推理”; + - 如果文本与语言理解、句子结构分析或文本解释有关,请选择“文本理解”; + - 如果文本涉及通过询问或回答问题获取知识或信息,请选择“知识问答”。 + + 请避免选择“其它”类别,除非文本内容确实与上述五个类别完全不相关。请以“XXX”的格式输出答案。 + + 以下是文本: + \"{text}\" + """ + return prompt + + +def make_prompt_quality(text): + prompt = f"""请你对以下中文数据样本进行全面的质量评估,涵盖以下几个方面: + + 1. **完整性**:数据样本是否包含所有必要的信息,是否存在内容遗漏或关键数据缺失。 + 2. **语义一致性**:数据样本的内容在语义上是否连贯一致,是否存在逻辑矛盾或语义模糊的问题。 + 3. **语法正确性**:数据样本的语法是否正确,包括标点符号、拼写、用词和句子结构等方面。 + 4. **数据多样性**:数据样本是否在表达上具有多样性,是否为预训练模型提供了足够丰富的语言现象。 + + 请综合考虑上述各项内容,给出该数据样本的整体质量评分(1-10)。 + 请按照整体质量{1-10}进行评分,分数值取整数,直接将评分的分数值作为输出答案。 + + 以下是文本: + \"{text}\" + """ + return prompt + + +def process_record(record): + resp = send_query(record) + answer = resp.get("choices")[0].get("message").get("content").strip('\n') + + if args.task_type == 'tag': + if answer not in tags_list: + answer = "其它" + record['tag'] = answer + elif args.task_type == 'quality': + record['quality'] = answer + return record + + +def write_to_file(record, f_out, lock): + with lock: + f_out.write(json.dumps(record) + '\n') + + +def process_file(input_filename, output_filename): + # Create a lock object + lock = threading.Lock() + + with gzip.open(input_filename, 'rt', encoding='utf-8') as f_in: + with gzip.open(output_filename, 'wt', encoding='utf-8') as f_out: + # Create a thread pool + with ThreadPoolExecutor(max_workers=args.num_workers) as executor: + for line in f_in: + record = json.loads(line) + future = executor.submit(process_record, record) + future.add_done_callback( + lambda fut: write_to_file(fut.result(), f_out, lock)) + + +if __name__ == "__main__": + process_file(args.input_file, args.output_file) diff --git a/src/tags.csv b/src/tags.csv new file mode 100644 index 0000000..da91a9a --- /dev/null +++ b/src/tags.csv @@ -0,0 +1 @@ +安全合规,逻辑推理,文本理解,知识问答,AI智能体 \ No newline at end of file -- Gitee From c4ac3d754fba041f784e38dc28babce8bac036d0 Mon Sep 17 00:00:00 2001 From: baixinbs Date: Sun, 11 Aug 2024 14:43:48 +0000 Subject: [PATCH 2/3] add README --- README.md | 41 ++++++++++++++--------------------------- 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index af31e1b..02b43b4 100644 --- a/README.md +++ b/README.md @@ -1,37 +1,24 @@ # data-distribution #### 介绍 -训练数据分布分析,包括数据标签和数据质量 +训练数据分布分析,包括 +- 数据标签生成; +- 数据质量评分; +- 数据标签/质量分布统计分析。 #### 软件架构 -软件架构说明 +基于prompt工程直接调用大模型能力。 #### 安装教程 - -1. xxxx -2. xxxx -3. xxxx +直接下载代码。 #### 使用说明 - -1. xxxx -2. xxxx -3. xxxx - -#### 参与贡献 - -1. Fork 本仓库 -2. 新建 Feat_xxx 分支 -3. 提交代码 -4. 新建 Pull Request - - -#### 特技 - -1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md -2. Gitee 官方博客 [blog.gitee.com](https://blog.gitee.com) -3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解 Gitee 上的优秀开源项目 -4. [GVP](https://gitee.com/gvp) 全称是 Gitee 最有价值开源项目,是综合评定出的优秀开源项目 -5. Gitee 官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help) -6. Gitee 封面人物是一档用来展示 Gitee 会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) +运行run.sh,参数说明: +- llm-host: 模型服务地址; +- llm-port: 模型服务端口; +- input-file: 原始的文件; +- output-file: 在原始文件中加入标签/质量评分后写入新文件; +- tag-file: 需要自定义标签体系时,将标签体系写入该csv文件; +- num-workers: 并行调用模型服务的线程数; +- task-type:任务类型,支持"tag"和"quality"两种。 -- Gitee From 8f52e0bdae65ad5437d45714e951eaf36aeb9051 Mon Sep 17 00:00:00 2001 From: baixinbs Date: Mon, 12 Aug 2024 01:06:38 +0000 Subject: [PATCH 3/3] Add analysis to README --- README.md | 13 +++++++++++-- analysis.sh => examples/analysis.sh | 4 +++- run.sh => examples/run.sh | 0 src/analysis.py | 4 ---- 4 files changed, 14 insertions(+), 7 deletions(-) rename analysis.sh => examples/analysis.sh (45%) rename run.sh => examples/run.sh (100%) diff --git a/README.md b/README.md index 02b43b4..a63fcb3 100644 --- a/README.md +++ b/README.md @@ -14,11 +14,20 @@ 直接下载代码。 #### 使用说明 -运行run.sh,参数说明: +1、生成标签或质量评分 + +运行examples/run.sh,参数说明: - llm-host: 模型服务地址; - llm-port: 模型服务端口; - input-file: 原始的文件; - output-file: 在原始文件中加入标签/质量评分后写入新文件; - tag-file: 需要自定义标签体系时,将标签体系写入该csv文件; - num-workers: 并行调用模型服务的线程数; -- task-type:任务类型,支持"tag"和"quality"两种。 +- task-type:任务类型,支持"tag"和"quality"。 + +2、统计标签/质量评分分布 + +运行examples/analysis.sh,参数说明: +- input-file: 待分析的数据文件; +- output-file: 分析结果输出文件; +- task-type: 任务类型,支持"tag"和"quality"。 \ No newline at end of file diff --git a/analysis.sh b/examples/analysis.sh similarity index 45% rename from analysis.sh rename to examples/analysis.sh index a119830..10ddf6f 100644 --- a/analysis.sh +++ b/examples/analysis.sh @@ -1,3 +1,5 @@ python src/analysis.py \ --input-file=file_to_analysis \ - --output_file=file_to_output_result \ No newline at end of file + --output_file=file_to_output_result \ + --task-type=tag_or_quality + diff --git a/run.sh b/examples/run.sh similarity index 100% rename from run.sh rename to examples/run.sh diff --git a/src/analysis.py b/src/analysis.py index a838d6c..9744dd4 100644 --- a/src/analysis.py +++ b/src/analysis.py @@ -23,7 +23,6 @@ args = parse_args() def analysis_file(): value_dict = dict() - cnt = 0 with gzip.open(args.input_file, 'rt', encoding='utf-8') as f: for line in f: # 解析每一行 @@ -34,9 +33,6 @@ def analysis_file(): value_dict[v] += 1 else: value_dict[v] = 1 - cnt += 1 - if cnt == 20: - break return value_dict -- Gitee