ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

1-cascade MRI reconstruction: dataset.py

2021-07-11 17:01:32  阅读:220  来源: 互联网

标签:self py mask Nx cascade und MRI np data


首先配置文件不能少config.yaml

# Model Parameters
network:
        num_cascades: 6
        num_layers: 5 # Number of layers in the CNN per cascade
        num_filters: 64
        kernel_size: 3
        stride: 1
        padding: 1 #A padding of 1 is needed to keep the image in the same size
        noise: null #Noise in the measurements. To be used in the data consistency step

#Dataset parameters
dataset:
        data_path: 'data1/'
        acceleration_factor: 4.0
        fraction: 0.8 #train set size
        shuffle: 3 #Seed for numpy random generator
        sample_n: 10
        acq_noise: 0 #acquisation noise
        centred: False
        norm: 'ortho'  #norm: 'ortho' or null. if 'ortho', performs unitary transform, otherwise normal dft

# Training parameters
train:
        batch_size: 1
        num_epochs: 5
        early_stop: 100

        # Adam Optimizer Parameters
        learning_rate: 0.001
        b_1: 0.9
        b_2: 0.999
        l2: 0.0000001

        # Miscellaneous
        output_path: 'logs'
        cuda: False

单步执行

import os
import torch
import numpy as np
from math import ceil
from helpers_1 import *
from scipy.io import loadmat
from numpy.lib.stride_tricks import as_strided
import yaml
args = yaml.load(open('config.yaml', 'r'), Loader=yaml.FullLoader)

在这里插入图片描述
dataset = OCMRDataset(fold=‘train’, **args[‘dataset’])

        self.evalset = evalset
        self.data_path = data_path
        self.acc = acceleration_factor
        self.sample_n = sample_n
        self.noise = acq_noise
        self.centred = centred
        self.norm = norm
        self.files = os.listdir(self.data_path)
        if shuffle:
            np.random.seed(shuffle)
            np.random.shuffle(self.files) 
        if fold == 'train':
            self.files = self.files[:int(len(self.files) * fraction)]

在这里插入图片描述

    def __getitem__(self, idx):
        if self.evalset and idx == 0:
            np.random.seed(9001)
        data = loadmat(os.path.join(self.data_path, self.files[idx]))['xn'] * 1e3

在这里插入图片描述

        data = np.expand_dims(data, 0)

因为这里batch_size设的1,所以就有一个256*256
在这里插入图片描述

        mask = self.cartesian_mask(data.shape)
    def cartesian_mask(self, shape):
        N, Nx, Ny = int(np.prod(shape[:-2])), shape[-2], shape[-1]
        pdf_x = normal_pdf(Nx, 0.5/(Nx/10.)**2)

在这里插入图片描述

def normal_pdf(length, sensitivity):
    return np.exp(-sensitivity * (np.arange(length) - length / 2)**2)

在这里插入图片描述
在这里插入图片描述

        lmda = Nx/(2.*self.acc)
        n_lines = int(Nx / self.acc)

        # add uniform distribution
        pdf_x += lmda * 1./Nx

        if self.sample_n:
            pdf_x[Nx//2-self.sample_n//2:Nx//2+self.sample_n//2] = 0
            pdf_x /= np.sum(pdf_x)
            n_lines -= self.sample_n

在这里插入图片描述
在这里插入图片描述

        mask = np.zeros((N, Nx))
        for i in range(N):
            idx = np.random.choice(Nx, n_lines, False, pdf_x)
            mask[i, idx] = 1

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

        if self.sample_n:
            mask[:, Nx//2-self.sample_n//2:Nx//2+self.sample_n//2] = 1

        size = mask.itemsize
        mask = as_strided(mask, (N, Nx, Ny), (size * Nx, size, 0))
        mask = mask.reshape(shape)

在这里插入图片描述
mask上下很多行都是0,中间位置1越来越密集
在这里插入图片描述

        if not self.centred:
            mask = ifftshift(mask, axes=(-1, -2))

        return mask

反过来了1变0,0变1
在这里插入图片描述
在这里插入图片描述

        data_und, k_und = self.undersample(data, mask)

在这里插入图片描述

        assert x.shape == mask.shape
        # zero mean complex Gaussian noise
        noise_power = self.noise
        nz = np.sqrt(.5)*(np.random.normal(0, 1, x.shape) + 1j * np.random.normal(0, 1, x.shape))
        nz = nz * np.sqrt(noise_power)

在这里插入图片描述

        if self.norm == 'ortho':
            # multiplicative factor
            nz = nz * np.sqrt(np.prod(mask.shape[-2:]))

在这里插入图片描述

        if self.centred:
            x_f = fft2c(x, norm=self.norm)
            x_fu = mask * (x_f + nz)
            x_u = ifft2c(x_fu, norm=self.norm)
            return x_u, x_fu
        else:
            x_f = fft2(x, norm=self.norm)
            x_fu = mask * (x_f + nz)
            x_u = ifft2(x_fu, norm=self.norm)
            return x_u, x_fu

在这里插入图片描述
在这里插入图片描述
data_und, k_und = x_u, x_fu
在这里插入图片描述
在这里插入图片描述

        data_gnd = format_data(data)
def format_data(data, mask=False):
    if mask: 
        data = data * (1+1j)
    data = complex2real(data)
def complex2real(x):
	x_real = np.real(x)

在这里插入图片描述

x_imag = np.imag(x)

在这里插入图片描述

y = np.array([x_real, x_imag]).astype(np.float)

在这里插入图片描述

    if x.ndim >= 3:
        y = y.swapaxes(0, 1)
    return y

在这里插入图片描述

def format_data(data, mask=False):

    data = complex2real(data)
    return data.squeeze(0)

在这里插入图片描述
data_gnd = data.squeeze(0)
在这里插入图片描述

        data_und = format_data(data_und)
def format_data(data, mask=False):

    data = complex2real(data)
    return data.squeeze(0)
def complex2real(x):
    x_real = np.real(x)
    x_imag = np.imag(x)
    y = np.array([x_real, x_imag]).astype(np.float)
    # re-order in convenient order
    if x.ndim >= 3:
        y = y.swapaxes(0, 1)
    return y

data = complex2real(data)即等于y
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
data_und = data.squeeze(0)
在这里插入图片描述

        k_und = format_data(k_und)
def format_data(data, mask=False):

    data = complex2real(data)
    return data.squeeze(0)
def complex2real(x):
    x_real = np.real(x)
    x_imag = np.imag(x)
    y = np.array([x_real, x_imag]).astype(np.float)
    # re-order in convenient order
    if x.ndim >= 3:
        y = y.swapaxes(0, 1)
    return y

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
k_und = data.squeeze(0)
在这里插入图片描述

        mask = format_data(mask, mask=True)
def format_data(data, mask=False):

    data = complex2real(data)
    return data.squeeze(0)
def complex2real(x):
    x_real = np.real(x)
    x_imag = np.imag(x)
    y = np.array([x_real, x_imag]).astype(np.float)
    # re-order in convenient order
    if x.ndim >= 3:
        y = y.swapaxes(0, 1)
    return y

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
data = complex2real(data)即等于y
在这里插入图片描述
mask = data.squeeze(0)
在这里插入图片描述

        return {
            'image': data_und,
            'k': k_und.transpose(1,2,0),
            'mask': mask.transpose(1,2,0),
            'full': data_gnd
        }

sample = dataset[0]即4个tensor
第一个tensor。‘image’: data_und
第二个tensor。‘k’: k_und.transpose(1,2,0)
在这里插入图片描述
第三个tensor。‘mask’: mask.transpose(1,2,0)
在这里插入图片描述
第四个tensor。‘full’: data_gnd

输出:
Sample image shape: (2, 256, 256)
Sample full shape: (2, 256, 256)
在这里插入图片描述

标签:self,py,mask,Nx,cascade,und,MRI,np,data
来源: https://blog.csdn.net/xuru_0927/article/details/118654663

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有