147 lines
4.8 KiB
Python
147 lines
4.8 KiB
Python
import os
|
||
import pandas as pd
|
||
import json
|
||
|
||
import requests
|
||
|
||
|
||
|
||
# #查看json文件
|
||
# print_json_tree(data)
|
||
|
||
class image_caption_json():
|
||
"""
|
||
{
|
||
"info": {...},
|
||
"licenses": [...],
|
||
"images": [...],
|
||
"annotations": [...]
|
||
}
|
||
"""
|
||
def __init__(self,json_path,data_number):
|
||
self.json_path = json_path
|
||
self.data_number = data_number
|
||
with open(self.json_path, 'r') as f:
|
||
self.data = json.load(f)
|
||
|
||
|
||
# 树状输出json
|
||
def print_json_tree(self, indent=1):
|
||
for key, value in self.data.items():
|
||
print(' ' * indent + str(key), end='\n')
|
||
|
||
# 如果值是列表,统计列表长度
|
||
if isinstance(value, list):
|
||
value_count = len(value)
|
||
# 如果值是字典,统计字典的键数量
|
||
elif isinstance(value, dict):
|
||
value_count = len(value.keys())
|
||
# 如果值是单个值,数量为1
|
||
else:
|
||
value_count = 1
|
||
print(value_count)
|
||
|
||
|
||
# if isinstance(value, dict):
|
||
# print()
|
||
# print_json_tree(value, indent + 1)
|
||
# else:
|
||
# print(': ' + str(value))
|
||
|
||
|
||
|
||
def download_image(self,url, save_dir, filename, timeout=10):
|
||
# 1.创建保存路径
|
||
if not os.path.exists(save_dir):
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
try:
|
||
response = requests.get(url, stream=True)
|
||
response.raise_for_status() # 检查请求是否成功
|
||
|
||
img_path = f"{save_dir}/{filename}"
|
||
with open(img_path, "wb") as f:
|
||
for chunk in response.iter_content(1024):
|
||
f.write(chunk)
|
||
print(f"✅ Successfully saved {filename} to: {img_path} from: {url}\n")
|
||
return img_path
|
||
except Exception as e:
|
||
print(f"❌ Failed to download from[{url}]: {e}\n")
|
||
return None
|
||
|
||
def image_annotation(self,save_dir,csv_path):
|
||
# if os.path.exists(save_dir):
|
||
# print('coco_2014_caption目录已存在,跳过数据处理步骤')
|
||
# return 0
|
||
|
||
|
||
# 全局变量,记录最近成功的来源(0-"coco" 或 1-"flickr")
|
||
LAST_SUCCESSFUL_SOURCE = 0
|
||
# 初始化存储图片路径和描述的列表
|
||
image_paths = []
|
||
captions = []
|
||
|
||
# 获取前data_number个annotation
|
||
for i, img_info in enumerate(self.data['images'][:self.data_number]):
|
||
# 获取对应的caption和image
|
||
img_id = img_info['id']
|
||
filename = img_info['file_name']
|
||
coco_url = img_info['coco_url']
|
||
flickr_url = img_info['flickr_url']
|
||
# 查找image对应的caption
|
||
# # 匹配所有
|
||
# match_annotation =[annotation['caption'] for annotation in self.data['annotations'] if annotation['image_id'] == img_id]
|
||
# 只匹配第一个
|
||
caption =next((annotation['caption'] for annotation in self.data['annotations'] if annotation['image_id'] == img_id),None)
|
||
|
||
print(f"{i+1}. 图片ID: {img_id}")
|
||
print(f" 文件名: {filename}")
|
||
print(f" Caption: {caption}")
|
||
# print(f" coco_url: {coco_url}")
|
||
|
||
# 根据 url 下载图片
|
||
if LAST_SUCCESSFUL_SOURCE:
|
||
first_url = flickr_url
|
||
second_url = coco_url
|
||
else:
|
||
first_url = coco_url
|
||
second_url = flickr_url
|
||
|
||
image_path = self.download_image(first_url,save_dir,filename)
|
||
if not image_path:
|
||
image_path = self.download_image(second_url,save_dir,filename)
|
||
if image_path:
|
||
LAST_SUCCESSFUL_SOURCE =1-LAST_SUCCESSFUL_SOURCE
|
||
else:
|
||
print(f"❌❌ Failed to download ]\n")
|
||
continue
|
||
|
||
# 将路径和描述添加到列表中
|
||
image_paths.append(image_path)
|
||
captions.append(caption)
|
||
|
||
# 将图片路径和描述保存为CSV文件
|
||
df = pd.DataFrame({
|
||
'image_path': image_paths,
|
||
'caption': captions
|
||
})
|
||
# 将数据保存为CSV文件
|
||
df.to_csv(csv_path, index=False)
|
||
|
||
print(f'数据处理完成')
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
file_path = '/root/PMN_WS/coco_2014_caption/annotations/captions_train2014.json'
|
||
MAX_DATA_NUMBER = 100
|
||
|
||
image_caption = image_caption_json(file_path,MAX_DATA_NUMBER)
|
||
save_dir='/root/PMN_WS/qwen-test/coco_2014_image'
|
||
csv_path = '/root/PMN_WS/qwen-test/coco-2024-dataset.csv'
|
||
image_caption.image_annotation(save_dir,csv_path)
|
||
# image_caption.print_json_tree()
|