qwen_lora_test/data2csv.py

147 lines
4.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import pandas as pd
import json
import requests
# #查看json文件
# print_json_tree(data)
class image_caption_json():
"""
{
"info": {...},
"licenses": [...],
"images": [...],
"annotations": [...]
}
"""
def __init__(self,json_path,data_number):
self.json_path = json_path
self.data_number = data_number
with open(self.json_path, 'r') as f:
self.data = json.load(f)
# 树状输出json
def print_json_tree(self, indent=1):
for key, value in self.data.items():
print(' ' * indent + str(key), end='\n')
# 如果值是列表,统计列表长度
if isinstance(value, list):
value_count = len(value)
# 如果值是字典,统计字典的键数量
elif isinstance(value, dict):
value_count = len(value.keys())
# 如果值是单个值数量为1
else:
value_count = 1
print(value_count)
# if isinstance(value, dict):
# print()
# print_json_tree(value, indent + 1)
# else:
# print(': ' + str(value))
def download_image(self,url, save_dir, filename, timeout=10):
# 1.创建保存路径
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
try:
response = requests.get(url, stream=True)
response.raise_for_status() # 检查请求是否成功
img_path = f"{save_dir}/{filename}"
with open(img_path, "wb") as f:
for chunk in response.iter_content(1024):
f.write(chunk)
print(f"✅ Successfully saved {filename} to: {img_path} from: {url}\n")
return img_path
except Exception as e:
print(f"❌ Failed to download from[{url}]: {e}\n")
return None
def image_annotation(self,save_dir,csv_path):
# if os.path.exists(save_dir):
# print('coco_2014_caption目录已存在,跳过数据处理步骤')
# return 0
# 全局变量记录最近成功的来源0-"coco" 或 1-"flickr"
LAST_SUCCESSFUL_SOURCE = 0
# 初始化存储图片路径和描述的列表
image_paths = []
captions = []
# 获取前data_number个annotation
for i, img_info in enumerate(self.data['images'][:self.data_number]):
# 获取对应的caption和image
img_id = img_info['id']
filename = img_info['file_name']
coco_url = img_info['coco_url']
flickr_url = img_info['flickr_url']
# 查找image对应的caption
# # 匹配所有
# match_annotation =[annotation['caption'] for annotation in self.data['annotations'] if annotation['image_id'] == img_id]
# 只匹配第一个
caption =next((annotation['caption'] for annotation in self.data['annotations'] if annotation['image_id'] == img_id),None)
print(f"{i+1}. 图片ID: {img_id}")
print(f" 文件名: {filename}")
print(f" Caption: {caption}")
# print(f" coco_url: {coco_url}")
# 根据 url 下载图片
if LAST_SUCCESSFUL_SOURCE:
first_url = flickr_url
second_url = coco_url
else:
first_url = coco_url
second_url = flickr_url
image_path = self.download_image(first_url,save_dir,filename)
if not image_path:
image_path = self.download_image(second_url,save_dir,filename)
if image_path:
LAST_SUCCESSFUL_SOURCE =1-LAST_SUCCESSFUL_SOURCE
else:
print(f"❌❌ Failed to download ]\n")
continue
# 将路径和描述添加到列表中
image_paths.append(image_path)
captions.append(caption)
# 将图片路径和描述保存为CSV文件
df = pd.DataFrame({
'image_path': image_paths,
'caption': captions
})
# 将数据保存为CSV文件
df.to_csv(csv_path, index=False)
print(f'数据处理完成')
if __name__ == '__main__':
file_path = '/root/PMN_WS/coco_2014_caption/annotations/captions_train2014.json'
MAX_DATA_NUMBER = 100
image_caption = image_caption_json(file_path,MAX_DATA_NUMBER)
save_dir='/root/PMN_WS/qwen-test/coco_2014_image'
csv_path = '/root/PMN_WS/qwen-test/coco-2024-dataset.csv'
image_caption.image_annotation(save_dir,csv_path)
# image_caption.print_json_tree()