diff --git a/README.md b/README.md index af31e1bd22b0dc86dec6b1a46caaf05e36fd4391..a63fcb37ff8d8fec1d02913a03dd2960d451fb78 100644 --- a/README.md +++ b/README.md @@ -1,37 +1,33 @@ # data-distribution #### 介绍 -训练数据分布分析,包括数据标签和数据质量 +训练数据分布分析,包括 +- 数据标签生成; +- 数据质量评分; +- 数据标签/质量分布统计分析。 #### 软件架构 -软件架构说明 +基于prompt工程直接调用大模型能力。 #### 安装教程 - -1. xxxx -2. xxxx -3. xxxx +直接下载代码。 #### 使用说明 - -1. xxxx -2. xxxx -3. xxxx - -#### 参与贡献 - -1. Fork 本仓库 -2. 新建 Feat_xxx 分支 -3. 提交代码 -4. 新建 Pull Request - - -#### 特技 - -1. 使用 Readme\_XXX.md 来支持不同的语言,例如 Readme\_en.md, Readme\_zh.md -2. Gitee 官方博客 [blog.gitee.com](https://blog.gitee.com) -3. 你可以 [https://gitee.com/explore](https://gitee.com/explore) 这个地址来了解 Gitee 上的优秀开源项目 -4. [GVP](https://gitee.com/gvp) 全称是 Gitee 最有价值开源项目,是综合评定出的优秀开源项目 -5. Gitee 官方提供的使用手册 [https://gitee.com/help](https://gitee.com/help) -6. Gitee 封面人物是一档用来展示 Gitee 会员风采的栏目 [https://gitee.com/gitee-stars/](https://gitee.com/gitee-stars/) +1、生成标签或质量评分 + +运行examples/run.sh,参数说明: +- llm-host: 模型服务地址; +- llm-port: 模型服务端口; +- input-file: 原始的文件; +- output-file: 在原始文件中加入标签/质量评分后写入新文件; +- tag-file: 需要自定义标签体系时,将标签体系写入该csv文件; +- num-workers: 并行调用模型服务的线程数; +- task-type:任务类型,支持"tag"和"quality"。 + +2、统计标签/质量评分分布 + +运行examples/analysis.sh,参数说明: +- input-file: 待分析的数据文件; +- output-file: 分析结果输出文件; +- task-type: 任务类型,支持"tag"和"quality"。 \ No newline at end of file diff --git a/examples/analysis.sh b/examples/analysis.sh new file mode 100644 index 0000000000000000000000000000000000000000..10ddf6f09e2c26e007e863f2d7dc05008953188d --- /dev/null +++ b/examples/analysis.sh @@ -0,0 +1,5 @@ +python src/analysis.py \ + --input-file=file_to_analysis \ + --output_file=file_to_output_result \ + --task-type=tag_or_quality + diff --git a/examples/run.sh b/examples/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..392c9cf022fa43123343e805b3d0242435f7aab0 --- /dev/null +++ b/examples/run.sh @@ -0,0 +1,8 @@ +python src/generate.py \ + --llm-host=http://x.x.x.x \ + --llm-port=xxxx \ + --input-file=path_to_input_file \ + --output-file=path_to_ouput_file \ + --tag-file=path_to_tag_file \ + --num-workers=20 \ + --task-type=tag diff --git a/src/analysis.py b/src/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..9744dd4699b70d04e2de777afb1fbacf70184c12 --- /dev/null +++ b/src/analysis.py @@ -0,0 +1,46 @@ +import gzip +import json +import argparse +import logging + +logger = logging.getLogger(__file__) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--input-file', type=str, required=True) + parser.add_argument('--output-file', type=str, required=True) + parser.add_argument('--task_type', + type=str, + choices=['tag', 'quality'], + default='tag') + args, _ = parser.parse_known_args() + return args + + +args = parse_args() + + +def analysis_file(): + value_dict = dict() + with gzip.open(args.input_file, 'rt', encoding='utf-8') as f: + for line in f: + # 解析每一行 + record = json.loads(line) + # tag或quality值 + v = record[args.task_type] + if v in value_dict: + value_dict[v] += 1 + else: + value_dict[v] = 1 + return value_dict + + +def write_result_to_file(value_dict): + with open(args.output_file, 'w') as f: + for v in value_dict: + f.write(f"{v},{value_dict[v]}\n") + + +if __name__ == '__main__': + write_result_to_file(analysis_file()) diff --git a/src/generate.py b/src/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..ba89965e92bb8a99b6df3c888a40549d8d064437 --- /dev/null +++ b/src/generate.py @@ -0,0 +1,153 @@ +import requests +import gzip +import json +import argparse +import logging +from concurrent.futures import ThreadPoolExecutor +import threading + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger('__file__') + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--llm-host', type=str, required=True) + parser.add_argument('--llm-port', type=int, required=True) + parser.add_argument('--input-file', type=str, required=True) + parser.add_argument('--output-file', type=str, required=True) + parser.add_argument('--tag-file', type=str, default='src/tags.csv') + parser.add_argument('--num-workers', type=int, default=10) + parser.add_argument('--task-type', + type=str, + choices=['tag', 'quality'], + default='tag') + args, _ = parser.parse_known_args() + return args + + +args = parse_args() + +tag_str = None +tags_list = None +if args.tag_file: + with open(args.tag_file, 'r') as f: + tags_str = f.readline() + tags_list = tags_str.split(',') + + +def send_query(query): + query_text = None + if 'text' in query: + query_text = query['text'] + if args.task_type == 'tag': + prompt = make_prompt_tag(query_text) + elif args.task_type == 'quality': + prompt = make_prompt_quality(query_text) + else: + raise ValueError("Unsupported task type.") + payload = { + "model": "m_model", + "messages": [{ + "role": "user", + "content": prompt + }], + "max_tokens": 20, + "presence_penalty": 1.03, + "frequency_penalty": 1.0, + "seed": None, + "temperature": 0.5, + "top_p": 0.95, + "stream": False + } + try: + response = requests.post( + f"{args.llm_host}:{args.llm_port}/v1/chat/completions", + headers={'Content-Type': 'application/json'}, + json=payload) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + logger.warning(f"Request failed: {e}") + return None + + +def make_prompt_tag(text): + + if args.tag_file: + prompt = f"""请将以下文本分类为以下类别中的一个:{tags_str}。请仔细阅读文本,并根据其主要内容选择最合适的类别。 + + 请避免选择“其它”类别,除非文本内容确实与上述类别完全不相关。请以“XXX”的格式输出答案。 + + 以下是文本: + \"{text}\" + """ + else: + prompt = f"""请将以下文本分类为以下类别中的一个:安全合规、逻辑推理、文本理解、知识问答、AI智能体。请仔细阅读文本,并根据其主要内容选择最合适的类别。 + + - 如果文本涉及算法、自主学习或人工智能系统,请选择“AI智能体”; + - 如果涉及数据保护、法规或合规,请选择“安全合规”; + - 如果涉及数学证明或逻辑推理,请选择“逻辑推理”; + - 如果文本与语言理解、句子结构分析或文本解释有关,请选择“文本理解”; + - 如果文本涉及通过询问或回答问题获取知识或信息,请选择“知识问答”。 + + 请避免选择“其它”类别,除非文本内容确实与上述五个类别完全不相关。请以“XXX”的格式输出答案。 + + 以下是文本: + \"{text}\" + """ + return prompt + + +def make_prompt_quality(text): + prompt = f"""请你对以下中文数据样本进行全面的质量评估,涵盖以下几个方面: + + 1. **完整性**:数据样本是否包含所有必要的信息,是否存在内容遗漏或关键数据缺失。 + 2. **语义一致性**:数据样本的内容在语义上是否连贯一致,是否存在逻辑矛盾或语义模糊的问题。 + 3. **语法正确性**:数据样本的语法是否正确,包括标点符号、拼写、用词和句子结构等方面。 + 4. **数据多样性**:数据样本是否在表达上具有多样性,是否为预训练模型提供了足够丰富的语言现象。 + + 请综合考虑上述各项内容,给出该数据样本的整体质量评分(1-10)。 + 请按照整体质量{1-10}进行评分,分数值取整数,直接将评分的分数值作为输出答案。 + + 以下是文本: + \"{text}\" + """ + return prompt + + +def process_record(record): + resp = send_query(record) + answer = resp.get("choices")[0].get("message").get("content").strip('\n') + + if args.task_type == 'tag': + if answer not in tags_list: + answer = "其它" + record['tag'] = answer + elif args.task_type == 'quality': + record['quality'] = answer + return record + + +def write_to_file(record, f_out, lock): + with lock: + f_out.write(json.dumps(record) + '\n') + + +def process_file(input_filename, output_filename): + # Create a lock object + lock = threading.Lock() + + with gzip.open(input_filename, 'rt', encoding='utf-8') as f_in: + with gzip.open(output_filename, 'wt', encoding='utf-8') as f_out: + # Create a thread pool + with ThreadPoolExecutor(max_workers=args.num_workers) as executor: + for line in f_in: + record = json.loads(line) + future = executor.submit(process_record, record) + future.add_done_callback( + lambda fut: write_to_file(fut.result(), f_out, lock)) + + +if __name__ == "__main__": + process_file(args.input_file, args.output_file) diff --git a/src/tags.csv b/src/tags.csv new file mode 100644 index 0000000000000000000000000000000000000000..da91a9a5a234a77c015442be08f83e8d4e64796f --- /dev/null +++ b/src/tags.csv @@ -0,0 +1 @@ +安全合规,逻辑推理,文本理解,知识问答,AI智能体 \ No newline at end of file